## Imports and GPU setup

In [None]:
import Utils.dbutils as dbutils
import Utils.data_utils as data_utils
from Utils.embedding_utils import train_embedding
import Generators.CohortGenerator as CohortGenerator
import Generators.FeatureGenerator as FeatureGenerator
import Models.LogisticRegression.RegressionGen as lr_models
import Models.Transformer.visit_transformer as visit_transformer
import config

import numpy as np
import pandas as pd
import torch
import time

from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score

from ipywidgets import IntProgress, FloatText
from IPython.display import display

import matplotlib
import matplotlib.pyplot as plt

plt.rcParams["font.family"] = "serif"
plt.rcParams["font.size"] = 13

In [None]:
import importlib
visit_transformer = importlib.reload(visit_transformer)
data_utils = importlib.reload(data_utils)
assert(torch.cuda.is_available())
torch.cuda.set_device(0)

## Cohort, Outcome and Feature Collection

### 1. Set up a connection to the OMOP CDM database

Parameters for connection to be specified in ./config.py

In [None]:
# database connection
username = config.PG_USERNAME
password = config.PG_PASSWORD
database_name = config.DB_NAME

config_path = 'postgresql://{username}:{password}@{database_name}'.format(
    username = username,
    password = password,
    database_name = database_name
)

# schemas 
schema_name = 'eol_test' # all created tables will be created using this schema

# caching
reset_schema = False # if true, rebuild all data from scratch

# set up database, reset schemas as needed
db = dbutils.Database(config_path, schema_name)
if reset_schema:
    db.execute(
        'drop schema if exists {} cascade'.format(schema_name)
    )
db.execute(
    'create schema if not exists {}'.format(schema_name)
)

### 2. Generate the Cohort as per the given SQL file

In [None]:
cohort_name = 'eol_cohort'
cohort_script_path = config.SQL_PATH_COHORTS + '/gen_EOL_cohort.sql'

# cohort parameters  
params = {
          'cohort_table_name'     : cohort_name,
          'schema_name'           : schema_name,
          'aux_data_schema'       : config.CDM_AUX_SCHEMA,
          'training_start_date'   : '2016-01-01',
          'training_end_date'     : '2017-01-01',
          'gap'                   : '3 months',
          'outcome_window'        : '6 months'
         }

cohort = CohortGenerator.Cohort(
    schema_name=schema_name,
    cohort_table_name=cohort_name,
    cohort_generation_script=cohort_script_path,
    cohort_generation_kwargs=params,
    outcome_col_name='y'
)
cohort.build(db, replace=False)

### 3. Generate and build a feature set for each patient in the cohort using some default features

In [None]:
featureSet = FeatureGenerator.FeatureSet(db)
featureSet.add_default_features(
    ['drugs','conditions','procedures','specialty'],
    schema_name,
    cohort_name
)

In [None]:
%%time
# Build the Feature Set by executing SQL queries and reading into sparse matrices
cache_data_path = '/tmp/cache_data_eol_test'
featureSet.build(cohort, from_cached=False, cache_file=cache_data_path)

### 4. Process the collected data and calculate indices needed for the deep model

In [None]:
outcomes_filt, feature_matrix_3d_transpose, remap, good_feature_names = \
    FeatureGenerator.postprocess_feature_matrix(cohort, featureSet)

In [None]:
# process data for deep model
person_ixs, time_ixs, code_ixs = feature_matrix_3d_transpose.coords
all_codes_tensor = code_ixs
people = sorted(np.unique(person_ixs))
person_indices = np.searchsorted(person_ixs, people)
person_indices = np.append(person_indices, len(person_ixs))
person_chunks = [
    time_ixs[person_indices[i]: person_indices[i + 1]]
    for i in range(len(person_indices) - 1)
]

visit_chunks = []
visit_times_raw = []

for i, chunk in enumerate(person_chunks):
    visits = sorted(np.unique(chunk))
    visit_indices_local = np.searchsorted(chunk, visits)
    visit_indices_local = np.append(
        visit_indices_local,
        len(chunk)
    )
    visit_chunks.append(visit_indices_local)
    visit_times_raw.append(visits)

n_visits = {i:len(j) for i,j in enumerate(visit_times_raw)}

visit_days_rel = {
    i: (
        pd.to_datetime(params['training_end_date']) \
        - pd.to_datetime(featureSet.time_map[i])
    ).days for i in featureSet.time_map
}
vdrel_func = np.vectorize(visit_days_rel.get)
visit_time_rel = [
    vdrel_func(v) for v in visit_times_raw
]

maps = {
    'concept': featureSet.concept_map,
    'id': featureSet.id_map,
    'time': featureSet.time_map
}

dataset_dict = {
    'all_codes_tensor': all_codes_tensor, # A tensor of all codes occurring in the dataset
    'person_indices': person_indices, # A list of indices such that all_codes_tensor[person_indices[i]: person_indices[i+1]] are the codes assigned to the ith patient
    'visit_chunks': visit_chunks, # A list of indices such that all_codes_tensor[person_indices[i]+visit_chunks[j]:person_indices[i]+visit_chunks[j+1]] are the codes assigned to the ith patient during their jth visit
    'visit_time_rel': visit_time_rel, # A list of times (as measured in days to the prediction date) for each visit
    'n_visits': n_visits, # A dict defined such that n_visits[i] is the number of visits made by the ith patient
    'outcomes_filt': outcomes_filt, # A pandas Series defined such that outcomes_filt.iloc[i] is the outcome of the ith patient
    'remap': remap,
    'maps': maps
}

## Run the windowed regression model on the task defined above

In [None]:
# collect featre names
good_feature_names = np.vectorize(dataset_dict['maps']['concept'].get)(
    dataset_dict['remap']['concept']
)

In [None]:
# window the data using the window lengths specified below
feature_matrix_counts, feature_names = data_utils.window_data_sorted(
    window_lengths = [30, 180, 365, 730, 10000],
    feature_matrix = feature_matrix_3d_transpose,
    all_feature_names = good_feature_names,
    cohort = cohort, featureSet = featureSet
)
feature_matrix_counts = feature_matrix_counts.T

In [None]:
# split data into train, validate and test sets
val_size = 5000
indices_all = range(len(dataset_dict['outcomes_filt']))
X_train, X_test, y_train, y_test, indices_train, indices_test = train_test_split(
    feature_matrix_counts, dataset_dict['outcomes_filt'], indices_all,
    test_size=0.2, random_state=1
)

In [None]:
# train the regression model over several choices of regularization parameter
reg_lambdas = [2, 0.2, 0.02]
lr_val_aucs = []
for reg_lambda in reg_lambdas:
    clf_lr = lr_models.gen_lr_pipeline(reg_lambda)
    clf_lr.fit(X_train, y_train)
    pred_lr = clf_lr.predict_proba(X_test[:val_size])[:, 1]
    lr_val_aucs.append(roc_auc_score(y_test[:val_size], pred_lr))
    print('Validation AUC: {0:.3f}'.format(roc_auc_score(y_test[:val_size], pred_lr)))
                       
clf_lr = lr_models.gen_lr_pipeline(reg_lambdas[np.argmax(lr_val_aucs)])
clf_lr.fit(X_train, y_train)
pred_lr_all = clf_lr.predict_proba(feature_matrix_counts)[:, 1]

In [None]:
# pick the model with the best regularization, as measured by validation performance
pred_lr = clf_lr.predict_proba(X_test[val_size:])[:, 1]
print('Linear Model Test AUC: {0:.3f}'.format(roc_auc_score(y_test[val_size:], pred_lr)))

### Learn a Word2Vec embedding

In [None]:
embedding_dim = 300 # size of embedding, must be multiple of number of heads
window_days = 90 # number of days in window that defines a "Sentence" when learning the embedding
train_coords = np.nonzero(np.where(np.isin(person_ixs, indices_train), 1, 0))
embedding_filename = train_embedding(featureSet, feature_matrix_3d_transpose, window_days, \
                                     person_ixs[train_coords], time_ixs[train_coords], \
                                     remap['good_time_ixs'], embedding_dim)

## Run the SARD deep model on the predictive task
### 1. Set Model Parameters and Construct the Model

In [None]:
# using the same split as before, create train/validate/test batches for the deep model
# `mbsz` might need to be decreased based on the GPU's memory and the number of features being used
mbsz = 50
def get_batches(arr, mbsz=mbsz):
    curr, ret = 0, []
    while curr < len(arr) - 1:
        ret.append(arr[curr : curr + mbsz])
        curr += mbsz
    return ret

p_ranges_train, p_ranges_test = [
    get_batches(arr) for arr in (
        indices_train, indices_test
    )
]
p_ranges_val = p_ranges_test[:val_size // mbsz]
p_ranges_test = p_ranges_test[val_size // mbsz:]

In [None]:
# Pick a name for the model (mn_prefix) that will be used when saving checkpoints
# Then, set some parameters for SARD. The values below reflect a good starting point that performed well on several tasks
mn_prefix = 'eol_experiment_prefix'
n_heads = 2
assert embedding_dim % n_heads == 0
model_params = {
    'embedding_dim': int(embedding_dim / n_heads), # Dimension per head of visit embeddings
    'n_heads': n_heads, # Number of self-attention heads
    'attn_depth': 2, # Number of stacked self-attention layers
    'dropout': 0.05, # Dropout rate for both self-attention and the final prediction layer
    'use_mask': True, # Only allow visits to attend to other actual visits, not to padding visits
    'concept_embedding_path': embedding_filename # if unspecified, uses default Torch embeddings
}

In [None]:
# Set up fixed model parameters, loss functions, and build the model on the GPU
lr = 2e-4
n_epochs_pretrain = 1
ft_epochs = 1

update_every = 500
update_mod = update_every // mbsz

base_model = visit_transformer.VisitTransformer(
    featureSet, **model_params
)

clf = visit_transformer.VTClassifer(base_model).cuda()

clf.bert.set_data(
    torch.LongTensor(dataset_dict['all_codes_tensor']).cuda(),
    dataset_dict['person_indices'], dataset_dict['visit_chunks'],
    dataset_dict['visit_time_rel'], dataset_dict['n_visits']
)

loss_function_distill = torch.nn.BCEWithLogitsLoss(
    pos_weight=torch.FloatTensor([
        len(dataset_dict['outcomes_filt']) / dataset_dict['outcomes_filt'].sum() - 1
    ]), reduction='sum'
).cuda()

optimizer_clf = torch.optim.Adam(params=clf.parameters(), lr=lr)

def eval_curr_model_on(a):
    with torch.no_grad():
        preds_test, true_test = [], []
        for batch_num, p_range in enumerate(a):
            y_pred = clf(p_range) 
            preds_test += y_pred.tolist()
            true_test += list(dataset_dict['outcomes_filt'].iloc[list(p_range)].values)
        return roc_auc_score(true_test, preds_test)

### 2. Fit the SARD model to the best windowed linear model (Reverse Distillation)

In [None]:
# Run `n_epochs_pretrain` of Reverse Distillation pretraining
val_losses = []
progress_bar = IntProgress(min=0, max=int(n_epochs_pretrain * len(p_ranges_train)))
batch_loss_disp = FloatText(value=0.0, description='Avg. Batch Loss for Last 50 Batches', disabled=True)
time_disp = FloatText(value=0.0, description='Time for Last 50 Batches', disabled=True)

display(progress_bar)
display(batch_loss_disp)
display(time_disp)

for epoch in range(n_epochs_pretrain):
    t, batch_loss = time.time(), 0
    
    for batch_num, p_range in enumerate(p_ranges_train):
        
        if batch_num % 50 == 0:
            batch_loss_disp.value = round(batch_loss / 50, 2)
            time_disp.value = round(time.time() - t, 2)
            t, batch_loss = time.time(), 0
            
        y_pred = clf(p_range)
        loss_distill = loss_function_distill(
            y_pred, torch.FloatTensor(pred_lr_all[p_range]).cuda()
        )
        
        batch_loss += loss_distill.item()
        loss_distill.backward()
        
        if batch_num % update_mod == 0:
            optimizer_clf.step()
            optimizer_clf.zero_grad()
        
        progress_bar.value = batch_num + epoch * len(p_ranges_train)
        
    torch.save(
        clf.state_dict(),
        "SavedModels/{task}/{mn_prefix}_pretrain_epochs_{epochs}".format(
                task=task, mn_prefix = mn_prefix, epochs = epoch + 1
            )
        )
    
    clf.eval()
    ckpt_auc = eval_curr_model_on(p_ranges_val)
    print('Epochs: {} | Val AUC: {}'.format(epoch + 1, ckpt_auc))
    val_losses.append(ckpt_auc)
    clf.train()

In [None]:
# Save the pretrained model with best validation-set performance
clf.load_state_dict(
    torch.load("SavedModels/{task}/{mn_prefix}_pretrain_epochs_{epochs}".format(
        task=task, mn_prefix=mn_prefix, epochs=np.argmax(val_losses) + 1
    ))
)
torch.save(
        clf.state_dict(),
        "SavedModels/{task}/{mn_prefix}_pretrain_epochs_{epochs}".format(
                task=task, mn_prefix = mn_prefix, epochs = 'BEST'
            )
        )

### 3. Fine-tune the SARD model by training to match the actual outcomes on the training set

In [None]:
# Set up loss functions for fine-tuning. There are two terms:
#    - `loss_function_distill`, which penalizes differences between the linear model prediction and SARD's prediction
#    - `loss_function_clf`, which penalizes differences between the true outcome and SARD's prediction
loss_function_distill = torch.nn.BCEWithLogitsLoss(
    pos_weight=torch.FloatTensor([
        len(dataset_dict['outcomes_filt']) / dataset_dict['outcomes_filt'].sum() - 1
    ]), reduction='sum'
).cuda()

loss_function_clf = torch.nn.BCEWithLogitsLoss(
    pos_weight=torch.FloatTensor([
        len(dataset_dict['outcomes_filt']) / dataset_dict['outcomes_filt'].sum() - 1
    ]), reduction='sum'
).cuda()

In [None]:
# run `ft_epochs` of fine-tuning training, for each of the values of `alpha` below.
# Note that `alpha` is the relative weight of `loss_function_distill` as compared to `loss_function_clf`

all_pred_models = {}

progress_bar = IntProgress(min=0, max=int(ft_epochs * len(p_ranges_train)))
batch_loss_disp = FloatText(value=0.0, description='Avg. Batch Loss for Last 50 Batches', disabled=True)
time_disp = FloatText(value=0.0, description='Time for Last 50 Batches', disabled=True)

display(progress_bar)
display(batch_loss_disp)
display(time_disp)


no_rd = False
for alpha in [0,0.05,0.1,0.15, 0.2]:

    progress_bar.value = 0
    
    
    if no_rd:
        pretrained_model_fn = mn_prefix + '_None'
        start_model = None
        if start_model is None:
            base_model = visit_transformer.VisitTransformer(
                featureSet, **model_params
            )

            clf = visit_transformer.VTClassifer(base_model).cuda()

            clf.bert.set_data(
                torch.LongTensor(dataset_dict['all_codes_tensor']).cuda(),
                dataset_dict['person_indices'], dataset_dict['visit_chunks'],
                dataset_dict['visit_time_rel'], dataset_dict['n_visits']
            )
        else:
            pretrained_model_path = "SavedModels/{task}/{start_model}".format(
                task=task, start_model=start_model
            )
            clf.load_state_dict(torch.load(pretrained_model_path))
            
    else: 
        pretrained_model_fn = "{mn_prefix}_pretrain_epochs_{epochs}".format(
            mn_prefix=mn_prefix, epochs='BEST'
        )
        pretrained_model_path = "SavedModels/{task}/{mn_prefix}_pretrain_epochs_{epochs}".format(
            task=task, mn_prefix=mn_prefix, epochs='BEST'
        )
        clf.load_state_dict(torch.load(pretrained_model_path))
    
    clf.train()
    
    optimizer_clf = torch.optim.Adam(params=clf.parameters(), lr=2e-4)

    for epoch in range(ft_epochs):

        t, batch_loss = time.time(), 0

        for batch_num, p_range in enumerate(p_ranges_train):

            if batch_num % 50 == 0:
                batch_loss_disp.value = round(batch_loss / 50, 2)
                time_disp.value = round(time.time() - t, 2)
                t, batch_loss = time.time(), 0

            y_pred = clf(p_range)
            
            loss = loss_function_clf(
                y_pred,
                torch.FloatTensor(dataset_dict['outcomes_filt'].values[p_range]).cuda()
            )

            loss_distill = loss_distill = loss_function_distill(
                y_pred,
                torch.FloatTensor(pred_lr_all[p_range]).cuda()
            )

            batch_loss += loss.item() + alpha * loss_distill.item()
            loss_total = loss + alpha * loss_distill
            loss_total.backward()
            
            if batch_num % update_mod == 0:
                optimizer_clf.step()
                optimizer_clf.zero_grad()

            progress_bar.value = batch_num + epoch * len(p_ranges_train)
        
        saving_fn = "{pretrain}_alpha_{alpha}_epochs_{epochs}".format(
            task=task, pretrain = pretrained_model_fn, alpha=alpha, epochs = epoch + 1
        )
        torch.save(
            clf.state_dict(),
            "SavedModels/{task}/ablation/{saving_fn}".format(
                    task=task, saving_fn=saving_fn
                )
            )
        
        clf.eval()
        val_auc = eval_curr_model_on(p_ranges_val)
        print(val_auc)
        all_pred_models[saving_fn] = val_auc
        clf.train()

### 4. Evaluate the best SARD model, as determined by validation performance

In [None]:
best_model = max(all_pred_models, key=all_pred_models.get()
clf.load_state_dict(
    torch.load("SavedModels/{task}/{model}".format(
        task=task, model=best_model
    ))
)
clf.eval();
with torch.no_grad():
    preds_test, true_test = [], []
    for batch_num, p_range in enumerate(p_ranges_test):
        y_pred = clf(p_range) 
        preds_test += y_pred.tolist()
        true_test += list(dataset_dict['outcomes_filt'].iloc[list(p_range)].values)
    print(roc_auc_score(true_test, preds_test))
clf.train();