**@author: James V. Talwar**<br>

# Computing DARTH Scores: VADEr Interpretability  

**About**: This notebook serves as a tutorial/template for generating VADEr's paired interpretability metric: **D**irected **A**ttention **R**elevance from **T**ransformer **H**euristics (**DARTH**) scores. Specifically, this notebook can be used to calculate and save DARTH scores for a trained VADEr model, with the only changes needed being to the specified paths below.

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
import yaml
import pandas as pd
import numpy as np
from collections import defaultdict
from torch.utils.data import DataLoader

import sys
sys.path.append('./src/')

from VADErData import SNP_Dataset
from VADErDataUtils import GenerateChromosomePatchMask
from vader import VADEr

import logging
logging.getLogger().setLevel(logging.INFO)

In [2]:
logger = logging.getLogger()
console = logging.StreamHandler()
logger.addHandler(console)

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [4]:
'''
About: Load configuration from yaml-config path

Input(s): path: String corresponding to path to yaml config file for training.
'''
def LoadConfig(path):
    return yaml.load(open(path, 'r'), Loader = yaml.SafeLoader)

**USER: Update the following paths/parameters below:**
 - `modelSummaryPath`: Path to the directory of k-checkpointed trained VADEr models (for which this notebook will select the best performing model by validation performance on your specified training config `checkpoint metric`).
 - `trainingSummaryPath`: Path to the corresponding VADEr training performances (i.e., Training and validation set losses, accuracies, and AUCs).
 - `configPath`: Path to VADEr config used for model training.
 - `DARTH_score_write_path`: Path to (existing) directory to which want to write dataset DARTH scores.
 - `DARTH_score_file_name`: File prefix (not including file extension) for which would like to name DARTH scores for dataset under investigation. Update as desired. Default: "DARTH_Scores"
 - `batchSizeForEvaluation`: The number of samples in a batch. Update this according to your resources.
 - `number_workers`: The total number of workers to use for dataloading. Update this according to your needs/resources.
 - `featherPath`: Path to the composite genotype feather file for the dataset under investigation.
 - `phenoPath`: Path to the phenotype file containing phenotype information for all individuals in the dataset under investigation.
 - `testSetPath`: Path to the dataset-specific ID file, which contains all the individual IDs for the dataset under investigation.

In [5]:
modelSummaryPath = #<-- UPDATE WITH RELEVANT PATH 
trainingSummaryPath = #<-- UPDATE WITH RELEVANT PATH 
configPath = #<-- UPDATE WITH RELEVANT PATH 
DARTH_score_write_path = #<-- UPDATE WITH RELEVANT PATH 
DARTH_score_file_name = "DARTH_Scores" #<-- Change if desired; default DARTH score write name

config = LoadConfig(configPath)
datasetParams = config["dataset"]
modelParams = config["model_params"]

batchSizeForEvaluation = config["train_and_checkpoint"]["batch_size"]//2 #<-- SCALE AS NEEDED GIVEN GPU MEMORY RESOURCES
number_workers = 16 #<-- UPDATE AS NEEDED ACCORDING TO AVAILABLE RESOURCES

featherPath = #<-- UPDATE WITH RELEVANT PATH 
phenoPath = #<-- UPDATE WITH RELEVANT PATH 
testSetPath = #<-- UPDATE WITH RELEVANT PATH 

Load dataset:

In [6]:
testDataset = SNP_Dataset(featherFilePath = featherPath,
                          phenoFilePath = phenoPath,
                          idFilePath = testSetPath, 
                          snpSubsetPath = datasetParams.get("SNP_set"),
                          validMafSnpsPath = datasetParams.get("consistent_maf_SNPs"),
                          vaderPatchMappingPath = datasetParams.get("patch_mapping_path"),
                          trainingSetAgeStatsPath = datasetParams.get("age_train_stats"), 
                          sparsePatchThreshold = datasetParams.get("sparse_patch_threshold"),
                          enableShifting = False)

Valid age path given: /cellar/users/jtalwar/projects/BetterRiskScores/InSNPtion/Galbatorix/DucksInARow/TrainingSetStatistics/ELLIPSE/TrainingSetAgeStats.pkl
   Returning z-scored ages in loader...
1607 SNPs exist across the full 5e-4 dataset with INF or NULL values. Removing these now... Unremoved SNP set size is 110292
Cleaned SNP set size after removal of invalid SNPs is 108685
Filtering SNP set for MAF and genotype consistent SNPs...
Cleaned SNP set size after filtering incompatible genotype and MAF discrepancy SNPs is 12017
100%|██████████| 1204/1204 [00:00<00:00, 3932.72it/s]


In [7]:
loader = DataLoader(dataset = testDataset, pin_memory = True, shuffle = False, batch_size = batchSizeForEvaluation, num_workers = number_workers)

Identify, instantiate, and load best trained VADEr model:

In [8]:
metric_for_selection = config["train_and_checkpoint"]["checkpoint_metric"]
summary = pd.read_csv(trainingSummaryPath, sep = "\t", index_col = 0)
bestEpoch = summary[metric_for_selection].idxmax()
logger.info(f"Best {metric_for_selection} occurs at epoch {bestEpoch} with value {summary[metric_for_selection].max():.5f}")

Best Val_AUC occurs at epoch 32 with value 0.74864


In [9]:
num_registers = modelParams.get("num_registers")

vaderModel = VADEr(patchSizes = testDataset.patchSizes,
                   modelDim = modelParams["model_dim"],
                   mlpDim = modelParams["model_dim"] * modelParams["mlp_scale"],
                   depth = modelParams["num_transformer_blocks"],
                   attnHeads = modelParams["num_attention_heads"],
                   attnHeadDim = modelParams["model_dim"]//modelParams["num_attention_heads"],
                   multitaskOutputs = modelParams["prediction_dims"],
                   clumpProjectionDropout = modelParams["patch_projection_dropout"],
                   dropout = modelParams["model_dropout"], 
                   ageInclusion = modelParams["age_inclusion"],
                   aggr = modelParams["aggregation"],
                   context = modelParams.get("cls_representation"),
                   patchProjectionActivation = modelParams["non_linear_patch_projection"],
                   patchLayerNorm = modelParams.get("patch_layer_norm"),
                   trainingObjective = "cross_entropy",
                   attention = modelParams.get("attention"),
                   numRegisters = num_registers,
                   ffActivation = modelParams.get("mlp_method"),
                   contrastive_projection_net_dim = None) 

Implementing LEARNABLE CLS token representation.
Enabling 8 registers
Implementing SwiGLU... correcting mlpDim to 2048 to keep number params consistent
Implementing learnable temperature for MHA: True
Implementing learnable temperature for MHA: True
Implementing learnable temperature for MHA: True
Implementing learnable temperature for MHA: True
Implementing learnable temperature for MHA: True
Implementing learnable temperature for MHA: True
Implementing learnable temperature for MHA: True
Implementing learnable temperature for MHA: True
Implementing learnable temperature for MHA: True
Implementing learnable temperature for MHA: True
Implementing learnable temperature for MHA: True
Implementing learnable temperature for MHA: True


In [10]:
modelPath = os.path.join(modelSummaryPath, f"VADEr_Epoch_{bestEpoch}.pt")
vaderModel.load_state_dict(torch.load(modelPath)["modelStateDict"])

<All keys matched successfully>

In [11]:
vaderModel.to(device);

Define function (and helper functions) to generate DARTH scores:

In [12]:
def Average_Heads(attention, gradient, conjugate = False, **kwargs):
    attention_map = attention * gradient #[b, h, n, n]
    if conjugate:
        attention_map *= -1
    attention_map = attention_map.clamp(min = 0).mean(dim = 1) #[b, n, n]
    
    return attention_map

def Self_Attention_Rule(attention_map, relevance_map):
    return torch.matmul(attention_map, relevance_map)

def Generate_Relevance(model, batch_size, num_tokens, **kwargs):
    #Identify number of registers used in VADEr model
    try:
        num_registers = model.registers.size(1)
    except:
        num_registers = 0

    R = torch.eye(num_tokens, num_tokens).unsqueeze(0).repeat(batch_size, 1, 1).to(device) #[b, n, n]
    for block in vaderModel.transformer.blocks:
        attention = block[1].get_attention_map().detach()
        attention_gradient = block[1].get_attention_gradients()
        block[1].reset_attention_attributes()
        attention_map = Average_Heads(attention = attention, gradient = attention_gradient, **kwargs)
        R += Self_Attention_Rule(attention_map = attention_map, relevance_map = R)
    
    if num_registers > 0:
        return R[:, 0, 1:-num_registers].to("cpu")
    
    else:
        return R[:, 0, 1:].to("cpu")
    
def Generate_Transformer_Explainability(model, loader, device, mask = None, **kwargs):
    model.eval() 
    
    vader_attribution = torch.Tensor()
    
    if "conjugate" in kwargs:
        assert kwargs["conjugate"] in {True, False}, "invalid option for conjugate. conjugate must be in {True, False}."
        logger.info(f"DARTH score conjugate status: {kwargs['conjugate']}")
        
    for i, (patchBatch, diseaseStatusBatch, ancestryBatch, fHBatch, zAgeBatch) in enumerate(loader):
        model.zero_grad()
        gpuClumpBatch = {k:v.to(device) for k,v in patchBatch.items()} #features
        
        if vaderModel.includeAge: #including age - need to pass in more than clump dictionary
            output = vaderModel(dictOfClumps = gpuClumpBatch, mask = mask, age_batch = zAgeBatch.to(device), extract_attention = True)
        else:
            output = vaderModel(dictOfClumps = gpuClumpBatch, mask = mask, extract_attention = True)
        
        #sum logits - dy/dA will then be computed for each element in the batch; logits used for cleaner gradients
        z = output["disease"].sum()
        
        #z = F.sigmoid(output["disease"]).sum() #<-- Sigmoid instead of logits if desired: scales derivative by g(x)*1-g(x)
        
        z.backward()
        
        batch_attribution = Generate_Relevance(model = model, 
                                               batch_size = output['disease'].shape[0], 
                                               num_tokens = model.transformer.blocks[0][1].get_attention_map().size(-1),
                                               **kwargs)
        
        vader_attribution = torch.cat([vader_attribution, batch_attribution])
        
        # Clean up to prevent memory buildup
        torch.cuda.empty_cache()
        
    return vader_attribution

Get model attributions:

 - **Nota Bene**: Given the size of your dataset and your specified `batchSizeForEvaluation`, this computation (in the subsequent cell) may take some time (on the order of hours).

In [13]:
vader_patch_attribution = Generate_Transformer_Explainability(model = vaderModel, loader = loader, device = device)

Save attributions:

In [17]:
testIDs = pd.read_csv(testSetPath, header = None, dtype = str)[0].tolist()

In [15]:
numericalOrderedPatches = sorted([int(patch.split("p")[1]) for patch in testDataset.patchSizes])
patch_columns = ["patch" + str(el) for el in numericalOrderedPatches]

In [16]:
vader_interpretability_df = pd.DataFrame(vader_patch_attribution, index = testIDs, columns = patch_columns)
vader_interpretability_df.to_csv(os.path.join(DARTH_score_write_path, f"{DARTH_score_file_name}.tsv"), sep = "\t")