# Attentive Variational Information Bottleneck

In this notebook, we train and test the Attentive Variational Information Bottleneck (MVIB [1] with Attention of Experts) and MVIB on all datasets.

[1] Microbiome-based disease prediction with multimodal variational information bottlenecks, Grazioli et al., https://www.biorxiv.org/node/2109522.external-links.html

In [1]:
import pandas as pd
import torch
import numpy as np
import random

from vibtcr.dataset import TCRDataset
from vibtcr.mvib.mvib import MVIB
from vibtcr.mvib.mvib_trainer import TrainerMVIB

from torch.utils.data.sampler import WeightedRandomSampler
from sklearn.model_selection import train_test_split
from tqdm import tqdm

In [2]:
from sklearn.metrics import roc_auc_score, accuracy_score, precision_score, recall_score, f1_score, precision_recall_curve, auc
import pandas as pd
import torch

metrics = [
    'auROC',
    'Accuracy',
    'Recall',
    'Precision',
    'F1 score',
    'auPRC'
]

def pr_auc(y_true, y_prob):
    precision, recall, thresholds = precision_recall_curve(y_true, y_prob)
    pr_auc = auc(recall, precision)
    return pr_auc

def get_scores(y_true, y_prob, y_pred):
    """
    Compute a df with all classification metrics and respective scores.
    """
    
    scores = [
        roc_auc_score(y_true, y_prob),
        accuracy_score(y_true, y_pred),
        recall_score(y_true, y_pred),
        precision_score(y_true, y_pred),
        f1_score(y_true, y_pred),
        pr_auc(y_true, y_prob)
    ]
    
    df = pd.DataFrame(data={'score': scores, 'metrics': metrics})
    return df

In [3]:
def set_random_seed(random_seed):
    random.seed(random_seed)
    np.random.seed(random_seed)
    torch.manual_seed(random_seed)
    torch.cuda.manual_seed(random_seed)
    torch.cuda.manual_seed_all(random_seed)

In [4]:
import os
login = os.getlogin( )
DATA_BASE = f"/home/{login}/Git/tcr/data/"
RESULTS_BASE = f"/home/{login}/Git/tcr/notebooks/notebooks.classification/results/"
# To run in github checkout of vibtcr, after `unzip data.zip` ...
RESULTS_BASE = os.path.join('.', 'results')
DATA_BASE = os.path.join('..', '..', 'data')

In [5]:
device = torch.device('cuda:0')

batch_size = 4096
epochs = 500 #500 <-------------------------------------
lr = 1e-3

z_dim = 150
early_stopper_patience = 50
monitor = 'auROC'
lr_scheduler_param = 10

beta = 1e-6

# AVIB multimodal pooling of experts (aoe)

In [11]:
#
# NOTE: This notebook runs several choices for "joint_posterior"
#       Class BaseVIB supports:
#         'peo' ~ "product of experts" ~ "MVIB" (multimodal) original paper
#         'aoe' ~ "attention of experts" ~ "AVIB" (attentive)
#         'max_pool'
#         'avg_pool'
joint_posterior = "aoe"

# alpha+beta set - peptide+CDR3b (AVIB)

In [6]:
df = pd.read_csv(os.path.join(DATA_BASE, 'alpha-beta-splits', 'alpha-beta.csv'))

for i in range(5):  # 5 independent train/test splits
    set_random_seed(i)

    df_train, df_test = train_test_split(df.copy(), test_size=0.2, random_state=i)
    scaler = TCRDataset(df_train.copy(), torch.device("cpu"), cdr3b_col='tcrb', cdr3a_col=None).scaler

    ds_test = TCRDataset(df_test, torch.device("cpu"), cdr3b_col='tcrb', cdr3a_col=None, scaler=scaler)

    df_train, df_val = train_test_split(df_train, test_size=0.2, stratify=df_train.sign, random_state=i)
        
    # train loader with balanced sampling
    ds_train = TCRDataset(df_train, device, cdr3b_col='tcrb', cdr3a_col=None, scaler=scaler)
    class_count = np.array([df_train[df_train.sign == 0].shape[0], df_train[df_train.sign == 1].shape[0]])
    weight = 1. / class_count
    samples_weight = torch.tensor([weight[s] for s in df_train.sign])
    sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
    train_loader = torch.utils.data.DataLoader(
        ds_train,
        batch_size=batch_size,
        sampler=sampler
    )
    
    # val loader with balanced sampling
    ds_val = TCRDataset(df_val, device, cdr3b_col='tcrb', cdr3a_col=None, scaler=scaler)
    class_count = np.array([df_val[df_val.sign == 0].shape[0], df_val[df_val.sign == 1].shape[0]])
    weight = 1. / class_count
    samples_weight = torch.tensor([weight[s] for s in df_val.sign])
    sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
    val_loader = torch.utils.data.DataLoader(
        ds_val,
        batch_size=batch_size,
        sampler=sampler
    )

    model = MVIB(z_dim=z_dim, device=device, joint_posterior=joint_posterior).to(device)

    trainer = TrainerMVIB(
        model,
        epochs=epochs,
        lr=lr,
        beta=beta,
        checkpoint_dir=".",
        mode="bimodal",
        lr_scheduler_param=lr_scheduler_param
    )
    checkpoint = trainer.train(train_loader, val_loader, early_stopper_patience, monitor)
    run_name = f"mvib.bimodal.{joint_posterior}.alpha+beta-only.rep-{i}"
    trainer.save_checkpoint(checkpoint, folder='./', filename=os.path.join(RESULTS_BASE, f"{run_name}.pth"))
    
    # test
    model = MVIB.from_checkpoint(checkpoint, torch.device("cpu"))
    pred = model.classify(pep=ds_test.pep, cdr3b=ds_test.cdr3b, cdr3a=None)
    pred = pred.detach().numpy()
    df_test['prediction_'+str(i)] = pred.squeeze().tolist()

    # save results for further analysis
    df_test.to_csv(os.path.join(RESULTS_BASE, f"{run_name}.csv"), index=False)

[VAL] Best epoch 115 | Best val score -0.913730 | DKL-prior 0.000494 | BCE 0.811390 | auROC 0.9137:  33%|████████████████████████████████████████████████████████▍                                                                                                                   | 164/500 [07:37<15:38,  2.79s/it]


Saving best model: epoch 115


[VAL] Best epoch 49 | Best val score -0.915488 | DKL-prior 0.000439 | BCE 0.505180 | auROC 0.9155:  20%|██████████████████████████████████                                                                                                                                            | 98/500 [04:54<20:08,  3.01s/it]


Saving best model: epoch 49


[VAL] Best epoch 87 | Best val score -0.918312 | DKL-prior 0.000490 | BCE 0.616983 | auROC 0.9183:  27%|███████████████████████████████████████████████                                                                                                                              | 136/500 [06:58<18:39,  3.08s/it]


Saving best model: epoch 87


[VAL] Best epoch 75 | Best val score -0.913309 | DKL-prior 0.000497 | BCE 0.581592 | auROC 0.9133:  25%|██████████████████████████████████████████▉                                                                                                                                  | 124/500 [06:21<19:15,  3.07s/it]


Saving best model: epoch 75


[VAL] Best epoch 91 | Best val score -0.919976 | DKL-prior 0.000475 | BCE 0.610277 | auROC 0.9200:  28%|████████████████████████████████████████████████▍                                                                                                                            | 140/500 [07:11<18:29,  3.08s/it]


Saving best model: epoch 91


# alpha+beta set - peptide+CDR3b+CDR3a (AVIB)

In [7]:
df = pd.read_csv(os.path.join(DATA_BASE, 'alpha-beta-splits', 'alpha-beta.csv'))

for i in range(5):  # 5 independent train/test splits
    set_random_seed(i)

    df_train, df_test = train_test_split(df.copy(), test_size=0.2, random_state=i)
    scaler = TCRDataset(df_train.copy(), torch.device("cpu"), cdr3b_col='tcrb', cdr3a_col='tcra').scaler

    ds_test = TCRDataset(df_test, torch.device("cpu"), cdr3b_col='tcrb', cdr3a_col='tcra', scaler=scaler)

    df_train, df_val = train_test_split(df_train, test_size=0.2, stratify=df_train.sign, random_state=i)
        
    # train loader with balanced sampling
    ds_train = TCRDataset(df_train, device, cdr3b_col='tcrb', cdr3a_col='tcra', scaler=scaler)
    class_count = np.array([df_train[df_train.sign == 0].shape[0], df_train[df_train.sign == 1].shape[0]])
    weight = 1. / class_count
    samples_weight = torch.tensor([weight[s] for s in df_train.sign])
    sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
    train_loader = torch.utils.data.DataLoader(
        ds_train,
        batch_size=batch_size,
        sampler=sampler
    )
    
    # val loader with balanced sampling
    ds_val = TCRDataset(df_val, device, cdr3b_col='tcrb', cdr3a_col='tcra', scaler=scaler)
    class_count = np.array([df_val[df_val.sign == 0].shape[0], df_val[df_val.sign == 1].shape[0]])
    weight = 1. / class_count
    samples_weight = torch.tensor([weight[s] for s in df_val.sign])
    sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
    val_loader = torch.utils.data.DataLoader(
        ds_val,
        batch_size=batch_size,
        sampler=sampler
    )

    model = MVIB(z_dim=z_dim, device=device, joint_posterior=joint_posterior).to(device)

    trainer = TrainerMVIB(
        model,
        epochs=epochs,
        lr=lr,
        beta=beta,
        checkpoint_dir=".",
        mode="trimodal",
        lr_scheduler_param=lr_scheduler_param
    )
    checkpoint = trainer.train(train_loader, val_loader, early_stopper_patience, monitor)    
    run_name = f"mvib.trimodal.{joint_posterior}.alpha+beta-only.rep-{i}"
    trainer.save_checkpoint(checkpoint, folder='./', filename=os.path.join(RESULTS_BASE, f"{run_name}.pth"))

    # test
    model = MVIB.from_checkpoint(checkpoint, torch.device("cpu"))
    pred = model.classify(pep=ds_test.pep, cdr3b=ds_test.cdr3b, cdr3a=ds_test.cdr3a)
    pred = pred.detach().numpy()
    df_test['prediction_'+str(i)] = pred.squeeze().tolist()

    # save results for further analysis
    df_test.to_csv(os.path.join(RESULTS_BASE, f"{run_name}.csv"), index=False)

[VAL] Best epoch 78 | Best val score -0.935460 | DKL-prior 0.000474 | BCE 0.635196 | auROC 0.9355:  25%|███████████████████████████████████████████▍                                                                                                                               | 127/500 [20:39<1:00:41,  9.76s/it]


Saving best model: epoch 78


[VAL] Best epoch 52 | Best val score -0.935776 | DKL-prior 0.000561 | BCE 0.525582 | auROC 0.9358:  20%|██████████████████████████████████▌                                                                                                                                        | 101/500 [15:57<1:03:02,  9.48s/it]


Saving best model: epoch 52


[VAL] Best epoch 66 | Best val score -0.940779 | DKL-prior 0.000573 | BCE 0.502048 | auROC 0.9408:  23%|███████████████████████████████████████▎                                                                                                                                   | 115/500 [17:57<1:00:08,  9.37s/it]


Saving best model: epoch 66


[VAL] Best epoch 78 | Best val score -0.940502 | DKL-prior 0.000384 | BCE 0.578420 | auROC 0.9405:  25%|███████████████████████████████████████████▉                                                                                                                                 | 127/500 [19:27<57:07,  9.19s/it]


Saving best model: epoch 78


[VAL] Best epoch 73 | Best val score -0.943219 | DKL-prior 0.000422 | BCE 0.516536 | auROC 0.9432:  24%|████████████████████████████████████████████▍                                                                                                                                         | 122/500 [18:42<57:59,  9.20s/it]


Saving best model: epoch 73


# alpha+beta set - peptide+CDR3a (AVIB)

In [8]:
df = pd.read_csv(os.path.join(DATA_BASE, 'alpha-beta-splits', 'alpha-beta.csv'))

for i in range(5):  # 5 independent train/test splits
    set_random_seed(i)

    df_train, df_test = train_test_split(df.copy(), test_size=0.2, random_state=i)
    scaler = TCRDataset(df_train.copy(), torch.device("cpu"), cdr3b_col='tcra').scaler
    # we pass column `tcra` to `cdr3b_col` because TCRDataset expects to have the CDR3b attribute
    ds_test = TCRDataset(df_test, torch.device("cpu"), cdr3b_col='tcra', scaler=scaler)

    df_train, df_val = train_test_split(df_train, test_size=0.2, stratify=df_train.sign, random_state=i)
        
    # train loader with balanced sampling
    ds_train = TCRDataset(df_train, device, cdr3b_col='tcra', scaler=scaler)
    class_count = np.array([df_train[df_train.sign == 0].shape[0], df_train[df_train.sign == 1].shape[0]])
    weight = 1. / class_count
    samples_weight = torch.tensor([weight[s] for s in df_train.sign])
    sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
    train_loader = torch.utils.data.DataLoader(
        ds_train,
        batch_size=batch_size,
        sampler=sampler
    )
    
    # val loader with balanced sampling
    ds_val = TCRDataset(df_val, device, cdr3b_col='tcra', scaler=scaler)
    class_count = np.array([df_val[df_val.sign == 0].shape[0], df_val[df_val.sign == 1].shape[0]])
    weight = 1. / class_count
    samples_weight = torch.tensor([weight[s] for s in df_val.sign])
    sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
    val_loader = torch.utils.data.DataLoader(
        ds_val,
        batch_size=batch_size,
        sampler=sampler
    )

    model = MVIB(z_dim=z_dim, device=device, joint_posterior=joint_posterior).to(device)

    trainer = TrainerMVIB(
        model,
        epochs=epochs,
        lr=lr,
        beta=beta,
        checkpoint_dir=".",
        mode="bimodal",
        lr_scheduler_param=lr_scheduler_param
    )
    checkpoint = trainer.train(train_loader, val_loader, early_stopper_patience, monitor)    
    run_name = f"mvib.bimodal-alpha.{joint_posterior}.alpha+beta-only.rep-{i}"
    trainer.save_checkpoint(checkpoint, folder='./', filename=os.path.join(RESULTS_BASE, f"{run_name}.pth"))

    # test
    model = MVIB.from_checkpoint(checkpoint, torch.device("cpu"))
    pred = model.classify(pep=ds_test.pep, cdr3b=ds_test.cdr3b)
    pred = pred.detach().numpy()
    df_test['prediction_'+str(i)] = pred.squeeze().tolist()

    # save results for further analysis
    df_test.to_csv(os.path.join(RESULTS_BASE, f"{run_name}.csv"), index=False)

[VAL] Best epoch 84 | Best val score -0.924685 | DKL-prior 0.000489 | BCE 0.561063 | auROC 0.9247:  27%|████████████████████████████████████████████████▍                                                                                                                                     | 133/500 [06:11<17:04,  2.79s/it]


Saving best model: epoch 84


[VAL] Best epoch 52 | Best val score -0.919000 | DKL-prior 0.000524 | BCE 0.511463 | auROC 0.9190:  20%|████████████████████████████████████▊                                                                                                                                                 | 101/500 [04:41<18:33,  2.79s/it]


Saving best model: epoch 52


[VAL] Best epoch 66 | Best val score -0.930090 | DKL-prior 0.000519 | BCE 0.452993 | auROC 0.9301:  23%|█████████████████████████████████████████▊                                                                                                                                            | 115/500 [05:20<17:53,  2.79s/it]


Saving best model: epoch 66


[VAL] Best epoch 160 | Best val score -0.924717 | DKL-prior 0.000513 | BCE 0.894604 | auROC 0.9247:  42%|███████████████████████████████████████████████████████████████████████████▋                                                                                                         | 209/500 [09:40<13:28,  2.78s/it]


Saving best model: epoch 160


[VAL] Best epoch 84 | Best val score -0.929822 | DKL-prior 0.000501 | BCE 0.516446 | auROC 0.9298:  27%|████████████████████████████████████████████████▍                                                                                                                                     | 133/500 [06:11<17:05,  2.79s/it]


Saving best model: epoch 84


# beta set (AVIB)

In [9]:
df = pd.read_csv(os.path.join(DATA_BASE, 'alpha-beta-splits', 'beta.csv'))

for i in range(5):  # 5 independent train/test splits
    set_random_seed(i)

    df_train, df_test = train_test_split(df.copy(), test_size=0.2, random_state=i)
    scaler = TCRDataset(df_train.copy(), torch.device("cpu"), cdr3b_col='tcrb', cdr3a_col=None).scaler

    ds_test = TCRDataset(df_test, torch.device("cpu"), cdr3b_col='tcrb', cdr3a_col=None, scaler=scaler)

    df_train, df_val = train_test_split(df_train, test_size=0.2, stratify=df_train.sign, random_state=i)
        
    # train loader with balanced sampling
    ds_train = TCRDataset(df_train, device, cdr3b_col='tcrb', cdr3a_col=None, scaler=scaler)
    class_count = np.array([df_train[df_train.sign == 0].shape[0], df_train[df_train.sign == 1].shape[0]])
    weight = 1. / class_count
    samples_weight = torch.tensor([weight[s] for s in df_train.sign])
    sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
    train_loader = torch.utils.data.DataLoader(
        ds_train,
        batch_size=batch_size,
        sampler=sampler
    )
    
    # val loader with balanced sampling
    ds_val = TCRDataset(df_val, device, cdr3b_col='tcrb', cdr3a_col=None, scaler=scaler)
    class_count = np.array([df_val[df_val.sign == 0].shape[0], df_val[df_val.sign == 1].shape[0]])
    weight = 1. / class_count
    samples_weight = torch.tensor([weight[s] for s in df_val.sign])
    sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
    val_loader = torch.utils.data.DataLoader(
        ds_val,
        batch_size=batch_size,
        sampler=sampler
    )

    model = MVIB(z_dim=z_dim, device=device, joint_posterior=joint_posterior).to(device)

    trainer = TrainerMVIB(
        model,
        epochs=epochs,
        lr=lr,
        beta=beta,
        checkpoint_dir=".",
        mode="bimodal",
        lr_scheduler_param=lr_scheduler_param
    )
    checkpoint = trainer.train(train_loader, val_loader, early_stopper_patience, monitor)    
    run_name = f"mvib.bimodal.{joint_posterior}.beta-only.rep-{i}"
    trainer.save_checkpoint(checkpoint, folder='./', filename=os.path.join(RESULTS_BASE, f"{run_name}.pth"))
    
    # test
    model = MVIB.from_checkpoint(checkpoint, torch.device("cpu"))
    pred = model.classify(pep=ds_test.pep, cdr3b=ds_test.cdr3b, cdr3a=None)
    pred = pred.detach().numpy()
    df_test['prediction_'+str(i)] = pred.squeeze().tolist()

    # save results for further analysis
    df_test.to_csv(os.path.join(RESULTS_BASE, f"{run_name}.csv"), index=False)

[VAL] Best epoch 174 | Best val score -0.832874 | DKL-prior 0.000496 | BCE 0.800350 | auROC 0.8329:  45%|████████████████████████████████████████████████████████████████████████████████▋                                                                                                    | 223/500 [15:37<19:24,  4.20s/it]


Saving best model: epoch 174


[VAL] Best epoch 288 | Best val score -0.837785 | DKL-prior 0.000621 | BCE 0.872121 | auROC 0.8378:  67%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉                                                           | 337/500 [23:14<11:14,  4.14s/it]


Saving best model: epoch 288


[VAL] Best epoch 170 | Best val score -0.834989 | DKL-prior 0.000544 | BCE 0.899107 | auROC 0.8350:  44%|███████████████████████████████████████████████████████████████████████████████▎                                                                                                     | 219/500 [15:14<19:33,  4.18s/it]


Saving best model: epoch 170


[VAL] Best epoch 274 | Best val score -0.846585 | DKL-prior 0.000608 | BCE 0.862442 | auROC 0.8466:  65%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉                                                                | 323/500 [22:10<12:08,  4.12s/it]


Saving best model: epoch 274


[VAL] Best epoch 172 | Best val score -0.833392 | DKL-prior 0.000653 | BCE 0.797987 | auROC 0.8334:  44%|████████████████████████████████████████████████████████████████████████████████                                                                                                     | 221/500 [15:28<19:31,  4.20s/it]


Saving best model: epoch 172


# full set: alpha+beta set + beta set (AVIB)

In [10]:
df1 = pd.read_csv(os.path.join(DATA_BASE, 'alpha-beta-splits', 'beta.csv'))
df2 = pd.read_csv(os.path.join(DATA_BASE, 'alpha-beta-splits', 'alpha-beta.csv'))
df = pd.concat([df1, df2]).reset_index()

for i in range(5):  # 5 independent train/test splits
    set_random_seed(i)

    df_train, df_test = train_test_split(df.copy(), test_size=0.2, random_state=i)
    scaler = TCRDataset(df_train.copy(), torch.device("cpu"), cdr3b_col='tcrb', cdr3a_col=None).scaler

    ds_test = TCRDataset(df_test, torch.device("cpu"), cdr3b_col='tcrb', cdr3a_col=None, scaler=scaler)

    df_train, df_val = train_test_split(df_train, test_size=0.2, stratify=df_train.sign, random_state=i)
        
    # train loader with balanced sampling
    ds_train = TCRDataset(df_train, device, cdr3b_col='tcrb', cdr3a_col=None, scaler=scaler)
    class_count = np.array([df_train[df_train.sign == 0].shape[0], df_train[df_train.sign == 1].shape[0]])
    weight = 1. / class_count
    samples_weight = torch.tensor([weight[s] for s in df_train.sign])
    sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
    train_loader = torch.utils.data.DataLoader(
        ds_train,
        batch_size=batch_size,
        sampler=sampler
    )
    
    # val loader with balanced sampling
    ds_val = TCRDataset(df_val, device, cdr3b_col='tcrb', cdr3a_col=None, scaler=scaler)
    class_count = np.array([df_val[df_val.sign == 0].shape[0], df_val[df_val.sign == 1].shape[0]])
    weight = 1. / class_count
    samples_weight = torch.tensor([weight[s] for s in df_val.sign])
    sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
    val_loader = torch.utils.data.DataLoader(
        ds_val,
        batch_size=batch_size,
        sampler=sampler
    )

    model = MVIB(z_dim=z_dim, device=device, joint_posterior=joint_posterior).to(device)

    trainer = TrainerMVIB(
        model,
        epochs=epochs,
        lr=lr,
        beta=beta,
        checkpoint_dir=".",
        mode="bimodal",
        lr_scheduler_param=lr_scheduler_param
    )
    checkpoint = trainer.train(train_loader, val_loader, early_stopper_patience, monitor)    
    run_name = f"mvib.bimodal.{joint_posterior}.full.rep-{i}"
    trainer.save_checkpoint(checkpoint, folder='./', filename=os.path.join(RESULTS_BASE, f"{run_name}.pth"))
    
    # test
    model = MVIB.from_checkpoint(checkpoint, torch.device("cpu"))
    pred = model.classify(pep=ds_test.pep, cdr3b=ds_test.cdr3b, cdr3a=None)
    pred = pred.detach().numpy()
    df_test['prediction_'+str(i)] = pred.squeeze().tolist()

    # save results for further analysis
    df_test.to_csv(os.path.join(RESULTS_BASE, f"{run_name}.csv"), index=False)

[VAL] Best epoch 180 | Best val score -0.815843 | DKL-prior 0.000677 | BCE 0.755239 | auROC 0.8158:  46%|██████████████████████████████████████████████████████████████████████████████████▉                                                                                                  | 229/500 [27:47<32:53,  7.28s/it]


Saving best model: epoch 180


[VAL] Best epoch 150 | Best val score -0.806158 | DKL-prior 0.000565 | BCE 0.758233 | auROC 0.8062:  40%|████████████████████████████████████████████████████████████████████████                                                                                                             | 199/500 [24:05<36:25,  7.26s/it]


Saving best model: epoch 150


[VAL] Best epoch 215 | Best val score -0.813805 | DKL-prior 0.000642 | BCE 0.753081 | auROC 0.8138:  53%|███████████████████████████████████████████████████████████████████████████████████████████████▌                                                                                     | 264/500 [31:53<28:30,  7.25s/it]


Saving best model: epoch 215


[VAL] Best epoch 162 | Best val score -0.814897 | DKL-prior 0.000554 | BCE 0.699809 | auROC 0.8149:  42%|████████████████████████████████████████████████████████████████████████████▍                                                                                                        | 211/500 [25:29<34:54,  7.25s/it]


Saving best model: epoch 162


[VAL] Best epoch 272 | Best val score -0.816647 | DKL-prior 0.000692 | BCE 0.791097 | auROC 0.8166:  64%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏                                                                | 321/500 [38:56<21:42,  7.28s/it]


Saving best model: epoch 272


# MVIB multimodal pooling of experts (poe)

In [11]:
joint_posterior = "poe"

# alpha+beta set - peptide+CDR3b (MVIB)

In [6]:
df = pd.read_csv(os.path.join(DATA_BASE, 'alpha-beta-splits', 'alpha-beta.csv'))

for i in range(5):  # 5 independent train/test splits
    set_random_seed(i)

    df_train, df_test = train_test_split(df.copy(), test_size=0.2, random_state=i)
    scaler = TCRDataset(df_train.copy(), torch.device("cpu"), cdr3b_col='tcrb', cdr3a_col=None).scaler

    ds_test = TCRDataset(df_test, torch.device("cpu"), cdr3b_col='tcrb', cdr3a_col=None, scaler=scaler)

    df_train, df_val = train_test_split(df_train, test_size=0.2, stratify=df_train.sign, random_state=i)
        
    # train loader with balanced sampling
    ds_train = TCRDataset(df_train, device, cdr3b_col='tcrb', cdr3a_col=None, scaler=scaler)
    class_count = np.array([df_train[df_train.sign == 0].shape[0], df_train[df_train.sign == 1].shape[0]])
    weight = 1. / class_count
    samples_weight = torch.tensor([weight[s] for s in df_train.sign])
    sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
    train_loader = torch.utils.data.DataLoader(
        ds_train,
        batch_size=batch_size,
        sampler=sampler
    )
    
    # val loader with balanced sampling
    ds_val = TCRDataset(df_val, device, cdr3b_col='tcrb', cdr3a_col=None, scaler=scaler)
    class_count = np.array([df_val[df_val.sign == 0].shape[0], df_val[df_val.sign == 1].shape[0]])
    weight = 1. / class_count
    samples_weight = torch.tensor([weight[s] for s in df_val.sign])
    sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
    val_loader = torch.utils.data.DataLoader(
        ds_val,
        batch_size=batch_size,
        sampler=sampler
    )

    model = MVIB(z_dim=z_dim, device=device, joint_posterior=joint_posterior).to(device)

    trainer = TrainerMVIB(
        model,
        epochs=epochs,
        lr=lr,
        beta=beta,
        checkpoint_dir=".",
        mode="bimodal",
        lr_scheduler_param=lr_scheduler_param
    )
    checkpoint = trainer.train(train_loader, val_loader, early_stopper_patience, monitor)
    run_name = f"mvib.bimodal.{joint_posterior}.alpha+beta-only.rep-{i}"
    trainer.save_checkpoint(checkpoint, folder='./', filename=os.path.join(RESULTS_BASE, f"{run_name}.pth"))
    
    # test
    model = MVIB.from_checkpoint(checkpoint, torch.device("cpu"))
    pred = model.classify(pep=ds_test.pep, cdr3b=ds_test.cdr3b, cdr3a=None)
    pred = pred.detach().numpy()
    df_test['prediction_'+str(i)] = pred.squeeze().tolist()

    # save results for further analysis
    df_test.to_csv(os.path.join(RESULTS_BASE, f"{run_name}.csv"), index=False)

[VAL] Best epoch 115 | Best val score -0.913730 | DKL-prior 0.000494 | BCE 0.811390 | auROC 0.9137:  33%|████████████████████████████████████████████████████████▍                                                                                                                   | 164/500 [07:37<15:38,  2.79s/it]


Saving best model: epoch 115


[VAL] Best epoch 49 | Best val score -0.915488 | DKL-prior 0.000439 | BCE 0.505180 | auROC 0.9155:  20%|██████████████████████████████████                                                                                                                                            | 98/500 [04:54<20:08,  3.01s/it]


Saving best model: epoch 49


[VAL] Best epoch 87 | Best val score -0.918312 | DKL-prior 0.000490 | BCE 0.616983 | auROC 0.9183:  27%|███████████████████████████████████████████████                                                                                                                              | 136/500 [06:58<18:39,  3.08s/it]


Saving best model: epoch 87


[VAL] Best epoch 75 | Best val score -0.913309 | DKL-prior 0.000497 | BCE 0.581592 | auROC 0.9133:  25%|██████████████████████████████████████████▉                                                                                                                                  | 124/500 [06:21<19:15,  3.07s/it]


Saving best model: epoch 75


[VAL] Best epoch 91 | Best val score -0.919976 | DKL-prior 0.000475 | BCE 0.610277 | auROC 0.9200:  28%|████████████████████████████████████████████████▍                                                                                                                            | 140/500 [07:11<18:29,  3.08s/it]


Saving best model: epoch 91


# alpha+beta set - peptide+CDR3b+CDR3a (MVIB)

In [7]:
df = pd.read_csv(os.path.join(DATA_BASE, 'alpha-beta-splits', 'alpha-beta.csv'))

for i in range(5):  # 5 independent train/test splits
    set_random_seed(i)

    df_train, df_test = train_test_split(df.copy(), test_size=0.2, random_state=i)
    scaler = TCRDataset(df_train.copy(), torch.device("cpu"), cdr3b_col='tcrb', cdr3a_col='tcra').scaler

    ds_test = TCRDataset(df_test, torch.device("cpu"), cdr3b_col='tcrb', cdr3a_col='tcra', scaler=scaler)

    df_train, df_val = train_test_split(df_train, test_size=0.2, stratify=df_train.sign, random_state=i)
        
    # train loader with balanced sampling
    ds_train = TCRDataset(df_train, device, cdr3b_col='tcrb', cdr3a_col='tcra', scaler=scaler)
    class_count = np.array([df_train[df_train.sign == 0].shape[0], df_train[df_train.sign == 1].shape[0]])
    weight = 1. / class_count
    samples_weight = torch.tensor([weight[s] for s in df_train.sign])
    sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
    train_loader = torch.utils.data.DataLoader(
        ds_train,
        batch_size=batch_size,
        sampler=sampler
    )
    
    # val loader with balanced sampling
    ds_val = TCRDataset(df_val, device, cdr3b_col='tcrb', cdr3a_col='tcra', scaler=scaler)
    class_count = np.array([df_val[df_val.sign == 0].shape[0], df_val[df_val.sign == 1].shape[0]])
    weight = 1. / class_count
    samples_weight = torch.tensor([weight[s] for s in df_val.sign])
    sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
    val_loader = torch.utils.data.DataLoader(
        ds_val,
        batch_size=batch_size,
        sampler=sampler
    )

    model = MVIB(z_dim=z_dim, device=device, joint_posterior=joint_posterior).to(device)

    trainer = TrainerMVIB(
        model,
        epochs=epochs,
        lr=lr,
        beta=beta,
        checkpoint_dir=".",
        mode="trimodal",
        lr_scheduler_param=lr_scheduler_param
    )
    checkpoint = trainer.train(train_loader, val_loader, early_stopper_patience, monitor)    
    run_name = f"mvib.trimodal.{joint_posterior}.alpha+beta-only.rep-{i}"
    trainer.save_checkpoint(checkpoint, folder='./', filename=os.path.join(RESULTS_BASE, f"{run_name}.pth"))

    # test
    model = MVIB.from_checkpoint(checkpoint, torch.device("cpu"))
    pred = model.classify(pep=ds_test.pep, cdr3b=ds_test.cdr3b, cdr3a=ds_test.cdr3a)
    pred = pred.detach().numpy()
    df_test['prediction_'+str(i)] = pred.squeeze().tolist()

    # save results for further analysis
    df_test.to_csv(os.path.join(RESULTS_BASE, f"{run_name}.csv"), index=False)

[VAL] Best epoch 78 | Best val score -0.935460 | DKL-prior 0.000474 | BCE 0.635196 | auROC 0.9355:  25%|███████████████████████████████████████████▍                                                                                                                               | 127/500 [20:39<1:00:41,  9.76s/it]


Saving best model: epoch 78


[VAL] Best epoch 52 | Best val score -0.935776 | DKL-prior 0.000561 | BCE 0.525582 | auROC 0.9358:  20%|██████████████████████████████████▌                                                                                                                                        | 101/500 [15:57<1:03:02,  9.48s/it]


Saving best model: epoch 52


[VAL] Best epoch 66 | Best val score -0.940779 | DKL-prior 0.000573 | BCE 0.502048 | auROC 0.9408:  23%|███████████████████████████████████████▎                                                                                                                                   | 115/500 [17:57<1:00:08,  9.37s/it]


Saving best model: epoch 66


[VAL] Best epoch 78 | Best val score -0.940502 | DKL-prior 0.000384 | BCE 0.578420 | auROC 0.9405:  25%|███████████████████████████████████████████▉                                                                                                                                 | 127/500 [19:27<57:07,  9.19s/it]


Saving best model: epoch 78


[VAL] Best epoch 73 | Best val score -0.943219 | DKL-prior 0.000422 | BCE 0.516536 | auROC 0.9432:  24%|████████████████████████████████████████████▍                                                                                                                                         | 122/500 [18:42<57:59,  9.20s/it]


Saving best model: epoch 73


# alpha+beta set - peptide+CDR3a (MVIB)

In [8]:
df = pd.read_csv(os.path.join(DATA_BASE, 'alpha-beta-splits', 'alpha-beta.csv'))

for i in range(5):  # 5 independent train/test splits
    set_random_seed(i)

    df_train, df_test = train_test_split(df.copy(), test_size=0.2, random_state=i)
    scaler = TCRDataset(df_train.copy(), torch.device("cpu"), cdr3b_col='tcra').scaler
    # we pass column `tcra` to `cdr3b_col` because TCRDataset expects to have the CDR3b attribute
    ds_test = TCRDataset(df_test, torch.device("cpu"), cdr3b_col='tcra', scaler=scaler)

    df_train, df_val = train_test_split(df_train, test_size=0.2, stratify=df_train.sign, random_state=i)
        
    # train loader with balanced sampling
    ds_train = TCRDataset(df_train, device, cdr3b_col='tcra', scaler=scaler)
    class_count = np.array([df_train[df_train.sign == 0].shape[0], df_train[df_train.sign == 1].shape[0]])
    weight = 1. / class_count
    samples_weight = torch.tensor([weight[s] for s in df_train.sign])
    sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
    train_loader = torch.utils.data.DataLoader(
        ds_train,
        batch_size=batch_size,
        sampler=sampler
    )
    
    # val loader with balanced sampling
    ds_val = TCRDataset(df_val, device, cdr3b_col='tcra', scaler=scaler)
    class_count = np.array([df_val[df_val.sign == 0].shape[0], df_val[df_val.sign == 1].shape[0]])
    weight = 1. / class_count
    samples_weight = torch.tensor([weight[s] for s in df_val.sign])
    sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
    val_loader = torch.utils.data.DataLoader(
        ds_val,
        batch_size=batch_size,
        sampler=sampler
    )

    model = MVIB(z_dim=z_dim, device=device, joint_posterior=joint_posterior).to(device)

    trainer = TrainerMVIB(
        model,
        epochs=epochs,
        lr=lr,
        beta=beta,
        checkpoint_dir=".",
        mode="bimodal",
        lr_scheduler_param=lr_scheduler_param
    )
    checkpoint = trainer.train(train_loader, val_loader, early_stopper_patience, monitor)    
    run_name = f"mvib.bimodal-alpha.{joint_posterior}.alpha+beta-only.rep-{i}"
    trainer.save_checkpoint(checkpoint, folder='./', filename=os.path.join(RESULTS_BASE, f"{run_name}.pth"))

    # test
    model = MVIB.from_checkpoint(checkpoint, torch.device("cpu"))
    pred = model.classify(pep=ds_test.pep, cdr3b=ds_test.cdr3b)
    pred = pred.detach().numpy()
    df_test['prediction_'+str(i)] = pred.squeeze().tolist()

    # save results for further analysis
    df_test.to_csv(os.path.join(RESULTS_BASE, f"{run_name}.csv"), index=False)

[VAL] Best epoch 84 | Best val score -0.924685 | DKL-prior 0.000489 | BCE 0.561063 | auROC 0.9247:  27%|████████████████████████████████████████████████▍                                                                                                                                     | 133/500 [06:11<17:04,  2.79s/it]


Saving best model: epoch 84


[VAL] Best epoch 52 | Best val score -0.919000 | DKL-prior 0.000524 | BCE 0.511463 | auROC 0.9190:  20%|████████████████████████████████████▊                                                                                                                                                 | 101/500 [04:41<18:33,  2.79s/it]


Saving best model: epoch 52


[VAL] Best epoch 66 | Best val score -0.930090 | DKL-prior 0.000519 | BCE 0.452993 | auROC 0.9301:  23%|█████████████████████████████████████████▊                                                                                                                                            | 115/500 [05:20<17:53,  2.79s/it]


Saving best model: epoch 66


[VAL] Best epoch 160 | Best val score -0.924717 | DKL-prior 0.000513 | BCE 0.894604 | auROC 0.9247:  42%|███████████████████████████████████████████████████████████████████████████▋                                                                                                         | 209/500 [09:40<13:28,  2.78s/it]


Saving best model: epoch 160


[VAL] Best epoch 84 | Best val score -0.929822 | DKL-prior 0.000501 | BCE 0.516446 | auROC 0.9298:  27%|████████████████████████████████████████████████▍                                                                                                                                     | 133/500 [06:11<17:05,  2.79s/it]


Saving best model: epoch 84


# beta set (MVIB)

In [9]:
df = pd.read_csv(os.path.join(DATA_BASE, 'alpha-beta-splits', 'beta.csv'))

for i in range(5):  # 5 independent train/test splits
    set_random_seed(i)

    df_train, df_test = train_test_split(df.copy(), test_size=0.2, random_state=i)
    scaler = TCRDataset(df_train.copy(), torch.device("cpu"), cdr3b_col='tcrb', cdr3a_col=None).scaler

    ds_test = TCRDataset(df_test, torch.device("cpu"), cdr3b_col='tcrb', cdr3a_col=None, scaler=scaler)

    df_train, df_val = train_test_split(df_train, test_size=0.2, stratify=df_train.sign, random_state=i)
        
    # train loader with balanced sampling
    ds_train = TCRDataset(df_train, device, cdr3b_col='tcrb', cdr3a_col=None, scaler=scaler)
    class_count = np.array([df_train[df_train.sign == 0].shape[0], df_train[df_train.sign == 1].shape[0]])
    weight = 1. / class_count
    samples_weight = torch.tensor([weight[s] for s in df_train.sign])
    sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
    train_loader = torch.utils.data.DataLoader(
        ds_train,
        batch_size=batch_size,
        sampler=sampler
    )
    
    # val loader with balanced sampling
    ds_val = TCRDataset(df_val, device, cdr3b_col='tcrb', cdr3a_col=None, scaler=scaler)
    class_count = np.array([df_val[df_val.sign == 0].shape[0], df_val[df_val.sign == 1].shape[0]])
    weight = 1. / class_count
    samples_weight = torch.tensor([weight[s] for s in df_val.sign])
    sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
    val_loader = torch.utils.data.DataLoader(
        ds_val,
        batch_size=batch_size,
        sampler=sampler
    )

    model = MVIB(z_dim=z_dim, device=device, joint_posterior=joint_posterior).to(device)

    trainer = TrainerMVIB(
        model,
        epochs=epochs,
        lr=lr,
        beta=beta,
        checkpoint_dir=".",
        mode="bimodal",
        lr_scheduler_param=lr_scheduler_param
    )
    checkpoint = trainer.train(train_loader, val_loader, early_stopper_patience, monitor)    
    run_name = f"mvib.bimodal.{joint_posterior}.beta-only.rep-{i}"
    trainer.save_checkpoint(checkpoint, folder='./', filename=os.path.join(RESULTS_BASE, f"{run_name}.pth"))
    
    # test
    model = MVIB.from_checkpoint(checkpoint, torch.device("cpu"))
    pred = model.classify(pep=ds_test.pep, cdr3b=ds_test.cdr3b, cdr3a=None)
    pred = pred.detach().numpy()
    df_test['prediction_'+str(i)] = pred.squeeze().tolist()

    # save results for further analysis
    df_test.to_csv(os.path.join(RESULTS_BASE, f"{run_name}.csv"), index=False)

[VAL] Best epoch 174 | Best val score -0.832874 | DKL-prior 0.000496 | BCE 0.800350 | auROC 0.8329:  45%|████████████████████████████████████████████████████████████████████████████████▋                                                                                                    | 223/500 [15:37<19:24,  4.20s/it]


Saving best model: epoch 174


[VAL] Best epoch 288 | Best val score -0.837785 | DKL-prior 0.000621 | BCE 0.872121 | auROC 0.8378:  67%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉                                                           | 337/500 [23:14<11:14,  4.14s/it]


Saving best model: epoch 288


[VAL] Best epoch 170 | Best val score -0.834989 | DKL-prior 0.000544 | BCE 0.899107 | auROC 0.8350:  44%|███████████████████████████████████████████████████████████████████████████████▎                                                                                                     | 219/500 [15:14<19:33,  4.18s/it]


Saving best model: epoch 170


[VAL] Best epoch 274 | Best val score -0.846585 | DKL-prior 0.000608 | BCE 0.862442 | auROC 0.8466:  65%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉                                                                | 323/500 [22:10<12:08,  4.12s/it]


Saving best model: epoch 274


[VAL] Best epoch 172 | Best val score -0.833392 | DKL-prior 0.000653 | BCE 0.797987 | auROC 0.8334:  44%|████████████████████████████████████████████████████████████████████████████████                                                                                                     | 221/500 [15:28<19:31,  4.20s/it]


Saving best model: epoch 172


# full set: alpha+beta set + beta set (MVIB)

In [10]:
df1 = pd.read_csv(os.path.join(DATA_BASE, 'alpha-beta-splits', 'beta.csv'))
df2 = pd.read_csv(os.path.join(DATA_BASE, 'alpha-beta-splits', 'alpha-beta.csv'))
df = pd.concat([df1, df2]).reset_index()

for i in range(5):  # 5 independent train/test splits
    set_random_seed(i)

    df_train, df_test = train_test_split(df.copy(), test_size=0.2, random_state=i)
    scaler = TCRDataset(df_train.copy(), torch.device("cpu"), cdr3b_col='tcrb', cdr3a_col=None).scaler

    ds_test = TCRDataset(df_test, torch.device("cpu"), cdr3b_col='tcrb', cdr3a_col=None, scaler=scaler)

    df_train, df_val = train_test_split(df_train, test_size=0.2, stratify=df_train.sign, random_state=i)
        
    # train loader with balanced sampling
    ds_train = TCRDataset(df_train, device, cdr3b_col='tcrb', cdr3a_col=None, scaler=scaler)
    class_count = np.array([df_train[df_train.sign == 0].shape[0], df_train[df_train.sign == 1].shape[0]])
    weight = 1. / class_count
    samples_weight = torch.tensor([weight[s] for s in df_train.sign])
    sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
    train_loader = torch.utils.data.DataLoader(
        ds_train,
        batch_size=batch_size,
        sampler=sampler
    )
    
    # val loader with balanced sampling
    ds_val = TCRDataset(df_val, device, cdr3b_col='tcrb', cdr3a_col=None, scaler=scaler)
    class_count = np.array([df_val[df_val.sign == 0].shape[0], df_val[df_val.sign == 1].shape[0]])
    weight = 1. / class_count
    samples_weight = torch.tensor([weight[s] for s in df_val.sign])
    sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
    val_loader = torch.utils.data.DataLoader(
        ds_val,
        batch_size=batch_size,
        sampler=sampler
    )

    model = MVIB(z_dim=z_dim, device=device, joint_posterior=joint_posterior).to(device)

    trainer = TrainerMVIB(
        model,
        epochs=epochs,
        lr=lr,
        beta=beta,
        checkpoint_dir=".",
        mode="bimodal",
        lr_scheduler_param=lr_scheduler_param
    )
    checkpoint = trainer.train(train_loader, val_loader, early_stopper_patience, monitor)    
    run_name = f"mvib.bimodal.{joint_posterior}.full.rep-{i}"
    trainer.save_checkpoint(checkpoint, folder='./', filename=os.path.join(RESULTS_BASE, f"{run_name}.pth"))
    
    # test
    model = MVIB.from_checkpoint(checkpoint, torch.device("cpu"))
    pred = model.classify(pep=ds_test.pep, cdr3b=ds_test.cdr3b, cdr3a=None)
    pred = pred.detach().numpy()
    df_test['prediction_'+str(i)] = pred.squeeze().tolist()

    # save results for further analysis
    df_test.to_csv(os.path.join(RESULTS_BASE, f"{run_name}.csv"), index=False)

[VAL] Best epoch 180 | Best val score -0.815843 | DKL-prior 0.000677 | BCE 0.755239 | auROC 0.8158:  46%|██████████████████████████████████████████████████████████████████████████████████▉                                                                                                  | 229/500 [27:47<32:53,  7.28s/it]


Saving best model: epoch 180


[VAL] Best epoch 150 | Best val score -0.806158 | DKL-prior 0.000565 | BCE 0.758233 | auROC 0.8062:  40%|████████████████████████████████████████████████████████████████████████                                                                                                             | 199/500 [24:05<36:25,  7.26s/it]


Saving best model: epoch 150


[VAL] Best epoch 215 | Best val score -0.813805 | DKL-prior 0.000642 | BCE 0.753081 | auROC 0.8138:  53%|███████████████████████████████████████████████████████████████████████████████████████████████▌                                                                                     | 264/500 [31:53<28:30,  7.25s/it]


Saving best model: epoch 215


[VAL] Best epoch 162 | Best val score -0.814897 | DKL-prior 0.000554 | BCE 0.699809 | auROC 0.8149:  42%|████████████████████████████████████████████████████████████████████████████▍                                                                                                        | 211/500 [25:29<34:54,  7.25s/it]


Saving best model: epoch 162


[VAL] Best epoch 272 | Best val score -0.816647 | DKL-prior 0.000692 | BCE 0.791097 | auROC 0.8166:  64%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏                                                                | 321/500 [38:56<21:42,  7.28s/it]


Saving best model: epoch 272


# Max pooling of experts

In [11]:
joint_posterior = "max_pool"

# alpha+beta set - peptide+CDR3b (max pooling of experts)

In [12]:
df = pd.read_csv(os.path.join(DATA_BASE, 'alpha-beta-splits', 'alpha-beta.csv'))

for i in range(5):  # 5 independent train/test splits
    set_random_seed(i)

    df_train, df_test = train_test_split(df.copy(), test_size=0.2, random_state=i)
    scaler = TCRDataset(df_train.copy(), torch.device("cpu"), cdr3b_col='tcrb', cdr3a_col=None).scaler

    ds_test = TCRDataset(df_test, torch.device("cpu"), cdr3b_col='tcrb', cdr3a_col=None, scaler=scaler)

    df_train, df_val = train_test_split(df_train, test_size=0.2, stratify=df_train.sign, random_state=i)
        
    # train loader with balanced sampling
    ds_train = TCRDataset(df_train, device, cdr3b_col='tcrb', cdr3a_col=None, scaler=scaler)
    class_count = np.array([df_train[df_train.sign == 0].shape[0], df_train[df_train.sign == 1].shape[0]])
    weight = 1. / class_count
    samples_weight = torch.tensor([weight[s] for s in df_train.sign])
    sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
    train_loader = torch.utils.data.DataLoader(
        ds_train,
        batch_size=batch_size,
        sampler=sampler
    )
    
    # val loader with balanced sampling
    ds_val = TCRDataset(df_val, device, cdr3b_col='tcrb', cdr3a_col=None, scaler=scaler)
    class_count = np.array([df_val[df_val.sign == 0].shape[0], df_val[df_val.sign == 1].shape[0]])
    weight = 1. / class_count
    samples_weight = torch.tensor([weight[s] for s in df_val.sign])
    sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
    val_loader = torch.utils.data.DataLoader(
        ds_val,
        batch_size=batch_size,
        sampler=sampler
    )

    model = MVIB(z_dim=z_dim, device=device, joint_posterior=joint_posterior).to(device)

    trainer = TrainerMVIB(
        model,
        epochs=epochs,
        lr=lr,
        beta=beta,
        checkpoint_dir=".",
        mode="bimodal",
        lr_scheduler_param=lr_scheduler_param
    )
    checkpoint = trainer.train(train_loader, val_loader, early_stopper_patience, monitor)    
    run_name = f"mvib.bimodal.{joint_posterior}.alpha+beta-only.rep-{i}"
    trainer.save_checkpoint(checkpoint, folder='./', filename=os.path.join(RESULTS_BASE, f"{run_name}.pth"))
    
    # test
    model = MVIB.from_checkpoint(checkpoint, torch.device("cpu"))
    pred = model.classify(pep=ds_test.pep, cdr3b=ds_test.cdr3b, cdr3a=None)
    pred = pred.detach().numpy()
    df_test['prediction_'+str(i)] = pred.squeeze().tolist()

    # save results for further analysis
    df_test.to_csv(os.path.join(RESULTS_BASE, f"{run_name}.csv"), index=False)

[VAL] Best epoch 63 | Best val score -0.905690 | DKL-prior 0.000380 | BCE 0.420178 | auROC 0.9057:  22%|████████████████████████████████████████▊                                                                                                                                             | 112/500 [04:48<16:39,  2.58s/it]


Saving best model: epoch 63


[VAL] Best epoch 165 | Best val score -0.911779 | DKL-prior 0.000534 | BCE 0.535706 | auROC 0.9118:  43%|█████████████████████████████████████████████████████████████████████████████▍                                                                                                       | 214/500 [09:10<12:15,  2.57s/it]


Saving best model: epoch 165


[VAL] Best epoch 73 | Best val score -0.911577 | DKL-prior 0.000416 | BCE 0.448623 | auROC 0.9116:  24%|████████████████████████████████████████████▍                                                                                                                                         | 122/500 [05:15<16:18,  2.59s/it]


Saving best model: epoch 73


[VAL] Best epoch 125 | Best val score -0.909825 | DKL-prior 0.000397 | BCE 0.544235 | auROC 0.9098:  35%|██████████████████████████████████████████████████████████████▉                                                                                                                      | 174/500 [07:26<13:57,  2.57s/it]


Saving best model: epoch 125


[VAL] Best epoch 137 | Best val score -0.915058 | DKL-prior 0.000491 | BCE 0.476750 | auROC 0.9151:  37%|███████████████████████████████████████████████████████████████████▎                                                                                                                 | 186/500 [07:59<13:29,  2.58s/it]


Saving best model: epoch 137


# alpha+beta set - peptide+CDR3b+CDR3a  (max pooling of experts)

In [13]:
df = pd.read_csv(os.path.join(DATA_BASE, 'alpha-beta-splits', 'alpha-beta.csv'))

for i in range(5):  # 5 independent train/test splits
    set_random_seed(i)

    df_train, df_test = train_test_split(df.copy(), test_size=0.2, random_state=i)
    scaler = TCRDataset(df_train.copy(), torch.device("cpu"), cdr3b_col='tcrb', cdr3a_col='tcra').scaler

    ds_test = TCRDataset(df_test, torch.device("cpu"), cdr3b_col='tcrb', cdr3a_col='tcra', scaler=scaler)

    df_train, df_val = train_test_split(df_train, test_size=0.2, stratify=df_train.sign, random_state=i)
        
    # train loader with balanced sampling
    ds_train = TCRDataset(df_train, device, cdr3b_col='tcrb', cdr3a_col='tcra', scaler=scaler)
    class_count = np.array([df_train[df_train.sign == 0].shape[0], df_train[df_train.sign == 1].shape[0]])
    weight = 1. / class_count
    samples_weight = torch.tensor([weight[s] for s in df_train.sign])
    sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
    train_loader = torch.utils.data.DataLoader(
        ds_train,
        batch_size=batch_size,
        sampler=sampler
    )
    
    # val loader with balanced sampling
    ds_val = TCRDataset(df_val, device, cdr3b_col='tcrb', cdr3a_col='tcra', scaler=scaler)
    class_count = np.array([df_val[df_val.sign == 0].shape[0], df_val[df_val.sign == 1].shape[0]])
    weight = 1. / class_count
    samples_weight = torch.tensor([weight[s] for s in df_val.sign])
    sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
    val_loader = torch.utils.data.DataLoader(
        ds_val,
        batch_size=batch_size,
        sampler=sampler
    )

    model = MVIB(z_dim=z_dim, device=device, joint_posterior=joint_posterior).to(device)

    trainer = TrainerMVIB(
        model,
        epochs=epochs,
        lr=lr,
        beta=beta,
        checkpoint_dir=".",
        mode="trimodal",
        lr_scheduler_param=lr_scheduler_param
    )
    checkpoint = trainer.train(train_loader, val_loader, early_stopper_patience, monitor)    
    run_name = f"mvib.trimodal.{joint_posterior}.alpha+beta-only.rep-{i}"
    trainer.save_checkpoint(checkpoint, folder='./', filename=os.path.join(RESULTS_BASE, f"{run_name}.pth"))
    
    # test
    model = MVIB.from_checkpoint(checkpoint, torch.device("cpu"))
    pred = model.classify(pep=ds_test.pep, cdr3b=ds_test.cdr3b, cdr3a=ds_test.cdr3a)
    pred = pred.detach().numpy()
    df_test['prediction_'+str(i)] = pred.squeeze().tolist()

    # save results for further analysis
    df_test.to_csv(os.path.join(RESULTS_BASE, f"{run_name}.csv"), index=False)

[VAL] Best epoch 80 | Best val score -0.926624 | DKL-prior 0.000519 | BCE 0.410618 | auROC 0.9266:  26%|██████████████████████████████████████████████▉                                                                                                                                       | 129/500 [18:08<52:10,  8.44s/it]


Saving best model: epoch 80


[VAL] Best epoch 82 | Best val score -0.926839 | DKL-prior 0.000544 | BCE 0.441198 | auROC 0.9268:  26%|███████████████████████████████████████████████▋                                                                                                                                      | 131/500 [18:26<51:55,  8.44s/it]


Saving best model: epoch 82


[VAL] Best epoch 94 | Best val score -0.930088 | DKL-prior 0.000574 | BCE 0.436505 | auROC 0.9301:  29%|████████████████████████████████████████████████████                                                                                                                                  | 143/500 [20:09<50:20,  8.46s/it]


Saving best model: epoch 94


[VAL] Best epoch 91 | Best val score -0.933600 | DKL-prior 0.000523 | BCE 0.426499 | auROC 0.9336:  28%|██████████████████████████████████████████████████▉                                                                                                                                   | 140/500 [19:49<50:57,  8.49s/it]


Saving best model: epoch 91


[VAL] Best epoch 137 | Best val score -0.937910 | DKL-prior 0.000714 | BCE 0.433803 | auROC 0.9379:  37%|███████████████████████████████████████████████████████████████████▎                                                                                                                 | 186/500 [26:13<44:16,  8.46s/it]


Saving best model: epoch 137


# Average pooling of experts

In [14]:
joint_posterior = "avg_pool"

# alpha+beta set - peptide+CDR3b (average pooling of experts)

In [15]:
df = pd.read_csv(os.path.join(DATA_BASE, 'alpha-beta-splits', 'alpha-beta.csv'))

for i in range(5):  # 5 independent train/test splits
    set_random_seed(i)

    df_train, df_test = train_test_split(df.copy(), test_size=0.2, random_state=i)
    scaler = TCRDataset(df_train.copy(), torch.device("cpu"), cdr3b_col='tcrb', cdr3a_col=None).scaler

    ds_test = TCRDataset(df_test, torch.device("cpu"), cdr3b_col='tcrb', cdr3a_col=None, scaler=scaler)

    df_train, df_val = train_test_split(df_train, test_size=0.2, stratify=df_train.sign, random_state=i)
        
    # train loader with balanced sampling
    ds_train = TCRDataset(df_train, device, cdr3b_col='tcrb', cdr3a_col=None, scaler=scaler)
    class_count = np.array([df_train[df_train.sign == 0].shape[0], df_train[df_train.sign == 1].shape[0]])
    weight = 1. / class_count
    samples_weight = torch.tensor([weight[s] for s in df_train.sign])
    sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
    train_loader = torch.utils.data.DataLoader(
        ds_train,
        batch_size=batch_size,
        sampler=sampler
    )
    
    # val loader with balanced sampling
    ds_val = TCRDataset(df_val, device, cdr3b_col='tcrb', cdr3a_col=None, scaler=scaler)
    class_count = np.array([df_val[df_val.sign == 0].shape[0], df_val[df_val.sign == 1].shape[0]])
    weight = 1. / class_count
    samples_weight = torch.tensor([weight[s] for s in df_val.sign])
    sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
    val_loader = torch.utils.data.DataLoader(
        ds_val,
        batch_size=batch_size,
        sampler=sampler
    )

    model = MVIB(z_dim=z_dim, device=device, joint_posterior=joint_posterior).to(device)

    trainer = TrainerMVIB(
        model,
        epochs=epochs,
        lr=lr,
        beta=beta,
        checkpoint_dir=".",
        mode="bimodal",
        lr_scheduler_param=lr_scheduler_param
    )
    checkpoint = trainer.train(train_loader, val_loader, early_stopper_patience, monitor)    
    run_name = f"mvib.bimodal.{joint_posterior}.alpha+beta-only.rep-{i}"
    trainer.save_checkpoint(checkpoint, folder='./', filename=os.path.join(RESULTS_BASE, f"{run_name}.pth"))
    
    # test
    model = MVIB.from_checkpoint(checkpoint, torch.device("cpu"))
    pred = model.classify(pep=ds_test.pep, cdr3b=ds_test.cdr3b, cdr3a=None)
    pred = pred.detach().numpy()
    df_test['prediction_'+str(i)] = pred.squeeze().tolist()

    # save results for further analysis
    df_test.to_csv(os.path.join(RESULTS_BASE, f"{run_name}.csv"), index=False)

SyntaxError: f-string: expecting '}' (3549742955.py, line 59)

# alpha+beta set - peptide+CDR3b+CDR3a  (average pooling of experts)

In [None]:
df = pd.read_csv(os.path.join(DATA_BASE, 'alpha-beta-splits', 'alpha-beta.csv'))

for i in range(5):  # 5 independent train/test splits
    set_random_seed(i)

    df_train, df_test = train_test_split(df.copy(), test_size=0.2, random_state=i)
    scaler = TCRDataset(df_train.copy(), torch.device("cpu"), cdr3b_col='tcrb', cdr3a_col='tcra').scaler

    ds_test = TCRDataset(df_test, torch.device("cpu"), cdr3b_col='tcrb', cdr3a_col='tcra', scaler=scaler)

    df_train, df_val = train_test_split(df_train, test_size=0.2, stratify=df_train.sign, random_state=i)
        
    # train loader with balanced sampling
    ds_train = TCRDataset(df_train, device, cdr3b_col='tcrb', cdr3a_col='tcra', scaler=scaler)
    class_count = np.array([df_train[df_train.sign == 0].shape[0], df_train[df_train.sign == 1].shape[0]])
    weight = 1. / class_count
    samples_weight = torch.tensor([weight[s] for s in df_train.sign])
    sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
    train_loader = torch.utils.data.DataLoader(
        ds_train,
        batch_size=batch_size,
        sampler=sampler
    )
    
    # val loader with balanced sampling
    ds_val = TCRDataset(df_val, device, cdr3b_col='tcrb', cdr3a_col='tcra', scaler=scaler)
    class_count = np.array([df_val[df_val.sign == 0].shape[0], df_val[df_val.sign == 1].shape[0]])
    weight = 1. / class_count
    samples_weight = torch.tensor([weight[s] for s in df_val.sign])
    sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
    val_loader = torch.utils.data.DataLoader(
        ds_val,
        batch_size=batch_size,
        sampler=sampler
    )

    model = MVIB(z_dim=z_dim, device=device, joint_posterior=joint_posterior).to(device)

    trainer = TrainerMVIB(
        model,
        epochs=epochs,
        lr=lr,
        beta=beta,
        checkpoint_dir=".",
        mode="trimodal",
        lr_scheduler_param=lr_scheduler_param
    )
    checkpoint = trainer.train(train_loader, val_loader, early_stopper_patience, monitor)    
    run_name = f"mvib.trimodal.{joint_posterior}.alpha+beta-only.rep-{i}"
    trainer.save_checkpoint(checkpoint, folder='./', filename=os.path.join(RESULTS_BASE, f"{run_name}.pth"))
    
    # test
    model = MVIB.from_checkpoint(checkpoint, torch.device("cpu"))
    pred = model.classify(pep=ds_test.pep, cdr3b=ds_test.cdr3b, cdr3a=ds_test.cdr3a)
    pred = pred.detach().numpy()
    df_test['prediction_'+str(i)] = pred.squeeze().tolist()

    # save results for further analysis
    df_test.to_csv(os.path.join(RESULTS_BASE, f"{run_name}.csv"), index=False)