In [1]:
import numpy as np
import argparse
import os
import imp
import re
import pickle
import datetime
import random
import math
import copy


import torch
from torch import nn
import torch.nn.utils.rnn as rnn_utils
from torch.utils import data
from torch.autograd import Variable
import torch.nn.functional as F


from utils import utils
from utils.readers import InHospitalMortalityReader
from utils.preprocessing import Discretizer, Normalizer
from utils import metrics
from utils import common_utils

### Prepare

In [2]:
data_path = './data/'
file_name = './model/concare0'
small_part = False
arg_timestep = 1.0
batch_size = 256
epochs = 100

In [3]:
# Build readers, discretizers, normalizers
train_reader = InHospitalMortalityReader(dataset_dir=os.path.join(data_path, 'train'),
                                         listfile=os.path.join(data_path, 'train_listfile.csv'),
                                         period_length=48.0)

val_reader = InHospitalMortalityReader(dataset_dir=os.path.join(data_path, 'train'),
                                       listfile=os.path.join(data_path, 'val_listfile.csv'),
                                       period_length=48.0)

discretizer = Discretizer(timestep=arg_timestep,
                          store_masks=True,
                          impute_strategy='previous',
                          start_time='zero')

In [4]:
discretizer_header = discretizer.transform(train_reader.read_example(0)["X"])[1].split(',')
cont_channels = [i for (i, x) in enumerate(discretizer_header) if x.find("->") == -1]

normalizer = Normalizer(fields=cont_channels)  # choose here which columns to standardize
normalizer_state = 'ihm_normalizer'
normalizer_state = os.path.join(os.path.dirname(data_path), normalizer_state)
normalizer.load_params(normalizer_state)

In [5]:
n_trained_chunks = 0
train_raw = utils.load_data(train_reader, discretizer, normalizer, small_part, return_names=True)
val_raw = utils.load_data(val_reader, discretizer, normalizer, small_part, return_names=True)

In [6]:
demographic_data = []
diagnosis_data = []
idx_list = []

demo_path = data_path + 'demographic/'
for cur_name in os.listdir(demo_path):
    cur_id, cur_episode = cur_name.split('_', 1)
    cur_episode = cur_episode[:-4]
    cur_file = demo_path + cur_name

    with open(cur_file, "r") as tsfile:
        header = tsfile.readline().strip().split(',')
        if header[0] != "Icustay":
            continue
        cur_data = tsfile.readline().strip().split(',')
        
    if len(cur_data) == 1:
        cur_demo = np.zeros(12)
        cur_diag = np.zeros(128)
    else:
        if cur_data[3] == '':
            cur_data[3] = 60.0
        if cur_data[4] == '':
            cur_data[4] = 160
        if cur_data[5] == '':
            cur_data[5] = 60

        cur_demo = np.zeros(12)
        cur_demo[int(cur_data[1])] = 1
        cur_demo[5 + int(cur_data[2])] = 1
        cur_demo[9:] = cur_data[3:6]
        cur_diag = np.array(cur_data[8:], dtype=np.int)

    demographic_data.append(cur_demo)
    diagnosis_data.append(cur_diag)
    idx_list.append(cur_id+'_'+cur_episode)

for each_idx in range(9,12):
    cur_val = []
    for i in range(len(demographic_data)):
        cur_val.append(demographic_data[i][each_idx])
    cur_val = np.array(cur_val)
    _mean = np.mean(cur_val)
    _std = np.std(cur_val)
    _std = _std if _std > 1e-7 else 1e-7
    for i in range(len(demographic_data)):
        demographic_data[i][each_idx] = (demographic_data[i][each_idx] - _mean) / _std

In [7]:
device = torch.device("cuda:0" if torch.cuda.is_available() == True else 'cpu')
#device = torch.device('cpu')
print("available device: {}".format(device))

available device: cuda:0


### model

In [8]:
class SingleAttention(nn.Module):
    def __init__(self, attention_input_dim, attention_hidden_dim, attention_type='add', demographic_dim=12, time_aware=False, use_demographic=False):
        super(SingleAttention, self).__init__()
        
        self.attention_type = attention_type
        self.attention_hidden_dim = attention_hidden_dim
        self.attention_input_dim = attention_input_dim
        self.use_demographic = use_demographic
        self.demographic_dim = demographic_dim
        self.time_aware = time_aware

        # batch_time = torch.arange(0, batch_mask.size()[1], dtype=torch.float32).reshape(1, batch_mask.size()[1], 1)
        # batch_time = batch_time.repeat(batch_mask.size()[0], 1, 1)
        
        if attention_type == 'add':
            if self.time_aware == True:
                # self.Wx = nn.Parameter(torch.randn(attention_input_dim+1, attention_hidden_dim))
                self.Wx = nn.Parameter(torch.randn(attention_input_dim, attention_hidden_dim))
                self.Wtime_aware = nn.Parameter(torch.randn(1, attention_hidden_dim))
                nn.init.kaiming_uniform_(self.Wtime_aware, a=math.sqrt(5))
            else:
                self.Wx = nn.Parameter(torch.randn(attention_input_dim, attention_hidden_dim))
            self.Wt = nn.Parameter(torch.randn(attention_input_dim, attention_hidden_dim))
            self.Wd = nn.Parameter(torch.randn(demographic_dim, attention_hidden_dim))
            self.bh = nn.Parameter(torch.zeros(attention_hidden_dim,))
            self.Wa = nn.Parameter(torch.randn(attention_hidden_dim, 1))
            self.ba = nn.Parameter(torch.zeros(1,))
            
            nn.init.kaiming_uniform_(self.Wd, a=math.sqrt(5))
            nn.init.kaiming_uniform_(self.Wx, a=math.sqrt(5))
            nn.init.kaiming_uniform_(self.Wt, a=math.sqrt(5))
            nn.init.kaiming_uniform_(self.Wa, a=math.sqrt(5))
        elif attention_type == 'mul':
            self.Wa = nn.Parameter(torch.randn(attention_input_dim, attention_input_dim))
            self.ba = nn.Parameter(torch.zeros(1,))
            
            nn.init.kaiming_uniform_(self.Wa, a=math.sqrt(5))
        elif attention_type == 'concat':
            if self.time_aware == True:
                self.Wh = nn.Parameter(torch.randn(2*attention_input_dim+1, attention_hidden_dim))
            else:
                self.Wh = nn.Parameter(torch.randn(2*attention_input_dim, attention_hidden_dim))

            self.Wa = nn.Parameter(torch.randn(attention_hidden_dim, 1))
            self.ba = nn.Parameter(torch.zeros(1,))
            
            nn.init.kaiming_uniform_(self.Wh, a=math.sqrt(5))
            nn.init.kaiming_uniform_(self.Wa, a=math.sqrt(5))
            
        elif attention_type == 'new':
            self.Wt = nn.Parameter(torch.randn(attention_input_dim, attention_hidden_dim))
            self.Wx = nn.Parameter(torch.randn(attention_input_dim, attention_hidden_dim))

            self.rate = nn.Parameter(torch.zeros(1)+0.8)
            nn.init.kaiming_uniform_(self.Wx, a=math.sqrt(5))
            nn.init.kaiming_uniform_(self.Wt, a=math.sqrt(5))
            
        else:
            raise RuntimeError('Wrong attention type.')
        
        self.tanh = nn.Tanh()
        self.softmax = nn.Softmax()
        self.sigmoid = nn.Sigmoid()
        self.relu = nn.ReLU()
    
    def forward(self, input, demo=None):
 
        batch_size, time_step, input_dim = input.size() # batch_size * time_step * hidden_dim(i)
        #assert(input_dim == self.input_dim)

        # time_decays = torch.zeros((time_step,time_step)).to(device)# t*t
        # for this_time in range(time_step):
        #     for pre_time in range(time_step):
        #         if pre_time > this_time:
        #             break
        #         time_decays[this_time][pre_time] = torch.tensor(this_time - pre_time, dtype=torch.float32).to(device)
        # b_time_decays = tile(time_decays, 0, batch_size).view(batch_size,time_step,time_step).unsqueeze(-1).to(device)# b t t 1

        time_decays = torch.tensor(range(47,-1,-1), dtype=torch.float32).unsqueeze(-1).unsqueeze(0).to(device)# 1*t*1
        b_time_decays = time_decays.repeat(batch_size,1,1)+1# b t 1
        
        if self.attention_type == 'add': #B*T*I  @ H*I
            q = torch.matmul(input[:,-1,:], self.Wt)# b h
            q = torch.reshape(q, (batch_size, 1, self.attention_hidden_dim)) #B*1*H
            if self.time_aware == True:
                # k_input = torch.cat((input, time), dim=-1)
                k = torch.matmul(input, self.Wx)#b t h
                # k = torch.reshape(k, (batch_size, 1, time_step, self.attention_hidden_dim)) #B*1*T*H
                time_hidden = torch.matmul(b_time_decays, self.Wtime_aware)#  b t h
            else:
                k = torch.matmul(input, self.Wx)# b t h
                # k = torch.reshape(k, (batch_size, 1, time_step, self.attention_hidden_dim)) #B*1*T*H
            if self.use_demographic == True:
                d = torch.matmul(demo, self.Wd) #B*H
                d = torch.reshape(d, (batch_size, 1, self.attention_hidden_dim)) # b 1 h
            h = q + k + self.bh # b t h
            if self.time_aware == True:
                h += time_hidden
            h = self.tanh(h) #B*T*H
            e = torch.matmul(h, self.Wa) + self.ba #B*T*1
            e = torch.reshape(e, (batch_size, time_step))# b t
        elif self.attention_type == 'mul':
            e = torch.matmul(input[:,-1,:], self.Wa)#b i
            e = torch.matmul(e.unsqueeze(1), input.permute(0,2,1)).squeeze() + self.ba #b t
        elif self.attention_type == 'concat':
            q = input[:,-1,:].unsqueeze(1).repeat(1,time_step,1)# b t i
            k = input
            c = torch.cat((q, k), dim=-1) #B*T*2I
            if self.time_aware == True:
                c = torch.cat((c, b_time_decays), dim=-1) #B*T*2I+1
            h = torch.matmul(c, self.Wh)
            h = self.tanh(h)
            e = torch.matmul(h, self.Wa) + self.ba #B*T*1
            e = torch.reshape(e, (batch_size, time_step)) # b t 
            
        elif self.attention_type == 'new':
            
            q = torch.matmul(input[:,-1,:], self.Wt)# b h
            q = torch.reshape(q, (batch_size, 1, self.attention_hidden_dim)) #B*1*H
            k = torch.matmul(input, self.Wx)#b t h
            dot_product = torch.matmul(q, k.transpose(1, 2)).squeeze() # b t
            denominator =  self.sigmoid(self.rate) * (torch.log(2.72 +  (1-self.sigmoid(dot_product)))* (b_time_decays.squeeze()))
            e = self.relu(self.sigmoid(dot_product)/(denominator)) # b * t
#          * (b_time_decays.squeeze())
        # e = torch.exp(e - torch.max(e, dim=-1, keepdim=True).values)
        
        # if self.attention_width is not None:
        #     if self.history_only:
        #         lower = torch.arange(0, time_step).to(device) - (self.attention_width - 1)
        #     else:
        #         lower = torch.arange(0, time_step).to(device) - self.attention_width // 2
        #     lower = lower.unsqueeze(-1)
        #     upper = lower + self.attention_width
        #     indices = torch.arange(0, time_step).unsqueeze(0).to(device)
        #     e = e * (lower <= indices).float() * (indices < upper).float()
        
        # s = torch.sum(e, dim=-1, keepdim=True)
        # mask = subsequent_mask(time_step).to(device) # 1 t t 下三角
        # scores = e.masked_fill(mask == 0, -1e9)# b t t 下三角
        a = self.softmax(e) #B*T
        v = torch.matmul(a.unsqueeze(1), input).squeeze() #B*I

        return v, a

class FinalAttentionQKV(nn.Module):
    def __init__(self, attention_input_dim, attention_hidden_dim, attention_type='add', dropout=None):
        super(FinalAttentionQKV, self).__init__()
        
        self.attention_type = attention_type
        self.attention_hidden_dim = attention_hidden_dim
        self.attention_input_dim = attention_input_dim


        self.W_q = nn.Linear(attention_input_dim, attention_hidden_dim)
        self.W_k = nn.Linear(attention_input_dim, attention_hidden_dim)
        self.W_v = nn.Linear(attention_input_dim, attention_hidden_dim)

        self.W_out = nn.Linear(attention_hidden_dim, 1)

        self.b_in = nn.Parameter(torch.zeros(1,))
        self.b_out = nn.Parameter(torch.zeros(1,))

        nn.init.kaiming_uniform_(self.W_q.weight, a=math.sqrt(5))
        nn.init.kaiming_uniform_(self.W_k.weight, a=math.sqrt(5))
        nn.init.kaiming_uniform_(self.W_v.weight, a=math.sqrt(5))
        nn.init.kaiming_uniform_(self.W_out.weight, a=math.sqrt(5))

        self.Wh = nn.Parameter(torch.randn(2*attention_input_dim, attention_hidden_dim))
        self.Wa = nn.Parameter(torch.randn(attention_hidden_dim, 1))
        self.ba = nn.Parameter(torch.zeros(1,))
        
        nn.init.kaiming_uniform_(self.Wh, a=math.sqrt(5))
        nn.init.kaiming_uniform_(self.Wa, a=math.sqrt(5))
        
        self.dropout = nn.Dropout(p=dropout)
        self.tanh = nn.Tanh()
        self.softmax = nn.Softmax()
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, input):
 
        batch_size, time_step, input_dim = input.size() # batch_size * input_dim + 1 * hidden_dim(i)
        input_q = self.W_q(input[:, -1, :]) # b h
        input_k = self.W_k(input)# b t h
        input_v = self.W_v(input)# b t h

        if self.attention_type == 'add': #B*T*I  @ H*I

            q = torch.reshape(input_q, (batch_size, 1, self.attention_hidden_dim)) #B*1*H
            h = q + input_k + self.b_in # b t h
            h = self.tanh(h) #B*T*H
            e = self.W_out(h) # b t 1
            e = torch.reshape(e, (batch_size, time_step))# b t

        elif self.attention_type == 'mul':
            q = torch.reshape(input_q, (batch_size, self.attention_hidden_dim, 1)) #B*h 1
            e = torch.matmul(input_k, q).squeeze()#b t
            
        elif self.attention_type == 'concat':
            q = input_q.unsqueeze(1).repeat(1,time_step,1)# b t h
            k = input_k
            c = torch.cat((q, k), dim=-1) #B*T*2I
            h = torch.matmul(c, self.Wh)
            h = self.tanh(h)
            e = torch.matmul(h, self.Wa) + self.ba #B*T*1
            e = torch.reshape(e, (batch_size, time_step)) # b t 
        
        a = self.softmax(e) #B*T
        if self.dropout is not None:
            a = self.dropout(a)
        v = torch.matmul(a.unsqueeze(1), input_v).squeeze() #B*I

        return v, a

def clones(module, N):
    "Produce N identical layers."
    return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])

def tile(a, dim, n_tile):
    init_dim = a.size(dim)
    repeat_idx = [1] * a.dim()
    repeat_idx[dim] = n_tile
    a = a.repeat(*(repeat_idx))
    order_index = torch.LongTensor(np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)])).to(device)
    return torch.index_select(a, dim, order_index).to(device)

class PositionwiseFeedForward(nn.Module): # new added
    "Implements FFN equation."
    def __init__(self, d_model, d_ff, dropout=0.1):
        super(PositionwiseFeedForward, self).__init__()
        self.w_1 = nn.Linear(d_model, d_ff)
        self.w_2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        return self.w_2(self.dropout(F.relu(self.w_1(x)))), None

# class PositionwiseFeedForwardConv(nn.Module):

#     def __init__(self, model_dim=512, ffn_dim=2048, dropout=0.0):
#         super(PositionalWiseFeedForward, self).__init__()
#         self.w1 = nn.Conv1d(model_dim, ffn_dim, 1)
#         self.w2 = nn.Conv1d(model_dim, ffn_dim, 1)
#         self.dropout = nn.Dropout(dropout)
#         self.layer_norm = nn.LayerNorm(model_dim)

#     def forward(self, x):
#         output = x.transpose(1, 2)
#         output = self.w2(F.relu(self.w1(output)))
#         output = self.dropout(output.transpose(1, 2))

#         # add residual and norm layer
#         output = self.layer_norm(x + output)
#         return output

class PositionalEncoding(nn.Module): # new added / not use anymore
    "Implement the PE function."
    def __init__(self, d_model, dropout, max_len=400):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        
        # Compute the positional encodings once in log space.
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0., max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0., d_model, 2) * -(math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)
        
    def forward(self, x):
        x = x + Variable(self.pe[:, :x.size(1)], 
                         requires_grad=False)
        return self.dropout(x)

def subsequent_mask(size):
    "Mask out subsequent positions."
    attn_shape = (1, size, size)
    subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8')
    return torch.from_numpy(subsequent_mask) == 0 # 下三角矩阵

def attention(query, key, value, mask=None, dropout=None):
    "Compute 'Scaled Dot Product Attention'"
    d_k = query.size(-1)# b h t d_k
    scores = torch.matmul(query, key.transpose(-2, -1)) \
             / math.sqrt(d_k) # b h t t
    if mask is not None:# 1 1 t t
        scores = scores.masked_fill(mask == 0, -1e9)# b h t t 下三角
    p_attn = F.softmax(scores, dim = -1)# b h t t
    if dropout is not None:
        p_attn = dropout(p_attn)
    return torch.matmul(p_attn, value), p_attn # b h t v (d_k) 
    
class MultiHeadedAttention(nn.Module):
    def __init__(self, h, d_model, dropout=0):
        "Take in model size and number of heads."
        super(MultiHeadedAttention, self).__init__()
        assert d_model % h == 0
        # We assume d_v always equals d_k
        self.d_k = d_model // h
        self.h = h
        self.linears = clones(nn.Linear(d_model, self.d_k * self.h), 3)
        self.final_linear = nn.Linear(d_model, d_model)
        self.attn = None
        self.dropout = nn.Dropout(p=dropout)
        
    def forward(self, query, key, value, mask=None):
        if mask is not None:
            # Same mask applied to all h heads.
            mask = mask.unsqueeze(1) # 1 1 t t

        nbatches = query.size(0)# b
        input_dim = query.size(1)# i+1
        feature_dim = query.size(-1)# i+1

        #input size -> # batch_size * d_input * hidden_dim
        
        # d_model => h * d_k 
        query, key, value = \
            [l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
             for l, x in zip(self.linears, (query, key, value))] # b num_head d_input d_k
        
       
        x, self.attn = attention(query, key, value, mask=mask, 
                                 dropout=self.dropout)# b num_head d_input d_v (d_k) 

      
        x = x.transpose(1, 2).contiguous() \
             .view(nbatches, -1, self.h * self.d_k)# batch_size * d_input * hidden_dim

        #DeCov 
        DeCov_contexts = x.transpose(0, 1).transpose(1, 2) # I+1 H B
        Covs = cov(DeCov_contexts[0,:,:])
        DeCov_loss = 0.5 * (torch.norm(Covs, p = 'fro')**2 - torch.norm(torch.diag(Covs))**2 ) 
        for i in range(feature_dim -1 + 1):
            Covs = cov(DeCov_contexts[i+1,:,:])
            DeCov_loss += 0.5 * (torch.norm(Covs, p = 'fro')**2 - torch.norm(torch.diag(Covs))**2 ) 


        return self.final_linear(x), DeCov_loss

class LayerNorm(nn.Module):
    def __init__(self, features, eps=1e-7):
        super(LayerNorm, self).__init__()
        self.a_2 = nn.Parameter(torch.ones(features))
        self.b_2 = nn.Parameter(torch.zeros(features))
        self.eps = eps

    def forward(self, x):
        mean = x.mean(-1, keepdim=True)
        std = x.std(-1, keepdim=True)
        return self.a_2 * (x - mean) / (std + self.eps) + self.b_2

def cov(m, y=None):
    if y is not None:
        m = torch.cat((m, y), dim=0)
    m_exp = torch.mean(m, dim=1)
    x = m - m_exp[:, None]
    cov = 1 / (x.size(1) - 1) * x.mm(x.t())
    return cov

class SublayerConnection(nn.Module):
    """
    A residual connection followed by a layer norm.
    Note for code simplicity the norm is first as opposed to last.
    """
    def __init__(self, size, dropout):
        super(SublayerConnection, self).__init__()
        self.norm = LayerNorm(size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, sublayer):
        "Apply residual connection to any sublayer with the same size."
        returned_value = sublayer(self.norm(x))
        return x + self.dropout(returned_value[0]) , returned_value[1]

class ConCare(nn.Module):
    def __init__(self, input_dim, hidden_dim, d_model,  MHD_num_head, d_ff, output_dim, keep_prob=0.5):
        super(ConCare, self).__init__()

        # hyperparameters
        self.input_dim = input_dim  
        self.hidden_dim = hidden_dim  # d_model
        self.d_model = d_model
        self.MHD_num_head = MHD_num_head
        self.d_ff = d_ff
        self.output_dim = output_dim
        self.keep_prob = keep_prob

        # layers
        self.PositionalEncoding = PositionalEncoding(self.d_model, dropout = 0, max_len = 400)

        self.GRUs = clones(nn.GRU(1, self.hidden_dim, batch_first = True), self.input_dim)
        self.LastStepAttentions = clones(SingleAttention(self.hidden_dim, 8, attention_type='new', demographic_dim=12, time_aware=True, use_demographic=False),self.input_dim)
        
        self.FinalAttentionQKV = FinalAttentionQKV(self.hidden_dim, self.hidden_dim, attention_type='mul',dropout = 1 - self.keep_prob)

        self.MultiHeadedAttention = MultiHeadedAttention(self.MHD_num_head, self.d_model,dropout = 1 - self.keep_prob)
        self.SublayerConnection = SublayerConnection(self.d_model, dropout = 1 - self.keep_prob)

        self.PositionwiseFeedForward = PositionwiseFeedForward(self.d_model, self.d_ff, dropout=0.1)

        self.demo_proj_main = nn.Linear(12, self.hidden_dim)
        self.demo_proj = nn.Linear(12, self.hidden_dim)
        self.output0 = nn.Linear(self.hidden_dim, self.hidden_dim)
        self.output1 = nn.Linear(self.hidden_dim, self.output_dim)

        self.dropout = nn.Dropout(p = 1 - self.keep_prob)
        self.tanh=nn.Tanh()
        self.softmax = nn.Softmax()
        self.sigmoid = nn.Sigmoid()
        self.relu=nn.ReLU()

    def forward(self, input, demo_input):
        # input shape [batch_size, timestep, feature_dim]
        demo_main = self.tanh(self.demo_proj_main(demo_input)).unsqueeze(1)# b hidden_dim
        
        batch_size = input.size(0)
        time_step = input.size(1)
        feature_dim = input.size(2)
        assert(feature_dim == self.input_dim)# input Tensor : 256 * 48 * 76
        assert(self.d_model % self.MHD_num_head == 0)

        # Initialization
        #cur_hs = Variable(torch.zeros(batch_size, self.hidden_dim).unsqueeze(0))

        # forward
        GRU_embeded_input = self.GRUs[0](input[:,:,0].unsqueeze(-1), Variable(torch.zeros(batch_size, self.hidden_dim).unsqueeze(0)).to(device))[0] # b t h
        Attention_embeded_input = self.LastStepAttentions[0](GRU_embeded_input)[0].unsqueeze(1)# b 1 h
        for i in range(feature_dim-1):
            embeded_input = self.GRUs[i+1](input[:,:,i+1].unsqueeze(-1), Variable(torch.zeros(batch_size, self.hidden_dim).unsqueeze(0)).to(device))[0] # b 1 h
            embeded_input = self.LastStepAttentions[i+1](embeded_input)[0].unsqueeze(1)# b 1 h
            Attention_embeded_input = torch.cat((Attention_embeded_input, embeded_input), 1)# b i h

        Attention_embeded_input = torch.cat((Attention_embeded_input, demo_main), 1)# b i+1 h
        posi_input = self.dropout(Attention_embeded_input) # batch_size * d_input+1 * hidden_dim

#         GRU_embeded_input = self.GRUs[0](input[:,:,0].unsqueeze(-1), Variable(torch.zeros(batch_size, self.hidden_dim).unsqueeze(0)).to(device))[0][:,-1,:].unsqueeze(1) # b 1 h
#         for i in range(feature_dim-1):
#             embeded_input = self.GRUs[i+1](input[:,:,i+1].unsqueeze(-1), Variable(torch.zeros(batch_size, self.hidden_dim).unsqueeze(0)).to(device))[0][:,-1,:].unsqueeze(1) # b 1 h
#             GRU_embeded_input = torch.cat((GRU_embeded_input, embeded_input), 1)

#         GRU_embeded_input = torch.cat((GRU_embeded_input, demo_main), 1)# b i+1 h
#         posi_input = self.dropout(GRU_embeded_input) # batch_size * d_input * hidden_dim


        #mask = subsequent_mask(time_step).to(device) # 1 t t 下三角 N to 1任务不用mask
        contexts = self.SublayerConnection(posi_input, lambda x: self.MultiHeadedAttention(posi_input, posi_input, posi_input, None))# # batch_size * d_input * hidden_dim
    
        DeCov_loss = contexts[1]
        contexts = contexts[0]

        contexts = self.SublayerConnection(contexts, lambda x: self.PositionwiseFeedForward(contexts))[0]# # batch_size * d_input * hidden_dim
        #contexts = contexts.view(batch_size, feature_dim * self.hidden_dim)#
        # contexts = torch.matmul(self.Wproj, contexts) + self.bproj
        # contexts = contexts.squeeze()
        # demo_key = self.demo_proj(demo_input)# b hidden_dim
        # demo_key = self.relu(demo_key)
        # input_dim_scores = torch.matmul(contexts, demo_key.unsqueeze(-1)).squeeze() # b i
        # input_dim_scores = self.dropout(self.sigmoid(input_dim_scores)).unsqueeze(1)# b i
        
        # weighted_contexts = torch.matmul(input_dim_scores, contexts).squeeze()

        weighted_contexts = self.FinalAttentionQKV(contexts)[0]
        output = self.output1(self.relu(self.output0(weighted_contexts)))# b 1
        output = self.sigmoid(output)
          
        return output, DeCov_loss
    #, self.MultiHeadedAttention.attn



In [9]:
def get_loss(y_pred, y_true):
    loss = torch.nn.BCELoss()
    return loss(y_pred, y_true)

In [10]:
class Dataset(data.Dataset):
    def __init__(self, x, y, name):
        self.x = x
        self.y = y
        self.name = name

    def __getitem__(self, index):#返回的是tensor
        return self.x[index], self.y[index], self.name[index]

    def __len__(self):
        return len(self.x)

In [11]:
train_dataset = Dataset(train_raw['data'][0], train_raw['data'][1], train_raw['names'])
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
valid_dataset = Dataset(val_raw['data'][0], val_raw['data'][1], val_raw['names'])
valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)

### Run

In [12]:
RANDOM_SEED = 42
np.random.seed(RANDOM_SEED) #numpy
random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED) # cpu
torch.cuda.manual_seed(RANDOM_SEED) #gpu
torch.backends.cudnn.deterministic=True # cudnn

model = ConCare(input_dim = 76, hidden_dim = 64, d_model = 64,  MHD_num_head = 4 , d_ff = 256, output_dim = 1).to(device)
# input_dim, d_model, d_k, d_v, MHD_num_head, d_ff, output_dim
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

max_roc = 0
max_prc = 0
train_loss = []
train_model_loss = []
train_decov_loss = []
valid_loss = []
valid_model_loss = []
valid_decov_loss = []
history = []
np.set_printoptions(threshold=np.inf)
np.set_printoptions(precision=2)
np.set_printoptions(suppress=True)

for each_epoch in range(100):
    batch_loss = []
    model_batch_loss = []
    decov_batch_loss = []

    model.train()
 
    for step, (batch_x, batch_y, batch_name) in enumerate(train_loader):   
        optimizer.zero_grad()
        batch_x = batch_x.float().to(device)
        batch_y = batch_y.float().to(device)

        batch_demo = []
        for i in range(len(batch_name)):
            cur_id, cur_ep, _ = batch_name[i].split('_', 2)
            cur_idx = cur_id + '_' + cur_ep
            cur_demo = torch.tensor(demographic_data[idx_list.index(cur_idx)], dtype=torch.float32)
            batch_demo.append(cur_demo)
        
        batch_demo = torch.stack(batch_demo).to(device)
        output, decov_loss = model(batch_x, batch_demo)
        
        
        model_loss = get_loss(output, batch_y.unsqueeze(-1))
        loss = model_loss + 800* decov_loss
        
        batch_loss.append(loss.cpu().detach().numpy())
        model_batch_loss.append(model_loss.cpu().detach().numpy())
        decov_batch_loss.append(decov_loss.cpu().detach().numpy())
        loss.backward()
        optimizer.step()
        
        if step % 30 == 0:
            print('Epoch %d Batch %d: Train Loss = %.4f'%(each_epoch, step, np.mean(np.array(batch_loss))))
            print('Model Loss = %.4f, Decov Loss = %.4f'%(np.mean(np.array(model_batch_loss)), np.mean(np.array(decov_batch_loss))))
    train_loss.append(np.mean(np.array(batch_loss)))
    train_model_loss.append(np.mean(np.array(model_batch_loss)))
    train_decov_loss.append(np.mean(np.array(decov_batch_loss)))
    
    batch_loss = []
    model_batch_loss = []
    decov_batch_loss = []
    
    y_true = []
    y_pred = []
    with torch.no_grad():
        model.eval()
        for step, (batch_x, batch_y, batch_name) in enumerate(valid_loader):
            batch_x = batch_x.float().to(device)
            batch_y = batch_y.float().to(device)
            batch_demo = []
            for i in range(len(batch_name)):
                cur_id, cur_ep, _ = batch_name[i].split('_', 2)
                cur_idx = cur_id + '_' + cur_ep
                cur_demo = torch.tensor(demographic_data[idx_list.index(cur_idx)], dtype=torch.float32)
                batch_demo.append(cur_demo)

            batch_demo = torch.stack(batch_demo).to(device)
            output,decov_loss = model(batch_x, batch_demo)
            
            model_loss = get_loss(output, batch_y.unsqueeze(-1))

            loss = model_loss + 10* decov_loss
            batch_loss.append(loss.cpu().detach().numpy())
            model_batch_loss.append(model_loss.cpu().detach().numpy())
            decov_batch_loss.append(decov_loss.cpu().detach().numpy())
            y_pred += list(output.cpu().detach().numpy().flatten())
            y_true += list(batch_y.cpu().numpy().flatten())
            
    valid_loss.append(np.mean(np.array(batch_loss)))
    valid_model_loss.append(np.mean(np.array(model_batch_loss)))
    valid_decov_loss.append(np.mean(np.array(decov_batch_loss)))
    
    print("\n==>Predicting on validation")
    print('Valid Loss = %.4f'%(valid_loss[-1]))
    print('valid_model Loss = %.4f'%(valid_model_loss[-1]))
    print('valid_decov Loss = %.4f'%(valid_decov_loss[-1]))
    y_pred = np.array(y_pred)
    y_pred = np.stack([1 - y_pred, y_pred], axis=1)
    ret = metrics.print_metrics_binary(y_true, y_pred)
    history.append(ret)
    print()

    cur_auroc = ret['auroc']
    
    if cur_auroc > max_roc:
        max_roc = cur_auroc
        state = {
            'net': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'epoch': each_epoch
        }
        torch.save(state, file_name)
        print('\n------------ Save best model ------------\n')



Epoch 0 Batch 0: Train Loss = 0.9343
Model Loss = 0.7072, Decov Loss = 0.0003
Epoch 0 Batch 30: Train Loss = 0.5345
Model Loss = 0.4976, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.3934
valid_model Loss = 0.3934
valid_decov Loss = 0.0000
confusion matrix:
[[2786    0]
 [ 436    0]]
accuracy = 0.8646803498268127
precision class 0 = 0.8646803498268127
precision class 1 = nan
recall class 0 = 1.0
recall class 1 = 0.0
AUC of ROC = 0.7543364759577705
AUC of PRC = 0.33113012079018
min(+P, Se) = 0.3775280898876405
f1_score = nan


------------ Save best model ------------



  prec1 = cf[1][1] / (cf[1][1] + cf[0][1]) # TP/TP+FP = PPV = precision


Epoch 1 Batch 0: Train Loss = 0.4347
Model Loss = 0.4305, Decov Loss = 0.0000
Epoch 1 Batch 30: Train Loss = 0.4032
Model Loss = 0.4000, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.3907
valid_model Loss = 0.3907
valid_decov Loss = 0.0000
confusion matrix:
[[2786    0]
 [ 436    0]]
accuracy = 0.8646803498268127
precision class 0 = 0.8646803498268127
precision class 1 = nan
recall class 0 = 1.0
recall class 1 = 0.0
AUC of ROC = 0.7708039707054276
AUC of PRC = 0.38223636256297006
min(+P, Se) = 0.4105504587155963
f1_score = nan


------------ Save best model ------------



  prec1 = cf[1][1] / (cf[1][1] + cf[0][1]) # TP/TP+FP = PPV = precision


Epoch 2 Batch 0: Train Loss = 0.4378
Model Loss = 0.4354, Decov Loss = 0.0000
Epoch 2 Batch 30: Train Loss = 0.3942
Model Loss = 0.3917, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.3520
valid_model Loss = 0.3520
valid_decov Loss = 0.0000
confusion matrix:
[[2786    0]
 [ 436    0]]
accuracy = 0.8646803498268127
precision class 0 = 0.8646803498268127
precision class 1 = nan
recall class 0 = 1.0
recall class 1 = 0.0
AUC of ROC = 0.7692031586503947
AUC of PRC = 0.3623708077025312
min(+P, Se) = 0.39635535307517084
f1_score = nan



  prec1 = cf[1][1] / (cf[1][1] + cf[0][1]) # TP/TP+FP = PPV = precision


Epoch 3 Batch 0: Train Loss = 0.3457
Model Loss = 0.3436, Decov Loss = 0.0000
Epoch 3 Batch 30: Train Loss = 0.3671
Model Loss = 0.3650, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.3267
valid_model Loss = 0.3267
valid_decov Loss = 0.0000
confusion matrix:
[[2786    0]
 [ 436    0]]
accuracy = 0.8646803498268127
precision class 0 = 0.8646803498268127
precision class 1 = nan
recall class 0 = 1.0
recall class 1 = 0.0
AUC of ROC = 0.798300150819629
AUC of PRC = 0.40593964609665545
min(+P, Se) = 0.4429223744292237
f1_score = nan


------------ Save best model ------------



  prec1 = cf[1][1] / (cf[1][1] + cf[0][1]) # TP/TP+FP = PPV = precision


Epoch 4 Batch 0: Train Loss = 0.3183
Model Loss = 0.3166, Decov Loss = 0.0000
Epoch 4 Batch 30: Train Loss = 0.3453
Model Loss = 0.3435, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.3241
valid_model Loss = 0.3241
valid_decov Loss = 0.0000
confusion matrix:
[[2786    0]
 [ 436    0]]
accuracy = 0.8646803498268127
precision class 0 = 0.8646803498268127
precision class 1 = nan
recall class 0 = 1.0
recall class 1 = 0.0
AUC of ROC = 0.8121562926032522
AUC of PRC = 0.43459985119226
min(+P, Se) = 0.46788990825688076
f1_score = nan


------------ Save best model ------------



  prec1 = cf[1][1] / (cf[1][1] + cf[0][1]) # TP/TP+FP = PPV = precision


Epoch 5 Batch 0: Train Loss = 0.3866
Model Loss = 0.3851, Decov Loss = 0.0000
Epoch 5 Batch 30: Train Loss = 0.3339
Model Loss = 0.3323, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.3141
valid_model Loss = 0.3141
valid_decov Loss = 0.0000
confusion matrix:
[[2709   77]
 [ 340   96]]
accuracy = 0.8705772757530212
precision class 0 = 0.8884880542755127
precision class 1 = 0.5549132823944092
recall class 0 = 0.9723618030548096
recall class 1 = 0.22018349170684814
AUC of ROC = 0.8197886549391782
AUC of PRC = 0.42995385674405207
min(+P, Se) = 0.4541284403669725
f1_score = 0.31527093956437374


------------ Save best model ------------





Epoch 6 Batch 0: Train Loss = 0.3610
Model Loss = 0.3593, Decov Loss = 0.0000
Epoch 6 Batch 30: Train Loss = 0.3289
Model Loss = 0.3272, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.3109
valid_model Loss = 0.3109
valid_decov Loss = 0.0000
confusion matrix:
[[2763   23]
 [ 384   52]]
accuracy = 0.8736809492111206
precision class 0 = 0.877979040145874
precision class 1 = 0.6933333277702332
recall class 0 = 0.9917444586753845
recall class 1 = 0.11926605552434921
AUC of ROC = 0.8251035650072118
AUC of PRC = 0.45035242588487673
min(+P, Se) = 0.45642201834862384
f1_score = 0.2035225109475395


------------ Save best model ------------





Epoch 7 Batch 0: Train Loss = 0.4361
Model Loss = 0.4344, Decov Loss = 0.0000
Epoch 7 Batch 30: Train Loss = 0.3357
Model Loss = 0.3251, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.3100
valid_model Loss = 0.3100
valid_decov Loss = 0.0000
confusion matrix:
[[2752   34]
 [ 395   41]]
accuracy = 0.866852879524231
precision class 0 = 0.8744836449623108
precision class 1 = 0.54666668176651
recall class 0 = 0.9877961277961731
recall class 1 = 0.09403669834136963
AUC of ROC = 0.8239765340463787
AUC of PRC = 0.41764399450944756
min(+P, Se) = 0.44469525959367945
f1_score = 0.160469669561909





Epoch 8 Batch 0: Train Loss = 0.2864
Model Loss = 0.2845, Decov Loss = 0.0000
Epoch 8 Batch 30: Train Loss = 0.3303
Model Loss = 0.3283, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.3109
valid_model Loss = 0.3109
valid_decov Loss = 0.0000
confusion matrix:
[[2785    1]
 [ 428    8]]
accuracy = 0.866852879524231
precision class 0 = 0.8667911887168884
precision class 1 = 0.8888888955116272
recall class 0 = 0.9996410608291626
recall class 1 = 0.01834862306714058
AUC of ROC = 0.8305724230589381
AUC of PRC = 0.4560914696666905
min(+P, Se) = 0.46788990825688076
f1_score = 0.03595505423308656


------------ Save best model ------------





Epoch 9 Batch 0: Train Loss = 0.3725
Model Loss = 0.3704, Decov Loss = 0.0000
Epoch 9 Batch 30: Train Loss = 0.3254
Model Loss = 0.3233, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.3025
valid_model Loss = 0.3025
valid_decov Loss = 0.0000
confusion matrix:
[[2723   63]
 [ 344   92]]
accuracy = 0.8736809492111206
precision class 0 = 0.8878383040428162
precision class 1 = 0.5935483574867249
recall class 0 = 0.9773869514465332
recall class 1 = 0.2110091745853424
AUC of ROC = 0.8361071412106403
AUC of PRC = 0.459163713500029
min(+P, Se) = 0.4793577981651376
f1_score = 0.31133670211979964


------------ Save best model ------------





Epoch 10 Batch 0: Train Loss = 0.2598
Model Loss = 0.2579, Decov Loss = 0.0000
Epoch 10 Batch 30: Train Loss = 0.3166
Model Loss = 0.3143, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.3093
valid_model Loss = 0.3093
valid_decov Loss = 0.0000
confusion matrix:
[[2764   22]
 [ 395   41]]
accuracy = 0.8705772757530212
precision class 0 = 0.8749604225158691
precision class 1 = 0.6507936716079712
recall class 0 = 0.9921033978462219
recall class 1 = 0.09403669834136963
AUC of ROC = 0.838248417711098
AUC of PRC = 0.4579020666907008
min(+P, Se) = 0.47477064220183485
f1_score = 0.16432865964805793


------------ Save best model ------------





Epoch 11 Batch 0: Train Loss = 0.3306
Model Loss = 0.3288, Decov Loss = 0.0000
Epoch 11 Batch 30: Train Loss = 0.3122
Model Loss = 0.3101, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.3006
valid_model Loss = 0.3006
valid_decov Loss = 0.0000
confusion matrix:
[[2764   22]
 [ 403   33]]
accuracy = 0.8680943250656128
precision class 0 = 0.8727502226829529
precision class 1 = 0.6000000238418579
recall class 0 = 0.9921033978462219
recall class 1 = 0.07568807154893875
AUC of ROC = 0.8414689765999065
AUC of PRC = 0.4541971027753259
min(+P, Se) = 0.48394495412844035
f1_score = 0.1344195511048452


------------ Save best model ------------





Epoch 12 Batch 0: Train Loss = 0.3655
Model Loss = 0.3632, Decov Loss = 0.0000
Epoch 12 Batch 30: Train Loss = 0.3105
Model Loss = 0.3082, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.3227
valid_model Loss = 0.3227
valid_decov Loss = 0.0000
confusion matrix:
[[2736   50]
 [ 348   88]]
accuracy = 0.8764742612838745
precision class 0 = 0.887159526348114
precision class 1 = 0.6376811861991882
recall class 0 = 0.9820531010627747
recall class 1 = 0.20183485746383667
AUC of ROC = 0.84243218056205
AUC of PRC = 0.4761557863542007
min(+P, Se) = 0.4782608695652174
f1_score = 0.3066202064758756


------------ Save best model ------------





Epoch 13 Batch 0: Train Loss = 0.4319
Model Loss = 0.4297, Decov Loss = 0.0000
Epoch 13 Batch 30: Train Loss = 0.3247
Model Loss = 0.3227, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.2986
valid_model Loss = 0.2985
valid_decov Loss = 0.0000
confusion matrix:
[[2720   66]
 [ 333  103]]
accuracy = 0.876163899898529
precision class 0 = 0.8909269571304321
precision class 1 = 0.6094674468040466
recall class 0 = 0.976310133934021
recall class 1 = 0.2362385392189026
AUC of ROC = 0.8407082924451879
AUC of PRC = 0.47458905149732705
min(+P, Se) = 0.481651376146789
f1_score = 0.3404958737765281





Epoch 14 Batch 0: Train Loss = 0.2870
Model Loss = 0.2850, Decov Loss = 0.0000
Epoch 14 Batch 30: Train Loss = 0.3103
Model Loss = 0.3079, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.2969
valid_model Loss = 0.2969
valid_decov Loss = 0.0000
confusion matrix:
[[2750   36]
 [ 382   54]]
accuracy = 0.8702669143676758
precision class 0 = 0.8780332207679749
precision class 1 = 0.6000000238418579
recall class 0 = 0.9870782494544983
recall class 1 = 0.12385321408510208
AUC of ROC = 0.8449472131298694
AUC of PRC = 0.47019391551178036
min(+P, Se) = 0.47139588100686497
f1_score = 0.20532320165248447


------------ Save best model ------------





Epoch 15 Batch 0: Train Loss = 0.2775
Model Loss = 0.2753, Decov Loss = 0.0000
Epoch 15 Batch 30: Train Loss = 0.3093
Model Loss = 0.3072, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.2927
valid_model Loss = 0.2927
valid_decov Loss = 0.0000
confusion matrix:
[[2728   58]
 [ 344   92]]
accuracy = 0.8752327561378479
precision class 0 = 0.8880208134651184
precision class 1 = 0.6133333444595337
recall class 0 = 0.9791816473007202
recall class 1 = 0.2110091745853424
AUC of ROC = 0.8465879528705124
AUC of PRC = 0.49092985851655246
min(+P, Se) = 0.49311926605504586
f1_score = 0.31399318717391406


------------ Save best model ------------





Epoch 16 Batch 0: Train Loss = 0.3747
Model Loss = 0.3725, Decov Loss = 0.0000
Epoch 16 Batch 30: Train Loss = 0.3151
Model Loss = 0.3130, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.2973
valid_model Loss = 0.2972
valid_decov Loss = 0.0000
confusion matrix:
[[2713   73]
 [ 327  109]]
accuracy = 0.8758534789085388
precision class 0 = 0.8924342393875122
precision class 1 = 0.598901093006134
recall class 0 = 0.9737975597381592
recall class 1 = 0.25
AUC of ROC = 0.8421045265646714
AUC of PRC = 0.47995820445807813
min(+P, Se) = 0.4724770642201835
f1_score = 0.35275080803895636





Epoch 17 Batch 0: Train Loss = 0.2437
Model Loss = 0.2413, Decov Loss = 0.0000
Epoch 17 Batch 30: Train Loss = 0.3032
Model Loss = 0.2970, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.2957
valid_model Loss = 0.2957
valid_decov Loss = 0.0000
confusion matrix:
[[2712   74]
 [ 343   93]]
accuracy = 0.8705772757530212
precision class 0 = 0.8877250552177429
precision class 1 = 0.5568862557411194
recall class 0 = 0.9734386205673218
recall class 1 = 0.21330274641513824
AUC of ROC = 0.8445265317412751
AUC of PRC = 0.46143091364364314
min(+P, Se) = 0.481651376146789
f1_score = 0.30845771558943513





Epoch 18 Batch 0: Train Loss = 0.2566
Model Loss = 0.2543, Decov Loss = 0.0000
Epoch 18 Batch 30: Train Loss = 0.2969
Model Loss = 0.2947, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.3197
valid_model Loss = 0.3197
valid_decov Loss = 0.0000
confusion matrix:
[[2703   83]
 [ 321  115]]
accuracy = 0.874612033367157
precision class 0 = 0.8938491940498352
precision class 1 = 0.5808081030845642
recall class 0 = 0.9702081680297852
recall class 1 = 0.2637614607810974
AUC of ROC = 0.8407403992439261
AUC of PRC = 0.47970657938353695
min(+P, Se) = 0.4919908466819222
f1_score = 0.3627760228580774





Epoch 19 Batch 0: Train Loss = 0.3574
Model Loss = 0.3553, Decov Loss = 0.0000
Epoch 19 Batch 30: Train Loss = 0.3072
Model Loss = 0.3050, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.2943
valid_model Loss = 0.2942
valid_decov Loss = 0.0000
confusion matrix:
[[2706   80]
 [ 327  109]]
accuracy = 0.8736809492111206
precision class 0 = 0.8921859264373779
precision class 1 = 0.5767195820808411
recall class 0 = 0.9712849855422974
recall class 1 = 0.25
AUC of ROC = 0.8462841731593749
AUC of PRC = 0.47446254447550984
min(+P, Se) = 0.48654708520179374
f1_score = 0.34880000098052977





Epoch 20 Batch 0: Train Loss = 0.3490
Model Loss = 0.3469, Decov Loss = 0.0000
Epoch 20 Batch 30: Train Loss = 0.2951
Model Loss = 0.2922, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.2890
valid_model Loss = 0.2890
valid_decov Loss = 0.0000
confusion matrix:
[[2712   74]
 [ 322  114]]
accuracy = 0.8770949840545654
precision class 0 = 0.8938694596290588
precision class 1 = 0.6063829660415649
recall class 0 = 0.9734386205673218
recall class 1 = 0.26146790385246277
AUC of ROC = 0.8516558875636375
AUC of PRC = 0.4919539819502474
min(+P, Se) = 0.5080091533180778
f1_score = 0.3653846141501993


------------ Save best model ------------





Epoch 21 Batch 0: Train Loss = 0.3242
Model Loss = 0.3220, Decov Loss = 0.0000
Epoch 21 Batch 30: Train Loss = 0.3071
Model Loss = 0.3049, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.3110
valid_model Loss = 0.3109
valid_decov Loss = 0.0000
confusion matrix:
[[2722   64]
 [ 331  105]]
accuracy = 0.8774053454399109
precision class 0 = 0.8915820717811584
precision class 1 = 0.6213017702102661
recall class 0 = 0.9770280122756958
recall class 1 = 0.24082568287849426
AUC of ROC = 0.8541330505739707
AUC of PRC = 0.5006839765218047
min(+P, Se) = 0.5045871559633027
f1_score = 0.3471074438489176


------------ Save best model ------------





Epoch 22 Batch 0: Train Loss = 0.3176
Model Loss = 0.3153, Decov Loss = 0.0000
Epoch 22 Batch 30: Train Loss = 0.3046
Model Loss = 0.2984, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.3344
valid_model Loss = 0.3344
valid_decov Loss = 0.0000
confusion matrix:
[[2739   47]
 [ 359   77]]
accuracy = 0.8739913105964661
precision class 0 = 0.8841187953948975
precision class 1 = 0.6209677457809448
recall class 0 = 0.9831299185752869
recall class 1 = 0.17660550773143768
AUC of ROC = 0.8546039502887965
AUC of PRC = 0.5011049453850902
min(+P, Se) = 0.5125858123569794
f1_score = 0.27499999905119143


------------ Save best model ------------





Epoch 23 Batch 0: Train Loss = 0.3507
Model Loss = 0.3486, Decov Loss = 0.0000
Epoch 23 Batch 30: Train Loss = 0.3139
Model Loss = 0.3119, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.2912
valid_model Loss = 0.2912
valid_decov Loss = 0.0000
confusion matrix:
[[2716   70]
 [ 330  106]]
accuracy = 0.8758534789085388
precision class 0 = 0.8916611671447754
precision class 1 = 0.6022727489471436
recall class 0 = 0.9748743772506714
recall class 1 = 0.2431192696094513
AUC of ROC = 0.8533600176505068
AUC of PRC = 0.4984131920024868
min(+P, Se) = 0.5160550458715596
f1_score = 0.3464052481629717





Epoch 24 Batch 0: Train Loss = 0.2913
Model Loss = 0.2891, Decov Loss = 0.0000
Epoch 24 Batch 30: Train Loss = 0.3024
Model Loss = 0.3001, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.2876
valid_model Loss = 0.2876
valid_decov Loss = 0.0000
confusion matrix:
[[2693   93]
 [ 301  135]]
accuracy = 0.8777157068252563
precision class 0 = 0.8994656205177307
precision class 1 = 0.5921052694320679
recall class 0 = 0.9666188359260559
recall class 1 = 0.3096330165863037
AUC of ROC = 0.8537510619941121
AUC of PRC = 0.4907950237510284
min(+P, Se) = 0.5068807339449541
f1_score = 0.40662649807277257





Epoch 25 Batch 0: Train Loss = 0.3048
Model Loss = 0.3028, Decov Loss = 0.0000
Epoch 25 Batch 30: Train Loss = 0.3026
Model Loss = 0.3004, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.2872
valid_model Loss = 0.2872
valid_decov Loss = 0.0000
confusion matrix:
[[2713   73]
 [ 324  112]]
accuracy = 0.87678462266922
precision class 0 = 0.8933157920837402
precision class 1 = 0.6054053902626038
recall class 0 = 0.9737975597381592
recall class 1 = 0.2568807303905487
AUC of ROC = 0.8550732035011228
AUC of PRC = 0.49640585701128365
min(+P, Se) = 0.5124716553287982
f1_score = 0.3607085408964025


------------ Save best model ------------





Epoch 26 Batch 0: Train Loss = 0.3252
Model Loss = 0.3238, Decov Loss = 0.0000
Epoch 26 Batch 30: Train Loss = 0.2901
Model Loss = 0.2882, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.2950
valid_model Loss = 0.2950
valid_decov Loss = 0.0000
confusion matrix:
[[2734   52]
 [ 351   85]]
accuracy = 0.8749223947525024
precision class 0 = 0.8862236738204956
precision class 1 = 0.6204379796981812
recall class 0 = 0.9813352227210999
recall class 1 = 0.19495412707328796
AUC of ROC = 0.8586304721510567
AUC of PRC = 0.5082022773131127
min(+P, Se) = 0.5252293577981652
f1_score = 0.29668410893297076


------------ Save best model ------------





Epoch 27 Batch 0: Train Loss = 0.2168
Model Loss = 0.2149, Decov Loss = 0.0000
Epoch 27 Batch 30: Train Loss = 0.2916
Model Loss = 0.2897, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.2835
valid_model Loss = 0.2835
valid_decov Loss = 0.0000
confusion matrix:
[[2703   83]
 [ 300  136]]
accuracy = 0.8811297416687012
precision class 0 = 0.9000998735427856
precision class 1 = 0.621004581451416
recall class 0 = 0.9702081680297852
recall class 1 = 0.31192660331726074
AUC of ROC = 0.8585168634785986
AUC of PRC = 0.5127662965019711
min(+P, Se) = 0.5252293577981652
f1_score = 0.4152671770418476





Epoch 28 Batch 0: Train Loss = 0.2316
Model Loss = 0.2302, Decov Loss = 0.0000
Epoch 28 Batch 30: Train Loss = 0.2933
Model Loss = 0.2918, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.2862
valid_model Loss = 0.2861
valid_decov Loss = 0.0000
confusion matrix:
[[2682  104]
 [ 286  150]]
accuracy = 0.8789571523666382
precision class 0 = 0.9036388397216797
precision class 1 = 0.5905511975288391
recall class 0 = 0.9626705050468445
recall class 1 = 0.34403669834136963
AUC of ROC = 0.8576417474001725
AUC of PRC = 0.5055776172323836
min(+P, Se) = 0.5194508009153318
f1_score = 0.4347826140208846





Epoch 29 Batch 0: Train Loss = 0.3453
Model Loss = 0.3429, Decov Loss = 0.0000
Epoch 29 Batch 30: Train Loss = 0.2852
Model Loss = 0.2828, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.2891
valid_model Loss = 0.2891
valid_decov Loss = 0.0000
confusion matrix:
[[2685  101]
 [ 290  146]]
accuracy = 0.8786467909812927
precision class 0 = 0.902521014213562
precision class 1 = 0.591093122959137
recall class 0 = 0.9637473225593567
recall class 1 = 0.3348623812198639
AUC of ROC = 0.8560306447045186
AUC of PRC = 0.4982876785443739
min(+P, Se) = 0.5160550458715596
f1_score = 0.42752560660385835





Epoch 30 Batch 0: Train Loss = 0.3085
Model Loss = 0.3066, Decov Loss = 0.0000
Epoch 30 Batch 30: Train Loss = 0.2942
Model Loss = 0.2922, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.2803
valid_model Loss = 0.2802
valid_decov Loss = 0.0000
confusion matrix:
[[2684  102]
 [ 279  157]]
accuracy = 0.8817504644393921
precision class 0 = 0.905838668346405
precision class 1 = 0.6061776280403137
recall class 0 = 0.9633883833885193
recall class 1 = 0.3600917458534241
AUC of ROC = 0.8623820280959186
AUC of PRC = 0.5190009787123826
min(+P, Se) = 0.528604118993135
f1_score = 0.4517985693756007


------------ Save best model ------------





Epoch 31 Batch 0: Train Loss = 0.2575
Model Loss = 0.2546, Decov Loss = 0.0000
Epoch 31 Batch 30: Train Loss = 0.2891
Model Loss = 0.2867, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.2898
valid_model Loss = 0.2898
valid_decov Loss = 0.0000
confusion matrix:
[[2728   58]
 [ 333  103]]
accuracy = 0.8786467909812927
precision class 0 = 0.8912120461463928
precision class 1 = 0.6397515535354614
recall class 0 = 0.9791816473007202
recall class 1 = 0.2362385392189026
AUC of ROC = 0.8634448454592755
AUC of PRC = 0.5228610373480376
min(+P, Se) = 0.5160550458715596
f1_score = 0.3450586341565403


------------ Save best model ------------





Epoch 32 Batch 0: Train Loss = 0.2851
Model Loss = 0.2808, Decov Loss = 0.0000
Epoch 32 Batch 30: Train Loss = 0.2882
Model Loss = 0.2861, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.2808
valid_model Loss = 0.2808
valid_decov Loss = 0.0000
confusion matrix:
[[2675  111]
 [ 262  174]]
accuracy = 0.8842334151268005
precision class 0 = 0.9107933044433594
precision class 1 = 0.6105263233184814
recall class 0 = 0.9601579308509827
recall class 1 = 0.39908257126808167
AUC of ROC = 0.864524127847626
AUC of PRC = 0.5294497115195314
min(+P, Se) = 0.5354691075514875
f1_score = 0.4826629864999398


------------ Save best model ------------





Epoch 33 Batch 0: Train Loss = 0.3127
Model Loss = 0.3107, Decov Loss = 0.0000
Epoch 33 Batch 30: Train Loss = 0.2842
Model Loss = 0.2826, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.2847
valid_model Loss = 0.2847
valid_decov Loss = 0.0000
confusion matrix:
[[2662  124]
 [ 245  191]]
accuracy = 0.8854748606681824
precision class 0 = 0.9157207012176514
precision class 1 = 0.6063492298126221
recall class 0 = 0.9554917216300964
recall class 1 = 0.43807339668273926
AUC of ROC = 0.8671066670179206
AUC of PRC = 0.5491014374215155
min(+P, Se) = 0.5321100917431193
f1_score = 0.5086551362283378


------------ Save best model ------------





Epoch 34 Batch 0: Train Loss = 0.2921
Model Loss = 0.2901, Decov Loss = 0.0000
Epoch 34 Batch 30: Train Loss = 0.2935
Model Loss = 0.2918, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.2952
valid_model Loss = 0.2952
valid_decov Loss = 0.0000
confusion matrix:
[[2715   71]
 [ 301  135]]
accuracy = 0.884543776512146
precision class 0 = 0.9001989364624023
precision class 1 = 0.655339777469635
recall class 0 = 0.974515438079834
recall class 1 = 0.3096330165863037
AUC of ROC = 0.8631340681125153
AUC of PRC = 0.517725010774221
min(+P, Se) = 0.5206422018348624
f1_score = 0.42056073173635455





Epoch 35 Batch 0: Train Loss = 0.3524
Model Loss = 0.3495, Decov Loss = 0.0000
Epoch 35 Batch 30: Train Loss = 0.2799
Model Loss = 0.2777, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.2777
valid_model Loss = 0.2777
valid_decov Loss = 0.0000
confusion matrix:
[[2709   77]
 [ 294  142]]
accuracy = 0.8848541378974915
precision class 0 = 0.9020978808403015
precision class 1 = 0.6484017968177795
recall class 0 = 0.9723618030548096
recall class 1 = 0.32568806409835815
AUC of ROC = 0.8651728498323861
AUC of PRC = 0.5367678361927966
min(+P, Se) = 0.5344036697247706
f1_score = 0.43358777138870236





Epoch 36 Batch 0: Train Loss = 0.3396
Model Loss = 0.3374, Decov Loss = 0.0000
Epoch 36 Batch 30: Train Loss = 0.2817
Model Loss = 0.2801, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.2768
valid_model Loss = 0.2768
valid_decov Loss = 0.0000
confusion matrix:
[[2713   73]
 [ 293  143]]
accuracy = 0.8864059448242188
precision class 0 = 0.9025282859802246
precision class 1 = 0.6620370149612427
recall class 0 = 0.9737975597381592
recall class 1 = 0.3279816508293152
AUC of ROC = 0.8665172191231386
AUC of PRC = 0.5248587161693585
min(+P, Se) = 0.5275229357798165
f1_score = 0.43865030141368333





Epoch 37 Batch 0: Train Loss = 0.3141
Model Loss = 0.3125, Decov Loss = 0.0000
Epoch 37 Batch 30: Train Loss = 0.2796
Model Loss = 0.2780, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.2889
valid_model Loss = 0.2889
valid_decov Loss = 0.0000
confusion matrix:
[[2694   92]
 [ 278  158]]
accuracy = 0.8851644992828369
precision class 0 = 0.9064602851867676
precision class 1 = 0.6320000290870667
recall class 0 = 0.9669777750968933
recall class 1 = 0.3623853325843811
AUC of ROC = 0.860902645600216
AUC of PRC = 0.5224896510326549
min(+P, Se) = 0.5273972602739726
f1_score = 0.4606414164204633





Epoch 38 Batch 0: Train Loss = 0.2967
Model Loss = 0.2949, Decov Loss = 0.0000
Epoch 38 Batch 30: Train Loss = 0.2974
Model Loss = 0.2955, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.2868
valid_model Loss = 0.2868
valid_decov Loss = 0.0000
confusion matrix:
[[2646  140]
 [ 238  198]]
accuracy = 0.8826815485954285
precision class 0 = 0.917475700378418
precision class 1 = 0.5857987999916077
recall class 0 = 0.9497487545013428
recall class 1 = 0.4541284441947937
AUC of ROC = 0.8640491118765519
AUC of PRC = 0.5451494979401328
min(+P, Se) = 0.5321100917431193
f1_score = 0.5116279030837312





Epoch 39 Batch 0: Train Loss = 0.2999
Model Loss = 0.2976, Decov Loss = 0.0000
Epoch 39 Batch 30: Train Loss = 0.2897
Model Loss = 0.2878, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.2873
valid_model Loss = 0.2873
valid_decov Loss = 0.0000
confusion matrix:
[[2725   61]
 [ 315  121]]
accuracy = 0.8833022713661194
precision class 0 = 0.8963815569877625
precision class 1 = 0.6648351550102234
recall class 0 = 0.978104829788208
recall class 1 = 0.2775229215621948
AUC of ROC = 0.8565435302330788
AUC of PRC = 0.5145086108023558
min(+P, Se) = 0.5148741418764302
f1_score = 0.391585744660433





Epoch 40 Batch 0: Train Loss = 0.2129
Model Loss = 0.2111, Decov Loss = 0.0000
Epoch 40 Batch 30: Train Loss = 0.2863
Model Loss = 0.2846, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.2850
valid_model Loss = 0.2850
valid_decov Loss = 0.0000
confusion matrix:
[[2726   60]
 [ 313  123]]
accuracy = 0.8842334151268005
precision class 0 = 0.8970056176185608
precision class 1 = 0.6721311211585999
recall class 0 = 0.9784637689590454
recall class 1 = 0.2821100950241089
AUC of ROC = 0.8627784235726469
AUC of PRC = 0.5373171731832738
min(+P, Se) = 0.530751708428246
f1_score = 0.3974151844273468





Epoch 41 Batch 0: Train Loss = 0.2554
Model Loss = 0.2539, Decov Loss = 0.0000
Epoch 41 Batch 30: Train Loss = 0.2874
Model Loss = 0.2859, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.2903
valid_model Loss = 0.2903
valid_decov Loss = 0.0000
confusion matrix:
[[2748   38]
 [ 335  101]]
accuracy = 0.8842334151268005
precision class 0 = 0.8913396000862122
precision class 1 = 0.7266187071800232
recall class 0 = 0.9863603711128235
recall class 1 = 0.23165138065814972
AUC of ROC = 0.8626841613045568
AUC of PRC = 0.546134597755174
min(+P, Se) = 0.5275229357798165
f1_score = 0.3513043587271794





Epoch 42 Batch 0: Train Loss = 0.2387
Model Loss = 0.2365, Decov Loss = 0.0000
Epoch 42 Batch 30: Train Loss = 0.2828
Model Loss = 0.2812, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.2767
valid_model Loss = 0.2767
valid_decov Loss = 0.0000
confusion matrix:
[[2676  110]
 [ 251  185]]
accuracy = 0.8879578113555908
precision class 0 = 0.914246678352356
precision class 1 = 0.6271186470985413
recall class 0 = 0.9605168700218201
recall class 1 = 0.42431193590164185
AUC of ROC = 0.865926947977107
AUC of PRC = 0.5544980518492554
min(+P, Se) = 0.5351473922902494
f1_score = 0.5061559583536562





Epoch 43 Batch 0: Train Loss = 0.2645
Model Loss = 0.2627, Decov Loss = 0.0000
Epoch 43 Batch 30: Train Loss = 0.2804
Model Loss = 0.2787, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.2810
valid_model Loss = 0.2810
valid_decov Loss = 0.0000
confusion matrix:
[[2738   48]
 [ 316  120]]
accuracy = 0.8870266675949097
precision class 0 = 0.8965291380882263
precision class 1 = 0.7142857313156128
recall class 0 = 0.9827709794044495
recall class 1 = 0.2752293646335602
AUC of ROC = 0.8647529916950414
AUC of PRC = 0.5428998832012538
min(+P, Se) = 0.5252293577981652
f1_score = 0.39735101510346965





Epoch 44 Batch 0: Train Loss = 0.2341
Model Loss = 0.2327, Decov Loss = 0.0000
Epoch 44 Batch 30: Train Loss = 0.2797
Model Loss = 0.2781, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.2753
valid_model Loss = 0.2753
valid_decov Loss = 0.0000
confusion matrix:
[[2718   68]
 [ 296  140]]
accuracy = 0.8870266675949097
precision class 0 = 0.9017916321754456
precision class 1 = 0.6730769276618958
recall class 0 = 0.9755922555923462
recall class 1 = 0.3211009204387665
AUC of ROC = 0.8657244281696818
AUC of PRC = 0.5529917544263591
min(+P, Se) = 0.5423340961098398
f1_score = 0.4347826254427159





Epoch 45 Batch 0: Train Loss = 0.2048
Model Loss = 0.2035, Decov Loss = 0.0000
Epoch 45 Batch 30: Train Loss = 0.2742
Model Loss = 0.2727, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.2775
valid_model Loss = 0.2775
valid_decov Loss = 0.0000
confusion matrix:
[[2679  107]
 [ 250  186]]
accuracy = 0.8891992568969727
precision class 0 = 0.9146466255187988
precision class 1 = 0.6348122954368591
recall class 0 = 0.9615936875343323
recall class 1 = 0.4266054928302765
AUC of ROC = 0.8679315647701153
AUC of PRC = 0.5573810693197566
min(+P, Se) = 0.5303370786516854
f1_score = 0.5102880459311049


------------ Save best model ------------





Epoch 46 Batch 0: Train Loss = 0.2528
Model Loss = 0.2520, Decov Loss = 0.0000
Epoch 46 Batch 30: Train Loss = 0.2906
Model Loss = 0.2892, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.2824
valid_model Loss = 0.2824
valid_decov Loss = 0.0000
confusion matrix:
[[2730   56]
 [ 309  127]]
accuracy = 0.8867163062095642
precision class 0 = 0.898321807384491
precision class 1 = 0.693989098072052
recall class 0 = 0.979899525642395
recall class 1 = 0.2912844121456146
AUC of ROC = 0.8706400613816131
AUC of PRC = 0.56308239676943
min(+P, Se) = 0.536697247706422
f1_score = 0.41033928241358314


------------ Save best model ------------





Epoch 47 Batch 0: Train Loss = 0.2574
Model Loss = 0.2557, Decov Loss = 0.0000
Epoch 47 Batch 30: Train Loss = 0.2825
Model Loss = 0.2809, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.2780
valid_model Loss = 0.2780
valid_decov Loss = 0.0000
confusion matrix:
[[2678  108]
 [ 263  173]]
accuracy = 0.8848541378974915
precision class 0 = 0.9105746150016785
precision class 1 = 0.6156583428382874
recall class 0 = 0.9612347483634949
recall class 1 = 0.39678898453712463
AUC of ROC = 0.8686428538498522
AUC of PRC = 0.5568178046785367
min(+P, Se) = 0.540045766590389
f1_score = 0.4825662232110082





Epoch 48 Batch 0: Train Loss = 0.2907
Model Loss = 0.2894, Decov Loss = 0.0000
Epoch 48 Batch 30: Train Loss = 0.2784
Model Loss = 0.2768, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.2750
valid_model Loss = 0.2750
valid_decov Loss = 0.0000
confusion matrix:
[[2720   66]
 [ 286  150]]
accuracy = 0.8907510638237
precision class 0 = 0.9048569798469543
precision class 1 = 0.6944444179534912
recall class 0 = 0.976310133934021
recall class 1 = 0.34403669834136963
AUC of ROC = 0.866302350546968
AUC of PRC = 0.5619400544020202
min(+P, Se) = 0.5354691075514875
f1_score = 0.46012269454975274





Epoch 49 Batch 0: Train Loss = 0.2297
Model Loss = 0.2274, Decov Loss = 0.0000
Epoch 49 Batch 30: Train Loss = 0.2741
Model Loss = 0.2726, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.2744
valid_model Loss = 0.2744
valid_decov Loss = 0.0000
confusion matrix:
[[2710   76]
 [ 284  152]]
accuracy = 0.8882681727409363
precision class 0 = 0.9051436185836792
precision class 1 = 0.6666666865348816
recall class 0 = 0.972720742225647
recall class 1 = 0.3486238420009613
AUC of ROC = 0.8677175194451946
AUC of PRC = 0.5691226629097711
min(+P, Se) = 0.5354691075514875
f1_score = 0.45783133375867885





Epoch 50 Batch 0: Train Loss = 0.3205
Model Loss = 0.3192, Decov Loss = 0.0000
Epoch 50 Batch 30: Train Loss = 0.2731
Model Loss = 0.2719, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.2844
valid_model Loss = 0.2844
valid_decov Loss = 0.0000
confusion matrix:
[[2748   38]
 [ 322  114]]
accuracy = 0.8882681727409363
precision class 0 = 0.895114004611969
precision class 1 = 0.75
recall class 0 = 0.9863603711128235
recall class 1 = 0.26146790385246277
AUC of ROC = 0.8631056659444009
AUC of PRC = 0.5559904967322553
min(+P, Se) = 0.5377574370709383
f1_score = 0.3877551059493627





Epoch 51 Batch 0: Train Loss = 0.3130
Model Loss = 0.3117, Decov Loss = 0.0000
Epoch 51 Batch 30: Train Loss = 0.2847
Model Loss = 0.2833, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.2748
valid_model Loss = 0.2748
valid_decov Loss = 0.0000
confusion matrix:
[[2710   76]
 [ 279  157]]
accuracy = 0.8898199796676636
precision class 0 = 0.9066577553749084
precision class 1 = 0.6738197207450867
recall class 0 = 0.972720742225647
recall class 1 = 0.3600917458534241
AUC of ROC = 0.8668930333186246
AUC of PRC = 0.5677710849907878
min(+P, Se) = 0.54337899543379
f1_score = 0.4693572466737829





Epoch 52 Batch 0: Train Loss = 0.2760
Model Loss = 0.2749, Decov Loss = 0.0000
Epoch 52 Batch 30: Train Loss = 0.2804
Model Loss = 0.2792, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.2876
valid_model Loss = 0.2876
valid_decov Loss = 0.0000
confusion matrix:
[[2753   33]
 [ 329  107]]
accuracy = 0.8876474499702454
precision class 0 = 0.8932511210441589
precision class 1 = 0.7642857432365417
recall class 0 = 0.9881550669670105
recall class 1 = 0.24541284143924713
AUC of ROC = 0.8674260885028023
AUC of PRC = 0.563767099089874
min(+P, Se) = 0.5481651376146789
f1_score = 0.37152776177282787





Epoch 53 Batch 0: Train Loss = 0.2608
Model Loss = 0.2599, Decov Loss = 0.0000
Epoch 53 Batch 30: Train Loss = 0.2764
Model Loss = 0.2751, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.2843
valid_model Loss = 0.2843
valid_decov Loss = 0.0000
confusion matrix:
[[2726   60]
 [ 300  136]]
accuracy = 0.8882681727409363
precision class 0 = 0.9008592367172241
precision class 1 = 0.6938775777816772
recall class 0 = 0.9784637689590454
recall class 1 = 0.31192660331726074
AUC of ROC = 0.8673264751015892
AUC of PRC = 0.556316584647635
min(+P, Se) = 0.536697247706422
f1_score = 0.430379749901152





Epoch 54 Batch 0: Train Loss = 0.2704
Model Loss = 0.2694, Decov Loss = 0.0000
Epoch 54 Batch 30: Train Loss = 0.2710
Model Loss = 0.2698, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.2717
valid_model Loss = 0.2717
valid_decov Loss = 0.0000
confusion matrix:
[[2724   62]
 [ 292  144]]
accuracy = 0.890130341053009
precision class 0 = 0.9031830430030823
precision class 1 = 0.6990291476249695
recall class 0 = 0.9777458906173706
recall class 1 = 0.3302752375602722
AUC of ROC = 0.8709059715352647
AUC of PRC = 0.5600323254747174
min(+P, Se) = 0.5377574370709383
f1_score = 0.4485981428163096


------------ Save best model ------------





Epoch 55 Batch 0: Train Loss = 0.2766
Model Loss = 0.2759, Decov Loss = 0.0000
Epoch 55 Batch 30: Train Loss = 0.2784
Model Loss = 0.2775, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.2790
valid_model Loss = 0.2790
valid_decov Loss = 0.0000
confusion matrix:
[[2737   49]
 [ 302  134]]
accuracy = 0.8910614252090454
precision class 0 = 0.9006252288818359
precision class 1 = 0.7322404384613037
recall class 0 = 0.9824120402336121
recall class 1 = 0.30733945965766907
AUC of ROC = 0.8696974387007119
AUC of PRC = 0.5675578258278792
min(+P, Se) = 0.5377574370709383
f1_score = 0.43295640393767965





Epoch 56 Batch 0: Train Loss = 0.3025
Model Loss = 0.3016, Decov Loss = 0.0000
Epoch 56 Batch 30: Train Loss = 0.2714
Model Loss = 0.2705, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.2744
valid_model Loss = 0.2744
valid_decov Loss = 0.0000
confusion matrix:
[[2713   73]
 [ 283  153]]
accuracy = 0.8895096182823181
precision class 0 = 0.9055407047271729
precision class 1 = 0.6769911646842957
recall class 0 = 0.9737975597381592
recall class 1 = 0.35091742873191833
AUC of ROC = 0.8668341708542713
AUC of PRC = 0.5609776131435793
min(+P, Se) = 0.5344036697247706
f1_score = 0.4622356641334029





Epoch 57 Batch 0: Train Loss = 0.2312
Model Loss = 0.2301, Decov Loss = 0.0000
Epoch 57 Batch 30: Train Loss = 0.2726
Model Loss = 0.2716, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.2783
valid_model Loss = 0.2783
valid_decov Loss = 0.0000
confusion matrix:
[[2750   36]
 [ 332  104]]
accuracy = 0.8857852220535278
precision class 0 = 0.892277717590332
precision class 1 = 0.7428571581840515
recall class 0 = 0.9870782494544983
recall class 1 = 0.23853211104869843
AUC of ROC = 0.8709882966602343
AUC of PRC = 0.568956764583817
min(+P, Se) = 0.5435779816513762
f1_score = 0.36111110853560185


------------ Save best model ------------





Epoch 58 Batch 0: Train Loss = 0.1771
Model Loss = 0.1761, Decov Loss = 0.0000
Epoch 58 Batch 30: Train Loss = 0.2789
Model Loss = 0.2778, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.2861
valid_model Loss = 0.2861
valid_decov Loss = 0.0000
confusion matrix:
[[2727   59]
 [ 301  135]]
accuracy = 0.8882681727409363
precision class 0 = 0.9005944728851318
precision class 1 = 0.6958763003349304
recall class 0 = 0.9788227081298828
recall class 1 = 0.3096330165863037
AUC of ROC = 0.8705017551716644
AUC of PRC = 0.5694823098471811
min(+P, Se) = 0.5412844036697247
f1_score = 0.42857139490449236





Epoch 59 Batch 0: Train Loss = 0.2290
Model Loss = 0.2279, Decov Loss = 0.0000
Epoch 59 Batch 30: Train Loss = 0.2696
Model Loss = 0.2687, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.2960
valid_model Loss = 0.2960
valid_decov Loss = 0.0000
confusion matrix:
[[2760   26]
 [ 343   93]]
accuracy = 0.8854748606681824
precision class 0 = 0.8894618153572083
precision class 1 = 0.7815126180648804
recall class 0 = 0.9906676411628723
recall class 1 = 0.21330274641513824
AUC of ROC = 0.8684987848811554
AUC of PRC = 0.5671718399220259
min(+P, Se) = 0.5389908256880734
f1_score = 0.33513513409678247





Epoch 60 Batch 0: Train Loss = 0.2187
Model Loss = 0.2174, Decov Loss = 0.0000
Epoch 60 Batch 30: Train Loss = 0.2738
Model Loss = 0.2728, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.2722
valid_model Loss = 0.2722
valid_decov Loss = 0.0000
confusion matrix:
[[2703   83]
 [ 285  151]]
accuracy = 0.8857852220535278
precision class 0 = 0.904618501663208
precision class 1 = 0.6452991366386414
recall class 0 = 0.9702081680297852
recall class 1 = 0.34633028507232666
AUC of ROC = 0.8683925854699447
AUC of PRC = 0.5646270116516828
min(+P, Se) = 0.5423340961098398
f1_score = 0.45074627488034935





Epoch 61 Batch 0: Train Loss = 0.2295
Model Loss = 0.2285, Decov Loss = 0.0000
Epoch 61 Batch 30: Train Loss = 0.2753
Model Loss = 0.2742, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.2857
valid_model Loss = 0.2857
valid_decov Loss = 0.0000
confusion matrix:
[[2737   49]
 [ 309  127]]
accuracy = 0.8888888955116272
precision class 0 = 0.8985554575920105
precision class 1 = 0.7215909361839294
recall class 0 = 0.9824120402336121
recall class 1 = 0.2912844121456146
AUC of ROC = 0.8692108972121421
AUC of PRC = 0.5702676906669268
min(+P, Se) = 0.5431818181818182
f1_score = 0.41503270503533485





Epoch 62 Batch 0: Train Loss = 0.2206
Model Loss = 0.2193, Decov Loss = 0.0000
Epoch 62 Batch 30: Train Loss = 0.2664
Model Loss = 0.2653, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.2725
valid_model Loss = 0.2725
valid_decov Loss = 0.0000
confusion matrix:
[[2685  101]
 [ 264  172]]
accuracy = 0.8867163062095642
precision class 0 = 0.9104781150817871
precision class 1 = 0.6300366520881653
recall class 0 = 0.9637473225593567
recall class 1 = 0.39449542760849
AUC of ROC = 0.8708137673952989
AUC of PRC = 0.5674967908578493
min(+P, Se) = 0.551487414187643
f1_score = 0.48519042673243024





Epoch 63 Batch 0: Train Loss = 0.3268
Model Loss = 0.3258, Decov Loss = 0.0000
Epoch 63 Batch 30: Train Loss = 0.2739
Model Loss = 0.2729, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.2716
valid_model Loss = 0.2716
valid_decov Loss = 0.0000
confusion matrix:
[[2703   83]
 [ 285  151]]
accuracy = 0.8857852220535278
precision class 0 = 0.904618501663208
precision class 1 = 0.6452991366386414
recall class 0 = 0.9702081680297852
recall class 1 = 0.34633028507232666
AUC of ROC = 0.8700917760493161
AUC of PRC = 0.5709251279828101
min(+P, Se) = 0.5389908256880734
f1_score = 0.45074627488034935





Epoch 64 Batch 0: Train Loss = 0.2567
Model Loss = 0.2556, Decov Loss = 0.0000
Epoch 64 Batch 30: Train Loss = 0.2715
Model Loss = 0.2693, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.2763
valid_model Loss = 0.2762
valid_decov Loss = 0.0000
confusion matrix:
[[2695   91]
 [ 266  170]]
accuracy = 0.8891992568969727
precision class 0 = 0.9101654887199402
precision class 1 = 0.6513410210609436
recall class 0 = 0.9673366546630859
recall class 1 = 0.3899082541465759
AUC of ROC = 0.8678986347201276
AUC of PRC = 0.5697336043343034
min(+P, Se) = 0.5298165137614679
f1_score = 0.4878048828899423





Epoch 65 Batch 0: Train Loss = 0.2776
Model Loss = 0.2765, Decov Loss = 0.0000
Epoch 65 Batch 30: Train Loss = 0.2653
Model Loss = 0.2641, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.2745
valid_model Loss = 0.2745
valid_decov Loss = 0.0000
confusion matrix:
[[2716   70]
 [ 295  141]]
accuracy = 0.8867163062095642
precision class 0 = 0.9020258784294128
precision class 1 = 0.6682464480400085
recall class 0 = 0.9748743772506714
recall class 1 = 0.3233945071697235
AUC of ROC = 0.8697534197856913
AUC of PRC = 0.567507618677515
min(+P, Se) = 0.5321100917431193
f1_score = 0.4358578295727972





Epoch 66 Batch 0: Train Loss = 0.2760
Model Loss = 0.2752, Decov Loss = 0.0000
Epoch 66 Batch 30: Train Loss = 0.2710
Model Loss = 0.2703, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.2708
valid_model Loss = 0.2708
valid_decov Loss = 0.0000
confusion matrix:
[[2702   84]
 [ 281  155]]
accuracy = 0.8867163062095642
precision class 0 = 0.9057995080947876
precision class 1 = 0.6485355496406555
recall class 0 = 0.9698492288589478
recall class 1 = 0.35550457239151
AUC of ROC = 0.8711743514426655
AUC of PRC = 0.5722998119268496
min(+P, Se) = 0.5504587155963303
f1_score = 0.4592592431247643


------------ Save best model ------------





Epoch 67 Batch 0: Train Loss = 0.2395
Model Loss = 0.2384, Decov Loss = 0.0000
Epoch 67 Batch 30: Train Loss = 0.2662
Model Loss = 0.2653, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.2732
valid_model Loss = 0.2732
valid_decov Loss = 0.0000
confusion matrix:
[[2707   79]
 [ 284  152]]
accuracy = 0.8873370289802551
precision class 0 = 0.9050484895706177
precision class 1 = 0.6580086350440979
recall class 0 = 0.9716439247131348
recall class 1 = 0.3486238420009613
AUC of ROC = 0.8711998722314062
AUC of PRC = 0.5755274127385928
min(+P, Se) = 0.5504587155963303
f1_score = 0.4557721123479054


------------ Save best model ------------





Epoch 68 Batch 0: Train Loss = 0.2085
Model Loss = 0.2077, Decov Loss = 0.0000
Epoch 68 Batch 30: Train Loss = 0.2664
Model Loss = 0.2656, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.2708
valid_model Loss = 0.2708
valid_decov Loss = 0.0000
confusion matrix:
[[2701   85]
 [ 270  166]]
accuracy = 0.8898199796676636
precision class 0 = 0.9091215133666992
precision class 1 = 0.6613546013832092
recall class 0 = 0.9694902896881104
recall class 1 = 0.3807339370250702
AUC of ROC = 0.8722165875247798
AUC of PRC = 0.5782008893923805
min(+P, Se) = 0.5469107551487414
f1_score = 0.483260565824909


------------ Save best model ------------





Epoch 69 Batch 0: Train Loss = 0.2925
Model Loss = 0.2919, Decov Loss = 0.0000
Epoch 69 Batch 30: Train Loss = 0.2671
Model Loss = 0.2664, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.2741
valid_model Loss = 0.2741
valid_decov Loss = 0.0000
confusion matrix:
[[2715   71]
 [ 295  141]]
accuracy = 0.8864059448242188
precision class 0 = 0.9019933342933655
precision class 1 = 0.6650943160057068
recall class 0 = 0.974515438079834
recall class 1 = 0.3233945071697235
AUC of ROC = 0.8723491309759808
AUC of PRC = 0.5696607469756911
min(+P, Se) = 0.5389908256880734
f1_score = 0.4351852038951025


------------ Save best model ------------





Epoch 70 Batch 0: Train Loss = 0.2603
Model Loss = 0.2598, Decov Loss = 0.0000
Epoch 70 Batch 30: Train Loss = 0.2656
Model Loss = 0.2648, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.2795
valid_model Loss = 0.2795
valid_decov Loss = 0.0000
confusion matrix:
[[2738   48]
 [ 321  115]]
accuracy = 0.8854748606681824
precision class 0 = 0.8950637578964233
precision class 1 = 0.7055214643478394
recall class 0 = 0.9827709794044495
recall class 1 = 0.2637614607810974
AUC of ROC = 0.8740063357416177
AUC of PRC = 0.5788736695656503
min(+P, Se) = 0.5423340961098398
f1_score = 0.3839732800906421


------------ Save best model ------------





Epoch 71 Batch 0: Train Loss = 0.3247
Model Loss = 0.3238, Decov Loss = 0.0000
Epoch 71 Batch 30: Train Loss = 0.2629
Model Loss = 0.2621, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.2718
valid_model Loss = 0.2718
valid_decov Loss = 0.0000
confusion matrix:
[[2698   88]
 [ 266  170]]
accuracy = 0.890130341053009
precision class 0 = 0.9102563858032227
precision class 1 = 0.6589147448539734
recall class 0 = 0.9684134721755981
recall class 1 = 0.3899082541465759
AUC of ROC = 0.8717119345087166
AUC of PRC = 0.5765052007017571
min(+P, Se) = 0.5358744394618834
f1_score = 0.4899135469803227





Epoch 72 Batch 0: Train Loss = 0.2789
Model Loss = 0.2779, Decov Loss = 0.0000
Epoch 72 Batch 30: Train Loss = 0.2690
Model Loss = 0.2680, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.2852
valid_model Loss = 0.2852
valid_decov Loss = 0.0000
confusion matrix:
[[2749   37]
 [ 326  110]]
accuracy = 0.8873370289802551
precision class 0 = 0.8939837217330933
precision class 1 = 0.7482993006706238
recall class 0 = 0.9867193102836609
recall class 1 = 0.25229358673095703
AUC of ROC = 0.8701247060993038
AUC of PRC = 0.5720156062649483
min(+P, Se) = 0.5298165137614679
f1_score = 0.37735847545062473





Epoch 73 Batch 0: Train Loss = 0.3004
Model Loss = 0.2995, Decov Loss = 0.0000
Epoch 73 Batch 30: Train Loss = 0.2698
Model Loss = 0.2681, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.2724
valid_model Loss = 0.2724
valid_decov Loss = 0.0000
confusion matrix:
[[2735   51]
 [ 299  137]]
accuracy = 0.8913718461990356
precision class 0 = 0.9014502167701721
precision class 1 = 0.728723406791687
recall class 0 = 0.9816941618919373
recall class 1 = 0.3142201900482178
AUC of ROC = 0.8709454875952501
AUC of PRC = 0.5847361246931813
min(+P, Se) = 0.5412844036697247
f1_score = 0.43910257097022615





Epoch 74 Batch 0: Train Loss = 0.2340
Model Loss = 0.2332, Decov Loss = 0.0000
Epoch 74 Batch 30: Train Loss = 0.2644
Model Loss = 0.2634, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.2778
valid_model Loss = 0.2778
valid_decov Loss = 0.0000
confusion matrix:
[[2737   49]
 [ 314  122]]
accuracy = 0.8873370289802551
precision class 0 = 0.8970829248428345
precision class 1 = 0.7134503126144409
recall class 0 = 0.9824120402336121
recall class 1 = 0.27981650829315186
AUC of ROC = 0.869099758293433
AUC of PRC = 0.5703810934345213
min(+P, Se) = 0.5342465753424658
f1_score = 0.4019769333158935





Epoch 75 Batch 0: Train Loss = 0.2331
Model Loss = 0.2324, Decov Loss = 0.0000
Epoch 75 Batch 30: Train Loss = 0.2698
Model Loss = 0.2691, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.2806
valid_model Loss = 0.2806
valid_decov Loss = 0.0000
confusion matrix:
[[2751   35]
 [ 320  116]]
accuracy = 0.8898199796676636
precision class 0 = 0.8957993984222412
precision class 1 = 0.7682119011878967
recall class 0 = 0.9874371886253357
recall class 1 = 0.26605504751205444
AUC of ROC = 0.8727368823145875
AUC of PRC = 0.5854820261444942
min(+P, Se) = 0.5560640732265446
f1_score = 0.39522998221451566





Epoch 76 Batch 0: Train Loss = 0.3647
Model Loss = 0.3641, Decov Loss = 0.0000
Epoch 76 Batch 30: Train Loss = 0.2678
Model Loss = 0.2663, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.2737
valid_model Loss = 0.2737
valid_decov Loss = 0.0000
confusion matrix:
[[2748   38]
 [ 322  114]]
accuracy = 0.8882681727409363
precision class 0 = 0.895114004611969
precision class 1 = 0.75
recall class 0 = 0.9863603711128235
recall class 1 = 0.26146790385246277
AUC of ROC = 0.8692240692321372
AUC of PRC = 0.5795074782818777
min(+P, Se) = 0.5570776255707762
f1_score = 0.3877551059493627





Epoch 77 Batch 0: Train Loss = 0.2476
Model Loss = 0.2464, Decov Loss = 0.0000
Epoch 77 Batch 30: Train Loss = 0.2645
Model Loss = 0.2636, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.2733
valid_model Loss = 0.2733
valid_decov Loss = 0.0000
confusion matrix:
[[2734   52]
 [ 295  141]]
accuracy = 0.892302930355072
precision class 0 = 0.9026080965995789
precision class 1 = 0.7305699586868286
recall class 0 = 0.9813352227210999
recall class 1 = 0.3233945071697235
AUC of ROC = 0.8715316424850332
AUC of PRC = 0.5881243627120508
min(+P, Se) = 0.5444191343963554
f1_score = 0.448330684222594





Epoch 78 Batch 0: Train Loss = 0.2468
Model Loss = 0.2460, Decov Loss = 0.0000
Epoch 78 Batch 30: Train Loss = 0.2593
Model Loss = 0.2585, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.2717
valid_model Loss = 0.2717
valid_decov Loss = 0.0000
confusion matrix:
[[2673  113]
 [ 236  200]]
accuracy = 0.8916822075843811
precision class 0 = 0.9188724756240845
precision class 1 = 0.6389776468276978
recall class 0 = 0.9594400525093079
recall class 1 = 0.4587155878543854
AUC of ROC = 0.8726273898983778
AUC of PRC = 0.5887653166638621
min(+P, Se) = 0.5510204081632653
f1_score = 0.534045406471266





Epoch 79 Batch 0: Train Loss = 0.2493
Model Loss = 0.2487, Decov Loss = 0.0000
Epoch 79 Batch 30: Train Loss = 0.2670
Model Loss = 0.2663, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.2767
valid_model Loss = 0.2767
valid_decov Loss = 0.0000
confusion matrix:
[[2746   40]
 [ 327  109]]
accuracy = 0.8860955834388733
precision class 0 = 0.8935893177986145
precision class 1 = 0.7315436005592346
recall class 0 = 0.9856424927711487
recall class 1 = 0.25
AUC of ROC = 0.871256676567635
AUC of PRC = 0.5821260947357688
min(+P, Se) = 0.5527522935779816
f1_score = 0.37264956958735074





Epoch 80 Batch 0: Train Loss = 0.2785
Model Loss = 0.2779, Decov Loss = 0.0000
Epoch 80 Batch 30: Train Loss = 0.2664
Model Loss = 0.2656, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.2717
valid_model Loss = 0.2717
valid_decov Loss = 0.0000
confusion matrix:
[[2681  105]
 [ 254  182]]
accuracy = 0.8885785341262817
precision class 0 = 0.9134582877159119
precision class 1 = 0.6341463327407837
recall class 0 = 0.9623115658760071
recall class 1 = 0.41743120551109314
AUC of ROC = 0.8718271896836739
AUC of PRC = 0.5791413253525117
min(+P, Se) = 0.536697247706422
f1_score = 0.5034578069903839





Epoch 81 Batch 0: Train Loss = 0.2706
Model Loss = 0.2700, Decov Loss = 0.0000
Epoch 81 Batch 30: Train Loss = 0.2623
Model Loss = 0.2616, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.2724
valid_model Loss = 0.2724
valid_decov Loss = 0.0000
confusion matrix:
[[2693   93]
 [ 258  178]]
accuracy = 0.8910614252090454
precision class 0 = 0.9125720262527466
precision class 1 = 0.6568265557289124
recall class 0 = 0.9666188359260559
recall class 1 = 0.4082568883895874
AUC of ROC = 0.8707174469990845
AUC of PRC = 0.5788240265516924
min(+P, Se) = 0.5560640732265446
f1_score = 0.5035360418524312





Epoch 82 Batch 0: Train Loss = 0.2514
Model Loss = 0.2503, Decov Loss = 0.0000
Epoch 82 Batch 30: Train Loss = 0.2731
Model Loss = 0.2723, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.2739
valid_model Loss = 0.2738
valid_decov Loss = 0.0000
confusion matrix:
[[2718   68]
 [ 290  146]]
accuracy = 0.8888888955116272
precision class 0 = 0.9035904407501221
precision class 1 = 0.6822429895401001
recall class 0 = 0.9755922555923462
recall class 1 = 0.3348623812198639
AUC of ROC = 0.8683234323649704
AUC of PRC = 0.5701168077920388
min(+P, Se) = 0.528604118993135
f1_score = 0.44923077846165976





Epoch 83 Batch 0: Train Loss = 0.2692
Model Loss = 0.2686, Decov Loss = 0.0000
Epoch 83 Batch 30: Train Loss = 0.2661
Model Loss = 0.2651, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.2856
valid_model Loss = 0.2856
valid_decov Loss = 0.0000
confusion matrix:
[[2766   20]
 [ 360   76]]
accuracy = 0.8820608258247375
precision class 0 = 0.8848368525505066
precision class 1 = 0.7916666865348816
recall class 0 = 0.9928212761878967
recall class 1 = 0.17431192100048065
AUC of ROC = 0.8662669507432312
AUC of PRC = 0.565788692112613
min(+P, Se) = 0.5252293577981652
f1_score = 0.28571427507144165





Epoch 84 Batch 0: Train Loss = 0.3637
Model Loss = 0.3629, Decov Loss = 0.0000
Epoch 84 Batch 30: Train Loss = 0.2657
Model Loss = 0.2641, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.2748
valid_model Loss = 0.2748
valid_decov Loss = 0.0000
confusion matrix:
[[2692   94]
 [ 267  169]]
accuracy = 0.8879578113555908
precision class 0 = 0.9097667932510376
precision class 1 = 0.6425855755805969
recall class 0 = 0.9662598967552185
recall class 1 = 0.3876146674156189
AUC of ROC = 0.8669370772604834
AUC of PRC = 0.5685055823502823
min(+P, Se) = 0.5435779816513762
f1_score = 0.4835479235383213





Epoch 85 Batch 0: Train Loss = 0.2587
Model Loss = 0.2580, Decov Loss = 0.0000
Epoch 85 Batch 30: Train Loss = 0.2589
Model Loss = 0.2581, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.2843
valid_model Loss = 0.2843
valid_decov Loss = 0.0000
confusion matrix:
[[2663  123]
 [ 247  189]]
accuracy = 0.8851644992828369
precision class 0 = 0.9151203036308289
precision class 1 = 0.6057692170143127
recall class 0 = 0.9558506608009338
recall class 1 = 0.4334862530231476
AUC of ROC = 0.8692397110058814
AUC of PRC = 0.5638352838794094
min(+P, Se) = 0.5527522935779816
f1_score = 0.5053475841518916





Epoch 86 Batch 0: Train Loss = 0.3108
Model Loss = 0.3102, Decov Loss = 0.0000
Epoch 86 Batch 30: Train Loss = 0.2774
Model Loss = 0.2765, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.2813
valid_model Loss = 0.2813
valid_decov Loss = 0.0000
confusion matrix:
[[2744   42]
 [ 326  110]]
accuracy = 0.8857852220535278
precision class 0 = 0.893811047077179
precision class 1 = 0.7236841917037964
recall class 0 = 0.9849246144294739
recall class 1 = 0.25229358673095703
AUC of ROC = 0.8679241555088679
AUC of PRC = 0.5720922002269385
min(+P, Se) = 0.5435779816513762
f1_score = 0.3741496669693905





Epoch 87 Batch 0: Train Loss = 0.2466
Model Loss = 0.2459, Decov Loss = 0.0000
Epoch 87 Batch 30: Train Loss = 0.2682
Model Loss = 0.2675, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.2756
valid_model Loss = 0.2756
valid_decov Loss = 0.0000
confusion matrix:
[[2674  112]
 [ 242  194]]
accuracy = 0.890130341053009
precision class 0 = 0.9170095920562744
precision class 1 = 0.6339869499206543
recall class 0 = 0.9597989916801453
recall class 1 = 0.44495412707328796
AUC of ROC = 0.8683382508874647
AUC of PRC = 0.5714530547224652
min(+P, Se) = 0.5458715596330275
f1_score = 0.5229110432457901





Epoch 88 Batch 0: Train Loss = 0.3091
Model Loss = 0.3085, Decov Loss = 0.0000
Epoch 88 Batch 30: Train Loss = 0.2662
Model Loss = 0.2654, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.2747
valid_model Loss = 0.2746
valid_decov Loss = 0.0000
confusion matrix:
[[2721   65]
 [ 281  155]]
accuracy = 0.8926132917404175
precision class 0 = 0.9063957333564758
precision class 1 = 0.7045454382896423
recall class 0 = 0.9766690731048584
recall class 1 = 0.35550457239151
AUC of ROC = 0.869564895249511
AUC of PRC = 0.5781623345780929
min(+P, Se) = 0.5504587155963303
f1_score = 0.472560958909109





Epoch 89 Batch 0: Train Loss = 0.2916
Model Loss = 0.2905, Decov Loss = 0.0000
Epoch 89 Batch 30: Train Loss = 0.2575
Model Loss = 0.2566, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.2819
valid_model Loss = 0.2819
valid_decov Loss = 0.0000
confusion matrix:
[[2730   56]
 [ 301  135]]
accuracy = 0.8891992568969727
precision class 0 = 0.9006928205490112
precision class 1 = 0.7068063020706177
recall class 0 = 0.979899525642395
recall class 1 = 0.3096330165863037
AUC of ROC = 0.8677282217114406
AUC of PRC = 0.5729928275835335
min(+P, Se) = 0.5481651376146789
f1_score = 0.430622002583519





Epoch 90 Batch 0: Train Loss = 0.2620
Model Loss = 0.2613, Decov Loss = 0.0000
Epoch 90 Batch 30: Train Loss = 0.2754
Model Loss = 0.2748, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.2837
valid_model Loss = 0.2837
valid_decov Loss = 0.0000
confusion matrix:
[[2734   52]
 [ 305  131]]
accuracy = 0.8891992568969727
precision class 0 = 0.8996380567550659
precision class 1 = 0.7158470153808594
recall class 0 = 0.9813352227210999
recall class 1 = 0.30045872926712036
AUC of ROC = 0.86643242424442
AUC of PRC = 0.5667000320729504
min(+P, Se) = 0.5435779816513762
f1_score = 0.4232633699807017





Epoch 91 Batch 0: Train Loss = 0.2826
Model Loss = 0.2819, Decov Loss = 0.0000
Epoch 91 Batch 30: Train Loss = 0.2607
Model Loss = 0.2600, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.2827
valid_model Loss = 0.2827
valid_decov Loss = 0.0000
confusion matrix:
[[2724   62]
 [ 295  141]]
accuracy = 0.8891992568969727
precision class 0 = 0.9022855162620544
precision class 1 = 0.6945812702178955
recall class 0 = 0.9777458906173706
recall class 1 = 0.3233945071697235
AUC of ROC = 0.8674894788490288
AUC of PRC = 0.5747025739164172
min(+P, Se) = 0.5504587155963303
f1_score = 0.4413145498840424





Epoch 92 Batch 0: Train Loss = 0.1917
Model Loss = 0.1904, Decov Loss = 0.0000
Epoch 92 Batch 30: Train Loss = 0.2555
Model Loss = 0.2547, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.2723
valid_model Loss = 0.2723
valid_decov Loss = 0.0000
confusion matrix:
[[2738   48]
 [ 323  113]]
accuracy = 0.8848541378974915
precision class 0 = 0.8944789171218872
precision class 1 = 0.7018633484840393
recall class 0 = 0.9827709794044495
recall class 1 = 0.25917431712150574
AUC of ROC = 0.8705355084729018
AUC of PRC = 0.5732813545855722
min(+P, Se) = 0.536697247706422
f1_score = 0.37855948045974397





Epoch 93 Batch 0: Train Loss = 0.2639
Model Loss = 0.2631, Decov Loss = 0.0000
Epoch 93 Batch 30: Train Loss = 0.2622
Model Loss = 0.2614, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.2846
valid_model Loss = 0.2846
valid_decov Loss = 0.0000
confusion matrix:
[[2753   33]
 [ 342   94]]
accuracy = 0.8836126923561096
precision class 0 = 0.8894991874694824
precision class 1 = 0.7401574850082397
recall class 0 = 0.9881550669670105
recall class 1 = 0.21559633314609528
AUC of ROC = 0.87232443343849
AUC of PRC = 0.5821896260378459
min(+P, Se) = 0.5527522935779816
f1_score = 0.33392540877212057





Epoch 94 Batch 0: Train Loss = 0.2822
Model Loss = 0.2815, Decov Loss = 0.0000
Epoch 94 Batch 30: Train Loss = 0.2527
Model Loss = 0.2520, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.2791
valid_model Loss = 0.2791
valid_decov Loss = 0.0000
confusion matrix:
[[2707   79]
 [ 273  163]]
accuracy = 0.8907510638237
precision class 0 = 0.9083892703056335
precision class 1 = 0.6735537052154541
recall class 0 = 0.9716439247131348
recall class 1 = 0.3738532066345215
AUC of ROC = 0.8623293400159382
AUC of PRC = 0.5689260918408467
min(+P, Se) = 0.5435779816513762
f1_score = 0.4808259515694862





Epoch 95 Batch 0: Train Loss = 0.2011
Model Loss = 0.2004, Decov Loss = 0.0000
Epoch 95 Batch 30: Train Loss = 0.2614
Model Loss = 0.2607, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.2822
valid_model Loss = 0.2822
valid_decov Loss = 0.0000
confusion matrix:
[[2743   43]
 [ 324  112]]
accuracy = 0.8860955834388733
precision class 0 = 0.894359290599823
precision class 1 = 0.7225806713104248
recall class 0 = 0.9845656752586365
recall class 1 = 0.2568807303905487
AUC of ROC = 0.8699411210706218
AUC of PRC = 0.5735698228132337
min(+P, Se) = 0.5481651376146789
f1_score = 0.3790186007169805





Epoch 96 Batch 0: Train Loss = 0.2291
Model Loss = 0.2286, Decov Loss = 0.0000
Epoch 96 Batch 30: Train Loss = 0.2645
Model Loss = 0.2639, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.2716
valid_model Loss = 0.2716
valid_decov Loss = 0.0000
confusion matrix:
[[2692   94]
 [ 272  164]]
accuracy = 0.8864059448242188
precision class 0 = 0.9082320928573608
precision class 1 = 0.6356589198112488
recall class 0 = 0.9662598967552185
recall class 1 = 0.3761467933654785
AUC of ROC = 0.8712879601151234
AUC of PRC = 0.5764946601991509
min(+P, Se) = 0.555045871559633
f1_score = 0.4726224554024744





Epoch 97 Batch 0: Train Loss = 0.2846
Model Loss = 0.2839, Decov Loss = 0.0000
Epoch 97 Batch 30: Train Loss = 0.2587
Model Loss = 0.2580, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.2738
valid_model Loss = 0.2738
valid_decov Loss = 0.0000
confusion matrix:
[[2725   61]
 [ 295  141]]
accuracy = 0.8895096182823181
precision class 0 = 0.9023178815841675
precision class 1 = 0.698019802570343
recall class 0 = 0.978104829788208
recall class 1 = 0.3233945071697235
AUC of ROC = 0.8721688389522976
AUC of PRC = 0.5772050252345092
min(+P, Se) = 0.5527522935779816
f1_score = 0.4420062935887337





Epoch 98 Batch 0: Train Loss = 0.2062
Model Loss = 0.2057, Decov Loss = 0.0000
Epoch 98 Batch 30: Train Loss = 0.2518
Model Loss = 0.2511, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.2773
valid_model Loss = 0.2773
valid_decov Loss = 0.0000
confusion matrix:
[[2664  122]
 [ 239  197]]
accuracy = 0.8879578113555908
precision class 0 = 0.9176713824272156
precision class 1 = 0.6175548434257507
recall class 0 = 0.9562095999717712
recall class 1 = 0.45183485746383667
AUC of ROC = 0.8735568405592841
AUC of PRC = 0.58053449343027
min(+P, Se) = 0.551487414187643
f1_score = 0.521854295816125





Epoch 99 Batch 0: Train Loss = 0.2797
Model Loss = 0.2792, Decov Loss = 0.0000
Epoch 99 Batch 30: Train Loss = 0.2595
Model Loss = 0.2587, Decov Loss = 0.0000

==>Predicting on validation
Valid Loss = 0.2928
valid_model Loss = 0.2928
valid_decov Loss = 0.0000
confusion matrix:
[[2755   31]
 [ 346   90]]
accuracy = 0.8829919099807739
precision class 0 = 0.8884230852127075
precision class 1 = 0.7438016533851624
recall class 0 = 0.9888729453086853
recall class 1 = 0.20642201602458954
AUC of ROC = 0.8709323155752551
AUC of PRC = 0.5720834642725382
min(+P, Se) = 0.5568181818181818
f1_score = 0.3231597766909604



### Run for test

In [13]:
checkpoint = torch.load(file_name)
save_epoch = checkpoint['epoch']
model.load_state_dict(checkpoint['net'])
optimizer.load_state_dict(checkpoint['optimizer'])
model.eval()

test_reader = InHospitalMortalityReader(dataset_dir=os.path.join(data_path, 'test'),
                                            listfile=os.path.join(data_path, 'test_listfile.csv'),
                                            period_length=48.0)
test_raw = utils.load_data(test_reader, discretizer, normalizer, small_part, return_names=True)
test_dataset = Dataset(test_raw['data'][0], test_raw['data'][1], test_raw['names'])
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [14]:
batch_loss = []
y_true = []
y_pred = []
with torch.no_grad():
    model.eval()
    for step, (batch_x, batch_y, batch_name) in enumerate(test_loader):
        batch_x = batch_x.float().to(device)
        batch_y = batch_y.float().to(device)
        batch_demo = []
        for i in range(len(batch_name)):
            cur_id, cur_ep, _ = batch_name[i].split('_', 2)
            cur_idx = cur_id + '_' + cur_ep
            cur_demo = torch.tensor(demographic_data[idx_list.index(cur_idx)], dtype=torch.float32)
            batch_demo.append(cur_demo)

        batch_demo = torch.stack(batch_demo).to(device)
        output = model(batch_x, batch_demo)[0]

        loss = get_loss(output, batch_y.unsqueeze(-1))
        batch_loss.append(loss.cpu().detach().numpy())
        y_pred += list(output.cpu().detach().numpy().flatten())
        y_true += list(batch_y.cpu().numpy().flatten())

print("\n==>Predicting on test")
print('Test Loss = %.4f'%(np.mean(np.array(batch_loss))))
y_pred = np.array(y_pred)
y_pred = np.stack([1 - y_pred, y_pred], axis=1)
test_res = metrics.print_metrics_binary(y_true, y_pred)




==>Predicting on test
Test Loss = 0.2532
confusion matrix:
[[2811   51]
 [ 270  104]]
accuracy = 0.9008034467697144
precision class 0 = 0.9123660922050476
precision class 1 = 0.6709677577018738
recall class 0 = 0.9821802973747253
recall class 1 = 0.27807486057281494
AUC of ROC = 0.8719613822277529
AUC of PRC = 0.5344983146118463
min(+P, Se) = 0.5106951871657754
f1_score = 0.3931947039659704


In [15]:
# Bootstrap
N = len(y_true)
N_idx = np.arange(N)
K = 1000

auroc = []
auprc = []
minpse = []
for i in range(K):
    boot_idx = np.random.choice(N_idx, N, replace=True)
    boot_true = np.array(y_true)[boot_idx]
    boot_pred = y_pred[boot_idx, :]
    test_ret = metrics.print_metrics_binary(boot_true, boot_pred, verbose=0)
    auroc.append(test_ret['auroc'])
    auprc.append(test_ret['auprc'])
    minpse.append(test_ret['minpse'])
    print('%d/%d'%(i+1,K))
    
print('auroc %.4f(%.4f)'%(np.mean(auroc), np.std(auroc)))
print('auprc %.4f(%.4f)'%(np.mean(auprc), np.std(auprc)))
print('minpse %.4f(%.4f)'%(np.mean(minpse), np.std(minpse)))

1/1000
2/1000
3/1000
4/1000
5/1000
6/1000
7/1000
8/1000
9/1000
10/1000
11/1000
12/1000
13/1000
14/1000
15/1000
16/1000
17/1000
18/1000
19/1000
20/1000
21/1000
22/1000
23/1000
24/1000
25/1000
26/1000
27/1000
28/1000
29/1000
30/1000
31/1000
32/1000
33/1000
34/1000
35/1000
36/1000
37/1000
38/1000
39/1000
40/1000
41/1000
42/1000
43/1000
44/1000
45/1000
46/1000
47/1000
48/1000
49/1000
50/1000
51/1000
52/1000
53/1000
54/1000
55/1000
56/1000
57/1000
58/1000
59/1000
60/1000
61/1000
62/1000
63/1000
64/1000
65/1000
66/1000
67/1000
68/1000
69/1000
70/1000
71/1000
72/1000
73/1000
74/1000
75/1000
76/1000
77/1000
78/1000
79/1000
80/1000
81/1000
82/1000
83/1000
84/1000
85/1000
86/1000
87/1000
88/1000
89/1000
90/1000
91/1000
92/1000
93/1000
94/1000
95/1000
96/1000
97/1000
98/1000
99/1000
100/1000
101/1000
102/1000
103/1000
104/1000
105/1000
106/1000
107/1000
108/1000
109/1000
110/1000
111/1000
112/1000
113/1000
114/1000
115/1000
116/1000
117/1000
118/1000
119/1000
120/1000
121/1000
122/1000
123/1000
1

927/1000
928/1000
929/1000
930/1000
931/1000
932/1000
933/1000
934/1000
935/1000
936/1000
937/1000
938/1000
939/1000
940/1000
941/1000
942/1000
943/1000
944/1000
945/1000
946/1000
947/1000
948/1000
949/1000
950/1000
951/1000
952/1000
953/1000
954/1000
955/1000
956/1000
957/1000
958/1000
959/1000
960/1000
961/1000
962/1000
963/1000
964/1000
965/1000
966/1000
967/1000
968/1000
969/1000
970/1000
971/1000
972/1000
973/1000
974/1000
975/1000
976/1000
977/1000
978/1000
979/1000
980/1000
981/1000
982/1000
983/1000
984/1000
985/1000
986/1000
987/1000
988/1000
989/1000
990/1000
991/1000
992/1000
993/1000
994/1000
995/1000
996/1000
997/1000
998/1000
999/1000
1000/1000
auroc 0.8724(0.0088)
auprc 0.5363(0.0266)
minpse 0.5121(0.0222)
