In [1]:
!pip install polars numpy torch

Defaulting to user installation because normal site-packages is not writeable

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.0.1[0m[39;49m -> [0m[32;49m23.1.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython3 -m pip install --upgrade pip[0m


In [2]:
import polars as pl
import pandas as pd
import os
import sys
import random
import pickle
import torch
import numpy as np

import torch.nn as nn
import torch.nn.functional as F
# set seed
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
DATA_PATH = ""
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:512"

  from pandas.core.computation.check import NUMEXPR_INSTALLED


In [3]:
import sys
obj = pickle.load(open(os.path.join(DATA_PATH, 'allevents_by_episode_48_charts'), 'rb'))
patients = obj.select(pl.col('subject_id')).to_series().to_list()
seqs = obj.select(pl.col('itemidx')).to_series().to_list()
mortality = obj.select(pl.col('mortality_tf')).to_series().to_list()

In [4]:
itemidx = set([each_item for events in seqs for items in events for each_item in items])
len(itemidx)

40119

In [5]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [6]:
# Custom dataset
from torch.utils.data import Dataset, DataLoader, random_split

class CustomDataset(Dataset):
    def __init__(self, seqs, mortality):
        self.seqs = seqs
        self.labels = mortality
    def __len__(self):
        return len(self.labels)
    def __getitem__(self, index):
        return self.seqs[index], self.labels[index]
dataset = CustomDataset(seqs, mortality)
assert len(dataset) == len(obj) #TODO write test separately 


In [7]:
# Collate function
def collate_fn(data):
    """
    Input data: a tuple of seqs, and label
    Output: x (num_patients, num_events, num_itemids) with *real* items in the events
            rev_x (num_patients, num_events, num_itemids) with *reverse real* items in the events
            masks  (num_patients, num_events, num_itemids) whether a *real* itemid in the events is present
            rev_masks (num_patients, num_events, num_itemids) whether a *reverse real* itemid in the events is present
    """
    seqs, labels = zip(*data)
    num_patients = len(seqs)
    max_num_events = max([len(event) for event in seqs])
    max_num_items = max([len(itemid) for event in seqs for itemid in event])
    tensor_shape = (num_patients, max_num_events, max_num_items)
    x =        torch.zeros(tensor_shape, dtype=torch.long)
    rev_x =    torch.zeros(tensor_shape, dtype=torch.long)
    masks =    torch.zeros(tensor_shape, dtype=torch.bool)
    rev_masks = torch.zeros(tensor_shape, dtype=torch.bool) 
    y =        torch.tensor(labels, dtype=torch.long)
    

    for i_patient, events in enumerate(seqs):
        for i_event, item in enumerate(events):
            padded_item = torch.concat([torch.tensor(item),
                                        torch.zeros(max_num_items - len(item))]).long()
            x[i_patient, i_event, :] = padded_item
            masks[i_patient, i_event, :] = torch.where(padded_item!=0,1,0)  
    for i_patient, events in enumerate(seqs):
        idx_all_real_events = torch.sum(x[i_patient, :, :], dim=(1))!= 0
        idx_padded_events =torch.sum(x[i_patient, :, :], dim=(1))== 0
        fliped = torch.flip(x[i_patient, idx_all_real_events, :].unsqueeze(1), (0,)).squeeze(1)
        rev_x[i_patient, :, :] = torch.concat((fliped, x[i_patient, idx_padded_events, :] ))
        rev_masks[i_patient, :, :] = torch.where(rev_x[i_patient, :, :] != 0, True, False)
    return x, masks, rev_x, rev_masks, y

    

from torch.utils.data import DataLoader
loader = DataLoader(dataset, batch_size=10, collate_fn=collate_fn, pin_memory=True)
loader_iter = iter(loader)
x, masks, rev_x, rev_masks, y = next(loader_iter)
#assert x.shape == masks.shape == (10, 8, 313)
assert y.shape == (10,)

In [8]:
train, val, test = int(len(dataset)*0.8), int(len(dataset)*0.1), len(dataset) - int(len(dataset)*0.8) -  int(len(dataset)*0.1)
lengths = [train, val, test]
train_dataset, val_dataset, test_dataset = random_split(dataset=dataset, lengths=lengths)
from pytorch_metric_learning import samplers

def load_data(train_dataset, val_dataset, test_dataset, collate_fn):
    batch_size = 64
    seqs, labels = zip(*train_dataset)
    sampler = samplers.MPerClassSampler(labels, m=len(train_dataset)//2, batch_size=None, length_before_new_iter=1000)
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, \
                                collate_fn=collate_fn, \
                                num_workers=1, sampler=sampler)
    
    val_loader = DataLoader(train_dataset, batch_size=batch_size, \
                                collate_fn=collate_fn, \
                                num_workers=4, shuffle=False)
    test_loader = DataLoader(train_dataset, batch_size=batch_size, \
                                collate_fn=collate_fn, \
                                num_workers=4, shuffle=False)
    return train_loader, val_loader, test_loader

train_loader, val_loader, test_loader = load_data(train_dataset, val_dataset, test_dataset, collate_fn)
print("Length of train dataset:", len(train_dataset))
print("Length of train with mortality:", np.sum([element[-1] for element in train_dataset]))
print("Length of train with survive patients:", (len(train_dataset) - np.sum([element[-1] for element in train_dataset])))
print("Length of val dataset:", len(val_dataset))
print("Length of val with mortality:", np.sum([element[-1] for element in val_dataset]))
print("Length of val with survive patients:", (len(train_dataset) - np.sum([element[-1] for element in val_dataset])))
print("Length of test dataset:", len(test_dataset))
print("Length of test with mortality:", np.sum([element[-1] for element in test_dataset]))
print("Length of test with survive patients:", (len(test_dataset) - np.sum([element[-1] for element in test_dataset])))

Length of train dataset: 27032
Length of train with mortality: 3542
Length of train with survive patients: 23490
Length of val dataset: 3379
Length of val with mortality: 458
Length of val with survive patients: 26574
Length of test dataset: 3379
Length of test with mortality: 471
Length of test with survive patients: 2908


In [9]:
def custom_for_loop(iterable):
    iterator = iter(iterable)
    done_looping = False
    count_mor, count_sur = 0, 0
    while not done_looping:
        try:
            item = next(iterator)
        except StopIteration:
            done_looping = True
        else:
            x, masks, rev_x, rev_masks, y = item
            count_mor += y.sum().item()
            count_sur += len(y) -  y.sum().item()
    return count_mor, count_sur

custom_for_loop(train_loader)# == len(train_dataset) // 2

(13516, 13516)

In [10]:
def sum_embeddings_with_masks(x, masks):
    """
    Input:  x               (batch_size, num_events, num_itemids, embedding_dims)
            mask            (batch_size, num_events, num_itemids)
    Output: sum_embeddings  (batch_size, num_events, embedding_dims)
    The return output compress the num_itemids into embedding vectors
    """
    masks = masks.unsqueeze(-1).expand(x.shape[0], x.shape[1], x.shape[2], x.shape[3])
    return torch.sum(masks * x, 2)
def get_last_event(hidden_states, masks):
    """
    hidden_states: (batch_size, #item, embedding_dim)
    masks:         (batch_size, #item, embedding_dim)
    return last_hidden_state: (batch_size, embedding_dim)
    """
    idx_last_event = torch.argmin(torch.sum(masks, 2), 1)
    return hidden_states[torch.arange(hidden_states.shape[0]), torch.where(idx_last_event - 1 < 0, max(idx_last_event), idx_last_event - 1),:]

In [11]:
ts = torch.tensor([[
    [-0.8201, 0.3956, 0.8989, -1.3884, -0.1670, 0.2851, -0.6411],
    [-0.8937, 0.9265, -0.5355, -1.1597, -0.4602, 0.7085, 1.0128],
    [ 0.2304, 1.0902, -1.5827, -0.3246, 1.9264, -0.3300, 0.1984]],
   
[[ 0.7821, 1.0391, -0.7245, -0.2093, -0.2153, -1.8157, -0.3452],
    [-2.0615, 0.6741, -1.3233, -1.3598, -0.0835, -0.0235, 0.1744],
    [ 2.2983, 0.9571, -0.6619, -0.8285, -0.6057, -1.4013, 1.2973]],

   [[ 1.6409, -1.0567, -0.2616, -0.2501, 0.5011, 0.2600, -0.1782],
    [    -0.2595, -0.0145, -0.3839, -2.9662, -1.0606, -0.3090, 0.9343],
    [ 1.6243, 0.0016, -0.4375, -2.1085, 1.1450, -0.3822, -0.3553]],

   [[ 0.7542, 0.1332, 0.1825, -0.5146, 0.8005, -0.1259, -0.9578],
    [ 1.7518, 0.9796, 0.4105, 1.7675, -0.0832, 0.5087, -0.8253],
    [ 0.1633, 0.5013, 1.4206, 1.1542, -1.5366, -0.5577, -0.4383]]])
masks = torch.tensor([[[ True,  True,  True,  True, False],
         [ True,  True,  True, False, False],
         [False, False, False, False, False]],

        [[ True, False, False, False, False],
         [ True,  True,  True,  True, False],
         [ True,  True, False, False, False]],

        [[ True,  True, False, False, False],
         [False, False, False, False, False],
         [False, False, False, False, False]],

        [[ True,  True,  True, False, False],
         [ True,  True,  True, False, False],
         [False, False, False, False, False]]])
assert torch.sum(get_last_event(ts, masks) == torch.tensor([
    [-0.8937, 0.9265, -0.5355, -1.1597, -0.4602, 0.7085, 1.0128],
    [ 2.2983, 0.9571, -0.6619, -0.8285, -0.6057, -1.4013, 1.2973],
    [1.6409, -1.0567, -0.2616, -0.2501, 0.5011, 0.2600, -0.1782],
    [ 1.7518, 0.9796, 0.4105, 1.7675, -0.0832, 0.5087, -0.8253],
    ])) == 7*4

In [12]:
# Naive RNN hehe :P
class NaiveRNN(nn.Module):
    def __init__(self, num_items):
        super().__init__()
        self.embedding = nn.Embedding(num_embeddings=num_items, embedding_dim=128)
        self.rnn = nn.GRU(128, hidden_size=128, batch_first=True)
        self.rev_rnn = nn.GRU(128, hidden_size=128, batch_first=True)
        self.fc = nn.Linear(in_features=128*2, out_features=1)
        self.dropout = nn.Dropout(0.2)
        self.sigmoid = nn.Sigmoid()
    def forward(self, x, masks, rev_x, rev_masks):
        batch_size = x.shape[0]
        # Forward pass with x and masks
        x = self.dropout(self.embedding(x))
        x = sum_embeddings_with_masks(x, masks)
        output, _ = self.rnn(x)
        real_event_h = get_last_event(output, masks)

        # Forward pass with rev_x and rev_masks
        rev_x = self.dropout(self.embedding(rev_x))
        rev_x = sum_embeddings_with_masks(rev_x, rev_masks)
        output_rev, _ = self.rnn(rev_x)
        real_event_h_rev = get_last_event(output_rev, rev_masks)

        # Concat both hidden states
        logits = self.dropout(self.fc(torch.cat([real_event_h, real_event_h_rev], 1)))
        probs = self.sigmoid(logits)
        return probs.view(batch_size)
naive_rnn = NaiveRNN(num_items=len(itemidx))
naive_rnn

NaiveRNN(
  (embedding): Embedding(40119, 128)
  (rnn): GRU(128, 128, batch_first=True)
  (rev_rnn): GRU(128, 128, batch_first=True)
  (fc): Linear(in_features=256, out_features=1, bias=True)
  (dropout): Dropout(p=0.2, inplace=False)
  (sigmoid): Sigmoid()
)

In [13]:
class CustomLoss(nn.Module):
    def __init__(self, num_majority, num_minority, K=0.01):
        super(CustomLoss, self).__init__()
        
        self.num_majority = num_majority
        self.num_minority = num_minority
        self.K = K
    def forward(self, y_hat, y):
        #y = torch.to('cpu').LongTensor(y)
        delta_maj, delta_min = self.K / self.num_majority**(1/4), self.K / self.num_minority**(1/4)
        zj = (y_hat > 0.5).int()
        masks = torch.where(zj != y)
        zc = y_hat
        sigma_zc_maj = torch.exp(y_hat - delta_maj) / \
                            (torch.exp(y_hat - delta_maj) + torch.sum(torch.exp(zj[masks])))

        sigma_zc_min = torch.exp(y_hat - delta_min) / \
                            (torch.exp(y_hat - delta_min) + torch.sum(torch.exp(zj[masks])))

        loss = (- torch.log(sigma_zc_maj) - torch.log(sigma_zc_min)).mean()
        return loss
minority = np.sum([element[-1] for element in train_dataset]) / len(train_dataset)
majority = 1 - minority
(majority, minority)


(0.8689701094998521, 0.13102989050014796)

In [14]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(naive_rnn.parameters(), lr=1e-3, weight_decay=1e-2)
from tqdm import tqdm
from sklearn.metrics import *
def train(model, train_loader, val_loader, n_epochs):
    for epoch in tqdm(range(n_epochs)):
        model.to(device).train()
        train_loss = 0
        for x, masks, rev_x, rev_masks, y in train_loader:
            x, masks, rev_x, rev_masks, y = x.to(device), masks.to(device), rev_x.to(device), rev_masks.to(device), y.to(device)
            y_hat = model(x, masks, rev_x, rev_masks).view(y.shape[0])
            loss = criterion(y_hat, y.float())
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            train_loss += loss.item()
        train_loss = train_loss / len(train_loader)
        print('Epoch: {} \t Training Loss: {:.6f}'.format(epoch+1, train_loss))
        acc, auc, precision, recall, f1score = eval_model(model, val_loader)
        print('Epoch: {} \t Validation acc: {:.2f}, auc:{:.2f}, precision: {:.2f}, recall: {:.2f}, f1: {:.2f}' 
              .format(epoch+1, acc, auc, precision, recall, f1score))    
def classification_metrics(Y_score, Y_pred, Y_true):
    acc, auc, precision, recall, f1score = accuracy_score(Y_true, Y_pred), \
                                           roc_auc_score(Y_true, Y_score), \
                                           precision_score(Y_true, Y_pred), \
                                           recall_score(Y_true, Y_pred), \
                                           f1_score(Y_true, Y_pred)
    return acc, auc, precision, recall, f1score



#precision, recall, thresholds = precision_recall_curve(y_true, y_scores)
def eval_model(model, val_loader):
    model.to(device).eval()
    y_pred = torch.LongTensor()
    y_score = torch.Tensor()
    y_true = torch.LongTensor()
    for x, masks, rev_x, rev_masks, y in val_loader:
        x, masks, rev_x, rev_masks, y = x.to(device), masks.to(device), rev_x.to(device), rev_masks.to(device), y.to(device)
        y_true = y.long().detach().to('cpu')
        y_hat = model(x, masks, rev_x, rev_masks).view(y.shape[0])
        y_score = y_hat.detach().to('cpu')
        y_pred = (y_hat > 0.5).int().detach().to('cpu')
    acc, auc, precision, recall, f1score = classification_metrics(y_score, y_pred, y_true)
    return acc, auc, precision, recall, f1score


In [15]:
# number of epochs to train the model
n_epochs = 50
train(naive_rnn, train_loader, val_loader, n_epochs)

  0%|          | 0/50 [00:00<?, ?it/s]

Epoch: 1 	 Training Loss: 3513.991530


  2%|▏         | 1/50 [00:15<12:45, 15.62s/it]

Epoch: 1 	 Validation acc: 0.50, auc:0.65, precision: 0.27, recall: 0.80, f1: 0.40
Epoch: 2 	 Training Loss: 3511.757812


  4%|▍         | 2/50 [00:30<12:04, 15.09s/it]

Epoch: 2 	 Validation acc: 0.42, auc:0.37, precision: 0.15, recall: 0.40, f1: 0.22
Epoch: 3 	 Training Loss: 3511.747199


  6%|▌         | 3/50 [00:45<11:41, 14.92s/it]

Epoch: 3 	 Validation acc: 0.58, auc:0.71, precision: 0.27, recall: 0.60, f1: 0.37
Epoch: 4 	 Training Loss: 3513.216491


  _warn_prf(average, modifier, msg_start, len(result))
  8%|▊         | 4/50 [00:59<11:23, 14.86s/it]

Epoch: 4 	 Validation acc: 0.79, auc:0.44, precision: 0.00, recall: 0.00, f1: 0.00
Epoch: 5 	 Training Loss: 3513.216601


  _warn_prf(average, modifier, msg_start, len(result))
 10%|█         | 5/50 [01:14<11:07, 14.83s/it]

Epoch: 5 	 Validation acc: 0.79, auc:0.39, precision: 0.00, recall: 0.00, f1: 0.00
Epoch: 6 	 Training Loss: 3511.746359


 12%|█▏        | 6/50 [01:29<10:50, 14.79s/it]

Epoch: 6 	 Validation acc: 0.58, auc:0.34, precision: 0.14, recall: 0.20, f1: 0.17
Epoch: 7 	 Training Loss: 3511.745385


 14%|█▍        | 7/50 [01:44<10:35, 14.77s/it]

Epoch: 7 	 Validation acc: 0.50, auc:0.56, precision: 0.23, recall: 0.60, f1: 0.33
Epoch: 8 	 Training Loss: 3511.745242


 16%|█▌        | 8/50 [01:58<10:20, 14.76s/it]

Epoch: 8 	 Validation acc: 0.50, auc:0.51, precision: 0.18, recall: 0.40, f1: 0.25
Epoch: 9 	 Training Loss: 3511.745248


 18%|█▊        | 9/50 [02:13<10:05, 14.76s/it]

Epoch: 9 	 Validation acc: 0.67, auc:0.25, precision: 0.00, recall: 0.00, f1: 0.00
Epoch: 10 	 Training Loss: 3513.216494


 20%|██        | 10/50 [02:28<09:50, 14.77s/it]

Epoch: 10 	 Validation acc: 0.21, auc:0.70, precision: 0.21, recall: 1.00, f1: 0.34
Epoch: 11 	 Training Loss: 3511.745518


  _warn_prf(average, modifier, msg_start, len(result))
 22%|██▏       | 11/50 [02:43<09:35, 14.77s/it]

Epoch: 11 	 Validation acc: 0.79, auc:0.50, precision: 0.00, recall: 0.00, f1: 0.00
Epoch: 12 	 Training Loss: 3511.745260


  _warn_prf(average, modifier, msg_start, len(result))
 24%|██▍       | 12/50 [02:57<09:20, 14.76s/it]

Epoch: 12 	 Validation acc: 0.79, auc:0.50, precision: 0.00, recall: 0.00, f1: 0.00
Epoch: 13 	 Training Loss: 3511.745294


 26%|██▌       | 13/50 [03:12<09:05, 14.75s/it]

Epoch: 13 	 Validation acc: 0.21, auc:0.50, precision: 0.21, recall: 1.00, f1: 0.34
Epoch: 14 	 Training Loss: 3513.216494


 28%|██▊       | 14/50 [03:27<08:50, 14.73s/it]

Epoch: 14 	 Validation acc: 0.21, auc:0.50, precision: 0.21, recall: 1.00, f1: 0.34
Epoch: 15 	 Training Loss: 3513.216620


  _warn_prf(average, modifier, msg_start, len(result))
 30%|███       | 15/50 [03:41<08:35, 14.72s/it]

Epoch: 15 	 Validation acc: 0.79, auc:0.50, precision: 0.00, recall: 0.00, f1: 0.00
Epoch: 16 	 Training Loss: 3513.216479


 32%|███▏      | 16/50 [03:56<08:21, 14.74s/it]

Epoch: 16 	 Validation acc: 0.21, auc:0.50, precision: 0.21, recall: 1.00, f1: 0.34
Epoch: 17 	 Training Loss: 3511.745720


  _warn_prf(average, modifier, msg_start, len(result))
 34%|███▍      | 17/50 [04:11<08:06, 14.75s/it]

Epoch: 17 	 Validation acc: 0.79, auc:0.50, precision: 0.00, recall: 0.00, f1: 0.00
Epoch: 18 	 Training Loss: 3511.745300


  _warn_prf(average, modifier, msg_start, len(result))
 36%|███▌      | 18/50 [04:26<07:51, 14.74s/it]

Epoch: 18 	 Validation acc: 0.79, auc:0.50, precision: 0.00, recall: 0.00, f1: 0.00
Epoch: 19 	 Training Loss: 3511.745271


  _warn_prf(average, modifier, msg_start, len(result))
 38%|███▊      | 19/50 [04:40<07:36, 14.73s/it]

Epoch: 19 	 Validation acc: 0.79, auc:0.50, precision: 0.00, recall: 0.00, f1: 0.00
Epoch: 20 	 Training Loss: 3511.745310


 40%|████      | 20/50 [04:55<07:21, 14.72s/it]

Epoch: 20 	 Validation acc: 0.21, auc:0.50, precision: 0.21, recall: 1.00, f1: 0.34
Epoch: 21 	 Training Loss: 3511.745340


  _warn_prf(average, modifier, msg_start, len(result))
 42%|████▏     | 21/50 [05:10<07:06, 14.72s/it]

Epoch: 21 	 Validation acc: 0.79, auc:0.50, precision: 0.00, recall: 0.00, f1: 0.00
Epoch: 22 	 Training Loss: 3511.745290


  _warn_prf(average, modifier, msg_start, len(result))
 44%|████▍     | 22/50 [05:25<06:52, 14.72s/it]

Epoch: 22 	 Validation acc: 0.79, auc:0.50, precision: 0.00, recall: 0.00, f1: 0.00
Epoch: 23 	 Training Loss: 3513.216494


  _warn_prf(average, modifier, msg_start, len(result))
 46%|████▌     | 23/50 [05:39<06:37, 14.74s/it]

Epoch: 23 	 Validation acc: 0.79, auc:0.50, precision: 0.00, recall: 0.00, f1: 0.00
Epoch: 24 	 Training Loss: 3511.745737


  _warn_prf(average, modifier, msg_start, len(result))
 48%|████▊     | 24/50 [05:54<06:23, 14.73s/it]

Epoch: 24 	 Validation acc: 0.79, auc:0.50, precision: 0.00, recall: 0.00, f1: 0.00
Epoch: 25 	 Training Loss: 3511.745250


  _warn_prf(average, modifier, msg_start, len(result))
 50%|█████     | 25/50 [06:09<06:08, 14.72s/it]

Epoch: 25 	 Validation acc: 0.79, auc:0.50, precision: 0.00, recall: 0.00, f1: 0.00
Epoch: 26 	 Training Loss: 3511.745468


  _warn_prf(average, modifier, msg_start, len(result))
 52%|█████▏    | 26/50 [06:23<05:53, 14.72s/it]

Epoch: 26 	 Validation acc: 0.79, auc:0.50, precision: 0.00, recall: 0.00, f1: 0.00
Epoch: 27 	 Training Loss: 3511.745275


  _warn_prf(average, modifier, msg_start, len(result))
 54%|█████▍    | 27/50 [06:38<05:38, 14.71s/it]

Epoch: 27 	 Validation acc: 0.79, auc:0.50, precision: 0.00, recall: 0.00, f1: 0.00
Epoch: 28 	 Training Loss: 3513.216494


 56%|█████▌    | 28/50 [06:53<05:23, 14.71s/it]

Epoch: 28 	 Validation acc: 0.21, auc:0.50, precision: 0.21, recall: 1.00, f1: 0.34
Epoch: 29 	 Training Loss: 3513.216485


  _warn_prf(average, modifier, msg_start, len(result))
 58%|█████▊    | 29/50 [07:08<05:08, 14.70s/it]

Epoch: 29 	 Validation acc: 0.79, auc:0.50, precision: 0.00, recall: 0.00, f1: 0.00
Epoch: 30 	 Training Loss: 3513.216727


 60%|██████    | 30/50 [07:22<04:54, 14.73s/it]

Epoch: 30 	 Validation acc: 0.21, auc:0.50, precision: 0.21, recall: 1.00, f1: 0.34
Epoch: 31 	 Training Loss: 3511.745482


  _warn_prf(average, modifier, msg_start, len(result))
 62%|██████▏   | 31/50 [07:37<04:39, 14.73s/it]

Epoch: 31 	 Validation acc: 0.79, auc:0.50, precision: 0.00, recall: 0.00, f1: 0.00
Epoch: 32 	 Training Loss: 3513.216494


 64%|██████▍   | 32/50 [07:52<04:25, 14.73s/it]

Epoch: 32 	 Validation acc: 0.21, auc:0.50, precision: 0.21, recall: 1.00, f1: 0.34
Epoch: 33 	 Training Loss: 3513.216513


 66%|██████▌   | 33/50 [08:07<04:10, 14.73s/it]

Epoch: 33 	 Validation acc: 0.21, auc:0.50, precision: 0.21, recall: 1.00, f1: 0.34
Epoch: 34 	 Training Loss: 3511.745555


  _warn_prf(average, modifier, msg_start, len(result))
 68%|██████▊   | 34/50 [08:21<03:55, 14.74s/it]

Epoch: 34 	 Validation acc: 0.79, auc:0.50, precision: 0.00, recall: 0.00, f1: 0.00
Epoch: 35 	 Training Loss: 3513.216491


 70%|███████   | 35/50 [08:36<03:40, 14.73s/it]

Epoch: 35 	 Validation acc: 0.21, auc:0.50, precision: 0.21, recall: 1.00, f1: 0.34
Epoch: 36 	 Training Loss: 3513.216819


  _warn_prf(average, modifier, msg_start, len(result))
 72%|███████▏  | 36/50 [08:51<03:26, 14.74s/it]

Epoch: 36 	 Validation acc: 0.79, auc:0.50, precision: 0.00, recall: 0.00, f1: 0.00
Epoch: 37 	 Training Loss: 3513.216500


 74%|███████▍  | 37/50 [09:06<03:11, 14.74s/it]

Epoch: 37 	 Validation acc: 0.21, auc:0.50, precision: 0.21, recall: 1.00, f1: 0.34
Epoch: 38 	 Training Loss: 3511.745451


  _warn_prf(average, modifier, msg_start, len(result))
 76%|███████▌  | 38/50 [09:20<02:56, 14.74s/it]

Epoch: 38 	 Validation acc: 0.79, auc:0.50, precision: 0.00, recall: 0.00, f1: 0.00
Epoch: 39 	 Training Loss: 3513.216494


  _warn_prf(average, modifier, msg_start, len(result))
 78%|███████▊  | 39/50 [09:35<02:42, 14.75s/it]

Epoch: 39 	 Validation acc: 0.79, auc:0.50, precision: 0.00, recall: 0.00, f1: 0.00
Epoch: 40 	 Training Loss: 3513.216557


 80%|████████  | 40/50 [09:50<02:27, 14.75s/it]

Epoch: 40 	 Validation acc: 0.21, auc:0.50, precision: 0.21, recall: 1.00, f1: 0.34
Epoch: 41 	 Training Loss: 3513.216614


  _warn_prf(average, modifier, msg_start, len(result))
 82%|████████▏ | 41/50 [10:05<02:12, 14.74s/it]

Epoch: 41 	 Validation acc: 0.79, auc:0.50, precision: 0.00, recall: 0.00, f1: 0.00
Epoch: 42 	 Training Loss: 3513.216496


 84%|████████▍ | 42/50 [10:19<01:57, 14.74s/it]

Epoch: 42 	 Validation acc: 0.21, auc:0.50, precision: 0.21, recall: 1.00, f1: 0.34
Epoch: 43 	 Training Loss: 3513.216508


 86%|████████▌ | 43/50 [10:34<01:43, 14.75s/it]

Epoch: 43 	 Validation acc: 0.21, auc:0.50, precision: 0.21, recall: 1.00, f1: 0.34
Epoch: 44 	 Training Loss: 3511.745632


  _warn_prf(average, modifier, msg_start, len(result))
 88%|████████▊ | 44/50 [10:49<01:28, 14.75s/it]

Epoch: 44 	 Validation acc: 0.79, auc:0.50, precision: 0.00, recall: 0.00, f1: 0.00
Epoch: 45 	 Training Loss: 3513.216494


  _warn_prf(average, modifier, msg_start, len(result))
 90%|█████████ | 45/50 [11:04<01:13, 14.75s/it]

Epoch: 45 	 Validation acc: 0.79, auc:0.50, precision: 0.00, recall: 0.00, f1: 0.00
Epoch: 46 	 Training Loss: 3513.216658


  _warn_prf(average, modifier, msg_start, len(result))
 92%|█████████▏| 46/50 [11:18<00:59, 14.75s/it]

Epoch: 46 	 Validation acc: 0.79, auc:0.50, precision: 0.00, recall: 0.00, f1: 0.00
Epoch: 47 	 Training Loss: 3511.745255


  _warn_prf(average, modifier, msg_start, len(result))
 94%|█████████▍| 47/50 [11:33<00:44, 14.74s/it]

Epoch: 47 	 Validation acc: 0.79, auc:0.50, precision: 0.00, recall: 0.00, f1: 0.00
Epoch: 48 	 Training Loss: 3513.216494


 96%|█████████▌| 48/50 [11:48<00:29, 14.75s/it]

Epoch: 48 	 Validation acc: 0.21, auc:0.50, precision: 0.21, recall: 1.00, f1: 0.34
Epoch: 49 	 Training Loss: 3513.216542


 98%|█████████▊| 49/50 [12:03<00:14, 14.75s/it]

Epoch: 49 	 Validation acc: 0.21, auc:0.50, precision: 0.21, recall: 1.00, f1: 0.34
Epoch: 50 	 Training Loss: 3511.745287


100%|██████████| 50/50 [12:17<00:00, 14.76s/it]

Epoch: 50 	 Validation acc: 0.21, auc:0.50, precision: 0.21, recall: 1.00, f1: 0.34





In [16]:
acc, auc, precision, recall, f1score = eval_model(naive_rnn, val_loader)
print(acc, auc, precision, recall, f1_score)

0.20833333333333334 0.5 0.20833333333333334 1.0 <function f1_score at 0x7ff7be66cf70>


In [17]:
acc, auc, precision, recall, f1score = eval_model(naive_rnn, test_loader)
print(acc, auc, precision, recall, f1_score)

0.20833333333333334 0.5 0.20833333333333334 1.0 <function f1_score at 0x7ff7be66cf70>
