In [None]:
# scientific libs
from datetime import datetime

# scientific libs
import numpy as np

# DL libs
import torch
from torch.utils.data import DataLoader
import torch.optim as optim
import torch.nn as nn
import esm
from torch.utils.tensorboard import SummaryWriter
from esm.data import ESMStructuralSplitDataset
from sklearn.metrics import roc_auc_score

# graph libs
import matplotlib.pyplot as plt
from tqdm import tqdm

In [None]:
"""Adapted from: https://github.com/facebookresearch/esm/blob/main/examples/esm_structural_dataset.ipynb"""

data_path = "./data/esm"

# Load datasets
train_dataset = ESMStructuralSplitDataset(
    split_level='superfamily',
    cv_partition='4',
    split='train',
    root_path = data_path
)

valid_dataset = ESMStructuralSplitDataset(
    split_level='superfamily',
    cv_partition='4',
    split='valid',
    root_path = data_path
)

In [None]:
# Load model
model, alphabet = esm.pretrained.esm2_t6_8M_UR50D()
# model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()

In [None]:
# Freeze all parameters of the pretrained model
for param in model.parameters():
    param.requires_grad = False

In [None]:
# Prepare data, optimizer, and objective/loss function
batch_converter = alphabet.get_batch_converter()
contact_threshold = 15

In [None]:
# Run the code below to get exampes of output
rand_example = np.random.randint(len(valid_dataset))
rand_target = valid_dataset[rand_example]
print(f"Data point {rand_example}, {rand_target['seq']}")

In [None]:
# Fixed example to compare
rand_example = 229
rand_target = valid_dataset[rand_example]

In [None]:
batch_labels, batch_strs, batch_tokens = batch_converter([(rand_example, rand_target["seq"])])
rand_target_c = rand_target['dist'] < contact_threshold

outputs = model(batch_tokens, return_contacts=True)

In [None]:
N = len(batch_strs[0])
fig, ax = plt.subplots(1, 2, figsize=(8, 3))
im = ax[0].imshow(outputs['contacts'][0].detach().numpy() > 0.5)
fig.colorbar(im)
ax[0].set_title("Predicted")
im = ax[1].imshow(rand_target_c) #['dist'] < contact_threshold)
fig.colorbar(im)
ax[1].set_title("Real")
plt.show()

In [None]:
model

In [None]:
# Modify only the last layer for regression task
model.contact_head.regression = nn.Linear(in_features=model.contact_head.regression.in_features, out_features=1)

# Set requires_grad=True only for the regression layer parameters to be trained
for param in model.contact_head.regression.parameters():
    param.requires_grad = True

In [None]:
# remove all the nans
masked_train = []
masked_valid = []

# modify dataset to remove/mask entries without any coordinates/distances
for data in train_dataset:
    mask = ~np.isnan(data["coords"].sum(axis=1))
    mdist = data["dist"][mask][:, mask]
    masked_entry = {
        "seq": "".join(c for c, cm in zip(data['seq'], mask) if cm),
        "ssp": "".join(c for c, cm in zip(data['ssp'], mask) if cm),
        "coords": data["coords"][mask],
        # Boolean values to only recognise distances up to a threshold of 15A
        "dist": mdist < contact_threshold
    }
    masked_train.append(masked_entry)

for data in valid_dataset:
    mask = ~np.isnan(data["coords"].sum(axis=1))
    mdist = data["dist"][mask][:, mask]
    masked_entry = {
        "seq": ''.join(c for c, cm in zip(data["seq"], mask) if cm),
        "ssp": ''.join(c for c, cm in zip(data["ssp"], mask) if cm),
        "coords": data["coords"][mask],
        "dist": mdist < contact_threshold
    }
    masked_valid.append(masked_entry)

In [None]:
device = torch.device("cpu")
if torch.backends.cuda.is_built():
    device = torch.device("cuda")
elif torch.backends.mps.is_available() and torch.backends.mps.is_built():
    device = torch.device("mps")

print(f"Using device {device}")

In [None]:
# Move model and tensors to device
model = model.to(device)

In [None]:
# Run name
run_name = "finetune_esm2_t6_8M_UR50D_4-" + datetime.now().strftime("%Y%m%d-%H%M%S")
tb_writer = SummaryWriter(log_dir=f"./runs/{run_name}")

def memory_usage():
    return torch.mps.current_allocated_memory() / 1e9

In [None]:
# Training loop
learning_rate = 0.003
batch_size = 1024
num_epochs = 10


# Create an optimizer object
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=learning_rate)

# loss is calculated for each input and target pair, mean values calculated manually
loss_fn_none = nn.BCELoss(reduction="none").to(device)
loss_fn_mean = nn.BCELoss(reduction="mean").to(device)

print(f"Epoch\tTrain loss\tTest loss")
for epoch in range(num_epochs):
    # Initialise losses
    total_loss = 0
    valid_loss = 0
    validation_size = 1

    # Set model to training mode
    model.train()

    # Training on randomly selected sequences in batches of 8
    for b in tqdm(
        DataLoader(
            np.random.choice(
                len(masked_train),
                size=batch_size,
                replace=False
            ),
            batch_size=8,
            shuffle=True
        ), ncols=40):
        # Tokenise input sequences
        batch_labels, batch_strs, batch_tokens = batch_converter([(i, masked_train[i]["seq"]) for i in b])

        # Clear gradients for each epoch
        optimizer.zero_grad()

        # Output predictions for batch
        outputs = model(batch_tokens.to(device), return_contacts=True)

        # Initialise objects to 0 to match the format of contact output tensor
        targets = torch.zeros_like(outputs["contacts"])
        mask = torch.zeros_like(outputs["contacts"])
        src_mask = torch.zeros_like(outputs["contacts"])

        # Pull masked and boolean (dist threshold) values from training sequences
        for i_, ti in enumerate(b):
            cm = masked_train[ti]["dist"]
            N = cm.shape[0]
            targets[i_, :N, :N] = torch.tensor(cm)
            mask[i_, :N, :N] = 1
            # Short range contacts up to 12 redidues are masked
            row_up, col_up = torch.triu_indices(N, N, offset=12)
            row_low, col_low = torch.tril_indices(N, N, offset=-12)
            src_mask[i_, row_up, col_up] = 1
            src_mask[i_, row_low, col_low] = 1
            targets = targets * src_mask

        del src_mask

        # Calculates bce loss between predictions and true values
        loss = loss_fn_none(outputs["contacts"], targets.to(device))

        del targets, outputs

        # Manually calculate mean per run (reduction = "none")
        loss = (loss * mask).mean()

        del mask

        # Pool loss values from each batch
        total_loss += loss.item()

        # Backpropagation
        loss.backward()

        # Updates last layer parameters to reduce loss
        optimizer.step()

    # Set model to evaluation mode
    model.eval()

    for b in DataLoader(range(validation_size), batch_size=1, shuffle=True):
        batch_labels, batch_strs, batch_tokens = batch_converter([(i, masked_valid[i]["seq"]) for i in b])

        # contacts
        outputs = model(batch_tokens.to(device), return_contacts=True)

        del batch_tokens

        # Calculates loss between predictions and true values
        targets = torch.tensor(np.array([masked_valid[i]["dist"] for i in b]), dtype=torch.float32).to(device)
        loss = loss_fn_mean(outputs["contacts"], targets)

        del targets

        # Pool loss values from each batch
        valid_loss += loss.item()

    # Print loss per epoch
    average_loss = total_loss / batch_size
    average_loss_test = valid_loss / validation_size

    # Write to Tensorboard logs
    tb_writer.add_scalar("Loss/train", average_loss, epoch)
    tb_writer.add_scalar("Loss/test", average_loss_test, epoch)
    tb_writer.add_scalar("Memory usage (GB)", memory_usage(), epoch)

    print(f"{epoch+1}/{num_epochs}\t{average_loss:.4f}\t\t{average_loss_test:.4f}")

In [None]:
batch_labels, batch_strs, batch_tokens = batch_converter([(rand_example, rand_target["seq"])])
rand_target_c = rand_target['dist'] < contact_threshold

outputs = model(batch_tokens.to(device), return_contacts=True)

N = len(batch_strs[0])
fig, ax = plt.subplots(1, 2, figsize=(8, 3))
im = ax[0].imshow(outputs['contacts'][0].to('cpu').detach().numpy() > 0.5)
fig.colorbar(im)
ax[0].set_title("Predicted")
im = ax[1].imshow(rand_target_c)
fig.colorbar(im)
ax[1].set_title("Real")
plt.show()

In [None]:
# Random example testing
rand_example = np.random.randint(len(valid_dataset))
rand_target = valid_dataset[rand_example]

print(f"Data point {rand_example}, {rand_target['seq']}")

batch_labels, batch_strs, batch_tokens = batch_converter([(rand_example, rand_target["seq"])])
rand_target_c = rand_target['dist'] < contact_threshold

outputs = model(batch_tokens, return_contacts=True)

N = len(batch_strs[0])
fig, ax = plt.subplots(1, 2, figsize=(8, 3))
im = ax[0].imshow(outputs['contacts'][0].detach().numpy() > 0.5)
fig.colorbar(im)
ax[0].set_title("Predicted")
im = ax[1].imshow(rand_target_c)
fig.colorbar(im)
ax[1].set_title("Real")
plt.show()

In [None]:
# torch.save(model, 'trained_model_1024_BCE_6ep.pth')