# Data preparation for ML pipeline
---
In this script I am preparing datasets for my machine learning pipeline. My objective is to train machine learning models for the prediction of variant effects, that are either family agnostic or family aware. The family agnostic models will be trained on the wt sequence and the sequence variants alone, while the family aware models will additionally be trained on MSA for the protein of interest.

The proteins I will be analyzing have been tested using deep mutational scanning experiments and were curated and collected in the [ProteinGym](https://github.com/OATML-Markslab/ProteinGym) by the [OATML](https://oatml.cs.ox.ac.uk/) research group.

Personally I am intersted in prediction the effects of mutations on enzyme activity, therefore I will be filtering the data for DMS experiments which either directly or indirectly measured the enzyme activity.

In [11]:
import pandas as pd
import os
import shutil
import sys
sys.path.append('../../../src/')
import proteusAI.io_tools as io_tools
import proteusAI.data_tools as data_tools
import proteusAI.ml_tools.esm_tools as esm_tools
import proteusAI.ml_tools.torch_tools as torch_tools
from sklearn.preprocessing import MinMaxScaler
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset, random_split
import matplotlib.pyplot as plt
from tqdm import tqdm
from torch.optim.lr_scheduler import StepLR
from torch import optim

protein_gym_path = '../../example_data/ProteinGym_substitutions/'
metadata = pd.read_csv('../../example_data/ProteinGym_reference_file_substitutions.csv')

## Collect all studies related to enzyme activity and move them to datasets

In [8]:
enzyme_data = metadata[metadata['selection_assay'].str.contains('Enzyme function|activity', case=False, na=False)]
enzyme_data = enzyme_data[enzyme_data['seq_len'] < 1024]
enzyme_data

Unnamed: 0,DMS_id,DMS_filename,UniProt_ID,taxon,target_seq,seq_len,includes_multiple_mutants,DMS_total_number_mutants,DMS_number_single_mutants,DMS_number_multiple_mutants,...,MSA_N_eff,MSA_Neff_L,MSA_Neff_L_category,MSA_num_significant,MSA_num_significant_L,raw_DMS_filename,raw_DMS_phenotype_name,raw_DMS_directionality,raw_DMS_mutant_column,weight_file_name
10,AMIE_PSEAE_Wrenbeck_2017,AMIE_PSEAE_Wrenbeck_2017.csv,AMIE_PSEAE,Prokaryote,MRHGDISSSNDTVGVAVVNYKMPRLHTAAEVLDNARKIAEMIVGMK...,346,False,6227,6227,0,...,29959.3,119.359761,high,557,2.219124,AMIE_PSEAE_Wrenbeck_2017.csv,isobutyramide_normalized_fitness,1,mutant,AMIE_PSEAE_theta_0.2.npy
21,CCDB_ECOLI_Tripathi_2016,CCDB_ECOLI_Tripathi_2016.csv,CCDB_ECOLI,Prokaryote,MQFKVYTYKRESRYRLFVDVQSDIIDTPGRRMVIPLASARLLSDKV...,101,False,1663,1663,0,...,16821.5,195.598837,high,61,0.709302,CCDB_ECOLI_Tripathi_2016.csv,score,-1,mutant,CCDB_ECOLI_theta_0.2.npy
22,CP2C9_HUMAN_Amorosi_abundance_2021,CP2C9_HUMAN_Amorosi_abundance_2021.csv,CP2C9_HUMAN,Human,MDSLVVLVLCLSCLLLLSLWRQSSGRGKLPPGPTPLPVIGNILQIG...,490,False,6370,6370,0,...,81212.1,187.124654,high,1092,2.516129,CP2C9_HUMAN_Amorosi_2021.csv,abundance_score,1,variant,CP2C9_HUMAN_theta_0.2.npy
23,CP2C9_HUMAN_Amorosi_activity_2021,CP2C9_HUMAN_Amorosi_activity_2021.csv,CP2C9_HUMAN,Human,MDSLVVLVLCLSCLLLLSLWRQSSGRGKLPPGPTPLPVIGNILQIG...,490,False,6142,6142,0,...,81212.1,187.124654,high,1092,2.516129,CP2C9_HUMAN_Amorosi_2021.csv,activity_score,1,variant,CP2C9_HUMAN_theta_0.2.npy
43,MSH2_HUMAN_Jia_2020,MSH2_HUMAN_Jia_2020.csv,MSH2_HUMAN,Human,MAVQPKETLQLESAAEVGFVRFFQGMPEKPTTTVRLFDRGDFYTAH...,934,False,16749,16749,0,...,10716.4,12.727316,medium,1035,1.229216,MSH2_HUMAN_Jia_2020.csv,LOF score,-1,Variant,MSH2_HUMAN_theta_0.2.npy
58,PTEN_HUMAN_Mighell_2018,PTEN_HUMAN_Mighell_2018.csv,PTEN_HUMAN,Human,MTAIIKEIVSRNKRRYQEDGFDLDLTYIYPNIIAMGFPAERLEGVY...,403,False,7260,7260,0,...,1425.3,4.70396,medium,52,0.171617,PTEN_HUMAN_Mighell_2018.csv,Fitness_score,1,mutant,PTEN_HUMAN_theta_0.2.npy
60,Q59976_STRSQ_Romero_2015,Q59976_STRSQ_Romero_2015.csv,Q59976_STRSQ,Prokaryote,MVPAAQQTAMAPDAALTFPEGFLWGSATASYQIEGAAAEDGRTPSI...,501,False,2999,2999,0,...,13981.2,31.631674,medium,850,1.923077,Q59976_STRSQ_Romero_2015.csv,enrichment,1,mutant,Q59976_STRSQ_theta_0.2.npy
66,RL401_YEAST_Roscoe_2014,RL401_YEAST_Roscoe_2014.csv,RL40A_YEAST,Eukaryote,MQIFVKTLTGKTITLEVESSDTIDNVKSKIQDKEGIPPDQQRLIFA...,128,False,1380,1380,0,...,3974.4,44.65618,medium,12,0.134831,RL401_YEAST_Roscoe_2014.csv,rel_react,1,mutant,RL401_YEAST_theta_0.2.npy
72,SRC_HUMAN_Ahler_CD_2019,SRC_HUMAN_Ahler_CD_2019.csv,SRC_HUMAN,Human,MGSNKSKPKDASQRRRSLEPAENVHGAGGGAFPASQTPSKPASADG...,536,False,3372,3372,0,...,1405.1,3.245035,medium,86,0.198614,SRC_HUMAN_Ahler_CD_2019.csv,Activity_Score,1,mutant_uniprot_1,SRC_HUMAN_theta_0.2.npy


In [16]:
enzyme_data.to_csv('../enzyme_metadata.csv')

In [10]:
mutant_datasets = [f for f in os.listdir('../../example_data/ProteinGym_substitutions/') if f.split('.')[0] in enzyme_data['DMS_id'].to_list()]

for f in mutant_datasets:
    source_file = os.path.join('../../example_data/ProteinGym_substitutions/', f)
    destination_dir = os.path.join('../datasets/', f)
    shutil.copy(source_file, destination_dir)

## normalize scores between 0 and 1

In [15]:
# Assuming dfs is your list of dataframes
scaler = MinMaxScaler(feature_range=(0, 1))

for i, name in enumerate(mutant_datasets):
    # load dataset
    df = pd.read_csv(f"../datasets/{name}")
    
    # Scale 'DMS_score' and store in 'y'
    df['y'] = scaler.fit_transform(df[['DMS_score']])
    
    # Save the dataframe with the same name at the same location
    df.to_csv(f"../datasets/{name}", index=False)

In [3]:
seqs = enzyme_data['target_seq'].to_list()
names = enzyme_data['UniProt_ID'].to_list()

## BLAST and MSA
---
The multiple sequence alignment for the family aware methods will be created from a BLAST search of the individual proteins using the UniRef 100 database as a reference database. The enzymes.fasta file is the uploaded on the [UniProt BLAST](https://www.uniprot.org/blast) tool as input. Blasting using python is to slow, due to the disencouragment of autonomous agents. Sequences found in the BLAST result will be parsed and a MSA alignment will be created using the muscle app.

In [4]:
io_tools.write_fasta(names, seqs, dest='../../example_data/ProteinGym_substitutions/enzymes.fasta')

In [5]:
def parse_blast_output(blast_output_file, identity_threshold=100):
    with open(blast_output_file, 'r') as file:
        lines = file.readlines()

    sequences = {}
    current_id = ''
    current_seq = ''
    current_identity = 0
    for line in lines:
        if line.startswith('>'):
            if current_id and current_identity <= identity_threshold:  # if not the first sequence and identity is below threshold
                sequences[f"{current_id}_{current_identity}%"] = ''.join([aa for aa in current_seq if aa in 'ACDEFGHIKLMNPQRSTVWY'])
            current_id = line.split(' ')[0][1:]  # remove '>' and get id
            current_seq = ''
            current_identity = 0
        elif line.startswith(' Identities = '):
            current_identity = float(line.split(",")[0].split("(")[1].split("%")[0])
        elif line.startswith('Sbjct'):
            current_seq += line.split()[2]  # append sequence part
    # add last sequence
    if current_id and current_identity <= identity_threshold:  # if there was at least one sequence and identity is below threshold
        sequences[f"{current_id}_{current_identity}%"] = ''.join([aa for aa in current_seq if aa in 'ACDEFGHIKLMNPQRSTVWY'])

    return sequences


In [6]:
msa_results = {}
for name in names:
    msa_results[name] = parse_blast_output(f'../../example_data/MSA/{name}.txt')

In [7]:
for i in range(len(names)):
    _names = [names[i]]
    _seqs = [seqs[i]]
    break
    try:
        for name, seq in msa_results[names[i]].items():
            _names.append(name)
            _seqs.append(seq)
        data_tools.align_proteins(_names, _seqs, save_fasta=f'../../example_data/MSA/{names[i]}.fasta', plot_results=False)
    except:
        n = [names[i]]
        seq = [seqs[i]]
        io_tools.write_fasta(n, seq)

## MSA to one hot encoding

In [8]:
msa_results = io_tools.load_all_fastas('../../example_data/MSA/')

In [9]:
d = msa_results.values()

In [10]:
alphabet = esm_tools.alphabet.to_dict()

encodings = {}
for key, value in msa_results.items():
    sequences = value[1]
    e = torch_tools.one_hot_encoder(sequences, alphabet)
    encoded_sequences = [encoding for encoding in e]
    encodings[key] = pd.DataFrame({
        'label':value[0], 
        'x':encoded_sequences
    })

encodings['MSH2_HUMAN.fasta']

Unnamed: 0,label,x
0,MSH2_HUMAN,"[[tensor(0.), tensor(0.), tensor(0.), tensor(0..."
1,UR100:UniRef100_P43246_100.0%,"[[tensor(0.), tensor(0.), tensor(0.), tensor(0..."
2,UR100:UniRef100_G3QW00_99.0%,"[[tensor(0.), tensor(0.), tensor(0.), tensor(0..."
3,UR100:UniRef100_A0A8I3B3M9_99.0%,"[[tensor(0.), tensor(0.), tensor(0.), tensor(0..."
4,UR100:UniRef100_UPI0015619230_99.0%,"[[tensor(0.), tensor(0.), tensor(0.), tensor(0..."
...,...,...
246,UR100:UniRef100_UPI001C69C888_96.0%,"[[tensor(0.), tensor(0.), tensor(0.), tensor(0..."
247,UR100:UniRef100_A0A8D0V6L8_91.0%,"[[tensor(0.), tensor(0.), tensor(0.), tensor(0..."
248,UR100:UniRef100_F1SQH6_96.0%,"[[tensor(0.), tensor(0.), tensor(0.), tensor(0..."
249,UR100:UniRef100_A0A8D1RBY7_96.0%,"[[tensor(0.), tensor(0.), tensor(0.), tensor(0..."


## Bayesian VAE
---
below I build a simple prototype of a bayesian VAE for all the proteins

In [11]:
class Encoder(nn.Module):
    def __init__(self, input_dim, hidden_dims, z_dim, dropout):
        super(Encoder, self).__init__()
        self.layers = nn.ModuleList()
        last_dim = input_dim
        for hidden_dim in hidden_dims:
            linear_layer = nn.Linear(last_dim, hidden_dim)
            nn.init.kaiming_normal_(linear_layer.weight)
            self.layers.append(linear_layer)
            self.layers.append(nn.Dropout(dropout))
            self.layers.append(nn.ReLU())
            last_dim = hidden_dim
        self.mu = nn.Linear(last_dim, z_dim)
        nn.init.kaiming_normal_(self.mu.weight)
        self.var = nn.Linear(last_dim, z_dim)
        nn.init.kaiming_normal_(self.var.weight)

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        z_mu = self.mu(x)
        z_var = self.var(x)
        return z_mu, z_var


class Decoder(nn.Module):
    def __init__(self, z_dim, hidden_dims, output_dim, dropout):
        super(Decoder, self).__init__()
        self.layers = nn.ModuleList()
        last_dim = z_dim
        for hidden_dim in reversed(hidden_dims):
            linear_layer = nn.Linear(last_dim, hidden_dim)
            nn.init.kaiming_normal_(linear_layer.weight)
            self.layers.append(linear_layer)
            self.layers.append(nn.Dropout(dropout))
            self.layers.append(nn.ReLU())
            last_dim = hidden_dim
        self.out = nn.Linear(last_dim, output_dim)
        nn.init.kaiming_normal_(self.out.weight)

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        predicted = torch.sigmoid(self.out(x))
        return predicted


class VAE(nn.Module):
    def __init__(self, input_dim, hidden_dims, z_dim, dropout=0.0):
        super(VAE, self).__init__()
        self.encoder = Encoder(input_dim, hidden_dims, z_dim, dropout)
        self.decoder = Decoder(z_dim, hidden_dims, input_dim, dropout)

    def forward(self, x):
        z_mu, z_var = self.encoder(x)
        std = torch.exp(z_var / 2)
        eps = torch.randn_like(std)
        x_sample = eps.mul(std).add_(z_mu)
        predicted = self.decoder(x_sample)
        return predicted, z_mu, z_var

In [12]:
class CustomDataset(Dataset):
    def __init__(self, data):
        self.data = data
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        label = self.data['label'].iloc[index]
        x = self.data['x'].iloc[index]
        return x

In [13]:
names = []
datasets = []
for key in encodings.keys():
    names.append(str(key)[:-6])
    datasets.append(CustomDataset(encodings[key]))

In [14]:
def train_vae(train_data, val_data, model, optimizer, criterion, scheduler, epochs, device, model_name, verbose=False):
    best_val_loss = float('inf')
    train_losses = []
    val_losses = []

    pbar = tqdm(range(epochs), desc='Training')
    for epoch in pbar:
        model.train()
        train_loss = 0
        num_examples = 0
        for batch in train_data:
            # Move the batch tensors to the right device
            batch = batch.to(device)
            batch = batch.view(batch.size(0), -1)
            
            optimizer.zero_grad()
            recon, mu, logvar = model(batch)
            
            loss = criterion(recon, batch, mu, logvar)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            num_examples += batch.size(0)
            
        average_train_loss = train_loss / num_examples
        train_losses.append(average_train_loss)
    
        if epoch % 100 == 0:
            if verbose:
                print(f"Epoch {epoch+1}, Train Loss: {average_train_loss:.4f}")

        # Validation
        model.eval()
        with torch.no_grad():
            val_loss = 0
            num_examples = 0
            for batch in val_data:
                 # flatten the batch
                batch = batch.view(batch.size(0), -1)
                batch = batch.to(device)
                recon, mu, logvar = model(batch)
                loss = criterion(recon, batch, mu, logvar)
                val_loss += loss.item()
                num_examples += batch.size(0)

            average_val_loss = val_loss / num_examples
            val_losses.append(average_val_loss)
            if epoch % 100 == 0:
                if verbose:
                    print(f"Epoch {epoch+1}, Val Loss: {average_val_loss:.4f}")
                    
        current_lr = optimizer.param_groups[0]['lr']
        pbar.set_postfix({'Train Loss': average_train_loss, 'Val Loss': average_val_loss, 'LR': current_lr})

        # Save model if it's the best so far
        if average_val_loss < best_val_loss:
            best_val_loss = average_val_loss
            torch.save(model.state_dict(), f'../checkpoints/{model_name}.pt')
            if verbose:
                print(f"Model saved at epoch {epoch+1}, Val Loss: {average_val_loss:.4f}")
            best_epoch = epoch
    
        scheduler.step()
        
    plot_losses(train_losses, val_losses, best_epoch, fname=model_name+'.png')

    return model

def plot_losses(train_losses, val_losses, best_epoch, fname=None):
    plt.figure(figsize=(10, 5))
    plt.plot(train_losses, label='Train Loss')
    plt.plot(val_losses, label='Validation Loss')
    plt.axvline(x=best_epoch, color='r', linestyle='--', label='Best Model')
    plt.title('Train and Validation Losses')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    if fname is not None:
        plt.savefig(fname)
    plt.show()

In [16]:
def criterion(recon_x, x, mu, logvar):
    BCE = F.binary_cross_entropy(recon_x, x, reduction='sum')
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD

In [18]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [19]:
batch_size = 2**10

for i, dat in enumerate(datasets):
    # define model name for saving
    model_name = names[i] + '_VAE'
    
    # Split the dataset into training and validation sets
    train_size = int(0.8 * len(dat))  # 80% for training
    val_size = len(dat) - train_size
    train_dataset, val_dataset = random_split(dat, [train_size, val_size])
    
    train_data = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_data = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    
    # Assuming each sequence in the dataset is one-hot encoded and is of shape (seq_len, alphabet_size)
    seq_len, alphabet_size = train_data.dataset[0].shape
    
    # Initialize model, optimizer and epochs
    model = VAE(input_dim=seq_len * alphabet_size, hidden_dims=[2048, 1024, 256], z_dim=64, dropout=0.1).to(device)
    optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)
    scheduler = StepLR(optimizer, step_size=100, gamma=0.1)
    epochs = 200
    
    # Train the model on the dataset
    print(f"Training {model_name} model...")
    model = train_vae(train_data, val_data, model, optimizer, criterion, scheduler, epochs, device, model_name)

Training MSH2_HUMAN_VAE model...


Training:  11%|█████████▌                                                                             | 22/200 [00:27<03:44,  1.26s/it, Train Loss=330, Val Loss=374, LR=0.001]
[E thread_pool.cpp:113] Exception in thread pool task: mutex lock failed: Invalid argument


KeyboardInterrupt: 