In [1]:
import numpy as np
import pandas as pd
import random
import glob
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torchtext.data import BucketIterator

random.seed(696)

## Load and split dataset

In [2]:
patients = pd.read_csv('patient_data_norm.csv')

In [3]:
def get_split_indices(pos_len, neg_len, ratios):
    train, val, test = ratios[0], ratios[1], ratios[2]
    pos_tr, neg_tr = int(round(pos_len * train)), int(round(neg_len * train))
    pos_val, neg_val = int(round(pos_len * val)), int(round(neg_len * val))
    pos_test, neg_test = int(round(pos_len * test)), int(round(neg_len * test))
    return ((pos_tr, pos_val, pos_test), (neg_tr, neg_val, neg_test))
    
    
def split_dataset(patients, ratios):
    positive = patients[patients['SepsisLabel'] == 1]['pid'].unique().tolist()
    negative = [i for i in range(1, 5000+1) if i not in positive]
    random.shuffle(positive)
    random.shuffle(negative)
    pos_idx, neg_idx = get_split_indices(len(positive), len(negative), ratios)
    
    train = positive[0:pos_idx[0]] + negative[0:neg_idx[0]]
    val = positive[pos_idx[0]:pos_idx[0] + pos_idx[1]] + negative[neg_idx[0]:neg_idx[0] + neg_idx[1]]
    test = positive[pos_idx[0] + pos_idx[1]:] + negative[neg_idx[0] + neg_idx[1]:]
    
    train_dict, val_dict, test_dict = {}, {}, {}
    for pid in train:
        train_dict[pid] = patients[patients['pid'] == pid]
    for pid in val:
        val_dict[pid] = patients[patients['pid'] == pid]
    for pid in test:
        test_dict[pid] = patients[patients['pid'] == pid]
    
    return train_dict, val_dict, test_dict

## Choose patient observation windows

In [4]:
def process_patient(patient, max_len, window_marker=70):
#     print(patient)
    patient = patient.reset_index()
    obs_len = patient.shape[0]
    
    if(patient[patient['SepsisLabel'] == 1].shape[0]):
        sepsis_idx = list(patient[patient['SepsisLabel']==1].index)[0]
        if obs_len > max_len: return process_longer_obs(patient, sepsis_idx, max_len, marker=window_marker)
        if obs_len < max_len: return process_shorter_obs(patient, sepsis_idx, max_len)
    else:
        if obs_len > max_len:
            return patient.iloc[0:max_len, 2:]
        if obs_len < max_len:
            p = patient.iloc[:, :]
            for i in range(max_len - obs_len):
                p = p.append(patient.iloc[-1, :])
            return p.iloc[:, 2:]
    return patient.iloc[:, 2:]

def process_shorter_obs(patient, sepsis_idx, max_len):
    p = pd.DataFrame()
    p = p.append(patient)
    for i in range(max_len - patient.shape[0]):
        p = p.append(p.iloc[-1, :])
    return p.reset_index().iloc[:, 3:]
        
def process_longer_obs(patient, sepsis_idx, max_len, marker=70):
    p = pd.DataFrame()
    avail_before = sepsis_idx - 1
    avail_after = patient.shape[0] - sepsis_idx
    need_before = int(max_len * marker/100)
    need_after = int(max_len * (100 - marker)/100) - 1
   
    if avail_before >= need_before and avail_after >= need_after:
        p = p.append(patient.iloc[avail_before - need_before:avail_before+1, :])
        p = p.append(patient.iloc[sepsis_idx, :])
        p = p.append(patient.iloc[sepsis_idx+1 : sepsis_idx + need_after, :])
    
    elif avail_before >= need_before and avail_after <= need_after:
        p = p.append(patient.iloc[avail_before - need_before:avail_before+1, :])
        p = p.append(patient.iloc[sepsis_idx, :])
        p = p.append(patient.iloc[sepsis_idx + 1:, :])
        for i in range(max_len - p.shape[0]):
            p = p.append(p.iloc[-1, :])
    
    elif avail_before <= need_before and avail_after >= need_after:
        p = p.append(patient.iloc[0:avail_before, :])
        p = p.append(patient.iloc[sepsis_idx, :])
        p = p.append(patient.iloc[sepsis_idx+1 : sepsis_idx + need_after, :])
        for i in range(max_len - p.shape[0]):
            p = p.concat([p.iloc[0, :], p], ignore_index = True)
    
    return p.reset_index().iloc[:, 3:]

## Create DataLoaders

In [5]:
class PatientDataset(Dataset):
    def __init__(self, patient_dict, max_obs_len, window_marker):
        self.patient_dict = patient_dict
        self.num_patients = len(patient_dict)
        self.pids = list(patient_dict.keys())
        self.max_obs_len = max_obs_len
        self.window_marker = window_marker
        
    def __len__(self):
        return self.num_patients
    
    def __getitem__(self, idx):
        patient = self.patient_dict[self.pids[idx]]
        patient = process_patient(patient, self.max_obs_len, self.window_marker)
        patient_features = torch.FloatTensor(patient.iloc[:, 1:-1].values)
        patient_labels = torch.FloatTensor(patient['SepsisLabel'])
        self.num_patients -= 1
        return patient_features, patient_labels

In [6]:
def data_loader(patient_dict, max_obs_len, batch_size, shuffle=True, window_marker=70):
    return DataLoader(PatientDataset(patient_dict, max_obs_len, window_marker), batch_size, shuffle)

## Model training and evaluation setup

In [34]:
def confusion_matrix(prediction, truth):
    confusion_vector = prediction/truth
    true_positives = torch.sum(confusion_vector == 1).item()
    false_positives = torch.sum(confusion_vector == float('inf')).item()
    true_negatives = torch.sum(torch.isnan(confusion_vector)).item()
    false_negatives = torch.sum(confusion_vector == 0).item()
    return true_positives, false_positives, true_negatives, false_negatives

def check_accuracy(model, loader, group):
    print('Checking ' + group + ' accuracy!')
    num_correct = 0
    num_samples = 0
    tp, fp, tn, fn, precision, recall, f1 = 0, 0, 0, 0, 0 ,0, 0
    model.eval()
    for t, (x, y) in enumerate(loader):
        scores = model(x)
        rounded_preds = torch.round(torch.sigmoid(scores))
        num_correct += (rounded_preds == y).sum()
        num_samples += y.size(0) * y.size(1)
        tp_t, fp_t, tn_t, fn_t = confusion_matrix(rounded_preds, y)
        tp += tp_t
        fp += fp_t
        tn += tn_t
        fn += fn_t

    acc = float(num_correct) / num_samples
    print('Got %d / %d correct (%.2f)' % (num_correct, num_samples, 100 * acc))
    print('TP = ', tp, ', FP = ', fp, ', TN = ', tn, ', FN = ', fn)
    if tp != 0:
        precision = tp/(tp + fp)
        recall = tp/(tp + fn)
        f1 = 2 * ((precision * recall)/(precision + recall))
    print('Precision = ', precision, ', Recall = ', recall, ', F1 Score = ', f1)
    print()
    return (acc, precision, recall, f1)

In [35]:
def train(model, optimizer, loss_fn, train_dict, val_dict, max_obs_len, batch_size, epochs=1, print_every=50, window_marker=70):
    train_history, val_history = [], []
    for e in range(epochs):
        print('Epoch: ', e+1)
        for t, (x, y) in enumerate(data_loader(train_dict, max_obs_len, batch_size, window_marker)):
            model.train()
            scores = model(x)
            loss = loss_fn(scores, y)
            if (t + 1) % print_every == 0:
                print('t = %d, loss = %.4f' % (t + 1, loss.item()))
                
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
        train_history.append(check_accuracy(model, data_loader(train_dict, max_obs_len, batch_size, window_marker), 'train'))
        val_history.append(check_accuracy(model, data_loader(val_dict, max_obs_len, batch_size, window_marker), 'val'))
        print()
    return (train_history, val_history)

In [36]:
def get_pos_weight(patient_dict):
    subset = patients[patients['pid'].isin(list(patient_dict.keys()))]
    total_samples = len(subset)
    pos_samples = subset[subset['SepsisLabel'] == 1]['pid'].count()
    return total_samples/pos_samples

## Simple LSTM model

In [45]:
class SimpleLSTM(nn.Module):
    def __init__(self, feature_dim, hidden_dim, out_dim):
        super().__init__()
        self.rnn = nn.LSTM(feature_dim, hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, out_dim)
        
    def forward(self, x):
        output, (hidden, cell) = self.rnn(x)
        fc_out = self.fc(hidden.squeeze(0)) 
        return fc_out

In [40]:
def test_run(config):
    train_dict, val_dict, test_dict = split_dataset(patients, config['ratios'])
    feature_dim, hidden_dim, output_dim = config['feature_dim'], config['hidden_dim'], config['output_dim']
    model = SimpleLSTM(feature_dim, hidden_dim, output_dim)
    optimizer = optim.Adam(model.parameters(), lr=config['lr_rate'])
    criterion = nn.BCEWithLogitsLoss()
    if config['pos_weight'] is not None:
        criterion = nn.BCEWithLogitsLoss(pos_weight=torch.Tensor([config['pos_weight']]))
    
    train_hist, val_hist = train(model, optimizer, criterion, 
                                     train_dict, val_dict, config['max_obs_len'], config['batch_size'], 
                                     epochs=config['epochs'], window_marker=config['window_marker'])
    
    # Some formatting bug here, will fix later
#     torch.save(model.state_dict(), config['model_name'] + '_' + model_perf)
#     _, tp, tr, tf = train_hist[-1]
#     _, vp, vr, vf = val_hist[-1]
#     model_perf = str(tp) + '_' + str(tr) + '_' + str(tf) + '_' + str(vp) + '_' + str(vr) + '_' + str(vf)
    return model, train_hist, val_hist

In [41]:
# New with actual LSTM layer

model_config = {
    'ratios': (.7, .2, .1),
    'feature_dim': 39,
    'hidden_dim': 128,
    'output_dim': 30,
    'max_obs_len': 30,
    'batch_size': 16,
    'lr_rate': 1e-3,
    'pos_weight': None,
    'epochs': 30,
    'window_marker': 70,
    'model_name': 'simple_lstm'
}

model, train_hist, val_hist = test_run(model_config)

Epoch:  1
t = 50, loss = 0.0146
t = 100, loss = 0.2910
t = 150, loss = 0.2448
t = 200, loss = 0.0207
Checking train accuracy!
Got 101898 / 105000 correct (97.05)
TP =  0 , FP =  0 , TN =  101898 , FN =  3102
Precision =  0 , Recall =  0 , F1 Score =  0

Checking val accuracy!
Got 29197 / 30000 correct (97.32)
TP =  0 , FP =  0 , TN =  29197 , FN =  803
Precision =  0 , Recall =  0 , F1 Score =  0


Epoch:  2
t = 50, loss = 0.2722
t = 100, loss = 0.0712
t = 150, loss = 0.0229
t = 200, loss = 0.0803
Checking train accuracy!
Got 101922 / 105000 correct (97.07)
TP =  26 , FP =  2 , TN =  101896 , FN =  3076
Precision =  0.9285714285714286 , Recall =  0.008381689232753063 , F1 Score =  0.016613418530351438

Checking val accuracy!
Got 29197 / 30000 correct (97.32)
TP =  0 , FP =  0 , TN =  29197 , FN =  803
Precision =  0 , Recall =  0 , F1 Score =  0


Epoch:  3
t = 50, loss = 0.0912
t = 100, loss = 0.0314
t = 150, loss = 0.1357
t = 200, loss = 0.2111
Checking train accuracy!
Got 101817 / 1

Got 29438 / 30000 correct (98.13)
TP =  300 , FP =  59 , TN =  29138 , FN =  503
Precision =  0.8356545961002786 , Recall =  0.37359900373599003 , F1 Score =  0.5163511187607573


Epoch:  18
t = 50, loss = 0.0050
t = 100, loss = 0.0160
t = 150, loss = 0.0040
t = 200, loss = 0.0024
Checking train accuracy!
Got 103821 / 105000 correct (98.88)
TP =  2571 , FP =  648 , TN =  101250 , FN =  531
Precision =  0.798695246971109 , Recall =  0.8288201160541586 , F1 Score =  0.8134788799240626

Checking val accuracy!
Got 29288 / 30000 correct (97.63)
TP =  337 , FP =  246 , TN =  28951 , FN =  466
Precision =  0.5780445969125214 , Recall =  0.41967621419676215 , F1 Score =  0.4862914862914863


Epoch:  19
t = 50, loss = 0.3162
t = 100, loss = 0.0056
t = 150, loss = 0.0023
t = 200, loss = 0.0054
Checking train accuracy!
Got 104285 / 105000 correct (99.32)
TP =  2631 , FP =  244 , TN =  101654 , FN =  471
Precision =  0.9151304347826087 , Recall =  0.8481624758220503 , F1 Score =  0.880374769951480

TypeError: 'float' object is not iterable

In [31]:
# Old one with RNN
model_config = {
    'ratios': (.7, .2, .1),
    'feature_dim': 39,
    'hidden_dim': 128,
    'output_dim': 30,
    'max_obs_len': 30,
    'batch_size': 16,
    'lr_rate': 1e-3,
    'pos_weight': None,
    'epochs': 20,
    'window_marker': 70
}

test_run(model_config)

Epoch:  1
t = 50, loss = 0.0211
t = 100, loss = 0.2317
t = 150, loss = 0.1927
t = 200, loss = 0.2605
Checking train accuracy!
Got 101818 / 105000 correct (96.97)
TP =  62 , FP =  180 , TN =  101756 , FN =  3002
Precision =  0.256198347107438 , Recall =  0.020234986945169713 , F1 Score =  0.03750756200846945

Checking val accuracy!
Got 29039 / 30000 correct (96.80)
TP =  0 , FP =  60 , TN =  29039 , FN =  901
Precision =  0 , Recall =  0 , F1 Score =  0


Epoch:  2
t = 50, loss = 0.3444
t = 100, loss = 0.0408
t = 150, loss = 0.2845
t = 200, loss = 0.0282
Checking train accuracy!
Got 102047 / 105000 correct (97.19)
TP =  248 , FP =  137 , TN =  101799 , FN =  2816
Precision =  0.6441558441558441 , Recall =  0.08093994778067885 , F1 Score =  0.14380979994201218

Checking val accuracy!
Got 29108 / 30000 correct (97.03)
TP =  65 , FP =  56 , TN =  29043 , FN =  836
Precision =  0.5371900826446281 , Recall =  0.07214206437291898 , F1 Score =  0.12720156555772996


Epoch:  3
t = 50, loss = 0.

t = 200, loss = 0.0625
Checking train accuracy!
Got 104154 / 105000 correct (99.19)
TP =  2397 , FP =  179 , TN =  101757 , FN =  667
Precision =  0.9305124223602484 , Recall =  0.7823107049608355 , F1 Score =  0.8499999999999999

Checking val accuracy!
Got 29431 / 30000 correct (98.10)
TP =  418 , FP =  86 , TN =  29013 , FN =  483
Precision =  0.8293650793650794 , Recall =  0.46392896781354054 , F1 Score =  0.5950177935943061


Epoch:  18
t = 50, loss = 0.0040
t = 100, loss = 0.0076
t = 150, loss = 0.0190
t = 200, loss = 0.0034
Checking train accuracy!
Got 103830 / 105000 correct (98.89)
TP =  2103 , FP =  209 , TN =  101727 , FN =  961
Precision =  0.9096020761245674 , Recall =  0.6863577023498695 , F1 Score =  0.7823660714285713

Checking val accuracy!
Got 29362 / 30000 correct (97.87)
TP =  374 , FP =  111 , TN =  28988 , FN =  527
Precision =  0.7711340206185567 , Recall =  0.41509433962264153 , F1 Score =  0.5396825396825398


Epoch:  19
t = 50, loss = 0.0640
t = 100, loss = 0.0

(SimpleLSTM(
   (rnn): RNN(39, 128, batch_first=True)
   (fc): Linear(in_features=128, out_features=30, bias=True)
 ),
 [96.9695238095238,
  97.18761904761905,
  97.45714285714286,
  97.7152380952381,
  98.16666666666667,
  98.01809523809524,
  98.27238095238096,
  98.05619047619048,
  97.99333333333334,
  98.51142857142857,
  98.27333333333334,
  98.75047619047619,
  98.79904761904761,
  98.77238095238096,
  98.26380952380951,
  99.01333333333334,
  99.19428571428571,
  98.88571428571429,
  98.4552380952381,
  98.5447619047619],
 [96.79666666666667,
  97.02666666666667,
  97.24666666666667,
  97.43666666666667,
  97.94333333333334,
  97.75333333333333,
  97.93666666666667,
  97.82,
  97.60666666666667,
  98.23666666666668,
  97.45333333333333,
  98.00999999999999,
  98.21333333333332,
  98.32666666666667,
  97.32,
  98.45666666666666,
  98.10333333333332,
  97.87333333333333,
  97.78999999999999,
  97.98])

In [32]:
# Old one with RNN
model_config = {
    'ratios': (.7, .2, .1),
    'feature_dim': 39,
    'hidden_dim': 128,
    'output_dim': 30,
    'max_obs_len': 30,
    'batch_size': 16,
    'lr_rate': 1e-3,
    'pos_weight': 17,
    'epochs': 20,
    'window_marker': 70
}

test_run(model_config)

Epoch:  1
t = 50, loss = 0.5825
t = 100, loss = 1.9420
t = 150, loss = 2.1549
t = 200, loss = 0.3596
Checking train accuracy!
Got 97113 / 105000 correct (92.49)
TP =  1791 , FP =  6505 , TN =  95322 , FN =  1382
Precision =  0.21588717454194792 , Recall =  0.5644500472738733 , F1 Score =  0.3123201674077949

Checking val accuracy!
Got 27750 / 30000 correct (92.50)
TP =  464 , FP =  1884 , TN =  27286 , FN =  366
Precision =  0.19761499148211242 , Recall =  0.5590361445783133 , F1 Score =  0.29200755191944616


Epoch:  2
t = 50, loss = 0.4268
t = 100, loss = 0.3989
t = 150, loss = 0.3383
t = 200, loss = 2.0334
Checking train accuracy!
Got 97030 / 105000 correct (92.41)
TP =  1705 , FP =  6502 , TN =  95325 , FN =  1468
Precision =  0.20774948214938468 , Recall =  0.5373463599117554 , F1 Score =  0.29964850615114236

Checking val accuracy!
Got 27704 / 30000 correct (92.35)
TP =  483 , FP =  1949 , TN =  27221 , FN =  347
Precision =  0.19860197368421054 , Recall =  0.5819277108433735 , F

t = 50, loss = 0.0424
t = 100, loss = 0.0679
t = 150, loss = 0.1036
t = 200, loss = 0.1651
Checking train accuracy!
Got 99269 / 105000 correct (94.54)
TP =  2801 , FP =  5359 , TN =  96468 , FN =  372
Precision =  0.34325980392156863 , Recall =  0.8827607942010716 , F1 Score =  0.49430865613694525

Checking val accuracy!
Got 28068 / 30000 correct (93.56)
TP =  583 , FP =  1685 , TN =  27485 , FN =  247
Precision =  0.2570546737213404 , Recall =  0.7024096385542169 , F1 Score =  0.3763718528082634


Epoch:  18
t = 50, loss = 0.1391
t = 100, loss = 2.2378
t = 150, loss = 0.0998
t = 200, loss = 0.1663
Checking train accuracy!
Got 101073 / 105000 correct (96.26)
TP =  2920 , FP =  3674 , TN =  98153 , FN =  253
Precision =  0.44282681225356385 , Recall =  0.9202647336905138 , F1 Score =  0.5979318112009829

Checking val accuracy!
Got 28399 / 30000 correct (94.66)
TP =  577 , FP =  1348 , TN =  27822 , FN =  253
Precision =  0.29974025974025975 , Recall =  0.6951807228915663 , F1 Score =  0

(SimpleLSTM(
   (rnn): RNN(39, 128, batch_first=True)
   (fc): Linear(in_features=128, out_features=30, bias=True)
 ),
 [92.48857142857143,
  92.4095238095238,
  89.6752380952381,
  89.79238095238095,
  94.51809523809523,
  92.1952380952381,
  94.03238095238096,
  91.21333333333334,
  96.06095238095239,
  93.05333333333333,
  94.40190476190476,
  93.96190476190476,
  95.79904761904761,
  94.75714285714287,
  92.25428571428571,
  96.7552380952381,
  94.54190476190476,
  96.26,
  97.07142857142857,
  94.80285714285715],
 [92.5,
  92.34666666666666,
  89.79333333333334,
  89.51,
  94.23,
  91.80666666666667,
  93.41333333333334,
  90.02666666666667,
  94.62,
  92.62,
  93.86333333333333,
  92.93,
  94.57666666666667,
  93.54666666666667,
  90.10000000000001,
  95.02000000000001,
  93.56,
  94.66333333333333,
  95.99,
  93.14])