# Attentive Variational Information Bottleneck

In this notebook, we train and test the Attentive Variational Information Bottleneck (MVIB [1] with Attention of Experts) and MVIB on all datasets.

[1] Microbiome-based disease prediction with multimodal variational information bottlenecks, Grazioli et al., https://www.biorxiv.org/node/2109522.external-links.html

In [1]:
import pandas as pd
import torch
import numpy as np
import random

from vibtcr.dataset import TCRDataset
from vibtcr.mvib.mvib import MVIB
from vibtcr.mvib.mvib_trainer import TrainerMVIB

from torch.utils.data.sampler import WeightedRandomSampler
from sklearn.model_selection import train_test_split
from tqdm import tqdm

In [2]:
from sklearn.metrics import roc_auc_score, accuracy_score, precision_score, recall_score, f1_score, precision_recall_curve, auc
import pandas as pd
import torch

metrics = [
    'auROC',
    'Accuracy',
    'Recall',
    'Precision',
    'F1 score',
    'auPRC'
]

def pr_auc(y_true, y_prob):
    precision, recall, thresholds = precision_recall_curve(y_true, y_prob)
    pr_auc = auc(recall, precision)
    return pr_auc

def get_scores(y_true, y_prob, y_pred):
    """
    Compute a df with all classification metrics and respective scores.
    """
    
    scores = [
        roc_auc_score(y_true, y_prob),
        accuracy_score(y_true, y_pred),
        recall_score(y_true, y_pred),
        precision_score(y_true, y_pred),
        f1_score(y_true, y_pred),
        pr_auc(y_true, y_prob)
    ]
    
    df = pd.DataFrame(data={'score': scores, 'metrics': metrics})
    return df

In [3]:
def set_random_seed(random_seed):
    random.seed(random_seed)
    np.random.seed(random_seed)
    torch.manual_seed(random_seed)
    torch.cuda.manual_seed(random_seed)
    torch.cuda.manual_seed_all(random_seed)

In [4]:
import os
login = os.getlogin( )
DATA_BASE = f"/home/{login}/Git/tcr/data/"
RESULTS_BASE = f"/home/{login}/Git/tcr/notebooks/notebooks.classification/results/"
# To run in github checkout of vibtcr, after `unzip data.zip` ...
RESULTS_BASE = os.path.join('.', 'results')
DATA_BASE = os.path.join('..', '..', 'data')

In [5]:
device = torch.device('cuda:0')

batch_size = 4096
epochs = 500 #500 <-------------------------------------
lr = 1e-3

z_dim = 150
early_stopper_patience = 50
monitor = 'auROC'
lr_scheduler_param = 10

beta = 1e-6

# AVIB multimodal pooling of experts (aoe)

In [6]:
#
# NOTE: This notebook runs several choices for "joint_posterior"
#       Class BaseVIB supports:
#         'peo' ~ "product of experts" ~ "MVIB" (multimodal) original paper
#         'aoe' ~ "attention of experts" ~ "AVIB" (attentive)
#         'max_pool'
#         'avg_pool'
joint_posterior = "aoe"

# alpha+beta set - peptide+CDR3b (AVIB)

In [7]:
df = pd.read_csv(os.path.join(DATA_BASE, 'alpha-beta-splits', 'alpha-beta.csv'))

for i in range(5):  # 5 independent train/test splits
    set_random_seed(i)

    df_train, df_test = train_test_split(df.copy(), test_size=0.2, random_state=i)
    scaler = TCRDataset(df_train.copy(), torch.device("cpu"), cdr3b_col='tcrb', cdr3a_col=None).scaler

    ds_test = TCRDataset(df_test, torch.device("cpu"), cdr3b_col='tcrb', cdr3a_col=None, scaler=scaler)

    df_train, df_val = train_test_split(df_train, test_size=0.2, stratify=df_train.sign, random_state=i)
        
    # train loader with balanced sampling
    ds_train = TCRDataset(df_train, device, cdr3b_col='tcrb', cdr3a_col=None, scaler=scaler)
    class_count = np.array([df_train[df_train.sign == 0].shape[0], df_train[df_train.sign == 1].shape[0]])
    weight = 1. / class_count
    samples_weight = torch.tensor([weight[s] for s in df_train.sign])
    sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
    train_loader = torch.utils.data.DataLoader(
        ds_train,
        batch_size=batch_size,
        sampler=sampler
    )
    
    # val loader with balanced sampling
    ds_val = TCRDataset(df_val, device, cdr3b_col='tcrb', cdr3a_col=None, scaler=scaler)
    class_count = np.array([df_val[df_val.sign == 0].shape[0], df_val[df_val.sign == 1].shape[0]])
    weight = 1. / class_count
    samples_weight = torch.tensor([weight[s] for s in df_val.sign])
    sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
    val_loader = torch.utils.data.DataLoader(
        ds_val,
        batch_size=batch_size,
        sampler=sampler
    )

    model = MVIB(z_dim=z_dim, device=device, joint_posterior=joint_posterior).to(device)

    trainer = TrainerMVIB(
        model,
        epochs=epochs,
        lr=lr,
        beta=beta,
        checkpoint_dir=".",
        mode="bimodal",
        lr_scheduler_param=lr_scheduler_param
    )
    checkpoint = trainer.train(train_loader, val_loader, early_stopper_patience, monitor)
    run_name = f"mvib.bimodal.{joint_posterior}.alpha+beta-only.rep-{i}"
    trainer.save_checkpoint(checkpoint, folder='./', filename=os.path.join(RESULTS_BASE, f"{run_name}.pth"))
    
    # test
    model = MVIB.from_checkpoint(checkpoint, torch.device("cpu"))
    pred = model.classify(pep=ds_test.pep, cdr3b=ds_test.cdr3b, cdr3a=None)
    pred = pred.detach().numpy()
    df_test['prediction_'+str(i)] = pred.squeeze().tolist()

    # save results for further analysis
    df_test.to_csv(os.path.join(RESULTS_BASE, f"{run_name}.csv"), index=False)

[VAL] Best epoch 86 | Best val score -0.915149 | DKL-prior 0.000509 | BCE 0.706523 | auROC 0.9151:  27%|██▋       | 135/500 [06:15<16:55,  2.78s/it]


Saving best model: epoch 86


[VAL] Best epoch 164 | Best val score -0.916218 | DKL-prior 0.000554 | BCE 0.879744 | auROC 0.9162:  43%|████▎     | 213/500 [10:11<13:43,  2.87s/it]


Saving best model: epoch 164


[VAL] Best epoch 74 | Best val score -0.918045 | DKL-prior 0.000480 | BCE 0.558302 | auROC 0.9180:  25%|██▍       | 123/500 [05:55<18:08,  2.89s/it]


Saving best model: epoch 74


[VAL] Best epoch 101 | Best val score -0.916223 | DKL-prior 0.000413 | BCE 0.649124 | auROC 0.9162:  30%|███       | 150/500 [07:13<16:51,  2.89s/it]


Saving best model: epoch 101


[VAL] Best epoch 78 | Best val score -0.920351 | DKL-prior 0.000545 | BCE 0.558462 | auROC 0.9204:  25%|██▌       | 127/500 [06:07<17:58,  2.89s/it]


Saving best model: epoch 78


# alpha+beta set - peptide+CDR3b+CDR3a (AVIB)

In [8]:
df = pd.read_csv(os.path.join(DATA_BASE, 'alpha-beta-splits', 'alpha-beta.csv'))

for i in range(5):  # 5 independent train/test splits
    set_random_seed(i)

    df_train, df_test = train_test_split(df.copy(), test_size=0.2, random_state=i)
    scaler = TCRDataset(df_train.copy(), torch.device("cpu"), cdr3b_col='tcrb', cdr3a_col='tcra').scaler

    ds_test = TCRDataset(df_test, torch.device("cpu"), cdr3b_col='tcrb', cdr3a_col='tcra', scaler=scaler)

    df_train, df_val = train_test_split(df_train, test_size=0.2, stratify=df_train.sign, random_state=i)
        
    # train loader with balanced sampling
    ds_train = TCRDataset(df_train, device, cdr3b_col='tcrb', cdr3a_col='tcra', scaler=scaler)
    class_count = np.array([df_train[df_train.sign == 0].shape[0], df_train[df_train.sign == 1].shape[0]])
    weight = 1. / class_count
    samples_weight = torch.tensor([weight[s] for s in df_train.sign])
    sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
    train_loader = torch.utils.data.DataLoader(
        ds_train,
        batch_size=batch_size,
        sampler=sampler
    )
    
    # val loader with balanced sampling
    ds_val = TCRDataset(df_val, device, cdr3b_col='tcrb', cdr3a_col='tcra', scaler=scaler)
    class_count = np.array([df_val[df_val.sign == 0].shape[0], df_val[df_val.sign == 1].shape[0]])
    weight = 1. / class_count
    samples_weight = torch.tensor([weight[s] for s in df_val.sign])
    sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
    val_loader = torch.utils.data.DataLoader(
        ds_val,
        batch_size=batch_size,
        sampler=sampler
    )

    model = MVIB(z_dim=z_dim, device=device, joint_posterior=joint_posterior).to(device)

    trainer = TrainerMVIB(
        model,
        epochs=epochs,
        lr=lr,
        beta=beta,
        checkpoint_dir=".",
        mode="trimodal",
        lr_scheduler_param=lr_scheduler_param
    )
    checkpoint = trainer.train(train_loader, val_loader, early_stopper_patience, monitor)    
    run_name = f"mvib.trimodal.{joint_posterior}.alpha+beta-only.rep-{i}"
    trainer.save_checkpoint(checkpoint, folder='./', filename=os.path.join(RESULTS_BASE, f"{run_name}.pth"))

    # test
    model = MVIB.from_checkpoint(checkpoint, torch.device("cpu"))
    pred = model.classify(pep=ds_test.pep, cdr3b=ds_test.cdr3b, cdr3a=ds_test.cdr3a)
    pred = pred.detach().numpy()
    df_test['prediction_'+str(i)] = pred.squeeze().tolist()

    # save results for further analysis
    df_test.to_csv(os.path.join(RESULTS_BASE, f"{run_name}.csv"), index=False)

[VAL] Best epoch 72 | Best val score -0.933601 | DKL-prior 0.000456 | BCE 0.537745 | auROC 0.9336:  24%|██▍       | 121/500 [18:58<59:24,  9.41s/it]


Saving best model: epoch 72


[VAL] Best epoch 52 | Best val score -0.938146 | DKL-prior 0.000533 | BCE 0.524283 | auROC 0.9381:  20%|██        | 101/500 [15:57<1:03:03,  9.48s/it]


Saving best model: epoch 52


[VAL] Best epoch 66 | Best val score -0.942532 | DKL-prior 0.000533 | BCE 0.480641 | auROC 0.9425:  23%|██▎       | 115/500 [18:09<1:00:47,  9.47s/it]


Saving best model: epoch 66


[VAL] Best epoch 86 | Best val score -0.940843 | DKL-prior 0.000475 | BCE 0.584330 | auROC 0.9408:  27%|██▋       | 135/500 [20:46<56:10,  9.23s/it]


Saving best model: epoch 86


[VAL] Best epoch 60 | Best val score -0.941849 | DKL-prior 0.000463 | BCE 0.440435 | auROC 0.9418:  22%|██▏       | 109/500 [16:43<59:58,  9.20s/it]


Saving best model: epoch 60


# alpha+beta set - peptide+CDR3a (AVIB)

In [9]:
df = pd.read_csv(os.path.join(DATA_BASE, 'alpha-beta-splits', 'alpha-beta.csv'))

for i in range(5):  # 5 independent train/test splits
    set_random_seed(i)

    df_train, df_test = train_test_split(df.copy(), test_size=0.2, random_state=i)
    scaler = TCRDataset(df_train.copy(), torch.device("cpu"), cdr3b_col='tcra').scaler
    # we pass column `tcra` to `cdr3b_col` because TCRDataset expects to have the CDR3b attribute
    ds_test = TCRDataset(df_test, torch.device("cpu"), cdr3b_col='tcra', scaler=scaler)

    df_train, df_val = train_test_split(df_train, test_size=0.2, stratify=df_train.sign, random_state=i)
        
    # train loader with balanced sampling
    ds_train = TCRDataset(df_train, device, cdr3b_col='tcra', scaler=scaler)
    class_count = np.array([df_train[df_train.sign == 0].shape[0], df_train[df_train.sign == 1].shape[0]])
    weight = 1. / class_count
    samples_weight = torch.tensor([weight[s] for s in df_train.sign])
    sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
    train_loader = torch.utils.data.DataLoader(
        ds_train,
        batch_size=batch_size,
        sampler=sampler
    )
    
    # val loader with balanced sampling
    ds_val = TCRDataset(df_val, device, cdr3b_col='tcra', scaler=scaler)
    class_count = np.array([df_val[df_val.sign == 0].shape[0], df_val[df_val.sign == 1].shape[0]])
    weight = 1. / class_count
    samples_weight = torch.tensor([weight[s] for s in df_val.sign])
    sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
    val_loader = torch.utils.data.DataLoader(
        ds_val,
        batch_size=batch_size,
        sampler=sampler
    )

    model = MVIB(z_dim=z_dim, device=device, joint_posterior=joint_posterior).to(device)

    trainer = TrainerMVIB(
        model,
        epochs=epochs,
        lr=lr,
        beta=beta,
        checkpoint_dir=".",
        mode="bimodal",
        lr_scheduler_param=lr_scheduler_param
    )
    checkpoint = trainer.train(train_loader, val_loader, early_stopper_patience, monitor)    
    run_name = f"mvib.bimodal-alpha.{joint_posterior}.alpha+beta-only.rep-{i}"
    trainer.save_checkpoint(checkpoint, folder='./', filename=os.path.join(RESULTS_BASE, f"{run_name}.pth"))

    # test
    model = MVIB.from_checkpoint(checkpoint, torch.device("cpu"))
    pred = model.classify(pep=ds_test.pep, cdr3b=ds_test.cdr3b)
    pred = pred.detach().numpy()
    df_test['prediction_'+str(i)] = pred.squeeze().tolist()

    # save results for further analysis
    df_test.to_csv(os.path.join(RESULTS_BASE, f"{run_name}.csv"), index=False)

[VAL] Best epoch 78 | Best val score -0.923301 | DKL-prior 0.000444 | BCE 0.567158 | auROC 0.9233:  25%|██▌       | 127/500 [05:55<17:22,  2.80s/it]


Saving best model: epoch 78


[VAL] Best epoch 118 | Best val score -0.921938 | DKL-prior 0.000565 | BCE 0.776266 | auROC 0.9219:  33%|███▎      | 167/500 [07:55<15:48,  2.85s/it]


Saving best model: epoch 118


[VAL] Best epoch 66 | Best val score -0.930392 | DKL-prior 0.000586 | BCE 0.452523 | auROC 0.9304:  23%|██▎       | 115/500 [05:20<17:53,  2.79s/it]


Saving best model: epoch 66


[VAL] Best epoch 69 | Best val score -0.925276 | DKL-prior 0.000501 | BCE 0.480685 | auROC 0.9253:  24%|██▎       | 118/500 [05:28<17:41,  2.78s/it]


Saving best model: epoch 69


[VAL] Best epoch 74 | Best val score -0.928442 | DKL-prior 0.000483 | BCE 0.543335 | auROC 0.9284:  25%|██▍       | 123/500 [05:43<17:32,  2.79s/it]


Saving best model: epoch 74


# beta set (AVIB)

In [10]:
df = pd.read_csv(os.path.join(DATA_BASE, 'alpha-beta-splits', 'beta.csv'))

for i in range(5):  # 5 independent train/test splits
    set_random_seed(i)

    df_train, df_test = train_test_split(df.copy(), test_size=0.2, random_state=i)
    scaler = TCRDataset(df_train.copy(), torch.device("cpu"), cdr3b_col='tcrb', cdr3a_col=None).scaler

    ds_test = TCRDataset(df_test, torch.device("cpu"), cdr3b_col='tcrb', cdr3a_col=None, scaler=scaler)

    df_train, df_val = train_test_split(df_train, test_size=0.2, stratify=df_train.sign, random_state=i)
        
    # train loader with balanced sampling
    ds_train = TCRDataset(df_train, device, cdr3b_col='tcrb', cdr3a_col=None, scaler=scaler)
    class_count = np.array([df_train[df_train.sign == 0].shape[0], df_train[df_train.sign == 1].shape[0]])
    weight = 1. / class_count
    samples_weight = torch.tensor([weight[s] for s in df_train.sign])
    sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
    train_loader = torch.utils.data.DataLoader(
        ds_train,
        batch_size=batch_size,
        sampler=sampler
    )
    
    # val loader with balanced sampling
    ds_val = TCRDataset(df_val, device, cdr3b_col='tcrb', cdr3a_col=None, scaler=scaler)
    class_count = np.array([df_val[df_val.sign == 0].shape[0], df_val[df_val.sign == 1].shape[0]])
    weight = 1. / class_count
    samples_weight = torch.tensor([weight[s] for s in df_val.sign])
    sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
    val_loader = torch.utils.data.DataLoader(
        ds_val,
        batch_size=batch_size,
        sampler=sampler
    )

    model = MVIB(z_dim=z_dim, device=device, joint_posterior=joint_posterior).to(device)

    trainer = TrainerMVIB(
        model,
        epochs=epochs,
        lr=lr,
        beta=beta,
        checkpoint_dir=".",
        mode="bimodal",
        lr_scheduler_param=lr_scheduler_param
    )
    checkpoint = trainer.train(train_loader, val_loader, early_stopper_patience, monitor)    
    run_name = f"mvib.bimodal.{joint_posterior}.beta-only.rep-{i}"
    trainer.save_checkpoint(checkpoint, folder='./', filename=os.path.join(RESULTS_BASE, f"{run_name}.pth"))
    
    # test
    model = MVIB.from_checkpoint(checkpoint, torch.device("cpu"))
    pred = model.classify(pep=ds_test.pep, cdr3b=ds_test.cdr3b, cdr3a=None)
    pred = pred.detach().numpy()
    df_test['prediction_'+str(i)] = pred.squeeze().tolist()

    # save results for further analysis
    df_test.to_csv(os.path.join(RESULTS_BASE, f"{run_name}.csv"), index=False)

[VAL] Best epoch 207 | Best val score -0.831984 | DKL-prior 0.000568 | BCE 0.817977 | auROC 0.8320:  51%|█████     | 256/500 [17:58<17:07,  4.21s/it]


Saving best model: epoch 207


[VAL] Best epoch 311 | Best val score -0.841795 | DKL-prior 0.000570 | BCE 0.937354 | auROC 0.8418:  72%|███████▏  | 360/500 [24:48<09:39,  4.14s/it]


Saving best model: epoch 311


[VAL] Best epoch 228 | Best val score -0.840914 | DKL-prior 0.000551 | BCE 0.898606 | auROC 0.8409:  55%|█████▌    | 277/500 [19:35<15:46,  4.25s/it]


Saving best model: epoch 228


[VAL] Best epoch 247 | Best val score -0.843472 | DKL-prior 0.000572 | BCE 0.880978 | auROC 0.8435:  59%|█████▉    | 296/500 [20:57<14:26,  4.25s/it]


Saving best model: epoch 247


[VAL] Best epoch 127 | Best val score -0.836380 | DKL-prior 0.000609 | BCE 0.759371 | auROC 0.8364:  35%|███▌      | 176/500 [12:29<22:59,  4.26s/it]


Saving best model: epoch 127


# full set: alpha+beta set + beta set (AVIB)

In [11]:
df1 = pd.read_csv(os.path.join(DATA_BASE, 'alpha-beta-splits', 'beta.csv'))
df2 = pd.read_csv(os.path.join(DATA_BASE, 'alpha-beta-splits', 'alpha-beta.csv'))
df = pd.concat([df1, df2]).reset_index()

for i in range(5):  # 5 independent train/test splits
    set_random_seed(i)

    df_train, df_test = train_test_split(df.copy(), test_size=0.2, random_state=i)
    scaler = TCRDataset(df_train.copy(), torch.device("cpu"), cdr3b_col='tcrb', cdr3a_col=None).scaler

    ds_test = TCRDataset(df_test, torch.device("cpu"), cdr3b_col='tcrb', cdr3a_col=None, scaler=scaler)

    df_train, df_val = train_test_split(df_train, test_size=0.2, stratify=df_train.sign, random_state=i)
        
    # train loader with balanced sampling
    ds_train = TCRDataset(df_train, device, cdr3b_col='tcrb', cdr3a_col=None, scaler=scaler)
    class_count = np.array([df_train[df_train.sign == 0].shape[0], df_train[df_train.sign == 1].shape[0]])
    weight = 1. / class_count
    samples_weight = torch.tensor([weight[s] for s in df_train.sign])
    sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
    train_loader = torch.utils.data.DataLoader(
        ds_train,
        batch_size=batch_size,
        sampler=sampler
    )
    
    # val loader with balanced sampling
    ds_val = TCRDataset(df_val, device, cdr3b_col='tcrb', cdr3a_col=None, scaler=scaler)
    class_count = np.array([df_val[df_val.sign == 0].shape[0], df_val[df_val.sign == 1].shape[0]])
    weight = 1. / class_count
    samples_weight = torch.tensor([weight[s] for s in df_val.sign])
    sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
    val_loader = torch.utils.data.DataLoader(
        ds_val,
        batch_size=batch_size,
        sampler=sampler
    )

    model = MVIB(z_dim=z_dim, device=device, joint_posterior=joint_posterior).to(device)

    trainer = TrainerMVIB(
        model,
        epochs=epochs,
        lr=lr,
        beta=beta,
        checkpoint_dir=".",
        mode="bimodal",
        lr_scheduler_param=lr_scheduler_param
    )
    checkpoint = trainer.train(train_loader, val_loader, early_stopper_patience, monitor)    
    run_name = f"mvib.bimodal.{joint_posterior}.full.rep-{i}"
    trainer.save_checkpoint(checkpoint, folder='./',
                            filename=os.path.join(RESULTS_BASE,f"{run_name}.pth"))
    
    # test
    model = MVIB.from_checkpoint(checkpoint, torch.device("cpu"))
    pred = model.classify(pep=ds_test.pep, cdr3b=ds_test.cdr3b, cdr3a=None)
    pred = pred.detach().numpy()
    df_test['prediction_'+str(i)] = pred.squeeze().tolist()

    # save results for further analysis
    df_test.to_csv(os.path.join(RESULTS_BASE, f"{run_name}.csv"), index=False)

[VAL] Best epoch 165 | Best val score -0.818070 | DKL-prior 0.000638 | BCE 0.719234 | auROC 0.8181:  43%|████▎     | 214/500 [26:09<34:57,  7.33s/it]


Saving best model: epoch 165


[VAL] Best epoch 166 | Best val score -0.810671 | DKL-prior 0.000614 | BCE 0.737749 | auROC 0.8107:  43%|████▎     | 215/500 [26:27<35:04,  7.38s/it]


Saving best model: epoch 166


[VAL] Best epoch 127 | Best val score -0.813910 | DKL-prior 0.000587 | BCE 0.697707 | auROC 0.8139:  35%|███▌      | 176/500 [21:43<40:00,  7.41s/it]


Saving best model: epoch 127


[VAL] Best epoch 130 | Best val score -0.816822 | DKL-prior 0.000588 | BCE 0.724327 | auROC 0.8168:  36%|███▌      | 179/500 [22:11<39:47,  7.44s/it]


Saving best model: epoch 130


[VAL] Best epoch 137 | Best val score -0.813370 | DKL-prior 0.000594 | BCE 0.761324 | auROC 0.8134:  37%|███▋      | 186/500 [23:01<38:51,  7.42s/it]


Saving best model: epoch 137


# MVIB multimodal pooling of experts (poe)

In [12]:
joint_posterior = "poe"

# alpha+beta set - peptide+CDR3b (MVIB)

In [13]:
df = pd.read_csv(os.path.join(DATA_BASE, 'alpha-beta-splits', 'alpha-beta.csv'))

for i in range(5):  # 5 independent train/test splits
    set_random_seed(i)

    df_train, df_test = train_test_split(df.copy(), test_size=0.2, random_state=i)
    scaler = TCRDataset(df_train.copy(), torch.device("cpu"), cdr3b_col='tcrb', cdr3a_col=None).scaler

    ds_test = TCRDataset(df_test, torch.device("cpu"), cdr3b_col='tcrb', cdr3a_col=None, scaler=scaler)

    df_train, df_val = train_test_split(df_train, test_size=0.2, stratify=df_train.sign, random_state=i)
        
    # train loader with balanced sampling
    ds_train = TCRDataset(df_train, device, cdr3b_col='tcrb', cdr3a_col=None, scaler=scaler)
    class_count = np.array([df_train[df_train.sign == 0].shape[0], df_train[df_train.sign == 1].shape[0]])
    weight = 1. / class_count
    samples_weight = torch.tensor([weight[s] for s in df_train.sign])
    sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
    train_loader = torch.utils.data.DataLoader(
        ds_train,
        batch_size=batch_size,
        sampler=sampler
    )
    
    # val loader with balanced sampling
    ds_val = TCRDataset(df_val, device, cdr3b_col='tcrb', cdr3a_col=None, scaler=scaler)
    class_count = np.array([df_val[df_val.sign == 0].shape[0], df_val[df_val.sign == 1].shape[0]])
    weight = 1. / class_count
    samples_weight = torch.tensor([weight[s] for s in df_val.sign])
    sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
    val_loader = torch.utils.data.DataLoader(
        ds_val,
        batch_size=batch_size,
        sampler=sampler
    )

    model = MVIB(z_dim=z_dim, device=device, joint_posterior=joint_posterior).to(device)

    trainer = TrainerMVIB(
        model,
        epochs=epochs,
        lr=lr,
        beta=beta,
        checkpoint_dir=".",
        mode="bimodal",
        lr_scheduler_param=lr_scheduler_param
    )
    checkpoint = trainer.train(train_loader, val_loader, early_stopper_patience, monitor)
    run_name = f"mvib.bimodal.{joint_posterior}.alpha+beta-only.rep-{i}"
    trainer.save_checkpoint(checkpoint, folder='./', filename=os.path.join(RESULTS_BASE, f"{run_name}.pth"))
    
    # test
    model = MVIB.from_checkpoint(checkpoint, torch.device("cpu"))
    pred = model.classify(pep=ds_test.pep, cdr3b=ds_test.cdr3b, cdr3a=None)
    pred = pred.detach().numpy()
    df_test['prediction_'+str(i)] = pred.squeeze().tolist()

    # save results for further analysis
    df_test.to_csv(os.path.join(RESULTS_BASE, f"{run_name}.csv"), index=False)

[VAL] Best epoch 87 | Best val score -0.911952 | DKL-prior 0.000372 | BCE 0.673360 | auROC 0.9120:  27%|██▋       | 136/500 [05:36<14:59,  2.47s/it]


Saving best model: epoch 87


[VAL] Best epoch 126 | Best val score -0.914753 | DKL-prior 0.000360 | BCE 0.742800 | auROC 0.9148:  35%|███▌      | 175/500 [07:09<13:17,  2.45s/it]


Saving best model: epoch 126


[VAL] Best epoch 89 | Best val score -0.915633 | DKL-prior 0.000407 | BCE 0.624047 | auROC 0.9156:  28%|██▊       | 138/500 [05:36<14:41,  2.44s/it]


Saving best model: epoch 89


[VAL] Best epoch 87 | Best val score -0.913535 | DKL-prior 0.000386 | BCE 0.684263 | auROC 0.9135:  27%|██▋       | 136/500 [05:31<14:47,  2.44s/it]


Saving best model: epoch 87


[VAL] Best epoch 91 | Best val score -0.916320 | DKL-prior 0.000371 | BCE 0.635772 | auROC 0.9163:  28%|██▊       | 140/500 [05:41<14:38,  2.44s/it]


Saving best model: epoch 91


# alpha+beta set - peptide+CDR3b+CDR3a (MVIB)

In [14]:
df = pd.read_csv(os.path.join(DATA_BASE, 'alpha-beta-splits', 'alpha-beta.csv'))

for i in range(5):  # 5 independent train/test splits
    set_random_seed(i)

    df_train, df_test = train_test_split(df.copy(), test_size=0.2, random_state=i)
    scaler = TCRDataset(df_train.copy(), torch.device("cpu"), cdr3b_col='tcrb', cdr3a_col='tcra').scaler

    ds_test = TCRDataset(df_test, torch.device("cpu"), cdr3b_col='tcrb', cdr3a_col='tcra', scaler=scaler)

    df_train, df_val = train_test_split(df_train, test_size=0.2, stratify=df_train.sign, random_state=i)
        
    # train loader with balanced sampling
    ds_train = TCRDataset(df_train, device, cdr3b_col='tcrb', cdr3a_col='tcra', scaler=scaler)
    class_count = np.array([df_train[df_train.sign == 0].shape[0], df_train[df_train.sign == 1].shape[0]])
    weight = 1. / class_count
    samples_weight = torch.tensor([weight[s] for s in df_train.sign])
    sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
    train_loader = torch.utils.data.DataLoader(
        ds_train,
        batch_size=batch_size,
        sampler=sampler
    )
    
    # val loader with balanced sampling
    ds_val = TCRDataset(df_val, device, cdr3b_col='tcrb', cdr3a_col='tcra', scaler=scaler)
    class_count = np.array([df_val[df_val.sign == 0].shape[0], df_val[df_val.sign == 1].shape[0]])
    weight = 1. / class_count
    samples_weight = torch.tensor([weight[s] for s in df_val.sign])
    sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
    val_loader = torch.utils.data.DataLoader(
        ds_val,
        batch_size=batch_size,
        sampler=sampler
    )

    model = MVIB(z_dim=z_dim, device=device, joint_posterior=joint_posterior).to(device)

    trainer = TrainerMVIB(
        model,
        epochs=epochs,
        lr=lr,
        beta=beta,
        checkpoint_dir=".",
        mode="trimodal",
        lr_scheduler_param=lr_scheduler_param
    )
    checkpoint = trainer.train(train_loader, val_loader, early_stopper_patience, monitor)    
    run_name = f"mvib.trimodal.{joint_posterior}.alpha+beta-only.rep-{i}"
    trainer.save_checkpoint(checkpoint, folder='./', filename=os.path.join(RESULTS_BASE, f"{run_name}.pth"))

    # test
    model = MVIB.from_checkpoint(checkpoint, torch.device("cpu"))
    pred = model.classify(pep=ds_test.pep, cdr3b=ds_test.cdr3b, cdr3a=ds_test.cdr3a)
    pred = pred.detach().numpy()
    df_test['prediction_'+str(i)] = pred.squeeze().tolist()

    # save results for further analysis
    df_test.to_csv(os.path.join(RESULTS_BASE, f"{run_name}.csv"), index=False)

[VAL] Best epoch 80 | Best val score -0.933487 | DKL-prior 0.000351 | BCE 0.582936 | auROC 0.9335:  26%|██▌       | 129/500 [17:17<49:43,  8.04s/it]


Saving best model: epoch 80


[VAL] Best epoch 82 | Best val score -0.934239 | DKL-prior 0.000373 | BCE 0.595437 | auROC 0.9342:  26%|██▌       | 131/500 [17:38<49:41,  8.08s/it]


Saving best model: epoch 82


[VAL] Best epoch 67 | Best val score -0.937545 | DKL-prior 0.000356 | BCE 0.468431 | auROC 0.9375:  23%|██▎       | 116/500 [15:32<51:27,  8.04s/it]


Saving best model: epoch 67


[VAL] Best epoch 151 | Best val score -0.936985 | DKL-prior 0.000350 | BCE 0.791356 | auROC 0.9370:  40%|████      | 200/500 [26:37<39:56,  7.99s/it]


Saving best model: epoch 151


[VAL] Best epoch 137 | Best val score -0.940832 | DKL-prior 0.000356 | BCE 0.668724 | auROC 0.9408:  37%|███▋      | 186/500 [24:45<41:47,  7.99s/it]


Saving best model: epoch 137


# alpha+beta set - peptide+CDR3a (MVIB)

In [15]:
df = pd.read_csv(os.path.join(DATA_BASE, 'alpha-beta-splits', 'alpha-beta.csv'))

for i in range(5):  # 5 independent train/test splits
    set_random_seed(i)

    df_train, df_test = train_test_split(df.copy(), test_size=0.2, random_state=i)
    scaler = TCRDataset(df_train.copy(), torch.device("cpu"), cdr3b_col='tcra').scaler
    # we pass column `tcra` to `cdr3b_col` because TCRDataset expects to have the CDR3b attribute
    ds_test = TCRDataset(df_test, torch.device("cpu"), cdr3b_col='tcra', scaler=scaler)

    df_train, df_val = train_test_split(df_train, test_size=0.2, stratify=df_train.sign, random_state=i)
        
    # train loader with balanced sampling
    ds_train = TCRDataset(df_train, device, cdr3b_col='tcra', scaler=scaler)
    class_count = np.array([df_train[df_train.sign == 0].shape[0], df_train[df_train.sign == 1].shape[0]])
    weight = 1. / class_count
    samples_weight = torch.tensor([weight[s] for s in df_train.sign])
    sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
    train_loader = torch.utils.data.DataLoader(
        ds_train,
        batch_size=batch_size,
        sampler=sampler
    )
    
    # val loader with balanced sampling
    ds_val = TCRDataset(df_val, device, cdr3b_col='tcra', scaler=scaler)
    class_count = np.array([df_val[df_val.sign == 0].shape[0], df_val[df_val.sign == 1].shape[0]])
    weight = 1. / class_count
    samples_weight = torch.tensor([weight[s] for s in df_val.sign])
    sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
    val_loader = torch.utils.data.DataLoader(
        ds_val,
        batch_size=batch_size,
        sampler=sampler
    )

    model = MVIB(z_dim=z_dim, device=device, joint_posterior=joint_posterior).to(device)

    trainer = TrainerMVIB(
        model,
        epochs=epochs,
        lr=lr,
        beta=beta,
        checkpoint_dir=".",
        mode="bimodal",
        lr_scheduler_param=lr_scheduler_param
    )
    checkpoint = trainer.train(train_loader, val_loader, early_stopper_patience, monitor)    
    run_name = f"mvib.bimodal-alpha.{joint_posterior}.alpha+beta-only.rep-{i}"
    trainer.save_checkpoint(checkpoint, folder='./', filename=os.path.join(RESULTS_BASE, f"{run_name}.pth"))

    # test
    model = MVIB.from_checkpoint(checkpoint, torch.device("cpu"))
    pred = model.classify(pep=ds_test.pep, cdr3b=ds_test.cdr3b)
    pred = pred.detach().numpy()
    df_test['prediction_'+str(i)] = pred.squeeze().tolist()

    # save results for further analysis
    df_test.to_csv(os.path.join(RESULTS_BASE, f"{run_name}.csv"), index=False)

[VAL] Best epoch 73 | Best val score -0.920042 | DKL-prior 0.000416 | BCE 0.514892 | auROC 0.9200:  24%|██▍       | 122/500 [04:56<15:18,  2.43s/it]


Saving best model: epoch 73


[VAL] Best epoch 119 | Best val score -0.918556 | DKL-prior 0.000368 | BCE 0.769417 | auROC 0.9186:  34%|███▎      | 168/500 [06:48<13:28,  2.43s/it]


Saving best model: epoch 119


[VAL] Best epoch 92 | Best val score -0.925117 | DKL-prior 0.000399 | BCE 0.553487 | auROC 0.9251:  28%|██▊       | 141/500 [05:42<14:32,  2.43s/it]


Saving best model: epoch 92


[VAL] Best epoch 87 | Best val score -0.921203 | DKL-prior 0.000391 | BCE 0.568877 | auROC 0.9212:  27%|██▋       | 136/500 [05:31<14:47,  2.44s/it]


Saving best model: epoch 87


[VAL] Best epoch 137 | Best val score -0.927519 | DKL-prior 0.000346 | BCE 0.647315 | auROC 0.9275:  37%|███▋      | 186/500 [07:34<12:47,  2.45s/it]


Saving best model: epoch 137


# beta set (MVIB)

In [16]:
df = pd.read_csv(os.path.join(DATA_BASE, 'alpha-beta-splits', 'beta.csv'))

for i in range(5):  # 5 independent train/test splits
    set_random_seed(i)

    df_train, df_test = train_test_split(df.copy(), test_size=0.2, random_state=i)
    scaler = TCRDataset(df_train.copy(), torch.device("cpu"), cdr3b_col='tcrb', cdr3a_col=None).scaler

    ds_test = TCRDataset(df_test, torch.device("cpu"), cdr3b_col='tcrb', cdr3a_col=None, scaler=scaler)

    df_train, df_val = train_test_split(df_train, test_size=0.2, stratify=df_train.sign, random_state=i)
        
    # train loader with balanced sampling
    ds_train = TCRDataset(df_train, device, cdr3b_col='tcrb', cdr3a_col=None, scaler=scaler)
    class_count = np.array([df_train[df_train.sign == 0].shape[0], df_train[df_train.sign == 1].shape[0]])
    weight = 1. / class_count
    samples_weight = torch.tensor([weight[s] for s in df_train.sign])
    sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
    train_loader = torch.utils.data.DataLoader(
        ds_train,
        batch_size=batch_size,
        sampler=sampler
    )
    
    # val loader with balanced sampling
    ds_val = TCRDataset(df_val, device, cdr3b_col='tcrb', cdr3a_col=None, scaler=scaler)
    class_count = np.array([df_val[df_val.sign == 0].shape[0], df_val[df_val.sign == 1].shape[0]])
    weight = 1. / class_count
    samples_weight = torch.tensor([weight[s] for s in df_val.sign])
    sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
    val_loader = torch.utils.data.DataLoader(
        ds_val,
        batch_size=batch_size,
        sampler=sampler
    )

    model = MVIB(z_dim=z_dim, device=device, joint_posterior=joint_posterior).to(device)

    trainer = TrainerMVIB(
        model,
        epochs=epochs,
        lr=lr,
        beta=beta,
        checkpoint_dir=".",
        mode="bimodal",
        lr_scheduler_param=lr_scheduler_param
    )
    checkpoint = trainer.train(train_loader, val_loader, early_stopper_patience, monitor)    
    run_name = f"mvib.bimodal.{joint_posterior}.beta-only.rep-{i}"
    trainer.save_checkpoint(checkpoint, folder='./', filename=os.path.join(RESULTS_BASE, f"{run_name}.pth"))
    
    # test
    model = MVIB.from_checkpoint(checkpoint, torch.device("cpu"))
    pred = model.classify(pep=ds_test.pep, cdr3b=ds_test.cdr3b, cdr3a=None)
    pred = pred.detach().numpy()
    df_test['prediction_'+str(i)] = pred.squeeze().tolist()

    # save results for further analysis
    df_test.to_csv(os.path.join(RESULTS_BASE, f"{run_name}.csv"), index=False)

[VAL] Best epoch 379 | Best val score -0.834778 | DKL-prior 0.000371 | BCE 0.771687 | auROC 0.8348:  86%|████████▌ | 428/500 [26:20<04:25,  3.69s/it]


Saving best model: epoch 379


[VAL] Best epoch 161 | Best val score -0.829311 | DKL-prior 0.000377 | BCE 0.691996 | auROC 0.8293:  42%|████▏     | 210/500 [12:41<17:31,  3.63s/it]


Saving best model: epoch 161


[VAL] Best epoch 262 | Best val score -0.839996 | DKL-prior 0.000351 | BCE 0.705013 | auROC 0.8400:  62%|██████▏   | 311/500 [18:59<11:32,  3.66s/it]


Saving best model: epoch 262


[VAL] Best epoch 258 | Best val score -0.841148 | DKL-prior 0.000368 | BCE 0.709435 | auROC 0.8411:  61%|██████▏   | 307/500 [18:30<11:38,  3.62s/it]


Saving best model: epoch 258


[VAL] Best epoch 225 | Best val score -0.835524 | DKL-prior 0.000362 | BCE 0.713573 | auROC 0.8355:  55%|█████▍    | 274/500 [16:38<13:43,  3.64s/it]


Saving best model: epoch 225


# full set: alpha+beta set + beta set (MVIB)

In [17]:
df1 = pd.read_csv(os.path.join(DATA_BASE, 'alpha-beta-splits', 'beta.csv'))
df2 = pd.read_csv(os.path.join(DATA_BASE, 'alpha-beta-splits', 'alpha-beta.csv'))
df = pd.concat([df1, df2]).reset_index()

for i in range(5):  # 5 independent train/test splits
    set_random_seed(i)

    df_train, df_test = train_test_split(df.copy(), test_size=0.2, random_state=i)
    scaler = TCRDataset(df_train.copy(), torch.device("cpu"), cdr3b_col='tcrb', cdr3a_col=None).scaler

    ds_test = TCRDataset(df_test, torch.device("cpu"), cdr3b_col='tcrb', cdr3a_col=None, scaler=scaler)

    df_train, df_val = train_test_split(df_train, test_size=0.2, stratify=df_train.sign, random_state=i)
        
    # train loader with balanced sampling
    ds_train = TCRDataset(df_train, device, cdr3b_col='tcrb', cdr3a_col=None, scaler=scaler)
    class_count = np.array([df_train[df_train.sign == 0].shape[0], df_train[df_train.sign == 1].shape[0]])
    weight = 1. / class_count
    samples_weight = torch.tensor([weight[s] for s in df_train.sign])
    sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
    train_loader = torch.utils.data.DataLoader(
        ds_train,
        batch_size=batch_size,
        sampler=sampler
    )
    
    # val loader with balanced sampling
    ds_val = TCRDataset(df_val, device, cdr3b_col='tcrb', cdr3a_col=None, scaler=scaler)
    class_count = np.array([df_val[df_val.sign == 0].shape[0], df_val[df_val.sign == 1].shape[0]])
    weight = 1. / class_count
    samples_weight = torch.tensor([weight[s] for s in df_val.sign])
    sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
    val_loader = torch.utils.data.DataLoader(
        ds_val,
        batch_size=batch_size,
        sampler=sampler
    )

    model = MVIB(z_dim=z_dim, device=device, joint_posterior=joint_posterior).to(device)

    trainer = TrainerMVIB(
        model,
        epochs=epochs,
        lr=lr,
        beta=beta,
        checkpoint_dir=".",
        mode="bimodal",
        lr_scheduler_param=lr_scheduler_param
    )
    checkpoint = trainer.train(train_loader, val_loader, early_stopper_patience, monitor)    
    run_name = f"mvib.bimodal.{joint_posterior}.full.rep-{i}"
    trainer.save_checkpoint(checkpoint, folder='./', filename=os.path.join(RESULTS_BASE, f"{run_name}.pth"))
    
    # test
    model = MVIB.from_checkpoint(checkpoint, torch.device("cpu"))
    pred = model.classify(pep=ds_test.pep, cdr3b=ds_test.cdr3b, cdr3a=None)
    pred = pred.detach().numpy()
    df_test['prediction_'+str(i)] = pred.squeeze().tolist()

    # save results for further analysis
    df_test.to_csv(os.path.join(RESULTS_BASE, f"{run_name}.csv"), index=False)

[VAL] Best epoch 397 | Best val score -0.819986 | DKL-prior 0.000434 | BCE 0.692449 | auROC 0.8200:  89%|████████▉ | 446/500 [47:10<05:42,  6.35s/it]


Saving best model: epoch 397


[VAL] Best epoch 334 | Best val score -0.815213 | DKL-prior 0.000407 | BCE 0.695995 | auROC 0.8152:  77%|███████▋  | 383/500 [40:44<12:26,  6.38s/it]


Saving best model: epoch 334


[VAL] Best epoch 283 | Best val score -0.816175 | DKL-prior 0.000392 | BCE 0.664660 | auROC 0.8162:  66%|██████▋   | 332/500 [35:09<17:47,  6.35s/it]


Saving best model: epoch 283


[VAL] Best epoch 278 | Best val score -0.819015 | DKL-prior 0.000387 | BCE 0.646202 | auROC 0.8190:  65%|██████▌   | 327/500 [34:52<18:27,  6.40s/it]


Saving best model: epoch 278


[VAL] Best epoch 332 | Best val score -0.822531 | DKL-prior 0.000422 | BCE 0.689946 | auROC 0.8225:  76%|███████▌  | 381/500 [40:22<12:36,  6.36s/it]


Saving best model: epoch 332


# Max pooling of experts

In [18]:
joint_posterior = "max_pool"

# alpha+beta set - peptide+CDR3b (max pooling of experts)

In [19]:
df = pd.read_csv(os.path.join(DATA_BASE, 'alpha-beta-splits', 'alpha-beta.csv'))

for i in range(5):  # 5 independent train/test splits
    set_random_seed(i)

    df_train, df_test = train_test_split(df.copy(), test_size=0.2, random_state=i)
    scaler = TCRDataset(df_train.copy(), torch.device("cpu"), cdr3b_col='tcrb', cdr3a_col=None).scaler

    ds_test = TCRDataset(df_test, torch.device("cpu"), cdr3b_col='tcrb', cdr3a_col=None, scaler=scaler)

    df_train, df_val = train_test_split(df_train, test_size=0.2, stratify=df_train.sign, random_state=i)
        
    # train loader with balanced sampling
    ds_train = TCRDataset(df_train, device, cdr3b_col='tcrb', cdr3a_col=None, scaler=scaler)
    class_count = np.array([df_train[df_train.sign == 0].shape[0], df_train[df_train.sign == 1].shape[0]])
    weight = 1. / class_count
    samples_weight = torch.tensor([weight[s] for s in df_train.sign])
    sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
    train_loader = torch.utils.data.DataLoader(
        ds_train,
        batch_size=batch_size,
        sampler=sampler
    )
    
    # val loader with balanced sampling
    ds_val = TCRDataset(df_val, device, cdr3b_col='tcrb', cdr3a_col=None, scaler=scaler)
    class_count = np.array([df_val[df_val.sign == 0].shape[0], df_val[df_val.sign == 1].shape[0]])
    weight = 1. / class_count
    samples_weight = torch.tensor([weight[s] for s in df_val.sign])
    sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
    val_loader = torch.utils.data.DataLoader(
        ds_val,
        batch_size=batch_size,
        sampler=sampler
    )

    model = MVIB(z_dim=z_dim, device=device, joint_posterior=joint_posterior).to(device)

    trainer = TrainerMVIB(
        model,
        epochs=epochs,
        lr=lr,
        beta=beta,
        checkpoint_dir=".",
        mode="bimodal",
        lr_scheduler_param=lr_scheduler_param
    )
    checkpoint = trainer.train(train_loader, val_loader, early_stopper_patience, monitor)    
    run_name = f"mvib.bimodal.{joint_posterior}.alpha+beta-only.rep-{i}"
    trainer.save_checkpoint(checkpoint, folder='./', filename=os.path.join(RESULTS_BASE, f"{run_name}.pth"))
    
    # test
    model = MVIB.from_checkpoint(checkpoint, torch.device("cpu"))
    pred = model.classify(pep=ds_test.pep, cdr3b=ds_test.cdr3b, cdr3a=None)
    pred = pred.detach().numpy()
    df_test['prediction_'+str(i)] = pred.squeeze().tolist()

    # save results for further analysis
    df_test.to_csv(os.path.join(RESULTS_BASE, f"{run_name}.csv"), index=False)

[VAL] Best epoch 172 | Best val score -0.909271 | DKL-prior 0.000485 | BCE 0.579490 | auROC 0.9093:  44%|████▍     | 221/500 [09:28<11:58,  2.57s/it]


Saving best model: epoch 172


[VAL] Best epoch 165 | Best val score -0.913073 | DKL-prior 0.000586 | BCE 0.518310 | auROC 0.9131:  43%|████▎     | 214/500 [09:11<12:17,  2.58s/it]


Saving best model: epoch 165


[VAL] Best epoch 151 | Best val score -0.912707 | DKL-prior 0.000554 | BCE 0.525336 | auROC 0.9127:  40%|████      | 200/500 [08:32<12:49,  2.56s/it]


Saving best model: epoch 151


[VAL] Best epoch 125 | Best val score -0.907327 | DKL-prior 0.000439 | BCE 0.481588 | auROC 0.9073:  35%|███▍      | 174/500 [07:26<13:55,  2.56s/it]


Saving best model: epoch 125


[VAL] Best epoch 141 | Best val score -0.913202 | DKL-prior 0.000461 | BCE 0.457095 | auROC 0.9132:  38%|███▊      | 190/500 [08:11<13:21,  2.59s/it]


Saving best model: epoch 141


# alpha+beta set - peptide+CDR3b+CDR3a  (max pooling of experts)

In [20]:
df = pd.read_csv(os.path.join(DATA_BASE, 'alpha-beta-splits', 'alpha-beta.csv'))

for i in range(5):  # 5 independent train/test splits
    set_random_seed(i)

    df_train, df_test = train_test_split(df.copy(), test_size=0.2, random_state=i)
    scaler = TCRDataset(df_train.copy(), torch.device("cpu"), cdr3b_col='tcrb', cdr3a_col='tcra').scaler

    ds_test = TCRDataset(df_test, torch.device("cpu"), cdr3b_col='tcrb', cdr3a_col='tcra', scaler=scaler)

    df_train, df_val = train_test_split(df_train, test_size=0.2, stratify=df_train.sign, random_state=i)
        
    # train loader with balanced sampling
    ds_train = TCRDataset(df_train, device, cdr3b_col='tcrb', cdr3a_col='tcra', scaler=scaler)
    class_count = np.array([df_train[df_train.sign == 0].shape[0], df_train[df_train.sign == 1].shape[0]])
    weight = 1. / class_count
    samples_weight = torch.tensor([weight[s] for s in df_train.sign])
    sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
    train_loader = torch.utils.data.DataLoader(
        ds_train,
        batch_size=batch_size,
        sampler=sampler
    )
    
    # val loader with balanced sampling
    ds_val = TCRDataset(df_val, device, cdr3b_col='tcrb', cdr3a_col='tcra', scaler=scaler)
    class_count = np.array([df_val[df_val.sign == 0].shape[0], df_val[df_val.sign == 1].shape[0]])
    weight = 1. / class_count
    samples_weight = torch.tensor([weight[s] for s in df_val.sign])
    sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
    val_loader = torch.utils.data.DataLoader(
        ds_val,
        batch_size=batch_size,
        sampler=sampler
    )

    model = MVIB(z_dim=z_dim, device=device, joint_posterior=joint_posterior).to(device)

    trainer = TrainerMVIB(
        model,
        epochs=epochs,
        lr=lr,
        beta=beta,
        checkpoint_dir=".",
        mode="trimodal",
        lr_scheduler_param=lr_scheduler_param
    )
    checkpoint = trainer.train(train_loader, val_loader, early_stopper_patience, monitor)    
    run_name = f"mvib.trimodal.{joint_posterior}.alpha+beta-only.rep-{i}"
    trainer.save_checkpoint(checkpoint, folder='./', filename=os.path.join(RESULTS_BASE, f"{run_name}.pth"))
    
    # test
    model = MVIB.from_checkpoint(checkpoint, torch.device("cpu"))
    pred = model.classify(pep=ds_test.pep, cdr3b=ds_test.cdr3b, cdr3a=ds_test.cdr3a)
    pred = pred.detach().numpy()
    df_test['prediction_'+str(i)] = pred.squeeze().tolist()

    # save results for further analysis
    df_test.to_csv(os.path.join(RESULTS_BASE, f"{run_name}.csv"), index=False)

[VAL] Best epoch 80 | Best val score -0.928553 | DKL-prior 0.000517 | BCE 0.405969 | auROC 0.9286:  26%|██▌       | 129/500 [18:07<52:06,  8.43s/it]


Saving best model: epoch 80


[VAL] Best epoch 82 | Best val score -0.927299 | DKL-prior 0.000539 | BCE 0.450303 | auROC 0.9273:  26%|██▌       | 131/500 [18:29<52:04,  8.47s/it]


Saving best model: epoch 82


[VAL] Best epoch 94 | Best val score -0.930142 | DKL-prior 0.000568 | BCE 0.435805 | auROC 0.9301:  29%|██▊       | 143/500 [20:10<50:23,  8.47s/it]


Saving best model: epoch 94


[VAL] Best epoch 91 | Best val score -0.931659 | DKL-prior 0.000504 | BCE 0.441023 | auROC 0.9317:  28%|██▊       | 140/500 [19:51<51:02,  8.51s/it]


Saving best model: epoch 91


[VAL] Best epoch 135 | Best val score -0.937005 | DKL-prior 0.000676 | BCE 0.394142 | auROC 0.9370:  37%|███▋      | 184/500 [25:55<44:30,  8.45s/it]


Saving best model: epoch 135


# Average pooling of experts

In [21]:
joint_posterior = "avg_pool"

# alpha+beta set - peptide+CDR3b (average pooling of experts)

In [22]:
df = pd.read_csv(os.path.join(DATA_BASE, 'alpha-beta-splits', 'alpha-beta.csv'))

for i in range(5):  # 5 independent train/test splits
    set_random_seed(i)

    df_train, df_test = train_test_split(df.copy(), test_size=0.2, random_state=i)
    scaler = TCRDataset(df_train.copy(), torch.device("cpu"), cdr3b_col='tcrb', cdr3a_col=None).scaler

    ds_test = TCRDataset(df_test, torch.device("cpu"), cdr3b_col='tcrb', cdr3a_col=None, scaler=scaler)

    df_train, df_val = train_test_split(df_train, test_size=0.2, stratify=df_train.sign, random_state=i)
        
    # train loader with balanced sampling
    ds_train = TCRDataset(df_train, device, cdr3b_col='tcrb', cdr3a_col=None, scaler=scaler)
    class_count = np.array([df_train[df_train.sign == 0].shape[0], df_train[df_train.sign == 1].shape[0]])
    weight = 1. / class_count
    samples_weight = torch.tensor([weight[s] for s in df_train.sign])
    sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
    train_loader = torch.utils.data.DataLoader(
        ds_train,
        batch_size=batch_size,
        sampler=sampler
    )
    
    # val loader with balanced sampling
    ds_val = TCRDataset(df_val, device, cdr3b_col='tcrb', cdr3a_col=None, scaler=scaler)
    class_count = np.array([df_val[df_val.sign == 0].shape[0], df_val[df_val.sign == 1].shape[0]])
    weight = 1. / class_count
    samples_weight = torch.tensor([weight[s] for s in df_val.sign])
    sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
    val_loader = torch.utils.data.DataLoader(
        ds_val,
        batch_size=batch_size,
        sampler=sampler
    )

    model = MVIB(z_dim=z_dim, device=device, joint_posterior=joint_posterior).to(device)

    trainer = TrainerMVIB(
        model,
        epochs=epochs,
        lr=lr,
        beta=beta,
        checkpoint_dir=".",
        mode="bimodal",
        lr_scheduler_param=lr_scheduler_param
    )
    checkpoint = trainer.train(train_loader, val_loader, early_stopper_patience, monitor)    
    run_name = f"mvib.bimodal.{joint_posterior}.alpha+beta-only.rep-{i}"
    trainer.save_checkpoint(checkpoint, folder='./', filename=os.path.join(RESULTS_BASE, f"{run_name}.pth"))
    
    # test
    model = MVIB.from_checkpoint(checkpoint, torch.device("cpu"))
    pred = model.classify(pep=ds_test.pep, cdr3b=ds_test.cdr3b, cdr3a=None)
    pred = pred.detach().numpy()
    df_test['prediction_'+str(i)] = pred.squeeze().tolist()

    # save results for further analysis
    df_test.to_csv(os.path.join(RESULTS_BASE, f"{run_name}.csv"), index=False)

[VAL] Best epoch 80 | Best val score -0.904815 | DKL-prior 0.000479 | BCE 0.551604 | auROC 0.9048:  26%|██▌       | 129/500 [05:08<14:46,  2.39s/it]


Saving best model: epoch 80


[VAL] Best epoch 204 | Best val score -0.910141 | DKL-prior 0.000432 | BCE 0.775137 | auROC 0.9101:  51%|█████     | 253/500 [10:11<09:56,  2.42s/it]


Saving best model: epoch 204


[VAL] Best epoch 88 | Best val score -0.910858 | DKL-prior 0.000470 | BCE 0.525894 | auROC 0.9109:  27%|██▋       | 137/500 [05:30<14:35,  2.41s/it]


Saving best model: epoch 88


[VAL] Best epoch 120 | Best val score -0.905538 | DKL-prior 0.000456 | BCE 0.652001 | auROC 0.9055:  34%|███▍      | 169/500 [06:46<13:15,  2.40s/it]


Saving best model: epoch 120


[VAL] Best epoch 138 | Best val score -0.909851 | DKL-prior 0.000444 | BCE 0.669654 | auROC 0.9099:  37%|███▋      | 187/500 [07:32<12:37,  2.42s/it]


Saving best model: epoch 138


# alpha+beta set - peptide+CDR3b+CDR3a  (average pooling of experts)

In [23]:
df = pd.read_csv(os.path.join(DATA_BASE, 'alpha-beta-splits', 'alpha-beta.csv'))

for i in range(5):  # 5 independent train/test splits
    set_random_seed(i)

    df_train, df_test = train_test_split(df.copy(), test_size=0.2, random_state=i)
    scaler = TCRDataset(df_train.copy(), torch.device("cpu"), cdr3b_col='tcrb', cdr3a_col='tcra').scaler

    ds_test = TCRDataset(df_test, torch.device("cpu"), cdr3b_col='tcrb', cdr3a_col='tcra', scaler=scaler)

    df_train, df_val = train_test_split(df_train, test_size=0.2, stratify=df_train.sign, random_state=i)
        
    # train loader with balanced sampling
    ds_train = TCRDataset(df_train, device, cdr3b_col='tcrb', cdr3a_col='tcra', scaler=scaler)
    class_count = np.array([df_train[df_train.sign == 0].shape[0], df_train[df_train.sign == 1].shape[0]])
    weight = 1. / class_count
    samples_weight = torch.tensor([weight[s] for s in df_train.sign])
    sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
    train_loader = torch.utils.data.DataLoader(
        ds_train,
        batch_size=batch_size,
        sampler=sampler
    )
    
    # val loader with balanced sampling
    ds_val = TCRDataset(df_val, device, cdr3b_col='tcrb', cdr3a_col='tcra', scaler=scaler)
    class_count = np.array([df_val[df_val.sign == 0].shape[0], df_val[df_val.sign == 1].shape[0]])
    weight = 1. / class_count
    samples_weight = torch.tensor([weight[s] for s in df_val.sign])
    sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
    val_loader = torch.utils.data.DataLoader(
        ds_val,
        batch_size=batch_size,
        sampler=sampler
    )

    model = MVIB(z_dim=z_dim, device=device, joint_posterior=joint_posterior).to(device)

    trainer = TrainerMVIB(
        model,
        epochs=epochs,
        lr=lr,
        beta=beta,
        checkpoint_dir=".",
        mode="trimodal",
        lr_scheduler_param=lr_scheduler_param
    )
    checkpoint = trainer.train(train_loader, val_loader, early_stopper_patience, monitor)    
    run_name = f"mvib.trimodal.{joint_posterior}.alpha+beta-only.rep-{i}"
    trainer.save_checkpoint(checkpoint, folder='./', filename=os.path.join(RESULTS_BASE, f"{run_name}.pth"))
    
    # test
    model = MVIB.from_checkpoint(checkpoint, torch.device("cpu"))
    pred = model.classify(pep=ds_test.pep, cdr3b=ds_test.cdr3b, cdr3a=ds_test.cdr3a)
    pred = pred.detach().numpy()
    df_test['prediction_'+str(i)] = pred.squeeze().tolist()

    # save results for further analysis
    df_test.to_csv(os.path.join(RESULTS_BASE, f"{run_name}.csv"), index=False)

[VAL] Best epoch 172 | Best val score -0.930989 | DKL-prior 0.000552 | BCE 0.751971 | auROC 0.9310:  44%|████▍     | 221/500 [29:08<36:46,  7.91s/it]


Saving best model: epoch 172


[VAL] Best epoch 201 | Best val score -0.928251 | DKL-prior 0.000550 | BCE 0.858718 | auROC 0.9283:  50%|█████     | 250/500 [33:05<33:05,  7.94s/it]


Saving best model: epoch 201


[VAL] Best epoch 167 | Best val score -0.934883 | DKL-prior 0.000559 | BCE 0.645326 | auROC 0.9349:  43%|████▎     | 216/500 [28:23<37:19,  7.89s/it]


Saving best model: epoch 167


[VAL] Best epoch 120 | Best val score -0.927328 | DKL-prior 0.000566 | BCE 0.639885 | auROC 0.9273:  34%|███▍      | 169/500 [22:11<43:28,  7.88s/it]


Saving best model: epoch 120


[VAL] Best epoch 137 | Best val score -0.935942 | DKL-prior 0.000577 | BCE 0.619351 | auROC 0.9359:  37%|███▋      | 186/500 [24:34<41:29,  7.93s/it]


Saving best model: epoch 137
