In [1]:
# Install the transformers package WITH the ESM model. 
# It is unfortunately not available in the official release yet.
#!git clone -b add_esm-proper --single-branch https://github.com/liujas000/transformers.git 
#!pip -q install ./transformers

In [18]:
# Load packages
import time
import datetime
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, balanced_accuracy_score
from transformers import pipeline, ESMTokenizer, ESMForSequenceClassification, AdamW
from transformers import get_linear_schedule_with_warmup as linear_scheduler
from transformers import get_constant_schedule_with_warmup as constant_scheduler
import torch
from torch.utils.data import DataLoader, Dataset
from torch.nn.functional import one_hot

# Use MPS or CUDA if available:
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

In [2]:
# Timing helper function
def time_step(elapsed):
    '''
    Takes a time in seconds and returns a string hh:mm:ss
    '''
    # Round to the nearest second.
    elapsed_rounded = int(round((elapsed)))
    
    # Format as hh:mm:ss
    return str(datetime.timedelta(seconds=elapsed_rounded))

In [3]:
# What is this notebook about?
generator = pipeline("text-generation", model = "gpt2", pad_token_id = 50256, num_return_sequences=1)
print(generator("This notebook is all about proteins, friends and ")[0]['generated_text'])

This notebook is all about proteins, friends and vernaculars. You won't want to look around the kitchen to find out where someone's egg salad was made, you won't want to know where your chicken salad is made, or that the


In [7]:
# Data preprocessing

# Get sequences, accession number and main category labels:

sequence = ""
sequences = list()
acc_num = list()
main_cat = list()

first = True
with open("../data/terp.faa") as file:
    
    first_acc = file.readline()
    acc_num.append(first_acc.split(">")[1].strip())
    main_cat.append(first_acc.split("_")[1].strip())

    for line in file:
        if line.startswith(">"):
            sequences.append(sequence)
            sequence = ""
            acc_num.append(line.split(">")[1].strip())
            main_cat.append(line.split("_")[1].strip())
        else:
            sequence += line.strip()
    
    # Add last sequence
    sequences.append(sequence)

# Create numbered labels for main categories:

main2label = {c: l for l, c in enumerate(sorted(set(main_cat)))}
label2main = {l: c for c, l in main2label.items()}

# Create class translation dictionary for accession numbers:

acc2class = dict()

with open("../data/class_vs_acc_v2.txt", "r") as file:
    for line in file:
        t_class = line.split("\t")[0]
        acc = line.split("\t")[1].strip()[1:]
        acc2class[acc] = t_class

# Create numbered labels for classes:
        
class2label = {c: l for l, c in enumerate(sorted(set(acc2class.values())))}
label2class = {l: c for c, l in class2label.items()}

print(
    f"The files contain:",
    f"{len(sequences)} sequences in",
    f"{len(set(main_cat))} main categories and",
    f"{len(set(acc2class.values()))} classes")

The files contain: 534 sequences in 10 main categories and 49 classes


In [5]:
# Possibly check class distribution here...

In [8]:
# Choose between category and class:
labels = main_cat
#labels = acc_num # This will translate to class later.

# Split into training and validation set. Is this necessary?

train_seq, val_seq, train_labels, val_labels = train_test_split(sequences, labels, test_size=.1)

print(f"Training size: {len(train_seq)} Validation size: {len(val_seq)}")

Training size: 480 Validation size: 54


In [9]:
# Tokenizer:
tokenizer = ESMTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D",
                                         do_lower_case=False)

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'EsmTokenizer'. 
The class this function is called from is 'ESMTokenizer'.


In [10]:
class SequenceDataset(Dataset):
    def __init__(self, input_sequences, input_labels, categories=True):
        
        # Init is run once, when instantiating the dataset class.
        #
        # Either supply with:
        #  Main categories - category classification
        #  Accession numbers - class classifcation 
        
        # The xx2label turns the label from text to a number from 0-(N-1) 
        if categories:
            self.labels = [main2label[cat] for cat in input_labels]
        else:
            self.labels = [class2label[acc2class[acc]] for acc in input_labels]
        
        # Tokenize sequence and pad to longest sequence in dataset.
        # Return pytorch-type tensors
        self.sequences = tokenizer(
                                input_sequences,
                                padding = 'longest',
                                return_tensors = 'pt')
        # Save label type
        self.label_type_cat = categories
        
    def classes(self):
        
        # Returns the classes in the dataset (optional function)
        return self.labels

    def __len__(self):
        
        # Returns the number of samples in dataset (required)
        return len(self.labels)

    def __getitem__(self, idx):
        
        # Returns a sample at position idx (required)
        # The sample includes:
        # - Input id's for the sequence
        # - Attention mask (to only focus on the sequence and not padding)
        # - Label (one-hot encoded)
        
        input_ids = self.sequences['input_ids'][idx]
        attention_mask = self.sequences['attention_mask'][idx]
        label = torch.tensor(self.labels[idx])
        num_labels = len(main2label.values()) if self.label_type_cat else len(class2label.values())
        
        sample = dict()
        sample['input_ids'] = input_ids
        sample['attention_mask'] = attention_mask
        sample['label'] = one_hot(label,
                                  num_classes = num_labels).to(torch.float)
        return sample

In [11]:
# Transform data to dataset class:
train_dataset = SequenceDataset(train_seq, train_labels)
val_dataset = SequenceDataset(val_seq, val_labels)


In [22]:
# SETTINGS:

# Model specifications:
num_labels = len(set(train_dataset.classes()))
model = ESMForSequenceClassification.from_pretrained(
    "facebook/esm2_t6_8M_UR50D",
    num_labels = num_labels,
    problem_type = "multi_label_classification")

# Data loading:
# The DataLoader splits the dataset by batch size,
# and returns an iter to go through each batch of samples.

epochs = 1
batch_size = 1

train_loader = DataLoader(train_dataset,
                          batch_size=batch_size,
                          shuffle=True)
val_loader = DataLoader(val_dataset,
                        batch_size=batch_size,
                        shuffle=True)

# Optimizer, learning rate and optional learning rate decay.

learning_rate = 5e-5
optimizer = AdamW(model.parameters(),
                  lr = learning_rate)
scheduler = linear_scheduler(optimizer,
                             num_warmup_steps = 0,
                             num_training_steps = len(train_loader) * epochs)

# If you don't want decay - uncomment this:
# scheduler = constant_scheduler(optimizer,
#                                num_warmup_steps = 0)


Some weights of the model checkpoint at facebook/esm2_t6_8M_UR50D were not used when initializing ESMForSequenceClassification: ['esm.encoder.layer.0.attention.self.rotary_embeddings.inv_freq', 'lm_head.dense.bias', 'lm_head.bias', 'esm.encoder.layer.3.attention.self.rotary_embeddings.inv_freq', 'lm_head.layer_norm.weight', 'esm.encoder.layer.2.attention.self.rotary_embeddings.inv_freq', 'lm_head.layer_norm.bias', 'lm_head.decoder.weight', 'lm_head.dense.weight', 'esm.encoder.layer.4.attention.self.rotary_embeddings.inv_freq', 'esm.encoder.layer.1.attention.self.rotary_embeddings.inv_freq', 'esm.encoder.layer.5.attention.self.rotary_embeddings.inv_freq']
- This IS expected if you are initializing ESMForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ESMForSequenceClassification from the check

In [24]:
# Training
device = torch.device("cpu")
model.to(device)
total_t0 = time.time()

for epoch in range(epochs):
    
    print(f"#------ EPOCH {epoch+1} of {epochs}")
    print(f"\nTraining\n")
    # Training step
    
    t0 = time.time()
    model.train()
    total_train_loss = 0

    for step, batch in enumerate(train_loader):
        
        if step % 5 == 0:
            
            time_elapsed = time_step(time.time() - t0)
            print(f"#--- Batch {step+1} of {len(train_loader)}\nTime: {time_elapsed}")
        
        optimizer.zero_grad()
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['label'].to(device)

        output = model(input_ids,
                       token_type_ids=None,
                       attention_mask=attention_mask,
                       labels=labels)

        total_train_loss += output.loss

        output.loss.backward()
        optimizer.step()
        scheduler.step() # Update learning rate

    avg_train_loss = total_train_loss / len(train_loader)
    print(f"Average training loss: {avg_train_loss:.2f}")

    # Evaluation step
    
    print(f"\nValidating\n")
    
    model.eval()
    t0 = time.time()
    total_val_loss = 0
    total_val_accuracy = 0
    total_balanced_val_accuracy = 0

    for batch in val_loader:

        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['label'].to(device)

        # Dont construct unnecessary compute-graph:
        with torch.no_grad():
            output = model(input_ids,
                           token_type_ids=None,
                           attention_mask=attention_mask,
                           labels=labels)
        
        total_val_loss += output.loss
        
        logits = output.logits.detach().cpu().numpy()
        labels = labels.to('cpu').numpy()
        

        
        total_val_accuracy += accuracy_score(labels.argmax(axis=1),
                                             logits.argmax(axis=1))
        
        total_balanced_val_accuracy += balanced_accuracy_score(
                                        labels.argmax(axis=1),
                                        logits.argmax(axis=1))
    
    avg_val_loss = total_val_loss / len(val_loader)
    avg_val_accuracy = total_val_accuracy / len(val_loader)
    avg_balanced_val_accuracy = total_balanced_val_accuracy / len(val_loader)
    validation_time = time_step(time.time() - t0)
    
    print(f"Average validation loss: {avg_val_loss:.2f}")
    print(f"Average validation accuracy: {avg_val_accuracy:.2f}")
    print(f"Average balanced validation accuracy: {avg_balanced_val_accuracy:.2f}")
    print(f"Validation time: {validation_time}\n")
    
print("Success!\n")
print(f"Total training time: {time_step(time.time() - total_t0)}")

#------ EPOCH 1 of 1

Training

#--- Batch 1 of 480
Time: 0:00:00
#--- Batch 6 of 480
Time: 0:00:04
#--- Batch 11 of 480
Time: 0:00:08
#--- Batch 16 of 480
Time: 0:00:11
#--- Batch 21 of 480
Time: 0:00:15
#--- Batch 26 of 480
Time: 0:00:18
#--- Batch 31 of 480
Time: 0:00:22
#--- Batch 36 of 480
Time: 0:00:25
#--- Batch 41 of 480
Time: 0:00:29
#--- Batch 46 of 480
Time: 0:00:32
#--- Batch 51 of 480
Time: 0:00:35
#--- Batch 56 of 480
Time: 0:00:39
#--- Batch 61 of 480
Time: 0:00:42
#--- Batch 66 of 480
Time: 0:00:46
#--- Batch 71 of 480
Time: 0:00:49
#--- Batch 76 of 480
Time: 0:00:53
#--- Batch 81 of 480
Time: 0:00:56
#--- Batch 86 of 480
Time: 0:01:00
#--- Batch 91 of 480
Time: 0:01:03
#--- Batch 96 of 480
Time: 0:01:07
#--- Batch 101 of 480
Time: 0:01:11
#--- Batch 106 of 480
Time: 0:01:14
#--- Batch 111 of 480
Time: 0:01:18
#--- Batch 116 of 480
Time: 0:01:21
#--- Batch 121 of 480
Time: 0:01:25
#--- Batch 126 of 480
Time: 0:01:29
#--- Batch 131 of 480
Time: 0:01:32
#--- Batch 136 of 

