In [None]:
# Validate the fine tuned model

In [22]:
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.nn.functional as F

import os
from pathlib import Path

import numpy as np

from sklearn.metrics import roc_auc_score, precision_score

import esm
from esm.data import ESMStructuralSplitDataset

import matplotlib.pyplot as plt

from tqdm import tqdm

In [2]:
model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
model = torch.load('trained_model_1024_BCE_6ep.pth')
model.eval()

ESM2(
  (embed_tokens): Embedding(33, 1280, padding_idx=1)
  (layers): ModuleList(
    (0-32): 33 x TransformerLayer(
      (self_attn): MultiheadAttention(
        (k_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (v_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (q_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (out_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (rot_emb): RotaryEmbedding()
      )
      (self_attn_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
      (fc1): Linear(in_features=1280, out_features=5120, bias=True)
      (fc2): Linear(in_features=5120, out_features=1280, bias=True)
      (final_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
    )
  )
  (contact_head): ContactPredictionHead(
    (regression): Linear(in_features=660, out_features=1, bias=True)
    (activation): Sigmoid()
  )
  (emb_layer_norm_after): LayerNorm((1280,), eps=1

In [3]:
valid_dataset = ESMStructuralSplitDataset(
    split_level='superfamily', 
    cv_partition='4', 
    split='valid', 
    root_path = os.path.expanduser('~/.cache/torch/data/esm'),
)

In [6]:
contact_threshold = 15

masked_valid = []

for t in valid_dataset:
    mask = ~np.isnan(t["coords"].sum(axis=1))
    mdist = t["dist"][mask][:, mask]
    masked_entry = {
        "seq": ''.join(c for c, cm in zip(t["seq"], mask) if cm),
        "ssp": ''.join(c for c, cm in zip(t["ssp"], mask) if cm),
        "coords": t["coords"][mask],
        "dist": mdist < contact_threshold
    }
    masked_valid.append(masked_entry)

In [7]:
device = torch.device("cpu")
print(f"Using device {device}")

Using device cpu


In [13]:
len(masked_valid)

2985

In [31]:
batch_converter = alphabet.get_batch_converter()

num_epochs = 1

for epoch in range(num_epochs):
        
    valid_loss = 0
    vsize = 8
    
    all_targets = []
    all_predictions = []
    
    for b in tqdm(
        DataLoader(np.random.choice(len(masked_valid), size=vsize, replace=False), batch_size=8, shuffle=True),
                   ncols=40):
    
        batch_labels, batch_strs, batch_tokens = batch_converter([(i, masked_valid[i]["seq"]) for i in b])
        inputs = batch_tokens
        
        outputs = model(inputs.to(device), return_contacts=True)
        
        targets = torch.zeros_like(outputs["contacts"])
        
        for i_, ti in enumerate(b):
            cm = masked_valid[ti]["dist"]
            N = cm.shape[0]
            targets[i_, :N, :N] = torch.tensor(cm)
        
        all_targets.append(targets)
        all_predictions.append(outputs['contacts'][0].detach().numpy() > 0.5)
        
        # Calculates loss between predictions and true values
        loss = nn.BCELoss()(outputs["contacts"], targets)
        
        # Pool loss values from each batch
        valid_loss += loss.item()
        del batch_tokens
    
    # Print loss per epoch 
    average_loss_test = valid_loss / vsize
    auc_score = roc_auc_score(all_targets, all_predictions)
    
    print(f"{epoch+1}/{num_epochs}\t\t{average_loss_test:.4f}")
    print(f'AUC: {auc_score:.4f}')

100%|█████| 1/1 [00:20<00:00, 20.41s/it]
  return numpy.asarray(x, dtype=dtype)
  return numpy.asarray(x, dtype=dtype)
  array = numpy.asarray(array, order=order, dtype=dtype)


ValueError: only one element tensors can be converted to Python scalars

In [None]:
batch_labels, batch_strs, batch_tokens = batch_converter([(rand_example, rand_target["seq"])])
rand_target_c = rand_target['dist'] < contact_threshold

outputs = model(batch_tokens, return_contacts=True)

N = len(batch_strs[0])
fig, ax = plt.subplots(1, 2, figsize=(8, 3))
im = ax[0].imshow(outputs['contacts'][0].detach().numpy() > 0.5)
fig.colorbar(im)
ax[0].set_title("Predicted")
im = ax[1].imshow(rand_target_c)
fig.colorbar(im)
ax[1].set_title("Real")
plt.show()

In [None]:
# Compare benchmark with original model

In [None]:
model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
model.eval()