# Table of Content

<a name="outline"></a>

## Setup

- [A](#seca) External Imports
- [B](#secb) Internal Imports
- [C](#secc) Configurations and Paths 
- [D](#secd) Patient Interface and Train/Val/Test Partitioning
- [E](#sece) Setup Metrics


## 1. [Load Models: Uninitialised](#models)
## 2. [Snapshot Selection](#snapshot)
## 3. [Disease Embeddings Clustering](#disease-clusters)
## 4. [Subject Embeddings Clustering](#subject-clusters)


<a name="seca"></a>

### A External Imports [^](#outline)

In [1]:
import sys
import os
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from pathlib import Path
from IPython.display import display
import jax

jax.config.update('jax_platform_name', 'cpu')

<a name="secb"></a>

### B Internal Imports [^](#outline)

In [2]:
sys.path.append("..")


from lib import utils as U
from lib.ehr.dataset import load_dataset

%load_ext autoreload
%autoreload 2


<a name="secc"></a>

### C Configurations and Paths [^](#outline)

In [3]:
training_dir = 'cprd_artefacts/train'
output_dir = 'cprd_clustering_artefacts'

Path(output_dir).mkdir(parents=True, exist_ok=True)

In [4]:
# Assign the folder of the dataset to `DATA_FILE`.
HOME = os.environ.get('HOME')
DATA_FILE = f'{HOME}/GP/ehr-data/cprd-data/DUMMY_DATA.csv'
SOURCE_DIR = os.path.abspath("..")

with U.modified_environ(DATA_FILE=DATA_FILE):
    cprd_dataset = load_dataset('CPRD')

<a name="secd"></a>

### D Patient Interface and Train/Val/Test Patitioning [^](#outline)

**Configurations should be matching the training notebook**

In [5]:
from lib.ehr.coding_scheme import DxLTC212FlatCodes, DxLTC9809FlatMedcodes, EthCPRD5, EthCPRD16
from lib.ehr import OutcomeExtractor, SurvivalOutcomeExtractor
from lib.ehr import Subject_JAX
from lib.ehr import StaticInfoFlags

%load_ext autoreload
%autoreload 2

code_scheme = {
    'dx': DxLTC9809FlatMedcodes(), # other options 
    'outcome': SurvivalOutcomeExtractor('dx_cprd_ltc9809'),
    # Comment above^, and uncomment below, to consider only the first occurrence of codes per subject.
    # 'outcome': SurvivalOutcomeExtractor('dx_cprd_ltc9809'),
    'eth': EthCPRD5()
}


static_info_flags = StaticInfoFlags(
 gender=True,
 age=True,
 idx_deprivation=True,
 ethnicity=EthCPRD5(), # <- include it by the category of interest, not just 'True'.
)

cprd_interface = Subject_JAX.from_dataset(cprd_dataset, code_scheme=code_scheme, static_info_flags=static_info_flags)
cprd_splits = cprd_interface.random_splits(split1=0.7, split2=0.85, random_seed=42)


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


<a name="sece"></a>

### E Setup Metrics [^](#outline)


In [6]:
from lib.metric import (CodeAUC, UntilFirstCodeAUC, AdmissionAUC, CodeGroupTopAlarmAccuracy, LossMetric, MetricsCollection)
# pecentile_range=20 will partition the codes into five gruops, where each group contains 
# codes that overall constitutes 20% of the codes in all visits of specified 'subjects' list.
code_freq_partitions = cprd_interface.outcome_by_percentiles(percentile_range=20, subjects=cprd_splits[0])



# Evaluate for different k values
top_k_list = [3, 5, 10, 15, 20]

metrics = {'code_auc': CodeAUC(cprd_interface),
           'code_first_auc': UntilFirstCodeAUC(cprd_interface),
           'admission_auc': AdmissionAUC(cprd_interface),
           'loss': LossMetric(cprd_interface),
           'code_group_acc': CodeGroupTopAlarmAccuracy(cprd_interface, top_k_list=top_k_list, code_groups=code_freq_partitions)}

metric_extractor = {
    'code_auc': metrics['code_auc'].aggregate_extractor({'field': 'auc', 'aggregate': 'mean'}),
    'code_first_auc': metrics['code_first_auc'].aggregate_extractor({'field': 'auc', 'aggregate': 'mean'}),
    'admission_auc': metrics['admission_auc'].aggregate_extractor({'field': 'auc', 'aggregate': 'mean'}),
    'loss': metrics['loss'].value_extractor({'field': 'focal_softmax'}),
}

<a name="models"></a>

## 1. Loading Models (Uninitialised) [^](#outline)

In [7]:
from lib.ml import ICENODE, ICENODE_UNIFORM, GRU, RETAIN, WindowLogReg
from lib.vis import models_from_configs, performance_traces, probe_model_snapshots

model_cls = {
    'ICE-NODE': ICENODE,
    'ICE-NODE_UNIFORM': ICENODE_UNIFORM,
    'GRU': GRU,
    'RETAIN': RETAIN,
    'LogReg': WindowLogReg
}       
cprd_models = models_from_configs(training_dir, model_cls, cprd_interface, cprd_splits)


<a name="snapshot"></a>


## 2. Snapshot Selection [^](#outline)

In [8]:
result = probe_model_snapshots(train_dir=training_dir, metric_extractor=metric_extractor, 
                               selection_metric='admission_auc_val', models=cprd_models)
display(result)

# Now cprd_models have the selected snapshots

Unnamed: 0,model,code_auc_idx,code_auc_val,code_first_auc_idx,code_first_auc_val,admission_auc_idx,admission_auc_val,loss_idx,loss_val
ICE-NODE_UNIFORM,ICE-NODE_UNIFORM,-1,,-1,,59,0.404644,59,0.000969
RETAIN,RETAIN,-1,,-1,,58,0.709439,35,0.000968
LogReg,LogReg,-1,,-1,,0,0.499683,0,0.000967
ICE-NODE,ICE-NODE,-1,,-1,,59,0.574464,59,0.000966
GRU,GRU,-1,,-1,,2,0.896452,5,0.000935


<a name="disease-clusters"></a>

## 3. Disease Embeddings Clustering on CPRD [^](#outline)

In [9]:


# Should be the same one used in JAX interface in the training notebook.
dx_scheme = DxLTC9809FlatMedcodes()


In [10]:
# scheme indices (textual code -> integer index)
dx_scheme.index

# reverse index (integer index -> textual code)
idx2code = {idx: code for code, idx in dx_scheme.index.items()}

### 1.A GloVe Based Disease Embeddings

Get the coocurrence matrix

In [12]:
cprd_all_subjects = (cprd_interface.keys())
# Time-window context coocurrence
cprd_cooc_timewin = cprd_interface.dx_coocurrence(cprd_all_subjects, window_size_days=365)

# Sequence context coocurrence
cprd_cooc_seqwin = cprd_interface.dx_coocurrence(cprd_all_subjects, context_size=20)

from lib.embeddings import train_glove

cprd_glove_timewin = train_glove(cprd_cooc_timewin, embeddings_size=100, iterations=500, prng_seed=0)
cprd_glove_seqwin = train_glove(cprd_cooc_seqwin, embeddings_size=100, iterations=500, prng_seed=0)

cprd_glove_timewin

array([[-3.24278208e-03,  1.28678871e-03, -2.74318569e-03, ...,
        -3.05468698e-04,  1.12339146e-03,  2.25189050e-03],
       [-2.98757527e-03, -7.03690611e-04, -8.10092430e-05, ...,
        -6.94691734e-04, -7.55079060e-03, -3.95144619e-03],
       [-8.14448813e-03,  1.39618211e-03, -7.77809797e-03, ...,
         6.70216424e-03,  6.35695556e-03, -7.93471806e-03],
       ...,
       [ 8.19550666e-03,  4.26941973e-03,  4.60743972e-03, ...,
        -5.78689374e-04,  6.21998103e-03,  5.35619514e-03],
       [ 2.05566253e-03, -1.13631646e-03, -4.87536546e-03, ...,
        -6.46422558e-03, -3.79417553e-03,  3.80826087e-03],
       [ 1.43140791e-03, -1.40970352e-03,  8.06397555e-03, ...,
        -1.22326184e-03,  7.43590727e-03, -4.97565251e-03]])

### 1.B Predictor Based Disease Embeddings



In [15]:
def disease_embeddings_dictionary(model):
    model = cprd_models[model]
    
    # Code history
    dx_for_emb = cprd_interface.dx_batch_history_vec(cprd_all_subjects)
    # Embeddings Mat
    dx_G = model.dx_emb.compute_embeddings_mat(dx_for_emb)

    embeddings_dict = {}
    for code, idx in dx_scheme.index.items():
        in_vec = np.zeros((cprd_interface.dx_dim, ))
        in_vec[idx] = 1.
        out_vec = model.dx_emb.encode(dx_G, in_vec)
        embeddings_dict[code] = out_vec
    return embeddings_dict

icenode_emb = disease_embeddings_dictionary('ICE-NODE')
icenode_uni_emb = disease_embeddings_dictionary('ICE-NODE_UNIFORM')
retain_emb = disease_embeddings_dictionary('RETAIN')
gru_emb = disease_embeddings_dictionary('GRU')

<a name="subject-clusters"></a>

## 4. Subject Embeddings Clustering on CPRD [^](#outline)

In [18]:
def subject_embeddings_dictionary(model):
    
    model = cprd_models[model]
    # All subjects in the study are passed
    return model.subject_embeddings(cprd_interface, cprd_all_subjects)

icenode_subj_emb = subject_embeddings_dictionary('ICE-NODE')
icenode_subj_uni_emb = subject_embeddings_dictionary('ICE-NODE_UNIFORM')
retain_subj_emb = subject_embeddings_dictionary('RETAIN')
gru_subj_emb = subject_embeddings_dictionary('GRU')