# Table of Content

<a name="outline"></a>

## Setup

- [A](#seca) External Imports
- [B](#secb) Internal Imports
- [C](#secc) Configurations and Paths 
- [D](#secd) Patient Interface and Train/Val/Test Partitioning
- [E](#sece) Setup Metrics


## 1. [Load Models: Uninitialised](#models)
## 2. [Snapshot Selection](#snapshot)
## 3. [Evaluations: Predictive Performance on CPRD](#eval)


<a name="seca"></a>

### A External Imports [^](#outline)

In [None]:
import sys
import os
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from pathlib import Path
from IPython.display import display

# Install upsetplot
# !pip install UpSetPlot==0.8.0
from upsetplot import from_contents, plot, UpSet, from_indicators
import jax

jax.config.update('jax_platform_name', 'cpu')

<a name="secb"></a>

### B Internal Imports [^](#outline)

In [None]:
sys.path.append("..")


from lib import utils as U
from lib.ehr.dataset import load_dataset

%load_ext autoreload
%autoreload 2


<a name="secc"></a>

### C Configurations and Paths [^](#outline)

In [None]:

training_dir = 'cprd_artefacts/train'
output_dir = 'cprd_analysis_artefacts'

Path(output_dir).mkdir(parents=True, exist_ok=True)

In [None]:
# Assign the folder of the dataset to `DATA_FILE`.
HOME = os.environ.get('HOME')
DATA_FILE = f'{HOME}/GP/ehr-data/cprd-data/DUMMY_DATA.csv'
SOURCE_DIR = os.path.abspath("..")

with U.modified_environ(DATA_FILE=DATA_FILE):
    cprd_dataset = load_dataset('CPRD')

In [None]:
relative_auc_config = {
    'pvalue': 0.01, 
    'min_auc': 0.9
}
top_k_list=[1, 2, 3, 5, 7, 10, 15, 20]
percentile_range=20 
n_percentiles=int(100/percentile_range)


import matplotlib.font_manager as font_manager
plt.rcParams.update(plt.rcParamsDefault)
plt.rcParams.update({'font.family': 'sans-serif',
                     'font.sans-serif': 'Helvetica',
                     'font.weight':  'normal'})

<a name="secd"></a>

### D Patient Interface and Train/Val/Test Patitioning [^](#outline)

**Configurations should be matching the training notebook**

In [None]:
from lib.ehr.coding_scheme import DxLTC212FlatCodes, DxLTC9809FlatMedcodes, EthCPRD5, EthCPRD16
from lib.ehr import OutcomeExtractor, SurvivalOutcomeExtractor
from lib.ehr import Subject_JAX
from lib.ehr import StaticInfoFlags

%load_ext autoreload
%autoreload 2

code_scheme = {
    'dx': DxLTC9809FlatMedcodes(), # other options 
    'outcome': SurvivalOutcomeExtractor('dx_cprd_ltc9809'),
    # Comment above^, and uncomment below, to consider only the first occurrence of codes per subject.
    # 'outcome': SurvivalOutcomeExtractor('dx_cprd_ltc9809'),
    'eth': EthCPRD5()
}


static_info_flags = StaticInfoFlags(
 gender=True,
 age=True,
 idx_deprivation=True,
 ethnicity=EthCPRD5(), # <- include it by the category of interest, not just 'True'.
)

cprd_interface = Subject_JAX.from_dataset(cprd_dataset, code_scheme=code_scheme, static_info_flags=static_info_flags)
cprd_splits = cprd_interface.random_splits(split1=0.7, split2=0.85, random_seed=42)


<a name="sece"></a>

### E Setup Metrics [^](#outline)


In [None]:
from lib.metric import (CodeAUC, UntilFirstCodeAUC, AdmissionAUC, CodeGroupTopAlarmAccuracy, LossMetric, MetricsCollection)
# pecentile_range=20 will partition the codes into five gruops, where each group contains 
# codes that overall constitutes 20% of the codes in all visits of specified 'subjects' list.
code_freq_partitions = cprd_interface.outcome_by_percentiles(percentile_range=20, subjects=cprd_splits[0])



# Evaluate for different k values
top_k_list = [3, 5, 10, 15, 20]

metrics = {'code_auc': CodeAUC(cprd_interface),
           'code_first_auc': UntilFirstCodeAUC(cprd_interface),
           'admission_auc': AdmissionAUC(cprd_interface),
           'loss': LossMetric(cprd_interface),
           'code_group_acc': CodeGroupTopAlarmAccuracy(cprd_interface, top_k_list=top_k_list, code_groups=code_freq_partitions)}

metric_extractor = {
    'code_auc': metrics['code_auc'].aggregate_extractor({'field': 'auc', 'aggregate': 'mean'}),
    'code_first_auc': metrics['code_first_auc'].aggregate_extractor({'field': 'auc', 'aggregate': 'mean'}),
    'admission_auc': metrics['admission_auc'].aggregate_extractor({'field': 'auc', 'aggregate': 'mean'}),
    'loss': metrics['loss'].value_extractor({'field': 'focal_softmax'}),
}

<a name="models"></a>

## 1. Loading Models (Uninitialised) [^](#outline)

In [None]:
from lib.ml import ICENODE, ICENODE_UNIFORM, GRU, RETAIN, WindowLogReg
from lib.vis import models_from_configs, performance_traces, probe_model_snapshots

model_cls = {
    'ICE-NODE': ICENODE,
    'ICE-NODE_UNIFORM': ICENODE_UNIFORM,
    'GRU': GRU,
    'RETAIN': RETAIN,
    'LogReg': WindowLogReg
}       
cprd_models = models_from_configs(training_dir, model_cls, cprd_interface, cprd_splits)


<a name="snapshot"></a>


## 2. Snapshot Selection [^](#outline)

In [None]:
result = probe_model_snapshots(train_dir=training_dir, metric_extractor=metric_extractor, 
                               selection_metric='admission_auc_val', models=cprd_models)
display(result)
# Now cprd_models have the selected snapshots

<a name="eval"></a>

## 3. Predictive Performance on CPRD [^](#outline)

In [None]:
cprd_test_res = {model_key: model(cprd_interface, cprd_splits[2], dict(eval_only=True))['predictions'] 
               for model_key, model in cprd_models.items()}


In [None]:
from lib.metric import DeLongTest
from lib.vis import auc_upset

delong_metric = DeLongTest(cprd_interface)
cprd_auctests = delong_metric.to_df(cprd_test_res)

In [None]:
model_keys = list(cprd_test_res.keys())
indicator_df, (nodiff_set, diff_set) = auc_upset(delong_metric, cprd_auctests, model_keys, 
                                                 p_value=0.05, min_auc=0.7)

upset_ctx = lambda : sns.plotting_context("paper", font_scale=1.5, 
                                          rc={"font.family": "sans-serif", 
                                          'axes.labelsize': 'medium',
                                          'ytick.labelsize': 'medium'})

with sns.axes_style("darkgrid"): #, upset_ctx():
    upset_format = from_indicators(indicator_df)
    upset_object = UpSet(upset_format, subset_size='count', show_counts=True)
    
    g = upset_object.plot()
        
    current_figure = plt.gcf()
    w, h = 2.5, 3
    wi, hi = current_figure.get_size_inches()
    current_figure.set_size_inches(hi*(w/h), hi)
    current_figure.savefig(f"{output_dir}/cprd_auc_upset.pdf", bbox_inches='tight')
    plt.show()

In [13]:
from lib.vis import top_k_tables
group_acc_metric = metrics['code_group_acc']
top_k_results = {k: group_acc_metric.to_df(k, res) for k, res in cprd_test_res.items()}
top_k_dfs = top_k_tables(group_acc_metric, top_k_results)

  group_alarm_acc[k] = group_tp.sum() / group_true.sum()


In [14]:
top_k_dfs[5]['raw']

Unnamed: 0,ACC-P0-k5,ACC-P1-k5,ACC-P2-k5,ACC-P3-k5,ACC-P4-k5
GRU,0.0,,,,
ICE-NODE,0.0,,,,
ICE-NODE_UNIFORM,0.0,,,,
LogReg,0.0,,,,
RETAIN,0.0,,,,
