<a href="https://colab.research.google.com/github/clinicalml/omop-learn/blob/master/examples/eol/sard_eol.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Run End of Life prediction task on Synthetic Patient Data in OMOP

This notebook runs the end-of-life (EOL) prediction task on synthetic patient data in OMOP using a linear baseline model and the SARD architecture [Kodialam et al. 2021].

Data is sourced from the publicly available Medicare Claims Synthetic Public Use Files (SynPUF), released by the Centers for Medicare and Medicaid Services (CMS) and available in [Google BigQuery. The synthetic set contains 2008-2010 Medicare insurance claims for development and demonstration purposes and was coverted to the Medical Outcomes Partnership (OMOP) Common Data Model from its original CSV form.

## Imports and GPU setup

In [None]:
import numpy as np
import pandas as pd
import torch
import time
import os

from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score

from ipywidgets import IntProgress, FloatText
from IPython.display import display

import matplotlib
import matplotlib.pyplot as plt

plt.rcParams["font.family"] = "serif"
plt.rcParams["font.size"] = 13

In [None]:
from omop_learn.backends.bigquery import BigQueryBackend
from omop_learn.data.cohort import Cohort
from omop_learn.data.feature import Feature
from omop_learn.utils.config import Config
from omop_learn.omop import OMOPDataset
from omop_learn.utils import date_utils, embedding_utils
from omop_learn.sparse.models import OMOPLogisticRegression
from omop_learn.models import transformer, visit_transformer

## Cohort, Outcome and Feature Collection

### 1. Set up a connection to the OMOP CDM database

Parameters for connection to be specified in ./config.py

In [None]:
config = Config({
    "project_name": "project",
    "cdm_schema": "bigquery-public-data.cms_synthetic_patient_data_omop",
    "prefix_schema": "username",
    "datasets_dir": "data_dir",
    "models_dir": "model_dir"
})

# Set up database, reset schemas as needed
backend = BigQueryBackend(config)
backend.reset_schema(config.prefix_schema) # Rebuild schema from scratch
backend.create_schema(config.prefix_schema) # Create schema if not exists

cohort_params = {
    "cohort_table_name": "synpuf_eol_cohort",
    "schema_name": config.prefix_schema,
    "cdm_schema": config.cdm_schema,
    "aux_data_schema": config.aux_cdm_schema,
    "training_start_date": "2009-01-01",
    "training_end_date": "2009-12-31",
    "gap": "3 month",
    "outcome_window": "6 month",
}
sql_dir = "./bigquery_sql"
sql_file = open(f"{sql_dir}/gen_EOL_cohort.sql", 'r')
cohort = Cohort.from_sql_file(sql_file, backend, params=cohort_params)

feature_names = ["drugs", "conditions", "procedures"]
feature_paths = [f"{sql_dir}/{feature_name}.sql" for feature_name in feature_names]
features = [Feature(n, p) for n, p in zip(feature_names, feature_paths)]

init_args = {
    "config" : config,
    "name" : "synpuf_eol",
    "cohort" : cohort,
    "features": features,
    "backend": backend,
}

dataset = OMOPDataset(**init_args)

### 4. Process the collected data and calculate indices needed for the deep model

In [None]:
window_days = [30, 180, 365, 730, 1000]
windowed_dataset = dataset.to_windowed(window_days)

In [None]:
person_ixs, time_ixs, code_ixs = windowed_dataset.feature_tensor.coords

In [None]:
# process data for deep model
person_ixs, time_ixs, code_ixs = windowed_dataset.feature_tensor.coords
outcomes_filt = windowed_dataset.outcomes
time_to_idx = windowed_dataset.times_map
idx_to_datetime = {idx: date_utils.from_unixtime([time])[0] for time, idx in time_to_idx.items()}

all_codes_tensor = code_ixs
people = sorted(np.unique(person_ixs))
person_indices = np.searchsorted(person_ixs, people)
person_indices = np.append(person_indices, len(person_ixs))
person_chunks = [
    time_ixs[person_indices[i]: person_indices[i + 1]]
    for i in range(len(person_indices) - 1)
]

visit_chunks = []
visit_times_raw = []

for i, chunk in enumerate(person_chunks):
    visits = sorted(np.unique(chunk))
    visit_indices_local = np.searchsorted(chunk, visits)
    visit_indices_local = np.append(
        visit_indices_local,
        len(chunk)
    )
    visit_chunks.append(visit_indices_local)
    visit_times_raw.append(visits)

n_visits = {i:len(j) for i,j in enumerate(visit_times_raw)}

visit_days_rel = {
    i: (
        pd.to_datetime(cohort_params['training_end_date']) \
        - pd.to_datetime(idx_to_datetime[time])
    ).days for time in time_ixs
}
vdrel_func = np.vectorize(visit_days_rel.get)
visit_time_rel = [
    vdrel_func(v) for v in visit_times_raw
]

remap = {
    'id': people,
    'time': sorted(np.unique(time_ixs)),
    'concept': sorted(np.unique(code_ixs))
}

dataset_dict = {
    'all_codes_tensor': all_codes_tensor, # A tensor of all codes occurring in the dataset
    'person_indices': person_indices, # A list of indices such that all_codes_tensor[person_indices[i]: person_indices[i+1]] are the codes assigned to the ith patient
    'visit_chunks': visit_chunks, # A list of indices such that all_codes_tensor[person_indices[i]+visit_chunks[j]:person_indices[i]+visit_chunks[j+1]] are the codes assigned to the ith patient during their jth visit
    'visit_time_rel': visit_time_rel, # A list of times (as measured in days to the prediction date) for each visit
    'n_visits': n_visits, # A dict defined such that n_visits[i] is the number of visits made by the ith patient
    'outcomes_filt': outcomes_filt, # A pandas Series defined such that outcomes_filt.iloc[i] is the outcome of the ith patient
    'remap': remap,
}

## Run the windowed regression model on the task defined above

In [None]:
# split data into train, validate and test sets
windowed_dataset.split()

In [None]:
# train the regression model over several choices of regularization parameter
reg_lambdas = [2, 0.2, 0.02]
lr_val_aucs = []
model = OMOPLogisticRegression("eol_new_50", windowed_dataset)

for reg_lambda in reg_lambdas:
    # Gen and fit
    model.gen_pipeline(reg_lambda)
    model.fit()
    # Eval on validation data
    pred_lr = model._pipeline.predict_proba(windowed_dataset.val['X'])[:, 1]
    lr_val_auc = roc_auc_score(windowed_dataset.val['y'], pred_lr)
    lr_val_aucs.append(lr_val_auc)
    print("C: %.4f, Val AUC: %.2f" % (reg_lambda, lr_val_auc))

In [None]:
# Gen and fit on best C
best_reg_lambda = reg_lambdas[np.argmax(lr_val_aucs)]
model.gen_pipeline(best_reg_lambda)
model.fit()
# Eval on test data
pred_lr = model._pipeline.predict_proba(windowed_dataset.test['X'])[:, 1]
score = roc_auc_score(windowed_dataset.test['y'], pred_lr)
print("C: %.4f, Test AUC: %.2f" % (best_reg_lambda, score))

### Learn a Word2Vec embedding

In [None]:
%%time
embedding_dim = 300 # size of embedding, must be multiple of number of heads
window_days = 90 # number of days in window that defines a "Sentence" when learning the embedding
train_coords = np.nonzero(np.where(np.isin(person_ixs, indices_train), 1, 0))
embedding_filename = embedding_utils.train_embedding(featureSet, feature_matrix_3d_transpose, window_days, \
                                     person_ixs[train_coords], time_ixs[train_coords], \
                                     remap['time'], embedding_dim)

## Run the SARD deep model on the predictive task
### 1. Set Model Parameters and Construct the Model

In [None]:
# using the same split as before, create train/validate/test batches for the deep model
# `mbsz` might need to be decreased based on the GPU's memory and the number of features being used
mbsz = 50
def get_batches(arr, mbsz=mbsz):
    curr, ret = 0, []
    while curr < len(arr) - 1:
        ret.append(arr[curr : curr + mbsz])
        curr += mbsz
    return ret

p_ranges_train, p_ranges_test = [
    get_batches(arr) for arr in (
        indices_train, indices_test
    )
]
p_ranges_val = p_ranges_test[:val_size // mbsz]
p_ranges_test = p_ranges_test[val_size // mbsz:]

In [None]:
# Pick a name for the model (mn_prefix) that will be used when saving checkpoints
# Then, set some parameters for SARD. The values below reflect a good starting point that performed well on several tasks
mn_prefix = 'eol_experiment_prefix'
n_heads = 2
assert embedding_dim % n_heads == 0
model_params = {
    'embedding_dim': int(embedding_dim / n_heads), # Dimension per head of visit embeddings
    'n_heads': n_heads, # Number of self-attention heads
    'attn_depth': 2, # Number of stacked self-attention layers
    'dropout': 0.05, # Dropout rate for both self-attention and the final prediction layer
    'use_mask': True, # Only allow visits to attend to other actual visits, not to padding visits
    'concept_embedding_path': embedding_filename # if unspecified, uses default Torch embeddings
}

In [None]:
# Set up fixed model parameters, loss functions, and build the model on the GPU
lr = 2e-4
n_epochs_pretrain = 1
ft_epochs = 1

update_every = 500
update_mod = update_every // mbsz

base_model = visit_transformer.VisitTransformer(
    featureSet, **model_params
)

clf = visit_transformer.VTClassifer(
    base_model, **model_params
).cuda()

clf.bert.set_data(
    torch.LongTensor(dataset_dict['all_codes_tensor']).cuda(),
    dataset_dict['person_indices'], dataset_dict['visit_chunks'],
    dataset_dict['visit_time_rel'], dataset_dict['n_visits']
)

loss_function_distill = torch.nn.BCEWithLogitsLoss(
    pos_weight=torch.FloatTensor([
        len(dataset_dict['outcomes_filt']) / dataset_dict['outcomes_filt'].sum() - 1
    ]), reduction='sum'
).cuda()

optimizer_clf = torch.optim.Adam(params=clf.parameters(), lr=lr)

def eval_curr_model_on(a):
    with torch.no_grad():
        preds_test, true_test = [], []
        for batch_num, p_range in enumerate(a):
            y_pred = clf(p_range)
            preds_test += y_pred.tolist()
            true_test += list(dataset_dict['outcomes_filt'].iloc[list(p_range)].values)
        return roc_auc_score(true_test, preds_test)

### 2. Fit the SARD model to the best windowed linear model (Reverse Distillation)

The following code saves models in a folder `/SavedModels/{task}/`; make sure to create the directory before running.

In [None]:
task = 'eol'

In [None]:
# Run `n_epochs_pretrain` of Reverse Distillation pretraining
val_losses = []
progress_bar = IntProgress(min=0, max=int(n_epochs_pretrain * len(p_ranges_train)))
batch_loss_disp = FloatText(value=0.0, description='Avg. Batch Loss for Last 50 Batches', disabled=True)
time_disp = FloatText(value=0.0, description='Time for Last 50 Batches', disabled=True)

display(progress_bar)
display(batch_loss_disp)
display(time_disp)

for epoch in range(n_epochs_pretrain):
    t, batch_loss = time.time(), 0

    for batch_num, p_range in enumerate(p_ranges_train):

        if batch_num % 50 == 0:
            batch_loss_disp.value = round(batch_loss / 50, 2)
            time_disp.value = round(time.time() - t, 2)
            t, batch_loss = time.time(), 0

        y_pred = clf(p_range)
        loss_distill = loss_function_distill(
            y_pred, torch.FloatTensor(pred_lr_all[p_range]).cuda()
        )

        batch_loss += loss_distill.item()
        loss_distill.backward()

        if batch_num % update_mod == 0:
            optimizer_clf.step()
            optimizer_clf.zero_grad()

        progress_bar.value = batch_num + epoch * len(p_ranges_train)

    torch.save(
        clf.state_dict(),
        "SavedModels/{task}/{mn_prefix}_pretrain_epochs_{epochs}".format(
                task=task, mn_prefix = mn_prefix, epochs = epoch + 1
            )
        )

    clf.eval()
    ckpt_auc = eval_curr_model_on(p_ranges_val)
    print('Epochs: {} | Val AUC: {}'.format(epoch + 1, ckpt_auc))
    val_losses.append(ckpt_auc)
    clf.train()

In [None]:
# Save the pretrained model with best validation-set performance
clf.load_state_dict(
    torch.load("SavedModels/{task}/{mn_prefix}_pretrain_epochs_{epochs}".format(
        task=task, mn_prefix=mn_prefix, epochs=np.argmax(val_losses) + 1
    ))
)
torch.save(
        clf.state_dict(),
        "SavedModels/{task}/{mn_prefix}_pretrain_epochs_{epochs}".format(
                task=task, mn_prefix = mn_prefix, epochs = 'BEST'
            )
        )

### 3. Fine-tune the SARD model by training to match the actual outcomes on the training set

In [None]:
# Set up loss functions for fine-tuning. There are two terms:
#    - `loss_function_distill`, which penalizes differences between the linear model prediction and SARD's prediction
#    - `loss_function_clf`, which penalizes differences between the true outcome and SARD's prediction
loss_function_distill = torch.nn.BCEWithLogitsLoss(
    pos_weight=torch.FloatTensor([
        len(dataset_dict['outcomes_filt']) / dataset_dict['outcomes_filt'].sum() - 1
    ]), reduction='sum'
).cuda()

loss_function_clf = torch.nn.BCEWithLogitsLoss(
    pos_weight=torch.FloatTensor([
        len(dataset_dict['outcomes_filt']) / dataset_dict['outcomes_filt'].sum() - 1
    ]), reduction='sum'
).cuda()

In [None]:
# run `ft_epochs` of fine-tuning training, for each of the values of `alpha` below.
# Note that `alpha` is the relative weight of `loss_function_distill` as compared to `loss_function_clf`

all_pred_models = {}

progress_bar = IntProgress(min=0, max=int(ft_epochs * len(p_ranges_train)))
batch_loss_disp = FloatText(value=0.0, description='Avg. Batch Loss for Last 50 Batches', disabled=True)
time_disp = FloatText(value=0.0, description='Time for Last 50 Batches', disabled=True)

display(progress_bar)
display(batch_loss_disp)
display(time_disp)


no_rd = False
for alpha in [0,0.05,0.1,0.15, 0.2]:

    progress_bar.value = 0

    if no_rd:
        pretrained_model_fn = mn_prefix + '_None'
        start_model = None
        if start_model is None:
            base_model = visit_transformer.VisitTransformer(
                featureSet, **model_params
            )

            clf = visit_transformer.VTClassifer(base_model, **model_params).cuda()

            clf.bert.set_data(
                torch.LongTensor(dataset_dict['all_codes_tensor']).cuda(),
                dataset_dict['person_indices'], dataset_dict['visit_chunks'],
                dataset_dict['visit_time_rel'], dataset_dict['n_visits']
            )
        else:
            pretrained_model_path = "SavedModels/{task}/{start_model}".format(
                task=task, start_model=start_model
            )
            clf.load_state_dict(torch.load(pretrained_model_path))

    else:
        pretrained_model_fn = "{mn_prefix}_pretrain_epochs_{epochs}".format(
            mn_prefix=mn_prefix, epochs='BEST'
        )
        pretrained_model_path = "SavedModels/{task}/{mn_prefix}_pretrain_epochs_{epochs}".format(
            task=task, mn_prefix=mn_prefix, epochs='BEST'
        )

        clf = visit_transformer.VTClassifer(base_model, **model_params).cuda()
        clf.bert.set_data(
            torch.LongTensor(dataset_dict['all_codes_tensor']).cuda(),
            dataset_dict['person_indices'], dataset_dict['visit_chunks'],
            dataset_dict['visit_time_rel'], dataset_dict['n_visits']
        )

        clf.load_state_dict(torch.load(pretrained_model_path))

    clf.train()

    optimizer_clf = torch.optim.Adam(params=clf.parameters(), lr=2e-4)

    for epoch in range(ft_epochs):

        t, batch_loss = time.time(), 0

        for batch_num, p_range in enumerate(p_ranges_train):

            if batch_num % 50 == 0:
                batch_loss_disp.value = round(batch_loss / 50, 2)
                time_disp.value = round(time.time() - t, 2)
                t, batch_loss = time.time(), 0

            y_pred = clf(p_range)

            loss = loss_function_clf(
                y_pred,
                torch.FloatTensor(dataset_dict['outcomes_filt'].values[p_range]).cuda()
            )

            loss_distill = loss_distill = loss_function_distill(
                y_pred,
                torch.FloatTensor(pred_lr_all[p_range]).cuda()
            )

            batch_loss += loss.item() + alpha * loss_distill.item()
            loss_total = loss + alpha * loss_distill
            loss_total.backward()

            if batch_num % update_mod == 0:
                optimizer_clf.step()
                optimizer_clf.zero_grad()

            progress_bar.value = batch_num + epoch * len(p_ranges_train)

        saving_fn = "{pretrain}_alpha_{alpha}_epochs_{epochs}".format(
            task=task, pretrain = pretrained_model_fn, alpha=alpha, epochs = epoch + 1
        )
        torch.save(
            clf.state_dict(),
            "SavedModels/{task}/{saving_fn}".format(
                    task=task, saving_fn=saving_fn
                )
            )

        clf.eval()
        val_auc = eval_curr_model_on(p_ranges_val)
        print(val_auc)
        all_pred_models[saving_fn] = val_auc
        clf.train()

### 4. Evaluate the best SARD model, as determined by validation performance

In [None]:
best_model = max(all_pred_models, key=all_pred_models.get)
clf.load_state_dict(
    torch.load("SavedModels/{task}/{model}".format(
        task=task, model=best_model
    ))
)
clf.eval();
with torch.no_grad():
    preds_test, true_test = [], []
    for batch_num, p_range in enumerate(p_ranges_test):
        y_pred = clf(p_range)
        preds_test += y_pred.tolist()
        true_test += list(dataset_dict['outcomes_filt'].iloc[list(p_range)].values)
    print(roc_auc_score(true_test, preds_test))
clf.train();