In [None]:
# Install the transformers package WITH the ESM model. 
# It is unfortunately not available in the official release yet.
#!git clone -b add_esm-proper --single-branch https://github.com/liujas000/transformers.git 
!pip -q install ./transformers

In [118]:
# Load packages

import math
from sklearn.model_selection import train_test_split
from transformers import pipeline, ESMForTokenClassification, ESMTokenizer, ESMForMaskedLM, ESMForSequenceClassification, AdamW
import torch
from torch.utils.data import DataLoader, Dataset
from torch.nn.functional import one_hot

# Use MPS or CUDA if available:
if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

In [69]:
# What is this notebook about?
generator = pipeline("text-generation", model = "gpt2", pad_token_id = 50256, num_return_sequences=1)
print(generator("This notebook is all about proteins, friends and ")[0]['generated_text'])

This notebook is all about proteins, friends and ersatz "funtime." How much time do we have to discuss all this protein stuff? Is it easy? How to make the best meat, a better dinner, and a better life for yourself


In [46]:
# Data preprocessing

# Get sequences, accession number and main category labels:

sequence = ""
sequences = list()
acc_num = list()
main_cat = list()

first = True
with open("../data/terp.faa") as file:
    
    first_acc = file.readline()
    acc_num.append(first_acc.split(">")[1].strip())
    main_cat.append(first_acc.split("_")[1].strip())

    for line in file:
        if line.startswith(">"):
            sequences.append(sequence)
            sequence = ""
            acc_num.append(line.split(">")[1].strip())
            main_cat.append(line.split("_")[1].strip())
        else:
            sequence += line.strip()
    
    # Add last sequence
    sequences.append(sequence)

# Create numbered labels for main categories:

main2label = {c: l for l, c in enumerate(sorted(set(main_cat)))}
label2main = {l: c for c, l in main2label.items()}

# Create class translation dictionary for accession numbers:

acc2class = dict()

with open("../data/class_vs_acc_v2.txt", "r") as file:
    for line in file:
        t_class = line.split("\t")[0]
        acc = line.split("\t")[1].strip()[1:]
        acc2class[acc] = t_class

# Create numbered labels for classes:
        
class2label = {c: l for l, c in enumerate(sorted(set(acc2class.values())))}
label2class = {l: c for c, l in class2label.items()}

print(
    f"The files contain:",
    f"{len(sequences)} sequences in",
    f"{len(set(main_cat))} main categories and",
    f"{len(set(acc2class.values()))} classes")

The files contain: 534 sequences in 10 main categories and 49 classes


In [None]:
# Possibly check class distribution here...

In [11]:
# Choose between category and class:
labels = main_cat
#labels = acc_num # This will translate to class later.

# Split into training and validation set. Is this necessary?

train_seq, val_seq, train_labels, val_labels = train_test_split(sequences, labels, test_size=.1)

print(f"Training size: {len(train_seq)} Validation size: {len(val_seq)}")

Training size: 480 Validation size: 54


In [131]:
# Tokenizer:
tokenizer = ESMTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D", do_lower_case=False)

Downloading:   0%|          | 0.00/93.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/125 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/40.0 [00:00<?, ?B/s]

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'EsmTokenizer'. 
The class this function is called from is 'ESMTokenizer'.


In [126]:
class SequenceDataset(Dataset):
    def __init__(self, input_sequences, input_labels, categories=True):
        
        # Init is run once, when instantiating the dataset class.
        #
        # Either supply with:
        #  Main categories - category classification
        #  Accession numbers - class classifcation 
        
        # The xx2label turns the label from text to a number from 0-(N-1) 
        if categories:
            self.labels = [main2label[cat] for cat in input_labels]
        else:
            self.labels = [class2label[acc2class[acc]] for acc in input_labels]
        
        # Tokenize sequence and pad to longest sequence in dataset.
        # Return pytorch-type tensors
        self.sequences = tokenizer(
                                input_sequences,
                                padding = 'longest',
                                return_tensors = 'pt')
        # Save label type
        self.label_type_cat = categories
        
    def classes(self):
        
        # Returns the classes in the dataset (optional function)
        return self.labels

    def __len__(self):
        
        # Returns the number of samples in dataset (required)
        return len(self.labels)

    def __getitem__(self, idx):
        
        # Returns a sample at position idx (required)
        # The sample includes:
        # - Input id's for the sequence
        # - Attention mask (to only focus on the sequence and not padding)
        # - Label (one-hot encoded)
        
        input_ids = self.sequences['input_ids'][idx]
        attention_mask = self.sequences['attention_mask'][idx]
        label = torch.tensor(self.labels[idx])
        num_labels = len(main2label.values()) if self.label_type_cat else len(class2label.values())
        
        sample = dict()
        sample['input_ids'] = input_ids
        sample['attention_mask'] = attention_mask
        sample['label'] = one_hot(label,
                                  num_classes=num_labels).to(torch.float)
        return sample

In [127]:
train_dataset = SequenceDataset(train_seq, train_labels)
val_dataset = SequenceDataset(val_seq, val_labels)

In [129]:
# Example of returning a single sample:
train_dataset.__getitem__(1)

{'input_ids': tensor([ 0, 20, 14, 18,  5,  8, 11,  6, 23, 17, 14,  6,  4,  9, 15, 11, 10,  9,
          5,  5, 22,  9, 22,  5,  9,  5,  9,  6,  4, 12,  4,  8,  7, 14,  5, 10,
         10, 15, 20, 12, 10, 11, 10, 14,  9,  4, 22, 12,  8,  4, 12, 18, 14, 15,
          5,  8, 16, 21, 21,  4, 13,  4, 18, 23, 16, 22,  4, 18, 22,  5, 18,  4,
          7, 13, 13,  9, 18, 13, 13,  6, 14,  5,  6, 10, 13, 14,  4, 20, 23,  9,
          5,  5, 12, 11, 10,  4,  7, 13,  7, 18, 13,  6,  5,  5, 14, 21,  6, 14,
         20,  9, 16,  5,  4, 11,  6,  4, 10,  9, 10, 11, 23, 10, 13, 10,  8, 14,
         16, 22, 17, 10, 16, 18, 10, 10, 13, 11,  5,  5, 22,  4, 22, 11, 19, 19,
          5,  9,  5,  7,  9, 10,  5,  5,  6, 16,  7, 14,  8, 10,  7, 13, 18,  7,
         15, 21, 10, 10, 13,  8,  7,  5, 20, 16, 14, 18,  4, 13,  4, 21,  9, 12,
         11,  5,  6, 12, 13,  4, 14, 13,  8,  5, 10,  8,  4, 14,  5, 19, 12,  5,
          4, 10, 17,  5,  7, 11, 13, 21,  8,  6,  4, 23, 17, 13, 12, 23,  8, 18,
          9, 15

In [146]:
import sys

In [153]:
# Training using pytorch only

num_labels = len(set(train_dataset.classes()))
epochs = 1
batch_size = 100

model = ESMForSequenceClassification.from_pretrained(
    "facebook/esm2_t6_8M_UR50D",
    num_labels = num_labels,
    problem_type = "multi_label_classification")

device = torch.device("cpu")
model.to(device)
model.train()

# Train loader returns an iter with the number of samples = batch size.
train_loader = DataLoader(train_dataset, 
                          batch_size=batch_size, 
                          shuffle=True)

optim = AdamW(model.parameters(), lr=5e-5)

n = 0
for epoch in range(epochs):
    for batch in train_loader:
        n += 1

print(n)

sys.exit(1)

for epoch in range(epochs):
    for batch in train_loader:
        
        optim.zero_grad()
        input_ids = sample['input_ids'].to(device)
        attention_mask = sample['attention_mask'].to(device)
        label = sample['label'].to(device)
        
        outputs = model(input_ids, attention_mask = attention_mask, labels=label)
        
        print(outputs)
        
        loss = outputs[0]
        
        print(loss)
        
        loss.backward()
        
        print(loss)
        
        optim.step()

model.eval()


Some weights of the model checkpoint at facebook/esm2_t6_8M_UR50D were not used when initializing ESMForSequenceClassification: ['esm.encoder.layer.2.attention.self.rotary_embeddings.inv_freq', 'esm.encoder.layer.4.attention.self.rotary_embeddings.inv_freq', 'lm_head.dense.bias', 'lm_head.decoder.bias', 'esm.encoder.layer.5.attention.self.rotary_embeddings.inv_freq', 'esm.encoder.layer.0.attention.self.rotary_embeddings.inv_freq', 'esm.encoder.layer.1.attention.self.rotary_embeddings.inv_freq', 'lm_head.decoder.weight', 'lm_head.layer_norm.bias', 'esm.encoder.layer.3.attention.self.rotary_embeddings.inv_freq', 'lm_head.dense.weight', 'lm_head.layer_norm.weight', 'lm_head.bias']
- This IS expected if you are initializing ESMForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ESMForSequenceClass

5


SystemExit: 1

In [143]:
print(outputs)
print(outputs[1].argmax())

SequenceClassifierOutput(loss=tensor(0.5423, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>), logits=tensor([[-0.4787, -0.0347, -0.4638, -0.2717, -0.4413, -0.3257, -0.5093, -0.1469,
         -0.5806, -0.4681]], grad_fn=<AddmmBackward0>), hidden_states=None, attentions=None)
tensor(1)


In [145]:
sample['label'].argmax()

tensor(7)

In [None]:
tokenizer = ESMTokenizer.from_pretrained("facebook/esm-1b", do_lower_case=False)
model = ESMForMaskedLM.from_pretrained("facebook/esm-1b")
unmasker = pipeline('fill-mask', model=model, tokenizer=tokenizer)
unmasker('QERLKSIVRILE<mask>SLGYNIVAT')