# Attentive Variational Information Bottleneck

In this notebook, we train and test the Attentive Variational Information Bottleneck (MVIB [1] with Attention of Experts) and MVIB on all datasets.

[1] Microbiome-based disease prediction with multimodal variational information bottlenecks, Grazioli et al., https://www.biorxiv.org/node/2109522.external-links.html

In [1]:
import pandas as pd
import torch
import numpy as np
import random

from vibtcr.dataset import TCRDataset
from vibtcr.mvib.mvib import MVIB
from vibtcr.mvib.mvib_trainer import TrainerMVIB

from torch.utils.data.sampler import WeightedRandomSampler
from sklearn.model_selection import train_test_split
from tqdm import tqdm

In [2]:
from sklearn.metrics import roc_auc_score, accuracy_score, precision_score, recall_score, f1_score, precision_recall_curve, auc
import pandas as pd
import torch

metrics = [
    'auROC',
    'Accuracy',
    'Recall',
    'Precision',
    'F1 score',
    'auPRC'
]

def pr_auc(y_true, y_prob):
    precision, recall, thresholds = precision_recall_curve(y_true, y_prob)
    pr_auc = auc(recall, precision)
    return pr_auc

def get_scores(y_true, y_prob, y_pred):
    """
    Compute a df with all classification metrics and respective scores.
    """
    
    scores = [
        roc_auc_score(y_true, y_prob),
        accuracy_score(y_true, y_pred),
        recall_score(y_true, y_pred),
        precision_score(y_true, y_pred),
        f1_score(y_true, y_pred),
        pr_auc(y_true, y_prob)
    ]
    
    df = pd.DataFrame(data={'score': scores, 'metrics': metrics})
    return df

In [3]:
def set_random_seed(random_seed):
    random.seed(random_seed)
    np.random.seed(random_seed)
    torch.manual_seed(random_seed)
    torch.cuda.manual_seed(random_seed)
    torch.cuda.manual_seed_all(random_seed)

In [4]:
import os
login = os.getlogin( )
DATA_BASE = f"/home/{login}/Git/tcr/data/"
RESULTS_BASE = f"/home/{login}/Git/tcr/notebooks/notebooks.classification/results/"

In [5]:
device = torch.device('cuda:0')

batch_size = 4096
epochs = 500
lr = 1e-3

z_dim = 150
early_stopper_patience = 50
monitor = 'auROC'
lr_scheduler_param = 10
joint_posterior = "aoe"

beta = 1e-6

# alpha+beta set - peptide+CDR3b

In [None]:
df = pd.read_csv(DATA_BASE + 'alpha-beta-splits/alpha-beta.csv')

for i in range(5):  # 5 independent train/test splits
    set_random_seed(i)

    df_train, df_test = train_test_split(df.copy(), test_size=0.2, random_state=i)
    scaler = TCRDataset(df_train.copy(), torch.device("cpu"), cdr3b_col='tcrb', cdr3a_col=None).scaler

    ds_test = TCRDataset(df_test, torch.device("cpu"), cdr3b_col='tcrb', cdr3a_col=None, scaler=scaler)

    df_train, df_val = train_test_split(df_train, test_size=0.2, stratify=df_train.sign, random_state=i)
        
    # train loader with balanced sampling
    ds_train = TCRDataset(df_train, device, cdr3b_col='tcrb', cdr3a_col=None, scaler=scaler)
    class_count = np.array([df_train[df_train.sign == 0].shape[0], df_train[df_train.sign == 1].shape[0]])
    weight = 1. / class_count
    samples_weight = torch.tensor([weight[s] for s in df_train.sign])
    sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
    train_loader = torch.utils.data.DataLoader(
        ds_train,
        batch_size=batch_size,
        sampler=sampler
    )
    
    # val loader with balanced sampling
    ds_val = TCRDataset(df_val, device, cdr3b_col='tcrb', cdr3a_col=None, scaler=scaler)
    class_count = np.array([df_val[df_val.sign == 0].shape[0], df_val[df_val.sign == 1].shape[0]])
    weight = 1. / class_count
    samples_weight = torch.tensor([weight[s] for s in df_val.sign])
    sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
    val_loader = torch.utils.data.DataLoader(
        ds_val,
        batch_size=batch_size,
        sampler=sampler
    )

    model = MVIB(z_dim=z_dim, device=device, joint_posterior=joint_posterior).to(device)

    trainer = TrainerMVIB(
        model,
        epochs=epochs,
        lr=lr,
        beta=beta,
        checkpoint_dir=".",
        mode="bimodal",
        lr_scheduler_param=lr_scheduler_param
    )
    checkpoint = trainer.train(train_loader, val_loader, early_stopper_patience, monitor)    
    
    # test
    model = MVIB.from_checkpoint(checkpoint, torch.device("cpu"))
    pred = model.classify(pep=ds_test.pep, cdr3b=ds_test.cdr3b, cdr3a=None)
    pred = pred.detach().numpy()
    df_test['prediction_'+str(i)] = pred.squeeze().tolist()

    # save results for further analysis
    df_test.to_csv(
        RESULTS_BASE + f"mvib.bimodal.{joint_posterior}.alpha+beta-only.rep-{i}.csv",
        index=False
    )

# alpha+beta set - peptide+CDR3b+CDR3a 

In [18]:
df = pd.read_csv(DATA_BASE + 'alpha-beta-splits/alpha-beta.csv')

for i in range(5):  # 5 independent train/test splits
    set_random_seed(i)

    df_train, df_test = train_test_split(df.copy(), test_size=0.2, random_state=i)
    scaler = TCRDataset(df_train.copy(), torch.device("cpu"), cdr3b_col='tcrb', cdr3a_col='tcra').scaler

    ds_test = TCRDataset(df_test, torch.device("cpu"), cdr3b_col='tcrb', cdr3a_col='tcra', scaler=scaler)

    df_train, df_val = train_test_split(df_train, test_size=0.2, stratify=df_train.sign, random_state=i)
        
    # train loader with balanced sampling
    ds_train = TCRDataset(df_train, device, cdr3b_col='tcrb', cdr3a_col='tcra', scaler=scaler)
    class_count = np.array([df_train[df_train.sign == 0].shape[0], df_train[df_train.sign == 1].shape[0]])
    weight = 1. / class_count
    samples_weight = torch.tensor([weight[s] for s in df_train.sign])
    sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
    train_loader = torch.utils.data.DataLoader(
        ds_train,
        batch_size=batch_size,
        sampler=sampler
    )
    
    # val loader with balanced sampling
    ds_val = TCRDataset(df_val, device, cdr3b_col='tcrb', cdr3a_col='tcra', scaler=scaler)
    class_count = np.array([df_val[df_val.sign == 0].shape[0], df_val[df_val.sign == 1].shape[0]])
    weight = 1. / class_count
    samples_weight = torch.tensor([weight[s] for s in df_val.sign])
    sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
    val_loader = torch.utils.data.DataLoader(
        ds_val,
        batch_size=batch_size,
        sampler=sampler
    )

    model = MVIB(z_dim=z_dim, device=device, joint_posterior=joint_posterior).to(device)

    trainer = TrainerMVIB(
        model,
        epochs=epochs,
        lr=lr,
        beta=beta,
        checkpoint_dir=".",
        mode="trimodal",
        lr_scheduler_param=lr_scheduler_param
    )
    checkpoint = trainer.train(train_loader, val_loader, early_stopper_patience, monitor)    
    
    # test
    model = MVIB.from_checkpoint(checkpoint, torch.device("cpu"))
    pred = model.classify(pep=ds_test.pep, cdr3b=ds_test.cdr3b, cdr3a=ds_test.cdr3a)
    pred = pred.detach().numpy()
    df_test['prediction_'+str(i)] = pred.squeeze().tolist()

    # save results for further analysis
    df_test.to_csv(
        RESULTS_BASE + f"mvib.trimodal.{joint_posterior}.alpha+beta-only.rep-{i}.csv",
        index=False
    )

[VAL] Best epoch 80 | Score -0.935698 | DKL-prior 0.000373 | BCE 0.549172 | auROC 0.9357:  26%|██▌       | 129/500 [09:02<26:00,  4.21s/it]
[VAL] Best epoch 156 | Score -0.935239 | DKL-prior 0.000369 | BCE 0.870948 | auROC 0.9352:  41%|████      | 205/500 [14:21<20:39,  4.20s/it]
[VAL] Best epoch 104 | Score -0.938586 | DKL-prior 0.000363 | BCE 0.602351 | auROC 0.9386:  31%|███       | 153/500 [10:48<24:29,  4.24s/it]
[VAL] Best epoch 120 | Score -0.935417 | DKL-prior 0.000336 | BCE 0.739175 | auROC 0.9354:  34%|███▍      | 169/500 [11:50<23:11,  4.20s/it]
[VAL] Best epoch 137 | Score -0.941222 | DKL-prior 0.000359 | BCE 0.763292 | auROC 0.9412:  37%|███▋      | 186/500 [13:10<22:14,  4.25s/it]


# alpha+beta set - peptide+CDR3a 

In [17]:
df = pd.read_csv(DATA_BASE + 'alpha-beta-splits/alpha-beta.csv')

for i in range(5):  # 5 independent train/test splits
    set_random_seed(i)

    df_train, df_test = train_test_split(df.copy(), test_size=0.2, random_state=i)
    scaler = TCRDataset(df_train.copy(), torch.device("cpu"), cdr3b_col='tcra').scaler
    # we pass column `tcra` to `cdr3b_col` because TCRDataset expects to have the CDR3b attribute
    ds_test = TCRDataset(df_test, torch.device("cpu"), cdr3b_col='tcra', scaler=scaler)

    df_train, df_val = train_test_split(df_train, test_size=0.2, stratify=df_train.sign, random_state=i)
        
    # train loader with balanced sampling
    ds_train = TCRDataset(df_train, device, cdr3b_col='tcra', scaler=scaler)
    class_count = np.array([df_train[df_train.sign == 0].shape[0], df_train[df_train.sign == 1].shape[0]])
    weight = 1. / class_count
    samples_weight = torch.tensor([weight[s] for s in df_train.sign])
    sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
    train_loader = torch.utils.data.DataLoader(
        ds_train,
        batch_size=batch_size,
        sampler=sampler
    )
    
    # val loader with balanced sampling
    ds_val = TCRDataset(df_val, device, cdr3b_col='tcra', scaler=scaler)
    class_count = np.array([df_val[df_val.sign == 0].shape[0], df_val[df_val.sign == 1].shape[0]])
    weight = 1. / class_count
    samples_weight = torch.tensor([weight[s] for s in df_val.sign])
    sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
    val_loader = torch.utils.data.DataLoader(
        ds_val,
        batch_size=batch_size,
        sampler=sampler
    )

    model = MVIB(z_dim=z_dim, device=device, joint_posterior=joint_posterior).to(device)

    trainer = TrainerMVIB(
        model,
        epochs=epochs,
        lr=lr,
        beta=beta,
        checkpoint_dir=".",
        mode="bimodal",
        lr_scheduler_param=lr_scheduler_param
    )
    checkpoint = trainer.train(train_loader, val_loader, early_stopper_patience, monitor)    
    
    # test
    model = MVIB.from_checkpoint(checkpoint, torch.device("cpu"))
    pred = model.classify(pep=ds_test.pep, cdr3b=ds_test.cdr3b)
    pred = pred.detach().numpy()
    df_test['prediction_'+str(i)] = pred.squeeze().tolist()

    # save results for further analysis
    df_test.to_csv(
        RESULTS_BASE + f"mvib.bimodal-alpha.{joint_posterior}.alpha+beta-only.rep-{i}.csv",
        index=False
    )

[VAL] Best epoch 78 | Best val score -0.924754 | DKL-prior 0.000436 | BCE 0.562571 | auROC 0.9248:  25%|██▌       | 127/500 [03:33<10:25,  1.68s/it]
[VAL] Best epoch 96 | Best val score -0.923685 | DKL-prior 0.000542 | BCE 0.679592 | auROC 0.9237:  29%|██▉       | 145/500 [04:04<09:57,  1.68s/it]
[VAL] Best epoch 84 | Best val score -0.928253 | DKL-prior 0.000504 | BCE 0.545874 | auROC 0.9283:  27%|██▋       | 133/500 [03:42<10:14,  1.67s/it]
[VAL] Best epoch 124 | Best val score -0.925423 | DKL-prior 0.000544 | BCE 0.705228 | auROC 0.9254:  35%|███▍      | 173/500 [04:49<09:07,  1.67s/it]
[VAL] Best epoch 63 | Best val score -0.927728 | DKL-prior 0.000494 | BCE 0.450250 | auROC 0.9277:  22%|██▏       | 112/500 [03:10<10:59,  1.70s/it]


# beta set 

In [14]:
df = pd.read_csv(DATA_BASE + 'alpha-beta-splits/beta.csv')

for i in range(5):  # 5 independent train/test splits
    set_random_seed(i)

    df_train, df_test = train_test_split(df.copy(), test_size=0.2, random_state=i)
    scaler = TCRDataset(df_train.copy(), torch.device("cpu"), cdr3b_col='tcrb', cdr3a_col=None).scaler

    ds_test = TCRDataset(df_test, torch.device("cpu"), cdr3b_col='tcrb', cdr3a_col=None, scaler=scaler)

    df_train, df_val = train_test_split(df_train, test_size=0.2, stratify=df_train.sign, random_state=i)
        
    # train loader with balanced sampling
    ds_train = TCRDataset(df_train, device, cdr3b_col='tcrb', cdr3a_col=None, scaler=scaler)
    class_count = np.array([df_train[df_train.sign == 0].shape[0], df_train[df_train.sign == 1].shape[0]])
    weight = 1. / class_count
    samples_weight = torch.tensor([weight[s] for s in df_train.sign])
    sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
    train_loader = torch.utils.data.DataLoader(
        ds_train,
        batch_size=batch_size,
        sampler=sampler
    )
    
    # val loader with balanced sampling
    ds_val = TCRDataset(df_val, device, cdr3b_col='tcrb', cdr3a_col=None, scaler=scaler)
    class_count = np.array([df_val[df_val.sign == 0].shape[0], df_val[df_val.sign == 1].shape[0]])
    weight = 1. / class_count
    samples_weight = torch.tensor([weight[s] for s in df_val.sign])
    sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
    val_loader = torch.utils.data.DataLoader(
        ds_val,
        batch_size=batch_size,
        sampler=sampler
    )

    model = MVIB(z_dim=z_dim, device=device, joint_posterior=joint_posterior).to(device)

    trainer = TrainerMVIB(
        model,
        epochs=epochs,
        lr=lr,
        beta=beta,
        checkpoint_dir=".",
        mode="bimodal",
        lr_scheduler_param=lr_scheduler_param
    )
    checkpoint = trainer.train(train_loader, val_loader, early_stopper_patience, monitor)    
    
    # test
    model = MVIB.from_checkpoint(checkpoint, torch.device("cpu"))
    pred = model.classify(pep=ds_test.pep, cdr3b=ds_test.cdr3b, cdr3a=None)
    pred = pred.detach().numpy()
    df_test['prediction_'+str(i)] = pred.squeeze().tolist()

    # save results for further analysis
    df_test.to_csv(
        RESULTS_BASE + f"mvib.bimodal.{joint_posterior}.beta-only.rep-{i}.csv",
        index=False
    )

[VAL] Best epoch 357 | Score -0.831115 | DKL-prior 0.000416 | BCE 0.816526 | auROC 0.8311:  81%|████████  | 406/500 [15:39<03:37,  2.31s/it]
[VAL] Best epoch 216 | Score -0.835735 | DKL-prior 0.000381 | BCE 0.755226 | auROC 0.8357:  53%|█████▎    | 265/500 [10:06<08:58,  2.29s/it]
[VAL] Best epoch 241 | Score -0.837941 | DKL-prior 0.000371 | BCE 0.718900 | auROC 0.8379:  58%|█████▊    | 290/500 [11:16<08:09,  2.33s/it]
[VAL] Best epoch 197 | Score -0.838721 | DKL-prior 0.000365 | BCE 0.695502 | auROC 0.8387:  49%|████▉     | 246/500 [09:24<09:42,  2.30s/it]
[VAL] Best epoch 313 | Score -0.840593 | DKL-prior 0.000391 | BCE 0.760762 | auROC 0.8406:  72%|███████▏  | 362/500 [13:50<05:16,  2.30s/it]


# full set: alpha+beta set + beta set

In [15]:
df1 = pd.read_csv(DATA_BASE + 'alpha-beta-splits/beta.csv')
df2 = pd.read_csv(DATA_BASE + 'alpha-beta-splits/alpha-beta.csv')
df = pd.concat([df1, df2]).reset_index()

for i in range(5):  # 5 independent train/test splits
    set_random_seed(i)

    df_train, df_test = train_test_split(df.copy(), test_size=0.2, random_state=i)
    scaler = TCRDataset(df_train.copy(), torch.device("cpu"), cdr3b_col='tcrb', cdr3a_col=None).scaler

    ds_test = TCRDataset(df_test, torch.device("cpu"), cdr3b_col='tcrb', cdr3a_col=None, scaler=scaler)

    df_train, df_val = train_test_split(df_train, test_size=0.2, stratify=df_train.sign, random_state=i)
        
    # train loader with balanced sampling
    ds_train = TCRDataset(df_train, device, cdr3b_col='tcrb', cdr3a_col=None, scaler=scaler)
    class_count = np.array([df_train[df_train.sign == 0].shape[0], df_train[df_train.sign == 1].shape[0]])
    weight = 1. / class_count
    samples_weight = torch.tensor([weight[s] for s in df_train.sign])
    sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
    train_loader = torch.utils.data.DataLoader(
        ds_train,
        batch_size=batch_size,
        sampler=sampler
    )
    
    # val loader with balanced sampling
    ds_val = TCRDataset(df_val, device, cdr3b_col='tcrb', cdr3a_col=None, scaler=scaler)
    class_count = np.array([df_val[df_val.sign == 0].shape[0], df_val[df_val.sign == 1].shape[0]])
    weight = 1. / class_count
    samples_weight = torch.tensor([weight[s] for s in df_val.sign])
    sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
    val_loader = torch.utils.data.DataLoader(
        ds_val,
        batch_size=batch_size,
        sampler=sampler
    )

    model = MVIB(z_dim=z_dim, device=device, joint_posterior=joint_posterior).to(device)

    trainer = TrainerMVIB(
        model,
        epochs=epochs,
        lr=lr,
        beta=beta,
        checkpoint_dir=".",
        mode="bimodal",
        lr_scheduler_param=lr_scheduler_param
    )
    checkpoint = trainer.train(train_loader, val_loader, early_stopper_patience, monitor)    
    
    # test
    model = MVIB.from_checkpoint(checkpoint, torch.device("cpu"))
    pred = model.classify(pep=ds_test.pep, cdr3b=ds_test.cdr3b, cdr3a=None)
    pred = pred.detach().numpy()
    df_test['prediction_'+str(i)] = pred.squeeze().tolist()

    # save results for further analysis
    df_test.to_csv(
        RESULTS_BASE + f"mvib.bimodal.{joint_posterior}.full.rep-{i}.csv",
        index=False
    )

[VAL] Best epoch 142 | Score -0.809937 | DKL-prior 0.000366 | BCE 0.643406 | auROC 0.8099:  38%|███▊      | 191/500 [12:56<20:56,  4.07s/it]
[VAL] Best epoch 334 | Score -0.813966 | DKL-prior 0.000441 | BCE 0.724015 | auROC 0.8140:  77%|███████▋  | 383/500 [25:35<07:49,  4.01s/it]
[VAL] Best epoch 279 | Score -0.819329 | DKL-prior 0.000404 | BCE 0.655243 | auROC 0.8193:  66%|██████▌   | 328/500 [22:09<11:37,  4.05s/it]
[VAL] Best epoch 264 | Score -0.817222 | DKL-prior 0.000392 | BCE 0.666876 | auROC 0.8172:  63%|██████▎   | 313/500 [21:12<12:40,  4.06s/it]
[VAL] Best epoch 323 | Score -0.814560 | DKL-prior 0.000439 | BCE 0.730162 | auROC 0.8146:  74%|███████▍  | 372/500 [25:02<08:37,  4.04s/it]


# Max pooling of experts

In [6]:
joint_posterior = "max_pool"

# alpha+beta set - peptide+CDR3b (max pooling of experts)

In [7]:
df = pd.read_csv(DATA_BASE + 'alpha-beta-splits/alpha-beta.csv')

for i in range(5):  # 5 independent train/test splits
    set_random_seed(i)

    df_train, df_test = train_test_split(df.copy(), test_size=0.2, random_state=i)
    scaler = TCRDataset(df_train.copy(), torch.device("cpu"), cdr3b_col='tcrb', cdr3a_col=None).scaler

    ds_test = TCRDataset(df_test, torch.device("cpu"), cdr3b_col='tcrb', cdr3a_col=None, scaler=scaler)

    df_train, df_val = train_test_split(df_train, test_size=0.2, stratify=df_train.sign, random_state=i)
        
    # train loader with balanced sampling
    ds_train = TCRDataset(df_train, device, cdr3b_col='tcrb', cdr3a_col=None, scaler=scaler)
    class_count = np.array([df_train[df_train.sign == 0].shape[0], df_train[df_train.sign == 1].shape[0]])
    weight = 1. / class_count
    samples_weight = torch.tensor([weight[s] for s in df_train.sign])
    sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
    train_loader = torch.utils.data.DataLoader(
        ds_train,
        batch_size=batch_size,
        sampler=sampler
    )
    
    # val loader with balanced sampling
    ds_val = TCRDataset(df_val, device, cdr3b_col='tcrb', cdr3a_col=None, scaler=scaler)
    class_count = np.array([df_val[df_val.sign == 0].shape[0], df_val[df_val.sign == 1].shape[0]])
    weight = 1. / class_count
    samples_weight = torch.tensor([weight[s] for s in df_val.sign])
    sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
    val_loader = torch.utils.data.DataLoader(
        ds_val,
        batch_size=batch_size,
        sampler=sampler
    )

    model = MVIB(z_dim=z_dim, device=device, joint_posterior=joint_posterior).to(device)

    trainer = TrainerMVIB(
        model,
        epochs=epochs,
        lr=lr,
        beta=beta,
        checkpoint_dir=".",
        mode="bimodal",
        lr_scheduler_param=lr_scheduler_param
    )
    checkpoint = trainer.train(train_loader, val_loader, early_stopper_patience, monitor)    
    
    # test
    model = MVIB.from_checkpoint(checkpoint, torch.device("cpu"))
    pred = model.classify(pep=ds_test.pep, cdr3b=ds_test.cdr3b, cdr3a=None)
    pred = pred.detach().numpy()
    df_test['prediction_'+str(i)] = pred.squeeze().tolist()

    # save results for further analysis
    df_test.to_csv(
        RESULTS_BASE + f"mvib.bimodal.{joint_posterior}.alpha+beta-only.rep-{i}.csv",
        index=False
    )

[VAL] Best epoch 145 | Score -0.910260 | DKL-prior 0.000527 | BCE 0.478257 | auROC 0.9103:  39%|███▉      | 194/500 [05:01<07:55,  1.56s/it]
[VAL] Best epoch 165 | Score -0.911694 | DKL-prior 0.000634 | BCE 0.503740 | auROC 0.9117:  43%|████▎     | 214/500 [05:33<07:25,  1.56s/it]
[VAL] Best epoch 167 | Score -0.911872 | DKL-prior 0.000543 | BCE 0.538788 | auROC 0.9119:  43%|████▎     | 216/500 [05:38<07:25,  1.57s/it]
[VAL] Best epoch 125 | Score -0.907811 | DKL-prior 0.000411 | BCE 0.527814 | auROC 0.9078:  35%|███▍      | 174/500 [04:41<08:47,  1.62s/it]
[VAL] Best epoch 129 | Score -0.909856 | DKL-prior 0.000478 | BCE 0.445229 | auROC 0.9099:  36%|███▌      | 178/500 [04:39<08:25,  1.57s/it]


# alpha+beta set - peptide+CDR3b+CDR3a  (max pooling of experts)

In [8]:
df = pd.read_csv(DATA_BASE + 'alpha-beta-splits/alpha-beta.csv')

for i in range(5):  # 5 independent train/test splits
    set_random_seed(i)

    df_train, df_test = train_test_split(df.copy(), test_size=0.2, random_state=i)
    scaler = TCRDataset(df_train.copy(), torch.device("cpu"), cdr3b_col='tcrb', cdr3a_col='tcra').scaler

    ds_test = TCRDataset(df_test, torch.device("cpu"), cdr3b_col='tcrb', cdr3a_col='tcra', scaler=scaler)

    df_train, df_val = train_test_split(df_train, test_size=0.2, stratify=df_train.sign, random_state=i)
        
    # train loader with balanced sampling
    ds_train = TCRDataset(df_train, device, cdr3b_col='tcrb', cdr3a_col='tcra', scaler=scaler)
    class_count = np.array([df_train[df_train.sign == 0].shape[0], df_train[df_train.sign == 1].shape[0]])
    weight = 1. / class_count
    samples_weight = torch.tensor([weight[s] for s in df_train.sign])
    sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
    train_loader = torch.utils.data.DataLoader(
        ds_train,
        batch_size=batch_size,
        sampler=sampler
    )
    
    # val loader with balanced sampling
    ds_val = TCRDataset(df_val, device, cdr3b_col='tcrb', cdr3a_col='tcra', scaler=scaler)
    class_count = np.array([df_val[df_val.sign == 0].shape[0], df_val[df_val.sign == 1].shape[0]])
    weight = 1. / class_count
    samples_weight = torch.tensor([weight[s] for s in df_val.sign])
    sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
    val_loader = torch.utils.data.DataLoader(
        ds_val,
        batch_size=batch_size,
        sampler=sampler
    )

    model = MVIB(z_dim=z_dim, device=device, joint_posterior=joint_posterior).to(device)

    trainer = TrainerMVIB(
        model,
        epochs=epochs,
        lr=lr,
        beta=beta,
        checkpoint_dir=".",
        mode="trimodal",
        lr_scheduler_param=lr_scheduler_param
    )
    checkpoint = trainer.train(train_loader, val_loader, early_stopper_patience, monitor)    
    
    # test
    model = MVIB.from_checkpoint(checkpoint, torch.device("cpu"))
    pred = model.classify(pep=ds_test.pep, cdr3b=ds_test.cdr3b, cdr3a=ds_test.cdr3a)
    pred = pred.detach().numpy()
    df_test['prediction_'+str(i)] = pred.squeeze().tolist()

    # save results for further analysis
    df_test.to_csv(
        RESULTS_BASE + f"mvib.trimodal.{joint_posterior}.alpha+beta-only.rep-{i}.csv",
        index=False
    )

[VAL] Best epoch 63 | Score -0.924964 | DKL-prior 0.000451 | BCE 0.387998 | auROC 0.9250:  22%|██▏       | 112/500 [08:22<29:00,  4.49s/it]
[VAL] Best epoch 101 | Score -0.930486 | DKL-prior 0.000590 | BCE 0.452560 | auROC 0.9305:  30%|███       | 150/500 [11:09<26:01,  4.46s/it]
[VAL] Best epoch 73 | Score -0.932433 | DKL-prior 0.000551 | BCE 0.414454 | auROC 0.9324:  24%|██▍       | 122/500 [09:05<28:09,  4.47s/it]
[VAL] Best epoch 87 | Score -0.934378 | DKL-prior 0.000548 | BCE 0.447056 | auROC 0.9344:  27%|██▋       | 136/500 [09:56<26:37,  4.39s/it]
[VAL] Best epoch 203 | Score -0.935721 | DKL-prior 0.000796 | BCE 0.504529 | auROC 0.9357:  50%|█████     | 252/500 [18:36<18:19,  4.43s/it]


# Average pooling fo experts

In [9]:
joint_posterior = "avg_pool"

# alpha+beta set - peptide+CDR3b (average pooling of experts)

In [11]:
df = pd.read_csv(DATA_BASE + 'alpha-beta-splits/alpha-beta.csv')

for i in range(5):  # 5 independent train/test splits
    set_random_seed(i)

    df_train, df_test = train_test_split(df.copy(), test_size=0.2, random_state=i)
    scaler = TCRDataset(df_train.copy(), torch.device("cpu"), cdr3b_col='tcrb', cdr3a_col=None).scaler

    ds_test = TCRDataset(df_test, torch.device("cpu"), cdr3b_col='tcrb', cdr3a_col=None, scaler=scaler)

    df_train, df_val = train_test_split(df_train, test_size=0.2, stratify=df_train.sign, random_state=i)
        
    # train loader with balanced sampling
    ds_train = TCRDataset(df_train, device, cdr3b_col='tcrb', cdr3a_col=None, scaler=scaler)
    class_count = np.array([df_train[df_train.sign == 0].shape[0], df_train[df_train.sign == 1].shape[0]])
    weight = 1. / class_count
    samples_weight = torch.tensor([weight[s] for s in df_train.sign])
    sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
    train_loader = torch.utils.data.DataLoader(
        ds_train,
        batch_size=batch_size,
        sampler=sampler
    )
    
    # val loader with balanced sampling
    ds_val = TCRDataset(df_val, device, cdr3b_col='tcrb', cdr3a_col=None, scaler=scaler)
    class_count = np.array([df_val[df_val.sign == 0].shape[0], df_val[df_val.sign == 1].shape[0]])
    weight = 1. / class_count
    samples_weight = torch.tensor([weight[s] for s in df_val.sign])
    sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
    val_loader = torch.utils.data.DataLoader(
        ds_val,
        batch_size=batch_size,
        sampler=sampler
    )

    model = MVIB(z_dim=z_dim, device=device, joint_posterior=joint_posterior).to(device)

    trainer = TrainerMVIB(
        model,
        epochs=epochs,
        lr=lr,
        beta=beta,
        checkpoint_dir=".",
        mode="bimodal",
        lr_scheduler_param=lr_scheduler_param
    )
    checkpoint = trainer.train(train_loader, val_loader, early_stopper_patience, monitor)    
    
    # test
    model = MVIB.from_checkpoint(checkpoint, torch.device("cpu"))
    pred = model.classify(pep=ds_test.pep, cdr3b=ds_test.cdr3b, cdr3a=None)
    pred = pred.detach().numpy()
    df_test['prediction_'+str(i)] = pred.squeeze().tolist()

    # save results for further analysis
    df_test.to_csv(
        RESULTS_BASE + f"mvib.bimodal.{joint_posterior}.alpha+beta-only.rep-{i}.csv",
        index=False
    )

[VAL] Best epoch 80 | Score -0.906532 | DKL-prior 0.000477 | BCE 0.529231 | auROC 0.9065:  26%|██▌       | 129/500 [03:11<09:11,  1.49s/it]
[VAL] Best epoch 82 | Score -0.908556 | DKL-prior 0.000472 | BCE 0.513567 | auROC 0.9086:  26%|██▌       | 131/500 [03:16<09:13,  1.50s/it]
[VAL] Best epoch 99 | Score -0.909954 | DKL-prior 0.000468 | BCE 0.533864 | auROC 0.9100:  30%|██▉       | 148/500 [03:42<08:48,  1.50s/it]
[VAL] Best epoch 125 | Score -0.906920 | DKL-prior 0.000465 | BCE 0.675006 | auROC 0.9069:  35%|███▍      | 174/500 [04:22<08:10,  1.51s/it]
[VAL] Best epoch 141 | Score -0.910486 | DKL-prior 0.000454 | BCE 0.726459 | auROC 0.9105:  38%|███▊      | 190/500 [05:08<08:23,  1.62s/it]


# alpha+beta set - peptide+CDR3b+CDR3a  (average pooling of experts)

In [12]:
df = pd.read_csv(DATA_BASE + 'alpha-beta-splits/alpha-beta.csv')

for i in range(5):  # 5 independent train/test splits
    set_random_seed(i)

    df_train, df_test = train_test_split(df.copy(), test_size=0.2, random_state=i)
    scaler = TCRDataset(df_train.copy(), torch.device("cpu"), cdr3b_col='tcrb', cdr3a_col='tcra').scaler

    ds_test = TCRDataset(df_test, torch.device("cpu"), cdr3b_col='tcrb', cdr3a_col='tcra', scaler=scaler)

    df_train, df_val = train_test_split(df_train, test_size=0.2, stratify=df_train.sign, random_state=i)
        
    # train loader with balanced sampling
    ds_train = TCRDataset(df_train, device, cdr3b_col='tcrb', cdr3a_col='tcra', scaler=scaler)
    class_count = np.array([df_train[df_train.sign == 0].shape[0], df_train[df_train.sign == 1].shape[0]])
    weight = 1. / class_count
    samples_weight = torch.tensor([weight[s] for s in df_train.sign])
    sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
    train_loader = torch.utils.data.DataLoader(
        ds_train,
        batch_size=batch_size,
        sampler=sampler
    )
    
    # val loader with balanced sampling
    ds_val = TCRDataset(df_val, device, cdr3b_col='tcrb', cdr3a_col='tcra', scaler=scaler)
    class_count = np.array([df_val[df_val.sign == 0].shape[0], df_val[df_val.sign == 1].shape[0]])
    weight = 1. / class_count
    samples_weight = torch.tensor([weight[s] for s in df_val.sign])
    sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
    val_loader = torch.utils.data.DataLoader(
        ds_val,
        batch_size=batch_size,
        sampler=sampler
    )

    model = MVIB(z_dim=z_dim, device=device, joint_posterior=joint_posterior).to(device)

    trainer = TrainerMVIB(
        model,
        epochs=epochs,
        lr=lr,
        beta=beta,
        checkpoint_dir=".",
        mode="trimodal",
        lr_scheduler_param=lr_scheduler_param
    )
    checkpoint = trainer.train(train_loader, val_loader, early_stopper_patience, monitor)    
    
    # test
    model = MVIB.from_checkpoint(checkpoint, torch.device("cpu"))
    pred = model.classify(pep=ds_test.pep, cdr3b=ds_test.cdr3b, cdr3a=ds_test.cdr3a)
    pred = pred.detach().numpy()
    df_test['prediction_'+str(i)] = pred.squeeze().tolist()

    # save results for further analysis
    df_test.to_csv(
        RESULTS_BASE + f"mvib.trimodal.{joint_posterior}.alpha+beta-only.rep-{i}.csv",
        index=False
    )

[VAL] Best epoch 148 | Score -0.927377 | DKL-prior 0.000564 | BCE 0.684856 | auROC 0.9274:  39%|███▉      | 197/500 [13:49<21:15,  4.21s/it]
[VAL] Best epoch 115 | Score -0.925632 | DKL-prior 0.000579 | BCE 0.537084 | auROC 0.9256:  33%|███▎      | 164/500 [11:28<23:31,  4.20s/it]
[VAL] Best epoch 195 | Score -0.933327 | DKL-prior 0.000555 | BCE 0.697255 | auROC 0.9333:  49%|████▉     | 244/500 [16:59<17:49,  4.18s/it]
[VAL] Best epoch 120 | Score -0.927332 | DKL-prior 0.000567 | BCE 0.632198 | auROC 0.9273:  34%|███▍      | 169/500 [11:51<23:13,  4.21s/it]
[VAL] Best epoch 128 | Score -0.936024 | DKL-prior 0.000581 | BCE 0.537664 | auROC 0.9360:  35%|███▌      | 177/500 [12:33<22:54,  4.26s/it]
