# Table of Content

<a name="outline"></a>

## Setup

- [A](#seca) External Imports
- [B](#secb) Internal Imports
- [C](#secc) Configurations and Paths 
- [D](#secd) JAX Interface
- [E](#sece) General Utility Functions


## Training

- [1](#sec1) Training ICE-NODE and The Baselines


In [5]:
# read in data
import datetime
import pandas as pd
import numpy as np
import time

code_root = "///home//tb1009//Documents//DS211//users/tb1009//CODES//"
tb_data_root = "///home//tb1009//Documents//DS211//users/tb1009//DATA//"
results_root = "///home//tb1009//Documents//DS211//users/tb1009//RESULTS//"
embedding_root = "///home//tb1009//Documents//DS211//users/tb1009//DATA//EMBEDDINGS//"
ICE_root = "///home//tb1009//Documents//DS211//users/tb1009//GIT//ICE-NODE-main//notebooks//"

cohort = 'cohort_2020'
censor_date = datetime.datetime.strptime('2020-01-01', '%Y-%m-%d')

# select ethnic categories to use (5 or 16)
ethnic_cats = 5

# select patients

# read in patient dataframe from 2015 or 2020 cohort
pat_cols = ['patid','yob','gender','censordate',cohort]
df_pat_all = pd.read_csv(tb_data_root+'all_patients.csv', usecols=pat_cols, dtype=str)
# keep if in cohort
df_pat_all = df_pat_all.loc[df_pat_all[cohort]=='1']
del df_pat_all[cohort]
print(len(df_pat_all))
# remove if less than 1 year follow-up
#df_pat_all['regstartdate'] = pd.to_datetime(df_pat_all['regstartdate'], format = '%Y-%m-%d')
#df_pat_all = df_pat_all.loc[df_pat_all['regstartdate']<=register_date]
#print(len(df_pat_all))
df_pat_id = df_pat_all['patid'].to_frame()
print(len(df_pat_id))

# read in observation file
obs_cols = ['patid','obsdate','enterdate', 'medcodeid']
df = pd.read_csv(tb_data_root+'observations_long.csv', nrows=200000, usecols=obs_cols, dtype='str')

# merge with patient index
df = df_pat_all.merge(df, on='patid', how='left')
print(len(df))
print(df['patid'].nunique())

# remove observations after censordate
df['obsdate'] = pd.to_datetime(df['obsdate'], format = '%Y-%m-%d')
df['enterdate'] = pd.to_datetime(df['enterdate'], format = '%Y-%m-%d')
df['censordate'] = pd.to_datetime(df['censordate'], format = '%Y-%m-%d')
df = df.loc[df['obsdate']<=censor_date]
print(len(df))
df = df.loc[df['enterdate']<=censor_date]
print(len(df))
del df['enterdate']
del df['censordate']

# drop duplicates (note original dataframe retained those with differing enter dates)
df = df.drop_duplicates()
print(len(df))

# sort
df = df.sort_values(by=['patid','obsdate'])
df = df.loc[df['medcodeid'].notna()]
print(len(df))
print(df['patid'].nunique())

# calculate age, rounded, as string
df['yob'] = pd.to_datetime(df['yob'], format = '%Y-%m-%d')
df['age'] = (df['obsdate'] - df['yob']).astype('timedelta64[Y]')
df['age'] = df['age'].astype(int).astype(str)

# Step 2a: group by patid and aggregate diseases to list
codes = df.groupby(['patid'])['medcodeid'].apply(', '.join).reset_index(name='medcode_list')

# Step 2b: group by patid and aggregate obsdates to list
df['obsdate_ym'] = df['obsdate'].dt.to_period('M')
df['obsdate_ym'] = df['obsdate_ym'].astype('str')
obsdates = df.groupby(['patid'])['obsdate_ym'].apply(', '.join).reset_index(name='obsdate_list')

# Step 2c: group by patid and aggregate age to list
age = df.groupby(['patid'])['age'].apply(', '.join).reset_index(name='age_list')

# Merge together
df_all = codes.merge(obsdates, on='patid')
df_all = df_all.merge(age, on='patid')
df_all.columns = ['patid','medcode_list','year_month_list','age_list']

# merge back in gender and encode strings (missing = 9)
df_gen = df[['patid','gender']].drop_duplicates()
df_gen.columns = ['patid','gender_str']
df_gen['gender'] = 9
df_gen.loc[df_gen['gender_str']=='F', 'gender'] = 0
df_gen.loc[df_gen['gender_str']=='M', 'gender'] = 1
df_gen.loc[df_gen['gender_str']=='I', 'gender'] = 2
del df_gen['gender_str']
df_all = df_all.merge(df_gen, on='patid', how='left')

# merge in ethnicity data and encode (missing = 99)
ethnicity = pd.read_csv(tb_data_root+'ethnicity.csv', dtype='str')
if ethnic_cats == 5:
    ethnicity = ethnicity[['patid','eth5']]
elif ethnic_cats == 16:
    ethnicity = ethnicity[['patid','eth16']]
else:
    raise Exception("Ethnicity categories to use not stated")
ethnicity.columns = ['patid','ethnicity']
ethnicity['ethnicity'] = ethnicity['ethnicity'].str.extract('(\d+)').astype(int)
df_all = df_all.merge(ethnicity, on='patid', how='left')
df_all['ethnicity'] = df_all['ethnicity'].replace(np.nan, 99)

# merge in IMD (missing = 99)
imd = pd.read_csv(tb_data_root+'imd.csv', dtype='str')
imd = imd[['patid','pat_imd10']]
imd.columns = ['patid','imd_decile']
df_all = df_all.merge(imd, on=['patid'], how='left')
df_all['imd_decile'] = df_all['imd_decile'].replace(np.nan, 99)
print(str(df_all['patid'].nunique()))

# save out
df_all.to_csv(tb_data_root+'ICE_TEST.csv', index=False, sep="\t")
df_all

10648304
10648304
10796125
10648304
135353
134986
125106
125106
4823
4823


Unnamed: 0,patid,medcode_list,year_month_list,age_list,gender,ethnicity,imd_decile
0,100650520364,"294725013, 294725013, 294725013, 294725013, 35...","2000-08, 2001-01, 2001-01, 2001-04, 2001-05, 2...","54, 55, 55, 55, 55, 55, 55, 55, 56, 58, 58, 58...",1,0.0,7
1,100765120364,"308725015, 308725015, 308725015, 259233017, 15...","1998-12, 2006-11, 2008-09, 2009-11, 2009-11, 2...","44, 52, 54, 55, 55, 58",0,0.0,8
2,100787620364,"84230010, 99042012, 64168014, 99042012, 641680...","1990-10, 1999-05, 2000-01, 2000-01, 2000-05, 2...","33, 42, 43, 43, 43, 43, 43, 43, 43, 44, 44, 44...",1,0.0,9
3,100838620364,"348110010, 398561000006117, 294621000000118, 3...","2011-07, 2016-07, 2017-01, 2017-07","16, 21, 22, 22",1,99.0,8
4,100990720364,"302633015, 353834010, 1494848017, 121741000006...","1973-03, 1973-03, 1997-04, 1997-05, 1997-12, 2...","47, 47, 71, 71, 71, 75, 77, 78, 78, 78, 78, 78...",0,0.0,8
...,...,...,...,...,...,...,...
4818,992764820468,"82343012, 141306010, 216207010, 82343012, 2564...","2010-11, 2010-12, 2010-12, 2016-02, 2016-02, 2...","67, 67, 67, 73, 73, 73",1,99.0,2
4819,992911520468,"121589010, 197761014, 405339016, 405339016, 71...","2011-01, 2011-02, 2012-01, 2012-07, 2012-09, 2...","61, 61, 61, 62, 62, 62, 64, 64, 64, 64, 64, 64...",1,0.0,1
4820,99437920364,"1776213016, 146927011","1998-12, 2009-05","2, 13",1,0.0,4
4821,99612520364,"308725015, 221511000000115, 150921000006118, 2...","1987-01, 1990-01, 1999-08, 2001-01, 2001-01, 2...","49, 53, 62, 64, 64, 64, 64, 65, 66, 66, 66, 68...",1,0.0,4


In [22]:
import pandas as pd
tb_data_root = "///home//tb1009//Documents//DS211//users/tb1009//DATA//"
df = pd.read_csv(tb_data_root+'ICE_TEST.csv', sep="\t", dtype={'medcode_list':str})
df = df[0:1000]
df.to_csv(tb_data_root+'ICE_TEST_1000.csv', index=False, sep="\t")
df

Unnamed: 0,patid,medcode_list,year_month_list,age_list,gender,ethnicity,imd_decile
0,100650520364,"294725013, 294725013, 294725013, 294725013, 35...","2000-08, 2001-01, 2001-01, 2001-04, 2001-05, 2...","54, 55, 55, 55, 55, 55, 55, 55, 56, 58, 58, 58...",1,0.0,7
1,100765120364,"308725015, 308725015, 308725015, 259233017, 15...","1998-12, 2006-11, 2008-09, 2009-11, 2009-11, 2...","44, 52, 54, 55, 55, 58",0,0.0,8
2,100787620364,"84230010, 99042012, 64168014, 99042012, 641680...","1990-10, 1999-05, 2000-01, 2000-01, 2000-05, 2...","33, 42, 43, 43, 43, 43, 43, 43, 43, 44, 44, 44...",1,0.0,9
3,100838620364,"348110010, 398561000006117, 294621000000118, 3...","2011-07, 2016-07, 2017-01, 2017-07","16, 21, 22, 22",1,99.0,8
4,100990720364,"302633015, 353834010, 1494848017, 121741000006...","1973-03, 1973-03, 1997-04, 1997-05, 1997-12, 2...","47, 47, 71, 71, 71, 75, 77, 78, 78, 78, 78, 78...",0,0.0,8
...,...,...,...,...,...,...,...
995,2246022820178,"99042012, 99042012, 259233017, 150921000006118...","1997-01, 2001-04, 2001-10, 2001-10, 2002-05, 2...","58, 62, 62, 62, 63, 64, 64, 64, 64, 65, 65, 66...",1,0.0,5
996,2246058520178,"376691000006116, 99042012, 150921000006118, 99...","1980-01, 2002-02, 2006-05, 2007-08, 2007-08, 2...","42, 65, 69, 70, 70, 72, 73, 73, 73, 73, 75, 76...",0,0.0,6
997,2246309020178,1746951000000114,1997-09,21,0,0.0,5
998,2246327620178,"18666015, 289190019, 18181000006118, 177747801...","1997-02, 1998-08, 2006-08, 2006-11, 2009-10, 2...","63, 64, 72, 72, 75, 75, 78, 79, 79, 79, 79, 79...",0,0.0,2


In [25]:
import numpy as np
np.array(df['medcode_list'])

array(['294725013, 294725013, 294725013, 294725013, 353834010, 119655018, 396340013, 294725013, 215851000000112, 215851000000112, 215851000000112, 215851000000112, 294725013, 215851000000112, 294725013, 215851000000112, 452891000006114, 215851000000112, 294725013, 259232010, 259233017, 145471000006115, 150921000006118, 294913019, 294913019, 294913019, 294913019, 294725013, 294725013, 294913019, 294725013, 294725013, 294725013, 294725013, 294725013, 9225016, 294725013, 9225016, 100716012, 150921000006118, 259233017, 145471000006115, 353143016, 353143016, 259232010, 100716012, 9225016, 100716012, 9225016, 9225016, 9225016, 298083015, 9225016, 1479675015, 217851000006114, 298076010',
       '308725015, 308725015, 308725015, 259233017, 150921000006118, 308725015',
       '84230010, 99042012, 64168014, 99042012, 64168014, 64168014, 99042012, 99042012, 99042012, 99042012, 99042012, 99042012, 99042012, 99042012, 99042012, 99042012, 259233017, 150921000006118, 145471000006115, 99042012, 145471

In [19]:
df.dtypes
# df['medcode_list']

patid                int64
medcode_list        object
year_month_list     object
age_list            object
gender               int64
ethnicity          float64
imd_decile           int64
dtype: object

In [26]:
df['year_month_list'].apply(lambda l: len(set(l))).min()

4

<a name="seca"></a>

### A External Imports [^](#outline)

In [None]:
# pip install --upgrade jax==0.4.1 jaxlib==0.4.1+cuda11.cudnn82 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

In [1]:
# set pre-allocated percentage of GPU memory (here using 20% of the 48GB)
%env XLA_PYTHON_CLIENT_MEM_FRACTION=0.2

env: XLA_PYTHON_CLIENT_MEM_FRACTION=0.2


In [1]:
%load_ext autoreload
%autoreload 2
import sys
import os
import glob
import random
from collections import defaultdict
from pathlib import Path

from IPython.display import display

import pandas as pd

from tqdm import tqdm
import jax
jax.config.update('jax_platform_name', 'gpu')

<a name="secb"></a>

### B Internal Imports [^](#outline)

In [2]:


sys.path.append("..")

from lib import utils as U
from lib.ehr.dataset import load_dataset

In [3]:
# Assign the folder of the dataset to `DATA_FILE`.

# HOME and DATA_STORE are arbitrary, change as appropriate.
HOME = os.environ.get('HOME')
DATA_STORE = f'{HOME}/Documents/DS211/users/tb1009/DATA'
DATA_FILE = os.path.join(DATA_STORE, 'ICE_TEST_1000.csv')
SOURCE_DIR = os.path.abspath("..")

#SOURCE_DIR = f'{HOME}/Documents/DS211/users/tb1009/GIT/ICE-NODE'

<a name="secd"></a>

### C Configurations and Paths [^](#outline)

In [4]:
output_dir = 'cprd_artefacts'
Path(output_dir).mkdir(parents=True, exist_ok=True)

In [5]:
with U.modified_environ(DATA_FILE=DATA_FILE):
    cprd_dataset = load_dataset('CPRD')
   

In [6]:
from lib.ml import ICENODE, ICENODE_UNIFORM, GRU, RETAIN, WindowLogReg

"""
predefined hyperparams re: each model.
"""

model_cls = {
    'ICE-NODE': ICENODE,
    'ICE-NODE_UNIFORM': ICENODE_UNIFORM,
    'GRU': GRU,
    'RETAIN': RETAIN,
    'LogReg': WindowLogReg
}

model_config = {
    'ICE-NODE': f'{SOURCE_DIR}/expt_configs/icenode.json' ,
    'ICE-NODE_UNIFORM': f'{SOURCE_DIR}/expt_configs/icenode.json' ,
    'GRU': f'{SOURCE_DIR}/expt_configs/gru.json' ,
    'RETAIN': f'{SOURCE_DIR}/expt_configs/retain.json',
    'LogReg': f'{SOURCE_DIR}/expt_configs/window_logreg.json'
}

model_config = {clf: U.load_config(file) for clf, file in model_config.items()}

clfs = ['ICE-NODE', 'ICE-NODE_UNIFORM', 'GRU', 'RETAIN', 'LogReg']

In [7]:
cprd_train_output_dir = {clf: f'{output_dir}/train/{clf}' for clf in clfs}

[Path(d).mkdir(parents=True, exist_ok=True) for d in cprd_train_output_dir.values()]

[None, None, None, None, None]

In [8]:
from lib.ml import ConfigDiskWriter, MinibatchLogger, EvaluationDiskWriter, ParamsDiskWriter

The reporter objects are called inside training iterations
Each has its own functionality:
1. **ConfigDiskWriter**: writes the experiment config file as JSON in the same training directory
2. **MinibatchLogger**: writes to the console the training progress details.
3. **EvaluationDiskWriter**: writes the evaluation as csv tables in the same training directory for each step of the 100.
4. **ParamsDiskWriter**: writes the model parameters snapshot at each step out of 100.

In [9]:
make_reporters = lambda output_dir, config: [ConfigDiskWriter(output_dir=output_dir, config=config),
                                             MinibatchLogger(config),
                                             EvaluationDiskWriter(output_dir=output_dir),
                                             ParamsDiskWriter(output_dir=output_dir)]

reporters = {model: make_reporters(cprd_train_output_dir[model], model_config[model]) for model in clfs}

<a name="sece"></a>

### D JAX Interface [^](#outline)

In [10]:
from lib.ehr.coding_scheme import DxLTC212FlatCodes, DxLTC9809FlatMedcodes, EthCPRD5, EthCPRD16
from lib.ehr import OutcomeExtractor, FirstOccurrenceOutcomeExtractor
from lib.ehr import Subject_JAX
from lib.ehr import StaticInfoFlags

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


The dictionary `code_scheme` in the next cell specifies the code spaces of:
- 'dx': diagnostic input (input features) codes. Possible arguments: :
    - `DxLTC9809FlatMedcodes()` for medcodes or
    - `DxLTC212FlatCodes()` for disease nums. 
- 'outcome': diagnostic outcome (target to be predicted) codes. Possible arguments: 
    - `OutcomeExtractor('dx_cprd_ltc9809')` for medcodes prediction or 
    - `OutcomeExtractor('dx_cprd_ltc212')` for disease num predictions, or 
    - `FirstOccurrenceOutcomeExtractor('dx_cprd_ltc9809')` for medcodes prediction (first occurrence per patient) or 
    - `FirstOccurrenceOutcomeExtractor('dx_cprd_ltc212')` for disease num predictions (first occurrence per patient), or 
- 'eth': ethinicity codes. Possible arguments:
    - `EthCPRD16()` to consider the 16 classifications of ethnicity.
    - `EthCPRD5()` to consider the 5 classifications of ethnicity.
    

**Note**: OutcomeExtractor can be provided only a subset of the diagnostic codes. For example,
you can focus the prediction objective on a small subset of interest (e.g. to predict only pulmonary-heart 
diseases codes, etc.).
OutcomeExtractor can also be replaced by FirstOccurrenceOutcomeExtractor to enforce the prediction 
objective to predict only the first occurrence of each code for one patient, and subsequent
redundant occurences will be avoided and not incorporated in the loss function.

In [11]:
code_scheme = {
    #'dx': DxLTC9809FlatMedcodes(), # other options 
    'dx': DxLTC212FlatCodes(),
    #'outcome': OutcomeExtractor('dx_cprd_ltc9809'),
    'outcome': FirstOccurrenceOutcomeExtractor('dx_cprd_ltc212'),
    # Comment above^, and uncomment below, to consider only the first occurrence of codes per subject.
    # 'outcome': FirstOccurrenceOutcomeExtractor('dx_cprd_ltc9809'),
    'eth': EthCPRD5()
}

### Adding Demographic Information in Training

What do you need to include as control features? **Uncomment each line to consider the corresponding static information.**

In [12]:
static_info_flags = StaticInfoFlags(
 gender=True,
 age=True,
 idx_deprivation=True,
 ethnicity=EthCPRD5(), # <- include it by the category of interest, not just 'True'.
)
cprd_interface = Subject_JAX.from_dataset(cprd_dataset, code_scheme=code_scheme, static_info_flags=static_info_flags)
cprd_splits = cprd_interface.random_splits(split1=0.7, split2=0.85, random_seed=42)

In [13]:
import jax.random as jrandom
import lib.ml as ml
%load_ext autoreload
%autoreload 2
key = jrandom.PRNGKey(0)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In the next cell we load a dictionary for each model specifiying the experiment configuration per model.
The classname of the trainer used is also specified in the experiment configs.
For example, this is the configuration file of ICE-NODE experiment.

```json
{
    "emb": {
        "dx": {
           "decoder_n_layers": 2,
           "classname":  "MatrixEmbeddings",      
           "embeddings_size": 300
        }
    },
    "model": {
        "ode_dyn_label": "mlp3",
        "ode_init_var": 1e-7,
        "state_size": 30,
        "timescale": 30
    },
    "training": {
        "batch_size": 256,
        "decay_rate": [0.25, 0.33],
        "lr": [7e-5,  1e-3],
        "epochs": 60,
        "reg_hyperparams": {
            "L_dyn": 1000.0,
            "L_l1": 0,
            "L_l2": 0
        },
        "opt": "adam",
        "classname": "ODETrainer2LR" <---- "classname, so this class should be available through ml package."
    }
}
```

Since we have a string of the classname, one way to get `ml.ODETrainer2LR` is `getattr(ml, 'ODETrainer2LR')`

In [14]:
cprd_models = {clf: model_cls[clf].from_config(model_config[clf],
                                              cprd_interface,
                                              cprd_splits[0],
                                              key) for clf in clfs}




cprd_trainers_cls = {clf: getattr(ml, model_config[clf]["training"]["classname"]) for clf in clfs}
cprd_trainers = {clf: cprd_trainers_cls[clf](**model_config[clf]["training"]) for clf in clfs}

## Metrics of Interest Specification

In [15]:
from lib.metric import (CodeAUC, UntilFirstCodeAUC, AdmissionAUC, CodeGroupTopAlarmAccuracy, MetricsCollection)

## Evaluation Metrics per Model

1. *CodeAUC*: evaluates the prediction AUC per code (aggregating over all subject visits, for all subjects)
2. *UntilFirstCodeAUC*: same as *CodeAUC*, but evaluates the prediction AUC until the first occurrence for each subject, once the code occured, all the subsequent visits are ignored for that code. If the code does not show in a particular subject, all the subject visits are ignored.
3. *AdmissionAUC*: evaluates the prediction AUC per visit (i.e. probability of assigning higher risk values for present codes than the absent ones).
4. *CodeGroupTopAlarmAccuracy*: partition codes into groups according the code frequency (from the most frequent to the least), and for each visit picks the top `k` risks, and the metric evaluates the accuracy of the top `k` riskiest codes by the model for being indeed present.
5. *MetricsCollection*: Groups multiple metrics to be considered at once.

**Note:** you will get different results when calling the method `outcome_by_percentiles` by changing the 'outcome' enty of the `ode_scheme` dictionary as following:
- OutcomeExtractor: the counting will consider the code and its redundant occurrence for each subject, then aggregated over all subjects 
- FirstOccurrenceOutcomeExtractor: the counting will consider the first occurrence only for each subject, then aggregated over all subjects.

In [16]:
# pecentile_range=20 will partition the codes into five gruops, where each group contains 
# codes that overall constitutes 20% of the codes in all visits of specified 'subjects' list.
code_freq_partitions = cprd_interface.outcome_by_percentiles(percentile_range=20, subjects=cprd_splits[0])

# Evaluate for different k values
top_k_list = [3, 5, 10, 15, 20]

metrics = [CodeAUC(cprd_interface),
          UntilFirstCodeAUC(cprd_interface),
          AdmissionAUC(cprd_interface),
          CodeGroupTopAlarmAccuracy(cprd_interface, top_k_list=top_k_list, code_groups=code_freq_partitions)]
all_metrics = MetricsCollection(metrics)

<a name="sec1"></a>

### 1 Training ICE-NODE and The Baselines (#outline)

In [17]:
from lib.ml import MetricsHistory

def train(clf):
    output_dir = cprd_train_output_dir[clf]
    config = model_config[clf]
    model = cprd_models[clf]
    trainer = cprd_trainers[clf]
    reporters = [EvaluationDiskWriter(output_dir), # <- responsible for writing evaluation tables on disk at the given path
                 ParamsDiskWriter(output_dir), # <- responsible for writing model parameters snapshot after each iteration.
                 ConfigDiskWriter(output_dir, config), # writes the config file as JSON
                ]
    
    history = MetricsHistory(all_metrics) # <- empty history
    
    return trainer(model, cprd_interface, cprd_splits, history=history, reporters=reporters, prng_seed=42)

#### ICE-NODE

In [19]:
## TODO: This may take a long time, a pretrained model already exists in (yy).
icenode_results = train('ICE-NODE')

  return self._open_to_write(zinfo, force_zip64=force_zip64)
  return self._open_to_write(zinfo, force_zip64=force_zip64)
  return self._open_to_write(zinfo, force_zip64=force_zip64)
  return self._open_to_write(zinfo, force_zip64=force_zip64)
  return self._open_to_write(zinfo, force_zip64=force_zip64)
  return self._open_to_write(zinfo, force_zip64=force_zip64)
  return self._open_to_write(zinfo, force_zip64=force_zip64)
  return self._open_to_write(zinfo, force_zip64=force_zip64)
  return self._open_to_write(zinfo, force_zip64=force_zip64)
  return self._open_to_write(zinfo, force_zip64=force_zip64)
  return self._open_to_write(zinfo, force_zip64=force_zip64)
  return self._open_to_write(zinfo, force_zip64=force_zip64)
  return self._open_to_write(zinfo, force_zip64=force_zip64)
  return self._open_to_write(zinfo, force_zip64=force_zip64)
  return self._open_to_write(zinfo, force_zip64=force_zip64)
  return self._open_to_write(zinfo, force_zip64=force_zip64)
  return self._open_to_w

  return self._open_to_write(zinfo, force_zip64=force_zip64)
  return self._open_to_write(zinfo, force_zip64=force_zip64)
  return self._open_to_write(zinfo, force_zip64=force_zip64)
  return self._open_to_write(zinfo, force_zip64=force_zip64)
  return self._open_to_write(zinfo, force_zip64=force_zip64)
  return self._open_to_write(zinfo, force_zip64=force_zip64)
  return self._open_to_write(zinfo, force_zip64=force_zip64)
  return self._open_to_write(zinfo, force_zip64=force_zip64)
  return self._open_to_write(zinfo, force_zip64=force_zip64)
  return self._open_to_write(zinfo, force_zip64=force_zip64)
  return self._open_to_write(zinfo, force_zip64=force_zip64)
  return self._open_to_write(zinfo, force_zip64=force_zip64)
  return self._open_to_write(zinfo, force_zip64=force_zip64)
  return self._open_to_write(zinfo, force_zip64=force_zip64)
  return self._open_to_write(zinfo, force_zip64=force_zip64)
  return self._open_to_write(zinfo, force_zip64=force_zip64)
  return self._open_to_w

  return self._open_to_write(zinfo, force_zip64=force_zip64)
  return self._open_to_write(zinfo, force_zip64=force_zip64)
100%|██████████| 578/578 [5:44:20<00:00, 35.74s/it]


#### ICE-NODE_UNIFORM

In [20]:
icenode_uniform_results = train('ICE-NODE_UNIFORM')































Exception ignored in: <generator object tqdm.__iter__ at 0x7f4844164510>
Traceback (most recent call last):
  File "/data/master/DS211/users/tb1009/venv308/lib/python3.8/site-packages/tqdm/std.py", line 1181, in __iter__
    yield obj
KeyboardInterrupt: 
  0%|          | 2/581 [01:42<8:15:12, 51.32s/it]


KeyboardInterrupt: 

#### GRU

In [21]:
gru_results = train('GRU')

  return self._open_to_write(zinfo, force_zip64=force_zip64)
  return self._open_to_write(zinfo, force_zip64=force_zip64)
  return self._open_to_write(zinfo, force_zip64=force_zip64)
  return self._open_to_write(zinfo, force_zip64=force_zip64)
  return self._open_to_write(zinfo, force_zip64=force_zip64)
  return self._open_to_write(zinfo, force_zip64=force_zip64)
  return self._open_to_write(zinfo, force_zip64=force_zip64)
  return self._open_to_write(zinfo, force_zip64=force_zip64)
  return self._open_to_write(zinfo, force_zip64=force_zip64)
  return self._open_to_write(zinfo, force_zip64=force_zip64)
  return self._open_to_write(zinfo, force_zip64=force_zip64)
  return self._open_to_write(zinfo, force_zip64=force_zip64)
  return self._open_to_write(zinfo, force_zip64=force_zip64)
  return self._open_to_write(zinfo, force_zip64=force_zip64)
  return self._open_to_write(zinfo, force_zip64=force_zip64)
  return self._open_to_write(zinfo, force_zip64=force_zip64)
  return self._open_to_w

  return self._open_to_write(zinfo, force_zip64=force_zip64)
  return self._open_to_write(zinfo, force_zip64=force_zip64)
  return self._open_to_write(zinfo, force_zip64=force_zip64)
  return self._open_to_write(zinfo, force_zip64=force_zip64)
  return self._open_to_write(zinfo, force_zip64=force_zip64)
  return self._open_to_write(zinfo, force_zip64=force_zip64)
  return self._open_to_write(zinfo, force_zip64=force_zip64)
  return self._open_to_write(zinfo, force_zip64=force_zip64)
  return self._open_to_write(zinfo, force_zip64=force_zip64)
  return self._open_to_write(zinfo, force_zip64=force_zip64)
  return self._open_to_write(zinfo, force_zip64=force_zip64)
  return self._open_to_write(zinfo, force_zip64=force_zip64)
  return self._open_to_write(zinfo, force_zip64=force_zip64)
  return self._open_to_write(zinfo, force_zip64=force_zip64)
  return self._open_to_write(zinfo, force_zip64=force_zip64)
  return self._open_to_write(zinfo, force_zip64=force_zip64)
  return self._open_to_w

#### RETAIN

In [22]:
retain_results = train('RETAIN')

  1%|▏         | 18/1320 [9:37:55<696:43:14, 1926.42s/it] 


KeyboardInterrupt: 

In [19]:
import jax
jax.config.update('jax_log_compiles', True)
logreg_results = train('LogReg')

  0%|          | 0/145 [01:14<?, ?it/s]

KeyboardInterrupt

