In [1]:
#!pip install polars numpy torch pandas seaborn

In [2]:
import polars as pl
import pandas as pd
import os
import sys
import random
import pickle
import torch
import numpy as np

import torch.nn as nn
import torch.nn.functional as F
# set seed
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
DATA_PATH = ""
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:512"

  from pandas.core.computation.check import NUMEXPR_INSTALLED


In [3]:
import sys
obj = pickle.load(open(os.path.join(DATA_PATH, 'allevents_by_episode_48_charts'), 'rb'))
obj.estimated_size()/(1024**2)
patients = obj.select(pl.col('subject_id')).to_series().to_list()
seqs = obj.select(pl.col('itemidx')).to_series().to_list()
mortality = obj.select(pl.col('mortality_tf')).to_series().to_list()

In [4]:
itemidx = set([each_item for events in seqs for items in events for each_item in items])
len(itemidx)

40119

In [5]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [6]:
# Custom dataset
from torch.utils.data import Dataset, DataLoader, random_split

class CustomDataset(Dataset):
    def __init__(self, seqs, mortality):
        self.seqs = seqs
        self.labels = mortality
    def __len__(self):
        return len(self.labels)
    def __getitem__(self, index):
        return self.seqs[index], self.labels[index]
dataset = CustomDataset(seqs, mortality)
assert len(dataset) == len(obj) #TODO write test separately 


In [7]:
# Collate function
def collate_fn(data):
    """
    Input data: a tuple of seqs, and label
    Output: x (num_patients, num_events, num_itemids) with *real* items in the events
            rev_x (num_patients, num_events, num_itemids) with *reverse real* items in the events
            masks  (num_patients, num_events, num_itemids) whether a *real* itemid in the events is present
            rev_masks (num_patients, num_events, num_itemids) whether a *reverse real* itemid in the events is present
    """
    seqs, labels = zip(*data)
    num_patients = len(seqs)
    max_num_events = max([len(event) for event in seqs])
    max_num_items = max([len(itemid) for event in seqs for itemid in event])
    tensor_shape = (num_patients, max_num_events, max_num_items)
    x =        torch.zeros(tensor_shape, dtype=torch.long)
    rev_x =    torch.zeros(tensor_shape, dtype=torch.long)
    masks =    torch.zeros(tensor_shape, dtype=torch.bool)
    rev_masks = torch.zeros(tensor_shape, dtype=torch.bool) 
    y =        torch.tensor(labels, dtype=torch.long)
    

    for i_patient, events in enumerate(seqs):
        for i_event, item in enumerate(events):
            padded_item = torch.concat([torch.tensor(item),
                                        torch.zeros(max_num_items - len(item))]).long()
            x[i_patient, i_event, :] = padded_item
            masks[i_patient, i_event, :] = torch.where(padded_item!=0,1,0)  
    for i_patient, events in enumerate(seqs):
        idx_all_real_events = torch.sum(x[i_patient, :, :], dim=(1))!= 0
        idx_padded_events =torch.sum(x[i_patient, :, :], dim=(1))== 0
        fliped = torch.flip(x[i_patient, idx_all_real_events, :].unsqueeze(1), (0,)).squeeze(1)
        rev_x[i_patient, :, :] = torch.concat((fliped, x[i_patient, idx_padded_events, :] ))
        rev_masks[i_patient, :, :] = torch.where(rev_x[i_patient, :, :] != 0, True, False)
    return x, masks, rev_x, rev_masks, y

    

from torch.utils.data import DataLoader
loader = DataLoader(dataset, batch_size=10, collate_fn=collate_fn, pin_memory=True)
loader_iter = iter(loader)
x, masks, rev_x, rev_masks, y = next(loader_iter)
#assert x.shape == masks.shape == (10, 8, 313)
assert y.shape == (10,)

In [22]:
train, val, test = int(len(dataset)*0.8), int(len(dataset)*0.1), len(dataset) - int(len(dataset)*0.8) -  int(len(dataset)*0.1)
lengths = [train, val, test]
train_dataset, val_dataset, test_dataset = random_split(dataset=dataset, lengths=lengths)
import math
def load_data(train_dataset, val_dataset, test_dataset, collate_fn):
    batch_size = 64
    seqs, labels = zip(*train_dataset)
    arr = np.array(labels)
    #First stage of sampling: find all the positive labels
    idx_true = np.where(arr == True)
    minority, majority= np.sum(arr)/len(arr), 1 - np.sum(arr)/len(arr)
    weights = np.apply_along_axis(lambda x: x * (majority) if x is True else x * minority, 0, np.ones_like(arr))
    #sampler = torch.utils.data.sampler.WeightedRandomSampler(weights=weights, num_samples = batch_size, replacement=False)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, \
                                collate_fn=collate_fn, \
                                num_workers=4, shuffle=True)
    
    val_loader = DataLoader(train_dataset, batch_size=batch_size, \
                                collate_fn=collate_fn, \
                                num_workers=4, shuffle=False)
    test_loader = DataLoader(train_dataset, batch_size=batch_size, \
                                collate_fn=collate_fn, \
                                num_workers=4, shuffle=False)
    return train_loader, val_loader, test_loader

train_loader, val_loader, test_loader = load_data(train_dataset, val_dataset, test_dataset, collate_fn)
print("Length of train dataset:", len(train_dataset))
print("Length of train with mortality:", np.sum([element[-1] for element in train_dataset]))
print("Length of train with survive patients:", (len(train_dataset) - np.sum([element[-1] for element in train_dataset])))
print("Length of val dataset:", len(val_dataset))
print("Length of val with mortality:", np.sum([element[-1] for element in val_dataset]))
print("Length of val with survive patients:", (len(train_dataset) - np.sum([element[-1] for element in val_dataset])))
print("Length of test dataset:", len(test_dataset))
print("Length of test with mortality:", np.sum([element[-1] for element in test_dataset]))
print("Length of test with survive patients:", (len(test_dataset) - np.sum([element[-1] for element in test_dataset])))

Length of train dataset: 27032
Length of train with mortality: 3552
Length of train with survive patients: 23480
Length of val dataset: 3379
Length of val with mortality: 451
Length of val with survive patients: 26581
Length of test dataset: 3379
Length of test with mortality: 468
Length of test with survive patients: 2911


In [9]:
def sum_embeddings_with_masks(x, masks):
    """
    Input:  x               (batch_size, num_events, num_itemids, embedding_dims)
            mask            (batch_size, num_events, num_itemids)
    Output: sum_embeddings  (batch_size, num_events, embedding_dims)
    The return output compress the num_itemids into embedding vectors
    """
    masks = masks.unsqueeze(-1).expand(x.shape[0], x.shape[1], x.shape[2], x.shape[3])
    return torch.sum(masks * x, 2)
def get_last_event(hidden_states, masks):
    """
    hidden_states: (batch_size, #item, embedding_dim)
    masks:         (batch_size, #item, embedding_dim)
    return last_hidden_state: (batch_size, embedding_dim)
    """
    idx_last_event = torch.argmin(torch.sum(masks, 2), 1)
    return hidden_states[torch.arange(hidden_states.shape[0]), torch.where(idx_last_event - 1 < 0, max(idx_last_event), idx_last_event - 1),:]

In [10]:
ts = torch.tensor([[
    [-0.8201, 0.3956, 0.8989, -1.3884, -0.1670, 0.2851, -0.6411],
    [-0.8937, 0.9265, -0.5355, -1.1597, -0.4602, 0.7085, 1.0128],
    [ 0.2304, 1.0902, -1.5827, -0.3246, 1.9264, -0.3300, 0.1984]],
   
[[ 0.7821, 1.0391, -0.7245, -0.2093, -0.2153, -1.8157, -0.3452],
    [-2.0615, 0.6741, -1.3233, -1.3598, -0.0835, -0.0235, 0.1744],
    [ 2.2983, 0.9571, -0.6619, -0.8285, -0.6057, -1.4013, 1.2973]],

   [[ 1.6409, -1.0567, -0.2616, -0.2501, 0.5011, 0.2600, -0.1782],
    [    -0.2595, -0.0145, -0.3839, -2.9662, -1.0606, -0.3090, 0.9343],
    [ 1.6243, 0.0016, -0.4375, -2.1085, 1.1450, -0.3822, -0.3553]],

   [[ 0.7542, 0.1332, 0.1825, -0.5146, 0.8005, -0.1259, -0.9578],
    [ 1.7518, 0.9796, 0.4105, 1.7675, -0.0832, 0.5087, -0.8253],
    [ 0.1633, 0.5013, 1.4206, 1.1542, -1.5366, -0.5577, -0.4383]]])
masks = torch.tensor([[[ True,  True,  True,  True, False],
         [ True,  True,  True, False, False],
         [False, False, False, False, False]],

        [[ True, False, False, False, False],
         [ True,  True,  True,  True, False],
         [ True,  True, False, False, False]],

        [[ True,  True, False, False, False],
         [False, False, False, False, False],
         [False, False, False, False, False]],

        [[ True,  True,  True, False, False],
         [ True,  True,  True, False, False],
         [False, False, False, False, False]]])
assert torch.sum(get_last_event(ts, masks) == torch.tensor([
    [-0.8937, 0.9265, -0.5355, -1.1597, -0.4602, 0.7085, 1.0128],
    [ 2.2983, 0.9571, -0.6619, -0.8285, -0.6057, -1.4013, 1.2973],
    [1.6409, -1.0567, -0.2616, -0.2501, 0.5011, 0.2600, -0.1782],
    [ 1.7518, 0.9796, 0.4105, 1.7675, -0.0832, 0.5087, -0.8253],
    ])) == 7*4

In [11]:
# Naive RNN hehe :P
class NaiveRNN(nn.Module):
    def __init__(self, num_items):
        super().__init__()
        self.embedding = nn.Embedding(num_embeddings=num_items, embedding_dim=128)
        self.rnn = nn.GRU(128, hidden_size=128, batch_first=True)
        self.rev_rnn = nn.GRU(128, hidden_size=128, batch_first=True)
        self.fc = nn.Linear(in_features=128*2, out_features=1)
        self.dropout = nn.Dropout(0.2)
        self.sigmoid = nn.Sigmoid()
    def forward(self, x, masks, rev_x, rev_masks):
        batch_size = x.shape[0]
        # Forward pass with x and masks
        x = self.dropout(self.embedding(x))
        x = sum_embeddings_with_masks(x, masks)
        output, _ = self.rnn(x)
        real_event_h = get_last_event(output, masks)

        # Forward pass with rev_x and rev_masks
        rev_x = self.dropout(self.embedding(rev_x))
        rev_x = sum_embeddings_with_masks(rev_x, rev_masks)
        output_rev, _ = self.rnn(rev_x)
        real_event_h_rev = get_last_event(output_rev, rev_masks)

        # Concat both hidden states
        logits = self.dropout(self.fc(torch.cat([real_event_h, real_event_h_rev], 1)))
        probs = self.sigmoid(logits)
        return probs.view(batch_size)
naive_rnn = NaiveRNN(num_items=len(itemidx))
naive_rnn

NaiveRNN(
  (embedding): Embedding(40119, 128)
  (rnn): GRU(128, 128, batch_first=True)
  (rev_rnn): GRU(128, 128, batch_first=True)
  (fc): Linear(in_features=256, out_features=1, bias=True)
  (dropout): Dropout(p=0.2, inplace=False)
  (sigmoid): Sigmoid()
)

In [12]:
class CustomLoss(nn.Module):
    def __init__(self, num_majority, num_minority, K=0.01):
        super(CustomLoss, self).__init__()
        
        self.num_majority = num_majority
        self.num_minority = num_minority
        self.K = K
    def forward(self, y_hat, y):
        #y = torch.to('cpu').LongTensor(y)
        delta_maj, delta_min = self.K / self.num_majority**(1/4), self.K / self.num_minority**(1/4)
        zj = (y_hat > 0.5).int()
        masks = torch.where(zj != y)
        zc = y_hat
        sigma_zc_maj = torch.exp(y_hat - delta_maj) / \
                            (torch.exp(y_hat - delta_maj) + torch.sum(torch.exp(zj[masks])))

        sigma_zc_min = torch.exp(y_hat - delta_min) / \
                            (torch.exp(y_hat - delta_min) + torch.sum(torch.exp(zj[masks])))

        loss = (- torch.log(sigma_zc_maj) - torch.log(sigma_zc_min)).mean()
        cross_loss = nn.CrossEntropyLoss()(y_hat, y)
        return cross_loss
minority = np.sum([element[-1] for element in train_dataset]) / len(train_dataset)
majority = 1 - minority
(majority, minority)


(0.8689701094998521, 0.13102989050014796)

In [13]:
criterion = CustomLoss(num_majority=majority, num_minority=minority)

optimizer = torch.optim.Adam(naive_rnn.parameters(), lr=1e-3, weight_decay=1e-2)
from tqdm import tqdm
from sklearn.metrics import *
def train(model, train_loader, val_loader, n_epochs):
    for epoch in tqdm(range(n_epochs)):
        model.to(device).train()
        train_loss = 0
        for x, masks, rev_x, rev_masks, y in train_loader:
            x, masks, rev_x, rev_masks, y = x.to(device), masks.to(device), rev_x.to(device), rev_masks.to(device), y.to(device)
            y_hat = model(x, masks, rev_x, rev_masks).view(y.shape[0])
            loss = criterion(y_hat, y.float())
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            train_loss += loss.item()
        train_loss = train_loss / len(train_loader)
        print('Epoch: {} \t Training Loss: {:.6f}'.format(epoch+1, train_loss))
        acc, auc, precision, recall, f1score = eval_model(model, val_loader)
        print('Epoch: {} \t Validation acc: {:.2f}, auc:{:.2f}, precision: {:.2f}, recall: {:.2f}, f1: {:.2f}' 
              .format(epoch+1, acc, auc, precision, recall, f1score))    
def classification_metrics(Y_score, Y_pred, Y_true):
    acc, auc, precision, recall, f1score = accuracy_score(Y_true, Y_pred), \
                                           roc_auc_score(Y_true, Y_score), \
                                           precision_score(Y_true, Y_pred), \
                                           recall_score(Y_true, Y_pred), \
                                           f1_score(Y_true, Y_pred)
    return acc, auc, precision, recall, f1score



#precision, recall, thresholds = precision_recall_curve(y_true, y_scores)
def eval_model(model, val_loader):
    model.to(device).eval()
    y_pred = torch.LongTensor()
    y_score = torch.Tensor()
    y_true = torch.LongTensor()
    for x, masks, rev_x, rev_masks, y in val_loader:
        x, masks, rev_x, rev_masks, y = x.to(device), masks.to(device), rev_x.to(device), rev_masks.to(device), y.to(device)
        y_true = y.long().detach().to('cpu')
        y_hat = model(x, masks, rev_x, rev_masks).view(y.shape[0])
        y_score = y_hat.detach().to('cpu')
        y_pred = (y_hat > 0.5).int().detach().to('cpu')
    acc, auc, precision, recall, f1score = classification_metrics(y_score, y_pred, y_true)
    return acc, auc, precision, recall, f1score


In [14]:
# number of epochs to train the model
n_epochs = 50
train(naive_rnn, train_loader, val_loader, n_epochs)

  0%|          | 0/50 [00:00<?, ?it/s]

Epoch: 1 	 Training Loss: 34.877946


  2%|▏         | 1/50 [00:14<11:51, 14.53s/it]

Epoch: 1 	 Validation acc: 0.46, auc:0.80, precision: 0.28, recall: 1.00, f1: 0.43
Epoch: 2 	 Training Loss: 34.713399


  4%|▍         | 2/50 [00:28<11:11, 14.00s/it]

Epoch: 2 	 Validation acc: 0.67, auc:0.77, precision: 0.33, recall: 0.60, f1: 0.43
Epoch: 3 	 Training Loss: 34.608937


  6%|▌         | 3/50 [00:41<10:51, 13.87s/it]

Epoch: 3 	 Validation acc: 0.71, auc:0.54, precision: 0.25, recall: 0.20, f1: 0.22
Epoch: 4 	 Training Loss: 34.544980


  8%|▊         | 4/50 [00:55<10:33, 13.78s/it]

Epoch: 4 	 Validation acc: 0.62, auc:0.47, precision: 0.25, recall: 0.40, f1: 0.31
Epoch: 5 	 Training Loss: 34.468515


 10%|█         | 5/50 [01:09<10:17, 13.73s/it]

Epoch: 5 	 Validation acc: 0.58, auc:0.67, precision: 0.33, recall: 1.00, f1: 0.50
Epoch: 6 	 Training Loss: 34.514550


 12%|█▏        | 6/50 [01:22<10:03, 13.71s/it]

Epoch: 6 	 Validation acc: 0.50, auc:0.66, precision: 0.27, recall: 0.80, f1: 0.40
Epoch: 7 	 Training Loss: 34.501683


 14%|█▍        | 7/50 [01:36<09:49, 13.72s/it]

Epoch: 7 	 Validation acc: 0.42, auc:0.79, precision: 0.26, recall: 1.00, f1: 0.42
Epoch: 8 	 Training Loss: 34.478237


 16%|█▌        | 8/50 [01:50<09:36, 13.73s/it]

Epoch: 8 	 Validation acc: 0.46, auc:0.82, precision: 0.28, recall: 1.00, f1: 0.43
Epoch: 9 	 Training Loss: 34.510086


 18%|█▊        | 9/50 [02:04<09:22, 13.72s/it]

Epoch: 9 	 Validation acc: 0.75, auc:0.66, precision: 0.43, recall: 0.60, f1: 0.50
Epoch: 10 	 Training Loss: 34.376851


 20%|██        | 10/50 [02:17<09:07, 13.68s/it]

Epoch: 10 	 Validation acc: 0.54, auc:0.59, precision: 0.25, recall: 0.60, f1: 0.35
Epoch: 11 	 Training Loss: 34.167537


 22%|██▏       | 11/50 [02:31<08:53, 13.68s/it]

Epoch: 11 	 Validation acc: 0.46, auc:0.61, precision: 0.28, recall: 1.00, f1: 0.43
Epoch: 12 	 Training Loss: 33.952574


 24%|██▍       | 12/50 [02:44<08:40, 13.69s/it]

Epoch: 12 	 Validation acc: 0.62, auc:0.65, precision: 0.33, recall: 0.80, f1: 0.47
Epoch: 13 	 Training Loss: 33.691611


 26%|██▌       | 13/50 [02:58<08:26, 13.69s/it]

Epoch: 13 	 Validation acc: 0.62, auc:0.62, precision: 0.30, recall: 0.60, f1: 0.40
Epoch: 14 	 Training Loss: 33.438576


 28%|██▊       | 14/50 [03:12<08:13, 13.70s/it]

Epoch: 14 	 Validation acc: 0.67, auc:0.75, precision: 0.33, recall: 0.60, f1: 0.43
Epoch: 15 	 Training Loss: 33.282245


 30%|███       | 15/50 [03:26<07:59, 13.70s/it]

Epoch: 15 	 Validation acc: 0.79, auc:0.91, precision: 0.50, recall: 1.00, f1: 0.67
Epoch: 16 	 Training Loss: 33.093085


 32%|███▏      | 16/50 [03:39<07:45, 13.69s/it]

Epoch: 16 	 Validation acc: 0.88, auc:0.86, precision: 0.62, recall: 1.00, f1: 0.77
Epoch: 17 	 Training Loss: 32.869023


 34%|███▍      | 17/50 [03:53<07:31, 13.69s/it]

Epoch: 17 	 Validation acc: 0.67, auc:0.75, precision: 0.33, recall: 0.60, f1: 0.43
Epoch: 18 	 Training Loss: 32.888567


 36%|███▌      | 18/50 [04:07<07:18, 13.69s/it]

Epoch: 18 	 Validation acc: 0.62, auc:0.82, precision: 0.33, recall: 0.80, f1: 0.47
Epoch: 19 	 Training Loss: 32.756325


 38%|███▊      | 19/50 [04:20<07:02, 13.64s/it]

Epoch: 19 	 Validation acc: 0.79, auc:0.85, precision: 0.50, recall: 1.00, f1: 0.67
Epoch: 20 	 Training Loss: 32.683367


 40%|████      | 20/50 [04:34<06:48, 13.63s/it]

Epoch: 20 	 Validation acc: 0.67, auc:0.78, precision: 0.36, recall: 0.80, f1: 0.50
Epoch: 21 	 Training Loss: 32.581540


 42%|████▏     | 21/50 [04:47<06:35, 13.63s/it]

Epoch: 21 	 Validation acc: 0.83, auc:0.98, precision: 0.56, recall: 1.00, f1: 0.71
Epoch: 22 	 Training Loss: 32.417383


 44%|████▍     | 22/50 [05:01<06:21, 13.63s/it]

Epoch: 22 	 Validation acc: 0.83, auc:0.97, precision: 0.56, recall: 1.00, f1: 0.71
Epoch: 23 	 Training Loss: 32.469156


 46%|████▌     | 23/50 [05:15<06:08, 13.65s/it]

Epoch: 23 	 Validation acc: 0.71, auc:0.91, precision: 0.42, recall: 1.00, f1: 0.59
Epoch: 24 	 Training Loss: 32.378585


 48%|████▊     | 24/50 [05:28<05:55, 13.66s/it]

Epoch: 24 	 Validation acc: 0.96, auc:0.95, precision: 0.83, recall: 1.00, f1: 0.91
Epoch: 25 	 Training Loss: 32.399429


 50%|█████     | 25/50 [05:42<05:40, 13.64s/it]

Epoch: 25 	 Validation acc: 0.79, auc:0.91, precision: 0.50, recall: 1.00, f1: 0.67
Epoch: 26 	 Training Loss: 32.295546


 52%|█████▏    | 26/50 [05:56<05:27, 13.63s/it]

Epoch: 26 	 Validation acc: 0.88, auc:0.96, precision: 0.62, recall: 1.00, f1: 0.77
Epoch: 27 	 Training Loss: 32.263115


 54%|█████▍    | 27/50 [06:09<05:13, 13.63s/it]

Epoch: 27 	 Validation acc: 0.75, auc:0.96, precision: 0.45, recall: 1.00, f1: 0.62
Epoch: 28 	 Training Loss: 32.149071


 56%|█████▌    | 28/50 [06:23<05:00, 13.66s/it]

Epoch: 28 	 Validation acc: 0.71, auc:0.96, precision: 0.42, recall: 1.00, f1: 0.59
Epoch: 29 	 Training Loss: 32.203840


 58%|█████▊    | 29/50 [06:37<04:47, 13.69s/it]

Epoch: 29 	 Validation acc: 0.83, auc:0.97, precision: 0.56, recall: 1.00, f1: 0.71
Epoch: 30 	 Training Loss: 32.114401


 60%|██████    | 30/50 [06:50<04:33, 13.69s/it]

Epoch: 30 	 Validation acc: 0.75, auc:0.83, precision: 0.43, recall: 0.60, f1: 0.50
Epoch: 31 	 Training Loss: 32.153026


 62%|██████▏   | 31/50 [07:04<04:20, 13.71s/it]

Epoch: 31 	 Validation acc: 0.92, auc:0.97, precision: 0.71, recall: 1.00, f1: 0.83
Epoch: 32 	 Training Loss: 32.114319


 64%|██████▍   | 32/50 [07:18<04:06, 13.68s/it]

Epoch: 32 	 Validation acc: 0.79, auc:0.97, precision: 0.50, recall: 1.00, f1: 0.67
Epoch: 33 	 Training Loss: 32.158764


 66%|██████▌   | 33/50 [07:31<03:52, 13.69s/it]

Epoch: 33 	 Validation acc: 0.83, auc:0.91, precision: 0.56, recall: 1.00, f1: 0.71
Epoch: 34 	 Training Loss: 32.030841


 68%|██████▊   | 34/50 [07:45<03:39, 13.69s/it]

Epoch: 34 	 Validation acc: 0.83, auc:0.85, precision: 0.60, recall: 0.60, f1: 0.60
Epoch: 35 	 Training Loss: 32.099155


 70%|███████   | 35/50 [07:59<03:25, 13.68s/it]

Epoch: 35 	 Validation acc: 0.83, auc:0.91, precision: 0.56, recall: 1.00, f1: 0.71
Epoch: 36 	 Training Loss: 32.035524


 72%|███████▏  | 36/50 [08:12<03:11, 13.67s/it]

Epoch: 36 	 Validation acc: 0.92, auc:0.85, precision: 0.80, recall: 0.80, f1: 0.80
Epoch: 37 	 Training Loss: 32.045492


 74%|███████▍  | 37/50 [08:26<02:57, 13.67s/it]

Epoch: 37 	 Validation acc: 0.88, auc:0.93, precision: 0.67, recall: 0.80, f1: 0.73
Epoch: 38 	 Training Loss: 31.912011


 76%|███████▌  | 38/50 [08:40<02:44, 13.68s/it]

Epoch: 38 	 Validation acc: 0.92, auc:0.95, precision: 0.71, recall: 1.00, f1: 0.83
Epoch: 39 	 Training Loss: 31.969549


 78%|███████▊  | 39/50 [08:54<02:30, 13.70s/it]

Epoch: 39 	 Validation acc: 0.96, auc:0.96, precision: 0.83, recall: 1.00, f1: 0.91
Epoch: 40 	 Training Loss: 31.941683


 80%|████████  | 40/50 [09:07<02:17, 13.71s/it]

Epoch: 40 	 Validation acc: 0.88, auc:0.95, precision: 0.62, recall: 1.00, f1: 0.77
Epoch: 41 	 Training Loss: 31.943736


 82%|████████▏ | 41/50 [09:21<02:03, 13.68s/it]

Epoch: 41 	 Validation acc: 0.92, auc:0.96, precision: 0.71, recall: 1.00, f1: 0.83
Epoch: 42 	 Training Loss: 31.862036


 84%|████████▍ | 42/50 [09:35<01:49, 13.67s/it]

Epoch: 42 	 Validation acc: 0.92, auc:0.92, precision: 0.71, recall: 1.00, f1: 0.83
Epoch: 43 	 Training Loss: 31.928226


 86%|████████▌ | 43/50 [09:48<01:35, 13.66s/it]

Epoch: 43 	 Validation acc: 0.92, auc:0.94, precision: 0.80, recall: 0.80, f1: 0.80
Epoch: 44 	 Training Loss: 31.890965


 88%|████████▊ | 44/50 [10:02<01:21, 13.66s/it]

Epoch: 44 	 Validation acc: 0.83, auc:0.95, precision: 0.56, recall: 1.00, f1: 0.71
Epoch: 45 	 Training Loss: 31.895241


 90%|█████████ | 45/50 [10:16<01:08, 13.68s/it]

Epoch: 45 	 Validation acc: 0.79, auc:0.84, precision: 0.50, recall: 0.80, f1: 0.62
Epoch: 46 	 Training Loss: 31.921272


 92%|█████████▏| 46/50 [10:29<00:54, 13.69s/it]

Epoch: 46 	 Validation acc: 0.79, auc:0.88, precision: 0.50, recall: 0.80, f1: 0.62
Epoch: 47 	 Training Loss: 31.921439


 94%|█████████▍| 47/50 [10:43<00:41, 13.71s/it]

Epoch: 47 	 Validation acc: 0.83, auc:0.94, precision: 0.57, recall: 0.80, f1: 0.67
Epoch: 48 	 Training Loss: 31.928663


 96%|█████████▌| 48/50 [10:57<00:27, 13.72s/it]

Epoch: 48 	 Validation acc: 0.75, auc:0.92, precision: 0.45, recall: 1.00, f1: 0.62
Epoch: 49 	 Training Loss: 31.799830


 98%|█████████▊| 49/50 [11:11<00:13, 13.71s/it]

Epoch: 49 	 Validation acc: 0.75, auc:0.96, precision: 0.45, recall: 1.00, f1: 0.62
Epoch: 50 	 Training Loss: 31.883884


100%|██████████| 50/50 [11:24<00:00, 13.69s/it]

Epoch: 50 	 Validation acc: 0.88, auc:0.96, precision: 0.62, recall: 1.00, f1: 0.77





In [15]:
acc, auc, precision, recall, f1score = eval_model(naive_rnn, val_loader)
print(acc, auc, precision, recall, f1_score)

0.875 0.9578947368421054 0.625 1.0 <function f1_score at 0x7fe3b2b138b0>


In [16]:
acc, auc, precision, recall, f1score = eval_model(naive_rnn, test_loader)
print(acc, auc, precision, recall, f1_score)

0.875 0.9578947368421054 0.625 1.0 <function f1_score at 0x7fe3b2b138b0>
