# Clean up fine-tuning steps

In [1]:
import torch
import gc
import esm
from esm import pretrained

!pip install fair-esm

# Clear cache
gc.collect()
torch.cuda.empty_cache()

# Set device to use GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load pretrained model
model, alphabet = esm.pretrained.esm2_t6_8M_UR50D()

# Check layers in architecture
for name, param in model.named_parameters():
    print(f"Parameter name: {name}, Size: {param.size()}")
    
model.to(device)

Parameter name: embed_tokens.weight, Size: torch.Size([33, 320])
Parameter name: layers.0.self_attn.k_proj.weight, Size: torch.Size([320, 320])
Parameter name: layers.0.self_attn.k_proj.bias, Size: torch.Size([320])
Parameter name: layers.0.self_attn.v_proj.weight, Size: torch.Size([320, 320])
Parameter name: layers.0.self_attn.v_proj.bias, Size: torch.Size([320])
Parameter name: layers.0.self_attn.q_proj.weight, Size: torch.Size([320, 320])
Parameter name: layers.0.self_attn.q_proj.bias, Size: torch.Size([320])
Parameter name: layers.0.self_attn.out_proj.weight, Size: torch.Size([320, 320])
Parameter name: layers.0.self_attn.out_proj.bias, Size: torch.Size([320])
Parameter name: layers.0.self_attn_layer_norm.weight, Size: torch.Size([320])
Parameter name: layers.0.self_attn_layer_norm.bias, Size: torch.Size([320])
Parameter name: layers.0.fc1.weight, Size: torch.Size([1280, 320])
Parameter name: layers.0.fc1.bias, Size: torch.Size([1280])
Parameter name: layers.0.fc2.weight, Size: tor

ESM2(
  (embed_tokens): Embedding(33, 320, padding_idx=1)
  (layers): ModuleList(
    (0-5): 6 x TransformerLayer(
      (self_attn): MultiheadAttention(
        (k_proj): Linear(in_features=320, out_features=320, bias=True)
        (v_proj): Linear(in_features=320, out_features=320, bias=True)
        (q_proj): Linear(in_features=320, out_features=320, bias=True)
        (out_proj): Linear(in_features=320, out_features=320, bias=True)
        (rot_emb): RotaryEmbedding()
      )
      (self_attn_layer_norm): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
      (fc1): Linear(in_features=320, out_features=1280, bias=True)
      (fc2): Linear(in_features=1280, out_features=320, bias=True)
      (final_layer_norm): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
    )
  )
  (contact_head): ContactPredictionHead(
    (regression): Linear(in_features=120, out_features=1, bias=True)
    (activation): Sigmoid()
  )
  (emb_layer_norm_after): LayerNorm((320,), eps=1e-05, elementwis

In [2]:
import numpy as np
import pandas as pd
import os
from torch.utils.data import DataLoader
from torch.utils.data import TensorDataset
import torch.optim as optim
import torch.nn as nn
from pathlib import Path
from esm.data import ESMStructuralSplitDataset
import sklearn
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score


# Download structural holdout datasets (already downloaded)
# for split_level in ['family', 'superfamily', 'fold']:
#     for cv_partition in ['0', '1', '2', '3', '4']:
#         esm_structural_train = ESMStructuralSplitDataset(
#             split_level=split_level, 
#             cv_partition=cv_partition, 
#             split='train', 
#             root_path = os.path.expanduser('~/.cache/torch/data/esm'),
#             download=True
#         )
#         esm_structural_valid = ESMStructuralSplitDataset(
#             split_level=split_level, 
#             cv_partition=cv_partition, 
#             split='valid', 
#             root_path = os.path.expanduser('~/.cache/torch/data/esm'),
#             download=True
#         )

# Set train and validation datasets from the ESM library
esm_structural_train = ESMStructuralSplitDataset(
    split_level='superfamily', 
    cv_partition='0', 
    split='train', 
    root_path = os.path.expanduser('~/.cache/torch/data/esm'),
)

esm_structural_valid = ESMStructuralSplitDataset(
    split_level='superfamily', 
    cv_partition='0', 
    split='valid', 
    root_path = os.path.expanduser('~/.cache/torch/data/esm'),
)

# Check how many entries in dictionary
print(len(esm_structural_train))
print(len(esm_structural_valid))

# Training dataset downloaded from ESMStructuralSplitDataset already split into training and validation sets
train_dataset = esm_structural_train
valid_dataset = esm_structural_valid

12031
3266


In [3]:
# Use ESM-2's batch converter for tokneisation/padding
batch_converter = alphabet.get_batch_converter()
# Pull only the sequences in entire dataset for conversion
train_data = [(i, train_dataset[i]["seq"]) for i in range(len(train_dataset))]
# Tokenise sequences
batch_labels, batch_strs, batch_tokens = batch_converter(train_data)

# Make batch_labels a tensor
#batch_labels = torch.tensor(batch_labels)
#batch_tokens_tensor = torch.tensor(batch_tokens)
batch_tokens_tensor = torch.tensor(batch_tokens).clone().detach()

train_dataset = TensorDataset(batch_tokens_tensor)
#train_dataset = TensorDataset(batch_labels, batch_tokens)

# Check the tensor dimensions
print(batch_tokens)
print(batch_tokens.shape)

tensor([[ 0, 20, 21,  ...,  1,  1,  1],
        [ 0,  8, 20,  ...,  1,  1,  1],
        [ 0, 20, 16,  ...,  1,  1,  1],
        ...,
        [ 0, 21, 19,  ...,  1,  1,  1],
        [ 0, 20,  6,  ...,  1,  1,  1],
        [ 0, 20, 10,  ...,  1,  1,  1]])
torch.Size([12031, 884])


  batch_tokens_tensor = torch.tensor(batch_tokens)


In [4]:
print(batch_tokens[0].shape)
print(batch_tokens[1].shape)
train_dataset

torch.Size([884])
torch.Size([884])


<torch.utils.data.dataset.TensorDataset at 0x203360fe910>

In [5]:
# Freeze all parameters of the pretrained model
for param in model.parameters():
    param.requires_grad = False

# Last layer output fatures decided from ESM-2's layers?
hidden_size = 320

# Modify last layer for the regression task
model.final_layer = nn.Sequential(
    nn.Dropout(0.2),
    nn.Linear(in_features=hidden_size, out_features=1),
    nn.ReLU(),
)

# Enable gradient computation for the parameters in the final_layer
for param in model.final_layer.parameters():
    param.requires_grad = True

In [6]:
# Clear cache
gc.collect()
torch.cuda.empty_cache()

#train_dataloader = DataLoader([batch_tokens], batch_size=8, shuffle=True)
train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True)

learning_rate = 0.001
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
loss_fn = nn.SmoothL1Loss()

num_epochs = 10

for epoch in range(num_epochs):

    model.train()
    total_loss = 0
    running_loss = 0.0

    for batch in train_dataloader:
        # Is there a padding issue for the shape problem, do the sequences need to be equal?
        # Does (embed_tokens): Embedding(33, 320, padding_idx=1) from the model handle this issue?
        inputs = batch #.to(device)
        # Reshape tensors to 2D
        #inputs = inputs.unsqueeze(1)
        optimizer.zero_grad()
        outputs = model(inputs)
        # AssertionError: ---> 81 assert tokens.ndim == 2
        loss = loss_fn(outputs, inputs)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        total_loss += loss.item()
        torch.cuda.empty_cache()

    average_loss = total_loss / len(train_dataloader)
    print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {average_loss:.4f}")
    print(f"Epoch {epoch+1} loss: {running_loss / len(train_dataloader):.4f}")

AttributeError: 'list' object has no attribute 'ndim'

In [7]:
print(batch.shape)

AttributeError: 'list' object has no attribute 'shape'

In [None]:
batch_converter = alphabet.get_batch_converter()

train_data = [(i, train_dataset[i]["seq"]) for i in range(len(train_dataset))]

batch_labels, batch_strs, batch_tokens = batch_converter(train_data)


batch_labels = torch.tensor(batch_labels)


train_dataset = TensorDataset(batch_labels, batch_tokens)

print(batch_tokens)
print(batch_tokens.shape)

gc.collect()
torch.cuda.empty_cache()

train_dataloader = DataLoader(train_dataset, batch_size=2, shuffle=True)

learning_rate = 0.001
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
loss_fn = nn.SmoothL1Loss()

num_epochs = 10

for epoch in range(num_epochs):

    model.train()
    total_loss = 0
    running_loss = 0.0

    for batch in train_dataloader:
        inputs = batch.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = loss_fn(outputs, inputs)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        total_loss += loss.item()
        torch.cuda.empty_cache()

    average_loss = total_loss / len(train_dataloader)
    print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {average_loss:.4f}")
    print(f"Epoch {epoch+1} loss: {running_loss / len(train_dataloader):.4f}")

In [None]:
# ESM-2 tokeniser for data
valid_data = [(i, valid_dataset[i]["seq"]) for i in range(len(valid_dataset))]
batch_labels, batch_strs, batch_tokens = batch_converter(valid_data)
valid_dataloader = DataLoader(batch_tokens, batch_size=16, shuffle=False)

# Set evaluation mode for dropout and batch normalisation
model.eval()

# Initialise lists   
predictions = []
true_contacts = []

# Turn off gradients for evaluation
with torch.no_grad():
    for batch in valid_dataloader:
        inputs = batch.to(device)  # Move batch to the specified device
        outputs = model(inputs)
        predictions.extend(outputs.tolist())

# Convert predictions and contact lists into tensors
predictions = torch.tensor(predictions)
true_contacts = torch.tensor(targets)

# Evaluation metrics
accuracy = accuracy_score(true_contacts, predictions)
precision = precision_score(true_contacts, predictions)
recall = recall_score(true_contacts, predictions)
f1 = f1_score(true_contacts, predictions)

print(f"Accuracy: {accuracy:.4f}")
print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")
print(f"F1-Score: {f1:.4f}")