# Create figures for generation

In [1]:
import os
from pathlib import Path
import sys
node_type = os.getenv('BB_CPU')
venv_dir = f'/rds/homes/g/gaddcz/Projects/CPRD/virtual-envTorch2.0-{node_type}'
venv_site_pkgs = Path(venv_dir) / 'lib' / f'python{sys.version_info.major}.{sys.version_info.minor}' / 'site-packages'
if venv_site_pkgs.exists():
    sys.path.insert(0, str(venv_site_pkgs))
    print(f"Added path '{venv_site_pkgs}' at start of search paths.")
else:
    print(f"Path '{venv_site_pkgs}' not found. Check that it exists and/or that it exists for node-type '{node_type}'.")

%load_ext autoreload
%autoreload 2

os.chdir('/rds/homes/g/gaddcz/Projects/CPRD/examples/modelling/paper_plots')
print(os.getcwd())

Added path '/rds/homes/g/gaddcz/Projects/CPRD/virtual-envTorch2.0-icelake/lib/python3.10/site-packages' at start of search paths.
/rds/homes/g/gaddcz/Projects/CPRD/examples/modelling/paper_plots


In [None]:
import numpy as np
import matplotlib.pyplot as plt
import plotly.graph_objects as go
import wandb
import polars as pl
import pandas as pd
from hydra import compose, initialize
import seaborn as sns
import json
import io
from CPRD.examples.data.map_to_reduced_names import convert_event_names, EVENT_NAME_SHORT_MAP
from matplotlib.colors import LogNorm, Normalize

from CPRD.examples.modelling.SurvivEHR.run_experiment import run

%env SLURM_NTASKS_PER_NODE=28   

sns.set(style="ticks", context="notebook")


## Initialise the dataloader used for pre-training

In [4]:
# load the configuration file, override any settings 
with initialize(version_base=None, config_path="../SurvivEHR/confs", job_name="causal_metric_testing_notebook"):
    cfg = compose(config_name="config_CompetingRisk11M", 
                  overrides=[# Experiment setup
                             "experiment.run_id=SurvivEHR-cr-small-debug7_exp1000-v1-v4-v1",
                             "experiment.train=False",
                             "experiment.test=False",
                             "experiment.log=False",
                             "data.meta_information_path=/rds/projects/g/gokhalkm-optimal/OPTIMAL_MASTER_DATASET/data/FoundationalModel/PreTrain/meta_information_QuantJenny.pickle",
                             "data.min_workers=12",
                            ]
                 )     

model, dm = run(cfg)
print(f"Loaded model with {sum(p.numel() for p in model.parameters())/1e6} M parameters")

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Loaded model with 11.20919 M parameters


/rds/bear-apps/2022a/EL8-ice/software/PyTorch-Lightning/2.1.0-foss-2022a-CUDA-11.7.0/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:67: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default


# Load the generated data for a dataset

In [5]:
dataset = "FineTune_CVD"
pre_trained_model = "SurvivEHR-cr-small-debug7_exp1000-v1-v4-v1"
sns.color_palette("Reds")

gen_data_path = f'../SurvivEHR/notebooks/CompetingRisk/0_pretraining/figs/generation/{pre_trained_model}/{dataset}_dataset/'

In [None]:
def _get_stratified_next_event_matrix(df, dm, events_of_interest=None):

    df["Next event"] = df["Next event"].map(EVENT_NAME_SHORT_MAP)
    df["Previous event"] = df["Previous event"].map(EVENT_NAME_SHORT_MAP)

    df = pd.crosstab(df['Next event'], df['Previous event'])
    
    # Drop columns and rows 
    if events_of_interest is not None:
        cols_to_keep = df.columns.intersection(events_of_interest[0])
        rows_to_keep = df.index.intersection(events_of_interest[1])
        df = df.loc[rows_to_keep, cols_to_keep]

    return df


## Transition matrices

In [11]:
def plot_stratified_next_event_matrix(df, dm, events_of_interest=None, minimum_threshold=0.01, max_steps=10, save_name="next_event.png"):

    # Filter out by how long into the future we generate
    df = df[df["Generation step"] < max_steps].copy()
    
    df = _get_stratified_next_event_matrix(df, dm, events_of_interest=events_of_interest)

    # For combinations occuring fewer than `minimum_threshold` times, set to zero
    minimum_threshold = int(minimum_threshold * df.sum().sum())
    for i in range(minimum_threshold):
        df.replace(i, 0, inplace=True)

    # Remove columns and rows that are all zeros
    df = df.loc[~(df==0).all(axis=1)]
    df = df.loc[:, (df != 0).any(axis=0)]

    # Set remaining zeros to nan so they don't convolute plot
    df.replace(0, np.nan, inplace=True)

    fig, axis = plt.subplots(1,1,figsize=(10,10), constrained_layout=True)

    sns.heatmap(df, xticklabels=True, yticklabels=True, cmap="YlOrRd", vmin=0,
                cbar_kws={'label': f'Count ({minimum_threshold}+ threshold)'}) # ,  norm=LogNorm())

    plt.xlabel("Prior event")
    plt.ylabel("Next event")
    plt.grid()
    
    plt.savefig(save_name)
    plt.close()



In [12]:
# Get the subgroups of tokens we want to plot by
lab_names = dm.meta_information["measurement_tables"][dm.meta_information["measurement_tables"]["count_obs"] > 0]["event"].to_list()
medication_names = dm.meta_information["measurement_tables"][dm.meta_information["measurement_tables"]["count_obs"] == 0]["event"].to_list()
diagnosis_names = dm.meta_information["diagnosis_table"]["event"].to_list()

# Plot heatmap
for name, events_of_interest in zip(["diagnosis vs diagnosis", "diagnosis vs drug", "diagnosis vs investigation"], 
                                    [[diagnosis_names, diagnosis_names],
                                     [diagnosis_names, medication_names],
                                     [diagnosis_names, lab_names],
                                     ]):

    df = pd.read_csv(gen_data_path + "next_event_data.csv")

    for i in range(2):
        events_of_interest[i] = [EVENT_NAME_SHORT_MAP[col] if col in EVENT_NAME_SHORT_MAP else col for col in events_of_interest[i]]

    for max_steps in range(1,25,5):
        plot_stratified_next_event_matrix(
            df,
            dm,
            events_of_interest=events_of_interest,
            minimum_threshold=1e-2,
            max_steps=max_steps,
            save_name=f"generation_matrix_{name}_{max_steps}.png" 
            )



## Histogram

In [13]:
def plot_stratified_next_event_histogram(df, dm, events_of_interest=None, top_k=10, max_steps=10, save_name="next_event.png"):

    # Filter out by how long into the future we generate
    df = df[df["Generation step"] < max_steps].copy()
    
    df["Previous event"] = df["Previous event"].map(EVENT_NAME_SHORT_MAP)
        
    counts = df['Previous event'].value_counts()
    top_labels = counts.head(top_k).index.tolist()

    filtered_df = df[df['Previous event'].isin(top_labels)].copy()
    filtered_df["Previous event"] = pd.Categorical(filtered_df["Previous event"], top_labels)

    fig, axis = plt.subplots(1,1,figsize=(8,5))

    sns.histplot(
        data=filtered_df, 
        y="Previous event", 
        stat="percent",
    )

    plt.ylabel(f"Frequency of events following T2DM diagnosis")
    plt.grid()
    plt.tight_layout()
    
    plt.savefig(save_name)
    plt.close()

In [14]:
# Plot histogram
df = pd.read_csv(gen_data_path + "next_event_data.csv")

plot_stratified_next_event_histogram(
    df,
    dm,
    events_of_interest=events_of_interest,
    top_k=20,
    max_steps=3,
    save_name=f"next_event_generation_histogram.png",
    )

## Sankey flow diagram

In [51]:

def transition_to_sankey_df(
    counts_df: pd.DataFrame,
    *,
    k_per_token: int = 3,
    min_prob: float = 1e-4,
    normalise: bool = True,
):
    """
    Build Plotly Sankey inputs from a transition-count DataFrame
    (index = “Next event”, columns = “Previous event”).

    • Makes the matrix square by adding missing rows/columns (zeros).
    • Transposes so rows = previous events, cols = next events.
    • Optionally row-normalises to probabilities.
    • Keeps only the top-k outgoing transitions per previous event
      with prob ≥ min_prob.
    """
    if not isinstance(counts_df, pd.DataFrame):
        raise TypeError("counts_df must be a pandas.DataFrame")

    # 1. Square the matrix
    events = sorted(set(counts_df.columns).union(counts_df.index))
    df = counts_df.reindex(index=events, columns=events, fill_value=0)

    # 2. Transpose → rows = previous, cols = next
    mat = df.T.to_numpy(dtype=float)
    labels = events                       # same order as rows

    # 3. Row-normalise (optional)
    if normalise:
        row_sums = mat.sum(axis=1, keepdims=True)
        row_sums[row_sums == 0] = 1.0
        mat = mat / row_sums

    # 4. Build edge lists
    n = mat.shape[0]
    src, tgt, val = [], [], []
    for i in range(n):                    # previous event
        top_idx = np.argsort(mat[i])[::-1][:k_per_token]
        for j in top_idx:                 # next event
            p = mat[i, j]
            if p >= min_prob:
                src.append(i)
                tgt.append(j)
                val.append(float(p))

    # --- RETURN in Plotly’s expected structure ------------------------
    return {
        "node": dict(
            label=labels,
            pad=15,
            thickness=8,
            line=dict(color="grey", width=0.5),
        ),
        "link": dict(
            source=src,
            target=tgt,
            value=val,
        ),
    }


def transition_to_circular_sankey_df(
    counts_df: pd.DataFrame,
    *,
    k_per_token: int = 4,
    min_prob: float = 0.01,
    normalise: bool = True,
    link_opacity: float = 0.35,
    node_colour: str = "#4C72B0",
):
    """
    Convert a (possibly non-square) transition-count DataFrame to a Plotly
    Sankey dict whose nodes sit on a circle (circular “flow” diagram).

    Expected DataFrame layout
    -------------------------
    * columns = “Previous event”
    * index   = “Next event”

    Steps
    -----
    1.  Make the matrix square by adding missing rows/columns (0).
    2.  Transpose so rows = previous, cols = next.
    3.  Optionally row-normalise counts → probabilities.
    4.  Keep top-k outgoing edges per row with prob ≥ min_prob.
    5.  Place every node at an angle θ on a unit circle and output a
        Plotly-ready dict (node & link) for go.Sankey.

    Parameters
    ----------
    counts_df : pd.DataFrame
    k_per_token : int – keep this many strongest edges per event
    min_prob    : float – probability cut-off after normalising
    normalise   : bool  – if False, edges are raw counts
    link_opacity: float – RGBA alpha for links (0 … 1)
    node_colour : str   – hex for node colour

    Returns
    -------
    dict  # suitable for go.Sankey(**dict)
    """
    if not isinstance(counts_df, pd.DataFrame):
        raise TypeError("counts_df must be a pandas.DataFrame")

    # 1 ▸ square matrix (union of row+col labels)
    events = sorted(set(counts_df.columns).union(counts_df.index))
    df_sq = counts_df.reindex(index=events, columns=events, fill_value=0)

    # 2 ▸ rows = previous, cols = next
    mat = df_sq.T.to_numpy(dtype=float)      # shape (N, N); rows = previous
    n = mat.shape[0]

    # 3 ▸ optional row-normalisation
    if normalise:
        row_sums = mat.sum(axis=1, keepdims=True)
        row_sums[row_sums == 0] = 1.0        # avoid divide-by-zero
        mat = mat / row_sums

    # 4 ▸ build edge lists
    src, tgt, val = [], [], []
    for i in range(n):
        top_idx = np.argsort(mat[i])[::-1][:k_per_token]
        for j in top_idx:
            p = mat[i, j]
            if p >= min_prob and i != j:     # skip self-loops
                src.append(i)
                tgt.append(j)
                val.append(float(p))

    if not src:
        raise ValueError("No edges survived; lower `min_prob` or raise `k_per_token`.")

    # 5 ▸ node positions: equally spaced on a circle (radius 0.45, centre 0.5)
    theta = np.linspace(0, 2 * np.pi, n, endpoint=False)
    node_x = 0.5 + 0.45 * np.cos(theta)      # in Plotly Sankey, 0…1 coordinates
    node_y = 0.5 + 0.45 * np.sin(theta)

    return dict(
        arrangement="fixed",                 # keep our positions
        node=dict(
            label=events,
            x=node_x.tolist(),
            y=node_y.tolist(),
            pad=4,
            thickness=8,
            line=dict(color="grey", width=0.5),
            color=node_colour,
        ),
        link=dict(
            source=src,
            target=tgt,
            value=val,
            color=f"rgba(76,114,176,{link_opacity})",
        ),
    )

In [52]:
df = pd.read_csv(gen_data_path + "next_event_data.csv")

print(len(df["Previous event"].unique()))
print(len(df["Next event"].unique()))


# df = df[df["Generation step"] == 0]
# df = df[df["Previous event"] != "DEATH"]

df = _get_stratified_next_event_matrix(df, dm)

for ind in df.index:
    if ind not in df.columns:
        print(f"missing column {ind}")

for col in df.columns:
    if col not in df.index:
        print(f"missing index {col}")

# print(df)
# print(df.index)
# print(df.columns)

# print(df)


262
261
missing index Death


In [None]:
sankey_data = transition_to_sankey_df(
    df, k_per_token=4, min_prob=0.01, normalise=True
)

print("edges kept:", len(sankey_data["link"]["source"]))
print("unique nodes:", len(set(sankey_data["link"]["source"] + 
                              sankey_data["link"]["target"])))

import plotly.io as pio
pio.renderers.default = "notebook"   # or "browser" / "png"

fig = go.Figure(go.Sankey(**sankey_data))
fig.update_layout(title="Event transition flows")
fig.write_html("sankey.html")
# fig.show()

In [57]:
sankey_data = transition_to_circular_sankey_df(
    df,
    k_per_token=1,
    min_prob=0.1,
    normalise=True,
)
fig = go.Figure(go.Sankey(**sankey_data))
fig.update_layout(
    title="Circular event-to-event flows",
    title_x=0.5,
    margin=dict(t=40, l=20, r=20, b=20),
)
# fig.show()          # or 
fig.write_html("circle_flow.html")