In [1]:
import os
from pathlib import Path
import sys
node_type = os.getenv('BB_CPU')
venv_dir = f'/rds/homes/g/gaddcz/Projects/CPRD/virtual-envTorch2.0-{node_type}'
venv_site_pkgs = Path(venv_dir) / 'lib' / f'python{sys.version_info.major}.{sys.version_info.minor}' / 'site-packages'
if venv_site_pkgs.exists():
    sys.path.insert(0, str(venv_site_pkgs))
    print(f"Added path '{venv_site_pkgs}' at start of search paths.")
else:
    print(f"Path '{venv_site_pkgs}' not found. Check that it exists and/or that it exists for node-type '{node_type}'.")

!pwd

%load_ext autoreload
%autoreload 2

Added path '/rds/homes/g/gaddcz/Projects/CPRD/virtual-envTorch2.0-icelake/lib/python3.10/site-packages' at start of search paths.
/rds/homes/g/gaddcz/Projects/CPRD/examples/data/2_build_pre_training_dataset


In [32]:
import torch
import numpy as np
import logging
from tqdm import tqdm
import pickle
from hydra import compose, initialize
from omegaconf import OmegaConf
import seaborn as sns
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
from matplotlib.ticker import ScalarFormatter

import matplotlib.ticker as ticker
from matplotlib.ticker import FuncFormatter, StrMethodFormatter
from typing import Tuple, Union
from matplotlib.axes import Axes


from FastEHR.dataloader import FoundationalDataModule

from CPRD.examples.modelling.SurvivEHR.run_experiment import run
from CPRD.src.models.survival.task_heads.causal import SurvStreamGPTForCausalModelling
from CPRD.examples.data.map_to_reduced_names import convert_event_names, EVENT_NAME_SHORT_MAP, EVENT_NAME_LONG_MAP

import time
import os
import polars as pl
pl.Config.set_tbl_rows(10000)
import pandas as pd
pd.options.display.max_rows = 10000

torch.manual_seed(1337)
torch.set_float32_matmul_precision('medium')

logging.basicConfig(level=logging.INFO)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# device = "cpu"    # if more informative debugging statements are needed
print(f"Using device: {device}.")

Using device: cuda.


In [3]:
# load the configuration file, override any settings 
with initialize(version_base=None, config_path="../../modelling/SurvivEHR/confs", job_name="testing_notebook"):
    cfg = compose(config_name="config_CompetingRisk11M", overrides=[])

print(f"Data config:\n")
print(OmegaConf.to_yaml(cfg.data))


Data config:

batch_size: 64
unk_freq_threshold: 0.0
min_workers: 12
global_diagnoses: false
repeating_events: true
path_to_db: /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/cprd.db
path_to_ds: /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/
meta_information_path: /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/meta_information_QuantJenny.pickle
subsample_training: null



In [4]:
# Build 
dm = FoundationalDataModule(path_to_db=cfg.data.path_to_db,
                            path_to_ds=cfg.data.path_to_ds,
                            load=True,
                            batch_size=cfg.data.batch_size,
                            max_seq_length=cfg.transformer.block_size,
                            freq_threshold=cfg.data.unk_freq_threshold,
                            min_workers=cfg.data.min_workers,
                           )

vocab_size = dm.train_set.tokenizer.vocab_size
print(f"{vocab_size} vocab elements")

INFO:root:Creating unsupervised collator for DataModule
INFO:root:Using meta information from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/meta_information.pickle
INFO:root:Using train file-row count dictionary from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/file_row_count_dict_train.pickle
INFO:root:Using test file-row count dictionary from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/file_row_count_dict_test.pickle
INFO:root:Using val file-row count dictionary from /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/file_row_count_dict_val.pickle
INFO:root:Tokenzier created based on 7,555,415,275 tokens
INFO:root:Using tabular tokenizer, created from meta information and containing 265 tokens
INFO:root:Set seed to 42
INFO:root:Loaded /rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain

265 vocab elements


In [5]:
short_names = [j for i,j in EVENT_NAME_LONG_MAP.items()]
# print(len(short_names))
# print(np.unique(short_names))
# print(len(np.unique(short_names)))

print(dm.tokenizer._event_counts.head())

shape: (5, 3)
┌────────────────────────┬───────┬───────────┐
│ EVENT                  ┆ COUNT ┆ FREQUENCY │
│ ---                    ┆ ---   ┆ ---       │
│ str                    ┆ u32   ┆ f64       │
╞════════════════════════╪═══════╪═══════════╡
│ UNK                    ┆ 0     ┆ 0.0       │
│ ADDISONS_DISEASE       ┆ 6691  ┆ 8.8559e-7 │
│ CYSTICFIBROSIS         ┆ 7053  ┆ 9.3350e-7 │
│ SYSTEMIC_SCLEROSIS     ┆ 8772  ┆ 0.000001  │
│ SICKLE_CELL_DISEASE_V2 ┆ 11159 ┆ 0.000001  │
└────────────────────────┴───────┴───────────┘


# Plot top k common events from each event type

In [65]:
def plot_count_histograms(polars_frames: list[pl.DataFrame],
                          frame_names: list[str],
                          top_k: int = 10,
                          bar_color: str = "#4C72B0",
                          y_axis: str = "COUNT"):
    """
    Draw a vertical stack of bar charts, one row per Polars frame.

    Parameters
    ----------
    polars_frames : list[pl.DataFrame]
        Each frame must contain columns 'EVENT' and either 'COUNT' or
        'FREQUENCY' (depending on `y_axis`).
    frame_names : list[str]
        Name of each frame, for annotation.
    top_k : int, default 10
        Number of highest-count events to display in every subplot.
    bar_color : str, default '#4C72B0'
        Single colour applied to all bars in all subplots.
    y_axis : {'COUNT', 'FREQUENCY'}, default 'COUNT'
        Which column to show on the y-axis.
    """
    assert y_axis in {"COUNT", "FREQUENCY"}, "`y_axis` must be 'COUNT' or 'FREQUENCY'"

    n_rows = len(polars_frames)
    fig_height = 2.3 * n_rows
    fig, axs = plt.subplots(n_rows, 1, figsize=(3, fig_height), sharex=False,)

    axs = axs.ravel()

    for idx, (frame, ax) in enumerate(zip(polars_frames, axs)):
        top = (
            frame
            .sort(y_axis, descending=True)
            .head(top_k)
            .to_pandas()
        )

        sns.barplot(
            data=top,
            y="EVENT",
            x=y_axis,
            color=bar_color,
            ax=ax,
            orient="h",
            saturation=1.0,
        )

        # remove y-tick text (it is now inside the bars)
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_xlabel(None)
        ax.set_ylabel(None)
        ax.set_yticks([])
        # ax.grid(False)
        
        # Annotate event names (to save space)
        # heights = [bar.get_height() * 0.5 for bar in ax.patches]
        for bar, label in zip(ax.patches, top["EVENT"]):
            ax.text(bar.get_width() * 0.02,
                    bar.get_y() + bar.get_height()/2,
                    label,
                    ha="left", va="center",
                    rotation=0,
                    color="white", fontsize="small", wrap=True)

        # Annotate data type
        ax.annotate(
            f"{frame_names[idx].replace('_', ' ')}",
            xy=(0.98, 0.02),
            xycoords="axes fraction",
            ha="right", va="bottom",
            fontsize="medium",
            fontweight="bold"
        )

        # format ticks: show counts in millions
        if y_axis == "COUNT":
            ax.xaxis.set_major_formatter(
                FuncFormatter(lambda x, _: f"{x/1_000_000:,.0f}")
            )
        elif y_axis == "FREQUENCY":
            pass
        else:
            raise NotImplementedError
        
    # # Labels
    fig.supylabel("Events", fontsize="large", x=0.06, y=0.5, rotation="vertical", ha="left", va="center")
    fig.supxlabel("Count (millions)" if y_axis == "COUNT" else "Frequency", y=0.04)
        
    fig.subplots_adjust(hspace=0.15)
    fig.subplots_adjust(top=0.97)#, bottom=0.06)
    
    return fig, axs

In [66]:
# Get the subgroups of tokens we want to plot by
lab_names = dm.meta_information["measurement_tables"][dm.meta_information["measurement_tables"]["count_obs"] > 0]["event"].to_list()
medication_names = dm.meta_information["measurement_tables"][dm.meta_information["measurement_tables"]["count_obs"] == 0]["event"].to_list()
diagnosis_names = dm.meta_information["diagnosis_table"]["event"].to_list()
top_k = 10

all_events_of_interest = []
for events_of_interest in [diagnosis_names, medication_names, lab_names]:

    # print(dm.tokenizer._event_counts)

    # Get the rows belonging to events of interest
    event_counts_of_interest = (dm.tokenizer._event_counts
                                # Keep only the events of interest
                                .filter(pl.col("EVENT").is_in(list(events_of_interest)))
                                # Map the coded name to the plotting name (here some events can be combined)
                                .with_columns(
                                    pl.col("EVENT")
                                      .map_dict(EVENT_NAME_SHORT_MAP, return_dtype=pl.Utf8)
                                      .alias("EVENT")
                                )
                               )

    # group on the (now-short) event name, and update the statistics
    event_counts_of_interest = (
        event_counts_of_interest
        .groupby("EVENT")     
        .agg(
            pl.col("COUNT").sum().alias("COUNT"),
            pl.col("FREQUENCY").sum().alias("FREQUENCY")
            )
        .sort("COUNT", descending=True)
        )
    
    # print(event_counts_of_interest)
    all_events_of_interest.append(event_counts_of_interest)
                                
fig, axs = plot_count_histograms(all_events_of_interest,
                                 frame_names=["Diagnoses", "Medications", "Investigations"],
                                 top_k=top_k)
plt.savefig(f"figs/histogram_top{top_k}.png", dpi=300)
plt.close()


## Waiting time histogram

In [8]:
waiting_times = []
max_batches = int(0.035 * len(dm.test_dataloader()))

for idx, batch in tqdm(enumerate(dm.test_dataloader()),
                      total=max_batches):

    ages = batch["ages"]
    mask = batch["attention_mask"].bool()
    delta_mask = mask[:, 1:] & mask[:, :-1]
    delta = ages[:, 1:] - ages[:, :-1]
    
    flat_deltas = delta.masked_select(delta_mask)   # 1-D tensor of valid ages
    flat_pos_deltas = flat_deltas.masked_select(flat_deltas > 0)

    waiting_times.append(flat_pos_deltas)

    if idx > max_batches:
        break

825it [04:25,  3.10it/s]                         


In [9]:
batch["ages"]

tensor([[ 0.0378,  0.1490,  0.1490,  ...,  0.0000,  0.0000,  0.0000],
        [15.5036, 17.1047, 17.5381,  ...,  0.0000,  0.0000,  0.0000],
        [ 8.0252,  8.0252,  8.0252,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [ 7.3737,  7.3737,  7.3737,  ...,  0.0000,  0.0000,  0.0000],
        [ 6.3375,  6.3375,  6.3375,  ...,  0.0000,  0.0000,  0.0000],
        [ 4.6405,  4.6405,  4.6405,  ...,  0.0000,  0.0000,  0.0000]])

In [10]:
flat_waiting_times = torch.concatenate(waiting_times)

# scale to years
flat_waiting_times *= 5

print(flat_waiting_times.shape)

torch.Size([933022])


In [59]:
def plot_waiting_times_inset_hist(
    waiting_times: torch.Tensor,
    *,
    bins: int = 30,
    inset_max: float = 28,
    inset=True,
):
    """
    """
    
    data = waiting_times.detach().cpu().numpy()

    fig, ax = plt.subplots(figsize=(2.5*2, 2.3))

    # Main histogram
    sns.histplot(
        data,
        bins=bins,
        ax=ax,
        edgecolor="white",
        color="#4C72B0",
        linewidth=0.5,
        kde=False,
        alpha=1.0,
    )
    ax.set_yscale("log")
    # ax.yaxis.set_major_formatter(ScalarFormatter())  # cleaner ticks
    ax.set_xlabel("Inter-event times (years)")
    ax.set_ylabel("Count")
    ax.grid(False)  # make sure main axis grid is off

    # Inset: zoom on values < inset_max
    if inset:
        data_inset = data * 365
        ax_ins = inset_axes(ax, width="40%", height="50%", loc="upper right")
        sns.histplot(
            data=data_inset[data_inset <= inset_max],
            bins=inset_max // 2,
            ax=ax_ins,
            edgecolor="white",
            color="indianred",
            linewidth=0.5,
            kde=False,
            alpha=1.0,
        )
        # ax_ins.set_xlim(0, inset_max)
        # ax_ins.set_yscale("log")
        # ax_ins.yaxis.set_major_formatter(ScalarFormatter())  # cleaner ticks
        ax_ins.set_xticks([0,7,14,21,28], minor=False)
        ax_ins.set_xlabel(f"First {inset_max} days")
        ax_ins.set_ylabel("")
        # ax_ins.set_title(f"", fontsize=9, pad=4)
        ax_ins.grid(False)  # ensure inset grid is off

        sns.despine(ax=ax_ins, top=True, right=True, left=False)
        
    # Clean up spines for both axes
    sns.despine(ax=ax, top=True, right=True)

    return fig, ax


def plot_waiting_times_hist(
    waiting_times: torch.Tensor,
    *,
    bins: int = 30,
    inset_max: float = 28,
    inset=True,
):
    """
    """
    
    data = waiting_times.detach().cpu().numpy()

    fig, axs = plt.subplots(1, 2, figsize=(2.5*3, 2.3), sharex=False, sharey=False)

    # Main histogram
    sns.histplot(
        data,
        bins=bins,
        ax=axs[0],
        edgecolor="white",
        color="#4C72B0",
        linewidth=0.5,
        kde=False,
        alpha=1.0,
    )
    axs[0].set_yscale("log")
    # ax.yaxis.set_major_formatter(ScalarFormatter())  # cleaner ticks
    axs[0].set_xlabel("Inter-event times (years)")
    axs[0].set_ylabel("Count (log-scale)")
    axs[0].grid(False)  # make sure main axis grid is off

    # Inset: zoom on values < inset_max
    data_inset = data * 365
    sns.histplot(
        data=data_inset[data_inset <= inset_max],
        bins=inset_max // 2,
        ax=axs[1],
        edgecolor="white",
        color="#4C72B0",
        linewidth=0.5,
        kde=False,
        alpha=1.0,
    )
    # ax_ins.set_xlim(0, inset_max)
    scale_y = 1e3
    ticks_y = ticker.FuncFormatter(lambda x, pos: '{0:g}'.format(x/scale_y))
    axs[1].yaxis.set_major_formatter(ticks_y)
    # ax_ins.set_yscale("log")
    # ax_ins.yaxis.set_major_formatter(ScalarFormatter())  # cleaner ticks
    axs[1].set_xlabel(f"First {inset_max} days")
    axs[1].set_ylabel("Count (thousands)")
    # ax_ins.set_title(f"", fontsize=9, pad=4)
    axs[1].grid(False)  # ensure inset grid is off

    # sns.despine(ax=axs[1], top=True, right=True, left=False)
        
    # Clean up spines for both axes
    # sns.despine(ax=ax, top=True, right=True)

    return fig, ax

In [60]:
fig, ax = plot_waiting_times_inset_hist(flat_waiting_times)
fig.savefig(f"figs/Histogram_event_times_inset.png", dpi=300, bbox_inches="tight")

fig, ax = plot_waiting_times_hist(flat_waiting_times)
fig.savefig(f"figs/Histogram_event_times.png", dpi=300, bbox_inches="tight")