**@author: James V. Talwar**

# VADEr Inference: Computing Polygenic Risk Scores with VADEr

**About:** This notebook provides a detailed walkthrough of how to generate polygenic risk scores (PRSs) with VADEr, and can be used either in standalone format (i.e., as is) or as a template/tutorial for generating a VADEr inference script. In particular, this notebook covers the following:
 - VADEr model instantiation
 - Loading the optimal (i.e., best validation set performance) trained-checkpointed model
 - Generating PRSs with VADEr for a dataset of interest
 - Calculating performance metrics 

In [1]:
import os
import yaml
import pandas as pd
import numpy as np
from collections import defaultdict
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

import sys 
sys.path.append('./src/') #path to VADEr, SNP_Dataset, utility functions etc...

from VADErData import SNP_Dataset
from VADErDataUtils import GenerateChromosomePatchMask
from vader import VADEr
from MetricUtils import BinaryClassAccuracy, MultiClassAccuracy, Calc_ROC_AUC 

import logging
logging.getLogger().setLevel(logging.INFO)

In [2]:
logger = logging.getLogger()
console = logging.StreamHandler()
logger.addHandler(console)

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [4]:
!nvidia-smi

Sat Jun  7 17:59:25 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 565.57.01              Driver Version: 565.57.01      CUDA Version: 12.7     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA A30                     On  |   00000000:81:00.0 Off |                    0 |
| N/A   34C    P0             29W /  165W |       4MiB /  24576MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
                                                

In [5]:
'''
About: Load configuration from yaml-config path

Input(s): path: String corresponding to path to yaml config file for training.
'''
def LoadConfig(path):
    return yaml.load(open(path, 'r'), Loader = yaml.SafeLoader)

**USER: Update the following paths/parameters below:**
 - `config_path`: Path to config employed for VADEr training.
 - `prediction_write_path`: Path to (existing) directory to which want to write dataset VADEr predictions/PRSs.
 - number_workers: The total number of workers to use for dataloading. Update this according to your needs/resources.
 - `prediction_file_name`: (Optional) File name (not including extension) for VADEr prediction/PRSs. By default this file name will match your config `training_summary_path`.
 - `metric`: (Optional) String corresponding to the metric for which want to load the best performing k-checkpointed VADEr model. By default `metric` will match your config specified `checkpoint_metric`, which ensures the best performing model by your defined metric is loaded (as this was the metric by which k-checkpointing was conducted). Valid options: `{Val_Disease_Accuracy, Val_AUC, Val_Loss, Train_Disease_Accuracy, Train_AUC, Train_Loss}`.
 - `feather_path`: File path to [composite-level](https://github.com/jvtalwar/DARTH_VADEr/wiki/Enabling-Dataloading:-Data-Processing-and-Expected-Formatting#which-should-i-choose---composite-level-vs-individual-level-feathers-important-considerations) test set genotype feather file.
   - Writing/reading individual-level feathers can also be employed here, you will just need to define a dataset (e.g., test set) specific `cached_feather_path` below and pass that to `cache_write_path` in SNP_Dataset object initialization (variable name `dataset` below).
 - `pheno_path`: File path to all dataset (e.g., test set) [phenotypes](https://github.com/jvtalwar/DARTH_VADEr/wiki/Enabling-Dataloading:-Data-Processing-and-Expected-Formatting#formatting-phenotypes).
 - `test_ids_path`: File path to all dataset (e.g., test set) IDs.

In [6]:
config_path = #<-- UPDATE WITH RELEVANT PATH 
prediction_write_path = "../Predictions" #<-- UPDATE WITH RELEVANT PATH (ensure directory exists) 
number_workers = 16 #<-- UPDATE AS NEEDED ACCORDING TO AVAILABLE RESOURCES

config = LoadConfig(config_path)

dataset_params = config["dataset"]
model_params = config["model_params"]
train_and_checkpoint_params = config["train_and_checkpoint"]

In [7]:
prediction_file_name = os.path.basename(train_and_checkpoint_params["training_summary_path"])
metric = train_and_checkpoint_params["checkpoint_metric"]

valid_metric = {"Val_Disease_Accuracy", "Val_AUC", "Val_Loss", "Train_Disease_Accuracy", "Train_AUC", "Train_Loss"}

assert metric in valid_metric

In [8]:
#Define all dataset (e.g., test set) paths:
feather_path = #<-- UPDATE WITH RELEVANT PATH 
pheno_path = #<-- UPDATE WITH RELEVANT PATH 
test_ids_path = #<-- UPDATE WITH RELEVANT PATH 

#cached_feather_path = ... <-- Define desired write directory if during PRS generation want to write/read from individual-level feathers

Identify the best performing VADEr model by defined `metric`:


In [9]:
metric_alignment = metric == train_and_checkpoint_params["checkpoint_metric"]

performance_summary = pd.read_csv(train_and_checkpoint_params["training_summary_path"], sep = "\t", index_col = 0)

if not metric_alignment:
    k_checkpointed_models = [int(model.split(".")[0].split("_")[-1]) for model in os.listdir(train_and_checkpoint_params["model_checkpoint_path"])]
    performance_summary = performance_summary.loc[k_checkpointed_models, :]
    
if metric in {"Val_Loss", "Train_Loss"}: #minimization metric
    best_performing_model = performance_summary[metric].idxmin()
    
else:
    best_performing_model = performance_summary[metric].idxmax()
    
logger.info(f"Best trained VADEr model by defined metric {metric} occurs at epoch {best_performing_model} with the following training/validation set performance:")
pd.DataFrame(performance_summary.loc[best_performing_model, :]).T

Best trained VADEr model by defined metric Val_AUC occurs at epoch 32 with the following training/validation set performance:


Unnamed: 0,Train_Disease_Accuracy,Train_Ancestry_Accuracy,Train_AUC,Train_Loss,Val_Disease_Accuracy,Val_Ancestry_Accuracy,Val_AUC,Val_Loss
32,0.7077,,0.765795,0.563552,0.680015,,0.74864,0.625232


Initialize SNP_Dataset object and dataloader:

In [10]:
dataset = SNP_Dataset(featherFilePath = feather_path,
                      phenoFilePath = pheno_path,
                      idFilePath = test_ids_path, 
                      snpSubsetPath = dataset_params.get("SNP_set"),
                      validMafSnpsPath = dataset_params.get("consistent_maf_SNPs"),
                      vaderPatchMappingPath = dataset_params.get("patch_mapping_path"),
                      trainingSetAgeStatsPath = dataset_params.get("age_train_stats"), 
                      sparsePatchThreshold = dataset_params.get("sparse_patch_threshold"),
                      enableShifting = False)

Valid age path given: /cellar/users/jtalwar/projects/BetterRiskScores/InSNPtion/Galbatorix/DucksInARow/TrainingSetStatistics/ELLIPSE/TrainingSetAgeStats.pkl
   Returning z-scored ages in loader...
1607 SNPs exist across the full 5e-4 dataset with INF or NULL values. Removing these now... Unremoved SNP set size is 110292
Cleaned SNP set size after removal of invalid SNPs is 108685
Filtering SNP set for MAF and genotype consistent SNPs...
Cleaned SNP set size after filtering incompatible genotype and MAF discrepancy SNPs is 12017
100%|██████████| 1204/1204 [00:00<00:00, 3603.55it/s]


In [11]:
dataloader = DataLoader(dataset = dataset, 
                        pin_memory = True, 
                        shuffle = False, 
                        batch_size = train_and_checkpoint_params["batch_size"] * torch.cuda.device_count() * 2, 
                        num_workers = number_workers)

Instantiate and load best trained VADEr model:

In [12]:
masking = train_and_checkpoint_params.get("masking") 
mask = None
if masking == "chrom":
    mask = GenerateChromosomePatchMask(patch_to_chrom_mapping_path = dataset_params.get("patch_to_chrom_mapping"),
                                       feature_patches = dataset.patchSizes)
    mask = mask.to(device)

attention = model_params.get("attention")
patch_layer_norm = model_params.get("patch_layer_norm")
num_registers = model_params.get("num_registers")

if attention is None:
    attention = "MHA"

if patch_layer_norm is None:
    patch_layer_norm = attention == "LSA"

if num_registers is None:
    num_registers = 0

vader_model = VADEr(patchSizes = dataset.patchSizes,
                   modelDim = model_params["model_dim"],
                   mlpDim = model_params["model_dim"] * model_params["mlp_scale"],
                   depth = model_params["num_transformer_blocks"],
                   attnHeads = model_params["num_attention_heads"],
                   attnHeadDim = model_params["model_dim"]//model_params["num_attention_heads"],
                   multitaskOutputs = model_params["prediction_dims"],
                   clumpProjectionDropout = model_params["patch_projection_dropout"],
                   dropout = model_params["model_dropout"], 
                   ageInclusion = model_params["age_inclusion"],
                   aggr = model_params["aggregation"],
                   context = model_params.get("cls_representation"),
                   patchProjectionActivation = model_params["non_linear_patch_projection"],
                   patchLayerNorm = patch_layer_norm,
                   trainingObjective = "cross_entropy",
                   attention = attention,
                   numRegisters = num_registers,
                   ffActivation = model_params.get("mlp_method"),
                   contrastive_projection_net_dim = None)

Implementing LEARNABLE CLS token representation.
Enabling 8 registers
Implementing SwiGLU... correcting mlpDim to 2048 to keep number params consistent
Implementing learnable temperature for MHA: True
Implementing learnable temperature for MHA: True
Implementing learnable temperature for MHA: True
Implementing learnable temperature for MHA: True
Implementing learnable temperature for MHA: True
Implementing learnable temperature for MHA: True
Implementing learnable temperature for MHA: True
Implementing learnable temperature for MHA: True
Implementing learnable temperature for MHA: True
Implementing learnable temperature for MHA: True
Implementing learnable temperature for MHA: True
Implementing learnable temperature for MHA: True


In [13]:
model_path = os.path.join(train_and_checkpoint_params["model_checkpoint_path"], f"VADEr_Epoch_{best_performing_model}.pt")
vader_model.load_state_dict(torch.load(model_path)["modelStateDict"], strict = False)

if torch.cuda.device_count() > 1:
    torch.nn.DataParallel(vader_model, list(range(torch.cuda.device_count())))
    
vader_model.to(device);

**Define method for evaluation and prediction file generation:**

In [14]:
'''
Method to evaluate a trained VADEr model and write predictions. Returns a dictionary of metrics.
'''
@torch.no_grad()
def Eval_Model(model, loader, device, IDs, write_path, file_name, mask):
    model.eval()
    
    raw = torch.Tensor()
    labels = torch.Tensor()
    metrics = defaultdict(float)
    lossFx = nn.BCEWithLogitsLoss(reduction = "sum")
    
    for i, (patchBatch, diseaseStatusBatch, ancestryBatch, fHBatch, zAgeBatch) in enumerate(loader):
        gpuClumpBatch = {k:v.to(device) for k,v in patchBatch.items()} #features
        diseaseStatusBatch = diseaseStatusBatch.to(device) #labels

        if model.includeAge: #including age - need to pass in more than clump dictionary
            output = model(dictOfClumps = gpuClumpBatch, mask = mask, age_batch = zAgeBatch.to(device))
        else:
            output = model(dictOfClumps = gpuClumpBatch, mask = mask)
        
        metrics["Loss"] += lossFx(output["disease"], diseaseStatusBatch.float()) #<-- if model is sup_con pretrained, will need to change output["disease"] to output - can check if self.projectionnetwork in model? --> getattr(model, "projectionNetwork")
        
        raw = torch.cat([raw, output["disease"].to("cpu")], dim = 0)
        labels = torch.cat([labels, diseaseStatusBatch.to("cpu")], dim = 0)
        

    predictions = torch.sigmoid(raw)
    
    #Return metrics:
    metrics["Loss"] = metrics["Loss"].item()/len(loader.dataset)
    metrics["Disease_Accuracy"] = BinaryClassAccuracy(preds = raw, labels = labels)
    
    metrics["AUC"] = Calc_ROC_AUC(preds = predictions, labels = labels)
    
    logger.info(f"Loss {metrics['Loss']:.5f}")
    logger.info(f"ACCURACY {metrics['Disease_Accuracy']:.5f}")
    logger.info(f"AUC {metrics['AUC']:.5f}")
    
    #Write predictions:
    predictionFile = pd.DataFrame([[el[0].item() for el in predictions], [el[0].item() for el in labels]], index = ["Predictions", "Labels"]).T
    predictionFile.index = IDs 
    predictionFile.to_csv(os.path.join(write_path, file_name), sep = "\t")
    
    logger.info(f"VADEr predictions written to: {os.path.join(write_path, file_name)}")
        
    return metrics

**Generate predictions and obtain dataset level performance across metrics:**

In [15]:
dataset_metrics = Eval_Model(model = vader_model, 
                            loader = dataloader, 
                            device = device, 
                            IDs = dataloader.dataset.datasetIDs,
                            write_path = prediction_write_path,
                            file_name = prediction_file_name,
                            mask = mask)

Loss 0.65670
ACCURACY 0.66400
AUC 0.72665
VADEr predictions written to: ../Predictions/67_VADEr_bs4096_lr8e-05_sc-cosine_with_warmup_wu95_clip10_dim768_mlp4_bl12_he12_pd0.2_md0.2_ag-cls_cls-learnable_act-True_at-MHA-LT_nr-8_ffAct-SwiGLU_mask-None_spt-False_pln-False_wd0.1_b2-0.99_nm-z_clumped.tsv
