# Clean up fine-tuning steps

In [None]:
!pip install fair-esm
# Clear cache
gc.collect()
torch.cuda.empty_cache()

model, alphabet = esm.pretrained.esm2_t6_8M_UR50D()
print(model)
for name, param in model.named_parameters():
    print(f"Parameter name: {name}, Size: {param.size()}")

In [None]:
import torch
import numpy as np
import pandas as pd
from torch.utils.data import DataLoader
import torch.optim as optim
import torch.nn as nn
from pathlib import Path
from esm import pretrained
from esm.data import ESMStructuralSplitDataset

# Set device to use GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# Download structural holdout datasets
for split_level in ['family', 'superfamily', 'fold']:
    for cv_partition in ['0', '1', '2', '3', '4']:
        esm_structural_train = ESMStructuralSplitDataset(
            split_level=split_level, 
            cv_partition=cv_partition, 
            split='train', 
            root_path = os.path.expanduser('~/.cache/torch/data/esm'),
            download=True
        )
        esm_structural_valid = ESMStructuralSplitDataset(
            split_level=split_level, 
            cv_partition=cv_partition, 
            split='valid', 
            root_path = os.path.expanduser('~/.cache/torch/data/esm'),
            download=True
        )

esm_structural_train = ESMStructuralSplitDataset(
    split_level='superfamily', 
    cv_partition='4', 
    split='train', 
    root_path = os.path.expanduser('~/.cache/torch/data/esm'),
)

esm_structural_valid = ESMStructuralSplitDataset(
    split_level='superfamily', 
    cv_partition='4', 
    split='valid', 
    root_path = os.path.expanduser('~/.cache/torch/data/esm'),
)

elet = esm_structural_train[0]
elev = esm_structural_valid[0]
print(elet.keys()) 
print('sequence', elet['seq'])
print('sequence', elev['seq'])

# Check how many entries in dictionary
print(len(esm_structural_train))
print(len(esm_structural_valid))

# Training dataset downloaded from ESMStructuralSplitDataset
train_dataset = esm_structural_train
valid_dataset = esm_structural_valid

In [None]:
# Freeze all parameters of the pretrained model
for param in model.parameters():
    param.requires_grad = False

# Modify only the last layer for regression task
model.contact_head.regression = nn.Linear(in_features=120, out_features=1)

# Set requires_grad=True only for the regression layer parameters to be trained
for param in model.contact_head.regression.parameters():
    param.requires_grad = True

In [None]:
batch_converter = alphabet.get_batch_converter()

train_data = [(i, train_dataset[i]["seq"]) for i in range(len(train_dataset))]

batch_labels, batch_strs, batch_tokens = batch_converter(train_data)

In [None]:
train_dataloader = DataLoader(batch_tokens, batch_size=16, shuffle=True)
# TypeError: default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found <class 'generator'>

learning_rate = 0.001
               
# Maybe try "AdamW" with weight decay
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=learning_rate)
               
# Set objective function. Huber less sensitive to outliers than MSEloss, maybe try MSE as well?        
loss_fn = nn.SmoothL1Loss()
      
# Training loop
num_epochs = 10
               
for epoch in range(num_epochs):
    model.train()
               
    # Initialise loss for each epoch   
    total_loss = 0
    
    for batch in train_dataloader:
        # Sequence inputs - IndexError: too many indices for tensor of dimension 2
        inputs = batch["batch_tokens"]
        # What are my targets in the dataset, is this just a placeholder to be populated, what key is this representing in my dataset?
        targets = batch["???"]
        # Clear gradients for each epoch
        optimizer.zero_grad()
        # Output predictions for batch 
        outputs = model(inputs)
        # Calculates Huber loss between predictions and true values
        loss = loss_fn(outputs, targets)
        # Pool loss values from each batch
        total_loss += loss.item()
        # Backpropagation
        loss.backward()
        # Updates last layer parameters to reduce loss
        optimizer.step()
    
    # Print loss per batch for the epoch 
    average_loss = total_loss / len(train_dataloader)
    print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {average_loss:.4f}")

In [None]:
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

# Need to decide what keys in the dictionary are relevant
valid_data = [(i, train_dataset[i]["seq"]) for i in range(len(valid_dataset))]
batch_labels, batch_strs, batch_tokens = batch_converter(valid_data)                           
valid_dataloader = DataLoader(batch_tokens, batch_size=16, shuffle=False)

# Set evaluation mode for dropout and batch normalisation
model.eval()

# Initialise lists               
predictions = []
true_contacts = []

# Turn off gradients for evaluation
with torch.no_grad(): 
    for batch in valid_dataloader:
        inputs = batch["batch_tokens"]
        targets = batch["???"]
        
        outputs = model(inputs)
        
        predictions.extend(outputs.tolist())
        true_contacts.extend(targets.tolist())

# Convert lists into tensors
predictions = torch.tensor(predictions)
true_contacts = torch.tensor(true_contacts)

# Evaluation metrics
accuracy = accuracy_score(true_contacts, predictions)
precision = precision_score(true_contacts, predictions)
recall = recall_score(true_contacts, predictions)
f1 = f1_score(true_contacts, predictions)

print(f"Accuracy: {accuracy:.4f}")
print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")
print(f"F1-Score: {f1:.4f}")