In [1]:
import os
from pathlib import Path
import sys
node_type = os.getenv('BB_CPU')
venv_dir = f'/rds/homes/g/gaddcz/Projects/CPRD/virtual-envTorch2.0-{node_type}'
venv_site_pkgs = Path(venv_dir) / 'lib' / f'python{sys.version_info.major}.{sys.version_info.minor}' / 'site-packages'
if venv_site_pkgs.exists():
    sys.path.insert(0, str(venv_site_pkgs))
    print(f"Added path '{venv_site_pkgs}' at start of search paths.")
else:
    print(f"Path '{venv_site_pkgs}' not found. Check that it exists and/or that it exists for node-type '{node_type}'.")

%load_ext autoreload
%autoreload 2

os.chdir('/rds/homes/g/gaddcz/Projects/CPRD/examples/modelling/paper_plots')
print(os.getcwd())

Added path '/rds/homes/g/gaddcz/Projects/CPRD/virtual-envTorch2.0-icelake/lib/python3.10/site-packages' at start of search paths.
/rds/homes/g/gaddcz/Projects/CPRD/examples/modelling/paper_plots


In [2]:
import pytorch_lightning
import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import random
import sqlite3
from dataclasses import dataclass
import logging
from FastEHR.dataloader.foundational_loader import FoundationalDataModule
import pickle
from tqdm import tqdm
import seaborn as sns

from pycox.datasets import support
from pycox.evaluation import EvalSurv
from sklearn.preprocessing import StandardScaler
from sklearn_pandas import DataFrameMapper
from torch.utils.data import TensorDataset, DataLoader
from mpl_toolkits.axes_grid1.inset_locator import zoomed_inset_axes, mark_inset
from mpl_toolkits.axes_grid1.inset_locator import inset_axes

from CPRD.src.modules.head_layers.survival.desurv import ODESurvSingle
from CPRD.src.modules.head_layers.survival.desurv import ODESurvMultiple

torch.manual_seed(1337)
logging.basicConfig(level=logging.INFO)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# device = "cpu"    # if more informative debugging statements are needed
print(f"Using device: {device}.")


KeyboardInterrupt



In [None]:
results = pd.read_pickle('results_new.pkl')
results = results.dropna()

display(results.head())
display(results[(results["Method"] == 'SurvivEHR (FFT)') & (results["Samples"]== 2999) & (results["Study"]== "CVD")].head())

print(results["Method"].unique())

# Plot table results

In [55]:
sns.set(style="ticks", context="notebook")

results_table  = results[results['Samples'] == results.groupby('Study')['Samples'].transform('max')]
results_zero_shot = results[results["Method"] == "SurvivEHR (zero-shot)"].copy()
results_zero_shot["Method"] = "SurvivEHR (ZS)"
results_table = pd.concat([results_table, results_zero_shot], ignore_index=True)

fig, axes = plt.subplots(1,3,figsize=(15,5), constrained_layout=True)

sns.barplot(data=results_table, x="Study", y="Concordance", hue="Method", ax=axes[0])
sns.barplot(data=results_table, x="Study", y="IBS", hue="Method", ax=axes[1])
sns.barplot(data=results_table, x="Study", y="INBLL", hue="Method", ax=axes[2])

# For each experiment adjust ylims where there are outliers
min_scale, max_scale = 0.75, 1.1
metrics = ["Concordance", "IBS", "INBLL"]
labels = ["Concordance (time-dependent)", "Integrated Brier Score", "Integrated Negative Bernoulli \n Log-likelihood"]
for idx_ax, ax in enumerate(axes):
    axes[idx_ax].set_ylim((results_table[metrics[idx_ax]].min()*min_scale, results_table[metrics[idx_ax]].max()*max_scale))
    if idx_ax != 0:
        ax.set_ylabel("")
    else:
        ax.set_ylabel("Score")
    axes[idx_ax].set_title(labels[idx_ax], fontsize=12, pad=12, fontweight='bold')
    
# Add zoom on IBS axis
target_ax = axes[1]
axins = inset_axes(target_ax, width="35%", height="30%",  loc='upper center',
                   # bbox_to_anchor=(0.17, -0.04, 1.0, 1.0), 
                   # bbox_transform=target_ax.transAxes,
                   # borderpad=0
                  )
sns.barplot(data=results_table, x="Study", y="IBS", hue="Method", ax=axins)
# Set the x and y limits of the zoomed-in view manually
axins.set_xlim(0.5, 1.5)          # Narrow x-axis to focus on subset of bars
axins.set_ylim(0.0331, 0.0338)     # Zoom in on y-axis range (e.g., IBS values)
axins.set_xticklabels([])         # Hide x labels in inset
axins.tick_params(labelsize=8)    # Make y labels smaller in inset
axins.set_xlabel("") 
axins.set_ylabel("") 
axins.tick_params(bottom=False)
axins.legend_.remove()
# Draw lines connecting inset and main plot
mark_inset(target_ax, axins, loc1=3, loc2=4, fc="none", ec="0.5")

# Add zoom on INBLL axis
target_ax = axes[2]
axins_inbll = inset_axes(target_ax, width="35%", height="30%",  loc='upper center',
                         # bbox_to_anchor=(0.17, -0.04, 1.0, 1.0), 
                         # bbox_transform=target_ax.transAxes,
                         # borderpad=0
                        )
sns.barplot(data=results_table, x="Study", y="INBLL", hue="Method", ax=axins_inbll)
# Set the x and y limits of the zoomed-in view manually
axins_inbll.set_xlim(0.5, 1.5)       # Narrow x-axis to focus on subset of bars
axins_inbll.set_ylim(0.137, 0.146)     # Zoom in on y-axis range (e.g., IBS values)
axins_inbll.set_xticklabels([])      # Hide x labels in inset
axins_inbll.tick_params(labelsize=8) # Make y labels smaller in inset
axins_inbll.set_xlabel("") 
axins_inbll.set_ylabel("") 
axins_inbll.tick_params(bottom=False)
axins_inbll.legend_.remove()
# Draw lines connecting inset and main plot
mark_inset(target_ax, axins_inbll, loc1=3, loc2=4, fc="none", ec="0.5")

# Add annotations for whether a score should be higher or lower
kwargs = {"xy":(0.85, 0.95), "xycoords":'axes fraction', "ha":'center', "va":'center', "fontsize":11, "fontweight":'normal'}
axes[0].annotate(" (↑ is better)", **kwargs)
axes[1].annotate("(↓ is better)", **kwargs)
axes[2].annotate("(↓ is better)", **kwargs)

# Legend
# axes[0].legend_.remove()
# axes[1].legend_.remove()
# axes[2].legend(loc='lower right', bbox_to_anchor=(1.55, 0.3),  ncol=1, frameon=False, title="Methods")             # Beside

# Move legend below the plot as a horizontal bar and remove others
for idx_ax, ax in enumerate(axes):
    if idx_exp == 1 and idx_ax == 1:
        # ax.legend(loc='lower center', bbox_to_anchor=(0.5, 1.02),  ncol=3, frameon=False)             # Above
        ax.legend(loc='upper center', bbox_to_anchor=(0.5, -0.15), ncol=3, frameon=True, fontsize=10)        # Below
        
    else:
        ax.legend_.remove()

# plt.tight_layout()
plt.savefig(f"Metrics.png")
plt.close(fig)

# Plot ablation

## Lineplot

In [5]:
results_ablation = results[results["Samples"] > 10000].copy()
log10_xscale = False
sns.set(style="ticks", context="notebook")

if log10_xscale:
    group_shift_map = {group: i * 0.025 for i, group in enumerate(results['Method'].unique())}
    results_ablation.loc[:, 'Samples'] = np.exp(np.log(results_ablation['Samples']) + results_ablation['Method'].map(group_shift_map))

fig, axes_all = plt.subplots(2, 3, figsize=(15,10), gridspec_kw={'top': 0.8}, constrained_layout=True)
for idx_exp, experiment in enumerate(["Hypertension", "CVD"]):

    df = results_ablation.query(f"Study == '{experiment}'")
    axes = axes_all[idx_exp,:]

    # Plot with interval
    lineplot_kwargs = {}
    # update to plot with bars
    # lineplot_kwargs = {**lineplot_kwargs, "style":"Method", "markers":True, "dashes":False}
    # update to plot with dot-dash
    # lineplot_kwargs = {**lineplot_kwargs, "err_style":"bars", "errorbar":("se", 2),}
    # update to plot individual samples
    # lineplot_kwargs = {**lineplot_kwargs, "units":"Seed", "estimator":None, "lw":1}

    sns.lineplot(data=df, x="Samples", y="Concordance", hue="Method", ax=axes[0], **lineplot_kwargs)
    sns.lineplot(data=df, x="Samples", y="IBS", hue="Method",ax=axes[1], **lineplot_kwargs)
    sns.lineplot(data=df, x="Samples", y="INBLL", hue="Method", ax=axes[2], **lineplot_kwargs)
    
    # For each experiment adjust ylims where there are outliers
    df_upper = lambda _df, _metric, _scale: _df[_metric].max()  + _scale * (_df[_metric].max()-_df[_metric].min())
    axes[0].set_ylim((df["Concordance"].min(), df_upper(df, "Concordance", 0.1) ))
    axes[1].set_ylim((df["IBS"].min(), df_upper(df, "IBS", 0.1) ))
    axes[2].set_ylim((df["INBLL"].min(), df_upper(df, "INBLL", 0.1) ))

    for ax in axes:
        ax.set_xlabel("Training Samples")
        if log10_xscale:
            ax.set(xscale="log")
        if idx_exp == 0:
            ax.set_xlabel("") 
        ax.set_ylabel("")

    # Column titles
    if idx_exp == 0:
        column_titles = ["Concordance (time-dependent)", 
                         "Integrated Brier Score", 
                         "Integrated Negative Bernoulli \n Log-likelihood"]
        for ax, title in zip(axes, column_titles):
            ax.set_title(title, fontsize=12, pad=12, fontweight='bold')

    # Row titles
    axes[0].set_ylabel(f"Scores", fontsize=12, fontweight='normal')
    row_titles = ["Hypertension Study", "CVD Study"]
    y_positions = [0.75, 0.325]
    for y, title in zip(y_positions, row_titles):
        fig.text(-0.02, y, title, va='center', ha='left', fontsize=12, rotation=90, fontweight='bold')

    # Add annotations for whether a score should be higher or lower
    kwargs = {"xy":(0.19, 0.95), "xycoords":'axes fraction', "ha":'center', "va":'center', "fontsize":11, "fontweight":'normal'}
    axes[0].annotate(" (higher ↑ is better)", **kwargs)
    axes[1].annotate("(lower ↓ is better)", **kwargs)
    axes[2].annotate("(lower ↓ is better)", **kwargs)
    
    # Move legend below the plot as a horizontal bar and remove others
    for idx_ax, ax in enumerate(axes):
        if idx_exp == 1 and idx_ax == 1:
            # ax.legend(loc='lower center', bbox_to_anchor=(0.5, 1.02),  ncol=3, frameon=False)             # Above
            ax.legend(loc='upper center', bbox_to_anchor=(0.5, -0.15), ncol=3, frameon=True,)        # Below
            
        else:
            ax.legend_.remove()
            
plt.tight_layout()
plt.savefig(f"Ablation_lineplot.png", bbox_inches='tight')
plt.close(fig)

  plt.tight_layout()


## Histogram plot

In [6]:
results_ablation = results[results["Samples"] > 10000].copy()

for idx_exp, experiment in enumerate(["Hypertension", "CVD"]):

    df = results_ablation.query(f"Study == '{experiment}'")
    fig, axes = plt.subplots(1,3,figsize=(25,5), gridspec_kw={'top': 0.8}, constrained_layout=True)
    
    sns.barplot(data=df, x="Samples", y="Concordance", hue="Method", ax=axes[0])
    sns.barplot(data=df, x="Samples", y="IBS", hue="Method", ax=axes[1])
    sns.barplot(data=df, x="Samples", y="INBLL", hue="Method", ax=axes[2])
    
    # For each experiment adjust ylims where there are outliers
    df_upper = lambda _df, _metric, _scale: _df[_metric].max()  + _scale * (_df[_metric].max()-_df[_metric].min())
    match experiment.lower():
        case "hypertension":
            axes[0].set_ylim((df["Concordance"].min(), df_upper(df, "Concordance", 0.1) ))
            axes[1].set_ylim((df["IBS"].min(), df_upper(df, "IBS", 0.1) ))
            axes[2].set_ylim((df["INBLL"].min(), df_upper(df, "INBLL", 0.5) ))
        case "cvd":
            axes[0].set_ylim((df["Concordance"].min(), df_upper(df, "Concordance", 0.1) ))
            axes[1].set_ylim((df["IBS"].min(), df_upper(df, "IBS", 0.1) ))
            axes[2].set_ylim((df["INBLL"].min(), df_upper(df, "INBLL", 0.1) ))
    
    
    # Move legend below the plot as a horizontal bar and remove others
    # axes[1].legend(loc='upper center', bbox_to_anchor=(0.5, -0.15), ncol=3, frameon=False)        # Below
    axes[1].legend(loc='lower center', bbox_to_anchor=(0.5, 1.02),  ncol=3, frameon=False)             # Above
    axes[0].legend_.remove()
    # axes[1].legend_.remove()
    axes[-1].legend_.remove()
    
    
    plt.savefig(f"Ablation_barplot_{experiment}.png")
    plt.close(fig)

In [7]:
np.log(results_ablation["Samples"].unique())

array([13.25706209, 13.12236338, 12.55392983, 11.98546961, 11.41703111,
       10.84857952, 10.28014158,  9.90348755,  9.71166097])