In [272]:
import os 
os.chdir("/home/bkk4001/repos/LoMaR")

In [273]:
def get_example_config():  # without density, ver_1 data, visit aug off, no risk, with long early stop, survival.
    config = {}
    config['exp_id'] = 0
    # model
    config['n_past_visits'] = 5 # present point + 4 years of history
    config['n_future_dx'] = 5 # 5 follow-up years
    config['input_embedding_dim'] = 512
    config['model_embedding_dim'] = 128
    config['n_heads'] = 4
    config['global_do_rate'] = 0.1
    # config['model_weight_dir'] = "/midtier/sablab/scratch/bkk4001/miccai_breast/exp_34/model_weights/rt_1_rv_0_ri_0/i_param_7/model_weights.pth"
    config['model_weight_dir'] = ""
    # data
    config['path_to_csv'] = "demo/data/demo_metadata.csv"
    config['n_pseudo'] = 1 # number of pseudo test sets for evaluation
    # results
    config['results_dir'] = "demo/results/"
    return config

In [274]:
import torch
from torch.utils.data import Dataset
import pandas as pd
import os
import numpy as np
import copy

class BreastDataset(Dataset):
    def __init__(self, config, history_masking_id=None):
        # Store config.
        self.config = config

        # Load the master CSV that contains per-patient info:
        # - paths to per-visit embedding .npy files (e.g., npy_-4, ..., npy_0)
        # - future diagnosis labels (e.g., dx_1, ..., dx_K)
        # - split assignment (trainvaltest column)
        past_visits = pd.read_csv(config['path_to_csv'], low_memory=False)

        # Determine history length (number of past visits) from config.
        # We represent visits with relative year codes:
        #   n_past_visits=5  -> [-4, -3, -2, -1,  0]
        n_past_visits = int(config['n_past_visits'])
        history_years = list(range(-(n_past_visits - 1), 1))

        # Optional: synthetic "history masking" to ablate certain visit slots.
        # If enabled, we will treat some visits as missing even if they exist in the CSV.
        # Convention: pattern value 1 => mask (hide) this visit, 0 => keep it if present.
        history_masking_id_dict = None
        if history_masking_id is not None:
            patterns = {
                0: [1, 1, 1, 1, 0],
                1: [1, 1, 1, 0, 0],
                2: [1, 1, 0, 0, 0],
                3: [1, 0, 0, 0, 0],
                4: [0, 0, 0, 0, 0],
                5: [0, 1, 0, 1, 0],
                6: [1, 1, 0, 1, 0],
            }
            # Map each year code (e.g., -4..0) to a {0,1} masking decision.
            keys = history_years
            history_masking_id_dict = dict(zip(keys, patterns[history_masking_id]))

        # Save dataset state.
        self.history_masking_id = history_masking_id
        self.history_masking_id_dict = history_masking_id_dict

        # Keep a copy for indexing and record dataset size.
        self.past_visits = past_visits.copy()
        self.len_data = len(self.past_visits)

        # Map string labels in the CSV to numeric values.
        # Note: Unknown is mapped to -1; training code typically needs to ignore/handle these separately.
        self.label_dict = {
            "Not Malignant": 0,
            "Malignant": 1,
            "Unknown": -1,
        }

    def __len__(self):
        # Number of patient rows in this split.
        return int(self.len_data)

    def __getitem__(self, idx):
        sample = {}

        # Recompute history year codes from config to keep __getitem__ self-contained.
        n_past_visits = int(self.config['n_past_visits'])
        history_years = list(range(-(n_past_visits - 1), 1))
        viscodes = np.array(history_years)

        # Get the CSV row for this patient/sample.
        patient_data = self.past_visits.iloc[idx]

        # We will build:
        # - embeddings: list of per-visit embedding vectors (or zero vectors if missing)
        # - visit_mask: indicates "missing after masking" (0=present, 1=missing/masked)
        # - original_visit_mask: indicates "truly missing in data" (0=present, 1=missing)
        embeddings = []
        visit_mask = []
        original_visit_mask = []

        if self.history_masking_id_dict is not None:
            # Masking enabled: some existing visits are artificially treated as missing.
            for history_year in history_years:
                has_file = (pd.isna(patient_data['npy_' + str(history_year)]) == False)
                is_masked = (self.history_masking_id_dict[history_year] == 1)

                if has_file and (not is_masked):
                    # Load precomputed embedding from disk.
                    visit_embeddings = np.load(patient_data['npy_' + str(history_year)])
                    embeddings.append(visit_embeddings)
                    visit_mask.append(0)
                else:
                    # If missing or masked out, use a zero vector placeholder.
                    # (512 is assumed embedding dim; should match how embeddings were saved.)
                    embeddings.append(np.zeros((512)))
                    visit_mask.append(1)

                # Track the true missingness (ignoring synthetic masking).
                original_visit_mask.append(0 if has_file else 1)
        else:
            # No masking: missingness is purely determined by whether the file path exists in CSV.
            for history_year in history_years:
                if pd.isna(patient_data['npy_' + str(history_year)]) == False:
                    visit_embeddings = np.load(patient_data['npy_' + str(history_year)])
                    embeddings.append(visit_embeddings)
                    visit_mask.append(0)
                    original_visit_mask.append(0)
                else:
                    embeddings.append(np.zeros((512)))
                    visit_mask.append(1)
                    original_visit_mask.append(1)

        # Stack into arrays:
        # - visit_embeddings: [T, D]
        # - visit_mask: [T]
        # - original_visit_mask: [T]
        embeddings = np.stack(embeddings)
        visit_mask = np.array(visit_mask)
        original_visit_mask = np.array(original_visit_mask)

        # Future survival labels across horizons 1..K where K = config['n_future_dx'].
        # The CSV is expected to have columns: dx_1, dx_2, ..., dx_K.
        n_future_dx = int(self.config['n_future_dx'])
        surv_cols = ['dx_' + str(z) for z in range(1, n_future_dx + 1)]

        # Map string labels to numeric; missing values are treated as "Unknown" (-1).
        label = np.array([float(self.label_dict[z]) for z in patient_data[surv_cols].fillna('Unknown')])

        # Package sample dict (NumPy arrays). Often converted to torch tensors in a collate_fn or training loop.
        sample['visit_embeddings'] = embeddings
        sample['visit_mask'] = visit_mask
        sample['original_visit_mask'] = np.array(original_visit_mask)
        sample['viscodes'] = viscodes
        sample['label'] = label

        return sample


In [275]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader


import time

import torch
import copy 
import numpy as np 
import pandas as pd 
from sklearn.metrics import roc_auc_score

# from lomar.model import *
# from demo.dataset import *
exec(open("/home/bkk4001/repos/LoMaR/lomar/model.py").read())


def compute_metrics(outputs, labels):
    """
    Compute ROC-AUC metrics for a multi-horizon prediction setting.

    Args:
        outputs: numpy array of shape [B, K] with predicted scores/logits per horizon.
        labels:  numpy array of shape [B, K] with binary labels {0,1} or -1 for unknown.

    Returns:
        List: [avg_rocauc, rocauc_year1, rocauc_year2, ...]
              avg_rocauc ignores NaNs (years where ROC-AUC is undefined).
    """
    rocauc_scores = []

    for i in range(labels.shape[1]):
        # Slice labels/predictions for horizon i (e.g., year i+1).
        year_labels = labels[:, i]
        year_predictions = outputs[:, i]

        # Only evaluate on valid labels (here: -1 means Unknown).
        mask = (year_labels != -1)

        # ROC-AUC is defined only if both classes {0,1} are present in the filtered labels.
        if len(np.unique(year_labels[mask])) == 2:
            rocauc_score = roc_auc_score(year_labels[mask], year_predictions[mask])
            rocauc_scores.append(rocauc_score)
        else:
            rocauc_scores.append(float('nan'))

    # Average ROC-AUC across horizons, ignoring undefined horizons (NaNs).
    avg_rocauc = np.nanmean(rocauc_scores)
    
    # Return average first, then per-horizon scores.
    return [avg_rocauc] + rocauc_scores

def evalute(model, test_dataset, test_loader, config):
        """
        Run a single forward pass over the entire test_loader (assumed full-batch),
        then compute pseudo-group ROC-AUC metrics based on dataset-provided boolean
        columns (pseudo_0, pseudo_1, ...).

        Note:
            - This function currently takes only the *first batch* from test_loader.
              In the demo() setup, batch_size=len(test_dataset), so this corresponds
              to evaluating the full dataset in one pass.
        """
        model.eval()
        with torch.no_grad():
            # Fetch a single batch (in this demo: the full dataset).
            batch = next(iter(test_loader))

            # Labels: [B, K] where K = number of future horizons (e.g., 5 years).
            labels = batch['label'].to(config['torch_device'])

            # Model outputs: expected [B, K].
            outputs = model(batch)

        # Collect results per pseudo split/group.
        pseudo_results = pd.DataFrame()
        for i_pseudo in range(config['n_pseudo']):
            # Ensure there is a boolean column pseudo_i in the dataset dataframe.
            # If absent, default to True for all rows (i.e., use all samples).
            
            # Note that if you are using a subject multiple times in your meta csv, 
            # you must set the pseudo indices so that each subject is used only once 
            # in a pseudo test set. Otherwise your evaluation will be biased. Please 
            # refer to the paper for more info.
            if "pseudo_"+str(i_pseudo) in test_dataset.past_visits.columns:
                pass 
            else: 
                test_dataset.past_visits["pseudo_"+str(i_pseudo)] = True

            # Boolean indexing mask over rows/samples.
            indices = test_dataset.past_visits["pseudo_"+str(i_pseudo)]

            # Filter predictions/labels for this pseudo group and compute ROC-AUCs.
            pseudo_preds = outputs[indices].detach().cpu().numpy()
            pseudo_labels = labels[indices].detach().cpu().numpy()
            pseudo_rocauc = compute_metrics(pseudo_preds, pseudo_labels)

            # Store results in a dataframe for easy averaging/printing.
            pseudo_results.loc[i_pseudo, 'history_masking_id'] = test_dataset.history_masking_id
            cols = ['1_year', '2_year', '3_year', '4_year', '5_year']
            pseudo_results.loc[i_pseudo, cols] = pseudo_rocauc[1:6]

        # Average across pseudo groups; indexed by history_masking_id for convenience.
        average_pseudo_results = pd.DataFrame(columns=pseudo_results.columns)
        average_pseudo_results.loc[test_dataset.history_masking_id] = pseudo_results.mean()

        # Pack outputs for downstream analysis / saving.
        res = {}
        res['labels'] = labels.detach().cpu().numpy()
        res['outputs'] = outputs.detach().cpu().numpy()
        res['average_pseudo_results'] = average_pseudo_results
        return res

def demo(config):
    """
    Demo inference script:
      - Loads a trained LoMaR checkpoint
      - Evaluates across multiple history masking settings (0..6)
      - Computes pseudo-group ROC-AUC metrics
      - Saves all outputs into an .npz log file
    """

    the_log = {}
    the_log['config'] = config 

    # Pick device automatically (GPU if available).
    config['torch_device'] = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    print("config:")
    print(config)
  
    # Create model and load pretrained weights.
    model = LoMaR(config)
    if config['model_weight_dir']:
        model.load_state_dict(torch.load(config['model_weight_dir']), strict=True)
    model.eval()
    print("Loaded the model, moving to inference.")
    
    inference_log = {}

    # Evaluate multiple synthetic history masking settings (ablation / robustness check).
    for history_masking_id in range(7):
     
        # Dataset is expected to apply the requested history masking internally.
        test_dataset =  BreastDataset(config, history_masking_id)

        # Full-batch loader: one batch contains all samples.
        test_loader = DataLoader(test_dataset, batch_size=len(test_dataset), shuffle=False)

        print("Starting evaluation for history_masking_id:", history_masking_id)
        start_time = time.time()
        test_evaluation_results = evalute(model, test_dataset, test_loader, config)
        end_time = time.time()
        print(f"Execution time: {end_time - start_time:.2f} seconds")
        
        print(test_evaluation_results['average_pseudo_results'])
        
        # Store results under the masking id key.
        inference_log[test_dataset.history_masking_id] = test_evaluation_results
        
    # Save the full inference log to disk.
    the_log['inference'] = inference_log
    log_dir = config['results_dir']
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)
    
    np.savez(log_dir+'/log_inference.npz', **the_log)       
    print("Saved results to:")
    print(log_dir+'exp_'+str(config['exp_id'])+'_log_inference.npz')
    print('Done!')


In [276]:
# past_visits = pd.read_csv('/midtier/sablab/scratch/bkk4001/karolinska/splits_0305_aicluster/rt_'+str(1)+'_rv_'+str(0)+'/'+'past_visits_0305.csv', low_memory=False)
# past_visits = past_visits.loc[past_visits['trainvaltest'] == 'test'].reset_index(drop=True)
# past_visits = past_visits.drop_duplicates(subset=['subject'], keep='last').reset_index(drop=True)
# past_visits = past_visits.iloc[:100]

# meta_cols = ['subject']
# dx_cols = [z for z in past_visits.columns if "dx_" in z]
# imageuid_cols = ["imageuid_"+str(z) for z in [-4, -3, -2, -1, 0]]
# dx_cols.sort()
# # rename: imageuid_*  -> npy_*
# rename_map = {c: c.replace("imageuid_", "npy_") for c in imageuid_cols}
# past_visits = past_visits.rename(columns=rename_map)

# # after renaming, update the list of cols to keep
# npy_cols = [rename_map[c] for c in imageuid_cols]


# for i in range(len(past_visits)):
#     npy_cols = ["npy_"+str(z) for z in [-4, -3, -2, -1, 0]]
#     for npy_col in npy_cols:
#         row = past_visits.loc[i, npy_col]
#         if pd.isna(row) == False:
#             npy = np.load(row)
#             # print(row)
#             row = 'demo/data/npy_files/' + row.split('/')[-1]
#             # print(row)
#             np.save(row, npy[:512])
#             past_visits.loc[i, npy_col] = row    
            
# past_visits = past_visits[meta_cols + dx_cols + npy_cols]
# past_visits.to_csv("/home/bkk4001/repos/LoMaR/demo/data/demo_metadata.csv", index=False)


# past_visits


In [277]:
config = get_example_config()
demo(config)

config:
{'exp_id': 0, 'n_past_visits': 5, 'n_future_dx': 5, 'input_embedding_dim': 512, 'model_embedding_dim': 128, 'n_heads': 4, 'global_do_rate': 0.1, 'model_weight_dir': '', 'path_to_csv': 'demo/data/demo_metadata.csv', 'n_pseudo': 1, 'results_dir': 'demo/results/', 'torch_device': device(type='cuda')}
Loaded the model, moving to inference.
Starting evaluation for history_masking_id: 0
Execution time: 0.08 seconds
   history_masking_id   1_year    2_year    3_year    4_year    5_year
0                 0.0  0.23435  0.297203  0.356061  0.541667  0.583333
Starting evaluation for history_masking_id: 1
Execution time: 0.09 seconds
   history_masking_id    1_year    2_year    3_year    4_year    5_year
1                 1.0  0.210273  0.291958  0.348485  0.520833  0.583333
Starting evaluation for history_masking_id: 2
Execution time: 0.11 seconds
   history_masking_id    1_year    2_year    3_year    4_year    5_year
2                 2.0  0.396469  0.382867  0.409091  0.583333  0.666667