<a href="https://colab.research.google.com/github/joanna-regan/CS598_DL4H_StageNet/blob/main/JR_FinalProject_ExtraCredit_Notebook.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#Reproducibility Study of StageNet: 
**Stage-Aware Neural Networks for Health Risk Prediction**

Original Athors:Junyi Gao, Cao Xiao, Yasha Wang, Wen Tang, Lucas M. Glass, and Jimeng Sun

Reproduced by: Joanna Regan



---



---



###Reproducibility Summary:


This report looks to reproduce the novel model StageNet, which is comprised of a stage-aware LSTM module and a stage-adaptive Convolution module. The goal of StageNet is to identify the points in time where health status changes rapidly (i.e., enters a new stage), and to dynamically learn information about the patterns within each stage to help make predictions about a patient's health risk.

This report also looks to reproduce two reduced models of StageNet. StageNet-I replaces the stage-aware LSTM with a vanilla LSTM, and StageNet-II only uses the stage-aware LSTM and removes the stage-adaptive convolution module.

We assess the claims that StageNet trained with MIMIC-III EHR data will achieve 10\% higher AUPRC and min(Re, P+) than baseline models on decompensation risk prediction task, and that reduced models StageNet-I and StageNet-II will still achieve higher AUPRC, AUROC, and min(Re, P+) than all baseline models on decompensation risk prediction task.

We found that we were unable to produce results as strong as those reported by the original work, though our results are in the right ballpark. The current study was limited to only use a small subset of data to train models, which likely contributed to the lower performance metrics.



---



---



###Data Statistics



Note that the dataset is not provided, but should be obtained from PhysioNet at:

https://physionet.org/content/mimiciii/1.4/

Data must then be preprocessed for the decompensation task according to:

https://github.com/YerevaNN/mimic3-benchmarks/

Final data samples should be added to /data/train_subdivided and /data/test to reproduce the results shown below.



We first instantiate some parameters for our dataloaders and set our arguments:

In [14]:
import numpy as np
import argparse
import os
import imp
import re
import pickle
import random
import matplotlib.pyplot as plt
import matplotlib as mpl
from time import perf_counter
import datetime as dt
from datetime import datetime

RANDOM_SEED = 12345
np.random.seed(RANDOM_SEED)
random.seed(RANDOM_SEED)

import torch
from torch import nn
import torch.nn.utils.rnn as rnn_utils
from torch.utils import data
from torch.autograd import Variable
import torch.nn.functional as F

torch.manual_seed(RANDOM_SEED)
torch.cuda.manual_seed(RANDOM_SEED)
torch.backends.cudnn.deterministic=True

from utils import utils
from utils.readers import DecompensationReader
from utils.preprocessing import Discretizer, Normalizer
from utils import metrics
from utils import common_utils
from model import StageNet

Confirm we're using GPU runtime:

In [3]:
device = torch.device("cuda:0" if torch.cuda.is_available() == True else 'cpu')
print("available device: {}".format(device))

available device: cuda:0


Define and load arguments:

In [4]:
def parse_arguments(parser):
    parser.add_argument('--test_mode', type=int, default=0, help='Test SA-CRNN on MIMIC-III dataset')
    parser.add_argument('--data_path', type=str, metavar='<data_path>', help='The path to the MIMIC-III data directory')
    parser.add_argument('--file_name', type=str, metavar='<data_path>', help='File name to save model')
    parser.add_argument('--small_part', type=int, default=0, help='Use part of training data')
    parser.add_argument('--batch_size', type=int, default=128, help='Training batch size')
    parser.add_argument('--epochs', type=int, default=50, help='Training epochs')
    parser.add_argument('--lr', type=float, default=0.001, help='Learing rate')

    parser.add_argument('--input_dim', type=int, default=76, help='Dimension of visit record data')
    parser.add_argument('--rnn_dim', type=int, default=384, help='Dimension of hidden units in RNN')
    parser.add_argument('--output_dim', type=int, default=1, help='Dimension of prediction target')
    parser.add_argument('--dropout_rate', type=float, default=0.5, help='Dropout rate')
    parser.add_argument('--dropconnect_rate', type=float, default=0.5, help='Dropout rate in RNN')
    parser.add_argument('--dropres_rate', type=float, default=0.3, help='Dropout rate in residue connection')
    parser.add_argument('--K', type=int, default=10, help='Value of hyper-parameter K')
    parser.add_argument('--chunk_level', type=int, default=3, help='Value of hyper-parameter K')

    parser.add_argument('-f')

    args = parser.parse_args()
    return args

parser = argparse.ArgumentParser()
args = parse_arguments(parser)

Next we define our personal arguments. We use a small sample of data for illustrative purposes:

In [5]:
#JR add in my paths:
args.data_path = './data/'
args.file_name = 'trained_model'
args.epochs = 5
args.small_part = 1000

We now load the train, validation, and test sets data and report some statistics.

The next cell should only take a few minutes to run. If running for over 5 minutes, recommend stopping the cell execution and rerunning.

The following few cells in this section can be run for more statistics.

In [8]:
print('Preparing data sets ... ')
start_time = perf_counter()

train_data_loader = common_utils.DeepSupervisionDataLoader(dataset_dir=os.path.join(args.data_path, 'train_subdivided'), 
                                                                   listfile=os.path.join(args.data_path, 'train_listfile.csv'),
                                                                   small_part=args.small_part)
timer1 = perf_counter()
val_data_loader = common_utils.DeepSupervisionDataLoader(dataset_dir=os.path.join(args.data_path, 'train_subdivided'), 
                                                                 listfile=os.path.join(args.data_path, 'val_listfile.csv'),
                                                                 small_part=args.small_part)
timer2 = perf_counter()

test_data_loader = common_utils.DeepSupervisionDataLoader(dataset_dir=os.path.join(args.data_path, 'test'), 
                                                          listfile=os.path.join(args.data_path, 'test_listfile.csv'), 
                                                          small_part=args.small_part)
end_data_loads = perf_counter()

print("Time to load train data: " + str(dt.timedelta(seconds = timer1 - start_time)))
print("Time to load validation data: " + str(dt.timedelta(seconds = timer2 - timer1))) #time to generate the val data loader
print("Time to load test data: " + str(dt.timedelta(seconds = end_data_loads - timer2)))

print("Size of training set: " + str(len(train_data_loader._data["X"])))
print("Size of validation set: " + str(len(val_data_loader._data["X"])))
print("Size of test set: " + str(len(test_data_loader._data["X"])))

Preparing data sets ... 
Generating data...
Generating data...
Generating data...
Time to load train data: 0:00:20.314760
Time to load validation data: 0:00:01.474405
Time to load test data: 0:00:02.427864
Size of training set: 1000
Size of validation set: 112
Size of test set: 278


In [9]:
import os
import numpy as np
from collections import Counter

my_subject_ids = []
my_length_of_stays = []
test_labels = []
val_labels = []
train_labels = []

for i in range(len(test_data_loader._data["X"])):
    my_subject_ids.append(test_data_loader._data["name"][i].split('_')[0])
    my_length_of_stays.append(test_data_loader._data["ts"][i][-1])
    test_labels.append(test_data_loader._data["ys"][i])
for i in range(len(val_data_loader._data["X"])):
    my_subject_ids.append(val_data_loader._data["name"][i].split('_')[0])
    my_length_of_stays.append(val_data_loader._data["ts"][i][-1])
    val_labels.append(val_data_loader._data["ys"][i])
for i in range(len(train_data_loader._data["X"])):
    my_subject_ids.append(train_data_loader._data["name"][i].split('_')[0])
    my_length_of_stays.append(train_data_loader._data["ts"][i][-1])
    train_labels.append(train_data_loader._data["ys"][i])

print("Total number of ICU stays: " + str(len(my_subject_ids)))
print("Total number of patients: " + str(len(list(set(my_subject_ids)))))
print("--- max number of stays per patient: " + str(max(Counter(my_subject_ids).values())))
print("--- max length of stay: " + str(max(my_length_of_stays)))
#print("--- max length of stay in days: " + str(max(my_length_of_stays)/24))
print("--- min number of stays per patient: " + str(min(Counter(my_subject_ids).values())))
print("--- min length of stay: " + str(min(my_length_of_stays)))
print("--- average number of stays per patient: " + str(sum(Counter(my_subject_ids).values()) / len(Counter(my_subject_ids))))
print("--- average length of stay: " + str(sum(my_length_of_stays) / len(my_length_of_stays)))

Total number of ICU stays: 1390
Total number of patients: 1128
--- max number of stays per patient: 16
--- max length of stay: 1031.0
--- min number of stays per patient: 1
--- min length of stay: 5.0
--- average number of stays per patient: 1.2322695035460993
--- average length of stay: 72.2158273381295


In [10]:
#Breakdowns of train/test/val
total_test_visits = 0
total_pos_test_visits = 0
total_val_visits = 0
total_pos_val_visits = 0
total_train_visits = 0
total_pos_train_visits = 0

for i in range(len(test_labels)):
    test_labels[i] = [int(x) for x in test_labels[i]]
    total_test_visits += len(test_labels[i])
    total_pos_test_visits += sum(test_labels[i])
for i in range(len(val_labels)):
    val_labels[i] = [int(x) for x in val_labels[i]]
    total_val_visits += len(val_labels[i])
    total_pos_val_visits += sum(val_labels[i])
for i in range(len(train_labels)):
    train_labels[i] = [int(x) for x in train_labels[i]]
    total_train_visits += len(train_labels[i])
    total_pos_train_visits += sum(train_labels[i])

total_visits = total_test_visits + total_val_visits + total_train_visits
total_pos_visits = total_pos_test_visits + total_pos_val_visits + total_pos_train_visits

print("Total number of visits (loaded subset): " + str(total_visits) + ", Positive samples: " + str(total_pos_visits))
 
print("--Train--")
print("Number of stays: " + str(len(train_data_loader._data["X"])) + ", Number of visits: " + str(total_train_visits) + ", Number positive visits: " + str(total_pos_train_visits))
print("--Validation--")
print("Number of stays: " + str(len(val_data_loader._data["X"])) + ", Number of visits: " + str(total_val_visits) + ", Number positive visits: " + str(total_pos_val_visits))
print("--Test--")
print("Number of stays: " + str(len(test_data_loader._data["X"])) + ", Number of visits: " + str(total_test_visits) + ", Number positive visits: " + str(total_pos_test_visits))

Total number of visits (loaded subset): 94819, Positive samples: 1946
--Train--
Number of stays: 1000, Number of visits: 67991, Number positive visits: 1607
--Validation--
Number of stays: 112, Number of visits: 7986, Number positive visits: 44
--Test--
Number of stays: 278, Number of visits: 18842, Number positive visits: 295




---



---



### Methodology Explanation and Examples

In this study we run the following experiments:
reproducing StageNet model

1.   reproducing StageNet model
  * evaluating against full test set
  * evaluating against test subset
2.   reproducing StageNet-I model
  * evaluating against full test set
  * evaluating against test subset
3.   reproducing StageNet-II model
  * evaluating against full test set
  * evaluating against test subset

Below we provide code for experiement 1b, where we train a StageNet model and then evaluate against the test subset we loaded earlier. Note that results won't exactly match those reported in the Results section because we are only using n = 1000 samples and training over 5 epochs for illustrative purposes. The earlier arguments could be modified to recreate the results:

```
args.epochs = 50
args.small_part = 5000
```


##### Sample code

Instantiate Discretizer and Normalizer objects:

In [11]:
discretizer = Discretizer(timestep=1.0, store_masks=True,
                                impute_strategy='previous', start_time='zero')

discretizer_header = discretizer.transform(train_data_loader._data["X"][0])[1].split(',')
cont_channels = [i for (i, x) in enumerate(discretizer_header) if x.find("->") == -1]

normalizer = Normalizer(fields=cont_channels)
normalizer_state = 'decomp_normalizer'
normalizer_state = os.path.join(os.path.dirname(args.data_path), normalizer_state)
normalizer.load_params(normalizer_state)

Load the Batched Generators:

On local machine, this took about 7 minutes for train and 1 minute for validation.
On Google Colab, should only take 30 seconds for small_part = 5000.

In [12]:
print("Preparing Batched Data Loaders")

start_time = perf_counter()
train_data_gen = utils.BatchGenDeepSupervision(train_data_loader, 
                                               discretizer, normalizer, 
                                               args.batch_size, 
                                               shuffle=True, 
                                               return_names=True)
timer1 = perf_counter()
val_data_gen = utils.BatchGenDeepSupervision(val_data_loader, 
                                             discretizer, 
                                             normalizer, 
                                             args.batch_size, 
                                             shuffle=False, 
                                             return_names=True)
end_time = perf_counter()
        
print("Time to load train generator: " + str(dt.timedelta(seconds = timer1 - start_time))) #time to generate the train data gen
print("Time to load validation generator: " + str(dt.timedelta(seconds = end_time - timer1))) #time to generate the val data gen

Preparing Batched Data Loaders
Time to load train generator: 0:00:05.019320
Time to load validation generator: 0:00:00.578053


Model Construction:

In [15]:
print('Constructing model ... ')
device = torch.device("cuda:0" if torch.cuda.is_available() == True else 'cpu')
print("available device: {}".format(device))

model = StageNet(args.input_dim+17, args.rnn_dim, args.K, args.output_dim, args.chunk_level, args.dropconnect_rate, args.dropout_rate, args.dropres_rate).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)


Constructing model ... 
available device: cuda:0


Training:

In [16]:
print('Start training ... ')

train_loss = []
val_loss = []
batch_loss = []
max_auprc = 0
val_auroc = []
val_auprc = []
acc = []
prec0 = []
prec1 = []
rec0 = []
rec1 = []
minpse = []

my_epoch_times = []

file_name = './saved_weights/' + args.file_name + datetime.now().strftime("%d%m%Y_%H_%M")
for each_chunk in range(args.epochs):
    epoch_starttime = perf_counter()
    cur_batch_loss = []
    model.train()
    for each_batch in range(train_data_gen.steps):
        batch_data = next(train_data_gen)
        batch_name = batch_data['names']
        batch_data = batch_data['data']

        batch_x = torch.tensor(batch_data[0][0], dtype=torch.float32).to(device)
        batch_mask = torch.tensor(batch_data[0][1], dtype=torch.float32).unsqueeze(-1).to(device)
        batch_y = torch.tensor(batch_data[1], dtype=torch.float32).to(device)
        tmp = torch.zeros(batch_x.size(0),17, dtype=torch.float32).to(device)
        batch_interval = torch.zeros((batch_x.size(0),batch_x.size(1),17), dtype=torch.float32).to(device)
        
        for i in range(batch_x.size(1)):
            cur_ind = batch_x[:,i,-17:]
            tmp+=(cur_ind == 0).float()
            batch_interval[:, i, :] = cur_ind * tmp
            tmp[cur_ind==1] = 0        
        
        if batch_mask.size()[1] > 400:
            batch_x = batch_x[:, :400, :]
            batch_mask = batch_mask[:, :400, :]
            batch_y = batch_y[:, :400, :]
            batch_interval = batch_interval[:, :400, :]

        batch_x = torch.cat((batch_x, batch_interval), dim=-1)
        batch_time = torch.ones((batch_x.size(0), batch_x.size(1)), dtype=torch.float32).to(device)

        optimizer.zero_grad()
        cur_output, _ = model(batch_x, batch_time, device)
        masked_output = cur_output * batch_mask 
        loss = batch_y * torch.log(masked_output + 1e-7) + (1 - batch_y) * torch.log(1 - masked_output + 1e-7)
        loss = torch.sum(loss, dim=1) / torch.sum(batch_mask, dim=1)
        loss = torch.neg(torch.sum(loss))
        cur_batch_loss.append(loss.cpu().detach().numpy())

        loss.backward()
        optimizer.step()
        
        if each_batch % 50 == 0:
            print('Chunk %d, Batch %d: Loss = %.4f'%(each_chunk, each_batch, cur_batch_loss[-1]))

    batch_loss.append(cur_batch_loss)
    train_loss.append(np.mean(np.array(cur_batch_loss)))
    
    epoch_endtime = perf_counter()
    my_epoch_times.append(epoch_endtime - epoch_starttime)

    print('Epoch training time: ' + str(dt.timedelta(seconds = epoch_endtime - epoch_starttime)))
    
    print("\n==>Predicting on validation")
    with torch.no_grad():
        model.eval()
        cur_val_loss = []
        valid_true = []
        valid_pred = []
        for each_batch in range(val_data_gen.steps):
            valid_data = next(val_data_gen)
            valid_name = valid_data['names']
            valid_data = valid_data['data']
            
            valid_x = torch.tensor(valid_data[0][0], dtype=torch.float32).to(device)
            valid_mask = torch.tensor(valid_data[0][1], dtype=torch.float32).unsqueeze(-1).to(device)
            valid_y = torch.tensor(valid_data[1], dtype=torch.float32).to(device)
            tmp = torch.zeros(valid_x.size(0),17, dtype=torch.float32).to(device)
            valid_interval = torch.zeros((valid_x.size(0),valid_x.size(1),17), dtype=torch.float32).to(device)
            
            for i in range(valid_x.size(1)):
                cur_ind = valid_x[:,i,-17:]
                tmp+=(cur_ind == 0).float()
                valid_interval[:, i, :] = cur_ind * tmp
                tmp[cur_ind==1] = 0  
            
            if valid_mask.size()[1] > 400:
                valid_x = valid_x[:, :400, :]
                valid_mask = valid_mask[:, :400, :]
                valid_y = valid_y[:, :400, :]
                valid_interval = valid_interval[:, :400, :]
            
            valid_x = torch.cat((valid_x, valid_interval), dim=-1)
            valid_time = torch.ones((valid_x.size(0), valid_x.size(1)), dtype=torch.float32).to(device)
            
            valid_output, valid_dis = model(valid_x, valid_time, device)
            masked_valid_output = valid_output * valid_mask

            valid_loss = valid_y * torch.log(masked_valid_output + 1e-7) + (1 - valid_y) * torch.log(1 - masked_valid_output + 1e-7)
            valid_loss = torch.sum(valid_loss, dim=1) / torch.sum(valid_mask, dim=1)
            valid_loss = torch.neg(torch.sum(valid_loss))
            cur_val_loss.append(valid_loss.cpu().detach().numpy())

            for m, t, p in zip(valid_mask.cpu().numpy().flatten(), valid_y.cpu().numpy().flatten(), valid_output.cpu().detach().numpy().flatten()):
                if np.equal(m, 1):
                    valid_true.append(t)
                    valid_pred.append(p)

        val_loss.append(np.mean(np.array(cur_val_loss)))
        print('Valid loss = %.4f'%(val_loss[-1]))
        print('\n')
        valid_pred = np.array(valid_pred)
        valid_pred = np.stack([1 - valid_pred, valid_pred], axis=1)
        ret = metrics.print_metrics_binary(valid_true, valid_pred)
        print()

        cur_auprc = ret['auprc']

        val_auprc.append(cur_auprc)
        val_auroc.append(ret['auroc'])
        acc.append(ret['acc'])
        prec0.append(ret['prec0'])
        prec1.append(ret['prec1'])
        rec0.append(ret['rec0'])
        rec1.append(ret['rec1'])
        minpse.append(ret['minpse'])
        
        if cur_auprc > max_auprc:
            max_auprc = cur_auprc
            state = {
                'net': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'chunk': each_chunk
            }
            torch.save(state, file_name)
            print('\n------------ Save best model ------------\n')

Start training ... 
Chunk 0, Batch 0: Loss = 89.4580
Epoch training time: 0:00:21.096546

==>Predicting on validation
Valid loss = 9.1935


confusion matrix:
[[7672    0]
 [  44    0]]


  prec1 = cf[1][1] / (cf[1][1] + cf[0][1])


accuracy = 0.9942975640296936
precision class 0 = 0.9942975640296936
precision class 1 = nan
recall class 0 = 1.0
recall class 1 = 0.0
AUC of ROC = 0.8024812778462413
AUC of PRC = 0.013002167467328817
min(+P, Se) = 0.018509254627313655


------------ Save best model ------------

Chunk 1, Batch 0: Loss = 27.6849
Epoch training time: 0:00:13.767411

==>Predicting on validation
Valid loss = 6.0051


confusion matrix:
[[7672    0]
 [  44    0]]
accuracy = 0.9942975640296936
precision class 0 = 0.9942975640296936
precision class 1 = nan
recall class 0 = 1.0
recall class 1 = 0.0
AUC of ROC = 0.9570071807754289
AUC of PRC = 0.05887201537875261
min(+P, Se) = 0.07494145199063232


------------ Save best model ------------



  prec1 = cf[1][1] / (cf[1][1] + cf[0][1])


Chunk 2, Batch 0: Loss = 21.5394
Epoch training time: 0:00:13.991198

==>Predicting on validation
Valid loss = 5.0819


confusion matrix:
[[7669    3]
 [  44    0]]
accuracy = 0.9939087629318237
precision class 0 = 0.9942953586578369
precision class 1 = 0.0
recall class 0 = 0.9996089935302734
recall class 1 = 0.0
AUC of ROC = 0.9691173333965304
AUC of PRC = 0.0787611277519728
min(+P, Se) = 0.09803921568627451


------------ Save best model ------------

Chunk 3, Batch 0: Loss = 22.7054
Epoch training time: 0:00:13.363892

==>Predicting on validation
Valid loss = 4.0893


confusion matrix:
[[7653   19]
 [  44    0]]
accuracy = 0.9918351769447327
precision class 0 = 0.9942834973335266
precision class 1 = 0.0
recall class 0 = 0.9975234866142273
recall class 1 = 0.0
AUC of ROC = 0.9735105223243911
AUC of PRC = 0.09709196786107609
min(+P, Se) = 0.14166666666666666


------------ Save best model ------------

Chunk 4, Batch 0: Loss = 20.0406
Epoch training time: 0:00:13.964875

==>Predicting



---



Evaluate on Test Data

Load last checkpoint:

In [17]:
checkpoint = torch.load(file_name)
save_chunk = checkpoint['chunk']
print("last saved model is in chunk {}".format(save_chunk))
model.load_state_dict(checkpoint['net'])
optimizer.load_state_dict(checkpoint['optimizer'])
model.eval()

last saved model is in chunk 4


StageNet(
  (kernel): Linear(in_features=94, out_features=1542, bias=True)
  (recurrent_kernel): Linear(in_features=385, out_features=1542, bias=True)
  (nn_scale): Linear(in_features=384, out_features=64, bias=True)
  (nn_rescale): Linear(in_features=64, out_features=384, bias=True)
  (nn_conv): Conv1d(384, 384, kernel_size=(10,), stride=(1,))
  (nn_output): Linear(in_features=384, out_features=1, bias=True)
  (nn_dropconnect): Dropout(p=0.5, inplace=False)
  (nn_dropconnect_r): Dropout(p=0.5, inplace=False)
  (nn_dropout): Dropout(p=0.5, inplace=False)
  (nn_dropres): Dropout(p=0.3, inplace=False)
)

Load test data.

In [18]:
start_data_loads = perf_counter()
test_data_gen = utils.BatchGenDeepSupervision(test_data_loader, discretizer, normalizer, args.batch_size, shuffle=False, return_names=True)
end_data_loads = perf_counter()

print("Time to load test generator: " + str(dt.timedelta(seconds = end_data_loads - start_data_loads)))

Time to load test generator: 0:00:01.348980


Test the model:

In [19]:
print('Testing model ... ')
print('Checkpoint to be loaded: ')
print(file_name)

checkpoint = torch.load(file_name)
save_chunk = checkpoint['chunk']
print("last saved model is in chunk {}".format(save_chunk))
model.load_state_dict(checkpoint['net'])
optimizer.load_state_dict(checkpoint['optimizer'])
model.eval()

start_time = perf_counter()
#test_data_loader = common_utils.DeepSupervisionDataLoader(dataset_dir=os.path.join(args.data_path, 'test'),
#                                                                listfile=os.path.join(args.data_path, 'test_listfile.csv'), small_part=args.small_part)

test_data_loader = common_utils.DeepSupervisionDataLoader(dataset_dir=os.path.join(args.data_path, 'test'),
                                                          listfile=os.path.join(args.data_path, 'test_listfile.csv'),
                                                          small_part=args.small_part)


timer1 = perf_counter()
test_data_gen = utils.BatchGenDeepSupervision(test_data_loader, discretizer,
                                            normalizer, args.batch_size,
                                            shuffle=False, return_names=True)
end_time = perf_counter()

print("Time to load test data: " + str(dt.timedelta(seconds = timer1 - start_time)))
print("Time to load test generator: " + str(dt.timedelta(seconds = end_time - timer1)))
print("Size of test set: " + str(len(test_data_loader._data["X"])))

with torch.no_grad():
    torch.manual_seed(RANDOM_SEED)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(RANDOM_SEED)

    cur_test_loss = []
    test_true = []
    test_pred = []
    
    for each_batch in range(test_data_gen.steps):
        test_data = next(test_data_gen)
        test_name = test_data['names']
        test_data = test_data['data']

        test_x = torch.tensor(test_data[0][0], dtype=torch.float32).to(device)
        test_mask = torch.tensor(test_data[0][1], dtype=torch.float32).unsqueeze(-1).to(device)
        test_y = torch.tensor(test_data[1], dtype=torch.float32).to(device)
        tmp = torch.zeros(test_x.size(0),17, dtype=torch.float32).to(device)
        test_interval = torch.zeros((test_x.size(0),test_x.size(1),17), dtype=torch.float32).to(device)

        for i in range(test_x.size(1)):
            cur_ind = test_x[:,i,-17:]
            tmp+=(cur_ind == 0).float()
            test_interval[:, i, :] = cur_ind * tmp
            tmp[cur_ind==1] = 0  
        
        if test_mask.size()[1] > 400:
            test_x = test_x[:, :400, :]
            test_mask = test_mask[:, :400, :]
            test_y = test_y[:, :400, :]
            test_interval = test_interval[:, :400, :]
        
        test_x = torch.cat((test_x, test_interval), dim=-1)
        test_time = torch.ones((test_x.size(0), test_x.size(1)), dtype=torch.float32).to(device)
        
        test_output, test_dis = model(test_x, test_time, device)
        masked_test_output = test_output * test_mask

        test_loss = test_y * torch.log(masked_test_output + 1e-7) + (1 - test_y) * torch.log(1 - masked_test_output + 1e-7)
        test_loss = torch.sum(test_loss, dim=1) / torch.sum(test_mask, dim=1)
        test_loss = torch.neg(torch.sum(test_loss))
        cur_test_loss.append(test_loss.cpu().detach().numpy()) 
        
        for m, t, p in zip(test_mask.cpu().numpy().flatten(), test_y.cpu().numpy().flatten(), test_output.cpu().detach().numpy().flatten()):
            if np.equal(m, 1):
                test_true.append(t)
                test_pred.append(p)
    
    print('Test loss = %.4f'%(np.mean(np.array(cur_test_loss))))
    print('\n')
    test_pred = np.array(test_pred)
    test_pred = np.stack([1 - test_pred, test_pred], axis=1)
    test_ret = metrics.print_metrics_binary(test_true, test_pred)

Testing model ... 
Checkpoint to be loaded: 
./saved_weights/trained_model09052023_02_16
last saved model is in chunk 4
Generating data...
Time to load test data: 0:00:02.942288
Time to load test generator: 0:00:01.339803
Size of test set: 278
Test loss = 7.5791


confusion matrix:
[[18307    85]
 [  275    20]]
accuracy = 0.9807352423667908
precision class 0 = 0.9852007031440735
precision class 1 = 0.190476194024086
recall class 0 = 0.9953784346580505
recall class 1 = 0.06779661029577255
AUC of ROC = 0.7428179901357259
AUC of PRC = 0.09318913646989999
min(+P, Se) = 0.1282051282051282




---



---



### Results

First, we translate the results for the sample experiment from this notebook:

| AUPRC  | AUROC  | min(Re, P+) |
|--------|--------|-------------|
| 0.0932 | 0.7428 | 0.1282      |

Next, we provide the results from all experiements. We also show results from 2 baseline models used in the original paper, and we show all results reported from the original paper in comparison to our own experiment results:

|                        |          | AUPRC |        |          | AUROC |        |          | min(Re,P+) |        |
|------------------------|:--------:|-------|--------|:--------:|-------|--------|:--------:|------------|--------|
|                        | Original | Full  | Subset | Original | Full  | Subset | Original | Full       | Subset |
| Baseline1: ON-LSTM     | 0.304    | ---   | ---    | 0.895    | ---   | ---    | 0.343    | ---        | ---    |
| Baseline2: Health-LSTM | 0.291    | ---   | ---    | 0.897    | ---   | ---    | 0.325    | ---        | ---    |
| Pre-Trained StageNet   | 0.323    | 0.341 | 0.289  | 0.903    | 0.909 | 0.890  | 0.372    | 0.390      | 0.347  |
| Reproduced StageNet    | 0.323    | 0.228 | 0.206  | 0.903    | 0.874 | 0.842  | 0.372    | 0.292      | 0.280  |
| Ablation1: StageNet-I  | 0.313    | 0.226 | 0.209  | 0.899    | 0.850 | 0.838  | 0.360    | 0.315      | 0.279  |
| Ablation2: StageNet-II | 0.311    | 0.220 | 0.211  | 0.897    | 0.872 | 0.844  | 0.358    | 0.287      | 0.280  |

---

---





### References

Junyi Gao, Cao Xiao, Yasha Wang, Wen Tang, Lucas M.
Glass, and Jimeng Sun. 2020. Stagenet: Stage-aware
neural networks for health risk prediction. CoRR,
abs/2001.10054.

Amaral L. Glass L. Hausdorff J. Ivanov P. C. Mark R.
... Stanley H. E. Goldberger, A. 2000. Physiobank,
physiotoolkit, and physionet: Components of a new
research resource for complex physiologic signals.
PhysioNet, 101:e215–e220.

Hrayr Harutyunyan, Hrant Khachatrian, David C. Kale,
Greg Ver Steeg, and Aram Galstyan. 2019. Multitask
learning and benchmarking with clinical time series
data. Scientific Data, 6(1):96.

Nasir Hayat, Krzysztof J. Geras, and Farah E. Shamout.
2022. Medfuse: Multi-modal fusion with clinical
time-series data and chest x-ray images.

Xi Zhang Fei Wang Anil K Jain Inci M Baytas,
Cao Xiao and Jiayu Zhou. 2017. Patient subtyp-
ing via time-aware lstm networks. Proceedings of
the 23rd ACM SIGKDD international conference on
knowledge discovery and data mining, pages 65–74.

Pollard Tom J. Shen Lu Lehman Li-wei H. Feng
Mengling Ghassemi Mohammad Moody Benjamin
Szolovits Peter Anthony Celi Leo Mark Roger G.
Johnson, Alistair E.W. 2016. Mimic-iii, a freely ac-
cessible critical care database.

Tengfei Ma, Cao Xiao, and Fei Wang. Health-ATM:
A Deep Architecture for Multifaceted Patient Health
Record Representation and Risk Prediction, pages
261–269.

Yikang Shen, Shawn Tan, Alessandro Sordoni, and
Aaron C. Courville. 2018. Ordered neurons: Inte-
grating tree structures into recurrent neural networks.
CoRR, abs/1810.09536.




---



---

Thank you!