# Clean up fine-tuning steps

In [1]:
import torch
import gc
import esm
from esm import pretrained

!pip install fair-esm

# Clear cache
gc.collect()
torch.cuda.empty_cache()

# Load pretrained model
model, alphabet = esm.pretrained.esm2_t6_8M_UR50D()

# Check layers in architecture
for name, param in model.named_parameters():
    print(f"Parameter name: {name}, Size: {param.size()}")

# Set device to use GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

Parameter name: embed_tokens.weight, Size: torch.Size([33, 320])
Parameter name: layers.0.self_attn.k_proj.weight, Size: torch.Size([320, 320])
Parameter name: layers.0.self_attn.k_proj.bias, Size: torch.Size([320])
Parameter name: layers.0.self_attn.v_proj.weight, Size: torch.Size([320, 320])
Parameter name: layers.0.self_attn.v_proj.bias, Size: torch.Size([320])
Parameter name: layers.0.self_attn.q_proj.weight, Size: torch.Size([320, 320])
Parameter name: layers.0.self_attn.q_proj.bias, Size: torch.Size([320])
Parameter name: layers.0.self_attn.out_proj.weight, Size: torch.Size([320, 320])
Parameter name: layers.0.self_attn.out_proj.bias, Size: torch.Size([320])
Parameter name: layers.0.self_attn_layer_norm.weight, Size: torch.Size([320])
Parameter name: layers.0.self_attn_layer_norm.bias, Size: torch.Size([320])
Parameter name: layers.0.fc1.weight, Size: torch.Size([1280, 320])
Parameter name: layers.0.fc1.bias, Size: torch.Size([1280])
Parameter name: layers.0.fc2.weight, Size: tor

ESM2(
  (embed_tokens): Embedding(33, 320, padding_idx=1)
  (layers): ModuleList(
    (0-5): 6 x TransformerLayer(
      (self_attn): MultiheadAttention(
        (k_proj): Linear(in_features=320, out_features=320, bias=True)
        (v_proj): Linear(in_features=320, out_features=320, bias=True)
        (q_proj): Linear(in_features=320, out_features=320, bias=True)
        (out_proj): Linear(in_features=320, out_features=320, bias=True)
        (rot_emb): RotaryEmbedding()
      )
      (self_attn_layer_norm): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
      (fc1): Linear(in_features=320, out_features=1280, bias=True)
      (fc2): Linear(in_features=1280, out_features=320, bias=True)
      (final_layer_norm): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
    )
  )
  (contact_head): ContactPredictionHead(
    (regression): Linear(in_features=120, out_features=1, bias=True)
    (activation): Sigmoid()
  )
  (emb_layer_norm_after): LayerNorm((320,), eps=1e-05, elementwis

In [2]:
import numpy as np
import pandas as pd
import os
from torch.utils.data import DataLoader
from torch.utils.data import TensorDataset
import torch.optim as optim
import torch.nn as nn
from pathlib import Path
from esm.data import ESMStructuralSplitDataset
import sklearn
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score


# Download structural holdout datasets (already downloaded)
# for split_level in ['family', 'superfamily', 'fold']:
#     for cv_partition in ['0', '1', '2', '3', '4']:
#         esm_structural_train = ESMStructuralSplitDataset(
#             split_level=split_level, 
#             cv_partition=cv_partition, 
#             split='train', 
#             root_path = os.path.expanduser('~/.cache/torch/data/esm'),
#             download=True
#         )
#         esm_structural_valid = ESMStructuralSplitDataset(
#             split_level=split_level, 
#             cv_partition=cv_partition, 
#             split='valid', 
#             root_path = os.path.expanduser('~/.cache/torch/data/esm'),
#             download=True
#         )

# Set train and validation datasets from the ESM library
esm_structural_train = ESMStructuralSplitDataset(
    split_level='superfamily', 
    cv_partition='0', 
    split='train', 
    root_path = os.path.expanduser('~/.cache/torch/data/esm'),
)

esm_structural_valid = ESMStructuralSplitDataset(
    split_level='superfamily', 
    cv_partition='0', 
    split='valid', 
    root_path = os.path.expanduser('~/.cache/torch/data/esm'),
)

# Check how many entries in dictionary
print(len(esm_structural_train))
print(len(esm_structural_valid))

# Training dataset downloaded from ESMStructuralSplitDataset already split into training and validation sets
train_dataset = esm_structural_train
valid_dataset = esm_structural_valid

12031
3266


In [3]:
# Use ESM-2's batch converter for tokneisation/padding
batch_converter = alphabet.get_batch_converter()
# Pull only the sequences in entire dataset for conversion
train_data = [(i, train_dataset[i]["seq"]) for i in range(len(train_dataset))]
# Tokenise sequences
batch_labels, batch_strs, batch_tokens = batch_converter(train_data)

In [4]:
# Put batch token tensor in a tensor dataset object
training_dataset = TensorDataset(batch_tokens)

In [5]:
# Check the batch tensor dimensions
print(batch_tokens)
print(batch_tokens.shape)
print(type(batch_tokens))
print(len(batch_tokens))

tensor([[ 0, 20, 21,  ...,  1,  1,  1],
        [ 0,  8, 20,  ...,  1,  1,  1],
        [ 0, 20, 16,  ...,  1,  1,  1],
        ...,
        [ 0, 21, 19,  ...,  1,  1,  1],
        [ 0, 20,  6,  ...,  1,  1,  1],
        [ 0, 20, 10,  ...,  1,  1,  1]])
torch.Size([12031, 884])
<class 'torch.Tensor'>
12031


In [6]:
print(training_dataset)
print(type(training_dataset))
print(len(training_dataset))

<torch.utils.data.dataset.TensorDataset object at 0x0000020C8D173A60>
<class 'torch.utils.data.dataset.TensorDataset'>
12031


In [7]:
# Freeze all parameters of the pretrained model
for param in model.parameters():
    param.requires_grad = False

# Last layer output fatures decided from ESM-2's layers, emb_layer_norm_after?
hidden_size = 320

# Modify last layer for our regression task
model.final_layer = nn.Sequential(
    nn.Dropout(0.2),
    nn.Linear(in_features=hidden_size, out_features=1),
    nn.ReLU(),
)

# Enable gradient computation for the parameters in the final_layer
for param in model.final_layer.parameters():
    param.requires_grad = True

In [None]:
# # Removed dataloader due to dictionary and list issues in training loop, not enough memory for CUDA to process full dataset of 12031
# # Clear cache
# gc.collect()
# torch.cuda.empty_cache()

# learning_rate = 0.001
# optimizer = optim.Adam(model.parameters(), lr=learning_rate)
# loss_fn = nn.SmoothL1Loss()

# num_epochs = 10

# for epoch in range(num_epochs):
#     gc.collect()
#     torch.cuda.empty_cache()
#     model.train()
#     total_loss = 0
#     running_loss = 0.0
#     inputs = batch_tokens.to(device)
#     optimizer.zero_grad()
#     outputs = model(inputs)
#     loss = loss_fn(outputs, inputs)
#     loss.backward()
#     optimizer.step()
#     running_loss += loss.item()
#     total_loss += loss.item()
#     gc.collect()
#     torch.cuda.empty_cache()

#     average_loss = total_loss / len(batch_tokens)
#     print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {average_loss:.4f}")
#     print(f"Epoch {epoch+1} loss: {running_loss:.4f}")

In [8]:
# Clear cache
gc.collect()
torch.cuda.empty_cache()


learning_rate = 0.001
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
#loss_fn = nn.MSELoss()
loss_fn = nn.SmoothL1Loss()

num_epochs = 10

# Set a low batch size for memory efficency
batch_size = 8

train_dataloader = DataLoader(batch_tokens, batch_size=batch_size, shuffle=True)

for epoch in range(num_epochs):
    gc.collect()
    torch.cuda.empty_cache()
    model.train()
    total_loss = 0
    running_loss = 0.0

    for batch_tokens in train_dataloader:

        #batch_tokens = torch.stack(batch_token_list).to(device) list of tensors issue
        
        # Reshape tensor to 2D shape [batch_size, sequence_length] to pass assert tokens.ndim == 2
        inputs = batch_tokens.reshape(batch_size, -1).to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        output_tensor = outputs["logits"] # what tensor to access for loss function, maybe "logits"?
        
        # What is the origin of the loss input size (torch.Size([8, 884, 33])) is this from the model embedding or lm_head?
        # Parameter name: embed_tokens.weight, Size: torch.Size([33, 320])
        # Parameter name: lm_head.bias, Size: torch.Size([33])
        
        # RuntimeError: The size of tensor a (33) must match the size of tensor b (884) at non-singleton dimension 2
        # Using a target size (torch.Size([8, 884]))
        
        loss = loss_fn(output_tensor, inputs) # AttributeError: 'dict' object has no attribute 'size' - I accessed logits instead, does this need to be processed i.e. softmax?
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        total_loss += loss.item()
        gc.collect()
        torch.cuda.empty_cache()

    average_loss = total_loss / len(train_dataloader)
    print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {average_loss:.4f}")
    print(f"Epoch {epoch+1} loss: {running_loss:.4f}")

  return F.smooth_l1_loss(input, target, reduction=self.reduction, beta=self.beta)


RuntimeError: The size of tensor a (33) must match the size of tensor b (884) at non-singleton dimension 2

In [9]:
print(type(batch_tokens))

<class 'torch.Tensor'>


In [10]:
print(inputs)

tensor([[ 0, 20,  6,  ...,  1,  1,  1],
        [ 0,  6,  7,  ...,  1,  1,  1],
        [ 0, 17,  8,  ...,  1,  1,  1],
        ...,
        [ 0,  6,  8,  ...,  1,  1,  1],
        [ 0, 20,  8,  ...,  1,  1,  1],
        [ 0, 20, 15,  ...,  1,  1,  1]], device='cuda:0')


In [11]:
print(type(inputs))

<class 'torch.Tensor'>


In [12]:
print(inputs.shape)

torch.Size([8, 884])


In [13]:
print(type(outputs))

<class 'dict'>


In [14]:
print(outputs["logits"].shape)

torch.Size([8, 884, 33])


In [15]:
print(outputs)

{'logits': tensor([[[ 14.2351,  -7.3852,  -6.2489,  ..., -15.6001, -15.7696,  -7.3804],
         [ -7.4059, -15.0762,  -7.7357,  ..., -15.8675, -16.0993, -15.0751],
         [-11.7911, -19.6771, -10.7796,  ..., -16.2037, -16.2054, -19.6660],
         ...,
         [-10.3239, -17.0342, -10.4923,  ..., -16.2025, -16.2505, -17.0353],
         [-10.8626, -18.0419, -11.0570,  ..., -16.2881, -16.3453, -18.0377],
         [-10.8497, -18.7546, -11.6085,  ..., -16.3317, -16.3946, -18.7504]],

        [[ 15.3661,  -8.9175,  -6.0069,  ..., -15.3930, -15.5476,  -8.9246],
         [ -7.3944, -15.1793,  -6.9476,  ..., -15.6827, -15.9010, -15.1780],
         [-11.0365, -21.0386, -11.8327,  ..., -16.4465, -16.4803, -21.0372],
         ...,
         [-10.7051, -22.7104, -10.0268,  ..., -16.1865, -16.0903, -22.6981],
         [-12.1272, -23.1866, -11.0768,  ..., -16.1650, -16.1004, -23.1894],
         [-11.4274, -20.6906, -11.0352,  ..., -16.3272, -16.3195, -20.6995]],

        [[ 15.1409,  -7.3800,  -6

In [None]:
# ESM-2 tokeniser for data
batch_size = 8

valid_data = [(i, valid_dataset[i]["seq"]) for i in range(len(valid_dataset))]
v_batch_labels, v_batch_strs, v_batch_tokens = batch_converter(valid_data)
valid_dataloader = DataLoader(v_batch_tokens, batch_size=batch_size, shuffle=True)

# Set evaluation mode for dropout and batch normalisation
model.eval()

# Initialise lists   
predictions = []
true_contacts = []

# Turn off gradients for evaluation
with torch.no_grad():
    for batch in valid_dataloader:
        inputs = v_batch_tokens.to(device)  # Move batch to the specified device
        outputs = model(inputs)
        predictions.extend(outputs.tolist())

# Convert predictions and contact lists into tensors
predictions = torch.tensor(predictions)
true_contacts = torch.tensor(targets)

# Evaluation metrics
accuracy = accuracy_score(true_contacts, predictions)
precision = precision_score(true_contacts, predictions)
recall = recall_score(true_contacts, predictions)
f1 = f1_score(true_contacts, predictions)

print(f"Accuracy: {accuracy:.4f}")
print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")
print(f"F1-Score: {f1:.4f}")