# peptide-MHC binding affinity regression

In this notebook, we train and test the Attentive Variational Information Bottleneck on peptide+MHC class II data to predict binding affinity. We also do experiments with the baseline and ablation methods.

In [29]:
import pandas as pd
import torch
import numpy as np
import random

from vibtcr.dataset import TCRDataset
from vibtcr.mvib.mvib import MVIB
from vibtcr.mvib.mvib_trainer import TrainerMVIB

from torch.utils.data.sampler import WeightedRandomSampler
from sklearn.model_selection import train_test_split
from tqdm import tqdm

In [30]:
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
import pandas as pd
import torch

metrics = ['MSE', 'RMSE', 'R2']

def get_scores(y_true, y_pred):
    """
    Compute a df with all regression metrics and respective scores.
    """
    
    scores = [
        mean_squared_error(y_true, y_pred),
        mean_absolute_error(y_true, y_pred),
        r2_score(y_true, y_pred),
    ]
    
    df = pd.DataFrame(data={'score': scores, 'metrics': metrics})
    return df

In [31]:
def set_random_seed(random_seed):
    random.seed(random_seed)
    np.random.seed(random_seed)
    torch.manual_seed(random_seed)
    torch.cuda.manual_seed(random_seed)
    torch.cuda.manual_seed_all(random_seed)

In [1]:
import os
import pandas as pd
login = os.getlogin( )
DATA_ROOT = f"/home/{login}/Git/tcr/data/mhc/NetMHCIIpan_train/"
RESULTS_BASE = f"/home/{login}/Git/tcr/notebooks/notebooks.regression/results/"

In [33]:
device = torch.device('cuda:4')

batch_size = 4096
epochs = 1000
lr = 1e-3

z_dim = 150
early_stopper_patience = 30
monitor = 'loss'
lr_scheduler_param = 10
loss = "mse"

beta = 1e-6

# PoE

In [34]:
joint_posterior = "poe"

In [35]:
df = pd.read_csv(DATA_BASE + 'netmhcIIpan4.csv')

for i in range(5):  # 5 independent train/test splits
    set_random_seed(i)

    df_train, df_test = train_test_split(df.copy(), test_size=0.2, random_state=i)

    scaler = TCRDataset(df_train.copy(), torch.device("cpu"), cdr3b_col='mhc', cdr3a_col=None, gt_col='BA').scaler

    ds_test = TCRDataset(df_test, torch.device("cpu"), cdr3b_col='mhc', cdr3a_col=None, scaler=scaler, gt_col='BA')

    df_train, df_val = train_test_split(df_train, test_size=0.2, random_state=i)
        
    # train loader with balanced sampling
    ds_train = TCRDataset(df_train, device, cdr3b_col='mhc', cdr3a_col=None, scaler=scaler, gt_col='BA')
    train_loader = torch.utils.data.DataLoader(
        ds_train,
        batch_size=batch_size,
        shuffle=True
    )
    
    # val loader with balanced sampling
    ds_val = TCRDataset(df_val, device, cdr3b_col='mhc', cdr3a_col=None, scaler=scaler, gt_col='BA')
    val_loader = torch.utils.data.DataLoader(
        ds_val,
        batch_size=batch_size,
        shuffle=True
    )

    model = MVIB(z_dim=z_dim, device=device, joint_posterior=joint_posterior).to(device)

    trainer = TrainerMVIB(
        model,
        epochs=epochs,
        lr=lr,
        beta=beta,
        checkpoint_dir=".",
        mode="bimodal",
        lr_scheduler_param=lr_scheduler_param,
        loss=loss
    )
    checkpoint = trainer.train(train_loader, val_loader, early_stopper_patience, monitor)    
    
    # test
    model = MVIB.from_checkpoint(checkpoint, torch.device("cpu"))
    pred = model.classify(pep=ds_test.pep, cdr3b=ds_test.cdr3b, cdr3a=None)
    pred = pred.detach().numpy()
    df_test['prediction_'+str(i)] = pred.squeeze().tolist()

    # save results for further analysis
    df_test.to_csv(
        RESULTS_BASE + f"mvib.{joint_posterior}.pMHC.rep-{i}.csv",
        index=False
    )

[VAL] Best epoch 296 | Best val score 0.026977 | DKL-prior 0.000165 | MSE 0.026813 |:  32%|███▎      | 325/1000 [08:48<18:18,  1.63s/it]
[VAL] Best epoch 402 | Best val score 0.026374 | DKL-prior 0.000178 | MSE 0.026195 |:  43%|████▎     | 431/1000 [11:37<15:20,  1.62s/it]
[VAL] Best epoch 320 | Best val score 0.026815 | DKL-prior 0.000166 | MSE 0.026649 |:  35%|███▍      | 349/1000 [09:25<17:34,  1.62s/it]
[VAL] Best epoch 389 | Best val score 0.026709 | DKL-prior 0.000170 | MSE 0.026539 |:  42%|████▏     | 418/1000 [11:35<16:08,  1.66s/it]
[VAL] Best epoch 302 | Best val score 0.027491 | DKL-prior 0.000165 | MSE 0.027326 |:  33%|███▎      | 331/1000 [09:33<19:19,  1.73s/it]


# AoE

In [36]:
joint_posterior = "aoe"

In [37]:
df = pd.read_csv(DATA_BASE + 'netmhcIIpan4.csv')

for i in range(5):  # 5 independent train/test splits
    set_random_seed(i)

    df_train, df_test = train_test_split(df.copy(), test_size=0.2, random_state=i)

    scaler = TCRDataset(df_train.copy(), torch.device("cpu"), cdr3b_col='mhc', cdr3a_col=None, gt_col='BA').scaler

    ds_test = TCRDataset(df_test, torch.device("cpu"), cdr3b_col='mhc', cdr3a_col=None, scaler=scaler, gt_col='BA')

    df_train, df_val = train_test_split(df_train, test_size=0.2, random_state=i)
        
    # train loader with balanced sampling
    ds_train = TCRDataset(df_train, device, cdr3b_col='mhc', cdr3a_col=None, scaler=scaler, gt_col='BA')
    train_loader = torch.utils.data.DataLoader(
        ds_train,
        batch_size=batch_size,
        shuffle=True
    )
    
    # val loader with balanced sampling
    ds_val = TCRDataset(df_val, device, cdr3b_col='mhc', cdr3a_col=None, scaler=scaler, gt_col='BA')
    val_loader = torch.utils.data.DataLoader(
        ds_val,
        batch_size=batch_size,
        shuffle=True
    )

    model = MVIB(z_dim=z_dim, device=device, joint_posterior=joint_posterior).to(device)

    trainer = TrainerMVIB(
        model,
        epochs=epochs,
        lr=lr,
        beta=beta,
        checkpoint_dir=".",
        mode="bimodal",
        lr_scheduler_param=lr_scheduler_param,
        loss=loss
    )
    checkpoint = trainer.train(train_loader, val_loader, early_stopper_patience, monitor)    
    
    # test
    model = MVIB.from_checkpoint(checkpoint, torch.device("cpu"))
    pred = model.classify(pep=ds_test.pep, cdr3b=ds_test.cdr3b, cdr3a=None)
    pred = pred.detach().numpy()
    df_test['prediction_'+str(i)] = pred.squeeze().tolist()

    # save results for further analysis
    df_test.to_csv(
        RESULTS_BASE + f"mvib.{joint_posterior}.pMHC.rep-{i}.csv",
        index=False
    )

[VAL] Best epoch 191 | Best val score 0.026043 | DKL-prior 0.000214 | MSE 0.025829 |:  22%|██▏       | 220/1000 [06:45<23:56,  1.84s/it]
[VAL] Best epoch 282 | Best val score 0.025615 | DKL-prior 0.000205 | MSE 0.025409 |:  31%|███       | 311/1000 [09:20<20:41,  1.80s/it]
[VAL] Best epoch 263 | Best val score 0.025968 | DKL-prior 0.000174 | MSE 0.025795 |:  29%|██▉       | 292/1000 [08:45<21:13,  1.80s/it]
[VAL] Best epoch 233 | Best val score 0.026508 | DKL-prior 0.000201 | MSE 0.026308 |:  26%|██▌       | 262/1000 [07:52<22:11,  1.80s/it]
[VAL] Best epoch 255 | Best val score 0.026357 | DKL-prior 0.000150 | MSE 0.026207 |:  28%|██▊       | 284/1000 [08:32<21:32,  1.81s/it]


# Average Pooling of Experts

In [42]:
joint_posterior = "avg_pool"

In [43]:
df = pd.read_csv(DATA_BASE + 'netmhcIIpan4.csv')

for i in range(5):  # 5 independent train/test splits
    set_random_seed(i)

    df_train, df_test = train_test_split(df.copy(), test_size=0.2, random_state=i)

    scaler = TCRDataset(df_train.copy(), torch.device("cpu"), cdr3b_col='mhc', cdr3a_col=None, gt_col='BA').scaler

    ds_test = TCRDataset(df_test, torch.device("cpu"), cdr3b_col='mhc', cdr3a_col=None, scaler=scaler, gt_col='BA')

    df_train, df_val = train_test_split(df_train, test_size=0.2, random_state=i)
        
    # train loader with balanced sampling
    ds_train = TCRDataset(df_train, device, cdr3b_col='mhc', cdr3a_col=None, scaler=scaler, gt_col='BA')
    train_loader = torch.utils.data.DataLoader(
        ds_train,
        batch_size=batch_size,
        shuffle=True
    )
    
    # val loader with balanced sampling
    ds_val = TCRDataset(df_val, device, cdr3b_col='mhc', cdr3a_col=None, scaler=scaler, gt_col='BA')
    val_loader = torch.utils.data.DataLoader(
        ds_val,
        batch_size=batch_size,
        shuffle=True
    )

    model = MVIB(z_dim=z_dim, device=device, joint_posterior=joint_posterior).to(device)

    trainer = TrainerMVIB(
        model,
        epochs=epochs,
        lr=lr,
        beta=beta,
        checkpoint_dir=".",
        mode="bimodal",
        lr_scheduler_param=lr_scheduler_param,
        loss=loss
    )
    checkpoint = trainer.train(train_loader, val_loader, early_stopper_patience, monitor)    
    
    # test
    model = MVIB.from_checkpoint(checkpoint, torch.device("cpu"))
    pred = model.classify(pep=ds_test.pep, cdr3b=ds_test.cdr3b, cdr3a=None)
    pred = pred.detach().numpy()
    df_test['prediction_'+str(i)] = pred.squeeze().tolist()

    # save results for further analysis
    df_test.to_csv(
        RESULTS_BASE + f"mvib.{joint_posterior}.pMHC.rep-{i}.csv",
        index=False
    )

[VAL] Best epoch 262 | Best val score 0.028497 | DKL-prior 0.000207 | MSE 0.028290 |:  29%|██▉       | 291/1000 [07:51<19:07,  1.62s/it]
[VAL] Best epoch 332 | Best val score 0.028016 | DKL-prior 0.000209 | MSE 0.027807 |:  36%|███▌      | 361/1000 [09:46<17:19,  1.63s/it]
[VAL] Best epoch 374 | Best val score 0.027719 | DKL-prior 0.000211 | MSE 0.027508 |:  40%|████      | 403/1000 [10:55<16:11,  1.63s/it]
[VAL] Best epoch 327 | Best val score 0.028083 | DKL-prior 0.000205 | MSE 0.027878 |:  36%|███▌      | 356/1000 [09:34<17:19,  1.61s/it]
[VAL] Best epoch 300 | Best val score 0.028713 | DKL-prior 0.000210 | MSE 0.028503 |:  33%|███▎      | 329/1000 [09:01<18:24,  1.65s/it]


# Max Pooling of Experts

In [44]:
joint_posterior = "max_pool"

In [45]:
df = pd.read_csv(DATA_BASE + 'netmhcIIpan4.csv')

for i in range(5):  # 5 independent train/test splits
    set_random_seed(i)

    df_train, df_test = train_test_split(df.copy(), test_size=0.2, random_state=i)

    scaler = TCRDataset(df_train.copy(), torch.device("cpu"), cdr3b_col='mhc', cdr3a_col=None, gt_col='BA').scaler

    ds_test = TCRDataset(df_test, torch.device("cpu"), cdr3b_col='mhc', cdr3a_col=None, scaler=scaler, gt_col='BA')

    df_train, df_val = train_test_split(df_train, test_size=0.2, random_state=i)
        
    # train loader with balanced sampling
    ds_train = TCRDataset(df_train, device, cdr3b_col='mhc', cdr3a_col=None, scaler=scaler, gt_col='BA')
    train_loader = torch.utils.data.DataLoader(
        ds_train,
        batch_size=batch_size,
        shuffle=True
    )
    
    # val loader with balanced sampling
    ds_val = TCRDataset(df_val, device, cdr3b_col='mhc', cdr3a_col=None, scaler=scaler, gt_col='BA')
    val_loader = torch.utils.data.DataLoader(
        ds_val,
        batch_size=batch_size,
        shuffle=True
    )

    model = MVIB(z_dim=z_dim, device=device, joint_posterior=joint_posterior).to(device)

    trainer = TrainerMVIB(
        model,
        epochs=epochs,
        lr=lr,
        beta=beta,
        checkpoint_dir=".",
        mode="bimodal",
        lr_scheduler_param=lr_scheduler_param,
        loss=loss
    )
    checkpoint = trainer.train(train_loader, val_loader, early_stopper_patience, monitor)    
    
    # test
    model = MVIB.from_checkpoint(checkpoint, torch.device("cpu"))
    pred = model.classify(pep=ds_test.pep, cdr3b=ds_test.cdr3b, cdr3a=None)
    pred = pred.detach().numpy()
    df_test['prediction_'+str(i)] = pred.squeeze().tolist()

    # save results for further analysis
    df_test.to_csv(
        RESULTS_BASE + f"mvib.{joint_posterior}.pMHC.rep-{i}.csv",
        index=False
    )

[VAL] Best epoch 149 | Best val score 0.033145 | DKL-prior 0.000225 | MSE 0.032920 |:  18%|█▊        | 178/1000 [04:58<22:56,  1.67s/it]
[VAL] Best epoch 141 | Best val score 0.035133 | DKL-prior 0.000220 | MSE 0.034913 |:  17%|█▋        | 170/1000 [04:52<23:48,  1.72s/it]
[VAL] Best epoch 109 | Best val score 0.034712 | DKL-prior 0.000215 | MSE 0.034498 |:  14%|█▍        | 138/1000 [03:53<24:19,  1.69s/it]
[VAL] Best epoch 181 | Best val score 0.033408 | DKL-prior 0.000206 | MSE 0.033202 |:  21%|██        | 210/1000 [05:57<22:25,  1.70s/it]
[VAL] Best epoch 113 | Best val score 0.035078 | DKL-prior 0.000206 | MSE 0.034872 |:  14%|█▍        | 142/1000 [04:05<24:41,  1.73s/it]


In [46]:
RESULTS_BASE = f"/home/{login}/Git/tcr/notebooks/notebooks.regression/results/"
FIGURES_BASE = f"/home/{login}/Git/tcr/notebooks/notebooks.regression/figures/"

predictions_files = [
    ('MVIB', [pd.read_csv(RESULTS_BASE + f"mvib.poe.pMHC.rep-{i}.csv") for i in range(5)]),
    ('AvgPOOLoE', [pd.read_csv(RESULTS_BASE + f"mvib.avg_pool.pMHC.rep-{i}.csv") for i in range(5)]),
    ('MaxPOOLeE', [pd.read_csv(RESULTS_BASE + f"mvib.max_pool.pMHC.rep-{i}.csv") for i in range(5)]),
    ('AVIB', [pd.read_csv(RESULTS_BASE + f"mvib.aoe.pMHC.rep-{i}.csv") for i in range(5)]),
]

In [47]:
results = []

for i in tqdm(range(5)):
    for predictions_file in predictions_files:
        prediction_df = predictions_file[1][i]
        if f'prediction_{i}' in prediction_df.columns:
            scores_df = get_scores(
                y_true=prediction_df['BA'].to_numpy(), 
                y_pred=prediction_df[f'prediction_{i}'].to_numpy(),
            )
            scores_df['Model'] = predictions_file[0]
            results.append(scores_df)
        
results_df = pd.concat(results).rename(columns={'metrics': 'Metrics', 'score': 'Score'})



100%|██████████| 5/5 [00:00<00:00, 118.27it/s]


In [48]:
results_df.groupby(['Metrics', 'Model']).mean()

Unnamed: 0_level_0,Unnamed: 1_level_0,Score
Metrics,Model,Unnamed: 2_level_1
MSE,AVIB,0.029876
MSE,AvgPOOLoE,0.032929
MSE,MVIB,0.03127
MSE,MaxPOOLeE,0.036157
R2,AVIB,0.558945
R2,AvgPOOLoE,0.513889
R2,MVIB,0.538357
R2,MaxPOOLeE,0.46621
RMSE,AVIB,0.133265
RMSE,AvgPOOLoE,0.14031


In [49]:
std_df = results_df.groupby(['Metrics', 'Model']).std()
std_df['Score'] = std_df['Score'].apply(lambda x: x / 5)
std_df

Unnamed: 0_level_0,Unnamed: 1_level_0,Score
Metrics,Model,Unnamed: 2_level_1
MSE,AVIB,7.1e-05
MSE,AvgPOOLoE,0.000245
MSE,MVIB,7.4e-05
MSE,MaxPOOLeE,0.000104
R2,AVIB,0.00112
R2,AvgPOOLoE,0.003546
R2,MVIB,0.001231
R2,MaxPOOLeE,0.00157
RMSE,AVIB,0.000323
RMSE,AvgPOOLoE,0.000449
