# Table of Content

<a name="outline"></a>

## Setup

- [A](#seca) External Imports
- [B](#secb) Internal Imports
- [C](#secc) Configurations and Paths 
- [D](#secd) JAX Interface
- [E](#sece) General Utility Functions


## Clustering

- [1](#sec2) Disease Embeddings Clustering
- [2](#sec3) Subject Embeddings Clustering

<a name="seca"></a>

### A External Imports [^](#outline)

In [1]:
%load_ext autoreload
%autoreload 2
import sys
import os
import glob
import random
from collections import defaultdict
from pathlib import Path

from IPython.display import display

import pandas as pd

from tqdm import tqdm
import jax
jax.config.update('jax_platform_name', 'cpu')

<a name="secb"></a>

### B Internal Imports [^](#outline)

In [2]:


sys.path.append("..")

from lib import utils as U
from lib.ehr.dataset import load_dataset

# Assign the folder of the dataset to `DATA_FILE`.

HOME = os.environ.get('HOME')
DATA_STORE = f'{HOME}/Documents/DS211/users/tb1009/DATA'
DATA_FILE = os.path.join(DATA_STORE, 'ICE_TEST_1000.csv')
SOURCE_DIR = os.path.abspath("..")

<a name="secc"></a>

### C Configurations and Paths [^](#outline)

In [3]:
output_dir = 'cprd_artefacts'
Path(output_dir).mkdir(parents=True, exist_ok=True)


In [4]:
with U.modified_environ(DATA_FILE=DATA_FILE):
    cprd_dataset = load_dataset('CPRD')
   

In [5]:
output_dir = 'cprd_clustering_artefacts'
Path(output_dir).mkdir(parents=True, exist_ok=True)


<a name="secd"></a>

### D JAX Interface [^](#outline)

### Configuration should match the same configuration used in training in `cprd2_dx_training.ipynb`

In [6]:

%load_ext autoreload
%autoreload 2

from lib.ehr.coding_scheme import DxLTC212FlatCodes, DxLTC9809FlatMedcodes, EthCPRD5, EthCPRD16
from lib.ehr import OutcomeExtractor, FirstOccurrenceOutcomeExtractor
from lib.ehr import Subject_JAX
from lib.ehr import StaticInfoFlags

%load_ext autoreload
%autoreload 2

code_scheme = {
    #'dx': DxLTC9809FlatMedcodes(), # other options 
    'dx': DxLTC212FlatCodes(),
    #'outcome': OutcomeExtractor('dx_cprd_ltc9809'),
    'outcome': FirstOccurrenceOutcomeExtractor('dx_cprd_ltc212'),
    # Comment above^, and uncomment below, to consider only the first occurrence of codes per subject.
    # 'outcome': FirstOccurrenceOutcomeExtractor('dx_cprd_ltc9809'),
    'eth': EthCPRD5()
}

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [7]:

static_info_flags = StaticInfoFlags(
 gender=True,
 age=True,
 idx_deprivation=True,
 ethnicity=EthCPRD5(), # <- include it by the category of interest, not just 'True'.
)
cprd_interface = Subject_JAX.from_dataset(cprd_dataset, code_scheme=code_scheme, static_info_flags=static_info_flags)
cprd_all_subjects = list(cprd_interface.keys())
cprd_splits = cprd_interface.random_splits(split1=0.7, split2=0.85, random_seed=42)

<a name="sec1"></a>

## 1 Disease Embeddings Clustering on CPRD [^](#outline)

In [8]:


# Should be the same one used in JAX interface in the training notebook.
#dx_scheme = DxLTC9809FlatMedcodes()
dx_scheme = DxLTC212FlatCodes()


In [9]:
# scheme indices (textual code -> integer index)
dx_scheme.index

# reverse index (integer index -> textual code)
idx2code = {idx: code for code, idx in dx_scheme.index.items()}

### 1.A GloVe Based Disease Embeddings

Get the coocurrence matrix

In [None]:
# Time-window context coocurrence
cprd_cooc_timewin = cprd_interface.dx_coocurrence(cprd_all_subjects, window_size_days=365)

# Sequence context coocurrence
cprd_cooc_seqwin = cprd_interface.dx_coocurrence(cprd_all_subjects, context_size=20)

from lib.embeddings import train_glove

cprd_glove_timewin = train_glove(cprd_cooc_timewin, embeddings_size=100, iterations=500, prng_seed=0)
cprd_glove_seqwin = train_glove(cprd_cooc_seqwin, embeddings_size=100, iterations=500, prng_seed=0)

cprd_glove_timewin

In [None]:
pd.DataFrame(cprd_cooc_timewin)

In [None]:
pd.DataFrame(cprd_cooc_seqwin)

In [None]:
pd.DataFrame(cprd_cooc_seqwin, index=dx_scheme.codes)

In [15]:
cprd_cooc_seqwin.shape

(212, 212)

In [22]:
cprd_glove_seqwin

array([[-2.79684767e-01, -4.44653353e-01, -3.67274849e-01, ...,
         4.62476977e-01, -7.95394904e-01,  4.33272337e-01],
       [-2.96603468e-01, -3.61163169e-02, -3.38016470e-01, ...,
        -4.86751560e-01, -8.11624934e-01,  4.42831558e-01],
       [-7.28966016e-01, -6.16700163e-01, -5.44311132e-01, ...,
        -2.69560831e-01, -9.91197410e-02,  7.03036940e-01],
       ...,
       [ 8.87873506e-04,  3.02664180e-03,  3.42072664e-03, ...,
         4.32996637e-03,  1.55669944e-03,  1.99661385e-03],
       [-2.67907123e-01, -2.00531299e-01, -4.19550499e-01, ...,
         5.67781934e-01, -5.95530259e-02,  4.87793229e-01],
       [ 2.59708709e-03,  1.36179815e-05, -4.54081004e-03, ...,
         1.42656857e-03,  4.44432416e-03,  6.02877876e-03]])

### 1.B Predictor Based Disease Embeddings

TODO

### 1.C Predictor Based Subject Embeddings

TODO

In [32]:


# def embeddings_dictionary(clf):
#     model, state = cprd_predictors[clf]
#     params = model.get_params(state)
#     # Embeddings Mat
#     dx_G = model.dx_emb.compute_embeddings_mat(params['dx_emb'])

#     embeddings_dict = {}
#     for code, idx in dx_scheme.index.items():
#         in_vec = np.zeros((cprd_interface.dx_dim, ))
#         in_vec[idx] = 1.
#         out_vec = model.dx_emb.encode(dx_G, in_vec)
#         embeddings_dict[code] = out_vec
#     return embeddings_dict

# icenode_emb = embeddings_dictionary('ICE-NODE')
# icenode_uni_emb = embeddings_dictionary('ICE-NODE_UNIFORM')
# retain_emb = embeddings_dictionary('RETAIN')
# gru_emb = embeddings_dictionary('GRU')


# def subject_embeddings_dictionary(clf):
#     model, state = cprd_predictors[clf]
#     # All subjects in the study are passed
#     return model.subject_embeddings(state, cprd_interface.subjects)

# icenode_subj_emb = subject_embeddings_dictionary('ICE-NODE')
# icenode_subj_uni_emb = subject_embeddings_dictionary('ICE-NODE_UNIFORM')
# retain_subj_emb = subject_embeddings_dictionary('RETAIN')
# gru_subj_emb = subject_embeddings_dictionary('GRU')

  leaves, treedef = jax.tree_flatten(tree)
  return jax.tree_unflatten(treedef, leaves)
