# peptide-MHC binding affinity regression

In this notebook, we train and test the Attentive Variational Information Bottleneck on peptide+MHC class II data to predict binding affinity. We also do experiments with the baseline and ablation methods.

Run dataset.pMHC.ipynb to prepare the input data. *data/mhc/NetMHCIIpan_train/netmhcIIpan4.csv*

In [1]:
import pandas as pd
import torch
import numpy as np
import random

from vibtcr.dataset import TCRDataset
from vibtcr.mvib.mvib import MVIB
from vibtcr.mvib.mvib_trainer import TrainerMVIB

from torch.utils.data.sampler import WeightedRandomSampler
from sklearn.model_selection import train_test_split
from tqdm import tqdm

In [2]:
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
import pandas as pd
import torch

metrics = ['MSE', 'RMSE', 'R2']

def get_scores(y_true, y_pred):
    """
    Compute a df with all regression metrics and respective scores.
    """
    
    scores = [
        mean_squared_error(y_true, y_pred),
        mean_absolute_error(y_true, y_pred),
        r2_score(y_true, y_pred),
    ]
    
    df = pd.DataFrame(data={'score': scores, 'metrics': metrics})
    return df

In [3]:
def set_random_seed(random_seed):
    random.seed(random_seed)
    np.random.seed(random_seed)
    torch.manual_seed(random_seed)
    torch.cuda.manual_seed(random_seed)
    torch.cuda.manual_seed_all(random_seed)


In [4]:
import os
import pandas as pd
login = os.getlogin( )
#DATA_ROOT = f"/home/{login}/Git/tcr/data/mhc/NetMHCIIpan_train/"
#RESULTS_BASE = f"/home/{login}/Git/tcr/notebooks/notebooks.regression/results/"
# To run in github checkout of vibtcr, after `unzip data.zip` ...
RESULTS_BASE = os.path.join('.', 'results')
#FIGURES_BASE = os.path.join('.', 'figures')
DATA_BASE = os.path.join('..', '..', 'data','mhc','NetMHCIIpan_train')

In [5]:
device = torch.device('cuda:1')

batch_size = 8192
epochs = 1000
lr = 1e-3

z_dim = 150
early_stopper_patience = 30
monitor = 'loss'
lr_scheduler_param = 10
loss = "mse"

beta = 1e-6

# PoE

In [6]:
joint_posterior = "poe"
df = pd.read_csv(os.path.join(DATA_BASE, 'netmhcIIpan4.csv'))

for i in range(5):  # 5 independent train/test splits
    set_random_seed(i)

    df_train, df_test = train_test_split(df.copy(), test_size=0.2, random_state=i)

    scaler = TCRDataset(df_train.copy(), torch.device("cpu"), cdr3b_col='mhc',
                        cdr3a_col=None, gt_col='BA').scaler

    ds_test = TCRDataset(df_test, torch.device("cpu"), cdr3b_col='mhc',
                         cdr3a_col=None, scaler=scaler, gt_col='BA')

    df_train, df_val = train_test_split(df_train, test_size=0.2, random_state=i)
        
    # train loader with balanced sampling
    ds_train = TCRDataset(df_train, device, cdr3b_col='mhc',
                          cdr3a_col=None, scaler=scaler, gt_col='BA')
    train_loader = torch.utils.data.DataLoader(
        ds_train,
        batch_size=batch_size,
        shuffle=True
    )
    
    # val loader with balanced sampling
    ds_val = TCRDataset(df_val, device, cdr3b_col='mhc',
                        cdr3a_col=None, scaler=scaler, gt_col='BA')
    val_loader = torch.utils.data.DataLoader(
        ds_val,
        batch_size=batch_size,
        shuffle=True
    )

    model = MVIB(z_dim=z_dim, device=device, joint_posterior=joint_posterior).to(device)

    trainer = TrainerMVIB(
        model,
        epochs=epochs,
        lr=lr,
        beta=beta,
        checkpoint_dir=".",
        mode="bimodal",
        lr_scheduler_param=lr_scheduler_param,
        loss=loss
    )
    checkpoint = trainer.train(train_loader, val_loader, early_stopper_patience, monitor)    
    run_name = f"mvib.{joint_posterior}.pMHC.rep-{i}"
    trainer.save_checkpoint(checkpoint, folder='./', filename=os.path.join(RESULTS_BASE, f"{run_name}.pth"))
    
    # test (why on cpu?)
    model = MVIB.from_checkpoint(checkpoint, torch.device("cpu"))
    pred = model.classify(pep=ds_test.pep, cdr3b=ds_test.cdr3b, cdr3a=None)
    pred = pred.detach().numpy()
    df_test['prediction_'+str(i)] = pred.squeeze().tolist()

    # save results for further analysis
    df_test.to_csv(
        os.path.join(RESULTS_BASE, f"{run_name}.csv"),
        index=False
    )

[VAL] Best epoch 411 | Best val score 0.025509 | DKL-prior 0.000181 | MSE 0.025328 |:  44%|████▍     | 440/1000 [18:44<23:50,  2.56s/it]


Saving best model: epoch 411


[VAL] Best epoch 336 | Best val score 0.025809 | DKL-prior 0.000173 | MSE 0.025636 |:  36%|███▋      | 365/1000 [15:33<27:03,  2.56s/it]


Saving best model: epoch 336


[VAL] Best epoch 410 | Best val score 0.025608 | DKL-prior 0.000181 | MSE 0.025427 |:  44%|████▍     | 439/1000 [18:48<24:01,  2.57s/it]


Saving best model: epoch 410


[VAL] Best epoch 336 | Best val score 0.026140 | DKL-prior 0.000171 | MSE 0.025970 |:  36%|███▋      | 365/1000 [15:38<27:12,  2.57s/it]


Saving best model: epoch 336


[VAL] Best epoch 468 | Best val score 0.026004 | DKL-prior 0.000187 | MSE 0.025817 |:  50%|████▉     | 497/1000 [21:06<21:22,  2.55s/it]


Saving best model: epoch 468


# AoE

In [7]:
joint_posterior = "aoe"

In [8]:
df = pd.read_csv(os.path.join(DATA_BASE, 'netmhcIIpan4.csv'))

for i in range(5):  # 5 independent train/test splits
    set_random_seed(i)

    df_train, df_test = train_test_split(df.copy(), test_size=0.2, random_state=i)

    scaler = TCRDataset(df_train.copy(), torch.device("cpu"), cdr3b_col='mhc', cdr3a_col=None, gt_col='BA').scaler

    ds_test = TCRDataset(df_test, torch.device("cpu"), cdr3b_col='mhc', cdr3a_col=None, scaler=scaler, gt_col='BA')

    df_train, df_val = train_test_split(df_train, test_size=0.2, random_state=i)
        
    # train loader with balanced sampling
    ds_train = TCRDataset(df_train, device, cdr3b_col='mhc', cdr3a_col=None, scaler=scaler, gt_col='BA')
    train_loader = torch.utils.data.DataLoader(
        ds_train,
        batch_size=batch_size,
        shuffle=True
    )
    
    # val loader with balanced sampling
    ds_val = TCRDataset(df_val, device, cdr3b_col='mhc', cdr3a_col=None, scaler=scaler, gt_col='BA')
    val_loader = torch.utils.data.DataLoader(
        ds_val,
        batch_size=batch_size,
        shuffle=True
    )

    model = MVIB(z_dim=z_dim, device=device, joint_posterior=joint_posterior).to(device)

    trainer = TrainerMVIB(
        model,
        epochs=epochs,
        lr=lr,
        beta=beta,
        checkpoint_dir=".",
        mode="bimodal",
        lr_scheduler_param=lr_scheduler_param,
        loss=loss
    )
    checkpoint = trainer.train(train_loader, val_loader, early_stopper_patience, monitor)    
    run_name = f"mvib.{joint_posterior}.pMHC.rep-{i}"
    trainer.save_checkpoint(checkpoint, folder='./', filename=os.path.join(RESULTS_BASE, f"{run_name}.pth"))
    
    # test
    model = MVIB.from_checkpoint(checkpoint, torch.device("cpu"))
    pred = model.classify(pep=ds_test.pep, cdr3b=ds_test.cdr3b, cdr3a=None)
    pred = pred.detach().numpy()
    df_test['prediction_'+str(i)] = pred.squeeze().tolist()

    # save results for further analysis
    df_test.to_csv(
        os.path.join(RESULTS_BASE, f"{run_name}.csv"),
        index=False
    )

[VAL] Best epoch 329 | Best val score 0.024496 | DKL-prior 0.000252 | MSE 0.024245 |:  36%|███▌      | 358/1000 [17:22<31:08,  2.91s/it]


Saving best model: epoch 329


[VAL] Best epoch 289 | Best val score 0.024881 | DKL-prior 0.000212 | MSE 0.024669 |:  32%|███▏      | 318/1000 [15:25<33:04,  2.91s/it]


Saving best model: epoch 289


[VAL] Best epoch 184 | Best val score 0.025436 | DKL-prior 0.000238 | MSE 0.025199 |:  21%|██▏       | 213/1000 [10:18<38:06,  2.91s/it]


Saving best model: epoch 184


[VAL] Best epoch 282 | Best val score 0.025344 | DKL-prior 0.000206 | MSE 0.025138 |:  31%|███       | 311/1000 [14:57<33:07,  2.88s/it]


Saving best model: epoch 282


[VAL] Best epoch 240 | Best val score 0.026203 | DKL-prior 0.000255 | MSE 0.025947 |:  27%|██▋       | 269/1000 [12:52<34:59,  2.87s/it]


Saving best model: epoch 240


# Average Pooling of Experts

In [9]:
joint_posterior = "avg_pool"

In [10]:
df = pd.read_csv(os.path.join(DATA_BASE, 'netmhcIIpan4.csv'))

for i in range(5):  # 5 independent train/test splits
    set_random_seed(i)

    df_train, df_test = train_test_split(df.copy(), test_size=0.2, random_state=i)

    scaler = TCRDataset(df_train.copy(), torch.device("cpu"), cdr3b_col='mhc', cdr3a_col=None, gt_col='BA').scaler

    ds_test = TCRDataset(df_test, torch.device("cpu"), cdr3b_col='mhc', cdr3a_col=None, scaler=scaler, gt_col='BA')

    df_train, df_val = train_test_split(df_train, test_size=0.2, random_state=i)
        
    # train loader with balanced sampling
    ds_train = TCRDataset(df_train, device, cdr3b_col='mhc', cdr3a_col=None, scaler=scaler, gt_col='BA')
    train_loader = torch.utils.data.DataLoader(
        ds_train,
        batch_size=batch_size,
        shuffle=True
    )
    
    # val loader with balanced sampling
    ds_val = TCRDataset(df_val, device, cdr3b_col='mhc', cdr3a_col=None, scaler=scaler, gt_col='BA')
    val_loader = torch.utils.data.DataLoader(
        ds_val,
        batch_size=batch_size,
        shuffle=True
    )

    model = MVIB(z_dim=z_dim, device=device, joint_posterior=joint_posterior).to(device)

    trainer = TrainerMVIB(
        model,
        epochs=epochs,
        lr=lr,
        beta=beta,
        checkpoint_dir=".",
        mode="bimodal",
        lr_scheduler_param=lr_scheduler_param,
        loss=loss
    )
    checkpoint = trainer.train(train_loader, val_loader, early_stopper_patience, monitor)
    # unfortunately, the trainer did not auto-save the models.
    run_name = f"mvib.{joint_posterior}.pMHC.rep-{i}"
    trainer.save_checkpoint(checkpoint, folder='./', filename=os.path.join(RESULTS_BASE, f"{run_name}.pth"))
    
    # test
    model = MVIB.from_checkpoint(checkpoint, torch.device("cpu"))
    pred = model.classify(pep=ds_test.pep, cdr3b=ds_test.cdr3b, cdr3a=None)
    pred = pred.detach().numpy()
    df_test['prediction_'+str(i)] = pred.squeeze().tolist()

    # save results for further analysis
    df_test.to_csv(
        os.path.join(RESULTS_BASE, f"{run_name}.csv"),
        index=False
    )

[VAL] Best epoch 599 | Best val score 0.026412 | DKL-prior 0.000303 | MSE 0.026109 |:  63%|██████▎   | 628/1000 [26:03<15:26,  2.49s/it]


Saving best model: epoch 599


[VAL] Best epoch 377 | Best val score 0.027039 | DKL-prior 0.000259 | MSE 0.026780 |:  41%|████      | 406/1000 [16:50<24:38,  2.49s/it]


Saving best model: epoch 377


[VAL] Best epoch 365 | Best val score 0.027153 | DKL-prior 0.000252 | MSE 0.026901 |:  39%|███▉      | 394/1000 [16:18<25:04,  2.48s/it]


Saving best model: epoch 365


[VAL] Best epoch 487 | Best val score 0.027152 | DKL-prior 0.000279 | MSE 0.026873 |:  52%|█████▏    | 516/1000 [21:22<20:03,  2.49s/it]


Saving best model: epoch 487


[VAL] Best epoch 480 | Best val score 0.027323 | DKL-prior 0.000280 | MSE 0.027043 |:  51%|█████     | 509/1000 [21:05<20:21,  2.49s/it]


Saving best model: epoch 480


# Max Pooling of Experts

In [11]:
joint_posterior = "max_pool"

In [12]:
df = pd.read_csv(os.path.join(DATA_BASE, 'netmhcIIpan4.csv'))

for i in range(5):  # 5 independent train/test splits
    set_random_seed(i)

    df_train, df_test = train_test_split(df.copy(), test_size=0.2, random_state=i)

    scaler = TCRDataset(df_train.copy(), torch.device("cpu"), cdr3b_col='mhc', cdr3a_col=None, gt_col='BA').scaler

    ds_test = TCRDataset(df_test, torch.device("cpu"), cdr3b_col='mhc', cdr3a_col=None, scaler=scaler, gt_col='BA')

    df_train, df_val = train_test_split(df_train, test_size=0.2, random_state=i)
        
    # train loader with balanced sampling
    ds_train = TCRDataset(df_train, device, cdr3b_col='mhc', cdr3a_col=None, scaler=scaler, gt_col='BA')
    train_loader = torch.utils.data.DataLoader(
        ds_train,
        batch_size=batch_size,
        shuffle=True
    )
    
    # val loader with balanced sampling
    ds_val = TCRDataset(df_val, device, cdr3b_col='mhc', cdr3a_col=None, scaler=scaler, gt_col='BA')
    val_loader = torch.utils.data.DataLoader(
        ds_val,
        batch_size=batch_size,
        shuffle=True
    )

    model = MVIB(z_dim=z_dim, device=device, joint_posterior=joint_posterior).to(device)

    trainer = TrainerMVIB(
        model,
        epochs=epochs,
        lr=lr,
        beta=beta,
        checkpoint_dir=".",
        mode="bimodal",
        lr_scheduler_param=lr_scheduler_param,
        loss=loss
    )
    checkpoint = trainer.train(train_loader, val_loader, early_stopper_patience, monitor)    
    run_name = f"mvib.{joint_posterior}.pMHC.rep-{i}"
    trainer.save_checkpoint(checkpoint, folder='./',
                            filename=os.path.join(RESULTS_BASE, f"{run_name}.pth"))
    
    # test
    model = MVIB.from_checkpoint(checkpoint, torch.device("cpu"))
    pred = model.classify(pep=ds_test.pep, cdr3b=ds_test.cdr3b, cdr3a=None)
    pred = pred.detach().numpy()
    df_test['prediction_'+str(i)] = pred.squeeze().tolist()

    # save results for further analysis
    df_test.to_csv(
        os.path.join(RESULTS_BASE, f"{run_name}.csv"),
        index=False
    )

[VAL] Best epoch 156 | Best val score 0.033577 | DKL-prior 0.000188 | MSE 0.033389 |:  18%|█▊        | 185/1000 [08:13<36:16,  2.67s/it]


Saving best model: epoch 156


[VAL] Best epoch 139 | Best val score 0.035237 | DKL-prior 0.000188 | MSE 0.035050 |:  17%|█▋        | 168/1000 [07:27<36:58,  2.67s/it]


Saving best model: epoch 139


[VAL] Best epoch 111 | Best val score 0.034921 | DKL-prior 0.000184 | MSE 0.034738 |:  14%|█▍        | 140/1000 [06:13<38:14,  2.67s/it]


Saving best model: epoch 111


[VAL] Best epoch 77 | Best val score 0.034794 | DKL-prior 0.000177 | MSE 0.034617 |:  11%|█         | 106/1000 [04:44<39:55,  2.68s/it]


Saving best model: epoch 77


[VAL] Best epoch 149 | Best val score 0.033970 | DKL-prior 0.000198 | MSE 0.033772 |:  18%|█▊        | 178/1000 [07:56<36:40,  2.68s/it]


Saving best model: epoch 149


In [13]:
#RESULTS_BASE = f"/home/{login}/Git/tcr/notebooks/notebooks.regression/results/"
#FIGURES_BASE = f"/home/{login}/Git/tcr/notebooks/notebooks.regression/figures/"
# To run in github checkout of vibtcr, after `unzip data.zip` ...
RESULTS_BASE = os.path.join('.', 'results')
FIGURES_BASE = os.path.join('.', 'figures')
#DATA_BASE = os.path.join('..', '..', 'data')
predictions_files = [
    ('MVIB',      [pd.read_csv(os.path.join(RESULTS_BASE, f"mvib.poe.pMHC.rep-{i}.csv")) for i in range(5)]),
    ('AvgPOOLoE', [pd.read_csv(os.path.join(RESULTS_BASE, f"mvib.avg_pool.pMHC.rep-{i}.csv")) for i in range(5)]),
    ('MaxPOOLeE', [pd.read_csv(os.path.join(RESULTS_BASE, f"mvib.max_pool.pMHC.rep-{i}.csv")) for i in range(5)]),
    ('AVIB',      [pd.read_csv(os.path.join(RESULTS_BASE, f"mvib.aoe.pMHC.rep-{i}.csv")) for i in range(5)]),
]

In [14]:
results = []

for i in tqdm(range(5)):
    for predictions_file in predictions_files:
        prediction_df = predictions_file[1][i]
        if f'prediction_{i}' in prediction_df.columns:
            scores_df = get_scores(
                y_true=prediction_df['BA'].to_numpy(), 
                y_pred=prediction_df[f'prediction_{i}'].to_numpy(),
            )
            scores_df['Model'] = predictions_file[0]
            results.append(scores_df)
        
results_df = pd.concat(results).rename(columns={'metrics': 'Metrics', 'score': 'Score'})



100%|██████████| 5/5 [00:00<00:00, 167.85it/s]


In [15]:
results_df.groupby(['Metrics', 'Model']).mean()

Unnamed: 0_level_0,Unnamed: 1_level_0,Score
Metrics,Model,Unnamed: 2_level_1
MSE,AVIB,0.029553
MSE,AvgPOOLoE,0.031083
MSE,MVIB,0.030397
MSE,MaxPOOLeE,0.035646
R2,AVIB,0.563701
R2,AvgPOOLoE,0.541144
R2,MVIB,0.551252
R2,MaxPOOLeE,0.473784
RMSE,AVIB,0.131189
RMSE,AvgPOOLoE,0.136195


In [16]:
std_df = results_df.groupby(['Metrics', 'Model']).std()
std_df['Score'] = std_df['Score'].apply(lambda x: x / 5) # Really?
std_df

Unnamed: 0_level_0,Unnamed: 1_level_0,Score
Metrics,Model,Unnamed: 2_level_1
MSE,AVIB,6.2e-05
MSE,AvgPOOLoE,8.6e-05
MSE,MVIB,5.8e-05
MSE,MaxPOOLeE,0.000152
R2,AVIB,0.00116
R2,AvgPOOLoE,0.001024
R2,MVIB,0.000814
R2,MaxPOOLeE,0.00197
RMSE,AVIB,0.000165
RMSE,AvgPOOLoE,0.000138


In [17]:
# TODO: matplotlib figures ?