## Measuring model ddG accuracy on protein variants
The goal of this notebook is to experiment with model ddG prediction capabilities, and exploring the possibility of predicting ddG of variants that were not supported previously.

In [20]:
#imports
import sys
sys.path.append('..')    # add parent directory to path
import numpy as np
import matplotlib.pyplot as plt
from model.hydro_net import PEM
from model.model_cfg import CFG
from Utils.train_utils import *
from Utils.pdb_parser import get_pdb_data
import torch
import pandas as pd
import os
import glob
from torch.utils.data import Dataset
from torch.utils.data import DataLoader


In [21]:
#load model
# Import the model
model = PEM(layers=CFG.num_layers,gaussian_coef=CFG.gaussian_coef).to(CFG.device)
# Upload model weights
CFG.model_path = '../data/Trained_models/'
epoch = 25
model_dict = torch.load(CFG.model_path+f"{epoch}_final_model.pt",map_location=CFG.device,weights_only=False)
model.load_state_dict(model_dict['model_state_dict'])

<All keys matched successfully>

In [22]:
#Configuration
remove_indels = True
one_mute = True
# Constants
NUMBER_OF_VARIANTS = 128
TENSOR_ROOT = r"mutation_data\tensors" 
MUTATIONS_ROOT = r"mutation_data\mutations"
COORDS = 'coords_tensor.pt'
DELTA_G = 'deltaG.pt'
MASKS = 'mask_tensor.pt'
ONE_HOT = 'one_hot_encodings.pt'
PROTT5_EMBEDDINGS = 'prott5_embeddings'
VAL_RATIO = 0.2
RANDOM_SEED = 42
NANO_TO_ANGSTROM = 0.1
DEBUG = False
TM_PATH = './data/Processed_K50_dG_datasets/TM_proteins.csv'

### Experimenting with a single protein

In [23]:
# Function from AllProteinValidationDataset
def load_embedding_tensor(embeddings_dir):
    embeddings = []
    all_embedding_files = sorted(glob.glob(os.path.join(embeddings_dir, 'prott5_embedding_*.pt')),
                                 key=lambda x: int(os.path.splitext(x)[0].split('_')[-1]))

    for filename in all_embedding_files:
        if filename.endswith('.pt'):
            embedding_tensor = torch.load(filename, map_location=torch.device('cpu'))  # Ensure loading on CPU
            embeddings.append(embedding_tensor)

    return torch.vstack(embeddings)

In [24]:
def find_wt_index(mutations_data):
    """
    Find the index of the wildtype sequence in the list of mutations
    Figure out later which wt is the true wt
    """
    
    return 0

def calculate_dG(protein_graph, protein_unfolded_graph):
    model.eval()
    with torch.no_grad():
        Gf = model(protein_graph.unsqueeze(0))
        Gf = Gf.cpu().numpy()
        Gf = Gf[0]
        Gu = model(protein_unfolded_graph.unsqueeze(0))
        Gu = Gu.cpu().numpy()
        Gu = Gu[0]
        pred_deltaG = Gu - Gf
        return pred_deltaG

In [25]:
#Data preparation
protein_path = os.path.join(TENSOR_ROOT, '1A0N')
mutations_path = os.path.join(MUTATIONS_ROOT, '1A0N.csv')

mutations = pd.read_csv(mutations_path)

if remove_indels:
    mutations = mutations[~mutations['mut_type'].str.contains('ins|del')].reset_index(drop=True)
    
# Load and preprocess the data for each protein
coords_tensor = torch.load(os.path.join(protein_path, COORDS))
delta_g_tensor = torch.load(os.path.join(protein_path, DELTA_G))
mask_tensor = torch.load(os.path.join(protein_path, MASKS))
one_hot_tensor = torch.load(os.path.join(protein_path, ONE_HOT))
embedding_tensor = load_embedding_tensor(os.path.join(protein_path, PROTT5_EMBEDDINGS))

# remove the mutations with more than one mutation
if one_mute:
    one_mut_index = mutations[~mutations['mut_type'].str.contains(':')]
    mutations = mutations.loc[one_mut_index.index]
    delta_g_tensor = delta_g_tensor[one_mut_index.index]
    one_hot_tensor = one_hot_tensor[one_mut_index.index]
    embedding_tensor = embedding_tensor[one_mut_index.index]
    
mutations_data = {
    'name': protein_path,
    'mutations': mutations['mut_type'].to_list(),
    'prott5': embedding_tensor,
    'coords': coords_tensor,
    'one_hot': one_hot_tensor,
    'delta_g': delta_g_tensor,
    'masks': mask_tensor
}


  coords_tensor = torch.load(os.path.join(protein_path, COORDS))
  delta_g_tensor = torch.load(os.path.join(protein_path, DELTA_G))
  mask_tensor = torch.load(os.path.join(protein_path, MASKS))
  one_hot_tensor = torch.load(os.path.join(protein_path, ONE_HOT))
  embedding_tensor = torch.load(filename, map_location=torch.device('cpu'))  # Ensure loading on CPU


In [26]:
#print mutation data information
print(f"Protein: {mutations_data['name']}")
print(f"Number of mutations: {len(mutations_data['mutations'])}")
print(f"Protein T5 embeddings shape: {mutations_data['prott5'].shape}")
print(f"Protein coords shape: {mutations_data['coords'].shape}")
print(f"Protein one hot encodings shape: {mutations_data['one_hot'].shape}")
print(f"Protein delta G shape: {mutations_data['delta_g'].shape}")
print(f"Protein masks shape: {mutations_data['masks'].shape}")


Protein: mutation_data\tensors\1A0N
Number of mutations: 2210
Protein T5 embeddings shape: torch.Size([2210, 58, 1024])
Protein coords shape: torch.Size([58, 4, 3])
Protein one hot encodings shape: torch.Size([2210, 58, 21])
Protein delta G shape: torch.Size([2210])
Protein masks shape: torch.Size([58])


In [27]:
#get wt folded graph
#assuming wt is the first sequence
wt_index = find_wt_index(mutations_data)

seq_one_hot = mutations_data['one_hot'][wt_index]
proT5_emb = mutations_data['prott5'][wt_index]
mask = mutations_data['masks'][wt_index]
coords_tensor = mutations_data['coords']
protein_graph = get_graph( coords_tensor ,seq_one_hot, proT5_emb,mask)
print(protein_graph.shape)
#get wt unfolded graph
protein_unfolded_graph = get_unfolded_graph( coords_tensor ,seq_one_hot, proT5_emb,mask)
#calculate dG
pred_wt_deltaG = calculate_dG(protein_graph, protein_unfolded_graph)
print(pred_wt_deltaG)


torch.Size([58, 1093])
-1.9128189


In [28]:
pred_mut_dG = []
for i in range(1,NUMBER_OF_VARIANTS):
    if i == wt_index:
        continue
    seq_one_hot = mutations_data['one_hot'][i]
    proT5_emb = mutations_data['prott5'][i]
    mask = mutations_data['masks'][0]
    coords_tensor = mutations_data['coords']
    protein_graph = get_graph( coords_tensor ,seq_one_hot, proT5_emb,mask)
    protein_unfolded_graph = get_unfolded_graph( coords_tensor ,seq_one_hot, proT5_emb,mask)
    pred_deltaG = calculate_dG(protein_graph, protein_unfolded_graph)
    pred_mut_dG.append(pred_deltaG)

print(pred_mut_dG)
#average dG
print(np.mean(pred_mut_dG))
#avarage ddG
print(np.mean(pred_mut_dG)-pred_wt_deltaG)
#calculate the correlation
true_dG = mutations_data['delta_g'][1:NUMBER_OF_VARIANTS]
print(true_dG)
print(np.corrcoef(pred_mut_dG,true_dG))


[-1.9128189, -1.9128189, -1.9128189, -1.9128189, -2.059597, -1.4475288, 3.6489077, 1.1425676, -0.2574787, -1.3480759, -0.39097214, -1.5175056, -1.7922173, -0.48610306, -0.74300194, -0.023918152, -2.3051147, -1.6167736, -0.79006004, -0.55924225, -1.777729, -2.0512009, -1.8434162, -1.7883263, -1.4741306, -2.4980278, -1.6094913, -1.4427719, -1.4530525, -1.7375412, -1.3019295, -2.1678772, -1.7538815, -1.8636036, -1.700676, -1.706028, -1.7485104, -2.398903, -1.4889164, -1.7892838, -0.92637825, -0.80413246, -1.4868145, -0.7943363, -1.7148018, -0.76965714, -1.8974648, -1.3092804, -1.1796398, -1.4278622, -2.1715374, -2.190075, -1.5623894, -1.6792889, -2.6545181, -2.8301697, -1.2293892, -1.2370129, -1.4861164, -1.9986649, -1.7525349, -2.5444126, -2.1364803, -1.9321976, -1.627039, -3.2045307, -2.6811638, -2.238329, -1.5813675, -2.8107643, -2.7763882, -1.6301899, -2.3790035, -2.3221645, -3.1501884, -3.0566463, -1.1757126, -2.0618057, -1.8428555, -1.3071537, -1.2411041, -1.5733128, -1.541832, -1.5

### Repeating the experiment with many proteins

In [29]:
class MutationDataset(Dataset):

    def __init__(self, tensor_root_dir, mutations_root_dir, remove_TM = False, one_mut = True):
        self.tensor_root_dir = tensor_root_dir
        self.mutations_root_dir = mutations_root_dir
        self.protein_dirs = [protein for i, protein in enumerate(os.listdir(self.tensor_root_dir))]
        self.one_mut = one_mut # remove the mutations with more than one mutation
        # remove TM proteins 
        if remove_TM:
            tm_proteins = pd.read_csv(TM_PATH)
            tm_proteins = tm_proteins['name'].apply(lambda x: x.split(".")[0]).unique().tolist()
            self.protein_dirs = [protein for protein in self.protein_dirs if protein not in tm_proteins]
        if DEBUG:
            self.protein_dirs = self.protein_dirs[:5]
        # # Train test split
        # self.training_protein, self.val_proteins = train_test_split(self.protein_dirs, test_size=VAL_RATIO, random_state=RANDOM_SEED)
        # if train:
        #     self.protein_dirs = self.training_protein
        # else:
        #     self.protein_dirs = self.val_proteins

    def __len__(self):
        return len(self.protein_dirs)

    def __getitem__(self, idx):
        protein_dir = os.path.join(self.tensor_root_dir, self.protein_dirs[idx])
        mutations_path = os.path.join(self.mutations_root_dir, f'{self.protein_dirs[idx]}.csv')
        mutations = pd.read_csv(mutations_path)
        mutations = mutations[~mutations['mut_type'].str.contains('ins|del')].reset_index(drop=True)
        # Load and preprocess the data for each protein
        coords_tensor = torch.load(os.path.join(protein_dir, COORDS))
        delta_g_tensor = torch.load(os.path.join(protein_dir, DELTA_G))
        mask_tensor = torch.load(os.path.join(protein_dir, MASKS))
        one_hot_tensor = torch.load(os.path.join(protein_dir, ONE_HOT))
        embedding_tensor = load_embedding_tensor(os.path.join(protein_dir, PROTT5_EMBEDDINGS))
        
        # remove the mutations with more than one mutation
        if self.one_mut:
            one_mut_index = mutations[~mutations['mut_type'].str.contains(':')]
            mutations = mutations.loc[one_mut_index.index]
            delta_g_tensor = delta_g_tensor[one_mut_index.index]
            one_hot_tensor = one_hot_tensor[one_mut_index.index]
            embedding_tensor = embedding_tensor[one_mut_index.index]
            
        mutations_data = {
            'name': self.protein_dirs[idx],
            'mutations': mutations['mut_type'].to_list(),
            'prott5': embedding_tensor,
            'coords': coords_tensor,
            'one_hot': one_hot_tensor,
            'delta_g': delta_g_tensor,
            'masks': mask_tensor
        }

        return mutations_data

In [30]:
protein_dataset = MutationDataset(TENSOR_ROOT, MUTATIONS_ROOT, one_mut = True)
#create a DataLoader
#protein_dataloader = DataLoader(protein_dataset, batch_size=1, shuffle=True)
print(protein_dataset[0])
avg_dGs = []
avg_ddGs = []
correlations = []
for mutations_data in protein_dataset:
    #calculate wt dG
    print(f"Protein: {mutations_data['name']}")
    wt_index = find_wt_index(mutations_data)
    seq_one_hot = mutations_data['one_hot'][wt_index]
    proT5_emb = mutations_data['prott5'][wt_index]
    mask = mutations_data['masks'][0]
    coords_tensor = mutations_data['coords']
    protein_graph = get_graph( coords_tensor ,seq_one_hot, proT5_emb,mask)
    protein_unfolded_graph = get_unfolded_graph( coords_tensor ,seq_one_hot, proT5_emb,mask)
    pred_mut_dG = []
    pred_mut_ddG = []
    for i in range(1,NUMBER_OF_VARIANTS):
        if i == wt_index:
            continue
        seq_one_hot = mutations_data['one_hot'][i]
        proT5_emb = mutations_data['prott5'][i]
        mask = mutations_data['masks'][0]
        coords_tensor = mutations_data['coords']
        protein_graph = get_graph( coords_tensor ,seq_one_hot, proT5_emb, mask)
        protein_unfolded_graph = get_unfolded_graph( coords_tensor ,seq_one_hot, proT5_emb,mask)
        pred_deltaG = calculate_dG(protein_graph, protein_unfolded_graph)
        pred_mut_dG.append(pred_deltaG)
        pred_mut_ddG.append(pred_deltaG-pred_wt_deltaG)

    #average dG
    avg_dG = np.mean(pred_mut_dG)
    avg_dGs.append(avg_dG)
    #average ddG
    avg_ddG = avg_dG-pred_wt_deltaG
    avg_ddGs.append(avg_ddG)
    #calculate the correlation 
    true_dG = mutations_data['delta_g'][1:NUMBER_OF_VARIANTS]
    correlation = np.corrcoef(pred_mut_dG,true_dG)
    correlations.append(correlation)
    
print(f"avg deltaG for each protein: {avg_dGs}") 

print(f"avg DDG for each protein: {avg_ddGs}")

# Print the correlation for each protein 
for idx, corr in enumerate(correlations):
    print(f"Protein {idx + 1} correlation: {corr[0, 1]}")
        

  coords_tensor = torch.load(os.path.join(protein_dir, COORDS))
  delta_g_tensor = torch.load(os.path.join(protein_dir, DELTA_G))
  mask_tensor = torch.load(os.path.join(protein_dir, MASKS))
  one_hot_tensor = torch.load(os.path.join(protein_dir, ONE_HOT))
  embedding_tensor = torch.load(filename, map_location=torch.device('cpu'))  # Ensure loading on CPU


{'name': '1A0N', 'mutations': ['wt', 'wt', 'wt', 'wt', 'wt', 'V1Q', 'V1E', 'V1N', 'V1H', 'V1D', 'V1R', 'V1K', 'V1T', 'V1S', 'V1A', 'V1G', 'V1M', 'V1L', 'V1I', 'V1W', 'V1Y', 'V1F', 'V1P', 'T2Q', 'T2E', 'T2N', 'T2H', 'T2D', 'T2R', 'T2K', 'T2S', 'T2A', 'T2G', 'T2M', 'T2L', 'T2V', 'T2I', 'T2W', 'T2Y', 'T2F', 'T2P', 'L3Q', 'L3E', 'L3N', 'L3H', 'L3D', 'L3R', 'L3K', 'L3T', 'L3S', 'L3A', 'L3G', 'L3M', 'L3V', 'L3I', 'L3W', 'L3Y', 'L3F', 'L3P', 'L3C', 'F4Q', 'F4E', 'F4N', 'F4H', 'F4D', 'F4R', 'F4K', 'F4T', 'F4S', 'F4A', 'F4G', 'F4M', 'F4L', 'F4V', 'F4I', 'F4W', 'F4Y', 'F4P', 'F4C', 'V5Q', 'V5E', 'V5N', 'V5H', 'V5D', 'V5R', 'V5K', 'V5T', 'V5S', 'V5A', 'V5G', 'V5M', 'V5L', 'V5I', 'V5W', 'V5Y', 'V5F', 'V5P', 'V5C', 'A6Q', 'A6E', 'A6N', 'A6H', 'A6D', 'A6R', 'A6K', 'A6T', 'A6S', 'A6G', 'A6M', 'A6L', 'A6V', 'A6I', 'A6W', 'A6Y', 'A6F', 'A6P', 'A6C', 'S7Q', 'S7E', 'S7N', 'S7H', 'S7D', 'S7R', 'S7K', 'S7T', 'S7A', 'S7G', 'S7M', 'S7L', 'S7V', 'S7I', 'S7W', 'S7Y', 'S7F', 'S7P', 'S7C', 'Y8Q', 'Y8E', 'Y8N', '

  coords_tensor = torch.load(os.path.join(protein_dir, COORDS))
  delta_g_tensor = torch.load(os.path.join(protein_dir, DELTA_G))
  mask_tensor = torch.load(os.path.join(protein_dir, MASKS))
  one_hot_tensor = torch.load(os.path.join(protein_dir, ONE_HOT))
  embedding_tensor = torch.load(filename, map_location=torch.device('cpu'))  # Ensure loading on CPU


Protein: 1A0N
torch.Size([58, 1093])
torch.Size([58, 1093])
torch.Size([58, 1093])
torch.Size([58, 1093])
torch.Size([58, 1093])
torch.Size([58, 1093])
torch.Size([58, 1093])
torch.Size([58, 1093])
torch.Size([58, 1093])
torch.Size([58, 1093])
torch.Size([58, 1093])
torch.Size([58, 1093])
torch.Size([58, 1093])
torch.Size([58, 1093])
torch.Size([58, 1093])
torch.Size([58, 1093])
torch.Size([58, 1093])
torch.Size([58, 1093])
torch.Size([58, 1093])
torch.Size([58, 1093])
torch.Size([58, 1093])
torch.Size([58, 1093])
torch.Size([58, 1093])
torch.Size([58, 1093])
torch.Size([58, 1093])
torch.Size([58, 1093])
torch.Size([58, 1093])
torch.Size([58, 1093])
torch.Size([58, 1093])
torch.Size([58, 1093])
torch.Size([58, 1093])
torch.Size([58, 1093])
torch.Size([58, 1093])
torch.Size([58, 1093])
torch.Size([58, 1093])
torch.Size([58, 1093])
torch.Size([58, 1093])
torch.Size([58, 1093])
torch.Size([58, 1093])
torch.Size([58, 1093])
torch.Size([58, 1093])
torch.Size([58, 1093])
torch.Size([58, 1093

  coords_tensor = torch.load(os.path.join(protein_dir, COORDS))
  delta_g_tensor = torch.load(os.path.join(protein_dir, DELTA_G))
  mask_tensor = torch.load(os.path.join(protein_dir, MASKS))
  one_hot_tensor = torch.load(os.path.join(protein_dir, ONE_HOT))
  embedding_tensor = torch.load(filename, map_location=torch.device('cpu'))  # Ensure loading on CPU


Protein: 1A32
torch.Size([63, 1093])
torch.Size([63, 1093])
torch.Size([63, 1093])
torch.Size([63, 1093])
torch.Size([63, 1093])
torch.Size([63, 1093])
torch.Size([63, 1093])
torch.Size([63, 1093])
torch.Size([63, 1093])
torch.Size([63, 1093])
torch.Size([63, 1093])
torch.Size([63, 1093])
torch.Size([63, 1093])
torch.Size([63, 1093])
torch.Size([63, 1093])
torch.Size([63, 1093])
torch.Size([63, 1093])
torch.Size([63, 1093])
torch.Size([63, 1093])
torch.Size([63, 1093])
torch.Size([63, 1093])
torch.Size([63, 1093])
torch.Size([63, 1093])
torch.Size([63, 1093])
torch.Size([63, 1093])
torch.Size([63, 1093])
torch.Size([63, 1093])
torch.Size([63, 1093])
torch.Size([63, 1093])
torch.Size([63, 1093])
torch.Size([63, 1093])
torch.Size([63, 1093])
torch.Size([63, 1093])
torch.Size([63, 1093])
torch.Size([63, 1093])
torch.Size([63, 1093])
torch.Size([63, 1093])
torch.Size([63, 1093])
torch.Size([63, 1093])
torch.Size([63, 1093])
torch.Size([63, 1093])
torch.Size([63, 1093])
torch.Size([63, 1093

  coords_tensor = torch.load(os.path.join(protein_dir, COORDS))
  delta_g_tensor = torch.load(os.path.join(protein_dir, DELTA_G))
  mask_tensor = torch.load(os.path.join(protein_dir, MASKS))
  one_hot_tensor = torch.load(os.path.join(protein_dir, ONE_HOT))
  embedding_tensor = torch.load(filename, map_location=torch.device('cpu'))  # Ensure loading on CPU


Protein: 1AOY
torch.Size([69, 1093])
torch.Size([69, 1093])
torch.Size([69, 1093])
torch.Size([69, 1093])
torch.Size([69, 1093])
torch.Size([69, 1093])
torch.Size([69, 1093])
torch.Size([69, 1093])
torch.Size([69, 1093])
torch.Size([69, 1093])
torch.Size([69, 1093])
torch.Size([69, 1093])
torch.Size([69, 1093])
torch.Size([69, 1093])
torch.Size([69, 1093])
torch.Size([69, 1093])
torch.Size([69, 1093])
torch.Size([69, 1093])
torch.Size([69, 1093])
torch.Size([69, 1093])
torch.Size([69, 1093])
torch.Size([69, 1093])
torch.Size([69, 1093])
torch.Size([69, 1093])
torch.Size([69, 1093])
torch.Size([69, 1093])
torch.Size([69, 1093])
torch.Size([69, 1093])
torch.Size([69, 1093])
torch.Size([69, 1093])
torch.Size([69, 1093])
torch.Size([69, 1093])
torch.Size([69, 1093])
torch.Size([69, 1093])
torch.Size([69, 1093])
torch.Size([69, 1093])
torch.Size([69, 1093])
torch.Size([69, 1093])
torch.Size([69, 1093])
torch.Size([69, 1093])
torch.Size([69, 1093])
torch.Size([69, 1093])
torch.Size([69, 1093

  coords_tensor = torch.load(os.path.join(protein_dir, COORDS))
  delta_g_tensor = torch.load(os.path.join(protein_dir, DELTA_G))
  mask_tensor = torch.load(os.path.join(protein_dir, MASKS))
  one_hot_tensor = torch.load(os.path.join(protein_dir, ONE_HOT))
  embedding_tensor = torch.load(filename, map_location=torch.device('cpu'))  # Ensure loading on CPU


Protein: 1B7J
torch.Size([64, 1093])
torch.Size([64, 1093])
torch.Size([64, 1093])
torch.Size([64, 1093])
torch.Size([64, 1093])
torch.Size([64, 1093])
torch.Size([64, 1093])
torch.Size([64, 1093])
torch.Size([64, 1093])
torch.Size([64, 1093])
torch.Size([64, 1093])
torch.Size([64, 1093])
torch.Size([64, 1093])
torch.Size([64, 1093])
torch.Size([64, 1093])
torch.Size([64, 1093])
torch.Size([64, 1093])
torch.Size([64, 1093])
torch.Size([64, 1093])
torch.Size([64, 1093])
torch.Size([64, 1093])
torch.Size([64, 1093])
torch.Size([64, 1093])
torch.Size([64, 1093])
torch.Size([64, 1093])
torch.Size([64, 1093])
torch.Size([64, 1093])
torch.Size([64, 1093])
torch.Size([64, 1093])
torch.Size([64, 1093])
torch.Size([64, 1093])
torch.Size([64, 1093])
torch.Size([64, 1093])
torch.Size([64, 1093])
torch.Size([64, 1093])
torch.Size([64, 1093])
torch.Size([64, 1093])
torch.Size([64, 1093])
torch.Size([64, 1093])
torch.Size([64, 1093])
torch.Size([64, 1093])
torch.Size([64, 1093])
torch.Size([64, 1093

  coords_tensor = torch.load(os.path.join(protein_dir, COORDS))
  delta_g_tensor = torch.load(os.path.join(protein_dir, DELTA_G))
  mask_tensor = torch.load(os.path.join(protein_dir, MASKS))
  one_hot_tensor = torch.load(os.path.join(protein_dir, ONE_HOT))
  embedding_tensor = torch.load(filename, map_location=torch.device('cpu'))  # Ensure loading on CPU


Protein: 1BK2
torch.Size([56, 1093])
torch.Size([56, 1093])
torch.Size([56, 1093])
torch.Size([56, 1093])
torch.Size([56, 1093])
torch.Size([56, 1093])
torch.Size([56, 1093])
torch.Size([56, 1093])
torch.Size([56, 1093])
torch.Size([56, 1093])
torch.Size([56, 1093])
torch.Size([56, 1093])
torch.Size([56, 1093])
torch.Size([56, 1093])
torch.Size([56, 1093])
torch.Size([56, 1093])
torch.Size([56, 1093])
torch.Size([56, 1093])
torch.Size([56, 1093])
torch.Size([56, 1093])
torch.Size([56, 1093])
torch.Size([56, 1093])
torch.Size([56, 1093])
torch.Size([56, 1093])
torch.Size([56, 1093])
torch.Size([56, 1093])
torch.Size([56, 1093])
torch.Size([56, 1093])
torch.Size([56, 1093])
torch.Size([56, 1093])
torch.Size([56, 1093])
torch.Size([56, 1093])
torch.Size([56, 1093])
torch.Size([56, 1093])
torch.Size([56, 1093])
torch.Size([56, 1093])
torch.Size([56, 1093])
torch.Size([56, 1093])
torch.Size([56, 1093])
torch.Size([56, 1093])
torch.Size([56, 1093])
torch.Size([56, 1093])
torch.Size([56, 1093

  coords_tensor = torch.load(os.path.join(protein_dir, COORDS))
  delta_g_tensor = torch.load(os.path.join(protein_dir, DELTA_G))
  mask_tensor = torch.load(os.path.join(protein_dir, MASKS))
  one_hot_tensor = torch.load(os.path.join(protein_dir, ONE_HOT))
  embedding_tensor = torch.load(filename, map_location=torch.device('cpu'))  # Ensure loading on CPU


Protein: 1BNZ
torch.Size([62, 1093])
torch.Size([62, 1093])
torch.Size([62, 1093])
torch.Size([62, 1093])
torch.Size([62, 1093])
torch.Size([62, 1093])
torch.Size([62, 1093])
torch.Size([62, 1093])
torch.Size([62, 1093])
torch.Size([62, 1093])
torch.Size([62, 1093])
torch.Size([62, 1093])
torch.Size([62, 1093])
torch.Size([62, 1093])
torch.Size([62, 1093])
torch.Size([62, 1093])
torch.Size([62, 1093])
torch.Size([62, 1093])
torch.Size([62, 1093])
torch.Size([62, 1093])
torch.Size([62, 1093])
torch.Size([62, 1093])
torch.Size([62, 1093])
torch.Size([62, 1093])
torch.Size([62, 1093])
torch.Size([62, 1093])
torch.Size([62, 1093])
torch.Size([62, 1093])
torch.Size([62, 1093])
torch.Size([62, 1093])
torch.Size([62, 1093])
torch.Size([62, 1093])
torch.Size([62, 1093])
torch.Size([62, 1093])
torch.Size([62, 1093])
torch.Size([62, 1093])
torch.Size([62, 1093])
torch.Size([62, 1093])
torch.Size([62, 1093])
torch.Size([62, 1093])
torch.Size([62, 1093])
torch.Size([62, 1093])
torch.Size([62, 1093

  coords_tensor = torch.load(os.path.join(protein_dir, COORDS))
  delta_g_tensor = torch.load(os.path.join(protein_dir, DELTA_G))
  mask_tensor = torch.load(os.path.join(protein_dir, MASKS))
  one_hot_tensor = torch.load(os.path.join(protein_dir, ONE_HOT))
  embedding_tensor = torch.load(filename, map_location=torch.device('cpu'))  # Ensure loading on CPU


Protein: 1CSQ
torch.Size([67, 1093])
torch.Size([67, 1093])
torch.Size([67, 1093])
torch.Size([67, 1093])
torch.Size([67, 1093])
torch.Size([67, 1093])
torch.Size([67, 1093])
torch.Size([67, 1093])
torch.Size([67, 1093])
torch.Size([67, 1093])
torch.Size([67, 1093])
torch.Size([67, 1093])
torch.Size([67, 1093])
torch.Size([67, 1093])
torch.Size([67, 1093])
torch.Size([67, 1093])
torch.Size([67, 1093])
torch.Size([67, 1093])
torch.Size([67, 1093])
torch.Size([67, 1093])
torch.Size([67, 1093])
torch.Size([67, 1093])
torch.Size([67, 1093])
torch.Size([67, 1093])
torch.Size([67, 1093])
torch.Size([67, 1093])
torch.Size([67, 1093])
torch.Size([67, 1093])
torch.Size([67, 1093])
torch.Size([67, 1093])
torch.Size([67, 1093])
torch.Size([67, 1093])
torch.Size([67, 1093])
torch.Size([67, 1093])
torch.Size([67, 1093])
torch.Size([67, 1093])
torch.Size([67, 1093])
torch.Size([67, 1093])
torch.Size([67, 1093])
torch.Size([67, 1093])
torch.Size([67, 1093])
torch.Size([67, 1093])
torch.Size([67, 1093

  coords_tensor = torch.load(os.path.join(protein_dir, COORDS))
  delta_g_tensor = torch.load(os.path.join(protein_dir, DELTA_G))
  mask_tensor = torch.load(os.path.join(protein_dir, MASKS))
  one_hot_tensor = torch.load(os.path.join(protein_dir, ONE_HOT))
  embedding_tensor = torch.load(filename, map_location=torch.device('cpu'))  # Ensure loading on CPU


Protein: 1E0L
torch.Size([37, 1093])
torch.Size([37, 1093])
torch.Size([37, 1093])
torch.Size([37, 1093])
torch.Size([37, 1093])
torch.Size([37, 1093])
torch.Size([37, 1093])
torch.Size([37, 1093])
torch.Size([37, 1093])
torch.Size([37, 1093])
torch.Size([37, 1093])
torch.Size([37, 1093])
torch.Size([37, 1093])
torch.Size([37, 1093])
torch.Size([37, 1093])
torch.Size([37, 1093])
torch.Size([37, 1093])
torch.Size([37, 1093])
torch.Size([37, 1093])
torch.Size([37, 1093])
torch.Size([37, 1093])
torch.Size([37, 1093])
torch.Size([37, 1093])
torch.Size([37, 1093])
torch.Size([37, 1093])
torch.Size([37, 1093])
torch.Size([37, 1093])
torch.Size([37, 1093])
torch.Size([37, 1093])
torch.Size([37, 1093])
torch.Size([37, 1093])
torch.Size([37, 1093])
torch.Size([37, 1093])
torch.Size([37, 1093])
torch.Size([37, 1093])
torch.Size([37, 1093])
torch.Size([37, 1093])
torch.Size([37, 1093])
torch.Size([37, 1093])
torch.Size([37, 1093])
torch.Size([37, 1093])
torch.Size([37, 1093])
torch.Size([37, 1093

In [31]:
# class ProteinDataset(torch.utils.data.Dataset):
#     def __init__(self, tensor_root_dir, mutations_root_dir):
#         self.tensor_root_dir = tensor_root_dir
#         self.mutations_root_dir = mutations_root_dir
#         self.protein_dirs = os.listdir(self.tensor_root_dir)

#     def __len__(self):
#         return len(self.protein_dirs)

#     def __getitem__(self, idx):
#         protein_dir = os.path.join(self.tensor_root_dir, self.protein_dirs[idx])
#         mutations_path = os.path.join(self.mutations_root_dir, f"{self.protein_dirs[idx]}.csv")
        
#         mutations = pd.read_csv(mutations_path)
#         coords_tensor = torch.load(os.path.join(protein_dir, "coords_tensor.pt"), map_location='cpu')
#         delta_g_tensor = torch.load(os.path.join(protein_dir, "deltaG.pt"), map_location='cpu')
#         mask_tensor = torch.load(os.path.join(protein_dir, "mask_tensor.pt"), map_location='cpu')
#         one_hot_tensor = torch.load(os.path.join(protein_dir, "one_hot_encodings.pt"), map_location='cpu')
#         prott5_tensor = torch.load(os.path.join(protein_dir, "prott5_embeddings/prott5_embedding_0.pt"), map_location='cpu')

#         return {
#             'name': self.protein_dirs[idx],
#             'mutations': mutations.to_dict(orient="records"),  # Convert DataFrame to a list of dicts
#             'coords': coords_tensor,
#             'one_hot': one_hot_tensor,
#             'delta_g': delta_g_tensor,
#             'masks': mask_tensor,
#             'prott5': prott5_tensor
#         }


In [32]:
# import numpy as np
# import torch
# import os
# import pandas as pd
# from torch.utils.data import DataLoader
# from Utils.train_utils import get_graph, get_unfolded_graph
# from model.hydro_net import PEM
# from model.model_cfg import CFG

# # Constants
# NUMBER_OF_VARIANTS = 128
# NANO_TO_ANGSTROM = 0.1
# DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
# REG_LAMBDA = 0.01  # Same as fine-tuning model

# # Load Model
# # Import the model
# model = PEM(layers=CFG.num_layers,gaussian_coef=CFG.gaussian_coef).to(CFG.device)
# # Upload model weights
# CFG.model_path = '../data/Trained_models/'
# epoch = 25
# model_dict = torch.load(CFG.model_path+f"{epoch}_final_model.pt", map_location=torch.device('cpu'))
# model.load_state_dict(model_dict['model_state_dict'])
# model.eval()

# # Normalize batch function
# def normalize_batch(batch):
#     batch['one_hot'] = batch['one_hot'][:, :, :, :-1]  # Remove last channel
#     batch['coords'] = batch['coords'] * NANO_TO_ANGSTROM  # Scale coordinates
#     return batch


# # Load dataset & dataloader
# protein_dataset = ProteinDataset("mutation_data/tensors", "mutation_data/mutations")
# protein_dataloader = DataLoader(protein_dataset, batch_size=1, shuffle=False, num_workers=0)

# # Function to compute deltaG
# def calculate_dG(protein_graph, protein_unfolded_graph):
#     with torch.no_grad():
#         Gf = model(protein_graph.unsqueeze(0)).cpu().numpy()[0]
#         Gu = model(protein_unfolded_graph.unsqueeze(0)).cpu().numpy()[0]
#         return Gu - Gf

# # Evaluation loop
# avg_dGs, avg_ddGs, correlations = [], [], []
# for batch in protein_dataloader:
#     batch = normalize_batch(batch)
    
#     # Compute wildtype energy
#     wt_index = 0  # Assuming wildtype is first in the list
#     protein_graph = get_graph(batch['coords'], batch['one_hot'][wt_index], batch['prott5'][wt_index], batch['masks'])
#     protein_unfolded_graph = get_unfolded_graph(batch['coords'], batch['one_hot'][wt_index], batch['prott5'][wt_index], batch['masks'])
#     pred_wt_deltaG = calculate_dG(protein_graph, protein_unfolded_graph)

#     # Compute mutation energies in batches
#     pred_mut_dG, pred_mut_ddG = [], []
#     true_dG = batch['delta_g'][1:NUMBER_OF_VARIANTS].cpu().numpy()  # True ΔG from dataset
#     for i in range(1, NUMBER_OF_VARIANTS):
#         protein_graph = get_graph(batch['coords'], batch['one_hot'][i], batch['prott5'][i], batch['masks'])
#         protein_unfolded_graph = get_unfolded_graph(batch['coords'], batch['one_hot'][i], batch['prott5'][i], batch['masks'])
#         pred_deltaG = calculate_dG(protein_graph, protein_unfolded_graph)
#         pred_mut_dG.append(pred_deltaG)
#         pred_mut_ddG.append(pred_deltaG - pred_wt_deltaG)

#     # Compute averages
#     avg_dG = np.mean(pred_mut_dG)
#     avg_ddG = avg_dG - pred_wt_deltaG
#     avg_dGs.append(avg_dG)
#     avg_ddGs.append(avg_ddG)

#     # Compute correlation
#     pred_mut_dG = np.array(pred_mut_dG)
#     mask = ~np.isnan(pred_mut_dG) & ~np.isnan(true_dG)  # Ensure no NaN values
#     correlation = np.corrcoef(pred_mut_dG[mask], true_dG[mask])[0, 1]
#     correlations.append(correlation)

# # Print results
# print(f"Avg deltaG for each protein: {avg_dGs}")
# print(f"Avg DDG for each protein: {avg_ddGs}")
# for idx, corr in enumerate(correlations):
#     print(f"Protein {idx + 1} correlation: {corr:.4f}")
