# Demo Notebook:
## DeSurv

In [10]:
import os
from pathlib import Path
import sys
node_type = os.getenv('BB_CPU')
venv_dir = f'/rds/homes/g/gaddcz/Projects/CPRD/virtual-envTorch2.0-{node_type}'
venv_site_pkgs = Path(venv_dir) / 'lib' / f'python{sys.version_info.major}.{sys.version_info.minor}' / 'site-packages'
if venv_site_pkgs.exists():
    sys.path.insert(0, str(venv_site_pkgs))
    print(f"Added path '{venv_site_pkgs}' at start of search paths.")
else:
    print(f"Path '{venv_site_pkgs}' not found. Check that it exists and/or that it exists for node-type '{node_type}'.")

%load_ext autoreload
%autoreload 2

Added path '/rds/homes/g/gaddcz/Projects/CPRD/virtual-envTorch2.0-icelake/lib/python3.10/site-packages' at start of search paths.
The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [11]:
import pytorch_lightning
import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import random
import sqlite3
from dataclasses import dataclass
import logging
from CPRD.data.foundational_loader import FoundationalDataModule
import pickle
from tqdm import tqdm

from pycox.datasets import support
from pycox.evaluation import EvalSurv
from sklearn.preprocessing import StandardScaler
from sklearn_pandas import DataFrameMapper
from torch.utils.data import TensorDataset, DataLoader

from CPRD.src.modules.head_layers.survival.desurv import ODESurvSingle
from CPRD.src.modules.head_layers.survival.desurv import ODESurvMultiple

torch.manual_seed(1337)
logging.basicConfig(level=logging.INFO)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# device = "cpu"    # if more informative debugging statements are needed
print(f"Using device: {device}.")

Using device: cuda.


# Load data

In [12]:
def get_dataloaders(dataset, competing_risk, sample_size=None):

    match dataset.lower():
        case "pycox":
            df_train = support.read_df()
            df_test = df_train.sample(frac=0.2)
            df_train = df_train.drop(df_test.index)
            df_val = df_train.sample(frac=0.2)
            df_train = df_train.drop(df_val.index)
            
            cols_standardize = ['x0', 'x7', 'x8', 'x9', 'x10', 'x11', 'x12', 'x13']
            cols_leave = ['x1', 'x2', 'x3', 'x4', 'x5', 'x6']
            
            standardize = [([col], StandardScaler()) for col in cols_standardize]
            leave = [(col, None) for col in cols_leave]
            
            x_mapper = DataFrameMapper(standardize + leave)
            
            x_train = x_mapper.fit_transform(df_train).astype('float32')
            x_val = x_mapper.transform(df_val).astype('float32')
            x_test = x_mapper.transform(df_test).astype('float32')
            
            get_target = lambda df: (df['duration'].values, df['event'].values)
            y_train = get_target(df_train)
            y_val = get_target(df_val)
            y_test = get_target(df_test)
            
            t_train, e_train = y_train
            t_val, e_val = y_val
            t_test, e_test = y_test
            
            t_train_max = np.amax(t_train)
            t_train = t_train / t_train_max
            t_val = t_val / t_train_max
            t_test = t_test / t_train_max
            
    
        case "hypertension" | "cvd":
    
            # Training samples
            if sample_size is not None:
                save_path =  f"/rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_{dataset}/" + f"benchmark_data/N={sample_size}.pickle" 
            else:
                save_path = f"/rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_{dataset}/" + "benchmark_data/all.pickle"
                
            with open(save_path, "rb") as handle:
                print(f"Loading training dataset from {save_path}")
                data_train = pickle.load(handle)
            
            # display(data["X_train"].head())
            # display(data["y_train"])
            # print(data.keys())
            
            data = {}
            data["X_train"] = data_train["X_train"]
            data["y_train"] = data_train["y_train"]
    
            # Test and validation samples
    
            save_path = f"/rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_{dataset}/" + "benchmark_data/all.pickle"
            with open(save_path, "rb") as handle:
                print(f"Loading validation/test datasets from {save_path}")
                data_val_test = pickle.load(handle)
                
            data["X_val"] = data_val_test["X_val"]
            data["y_val"] = data_val_test["y_val"]
            data["X_test"] = data_val_test["X_test"]
            data["y_test"] = data_val_test["y_test"]
    
            # Convert to correct formats
            x_train = data["X_train"].to_numpy(dtype=np.float32)
            x_val = data["X_val"].to_numpy(dtype=np.float32)
            x_test = data["X_test"].to_numpy(dtype=np.float32)
            
            t_train = np.asarray([i[1] for i in data["y_train"]])
            t_val = np.asarray([i[1] for i in data["y_val"]])        
            t_test = np.asarray([i[1] for i in data["y_test"]])
    
            if competing_risk is False:
                e_train = np.asarray([0 if i[0] == 0 else 1 for i in data["y_train"]])
                e_val = np.asarray([0 if i[0] == 0 else 1 for i in data["y_val"]])
                e_test = np.asarray([0 if i[0] == 0 else 1 for i in data["y_test"]])
            else:
                e_train = np.asarray([i[0] for i in data["y_train"]])
                e_val = np.asarray([i[0] for i in data["y_val"]])
                e_test = np.asarray([i[0] for i in data["y_test"]])

    # display(x_train.shape)
    # display(type(x_train))
    # display(type(x_train[0,0]))
    # display(e_train.shape)
    # display(type(e_train))
    # display(type(e_train[0]))
    # display(t_train.shape)
    # display(type(t_train))
    # display(type(t_train[0]))
    # print(np.mean(e_train))
    # print(np.mean(t_train))
    # print(np.std(t_train))
    # print(np.mean(x_train))
    # print(t_train.min())
    # print(t_train.max())
    # print(np.unique(e_test, return_counts=True))
    
    
    batch_size = 256
    dataset_train = TensorDataset(*[torch.tensor(u,dtype=dtype_) for u, dtype_ in [(x_train,torch.float32),
                                                                                   (t_train,torch.float32),
                                                                                   (e_train,torch.long)]])
    data_loader_train = DataLoader(dataset_train, batch_size=batch_size, pin_memory=True, shuffle=True, drop_last=True)
    
    dataset_val = TensorDataset(*[torch.tensor(u,dtype=dtype_) for u, dtype_ in [(x_val,torch.float32),
                                                                                   (t_val,torch.float32),
                                                                                   (e_val,torch.long)]])
    data_loader_val = DataLoader(dataset_val, batch_size=batch_size, pin_memory=True, shuffle=True)
    
    dataset_test = TensorDataset(*[torch.tensor(u,dtype=dtype_) for u, dtype_ in [(x_test,torch.float32),
                                                                                   (t_test,torch.float32),
                                                                                   (e_test,torch.long)]])
    data_loader_test = DataLoader(dataset_test, batch_size=batch_size, pin_memory=True, shuffle=True)

    return data_loader_train, data_loader_val, data_loader_test




# Example dataloader function usage

In [13]:
data_loader_train, data_loader_val, data_loader_test = get_dataloaders("CVD", False, sample_size=None)

Loading training dataset from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_CVD/benchmark_data/all.pickle
Loading validation/test datasets from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_CVD/benchmark_data/all.pickle


In [14]:
for batch in data_loader_test:
    x_train = batch[0]
    print(batch[0].shape)
    print(batch[1][0])
    print(batch[2][1])
    break

torch.Size([256, 279])
tensor(0.9167)
tensor(1)


# Train model
```
Hypertension

SR
[0.768850848388541]
[0.08287342061953376]
[0.2629869386534056]

CVD

SR
[0.6660374450461991]
[0.03355264167471281]
[0.14207857084831382]

CR
[0.6622136902603167]
[0.033540062894645686]
[0.142287892974895]
```

In [None]:
dataset = "CVD" # "Hypertension"
competing_risk = True
sample_sizes = [int(np.exp(_log_n)) for _log_n in np.linspace(np.log(3000), np.log(500000), 10)]      # [3000, 12500, 30000, 60000, 100000]: # 600, 1200, 
# sample_sizes = [None]
sample_sizes = sample_sizes[-1:]

training = True
lr = 1e-3
xdim = x_train.shape[1]


# the time grid which we generate over
t_eval = np.linspace(0, 1, 1000) 
# the time grid which we calculate scores over
time_grid = np.linspace(start=0, stop=1 , num=300)

In [15]:

models, model_names = [], []
if competing_risk is False:
    models.append(ODESurvSingle(xdim, [32, 32], device=device))
    model_names.append(f"{dataset}_sr")
else:
    models.append(ODESurvMultiple(xdim, [32, 32], num_risks=5))
    model_names.append(f"{dataset}_cr")

all_ctd, all_ibs, all_inbll = [], [], []
for model_name, model in zip(model_names, models):

    print(f"\n\n{model_name} with {sum(p.numel() for p in model.parameters() if p.requires_grad)} parameters")

    for sample_size in sample_sizes:
    
        data_loader_train, data_loader_val, data_loader_test = get_dataloaders(dataset, competing_risk, sample_size=sample_size)
    
        if training:
            print(f"Training")
            model.optimize(data_loader_train, n_epochs=20, logging_freq=1, data_loader_val=data_loader_val, max_wait=2)
            print("finished training")
            torch.save(model.state_dict(), model_name + "tst_model")
            model.eval()
    
        print(f"Testing")    
        state_dict = torch.load(model_name + "tst_model")
        model.load_state_dict(state_dict)
        model.eval()
        
        # argsortttest = np.argsort(t_test)
        # t_test = t_test[argsortttest]
        # e_test = e_test[argsortttest]
        # x_test = x_test[argsortttest,:]
        
        with torch.no_grad():
    
            ctd = []
            ibs = []
            inbll = []
            for batch in tqdm(data_loader_test, total=(len(data_loader_test)), desc="Testing"):
    
                x_test = batch[0].numpy()
                t_test = batch[1].numpy()
                e_test = batch[2].numpy()
    
                # The normalised grid over which to predict
                t_test_grid = torch.tensor(np.concatenate([t_eval] * x_test.shape[0], 0), dtype=torch.float32)
                x_test_grid = torch.tensor(x_test, dtype=torch.float32).repeat_interleave(t_eval.size, 0)
                
                pred_bsz = 51200
                pred = []
                for x_test_batched, t_test_batched in zip(torch.split(x_test_grid, pred_bsz), torch.split(t_test_grid, pred_bsz)):
                    
                    if competing_risk is False:
                        pred_ = model.predict(x_test_batched, t_test_batched)          # shape: (x_test.batched.shape[0],)
                    else:
                        pred_, pi_  = model.predict(x_test_batched, t_test_batched)    # shape: (x_test.batched.shape[0], num_outcomes)
                    pred.append(pred_)
                        
                pred = torch.concat(pred)
            
                pred = pred.reshape((x_test.shape[0], t_eval.size, -1)).cpu().detach().numpy()
                preds = [pred[:, :, _i] for _i in range(pred.shape[-1])]
                # print([_.shape for _ in preds])
        
                # Merge (additively) each outcome risk curve into a single CDF, and update label for if outcome occurred or not
                cdf = np.zeros_like(preds[0])
                lbls = np.zeros_like(e_test)     
                for _outcome_token in np.unique(e_test)[1:]:
                    # print(f"{_outcome_token} of {np.unique(e_test)[1:]} included from {len(preds)} surv CDFs")
                    # print(_outcome_token)
                    cdf += preds[_outcome_token - 1] 
                    lbls += (e_test == _outcome_token)
                
                surv = pd.DataFrame(np.transpose((1 - cdf.reshape((x_test.shape[0],t_eval.size)))), index=t_eval)
    
                # Evaluate surv curve with unscaled index with unscaled test times to event 
                ev = EvalSurv(surv, t_test, lbls, censor_surv='km')
                try:
                    # Same treatment as in SurvivEHR
                    ctd.append(ev.concordance_td())
                    ibs.append(ev.integrated_brier_score(time_grid))
                    inbll.append(ev.integrated_nbll(time_grid))
                except:
                    pass
                
            ctd = np.mean(ctd)
            ibs = np.mean(ibs)
            inbll = np.mean(inbll)

            print(f"{model_name}:".ljust(20) + f"N={sample_size}.".ljust(15) + f"Ctd: {ctd}. IBS: {ibs}. INBLL: {inbll}")

        all_ctd.append(ctd)
        all_ibs.append(ibs)
        all_inbll.append(inbll)




CVD_cr with 20394 parameters
Loading training dataset from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_CVD/benchmark_data/N=500000.pickle
Loading validation/test datasets from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/FineTune_CVD/benchmark_data/all.pickle
Training
	Epoch:  0. Total loss:   296443.99
best_epoch: 0
	Epoch:  0. Total val loss:    18485.24
	Epoch:  1. Total loss:   286393.12
best_epoch: 1
	Epoch:  1. Total val loss:    18413.58
	Epoch:  2. Total loss:   284411.50
	Epoch:  2. Total val loss:    18414.11
	Epoch:  3. Total loss:   282880.24
	Epoch:  3. Total val loss:    18434.12
	Epoch:  4. Total loss:   281413.23
finished training
Testing


Testing: 100%|██████████| 140/140 [07:51<00:00,  3.36s/it]

CVD_cr:             N=500000.      Ctd: 0.6582288879213832. IBS: 0.033605688713984706. INBLL: 0.14273257727425628





In [None]:
print(surv.head())
print(surv.shape)
print(preds[0].shape)

In [None]:
# sample_sizes[-1] = 572096
print(sample_sizes)
print(all_ctd)
print(all_ibs)
print(all_inbll)

# Output across different setups

Hypertension Single Risk

Hypertension Competing Risk

In [2]:
# Hypertension_cr:N=2999.        Ctd: 0.6815094745489083. IBS: 0.09055914251601428. INBLL: 0.31029164842950396
# Hypertension_cr:N=5296.        Ctd: 0.7124953181062986. IBS: 0.08950243267368668. INBLL: 0.2998138831024105
# Hypertension_cr:N=9351.        Ctd: 0.7274689831170883. IBS: 0.08788381776138605. INBLL: 0.29219335002924224
# Hypertension_cr:N=16509.       Ctd: 0.730682386913791. IBS: 0.08713170539614512. INBLL: 0.2915679214433027
# Hypertension_cr:N=29148.       Ctd: 0.7389774246331557. IBS: 0.08589961516053798. INBLL: 0.2797340688171351
# Hypertension_cr:N=51461.       Ctd: 0.7411341306621856. IBS: 0.08509218669879787. INBLL: 0.27232015638634294
# Hypertension_cr:N=90856.       Ctd: 0.7515646895266107. IBS: 0.08443103008855017. INBLL: 0.27128357563201616
# Hypertension_cr:N=160407.      Ctd: 0.7600540925292503. IBS: 0.08451878180772691. INBLL: 0.2692630746813842
# Hypertension_cr:N=283203.      Ctd: 0.7620305063798369. IBS: 0.08404512042043363. INBLL: 0.2666228576617026
# Hypertension_cr:N=500000.      Ctd: 0.761360566628242. IBS: 0.08371696663855675. INBLL: 0.26584980102348854

Cardiovascular disease Single Risk

Cardiovascular disease Competing Risk

In [None]:
# CVD_cr:          N=2999.        Ctd: 0.5625189849883911. IBS: 0.03435665860871859. INBLL: 0.14990771301641295
# CVD_cr:          N=5296.        Ctd: 0.5869489534497859. IBS: 0.03411933684032459. INBLL: 0.14823684700350284
# CVD_cr:          N=9351.        Ctd: 0.5993921885610577. IBS: 0.03442101009915632. INBLL: 0.14955773914208445
# CVD_cr:          N=16509.       Ctd: 0.6101382510977712. IBS: 0.03387871579676817. INBLL: 0.14660058710195034
# CVD_cr:          N=29148.       Ctd: 0.6200633894187596. IBS: 0.033806085396407184. INBLL: 0.1458525721893707
# CVD_cr:          N=51461.       Ctd: 0.624818164536866. IBS: 0.033786555455045275. INBLL: 0.14535719048483275
# CVD_cr:          N=90856.       Ctd: 0.6355840601226502. IBS: 0.03364728847608413. INBLL: 0.14414393636167722
# CVD_cr:          N=160407.      Ctd: 0.6469820170508909. IBS: 0.033556625283347644. INBLL: 0.14308219809140393
# CVD_cr:          N=283203.      Ctd: 0.6557505764681258. IBS: 0.033524955828030355. INBLL: 0.1423563785027365
# CVD_cr:          N=500000.      Ctd: 0.6582288879213832. IBS: 0.033605688713984706. INBLL: 0.14273257727425628