In [20]:
from transformers import AutoTokenizer, AutoModel,  AutoModelForSequenceClassification
from Bio import SeqIO
import torch
from torch.utils.data import DataLoader, TensorDataset
import random
import numpy as np
import torch.nn as nn
# https://huggingface.co/blog/AmelieSchreiber/esmbind
# the minimum for the ESM2 is 650M if we want better performance than ESM1b with 650M as well.

In [34]:
dataset = TensorDataset(torch.arange(40, dtype=torch.float32).view(10, 4), torch.tensor([i for i in range(10)]))

In [7]:
def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)


In [14]:
torch.arange(40).shape

tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
        36, 37, 38, 39])

In [36]:
a = nn.Conv1d(1, 10, 2)
g = torch.Generator()
g.manual_seed(0)
dl = DataLoader(dataset, batch_size=4, worker_init_fn=seed_worker, generator=g, shuffle=True, num_workers=2)
for i in range(2):
    print("Epoch", i)
    for batch, label in dl:
       print(batch.unsqueeze(1).shape)
       print(batch.unsqueeze(1))
       o = a(batch.unsqueeze(1))
       print(o.shape)
       print(o)

Epoch 0
torch.Size([4, 1, 4])
tensor([[[12., 13., 14., 15.]],

        [[28., 29., 30., 31.]],

        [[20., 21., 22., 23.]],

        [[ 8.,  9., 10., 11.]]])
torch.Size([4, 10, 3])
tensor([[[  7.9515,   8.6100,   9.2686],
         [ -5.7440,  -6.2499,  -6.7559],
         [ -2.4015,  -2.5816,  -2.7617],
         [-12.2291, -13.2001, -14.1711],
         [ 10.4888,  11.3591,  12.2294],
         [  3.9718,   4.3172,   4.6625],
         [  1.2479,   1.3680,   1.4881],
         [ -0.7694,  -0.8526,  -0.9357],
         [-12.2041, -13.1510, -14.0978],
         [ -2.9194,  -3.1050,  -3.2906]],

        [[ 18.4879,  19.1464,  19.8049],
         [-13.8394, -14.3453, -14.8513],
         [ -5.2831,  -5.4632,  -5.6433],
         [-27.7649, -28.7359, -29.7069],
         [ 24.4138,  25.2841,  26.1544],
         [  9.4968,   9.8421,  10.1874],
         [  3.1696,   3.2897,   3.4098],
         [ -2.0995,  -2.1826,  -2.2658],
         [-27.3533, -28.3001, -29.2469],
         [ -5.8897,  -6.0753,  -6.

In [29]:
batch.unsqueeze(1)

tensor([[[22., 23.]],

        [[ 8.,  9.]],

        [[34., 35.]],

        [[12., 13.]]])

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [3]:
with open('../data/whole_sequence.fasta', 'r') as f:
    seqs = list(SeqIO.parse(f, 'fasta'))
seq = [str(s.seq) for s in seqs]

In [4]:
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")
tok = tokenizer(seq[:2], padding=True, truncation=True, return_tensors="pt", is_split_into_words=False)

In [24]:
len(seq[1])

341

In [5]:
model = AutoModel.from_pretrained("facebook/esm2_t6_8M_UR50D", add_pooling_layer=False, output_hidden_states=True)
model.to(device)
model.eval()
n = AutoModelForSequenceClassification.from_pretrained("facebook/esm2_t6_8M_UR50D", num_labels=2)

Some weights of EsmForSequenceClassification were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['classifier.dense.weight', 'classifier.out_proj.weight', 'classifier.out_proj.bias', 'classifier.dense.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [51]:
out = n(**tok)

In [6]:
output = model(**tok)

In [27]:
output.hidden_states[-1].mean(dim=0).shape

torch.Size([343, 320])

In [77]:
tuple(output.hidden_states[-1].shape)

(2, 343, 320)

In [82]:
attention_weights = torch.nn.Linear(320, 1)

In [87]:
attention_scores = attention_weights(output.hidden_states[-1])
attention_weights = torch.softmax(attention_scores, -1)

TypeError: 'Tensor' object is not callable

In [89]:
attention_weights.shape

torch.Size([2, 343, 1])

In [29]:
_temp = output.hidden_states[-1].reshape(output.hidden_states[-1].shape[0], -1)
_temp.shape

torch.Size([2, 109760])

In [53]:
_temp[0]

tensor([ 0.1419,  0.5839, -0.0722,  ...,  0.4682, -0.6849, -0.3094],
       grad_fn=<SelectBackward0>)

In [32]:
(0, 2048 - _temp.shape[1])

(0, -107712)

In [66]:
o = torch.nn.functional.pad(_temp, (0, 2048 - _temp.shape[1]))

In [68]:
o[0][:10]

tensor([ 0.1419,  0.5839, -0.0722,  0.3390, -0.1853, -0.0982, -0.9235,  0.1019,
        -0.4527, -0.6959], grad_fn=<SliceBackward0>)

In [50]:
len(set(o[0].detach().numpy()).intersection(_temp[0].detach().numpy()))

109670

In [74]:
len(set(output.hidden_states[-1][0][0].detach().numpy()).intersection(_temp[0][:100].detach().numpy()))

100