In [1]:
import findspark
findspark.init()
findspark.find()
import pyspark
import numpy as np

findspark.find()
# pyspark==2.4.7
from pyspark import SparkContext, SparkConf
from pyspark.sql import SparkSession
from pyspark.sql.functions import to_timestamp
from pyspark.sql import functions as F
from pyspark.sql.functions import sum,avg,max,min,mean
from pyspark.sql.functions import row_number,lit
from pyspark.sql.window import Window

from torch.utils.data import TensorDataset, Dataset
from functools import reduce

conf = pyspark.SparkConf().setAppName('SparkApp').setMaster('local')
sc = pyspark.SparkContext(conf=conf)
spark = SparkSession(sc)

In [2]:
from torch.utils.data import TensorDataset, Dataset
from functools import reduce
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
import torch.optim as optim
import os
import time

In [3]:
########## read the raw data ##############
ADMISSIONS    = spark.read.option("header",True).csv('C:\Health data\ADMISSIONS.csv')
DIAGNOSES_ICD = spark.read.option("header",True).csv('C:\Health data\DIAGNOSES_ICD.csv')
PATIENTS      = spark.read.option("header",True).csv('C:\Health data\PATIENTS.csv')

In [4]:
# Select the columns we need, Admissions data, and rename the columns ###############
ADMISSIONS = ADMISSIONS.select('SUBJECT_ID','HADM_ID','ADMITTIME')
ADMISSIONS  = ADMISSIONS.withColumnRenamed("HADM_ID", "ID")\
                        .withColumnRenamed("SUBJECT_ID", "Patient_ID")

# conduct the code map for the ICD9_CODE of DIAGNOSES_ICD ###
codemap = DIAGNOSES_ICD.select('ICD9_CODE').distinct()
w = Window().orderBy(lit('A'))
codemap = codemap.withColumn("rowNum", row_number().over(w))\
                 .withColumnRenamed("ICD9_CODE", "Original_ICD9_CODE")\
                 .withColumnRenamed("rowNum", "New_code")
DIAGNOSES_ICD = DIAGNOSES_ICD.join(codemap,DIAGNOSES_ICD.ICD9_CODE==codemap.Original_ICD9_CODE,"left")\
                             .select('SUBJECT_ID','HADM_ID','New_code')

# Select the columns we need, PATIENTS data, and rename the columns ###### 
mortality = PATIENTS.select('SUBJECT_ID','EXPIRE_FLAG')\
                    .withColumnRenamed("EXPIRE_FLAG", "MORTALITY")\
                    .withColumnRenamed("SUBJECT_ID", "ID")

In [5]:
# join the ADMISSIONS, DIAGNOSES_ICD, and PATIENTS, based on the SUBJECT_ID and HADM_ID ####   
Total_inf1  = DIAGNOSES_ICD.join(mortality, DIAGNOSES_ICD.SUBJECT_ID == mortality.ID,"left")\
                           .select("SUBJECT_ID","HADM_ID","New_code","MORTALITY")
Total_inf = Total_inf1.join(ADMISSIONS, Total_inf1.HADM_ID == ADMISSIONS.ID,"left")\
                      .select('Patient_ID','HADM_ID','New_code','MORTALITY','ADMITTIME')
Total_inf = Total_inf.na.drop()

In [6]:
# create the visit sequence dataset ############
hu = Total_inf.sort(Total_inf.ADMITTIME.asc()).groupBy("Patient_ID","MORTALITY","ADMITTIME")\
    .agg(F.collect_list(F.struct("New_code")).alias("feature_value"))\
    .withColumn("feature_value", F.expr("transform(feature_value, x -> x.New_code)"))

bing = hu.sort(hu.Patient_ID.asc(),hu.ADMITTIME.asc()).groupBy("Patient_ID")\
    .agg(F.collect_list(F.struct("feature_value")).alias("feature_value"))

All_data = bing.join( mortality , bing.Patient_ID == mortality.ID, "left" )

ids    = All_data.select('Patient_ID').toPandas()
labels = All_data.select('MORTALITY').toPandas()
seqs   = All_data.select('feature_value').toPandas()

train_ids    = []
train_labels = []
train_seqs   = []

for i in range(len(ids)):
    train_ids.append(int(ids['Patient_ID'][i]))
    train_labels.append(int(labels['MORTALITY'][i]))
    gggg = np.array(seqs['feature_value'][i]).flatten().tolist()    
    train_seqs.append( gggg )

  gggg = np.array(seqs['feature_value'][i]).flatten().tolist()


In [17]:
num_features = Total_inf.select("New_code").distinct().count() + 2

In [18]:
class VisitSequenceWithLabelDataset(Dataset):
    def __init__(self, seqs, labels, num_features):
        if len(seqs) != len(labels):
            raise ValueError("Seqs and Labels have different lengths")
        self.labels = labels
        answers = []
        for x in seqs:
            a = len(x)
            b = num_features
            mtx = np.zeros((a, b))
            each_line = 0
            for ets in x:
                mtx[each_line, ets] = 1
                each_line = each_line + 1
            answers.append(mtx)
            self.seqs = answers
            
    def __len__(self):
        return len(self.labels)

    def __getitem__(self, index):
        # returns will be wrapped as List of Tensor(s) by DataLoader
        return self.seqs[index], self.labels[index]

def visit_collate_fn(batch):
    x = 0
    lines = []
    for a, b in batch:
        lines.append((a.shape[0], x))
        x = x + 1
    lines.sort(key = lambda s: s[0], reverse=True)
    line_row = lines[0][0]
    line_col = batch[0][0].shape[1]
    listOne = []
    listTwo = []
    listThree = []
    for i in list(map(lambda s: s[1], lines)):
        patient = batch[i]
        listTwo.append(patient[1])
        listThree.append(patient[0].shape[0])
        d = np.zeros((line_row, line_col))
        d[0:patient[0].shape[0], 0:patient[0].shape[1]] = patient[0]
        listOne.append(d)

    seqs_tensor = torch.FloatTensor(listOne)
    lengths_tensor = torch.LongTensor(listThree)
    labels_tensor = torch.LongTensor(listTwo)
    #print(listTwo)

    return (seqs_tensor, lengths_tensor), labels_tensor

In [20]:
train_dataset = VisitSequenceWithLabelDataset(train_seqs[0:10000], train_labels[0:10000], num_features)
train_loader  = DataLoader(dataset=train_dataset, batch_size=32, shuffle=True, collate_fn=visit_collate_fn, num_workers=0)
valid_dataset = VisitSequenceWithLabelDataset(train_seqs[25001:27000], train_labels[25001:27000], num_features)
test_dataset  = VisitSequenceWithLabelDataset(train_seqs[27001:30000], train_labels[27001:30000], num_features)
valid_loader = DataLoader(dataset=train_dataset, batch_size=32, shuffle=False, collate_fn=visit_collate_fn, num_workers=0)
test_loader  = DataLoader(dataset=test_dataset, batch_size=1, shuffle=False, collate_fn=visit_collate_fn, num_workers=0)

In [7]:
### RNN model construction, training, and evaluation ###########
class MyRNN1(nn.Module):
    def __init__(self, dim_input):
        super(MyRNN1, self).__init__()
        self.input_layer1 = nn.Linear(in_features=dim_input, out_features=32)
        self.rnn_model    = nn.GRU(input_size=32, hidden_size=16, num_layers=1, batch_first=True)
        self.input_layer2 = nn.Linear(in_features=16, out_features=2) 

    def forward(self, input_tuple):
        seqs, lengths = input_tuple
        seqs          = torch.tanh(self.input_layer1(seqs)) 
        seqs          = pack_padded_sequence(seqs, lengths, batch_first=True) 
        seqs, h       = self.rnn_model(seqs)
        seqs, _       = pad_packed_sequence(seqs, batch_first=True)
        seqs          = self.input_layer2(seqs[:, -1, :])
        return seqs


class MyRNN2(nn.Module):
    def __init__(self, dim_input):
        super(MyRNN2, self).__init__()

        self.batch_first =True
        self.layer1 = nn.Sequential(nn.Dropout(p=0.8),nn.Linear(dim_input, 128, bias=False),nn.Dropout(p=0.5))
        self.rnn1 = nn.GRU(input_size=128, hidden_size=128, num_layers=1, batch_first=True)
        self.rnnl1 = nn.Linear(in_features=128, out_features=1)
        self.rnnl1.bias.data.zero_()
        self.rnn2 = nn.GRU(input_size=128, hidden_size=128, num_layers=1, batch_first=True)
        self.rnnl2 = nn.Linear(in_features=128, out_features=128)
        self.rnnl2.bias.data.zero_()
        self.rnno = nn.Sequential(nn.Dropout(p=0.5),nn.Linear(in_features=128, out_features=2))
        self.rnno[1].bias.data.zero_()

    def forward(self, input_tuple):
        seqs, lengths = input_tuple
        b1, m1 = seqs.size()[:2]
        x = self.layer1(seqs)
        pi = pack_padded_sequence(x, lengths, batch_first=self.batch_first)
        a, _ = self.rnn1(pi)
        b, _ = pad_packed_sequence(a, batch_first=self.batch_first)
        c = torch.autograd.Variable(torch.FloatTensor([[1.0 if i < lengths[idx] else 0.0 for i in range(m1)] for idx in range(b1)]).unsqueeze(2), requires_grad=False)
        e = self.rnnl1(b)
        def max(x, c):
            exp = torch.exp(x)
            msp = exp * c
            sth = torch.sum(msp, dim=1, keepdim=True)
            return msp / sth
        alpha = max(e, c)
        h, _ = self.rnn2(pi)
        gps, _ = pad_packed_sequence(h, batch_first=self.batch_first)
        out = torch.tanh(self.rnnl2(gps))
        context = torch.bmm(torch.transpose(alpha, 1, 2), out * x).squeeze(1)
        rnno = self.rnno(context)
        return rnno
    
class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def compute_batch_accuracy(output, target):
    """Computes the accuracy for a batch"""
    with torch.no_grad():

        batch_size = target.size(0)
        _, pred = output.max(1)
        correct = pred.eq(target).sum()

        return correct * 100.0 / batch_size


def train(model, device, data_loader, criterion, optimizer, epoch, print_freq=10):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    accuracy = AverageMeter()

    model.train()

    end = time.time()
    for i, (input, target) in enumerate(data_loader):
        # measure data loading time
        data_time.update(time.time() - end)

        if isinstance(input, tuple):
            input = tuple([e.to(device) if type(e) == torch.Tensor else e for e in input])
        else:
            input = input.to(device)
        target = target.to(device)

        optimizer.zero_grad()
        output = model(input)
        loss = criterion(output, target)
        assert not np.isnan(loss.item()), 'Model diverged with loss = NaN'

        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        losses.update(loss.item(), target.size(0))
        accuracy.update(compute_batch_accuracy(output, target).item(), target.size(0))

        if i % print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Accuracy {acc.val:.3f} ({acc.avg:.3f})'.format(
                epoch, i, len(data_loader), batch_time=batch_time,
                data_time=data_time, loss=losses, acc=accuracy))

    return losses.avg, accuracy.avg


def evaluate(model, device, data_loader, criterion, print_freq=10):
    batch_time = AverageMeter()
    losses = AverageMeter()
    accuracy = AverageMeter()

    results = []

    model.eval()

    with torch.no_grad():
        end = time.time()
        for i, (input, target) in enumerate(data_loader):

            if isinstance(input, tuple):
                input = tuple([e.to(device) if type(e) == torch.Tensor else e for e in input])
            else:
                input = input.to(device)
            target = target.to(device)

            output = model(input)
            loss = criterion(output, target)

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            losses.update(loss.item(), target.size(0))
            accuracy.update(compute_batch_accuracy(output, target).item(), target.size(0))

            y_true = target.detach().to('cpu').numpy().tolist()
            y_pred = output.detach().to('cpu').max(1)[1].numpy().tolist()
            results.extend(list(zip(y_true, y_pred)))

            if i % print_freq == 0:
                print('Test: [{0}/{1}]\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                      'Accuracy {acc.val:.3f} ({acc.avg:.3f})'.format(
                    i, len(data_loader), batch_time=batch_time, loss=losses, acc=accuracy))

    return losses.avg, accuracy.avg, results

In [22]:
device = torch.device("cuda" if torch.cuda.is_available() and USE_CUDA else "cpu")
torch.manual_seed(1)
if device.type == "cuda":
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
model = MyRNN2(num_features)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters())
model.to(device)
criterion.to(device)

CrossEntropyLoss()

In [None]:
best_val_acc = 0.0
train_losses, train_accuracies = [], []
valid_losses, valid_accuracies = [], []
for epoch in range(20):
    train_loss, train_accuracy = train(model, device, train_loader, criterion, optimizer, epoch)
    valid_loss, valid_accuracy, valid_results = evaluate(model, device, valid_loader, criterion)

    train_losses.append(train_loss)
    valid_losses.append(valid_loss)

    train_accuracies.append(train_accuracy)
    valid_accuracies.append(valid_accuracy)

    is_best = valid_accuracy > best_val_acc  
    if is_best:
        best_val_acc = valid_accuracy
        torch.save(model, os.path.join("MyVariableRNN.pth"))

plot_learning_curves(train_losses, valid_losses, train_accuracies, valid_accuracies)

Epoch: [0][0/313]	Time 1.219 (1.219)	Data 0.786 (0.786)	Loss 0.6932 (0.6932)	Accuracy 56.250 (56.250)
Epoch: [0][10/313]	Time 1.546 (1.205)	Data 1.134 (0.877)	Loss 0.6915 (0.6922)	Accuracy 59.375 (61.080)
Epoch: [0][20/313]	Time 1.726 (1.203)	Data 1.307 (0.882)	Loss 0.6852 (0.6907)	Accuracy 71.875 (63.244)
Epoch: [0][30/313]	Time 1.010 (1.177)	Data 0.786 (0.863)	Loss 0.6862 (0.6896)	Accuracy 62.500 (63.206)
Epoch: [0][40/313]	Time 1.583 (1.171)	Data 1.164 (0.859)	Loss 0.6893 (0.6883)	Accuracy 56.250 (63.643)
Epoch: [0][50/313]	Time 1.017 (1.168)	Data 0.747 (0.858)	Loss 0.6736 (0.6869)	Accuracy 71.875 (64.093)
Epoch: [0][60/313]	Time 1.251 (1.165)	Data 0.927 (0.853)	Loss 0.6642 (0.6855)	Accuracy 75.000 (64.447)
Epoch: [0][70/313]	Time 1.105 (1.172)	Data 0.810 (0.859)	Loss 0.6600 (0.6839)	Accuracy 75.000 (64.877)
Epoch: [0][80/313]	Time 1.006 (1.146)	Data 0.742 (0.840)	Loss 0.6506 (0.6819)	Accuracy 71.875 (65.162)
Epoch: [0][90/313]	Time 1.053 (1.161)	Data 0.759 (0.851)	Loss 0.6468 (0.68

Epoch: [1][230/313]	Time 0.900 (1.191)	Data 0.627 (0.851)	Loss 0.6241 (0.6459)	Accuracy 68.750 (65.720)
Epoch: [1][240/313]	Time 0.972 (1.190)	Data 0.706 (0.850)	Loss 0.5348 (0.6452)	Accuracy 84.375 (65.845)
Epoch: [1][250/313]	Time 0.761 (1.185)	Data 0.538 (0.847)	Loss 0.7254 (0.6453)	Accuracy 53.125 (65.824)
Epoch: [1][260/313]	Time 1.277 (1.183)	Data 0.903 (0.846)	Loss 0.5763 (0.6448)	Accuracy 75.000 (65.912)
Epoch: [1][270/313]	Time 1.207 (1.189)	Data 0.847 (0.850)	Loss 0.6699 (0.6452)	Accuracy 59.375 (65.844)
Epoch: [1][280/313]	Time 1.189 (1.188)	Data 0.859 (0.849)	Loss 0.5808 (0.6452)	Accuracy 75.000 (65.825)
Epoch: [1][290/313]	Time 1.108 (1.195)	Data 0.770 (0.855)	Loss 0.5654 (0.6455)	Accuracy 78.125 (65.765)
Epoch: [1][300/313]	Time 1.401 (1.197)	Data 1.014 (0.856)	Loss 0.6065 (0.6462)	Accuracy 71.875 (65.635)
Epoch: [1][310/313]	Time 2.054 (1.214)	Data 1.413 (0.868)	Loss 0.6272 (0.6455)	Accuracy 68.750 (65.725)
Test: [0/313]	Time 0.912 (0.912)	Loss 0.6800 (0.6800)	Accuracy 5

Test: [170/313]	Time 0.752 (1.081)	Loss 0.6998 (0.6381)	Accuracy 56.250 (66.009)
Test: [180/313]	Time 1.382 (1.080)	Loss 0.5790 (0.6383)	Accuracy 75.000 (65.970)
Test: [190/313]	Time 0.704 (1.075)	Loss 0.7320 (0.6384)	Accuracy 50.000 (65.936)
Test: [200/313]	Time 1.369 (1.077)	Loss 0.5665 (0.6391)	Accuracy 78.125 (65.827)
Test: [210/313]	Time 1.198 (1.080)	Loss 0.6381 (0.6387)	Accuracy 65.625 (65.906)
Test: [220/313]	Time 0.764 (1.082)	Loss 0.7495 (0.6402)	Accuracy 46.875 (65.639)
Test: [230/313]	Time 0.807 (1.090)	Loss 0.6789 (0.6395)	Accuracy 59.375 (65.747)
Test: [240/313]	Time 1.671 (1.096)	Loss 0.6237 (0.6392)	Accuracy 68.750 (65.794)
Test: [250/313]	Time 0.892 (1.106)	Loss 0.5826 (0.6394)	Accuracy 75.000 (65.762)
Test: [260/313]	Time 0.880 (1.103)	Loss 0.5979 (0.6389)	Accuracy 71.875 (65.841)
Test: [270/313]	Time 1.276 (1.104)	Loss 0.5256 (0.6385)	Accuracy 84.375 (65.890)
Test: [280/313]	Time 1.193 (1.105)	Loss 0.6963 (0.6384)	Accuracy 56.250 (65.914)
Test: [290/313]	Time 0.866 (

Epoch: [4][110/313]	Time 1.205 (1.478)	Data 0.870 (1.069)	Loss 0.6298 (0.6309)	Accuracy 68.750 (67.427)
Epoch: [4][120/313]	Time 1.217 (1.479)	Data 0.859 (1.070)	Loss 0.7384 (0.6321)	Accuracy 53.125 (67.252)
Epoch: [4][130/313]	Time 1.085 (1.480)	Data 0.790 (1.071)	Loss 0.5786 (0.6335)	Accuracy 75.000 (67.032)
Epoch: [4][140/313]	Time 1.257 (1.471)	Data 0.912 (1.064)	Loss 0.5356 (0.6332)	Accuracy 81.250 (67.043)
Epoch: [4][150/313]	Time 1.198 (1.454)	Data 0.891 (1.052)	Loss 0.6136 (0.6342)	Accuracy 68.750 (66.929)
Epoch: [4][160/313]	Time 1.852 (1.446)	Data 1.304 (1.046)	Loss 0.6777 (0.6342)	Accuracy 59.375 (66.945)
Epoch: [4][170/313]	Time 1.083 (1.438)	Data 0.800 (1.040)	Loss 0.6258 (0.6334)	Accuracy 65.625 (67.032)
Epoch: [4][180/313]	Time 1.845 (1.434)	Data 1.358 (1.037)	Loss 0.6692 (0.6347)	Accuracy 65.625 (66.851)
Epoch: [4][190/313]	Time 0.933 (1.429)	Data 0.677 (1.034)	Loss 0.6642 (0.6359)	Accuracy 65.625 (66.721)
Epoch: [4][200/313]	Time 1.044 (1.429)	Data 0.780 (1.033)	Loss 0

Test: [20/313]	Time 1.142 (1.042)	Loss 0.6073 (0.6488)	Accuracy 68.750 (63.244)
Test: [30/313]	Time 0.649 (0.999)	Loss 0.7011 (0.6478)	Accuracy 56.250 (63.609)
Test: [40/313]	Time 0.986 (1.043)	Loss 0.6818 (0.6468)	Accuracy 59.375 (63.720)
Test: [50/313]	Time 1.627 (1.044)	Loss 0.6343 (0.6454)	Accuracy 65.625 (64.032)
Test: [60/313]	Time 1.408 (1.077)	Loss 0.6523 (0.6410)	Accuracy 62.500 (64.652)
Test: [70/313]	Time 1.139 (1.080)	Loss 0.6237 (0.6391)	Accuracy 65.625 (64.833)
Test: [80/313]	Time 1.382 (1.082)	Loss 0.7189 (0.6370)	Accuracy 53.125 (65.123)
Test: [90/313]	Time 0.649 (1.065)	Loss 0.6312 (0.6349)	Accuracy 65.625 (65.385)
Test: [100/313]	Time 1.883 (1.086)	Loss 0.5911 (0.6340)	Accuracy 71.875 (65.501)
Test: [110/313]	Time 0.977 (1.115)	Loss 0.6107 (0.6325)	Accuracy 65.625 (65.681)
Test: [120/313]	Time 2.421 (1.136)	Loss 0.5643 (0.6317)	Accuracy 75.000 (65.780)
Test: [130/313]	Time 0.736 (1.150)	Loss 0.6134 (0.6325)	Accuracy 68.750 (65.673)
Test: [140/313]	Time 0.888 (1.152)	L

Test: [310/313]	Time 1.062 (1.200)	Loss 0.6086 (0.6320)	Accuracy 68.750 (65.736)
Epoch: [7][0/313]	Time 1.640 (1.640)	Data 1.174 (1.174)	Loss 0.6419 (0.6419)	Accuracy 68.750 (68.750)
Epoch: [7][10/313]	Time 2.075 (1.571)	Data 1.526 (1.136)	Loss 0.6557 (0.6392)	Accuracy 65.625 (65.341)
Epoch: [7][20/313]	Time 1.588 (1.577)	Data 1.185 (1.147)	Loss 0.7078 (0.6497)	Accuracy 56.250 (63.988)
Epoch: [7][30/313]	Time 1.237 (1.559)	Data 0.860 (1.135)	Loss 0.6304 (0.6506)	Accuracy 68.750 (63.810)
Epoch: [7][40/313]	Time 1.672 (1.553)	Data 1.204 (1.130)	Loss 0.6492 (0.6578)	Accuracy 65.625 (62.957)
Epoch: [7][50/313]	Time 2.121 (1.552)	Data 1.541 (1.130)	Loss 0.5770 (0.6488)	Accuracy 71.875 (64.032)
Epoch: [7][60/313]	Time 1.377 (1.529)	Data 1.021 (1.114)	Loss 0.6755 (0.6465)	Accuracy 59.375 (64.242)
Epoch: [7][70/313]	Time 1.937 (1.519)	Data 1.409 (1.105)	Loss 0.6972 (0.6452)	Accuracy 56.250 (64.481)
Epoch: [7][80/313]	Time 1.156 (1.542)	Data 0.813 (1.123)	Loss 0.6390 (0.6438)	Accuracy 65.625 (6

Epoch: [8][220/313]	Time 1.322 (1.263)	Data 0.945 (0.899)	Loss 0.6229 (0.6338)	Accuracy 65.625 (66.205)
Epoch: [8][230/313]	Time 0.892 (1.259)	Data 0.626 (0.897)	Loss 0.7040 (0.6333)	Accuracy 59.375 (66.261)
Epoch: [8][240/313]	Time 1.274 (1.265)	Data 0.908 (0.901)	Loss 0.6928 (0.6337)	Accuracy 59.375 (66.196)
Epoch: [8][250/313]	Time 1.054 (1.267)	Data 0.718 (0.903)	Loss 0.7051 (0.6338)	Accuracy 56.250 (66.173)
Epoch: [8][260/313]	Time 1.114 (1.266)	Data 0.802 (0.902)	Loss 0.6503 (0.6336)	Accuracy 65.625 (66.212)
Epoch: [8][270/313]	Time 1.115 (1.264)	Data 0.775 (0.901)	Loss 0.6322 (0.6347)	Accuracy 62.500 (66.121)
Epoch: [8][280/313]	Time 0.952 (1.264)	Data 0.677 (0.900)	Loss 0.6500 (0.6344)	Accuracy 65.625 (66.125)
Epoch: [8][290/313]	Time 1.548 (1.261)	Data 1.084 (0.898)	Loss 0.5907 (0.6347)	Accuracy 75.000 (66.076)
Epoch: [8][300/313]	Time 1.176 (1.270)	Data 0.832 (0.905)	Loss 0.6690 (0.6356)	Accuracy 62.500 (65.926)
Epoch: [8][310/313]	Time 1.376 (1.268)	Data 0.961 (0.903)	Loss 0

In [None]:
best_model = torch.load(os.path.join(PATH_OUTPUT, "MyVariableRNN.pth"))
class_names = ['Alive', 'Dead']
test_loss, test_accuracy, test_results = evaluate(best_model, device, valid_loader, criterion)
plot_confusion_matrix(test_results, class_names)