In [1]:
import os
from pathlib import Path
import sys
node_type = os.getenv('BB_CPU')
venv_dir = f'/rds/homes/g/gaddcz/Projects/CPRD/virtual-envTorch2.0-{node_type}'
venv_site_pkgs = Path(venv_dir) / 'lib' / f'python{sys.version_info.major}.{sys.version_info.minor}' / 'site-packages'
if venv_site_pkgs.exists():
    sys.path.insert(0, str(venv_site_pkgs))
    print(f"Added path '{venv_site_pkgs}' at start of search paths.")
else:
    print(f"Path '{venv_site_pkgs}' not found. Check that it exists and/or that it exists for node-type '{node_type}'.")

%load_ext autoreload
%autoreload 2

Added path '/rds/homes/g/gaddcz/Projects/CPRD/virtual-envTorch2.0-icelake/lib/python3.10/site-packages' at start of search paths.


In [2]:
import pytorch_lightning
import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import random
import sqlite3
from dataclasses import dataclass
import logging
from CPRD.data.foundational_loader import FoundationalDataModule
import pickle
from tqdm import tqdm

from pycox.datasets import support
from pycox.evaluation import EvalSurv
from sklearn.preprocessing import StandardScaler
from sklearn_pandas import DataFrameMapper
from torch.utils.data import TensorDataset, DataLoader

from CPRD.src.modules.head_layers.survival.desurv import ODESurvSingle
from CPRD.src.modules.head_layers.survival.desurv import ODESurvMultiple

torch.manual_seed(1337)
logging.basicConfig(level=logging.INFO)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# device = "cpu"    # if more informative debugging statements are needed
print(f"Using device: {device}.")

Using device: cuda.


# Results, collated from other notebooks

In [3]:
dataset = "CVD"
risks = "both SR and CR"

# All results
if dataset == "CVD":
    results_cr = {
        "Benchmark:DeSurv": (
            "DeSurv",
            [2999,                5296,                9351,                16509,               29148,                51461,                   90856,                 160407,              283203,               500000],
            [0.5625189849883911,  0.5869489534497859,  0.5993921885610577,  0.6101382510977712,  0.6200633894187596,   0.624818164536866,      0.6355840601226502,    0.6469820170508909,   0.6557505764681258,   0.6582288879213832],
            [0.03435665860871859, 0.03411933684032459, 0.03442101009915632, 0.03387871579676817, 0.033806085396407184, 0.033786555455045275,   0.03364728847608413,   0.033556625283347644, 0.033524955828030355, 0.033605688713984706],
            [0.14990771301641295, 0.14823684700350284, 0.14955773914208445, 0.14660058710195034, 0.1458525721893707,   0.14535719048483275,    0.14414393636167722,   0.14308219809140393,  0.1423563785027365,   0.14273257727425628],
        ),
        "SurvivEHR-cr-small-v1_cvd-fine-tune-cr-AFalse-EffecBsz128": (
            "SurvivEHR-CR: CR: Full training, 128 effective batch size",
            [2999, 5296, 9351, 16509, 29148, 51461, 90856, 160407, 283203, 500000],
            [0.6203, 0.62171, 0.62933, 0.62968, 0.64305, 0.64297],
            [0.034028, 0.033998, 0.033902, 0.033906, 0.03379, 0.033784],
            [0.14803, 0.14687, 0.14598, 0.1455, 0.14443, 0.14485]
        ),
        "SurvivEHR-cr-small-v1_cvd-fine-tune-cr-AFalse-EffecBsz640": (
            "SurvivEHR-CR: CR: Full training, 640 effective batch size",
            [2999, 5296, 9351, 16509, 29148, 51461, 90856, 160407, 283203, 500000],
            [0.63283, 0.63518, 0.63523, 0.63006, 0.64464, 0.64761],
            [0.033816, 0.033746, 0.033927, 0.033844, 0.033705, 0.033781],
            [0.14525, 0.14487, 0.14644, 0.14527, 0.14395, 0.14466],
        ),
        "SurvivEHR-cr-small-v1_cvd-fine-tune-cr-AFalse-EffecBsz1280": (
            "SurvivEHR-CR: CR: Full training, 1280 effective batch size",
            [2999,     5296,     9351,     16509,    29148,    51461,    90856, 160407, 283203, 500000],
            [0.62325,  0.6291,   0.63017,  0.63448,  0.64444,  0.65062],
            [0.033763, 0.033963, 0.033802, 0.033883, 0.033759, 0.033783],
            [0.14508,  0.14732,  0.14538,  0.14572,  0.1449,   0.14478]
        ),
        "SurvivEHR-cr-small-v1_cvd-fine-tune-cr-A8-EffecBsz640": (
            "SurvivEHR-CR: CR: Adapter (8 hidden), 640 effective batch size",
            [2999,     5296,     9351,     16509,   29148,    51461,    90856, 160407, 283203, 500000],
            [0.62137,  0.61988,  0.63025,  0.63053, 0.63697,  0.64247,  0.64748],
            [0.033802, 0.033907, 0.033817, 0.03383, 0.033796, 0.033779, 0.033728],
            [0.14544,  0.14594,  0.14555,  0.14537, 0.14467,  0.14442,  0.14404],
        ),
        "SurvivEHR-cr-small-v1_cvd-fine-tune-cr-A128-EffecBsz640": (
            "SurvivEHR-CR: CR: Adapter (128 hidden), 640 effective batch size",
            [2999,     5296,     9351,    16509,    29148,    51461, 90856, 160407, 283203, 500000],
            [0.61522,  0.60889,  0.62987, 0.63499,  0.63891],
            [0.033899, 0.033901, 0.03393, 0.033873, 0.033802],
            [0.14622,  0.1466,   0.14623, 0.14513,  0.1446],
        ),
        "SurvivEHR-cr-small-v1_cvd-fine-tune-cr-A256-EffecBsz640": (
            "SurvivEHR-CR: CR: Adapter (256 hidden), 640 effective batch size",
            [2999,     5296,     9351,    16509,    29148,    51461, 90856, 160407, 283203, 500000],
            [0.61946,  0.61588,  0.62564, 0.62541,  0.62419,  0.63924],
            [0.033839, 0.033842, 0.03383, 0.033871, 0.033875, 0.033782],
            [0.14563,  0.14579,  0.14559, 0.14617,  0.14625,  0.14468],
        ),
        
        
    }
    
    results_sr = {
        "Benchmark: Random Survival Forest": (
            "Survival Random Forest (SR)",
            [2999, 5296, 9351, 16509, 29148, 51461, 90856, 160407, 283203, 500000],
            [0.5817231098874919, 0.5726886646549809, 0.5883719280021017, 0.5960278990576237, 0.5967549074819676, 0.6081889934667545, 0.6091873754214131, 0.6068623782034415, 0.6140593631057283, 0.6118567664875652],
            [0.03399845492489739, 0.03396416779973878, 0.03384484456385947, 0.0337225214683957, 0.033762424117441375, 0.03375098553961131, 0.033695391903505144, 0.033734096545360415, 0.03370825445709043, 0.033727277767864286],
            [0.14861047476033662, 0.149062513427203, 0.14750556866346265, 0.14647954529430643, 0.14621266423294518, 0.1457478005690567, 0.1452056774612423, 0.14558305134538546, 0.14507898552247775, 0.14548991693533334],
        )
    }

elif dataset == "Hypertension":

    results_cr = {
        
    }
    
    results_sr = {
        "Benchmark:DeSurv": (
            
            [2999, 5296, 9351, 16509, 29148, 51461, 90856, 160407, 283203, 500000],
            [0.6815094745489083, 0.7124953181062986, 0.7274689831170883, 0.730682386913791, 0.7389774246331557, 0.7411341306621856, 0.7515646895266107, 0.7600540925292503, 0.7620305063798369, 0.761360566628242],
            [0.09055914251601428, 0.08950243267368668, 0.08788381776138605, 0.08713170539614512, 0.08589961516053798, 0.08509218669879787, 0.08443103008855017, 0.08451878180772691, 0.08404512042043363, 0.08371696663855675],
            [0.31029164842950396, 0.2998138831024105, 0.29219335002924224, 0.2915679214433027, 0.2797340688171351, 0.27232015638634294, 0.27128357563201616, 0.2692630746813842, 0.2666228576617026, 0.26584980102348854]
        ),
    }
    


if risks == "cr":
    results = results_cr
elif risks == "sr":
    results = results_sr
else:
    results = {**results_sr, **results_cr}

display(results.keys())

dict_keys(['Benchmark: Random Survival Forest', 'Benchmark:DeSurv', 'SurvivEHR-cr-small-v1_cvd-fine-tune-cr-AFalse-EffecBsz128', 'SurvivEHR-cr-small-v1_cvd-fine-tune-cr-AFalse-EffecBsz640', 'SurvivEHR-cr-small-v1_cvd-fine-tune-cr-AFalse-EffecBsz1280', 'SurvivEHR-cr-small-v1_cvd-fine-tune-cr-A8-EffecBsz640', 'SurvivEHR-cr-small-v1_cvd-fine-tune-cr-A128-EffecBsz640', 'SurvivEHR-cr-small-v1_cvd-fine-tune-cr-A256-EffecBsz640'])

# Plot sweep results

In [4]:
def get_sweep_filter(approach):
    
    match approach.lower():
        case "all":
            # look at all runs
            sweep_filter = lambda run_id: True
        case "full_training":
            # look at all full training runs
            sweep_filter = lambda run_id: True if "AFalse" in run_id else False
        case "adapter_size":
            # look at effect of adapter size
            sweep_filter = lambda run_id: True if "AFalse" not in run_id else False
        case "using_adapter":
            # look at effect of using adapter
            sweep_filter = lambda run_id: True if "640" in run_id else False
        case "benchmark":
            # Look at some subset of SurvivEHR models and benchmarks
            sweep_filter = lambda run_id: True if "Benchmark" in run_id \
                                                or "SurvivEHR-cr-small-v1_cvd-fine-tune-cr-A8-EffecBsz640" == run_id \
                                                or "SurvivEHR-cr-small-v1_cvd-fine-tune-cr-AFalse-EffecBsz128" == run_id  else False
        case _:
            raise NotImplementedError
            
    return sweep_filter


In [6]:
# filter_approach = "using_adapter"


for filter_approach in ["all", "full_training", "adapter_size", "using_adapter", "benchmark"]:

    sweep_filter = get_sweep_filter(filter_approach)
    
    fig, axs = plt.subplots(1,3, figsize=(25,6))
    fig.suptitle(f"Metrics for {dataset}: {risks}")
    
    ylabels = ["Concordance (time-dependent)", "Integrated Brier Score", "Integrated Negative Bernoulli log-likelihood"]
    for metric_idx, ax in enumerate(axs):
        for key in results.keys():
            
            if sweep_filter(key):
                
                y = results[key][metric_idx + 2]
                x = results[key][1][:len(y)]
                
                if "benchmark" in key.lower():
                    linestyle='dashed'
                    marker="*"
                else:
                    linestyle='solid'
                    marker="o"
    
                if len(y) > 1:
                    ax.plot(x, y,
                            label=results[key][0],
                            lw=1.5,
                            linestyle=linestyle,
                            markersize=5,
                            marker=marker
                           )
                else:
                    ax.scatter(x, y,
                               label=results[key][0]
                              )
    
        ax.set_xlabel("Cohort sample size (log scale)")
        ax.set_ylabel(ylabels[metric_idx])
        ax.set_xscale('log')
        ax.set_xticks([2999, 5296, 9351, 16509, 29148, 51461, 90856, 160407, 283203, 500000],
                      [2999, 5296, 9351, 16509, 29148, 51461, 90856, 160407, 283203, 500000], 
                      rotation=90)
            
            
    plt.legend(loc="right", bbox_to_anchor=(1, 0., 1, 1))
    plt.tight_layout()
    plt.savefig(f"Metrics_{dataset}_{risks}_{filter_approach}", facecolor='white', transparent=False)