In [1]:
import pandas as pd
import os
# import wandb
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from datetime import datetime
from sklearn.model_selection import train_test_split
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, roc_auc_score, matthews_corrcoef
from transformers import (
    AutoModel,
    AutoTokenizer,
    # DataCollatorForTokenClassification,
    TrainingArguments,
    Trainer,
    BitsAndBytesConfig
)
from datasets import Dataset
from accelerate import Accelerator
from peft import get_peft_config, PeftModel, PeftConfig, get_peft_model, LoraConfig, TaskType, prepare_model_for_kbit_training
import pickle
from saveAndLoad import *
import gc

In [2]:
checkpoint = 'facebook/esm2_t30_150M_UR50D'
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = AutoModel.from_pretrained(checkpoint).to('cuda')
model.eval()

Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t30_150M_UR50D and are newly initialized: ['esm.pooler.dense.weight', 'esm.pooler.dense.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


EsmModel(
  (embeddings): EsmEmbeddings(
    (word_embeddings): Embedding(33, 640, padding_idx=1)
    (dropout): Dropout(p=0.0, inplace=False)
    (position_embeddings): Embedding(1026, 640, padding_idx=1)
  )
  (encoder): EsmEncoder(
    (layer): ModuleList(
      (0-29): 30 x EsmLayer(
        (attention): EsmAttention(
          (self): EsmSelfAttention(
            (query): Linear(in_features=640, out_features=640, bias=True)
            (key): Linear(in_features=640, out_features=640, bias=True)
            (value): Linear(in_features=640, out_features=640, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
            (rotary_embeddings): RotaryEmbedding()
          )
          (output): EsmSelfOutput(
            (dense): Linear(in_features=640, out_features=640, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
          )
          (LayerNorm): LayerNorm((640,), eps=1e-05, elementwise_affine=True)
        )
        (intermediate): EsmIntermediate(
  

In [3]:
seqs = pickleLoad('../../aa/canonical_mut.pkl')
refs = pickleLoad('../../aa/canonical_ref.pkl')

loading data from ../../aa/canonical_mut.pkl
loading data from ../../aa/canonical_ref.pkl


In [4]:
mutSeqSet = set(seqs)
len(mutSeqSet)

656995

In [5]:
refSeqSet = set(refs)
len(refSeqSet)

1433

In [6]:
with torch.no_grad():
    idx = 569032
    print(len(seqs[idx]))
    inputs = tokenizer(seqs[idx])
    ids = torch.tensor([inputs['input_ids']]).to('cuda')
    att = torch.tensor([inputs['attention_mask']]).to('cuda')
    output = model(ids, attention_mask=att)
    print(len(ids))
    output

14509
1


In [17]:
ref_idx_map = pickleLoad('../../aa/idxMap_canonical_mut_to_ref.pkl')
idx = 54
with torch.no_grad():
    print(len(seqs[idx]))
    inputs = tokenizer(seqs[idx])
    ids = torch.tensor([inputs['input_ids']]).to('cuda')
    att = torch.tensor([inputs['attention_mask']]).to('cuda')
    output0 = model(ids, attention_mask=att)
    output0 = output0['pooler_output'].detach().to('cpu').numpy()[0]

with torch.no_grad():
    idx = ref_idx_map[idx]
    print(len(refs[idx]))
    inputs = tokenizer(refs[idx])
    ids = torch.tensor([inputs['input_ids']]).to('cuda')
    att = torch.tensor([inputs['attention_mask']]).to('cuda')
    output1 = model(ids, attention_mask=att)
    output1 = output1['pooler_output'].detach().to('cpu').numpy()[0]

print(F.cosine_similarity(torch.tensor(output0).unsqueeze(0),torch.tensor(output1).unsqueeze(0)))

for i,j in zip(output0[:50],output1[:50]):
    print(i,j)

loading data from ../../aa/idxMap_canonical_mut_to_ref.pkl
928
928
tensor([1.0000])
0.065024815 0.0650314
0.26945215 0.26973447
0.11472979 0.115435004
0.22270499 0.22282611
0.123758435 0.12371794
-0.16197532 -0.16178976
0.06918807 0.07076595
0.044536114 0.045179795
-0.09739348 -0.09797263
-0.12477154 -0.123642206
0.015398634 0.015090378
0.11419302 0.11459662
-0.06595691 -0.06460882
-0.011293549 -0.011246576
-0.12015294 -0.12091027
0.038723208 0.039486215
-0.0009982433 -0.0009895485
0.18213648 0.18227133
0.1296329 0.12970518
-0.24497712 -0.24443491
-0.1895033 -0.19017214
0.07730589 0.07757619
-0.017968405 -0.018914307
0.14176086 0.14151566
-0.06627409 -0.065241754
-0.075001255 -0.074546695
-0.10816459 -0.107831396
0.10311163 0.10303631
0.08187747 0.08124441
0.14174916 0.14173748
0.15100494 0.15118705
-0.13025278 -0.13047566
-0.29223242 -0.29288253
-0.1458308 -0.14635356
0.0818969 0.08190235
-0.18994533 -0.19003633
-0.00015120208 0.00085328496
-0.44124052 -0.44161108
-0.041721288 -0.0429

In [36]:
output['last_hidden_state'].shape

torch.Size([1, 14511, 640])

In [37]:
output['pooler_output'].shape

torch.Size([1, 640])

In [29]:
type(output['pooler_output'].detach().to('cpu').numpy()[0][0])

numpy.float32

In [15]:
max=5
n=0
longest = 0
longest_i = 0
lens = set()
for ni,i in enumerate(seqs):
    if i is None: continue
    lens.add(len(i))
    if len(i)>longest:
        longest = len(i)
        longest_i = ni
    if 900<len(i)<1000:
        if n<max:
            print(ni, len(i))
            n+=1
print('\n',longest,longest_i,'\n')
print(sorted(list(lens))[-15:])


9 976
54 928
57 976
59 976
83 976

 14511 568862 

[5809, 5883, 5889, 5890, 8789, 8794, 8795, 8796, 8797, 8798, 14506, 14507, 14508, 14509, 14511]
