# Out-Of-Distribution Detection (OOD) for TCR recognition 

In this notebook, we investigate the effect of the beta hyperapram. on OOD detection.
The in-distribution dataset is the `α+β set`+`β set` (i.e. the human data); the out-of-distribtuion dataset is the `non-human set` (mouse + macaque).

The goal of this study is to come up with a novel OOD detection method and benchmark it to a set of baseline methods.
Goal: only the human `α+β set`+`β set` (in-distribution dataset) is available at training time. We have no access to the `non-human set` (out-of-distribution dataset).

In [None]:
import pandas as pd
import torch
import numpy as np
import random
import os

from vibtcr.dataset import TCRDataset
from vibtcr.mvib.mvib import MVIB
from vibtcr.mvib.mvib_trainer import TrainerMVIB

from torch.utils.data.sampler import WeightedRandomSampler
from torch.autograd import Variable

from sklearn.model_selection import train_test_split
from tqdm import tqdm

In [2]:
from sklearn.metrics import roc_auc_score, accuracy_score, precision_score, recall_score, f1_score, precision_recall_curve, auc
import pandas as pd
import torch

metrics = [
    'auROC',
    'Accuracy',
    'Recall',
    'Precision',
    'F1 score',
    'auPRC'
]

def pr_auc(y_true, y_prob):
    precision, recall, thresholds = precision_recall_curve(y_true, y_prob)
    pr_auc = auc(recall, precision)
    return pr_auc

def get_scores(y_true, y_prob, y_pred):
    """
    Compute a df with all classification metrics and respective scores.
    """
    
    scores = [
        roc_auc_score(y_true, y_prob),
        accuracy_score(y_true, y_pred),
        recall_score(y_true, y_pred),
        precision_score(y_true, y_pred),
        f1_score(y_true, y_pred),
        pr_auc(y_true, y_prob)
    ]
    
    df = pd.DataFrame(data={'score': scores, 'metrics': metrics})
    return df

In [3]:
def set_random_seed(random_seed):
    random.seed(random_seed)
    np.random.seed(random_seed)
    torch.manual_seed(random_seed)
    torch.cuda.manual_seed(random_seed)
    torch.cuda.manual_seed_all(random_seed)

In [4]:
login = os.getlogin( )

DATA_BASE = f"/home/{login}/Git/tcr/data/"
RESULTS_BASE = f"/home/{login}/Git/tcr/notebooks/notebooks.ood/results/"

In [9]:
device = torch.device('cuda:2')
test_device = torch.device('cuda:4')

batch_size = 4096
epochs = 200
lr = 1e-3

z_dim = 150
early_stopper_patience = 20
monitor = 'auROC'
lr_scheduler_param = 10
joint_posterior = "aoe"

beta = 1e-4 # 0, 1e-6, 1e-4, 1e-2

# Model training

In [10]:
df_in = pd.concat([
    pd.read_csv(DATA_BASE + 'alpha-beta-splits/alpha-beta.csv'),
    pd.read_csv(DATA_BASE + 'alpha-beta-splits/beta.csv')
    ])

df_out = pd.read_csv(DATA_BASE + 'vdjdb/vdjdb-2021-09-05/mouse-macaco.csv')

if os.path.isfile(RESULTS_BASE+f'avib.beta-{beta}.rep-0.pth.tar'):
    checkpoints = [
        torch.load(RESULTS_BASE+f'avib.beta-{beta}.rep-{i}.pth.tar')
        for i in range(5)
    ]
else:
    checkpoints = []

if len(checkpoints) == 0:
    for i in range(5):  # 5 independent train/test splits
        set_random_seed(i)

        df_train, df_test_in = train_test_split(df_in.copy(), test_size=0.2, random_state=i)
        df_test_out = df_out
        df_test_in = df_test_in.sample(n=len(df_test_out))
        scaler = TCRDataset(df_train.copy(), torch.device("cpu"), cdr3b_col='tcrb', cdr3a_col=None, softmax=True).scaler

        df_test_in['sign'] = 1  #in-distribution test set
        df_test_out['sign'] = 0  #out-of-distribution test set
        df_test = pd.concat([df_test_in, df_test_out])
        ds_test = TCRDataset(df_test, torch.device("cpu"), cdr3b_col='tcrb', cdr3a_col=None, scaler=scaler, softmax=True)

        df_train, df_val = train_test_split(df_train, test_size=0.2, stratify=df_train.sign, random_state=i)

        # train loader with balanced sampling
        ds_train = TCRDataset(df_train, device, cdr3b_col='tcrb', cdr3a_col=None, scaler=scaler, softmax=True)
        class_count = np.array([df_train[df_train.sign == 0].shape[0], df_train[df_train.sign == 1].shape[0]])
        weight = 1. / class_count
        samples_weight = torch.tensor([weight[s] for s in df_train.sign])
        sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
        train_loader = torch.utils.data.DataLoader(
            ds_train,
            batch_size=batch_size,
            sampler=sampler
        )

        # val loader with balanced sampling
        ds_val = TCRDataset(df_val, device, cdr3b_col='tcrb', cdr3a_col=None, softmax=True, scaler=scaler)
        class_count = np.array([df_val[df_val.sign == 0].shape[0], df_val[df_val.sign == 1].shape[0]])
        weight = 1. / class_count
        samples_weight = torch.tensor([weight[s] for s in df_val.sign])
        sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
        val_loader = torch.utils.data.DataLoader(
            ds_val,
            batch_size=batch_size,
            sampler=sampler
        )

        model = MVIB(z_dim=z_dim, device=device, joint_posterior=joint_posterior, softmax=True).to(device)

        trainer = TrainerMVIB(
            model,
            epochs=epochs,
            lr=lr,
            beta=beta,
            checkpoint_dir=".",
            mode="bimodal",
            lr_scheduler_param=lr_scheduler_param
        )
        checkpoint = trainer.train(train_loader, val_loader, early_stopper_patience, monitor)
        trainer.save_checkpoint(checkpoint, folder='./', filename=RESULTS_BASE+f'avib.beta-{beta}.rep-{i}.pth.tar')
        checkpoints.append(checkpoint)

# Mahalanobis distance
Lee et al., NIPS 2018

https://arxiv.org/abs/1807.03888

In [11]:
# code adapted from https://github.com/pokaxpoka/deep_Mahalanobis_detector

import sklearn.covariance
from torch.autograd import Variable
from scipy.spatial.distance import pdist, cdist, squareform

def get_Mahalanobis_score(model, test_loader, num_classes, sample_mean, precision, layer_index, magnitude, mean, scale, cdr3b_only = False):
    '''
    Compute the proposed Mahalanobis confidence score on input dataset
    return: Mahalanobis score from layer_index
    '''
    model.eval()
    Mahalanobis = []
    
    for pep, cdr3b, target in test_loader:
        pep, cdr3b, target = Variable(pep, requires_grad = True), Variable(cdr3b, requires_grad = True), Variable(target)
        
        if cdr3b_only:
            if layer_index == 0:
                out_features = model.encoder_cdr3b(cdr3b)[0]  # mu
            else:
                raise NotImplementedError("if cdr3b_only, layer_index must be 0")
        else:
            if layer_index == 0:
                out_features = model(pep, cdr3b)[0]  # mu
            else:
                out_features = model(pep, cdr3b)[2]  # classification logits

        # compute Mahalanobis score
        gaussian_score = 0
        for i in range(num_classes):
            batch_sample_mean = sample_mean[layer_index][i]
            zero_f = out_features.data - batch_sample_mean
            p = precision[layer_index]
            term_gau = -0.5*torch.mm(torch.mm(zero_f, p), zero_f.t()).diag()
            if i == 0:
                gaussian_score = term_gau.view(-1,1)
            else:
                gaussian_score = torch.cat((gaussian_score, term_gau.view(-1,1)), 1)

        # Input_processing
        sample_pred = gaussian_score.max(1)[1]
        batch_sample_mean = sample_mean[layer_index].index_select(0, sample_pred)
        zero_f = out_features - Variable(batch_sample_mean)
        p = Variable(precision[layer_index])
        pure_gau = -0.5*torch.mm(torch.mm(zero_f, p), zero_f.t()).diag()
        loss = torch.mean(-pure_gau)
        loss.backward()
        
        if not cdr3b_only:
            gradient_pep =  torch.ge(pep.grad.data, 0)
            gradient_pep = (gradient_pep.float() - 0.5) * 2
            # normalizing the gradient to the same space of input
            gradient_pep = (gradient_pep.transpose(1,2) / torch.tensor(mean/scale).float()).transpose(1,2)
            tempInputs_pep = torch.add(pep.data, -magnitude, gradient_pep)
          
        gradient_cdr3b =  torch.ge(cdr3b.grad.data, 0)
        gradient_cdr3b = (gradient_cdr3b.float() - 0.5) * 2
        # normalizing the gradient to the same space of input
        gradient_cdr3b = (gradient_cdr3b.transpose(1,2) / torch.tensor(mean/scale).float()).transpose(1,2)
        tempInputs_cdr3b = torch.add(cdr3b.data, -magnitude, gradient_cdr3b)

        if cdr3b_only:
            if layer_index == 0:
                noise_out_features = model.encoder_cdr3b(Variable(tempInputs_cdr3b))[0]  # mu
            else:
                raise NotImplementedError("if cdr3b_only, layer_index must be 0")
        else:
            if layer_index == 0:
                noise_out_features = model(
                    Variable(tempInputs_pep), Variable(tempInputs_cdr3b)
                )[0]  # mu
            else:
                noise_out_features = model(
                    Variable(tempInputs_pep), Variable(tempInputs_cdr3b)
                )[2]  # classification logits

        with torch.no_grad():
            noise_out_features = noise_out_features.view(noise_out_features.size(0), noise_out_features.size(1), -1)
            noise_out_features = torch.mean(noise_out_features, 2)
            noise_gaussian_score = 0
            for i in range(num_classes):
                batch_sample_mean = sample_mean[layer_index][i]
                zero_f = noise_out_features.data - batch_sample_mean
                term_gau = -0.5*torch.mm(torch.mm(zero_f, precision[layer_index]), zero_f.t()).diag()
                if i == 0:
                    noise_gaussian_score = term_gau.view(-1,1)
                else:
                    noise_gaussian_score = torch.cat((noise_gaussian_score, term_gau.view(-1,1)), 1)      

        noise_gaussian_score, _ = torch.max(noise_gaussian_score, dim=1)
        Mahalanobis.extend(noise_gaussian_score.detach().cpu().numpy())

    return Mahalanobis


In [None]:
df_in = pd.concat([
    pd.read_csv(DATA_BASE + 'alpha-beta-splits/alpha-beta.csv'),
    pd.read_csv(DATA_BASE + 'alpha-beta-splits/beta.csv')
    ])

df_out = pd.read_csv(DATA_BASE + 'vdjdb/vdjdb-2021-09-05/mouse-macaco.csv')

for experiment in range(5):  # 5 independent train/test splits
    set_random_seed(experiment)

    df_train, df_test_in = train_test_split(df_in.copy(), test_size=0.2, random_state=experiment)
    df_test_out = df_out
    df_test_in = df_test_in.sample(n=len(df_test_out))
    scaler = TCRDataset(df_train.copy(), torch.device("cpu"), cdr3b_col='tcrb', cdr3a_col=None, softmax=True).scaler
    
    df_test_in['sign'] = 1  #in-distribution test set
    df_test_out['sign'] = 0  #out-of-distribution test set
    df_test = pd.concat([df_test_in, df_test_out])
    ds_test = TCRDataset(df_test, torch.device("cpu"), cdr3b_col='tcrb', cdr3a_col=None, scaler=scaler, softmax=True)

    df_train, df_val = train_test_split(df_train, test_size=0.2, stratify=df_train.sign, random_state=experiment)

    # train loader with balanced sampling
    ds_train = TCRDataset(df_train, device, cdr3b_col='tcrb', cdr3a_col=None, scaler=scaler, softmax=True)
    class_count = np.array([df_train[df_train.sign == 0].shape[0], df_train[df_train.sign == 1].shape[0]])
    weight = 1. / class_count
    samples_weight = torch.tensor([weight[s] for s in df_train.sign])
    sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
    train_loader = torch.utils.data.DataLoader(
        ds_train,
        batch_size=batch_size,
        sampler=sampler
    )

    # val loader with balanced sampling
    ds_val = TCRDataset(df_val, device, cdr3b_col='tcrb', cdr3a_col=None, softmax=True, scaler=scaler)
    class_count = np.array([df_val[df_val.sign == 0].shape[0], df_val[df_val.sign == 1].shape[0]])
    weight = 1. / class_count
    samples_weight = torch.tensor([weight[s] for s in df_val.sign])
    sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
    val_loader = torch.utils.data.DataLoader(
        ds_val,
        batch_size=batch_size,
        sampler=sampler
    )

    # test loader for Mahalanobis distance
    test_loader = torch.utils.data.DataLoader(
        ds_test,
        batch_size=batch_size
    )

    checkpoint = checkpoints[experiment]       

    # test
    model = MVIB.from_checkpoint(checkpoint, torch.device("cpu"))
    model.eval()
            
    for layer_index in [0]:  # 0 = mu - means of latent Gaussian; 1 = pre-softmax logits
        for magnitude in [0]:  # different magnitude for gradient pre-processing
            ############# Peptide + CDR3b joint posterior #############
            sample_mean = []
            precision = []

            # get training samples of class 0
            pep_0 = ds_train.pep[ds_train.gt.argmax(dim=1) == torch.tensor(0)].cpu()
            cdr3b_0 = ds_train.cdr3b[ds_train.gt.argmax(dim=1) == torch.tensor(0)].cpu()

            # get training samples of class 1
            pep_1 = ds_train.pep[ds_train.gt.argmax(dim=1) == torch.tensor(1)].cpu()
            cdr3b_1 = ds_train.cdr3b[ds_train.gt.argmax(dim=1) == torch.tensor(1)].cpu()

            # pass the training samples through the trained model
            # get latent mu (mean fo the latent joint posterior) for each class
            samples_mu = [
                model(pep=pep_0, cdr3b=cdr3b_0, cdr3a=None)[0],
                model(pep=pep_1, cdr3b=cdr3b_1, cdr3a=None)[0],
            ]

            # pass the training samples through the trained model
            # get latent pre-softmax logits for each class
            samples_logits = [
                model(pep=pep_0, cdr3b=cdr3b_0, cdr3a=None)[2],
                model(pep=pep_1, cdr3b=cdr3b_1, cdr3a=None)[2],
            ]

            # compute mean of the mu for each class
            samples_mean_mu = torch.Tensor(2, z_dim)
            samples_mean_mu[0] = model(pep=pep_0, cdr3b=cdr3b_0, cdr3a=None)[0].mean(dim=0)
            samples_mean_mu[1] = model(pep=pep_1, cdr3b=cdr3b_1, cdr3a=None)[0].mean(dim=0)

            # compute mean of the pre-softmax logits for each class
            samples_mean_logits = torch.Tensor(2, 2)
            samples_mean_logits[0] = model(pep=pep_0, cdr3b=cdr3b_0, cdr3a=None)[2].mean(dim=0)
            samples_mean_logits[1] = model(pep=pep_1, cdr3b=cdr3b_1, cdr3a=None)[2].mean(dim=0)

            sample_mean.append(samples_mean_mu)
            sample_mean.append(samples_mean_logits)

            group_lasso = sklearn.covariance.EmpiricalCovariance(assume_centered=False)

            # construct inverse of covariance matrix for mu
            X = 0
            for i in range(2):
                if i == 0:
                    X = samples_mu[i] - sample_mean[0][i]
                else:
                    X = torch.cat((X, samples_mu[i] - sample_mean[0][i]), 0)

            # find inverse            
            group_lasso.fit(X.detach().cpu().numpy())
            temp_precision = group_lasso.precision_
            temp_precision = torch.from_numpy(temp_precision).float()
            precision.append(temp_precision)

            # construct inverse of covariance matrix for pre-softmax logits
            X = 0
            for i in range(2):
                if i == 0:
                    X = samples_logits[i] - sample_mean[1][i]
                else:
                    X = torch.cat((X, samples_logits[i] - sample_mean[1][i]), 0)

            # find inverse            
            group_lasso.fit(X.detach().cpu().numpy())
            temp_precision = group_lasso.precision_
            temp_precision = torch.from_numpy(temp_precision).float()
            precision.append(temp_precision)

            # compute Mahalanobis distances for all test samples (peptide+CDR3b)
            maha = get_Mahalanobis_score(model, test_loader, 2, sample_mean, precision, layer_index, magnitude, scaler.mean_, scaler.scale_, cdr3b_only=False)
            df_test['prediction_'+str(experiment)] = maha

            # save results for further analysis
            df_test.to_csv(
                RESULTS_BASE + f"non-human.mvib.{joint_posterior}.beta-{beta}.maha.layer-{layer_index}.epsilon-{magnitude}.rep-{experiment}.csv",
                index=False
            )
