# Attentive Variational Information Bottleneck

In this notebook, we train the Attentive Variational Information Bottleneck on the `α+β set` and test on the `β set`.

In [1]:
from sklearn.metrics import roc_auc_score, accuracy_score, precision_score, recall_score, f1_score, precision_recall_curve, auc
import pandas as pd
import torch

metrics = [
    'auROC',
    'Accuracy',
    'Recall',
    'Precision',
    'F1 score',
    'auPRC'
]

def pr_auc(y_true, y_prob):
    precision, recall, thresholds = precision_recall_curve(y_true, y_prob)
    pr_auc = auc(recall, precision)
    return pr_auc

def get_scores(y_true, y_prob, y_pred):
    """
    Compute a df with all classification metrics and respective scores.
    """
    
    scores = [
        roc_auc_score(y_true, y_prob),
        accuracy_score(y_true, y_pred),
        recall_score(y_true, y_pred),
        precision_score(y_true, y_pred),
        f1_score(y_true, y_pred),
        pr_auc(y_true, y_prob)
    ]
    
    df = pd.DataFrame(data={'score': scores, 'metrics': metrics})
    return df

In [2]:
import os
login = os.getlogin( )
DATA_BASE = f"/home/{login}/Git/tcr/data/"
RESULTS_BASE = f"/home/{login}/Git/tcr/notebooks/notebooks.classification/results/"

In [3]:
device = torch.device('cuda:1')
batch_size = 4096
epochs = 200
lr = 1e-3
z_dim = 150
beta = 1e-6
early_stopper_patience = 20
monitor = 'auROC'
lr_scheduler_param = 10
joint_posterior = "aoe"

In [5]:
import pandas as pd
import torch
import numpy as np

from vibtcr.dataset import TCRDataset
from vibtcr.mvib.mvib import MVIB
from vibtcr.mvib.mvib_trainer import TrainerMVIB
from torch.utils.data.sampler import WeightedRandomSampler
from sklearn.model_selection import train_test_split
from tqdm import tqdm


df = pd.read_csv(DATA_BASE + 'alpha-beta-splits/alpha-beta.csv')
scaler = TCRDataset(df.copy(), torch.device("cpu"), cdr3b_col='tcrb', cdr3a_col=None).scaler

df_test = pd.read_csv(DATA_BASE + 'alpha-beta-splits/beta.csv')
ds_test = TCRDataset(df_test, torch.device("cpu"), cdr3b_col='tcrb', cdr3a_col=None, scaler=scaler)

for i in range(5):  # 5 independent train/val splits
    df_train, df_val = train_test_split(df, test_size=0.2, stratify=df.sign, random_state=i)
    
    # train loader with balanced sampling
    ds_train = TCRDataset(df_train, device, cdr3b_col='tcrb', cdr3a_col=None, scaler=scaler)
    class_count = np.array([df_train[df_train.sign == 0].shape[0], df_train[df_train.sign == 1].shape[0]])
    weight = 1. / class_count
    samples_weight = torch.tensor([weight[s] for s in df_train.sign])
    sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
    train_loader = torch.utils.data.DataLoader(
        ds_train,
        batch_size=batch_size,
        sampler=sampler
    )
    
    # val loader with balanced sampling
    ds_val = TCRDataset(df_val, device, cdr3b_col='tcrb', cdr3a_col=None, scaler=scaler)
    class_count = np.array([df_val[df_val.sign == 0].shape[0], df_val[df_val.sign == 1].shape[0]])
    weight = 1. / class_count
    samples_weight = torch.tensor([weight[s] for s in df_val.sign])
    sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
    val_loader = torch.utils.data.DataLoader(
        ds_val,
        batch_size=batch_size,
        sampler=sampler
    )

    model = MVIB(z_dim=z_dim, device=device, joint_posterior=joint_posterior).to(device)

    trainer = TrainerMVIB(
        model,
        epochs=epochs,
        lr=lr,
        beta=beta,
        checkpoint_dir=".",
        mode="bimodal",
        lr_scheduler_param=lr_scheduler_param
    )
    checkpoint = trainer.train(train_loader, val_loader, early_stopper_patience, monitor)    
    
    # test
    model = MVIB.from_checkpoint(checkpoint, torch.device("cpu"))
    pred = model.classify(pep=ds_test.pep, cdr3b=ds_test.cdr3b, cdr3a=None)
    pred = pred.detach().numpy()
    df_test['prediction_'+str(i)] = pred.squeeze().tolist()

# save results for further analysis
df_test.to_csv(
    RESULTS_BASE + "mvib.ab2b.csv",
    index=False
)

[VAL] Best epoch 72 | Score -0.920575 | DKL-prior 0.000536 | BCE 0.482500 | auROC 0.9206:  46%|████▌     | 91/200 [03:14<03:53,  2.14s/it]
[VAL] Best epoch 54 | Score -0.915860 | DKL-prior 0.000438 | BCE 0.435013 | auROC 0.9159:  36%|███▋      | 73/200 [02:35<04:30,  2.13s/it]
[VAL] Best epoch 66 | Score -0.915766 | DKL-prior 0.000543 | BCE 0.442076 | auROC 0.9158:  42%|████▎     | 85/200 [03:01<04:05,  2.13s/it]
[VAL] Best epoch 95 | Score -0.920989 | DKL-prior 0.000535 | BCE 0.522592 | auROC 0.9210:  57%|█████▋    | 114/200 [04:02<03:02,  2.13s/it]
[VAL] Best epoch 75 | Score -0.919583 | DKL-prior 0.000529 | BCE 0.465375 | auROC 0.9196:  47%|████▋     | 94/200 [03:21<03:46,  2.14s/it]
