# Project: Protein stability prediction

In the project you will try to predict protein stability changes upon point mutations. 
We will use acuumulated data from experimental databases, i.e. the Megascale dataset. A current [pre-print paper](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC10402116/) has already preprocessed the dataset and created homology reduced data splits. We will reuse these. To do so, download the data folder from [here](https://polybox.ethz.ch/index.php/s/txvcb5jKy1A0TbY) and unzip it.  

The data includes measurements of changes in the Gibbs free enrgy ($\Delta \Delta G $). 
This will be the value that you will have to predict for a given protein with a point mutation. 
As input data you can use the protein sequence or a protein embedding retreived from ESM, a state of the art protein model.  

Here we will use protein embeddings computed by ESM as input. 
We provide precomputed embeddings from the last layer of the smallest ESM model. You can adjust the Dataloader's code to load the embedding of the wild type or of the mutaed sequence or both. You can use it however you like. This is just to provide you easy access to embeddings. If you want to compute your own embeddings from other layers or models you can do that, too. 

Below we provide you with a strcuture for the project that you can start with.  
Edit the cells to your liking and add more code to create your final model.

## Imports

In [None]:
import os 
import numpy as np
import pandas as pd
import scipy
import sklearn.metrics as skmetrics

# plotting
import matplotlib.pyplot as plt
import seaborn as sns

# Pytorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import lightning as L

import torchmetrics
from torchmetrics.regression import PearsonCorrCoef

## Dataloading

We are using the Megascale dataset. The train, validation and test sets are already predefined.  
As mentioned, we provide embeddings from the last layer of ESM as input. You can access either the wild type or the mutated sequence and you could also further adjsut the embeddings. 
Here we have an embedding representing the complete sequence. It was computed by averaging over the embeddings per residue in the sequence. 

The ``Dataset`` classes return tuples of ``(embedding, ddg_value)``.

In [None]:
# the dataloaders load the tensors from memory one by one, could potentially become a bottleneck

class ProtEmbeddingDataset(Dataset):
    """
    Dataset for the embeddings of the mutated sequences
    You can the get_item() method to return the data in the format you want
    """
    def __init__(self, tensor_folder, csv_file, id_col="name", label_col="ddG_ML"):
        """
        Initialize the dataset
        input at init: 
            tensor_folder: path to the directory with the embeddings we want to use, eg. "/home/data/mega_train_embeddings"
            cvs_file: path to the csv file corresponding to the data, eg. "home/data/mega_train.csv"
        """
        self.tensor_folder = tensor_folder
        self.df = pd.read_csv(csv_file, sep=",")
        # only use the mutation rows
        self.df = self.df[self.df.mut_type!="wt"]
        # get the labels and ids
        self.labels = torch.tensor(self.df[label_col].values)
        self.ids = self.df[id_col].values
        self.wt_names = self.df["WT_name"].values

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        # load embeddings
        # mutation embedding
        tensor_path = os.path.join(self.tensor_folder, self.ids[idx] + ".pt")
        tensor = torch.load(tensor_path)['mean_representations'][6]

        # wildtype embedding, uncomment if you want to use this, too
        #tensor_path_wt = os.path.join(self.tensor_folder, self.wt_names[idx] + ".pt")
        #tensor_wt = torch.load(tensor_path_wt)['mean_representations'][6]

        label = self.labels[idx] # ddG value
        # returns a tuple of the input embedding and the target ddG values
        return tensor, label.float()
    


In [None]:
# usage 
# make sure to adjust the paths to where your files are located
dataset_train = ProtEmbeddingDataset('project_data/mega_train_embeddings', 'project_data/mega_train.csv')
dataset_val = ProtEmbeddingDataset('project_data/mega_val_embeddings', 'project_data/mega_val.csv')
dataset_test = ProtEmbeddingDataset('project_data/mega_test_embeddings', 'project_data/mega_test.csv')

dataloader_train = DataLoader(dataset_train, batch_size=1024, shuffle=True, num_workers=16)
dataloader_val = DataLoader(dataset_val, batch_size=512, shuffle=False, num_workers=16)
dataloader_test = DataLoader(dataset_test, batch_size=32, shuffle=False)

## Model architecture and training

Now it's your turn. Create a model trained on the embeddings and the corresponding ddG values.  
Be aware that this is not a classification task, but a regression task. You want to predict a continuous number that is as close to the measured $\Delta \Delta G $ value as possible.
You will need to adjust your architecture and loss accordingly.

Train the model with the predefined dataloaders. And try to improve the model. 
Only test on the test set at the very end, when you have finished fine-tuning you model. 

In [None]:
# your code

## Validation and visualization

To get a good feeling of how the model is performing and to compare with literature, compute the Pearson and Spearman correlations.
You can also plot the predictions in a scatterplot. We have added some code for that. 

In [None]:
preds =[]
all_y = []
# save all predictions
for batch in dataloader_val:
    # adjust this to work with your model
    x,y = batch
    y_hat = model(x)
    preds.append(y_hat.squeeze().detach().numpy())
    all_y.append(y.detach().numpy())

# concatenate and plot
preds= np.concatenate(preds)
all_y = np.concatenate(all_y)

sns.regplot(x=preds,y=all_y)
plt.xlabel("Predicted ddG")
plt.ylabel("Measured ddG")

# get RMSE, Pearson and Spearman correlation 
print("RMSE:", skmetrics.mean_squared_error(all_y, preds, squared="False"))
print("Pearson r:", scipy.stats.pearsonr(preds, all_y))
print("Spearman r:", scipy.stats.spearmanr(preds, all_y))