# Fine-tuning protein language models

In [None]:
# Uncomment & execute once to download data
# https://services.healthtech.dtu.dk/services/DeepLocPro-1.0/
#!mkdir -p data
#!curl https://services.healthtech.dtu.dk/services/DeepLocPro-1.0/data/graphpart_set.fasta -o data/graphpart_set.fasta
#!curl https://services.healthtech.dtu.dk/services/DeepLocPro-1.0/data/benchmarking_dataset.fasta -o data/benchmarking_dataset.fasta

In [None]:
# Q6. Describe the problem of predicting the subcellular location of (prokaryotic) proteins as described in Moreno2024 (https://doi.org/10.1101/2024.01.04.574157)

In [None]:
import numpy as np, pandas as pd, sklearn.preprocessing
import Bio.SeqIO.FastaIO # Biopython for reading fasta files
import datasets, evaluate, transformers # Hugging Face libraries https://doi.org/10.18653/v1/2020.emnlp-demos.6

In [None]:
# Q7. How were the training/benchmark data sets constructed? How were the cross-validation folds defined?
def read_DeepLocPro(file, columns=None):
    with open(file) as handle:
        fasta_cols = ['header', 'sequence']
        df = pd.DataFrame.from_records([values for values in Bio.SeqIO.FastaIO.SimpleFastaParser(handle)], columns=fasta_cols)
    if columns is None:
        return df
    else:
        df[columns] = df['header'].str.split('|', expand=True)
        return df[columns + ['sequence']]

columns = ['uniprot_id', 'subcellular_location', 'organism_group']
df_graphpart = read_DeepLocPro('data/graphpart_set.fasta', columns=columns + ['fold_id'])
df_benchmarking = read_DeepLocPro('data/benchmarking_dataset.fasta', columns=columns)
df_graphpart

In [None]:
subcellular_location_encoder = sklearn.preprocessing.LabelEncoder()
subcellular_location_encoder.fit(df_graphpart['subcellular_location'])
for df in df_benchmarking, df_graphpart:
    df['label'] = subcellular_location_encoder.transform(df['subcellular_location'])

In [None]:
# Q7. How were the training/benchmark data sets constructed? How were the cross-validation folds defined?
# Subsample training/eval data from the homology-partitioned sequences in the preprint
random_number = 4 # https://xkcd.com/221/
train_query = 'fold_id == "1" or fold_id == "2" or fold_id == "3" or fold_id == "4"'
df_train = df_graphpart.query(train_query).groupby('subcellular_location').sample(n=10, random_state=random_number)
df_eval = df_graphpart.query(f'~({train_query})').groupby('subcellular_location').sample(n=10, random_state=random_number)
print(len(df_train), 'records in training data:')
print(df_train['subcellular_location'].value_counts())
print()
print(len(df_eval), 'records in eval data:')
print(df_eval['subcellular_location'].value_counts())
print()
print(len(df_benchmarking), 'records in benchmarking data:')
print(df_benchmarking['subcellular_location'].value_counts())

In [None]:
model_checkpoint = 'facebook/esm2_t6_8M_UR50D'
tokenizer = transformers.AutoTokenizer.from_pretrained(model_checkpoint)

In [None]:
train_tokenized = tokenizer(df_train['sequence'].tolist(), truncation=True, max_length=1024)
eval_tokenized = tokenizer(df_eval['sequence'].tolist(), truncation=True, max_length=1024)
benchmarking_tokenized = tokenizer(df_benchmarking['sequence'].tolist(), truncation=True, max_length=1024)

In [None]:
train_dataset = datasets.Dataset.from_dict(train_tokenized).add_column('labels', df_train['label'].tolist())
eval_dataset = datasets.Dataset.from_dict(eval_tokenized).add_column('labels', df_eval['label'].tolist())
benchmarking_dataset = datasets.Dataset.from_dict(benchmarking_tokenized).add_column('labels', df_benchmarking['label'].tolist())
benchmarking_dataset

In [None]:
# Q8. Describe the difference between EsmModel, and EsmForSequenceClassification?
model = transformers.AutoModelForSequenceClassification.from_pretrained(model_checkpoint, num_labels=df_benchmarking['label'].nunique())
#model

In [None]:
model_esm = transformers.EsmModel.from_pretrained(model_checkpoint)
#model_esm

In [None]:
#model_name = model_checkpoint.split('/')[-1]

args = transformers.TrainingArguments(
    #f'{model_name}-subcellular_location',
    output_dir='esm2_subcellular_location',
    evaluation_strategy='epoch',
    save_strategy='epoch',
    load_best_model_at_end=True,
    metric_for_best_model='accuracy',
    #per_device_train_batch_size=4,
    #per_device_eval_batch_size=4,
)

In [None]:
# The paper uses accuracy and macro F1 score to characterise the performance; we will trace both throughout the training
metric_accuracy = evaluate.load('accuracy')
metric_f1 = evaluate.load('f1')

def compute_metrics(eval_pred): # https://huggingface.co/docs/transformers/en/training#evaluate
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=1)
    return {
        'accuracy': metric_accuracy.compute(predictions=predictions, references=labels)['accuracy'],
        'f1_macro': metric_f1.compute(predictions=predictions, references=labels, average='macro')['f1'],
    }

In [None]:
trainer = transformers.Trainer(
    model,
    args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)

In [None]:
# We can now fine-tune the network, reporting the performance at the end of every epoch
retrained = trainer.train()
retrained

In [None]:
# Q9. How did the parameters change during re-training? Compare (a subset) of weights in the (retrained) model to model_esm

In [None]:
# We evaluate the fine-tuned model on the benchmark data set (globally)
# How does the fine-tuned model compare to DeepLocPro as reported in Table 3 of the preprint?
trainer.evaluate(eval_dataset=benchmarking_dataset)

In [None]:
# We'll take a closer look at the individual predictions
benchmarking_predictions = trainer.predict(test_dataset=benchmarking_dataset)
df_benchmarking['label_predicted'] = np.argmax(benchmarking_predictions.predictions, axis=-1)
print(len(df_benchmarking.query('label == label_predicted')))
df_benchmarking

In [None]:
# Show table with performance metrics split by organism to match Table 3 in preprint
def apply_(df):
    return pd.Series({
        'size': len(df),
        'accuracy': metric_accuracy.compute(predictions=df.label_predicted.values, references=df.label.values)['accuracy'],
        'f1_macro': metric_f1.compute(predictions=df.label_predicted.values, references=df.label.values, average='macro')['f1'],
    })

# Q10. Re-train on the whole data; compare to DeepLoc Pro
print(df_benchmarking.groupby('organism_group').apply(apply_).transpose()[['archaea', 'positive', 'negative']].to_string(float_format='%.2f'))