# Finetuning the ESM2 token classification model

### 0. Libraries & path

In [1]:
import torch
import evaluate
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from datasets import load_dataset, load_metric, Dataset, DatasetDict
from transformers import AutoTokenizer, AutoModelForTokenClassification
from transformers import TrainingArguments, Trainer, DataCollatorForTokenClassification
from seqeval.metrics import accuracy_score, f1_score, precision_score, recall_score

Init Plugin
Init Graph Optimizer
Init Kernel


In [2]:
# absolute path to project
absolute_path = '/Users/dimi/Documents/GitHub/PhageDEPOdetection/'

### 1. Functions

In [3]:
def get_labels(df , label = 1) :
    labels_df = []
    for _,row in df.iterrows():
        info = row["Boundaries"]
        seq_length = len(row["Full_seq"])
        if info == "Negative" :
            labels = [label] * seq_length
            labels_df.append(labels)
        elif info == "full_protein" or info == "full" :
            labels = [label] * seq_length
            labels_df.append(labels)
        elif info.count(":") > 0 : 
            start = int(info.split(":")[0])
            end = int(info.split(":")[1])
            labels = [0 if i < start or i >= end else label for i in range(seq_length)]
            labels_df.append(labels)
        else :
            start = int(info.split("_")[-2])
            end = int(info.split("_")[-1])
            labels = [0 if i < start or i >= end else label for i in range(seq_length)]
            labels_df.append(labels)
    return labels_df

In [4]:
def tokenize_function(inputs):
    """
    ESM finetuning needs input_ids, attention_mask, labels, position_ids.
    If the dataset does not contain these columns, we need to add them w this function.
    """
    return tokenizer(inputs['tokens'], add_special_tokens=False, padding='max_length', truncation=True)

In [5]:
## second attempt. If this doensn't work, problem with 0-preds that prevent recall etc..?
seqeval = evaluate.load("seqeval")
def compute_metrics(p):
    predictions, labels = p
    labels = labels.reshape((-1,))
    predictions = np.argmax(predictions, axis=2)
    predictions = predictions.reshape((-1,))
    predictions = predictions[labels!=-100]
    labels = labels[labels!=-100]
    results = seqeval.compute(predictions=predictions, references=labels)
    
    return {"precision": results["overall_precision"], "recall": results["overall_recall"], 
            "f1": results["overall_f1"], "accuracy": results["overall_accuracy"]}

### 2. Code to finetune

In [6]:
# load data
df_depo = pd.read_csv(absolute_path+'data/Phagedepo.Dataset.2007.tsv' , sep = '\t' , header = 0)

In [9]:
# assign and get the labels
df_beta_helix = df_depo[df_depo["Fold"] == "right-handed beta-helix"]
df_beta_prope = df_depo[df_depo["Fold"] == "6-bladed beta-propeller"]
df_beta_triple =  df_depo[df_depo["Fold"] == "triple-helix"]
df_negative = df_depo[df_depo["Fold"] == "Negative"]

labels_beta_helix = get_labels(df_beta_helix , label = 1)
seq_beta_helix = df_beta_helix["Full_seq"].to_list()
labels_beta_propeller = get_labels(df_beta_prope , label = 2)
seq_beta_propeller = df_beta_prope["Full_seq"].to_list()
labels_triple_helix = get_labels(df_beta_triple , label = 1)
seq_triple_helix = df_beta_triple["Full_seq"].to_list()
labels_negative = get_labels(df_negative , label = 0)
seq_negative = df_negative["Full_seq"].to_list()

# final input data
sequences = seq_beta_helix + seq_beta_propeller + seq_triple_helix + seq_negative
labels = labels_beta_helix + labels_beta_propeller + labels_triple_helix + labels_negative

In [8]:
# accomodate for max length of ESM2
max_length = 1024
sequences = [seq[:max_length] for seq in sequences]
labels = [lab[:max_length] for lab in labels]

In [11]:
# train-test split
train_sequences, test_sequences, train_labels, test_labels = train_test_split(sequences, labels, test_size=0.2, random_state = 243)
train_esm2 , train_CNV , esm2_labels , CNV_labels = train_test_split(train_sequences, train_labels, test_size=0.25, random_state = 243)

In [12]:
# initialize the tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")
data_collator = DataCollatorForTokenClassification(tokenizer)
model = AutoModelForTokenClassification.from_pretrained("facebook/esm2_t6_8M_UR50D", num_labels=3)

Some weights of the model checkpoint at facebook/esm2_t6_8M_UR50D were not used when initializing EsmForTokenClassification: ['lm_head.layer_norm.weight', 'lm_head.layer_norm.bias', 'lm_head.decoder.weight', 'lm_head.dense.weight', 'lm_head.dense.bias', 'lm_head.bias']
- This IS expected if you are initializing EsmForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing EsmForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of EsmForTokenClassification were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task 

In [13]:
# tokenize the data and make a dataset from it
train_tokenized = tokenizer(train_esm2, add_special_tokens=False, padding='max_length', truncation=True, max_length=1024)
test_tokenized = tokenizer(test_sequences, add_special_tokens=False, padding='max_length', truncation=True, max_length=1024)

train_dataset = Dataset.from_dict(train_tokenized)
test_dataset = Dataset.from_dict(test_tokenized)
train_dataset = train_dataset.add_column("labels", esm2_labels)
test_dataset = test_dataset.add_column("labels", test_labels)

In [14]:
train_dataset

Dataset({
    features: ['input_ids', 'attention_mask', 'labels'],
    num_rows: 1373
})

In [12]:
# set the training arguments
training_args = TrainingArguments(
    output_dir=absolute_path+'data/finetune',
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=4, #16
    per_device_eval_batch_size=8, #16
    num_train_epochs=2,
    weight_decay=0.001,
    #load_best_model_at_end=True,
    metric_for_best_model='f1',
)

In [13]:
# set the trainer and train
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)
trainer.train()

***** Running training *****
  Num examples = 1373
  Num Epochs = 2
  Instantaneous batch size per device = 4
  Total train batch size (w. parallel, distributed & accumulation) = 4
  Gradient Accumulation steps = 1
  Total optimization steps = 688
  Number of trainable parameters = 7738364


Epoch,Training Loss,Validation Loss


***** Running Evaluation *****
  Num examples = 458
  Batch size = 8


ValueError: Predictions and/or references don't match the expected format.
Expected format: {'predictions': Sequence(feature=Value(dtype='string', id='label'), length=-1, id='sequence'), 'references': Sequence(feature=Value(dtype='string', id='label'), length=-1, id='sequence')},
Input predictions: [0 0 0 ... 0 0 0],
Input references: [0 0 0 ... 0 0 0]

In [None]:
# save the model & tokenizer
trainer.save_model('data/finetune')
tokenizer.save_pretrained('data/finetune')