In [1]:
import os
from pathlib import Path
import sys
node_type = os.getenv('BB_CPU')
venv_dir = f'/rds/homes/g/gaddcz/Projects/CPRD/virtual-envTorch2.0-{node_type}'
venv_site_pkgs = Path(venv_dir) / 'lib' / f'python{sys.version_info.major}.{sys.version_info.minor}' / 'site-packages'
if venv_site_pkgs.exists():
    sys.path.insert(0, str(venv_site_pkgs))
    print(f"Added path '{venv_site_pkgs}' at start of search paths.")
else:
    print(f"Path '{venv_site_pkgs}' not found. Check that it exists and/or that it exists for node-type '{node_type}'.")

%load_ext autoreload
%autoreload 2

Added path '/rds/homes/g/gaddcz/Projects/CPRD/virtual-envTorch2.0-icelake/lib/python3.10/site-packages' at start of search paths.


In [2]:
import pytorch_lightning
import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import random
import sqlite3
from dataclasses import dataclass
import logging
from FastEHR.dataloader.foundational_loader import FoundationalDataModule
import pickle
from tqdm import tqdm
import seaborn as sns

from pycox.datasets import support
from pycox.evaluation import EvalSurv
from sklearn.preprocessing import StandardScaler
from sklearn_pandas import DataFrameMapper
from torch.utils.data import TensorDataset, DataLoader

from CPRD.src.modules.head_layers.survival.desurv import ODESurvSingle
from CPRD.src.modules.head_layers.survival.desurv import ODESurvMultiple

torch.manual_seed(1337)
logging.basicConfig(level=logging.INFO)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# device = "cpu"    # if more informative debugging statements are needed
print(f"Using device: {device}.")

Using device: cpu.


# Results, collated from other notebooks

In [13]:
results = pd.read_pickle('results_april.pkl')


# Temporarily remove empty results
results = results[results["Concordance"] != 0]


display(results)

display(results[(results["Model"] == 'Pre-trained') & (results["Samples"]== 2999)])


Unnamed: 0,Study,Seed,Samples,Model,Concordance,IBS,INBLL
0,Hypertension,1,2999,SurvivEHR SFT,0.525690,0.094264,0.335790
1,Hypertension,2,2999,SurvivEHR SFT,0.502880,0.094272,0.335820
2,Hypertension,3,2999,SurvivEHR SFT,0.527200,0.094284,0.335910
3,Hypertension,4,2999,SurvivEHR SFT,0.555690,0.094352,0.336390
4,Hypertension,5,2999,SurvivEHR SFT,0.496620,0.094242,0.335610
...,...,...,...,...,...,...,...
545,CVD,1,572096,DeepHit,0.654116,0.033611,0.143443
546,CVD,2,572096,DeepHit,0.658777,0.033520,0.142626
547,CVD,3,572096,DeepHit,0.656578,0.033577,0.143088
548,CVD,4,572096,DeepHit,0.657524,0.033574,0.143139


Unnamed: 0,Study,Seed,Samples,Model,Concordance,IBS,INBLL


In [43]:

for idx_exp, experiment in enumerate(["Hypertension", "CVD"]):

    df = results.query(f"Study == '{experiment}'")
    fig, axes = plt.subplots(1,3,figsize=(15,5), gridspec_kw={'top': 0.8}, constrained_layout=True)

for idx_exp, experiment in enumerate(["Hypertension", "CVD"]):

    df = results.query(f"Study == '{experiment}'")
    fig, axes = plt.subplots(1,3,figsize=(15,5), gridspec_kw={'top': 0.8}, constrained_layout=True)

    # Plot with interval
    lineplot_kwargs = {}
    # update to plot with bars
    lineplot_kwargs = {**lineplot_kwargs, "style":"Model", "markers":True, "dashes":False}
    # update to plot with dot-dash
    lineplot_kwargs = {**lineplot_kwargs, "err_style":"bars", "errorbar":("se", 2),}
    # update to plot individual samples
    # lineplot_kwargs = {**lineplot_kwargs, "units":"Seed", "estimator":None, "lw":1}

    sns.lineplot(data=df, x="Samples", y="Concordance", hue="Model", legend=False, ax=axes[0], **lineplot_kwargs)
    sns.lineplot(data=df, x="Samples", y="IBS", hue="Model",legend=False, ax=axes[1], **lineplot_kwargs)
    sns.lineplot(data=df, x="Samples", y="INBLL", hue="Model", legend=True if idx_exp == 0 else False, ax=axes[2], **lineplot_kwargs)
    
    # Shared legend inside plot area, across the top
    # handles, labels = axes[2].get_legend_handles_labels()
    # fig.legend(handles, labels, loc='upper center', ncol=len(labels), frameon=False)

    # For each experiment adjust ylims where there are outliers
    match experiment.lower():
        case "hypertension":
            pass
        case "cvd":
            axes[1].set_ylim((df["IBS"].min(), 0.0345))
            axes[2].set_ylim((df["INBLL"].min(), 0.15))
    
    plt.savefig(f"Metrics_{experiment}.png")
    plt.close(fig)
    # Plot with interval
    lineplot_kwargs = {}
    # update to plot with bars
    lineplot_kwargs = {**lineplot_kwargs, "style":"Model", "markers":True, "dashes":False}
    # update to plot with dot-dash
    lineplot_kwargs = {**lineplot_kwargs, "err_style":"bars", "errorbar":("se", 2),}
    # update to plot individual samples
    # lineplot_kwargs = {**lineplot_kwargs, "units":"Seed", "estimator":None, "lw":1}

    sns.lineplot(data=df, x="Samples", y="Concordance", hue="Model", legend=False, ax=axes[0], **lineplot_kwargs)
    sns.lineplot(data=df, x="Samples", y="IBS", hue="Model",legend=False, ax=axes[1], **lineplot_kwargs)
    sns.lineplot(data=df, x="Samples", y="INBLL", hue="Model", legend=True if idx_exp == 0 else False, ax=axes[2], **lineplot_kwargs)
    
    # Shared legend inside plot area, across the top
    # handles, labels = axes[2].get_legend_handles_labels()
    # fig.legend(handles, labels, loc='upper center', ncol=len(labels), frameon=False)

    # For each experiment adjust ylims where there are outliers
    match experiment.lower():
        case "hypertension":
            pass
        case "cvd":
            axes[1].set_ylim((df["IBS"].min(), 0.0345))
            axes[2].set_ylim((df["INBLL"].min(), 0.15))
    
    plt.savefig(f"Metrics_{experiment}.png")
    plt.close(fig)

In [9]:
# # filter_approach = "using_adapter"


# for filter_approach in ["paper"]:    # all", "full_training", "adapter_size", "using_adapter", "benchmark"

#     sweep_filter = get_sweep_filter(filter_approach)
    
#     fig, axs = plt.subplots(1,3, figsize=(20,6))
#     fig.suptitle(f"Metrics for {dataset}: {risks}")
    
#     ylabels = ["Concordance (time-dependent)", "Integrated Brier Score", "Integrated Negative Bernoulli log-likelihood"]
#     for metric_idx, ax in enumerate(axs):
#         print(metric_idx)
#         for key in results.keys():
            
#             if sweep_filter(key):
                
#                 y = results[key][3*metric_idx + 2]
#                 x = results[key][1][:len(y)]
                
#                 if "benchmark" in key.lower():
#                     linestyle='dashed'
#                     marker="*"
#                 else:
#                     linestyle='solid'
#                     marker="o"
    
#                 if len(y) > 1:
#                     ax.plot(x, y,
#                             label=results[key][0],
#                             lw=1.5,
#                             linestyle=linestyle,
#                             markersize=5,
#                             marker=marker
#                            )
#                 else:
#                     ax.scatter(x, y,
#                                label=results[key][0]
#                               )
    
#         ax.set_xlabel("Cohort sample size (log scale)")
#         ax.set_ylabel(ylabels[metric_idx])
#         ax.set_xscale('log')
#         ax.set_xticks([2999, 5296, 9351, 16509, 29148, 51461, 90856, 160407, 283203, 500000, 572096],
#                       [2999, 5296, 9351, 16509, 29148, 51461, 90856, 160407, 283203, 500000, 572096], 
#                       rotation=90)
            
            
#     plt.legend(loc="right", bbox_to_anchor=(0.6, 0., 1, 1))
#     plt.tight_layout()
#     plt.savefig(f"Metrics_{dataset}_{risks}_{filter_approach}", facecolor='white', transparent=False)

0
1
2
