# Fine-tuning protein language models

In [1]:
# Uncomment & execute once to download data
# https://services.healthtech.dtu.dk/services/DeepLocPro-1.0/
#!mkdir -p data
#!curl https://services.healthtech.dtu.dk/services/DeepLocPro-1.0/data/graphpart_set.fasta -o data/graphpart_set.fasta
#!curl https://services.healthtech.dtu.dk/services/DeepLocPro-1.0/data/benchmarking_dataset.fasta -o data/benchmarking_dataset.fasta

In [2]:
# Q6. Describe the problem of predicting the subcellular location of (prokaryotic) proteins as described in Moreno2024 (https://doi.org/10.1101/2024.01.04.574157)

In [3]:
import numpy as np, pandas as pd, sklearn.preprocessing
import Bio.SeqIO.FastaIO # Biopython for reading fasta files
import datasets, evaluate, transformers # Hugging Face libraries https://doi.org/10.18653/v1/2020.emnlp-demos.6
from IPython.core.display import display, HTML



In [4]:
# Q7. How were the training/benchmark data sets constructed? How were the cross-validation folds defined?
def read_DeepLocPro(file, columns=None):
    with open(file) as handle:
        fasta_cols = ['header', 'sequence']
        df = pd.DataFrame.from_records([values for values in Bio.SeqIO.FastaIO.SimpleFastaParser(handle)], columns=fasta_cols)
    if columns is None:
        return df
    else:
        df[columns] = df['header'].str.split('|', expand=True)
        return df[columns + ['sequence']]

columns = ['uniprot_id', 'subcellular_location', 'organism_group']
df_graphpart = read_DeepLocPro('data/graphpart_set.fasta', columns=columns + ['fold_id']).astype({'fold_id': int}).sort_values('fold_id').reset_index(drop=True)
df_graphpart

Unnamed: 0,uniprot_id,subcellular_location,organism_group,fold_id,sequence
0,Q8A0Z3,Cytoplasmic,negative,0,MAVTMADITKLRKMTGAGMMDCKNALTEAEGDYDKAMEIIRKKGQA...
1,Q8A2N8,Cytoplasmic,negative,0,MIMSKETLIKSIREIPDFPIPGILFYDVTTLFKDPWCLQELSNIMF...
2,P32709,CYtoplasmicMembrane,negative,0,MTQTSAFHFESLVWDWPIAIYLFLIGISAGLVTLAVLLRRFYPQAG...
3,E7FHF8,Cytoplasmic,archaea,0,MKLGVFELTDCGGCALNLLFLYDKLLDLLEFYEIAEFHMATSKKSR...
4,E7FHU4,Cytoplasmic,archaea,0,MGKVRIGFYALTSCYGCQLQLAMMDELLQLIPNAEIVCWFMIDRDS...
...,...,...,...,...,...
11901,Q97F85,Cytoplasmic,positive,4,MRKLFTSESVTEGHPDKICDQISDAILDAILEKDPNGRVACETTVT...
11902,P33656,Cytoplasmic,positive,4,MKNKTEVKNGGEKKNSKKVSKEESAKEKNEKMKIVKNLIDKGKKSG...
11903,P13949,OuterMembrane,negative,4,MCALDRRERPLNSQSVNKYILNVQNIYRNSPVPVCVRNKNRKILYA...
11904,P42185,Extracellular,negative,4,MRLRFSVPLFFFGCVFVHGVFAGPFPPPGMSLPEYWGEEHVWWDGR...


In [5]:
subcellular_location_encoder = sklearn.preprocessing.LabelEncoder()
subcellular_location_encoder.fit(df_graphpart['subcellular_location'])
df_graphpart['label'] = subcellular_location_encoder.transform(df_graphpart['subcellular_location'])

In [6]:
# Q7. How were the training/benchmark data sets constructed? How were the cross-validation folds defined?
# Subsample training/eval data from the homology-partitioned sequences in the preprint
random_number = 4 # https://xkcd.com/221/
train_id = [2, 3, 4]
eval_id = [0]
test_id = [1]

df_train = df_graphpart.query('fold_id in @train_id')#.groupby('subcellular_location').sample(n=20, random_state=random_number)
df_eval = df_graphpart.query('fold_id in @eval_id')#.groupby('subcellular_location').sample(n=10, random_state=random_number)
df_test = df_graphpart.query('fold_id in @test_id')
print(len(df_train), 'records in training data:')
print(df_train['subcellular_location'].value_counts())
print()
print(len(df_eval), 'records in eval data:')
print(df_eval['subcellular_location'].value_counts())
print()
print(len(df_test), 'records in test data:')
print(df_test['subcellular_location'].value_counts())

6668 records in training data:
subcellular_location
Cytoplasmic            3707
CYtoplasmicMembrane    1488
Extracellular           625
OuterMembrane           457
Periplasmic             341
Cellwall                 50
Name: count, dtype: int64

2554 records in eval data:
subcellular_location
Cytoplasmic            1540
CYtoplasmicMembrane     525
Extracellular           195
OuterMembrane           163
Periplasmic             110
Cellwall                 21
Name: count, dtype: int64

2684 records in test data:
subcellular_location
Cytoplasmic            1638
CYtoplasmicMembrane     522
Extracellular           257
OuterMembrane           136
Periplasmic             115
Cellwall                 16
Name: count, dtype: int64


In [7]:
model_checkpoint = 'facebook/esm2_t6_8M_UR50D'
tokenizer = transformers.AutoTokenizer.from_pretrained(model_checkpoint)

In [8]:
train_tokenized = tokenizer(df_train['sequence'].tolist(), truncation=True, max_length=1024)
eval_tokenized = tokenizer(df_eval['sequence'].tolist(), truncation=True, max_length=1024)
test_tokenized = tokenizer(df_test['sequence'].tolist(), truncation=True, max_length=1024)

In [9]:
train_dataset = datasets.Dataset.from_dict(train_tokenized).add_column('labels', df_train['label'].tolist())
eval_dataset = datasets.Dataset.from_dict(eval_tokenized).add_column('labels', df_eval['label'].tolist())
test_dataset = datasets.Dataset.from_dict(test_tokenized).add_column('labels', df_test['label'].tolist())

In [10]:
# Q8. Describe the difference between EsmModel, and EsmForSequenceClassification?
model = transformers.AutoModelForSequenceClassification.from_pretrained(model_checkpoint, num_labels=df_graphpart['label'].nunique())
#model

Some weights of EsmForSequenceClassification were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [11]:
#model_name = model_checkpoint.split('/')[-1]

args = transformers.TrainingArguments(
    #f'{model_name}-subcellular_location',
    output_dir='esm2_subcellular_location',
    eval_strategy='epoch',
    save_strategy='epoch',
    load_best_model_at_end=True,
    metric_for_best_model='accuracy',
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
)

In [12]:
# The paper uses accuracy and macro F1 score to characterise the performance; we will trace both throughout the training
metric_accuracy = evaluate.load('accuracy')
metric_f1 = evaluate.load('f1')

def compute_metrics(eval_pred): # https://huggingface.co/docs/transformers/en/training#evaluate
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=1)
    return {
        'accuracy': metric_accuracy.compute(predictions=predictions, references=labels)['accuracy'],
        'f1_macro': metric_f1.compute(predictions=predictions, references=labels, average='macro')['f1'],
    }

Downloading builder script:   0%|          | 0.00/6.79k [00:00<?, ?B/s]

In [13]:
trainer = transformers.Trainer(
    model,
    args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)

  trainer = transformers.Trainer(


In [14]:
# We can now fine-tune the network, reporting the performance at the end of every epoch
retrained = trainer.train()
retrained

Epoch,Training Loss,Validation Loss,Accuracy,F1 Macro
1,0.4642,0.426776,0.89076,0.650232
2,0.3659,0.496763,0.891151,0.669842
3,0.2541,0.501464,0.891151,0.694605


TrainOutput(global_step=5001, training_loss=0.3957060627409117, metrics={'train_runtime': 438.5219, 'train_samples_per_second': 45.617, 'train_steps_per_second': 11.404, 'total_flos': 598238789673168.0, 'train_loss': 0.3957060627409117, 'epoch': 3.0})

In [15]:
#Q. Accuracy/F1 Macro are calculated on what? train, test or eval?

In [16]:
# Q9. How did the parameters change during re-training? Compare (a subset) of weights in the (retrained) model to model_esm

In [17]:
# We evaluate the fine-tuned model on the benchmark data set (globally)
# How does the fine-tuned model compare to DeepLocPro as reported in Table 3 of the preprint?
trainer.evaluate(eval_dataset=test_dataset)

{'eval_loss': 0.4456363320350647,
 'eval_accuracy': 0.9016393442622951,
 'eval_f1_macro': 0.6649352551667449,
 'eval_runtime': 20.378,
 'eval_samples_per_second': 131.71,
 'eval_steps_per_second': 32.928,
 'epoch': 3.0}

In [18]:
fold_id = set(df_graphpart.fold_id)
graphpart_labels = []
for test_id in sorted(fold_id):
    eval_id = (test_id + 1) % 5
    train_id = fold_id - set([eval_id, test_id])

    df_train = df_graphpart.query('fold_id in @train_id')#.groupby('subcellular_location').sample(n=10, random_state=random_number)
    df_eval = df_graphpart.query('fold_id == @eval_id')
    df_test = df_graphpart.query('fold_id == @test_id')
    print(train_id, eval_id, test_id, len(df_train), len(df_eval), len(df_test))

    model_checkpoint = 'facebook/esm2_t6_8M_UR50D'
    tokenizer = transformers.AutoTokenizer.from_pretrained(model_checkpoint)

    train_tokenized = tokenizer(df_train['sequence'].tolist(), truncation=True, max_length=1024)
    eval_tokenized = tokenizer(df_eval['sequence'].tolist(), truncation=True, max_length=1024)
    test_tokenized = tokenizer(df_test['sequence'].tolist(), truncation=True, max_length=1024)

    train_dataset = datasets.Dataset.from_dict(train_tokenized).add_column('labels', df_train['label'].tolist())
    eval_dataset = datasets.Dataset.from_dict(eval_tokenized).add_column('labels', df_eval['label'].tolist())
    test_dataset = datasets.Dataset.from_dict(test_tokenized).add_column('labels', df_test['label'].tolist())

    model = transformers.AutoModelForSequenceClassification.from_pretrained(model_checkpoint, num_labels=df_graphpart['label'].nunique())
    #model_name = model_checkpoint.split('/')[-1]
    args = transformers.TrainingArguments(
        output_dir=f'esm2-subcellular_location-{eval_id}',
        eval_strategy='epoch',
        save_strategy='epoch',
        load_best_model_at_end=True,
        metric_for_best_model='accuracy',
        per_device_train_batch_size=4,
        per_device_eval_batch_size=4,
    )

    trainer = transformers.Trainer(
        model,
        args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        tokenizer=tokenizer,
        compute_metrics=compute_metrics,
    )

    retrained = trainer.train()
    test_predictions = trainer.predict(test_dataset=test_dataset)
    test_labels = np.argmax(test_predictions.predictions, axis=-1)
    graphpart_labels += list(test_labels)

{2, 3, 4} 1 0 6668 2684 2554


Some weights of EsmForSequenceClassification were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  trainer = transformers.Trainer(


Epoch,Training Loss,Validation Loss,Accuracy,F1 Macro
1,0.476,0.399561,0.904247,0.670197
2,0.3759,0.481482,0.895678,0.6636
3,0.2387,0.466033,0.90313,0.672721


{0, 3, 4} 2 1 7124 2098 2684


Some weights of EsmForSequenceClassification were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  trainer = transformers.Trainer(


Epoch,Training Loss,Validation Loss,Accuracy,F1 Macro
1,0.4722,0.464613,0.881792,0.649585
2,0.312,0.51472,0.879886,0.716558
3,0.2417,0.496437,0.889895,0.741874


{0, 1, 4} 3 2 7229 2579 2098


Some weights of EsmForSequenceClassification were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  trainer = transformers.Trainer(


Epoch,Training Loss,Validation Loss,Accuracy,F1 Macro
1,0.4623,0.527522,0.868941,0.635016
2,0.3235,0.494091,0.887166,0.654278
3,0.255,0.523508,0.883676,0.664307


{0, 1, 2} 4 3 7336 1991 2579


Some weights of EsmForSequenceClassification were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  trainer = transformers.Trainer(


Epoch,Training Loss,Validation Loss,Accuracy,F1 Macro
1,0.4834,0.629983,0.844802,0.638483
2,0.3105,0.637493,0.866399,0.693083
3,0.2177,0.64462,0.87343,0.740562


{1, 2, 3} 0 4 7361 2554 1991


Some weights of EsmForSequenceClassification were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  trainer = transformers.Trainer(


Epoch,Training Loss,Validation Loss,Accuracy,F1 Macro
1,0.4858,0.442423,0.88567,0.643349
2,0.3114,0.51118,0.888802,0.682085
3,0.2155,0.500282,0.898199,0.733145


In [19]:
df_graphpart['label_predicted'] = graphpart_labels
print(len(df_graphpart))
print(metric_accuracy.compute(predictions=df_graphpart.label_predicted.values, references=df_graphpart.label.values)['accuracy'])
print(metric_f1.compute(predictions=df_graphpart.label_predicted.values, references=df_graphpart.label.values, average='macro')['f1'])

11906
0.891819250797917
0.7068229932223565


In [40]:
# Show table with performance metrics split by organism to match Table 3 in preprint
def calculate_stats_(df):
    accuracy = metric_accuracy.compute(predictions=df.label_predicted.values, references=df.label.values)['accuracy']
    f1_macro = metric_f1.compute(predictions=df.label_predicted.values, references=df.label.values, average='macro')['f1']
    return pd.Series({
        'size': '{:d}'.format(len(df)),
        'accuracy': '{:.2f}'.format(accuracy),
        'f1_macro': '{:.2f}'.format(f1_macro),
    })

# Q10. Re-train on the whole data; compare to DeepLoc Pro
#print(df_graphpart.groupby('organism_group').apply(apply_).transpose()[['archaea', 'positive', 'negative']].to_string(float_format='%.2f'))
#display(HTML(df_graphpart.groupby('organism_group').apply(apply_).transpose()[['archaea', 'positive', 'negative']].to_html()))
pd.concat([
    calculate_stats_(df_graphpart),
    calculate_stats_(df_graphpart.query('organism_group == "archaea"')),
    calculate_stats_(df_graphpart.query('organism_group == "positive"')),
    calculate_stats_(df_graphpart.query('organism_group == "negative"')),
], axis=1).set_axis(['Overall', 'Archaea' , 'Gram pos', 'Gram neg'], axis=1)

Unnamed: 0,Overall,Archaea,Gram pos,Gram neg
size,11906.0,283.0,3206.0,8417.0
accuracy,0.89,0.84,0.92,0.88
f1_macro,0.71,0.43,0.48,0.69
