# BiteNet reproduction

When using RNN on BiteNet, the forgetfulness of RNN tends to lose memory of long term data.

Attention mechanism have some good performance, but not quite fit for EHR data, as for timestamps and hierarchical data.

BiteNet claims to have longer memory and have improvements on timestamps and hierarchical data

The project reproduce BiteNet model. Use MIMIC III data as input. With diagnosis codes and time stamp of each visit, predict re-admission.
Also predict with baseline models RNN and RETAIN.

Result shows that BiteNet performs better than baseline modes.
AUC score of this project is not exactly same as original paper. This could be from some parameters are not specified in original paper. 
So settings here is not same as original paper.

BiteNet paper: Xueping Peng, Guodong Long, Tao Shen, Sen Wang, Jing Jiang, and Chengqi Zhang. 2020. Bitenet: bidirectional temporal encoder network to predict medical outcomes. In 2020 IEEE International Conference on Data Mining (ICDM), pages 412–421. IEEE.

## Part 1: Load Data

In [1]:
import os
import json
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
with open('./data/train.json', 'r') as f:
    train = json.load(f)
with open('./data/dev.json', 'r') as f:
    val = json.load(f)
with open('./data/test.json', 'r') as f:
    test = json.load(f)

In [3]:
print(len(test['context_codes']))
print(len(train['context_codes']))


749
5992


In [4]:
train.keys()
train_small_size = 640
train_small = {}
test_small = {}
for k in train:
    train_small[k] = train[k][:int(len(train['context_codes'])/32)*32]
    test_small[k] = test[k][:int(len(test['context_codes'])/32)*32]

In [5]:
from torch.utils.data import Dataset


class CustomDataset(Dataset):
    
    def __init__(self, seqs, time, label):
        self.x = torch.tensor(seqs)
        self.t = torch.tensor(time)
        self.y = torch.tensor(label)
    
    def __len__(self):
        
        
        return len(self.x)
        
    
    def __getitem__(self, index):
        
        
        return self.x[index], self.t[index], self.y[index]
        


In [6]:
train_dataset = CustomDataset(train_small['context_codes'], train_small['intervals'], train_small['labels_2'])
val_dataset = CustomDataset(val['context_codes'], val['intervals'], val['labels_2'])
test_dataset = CustomDataset(test_small['context_codes'], test_small['intervals'], test_small['labels_2'])

'''train_dataset = CustomDataset(train['context_codes'], train['intervals'], train['labels_2'])
val_dataset = CustomDataset(val['context_codes'], val['intervals'], val['labels_2'])
test_dataset = CustomDataset(test['context_codes'], test['intervals'], test['labels_2'])'''

"train_dataset = CustomDataset(train['context_codes'], train['intervals'], train['labels_2'])\nval_dataset = CustomDataset(val['context_codes'], val['intervals'], val['labels_2'])\ntest_dataset = CustomDataset(test['context_codes'], test['intervals'], test['labels_2'])"

In [7]:
from torch.utils.data import DataLoader

def load_data(train_dataset, val_dataset, test_dataset):
    
    batch_size = 32
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    
    return train_loader, val_loader, test_loader


train_loader, val_loader, test_loader = load_data(train_dataset, val_dataset, test_dataset)    

## Part 2: BiteNet construction

The MasEnc block is constructed in EncoderStack.

BiteNet put inputs of codes through embedding, MasEnc block, then attention pooling. Input of time interval goes through embedding.
Add results together and goes through two MasEnc separately with different mask. Do attention pool again. Concat results and feed foward to output.

In [8]:
class ResidualConnection(torch.nn.Module):
    def __init__(self, layer, input_dim):
        super().__init__()
        
        self.layer = layer
        self.layer_norm = nn.LayerNorm(input_dim)
    
    def forward(self, x, mask):
        #print("x shape: ", x.shape)
        y = self.layer_norm(x)
        #print("y shape: ", y.shape)
        y = self.layer(y, mask)
        #print(self.layer)
        #print("y shape: ", y.shape)
        return x + y

In [9]:
class MasEnc(torch.nn.Module):
    def __init__(self, input_dim, embed_dim, num_heads, attn_mask):
        super().__init__()

        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.atten_mask = attn_mask

        
        self.q_linear = nn.Linear(input_dim, embed_dim, bias = False)
        self.k_linear = nn.Linear(input_dim, embed_dim, bias = False)
        self.v_linear = nn.Linear(input_dim, embed_dim, bias = False)
        
        self.mulatt = nn.MultiheadAttention(embed_dim = self.embed_dim, num_heads = self.num_heads, batch_first = True)
        
    
    def forward(self, x, input_mask):
        queries = x
        keys = x
        
        q = self.q_linear(queries)  # (N, L_q, d)
        k = self.k_linear(keys)  # (N, L_k, d)
        v = self.v_linear(keys)  # (N, L_k, d)
        
        if self.atten_mask == 'diag':
            attn_mask = torch.diag(torch.ones(x.shape[1]).bool())
        elif self.atten_mask == 'forward':
            attn_mask = torch.triu(torch.ones(x.shape[1], x.shape[1], dtype=torch.bool), diagonal=1)
        elif self.atten_mask == 'backward':
            attn_mask = torch.tril(torch.ones(x.shape[1], x.shape[1], dtype=torch.bool), diagonal=-1)
        else:
            attn_mask = torch.zeros(x.shape[1], x.shape[1], dtype=torch.bool)
              
        
        #data = self.mulatt(q, k, v, key_padding_mask = ~input_mask, attn_mask = attn_mask, need_weights = False)
        data = self.mulatt(q, k, v, attn_mask = attn_mask, need_weights = False)
        '''print("MasEnc q: ", q.shape, q.min(), q.max())
        print("MasEnc k: ", k.shape, k.min(), k.max())
        print("MasEnc v: ", v.shape, v.min(), v.max())
        print("MasEnc im: ", input_mask.shape, input_mask.min(), input_mask.max())
        print("MasEnc m: ", attn_mask.shape, attn_mask.min(), attn_mask.max(), attn_mask)
        print("MasEnc out: ", data[0].shape, data[0].min(), data[0].max())'''
        #print("call masenc")
        #print("masenc out shape: ", len(data), data[0].shape, data[1].shape)
        
        return data[0]

In [10]:
class FeedForwardNetwork(torch.nn.Module):
    def __init__(self, input_size, hidden_size, filter_size, dropout):
        super().__init__()
        self.hidden_size = hidden_size
        self.filter_size = filter_size
        self.dropout = dropout
        
        self.filter_layer = nn.Linear(input_size, filter_size)
        self.relu = nn.ReLU()
        self.drop_layer = nn.Dropout(dropout)
        self.output_layer = nn.Linear(filter_size, hidden_size)
        
    def forward(self, x, input_size):
        out = self.filter_layer(x)
        #print("FFN0: ", out.shape, out.min(), out.max())
        out = self.relu(out)
        #print("FFN1: ", out.shape, out.min(), out.max())
        out = self.drop_layer(out)
        #print("FFN2: ", out.shape, out.min(), out.max())
        out = self.output_layer(out)
        #print("FFN3: ", out.shape, out.min(), out.max())
        return out

In [11]:
class EncoderStack(torch.nn.Module):
    
    def __init__(self, input_dim, embed_dim, num_heads, attn_mask, n_hidden_layers, n_batch, n_visit, hidden_size, filter_size, dropout):
        super().__init__()
        self.embed_dim = embed_dim
        self.output_normalization = nn.LayerNorm([n_batch, n_visit, embed_dim])
        self.layers = []
        
        for i in range(n_hidden_layers):
            masked_encoder_layer = MasEnc(input_dim, embed_dim, num_heads, attn_mask)
            feed_forward_network = FeedForwardNetwork(embed_dim, hidden_size, filter_size, dropout)
            
            self.layers.append([
                                ResidualConnection(masked_encoder_layer, [n_batch, n_visit, embed_dim]), 
                                ResidualConnection(feed_forward_network, [n_batch, n_visit, embed_dim])
                                ])
            

    def forward(self, x, input_mask):
        
        for layer in self.layers:
            masked_encoder_layer = layer[0]
            feed_forward_network = layer[1]
            
            #print("EncoderStack0: ", x.shape, x.min(), x.max())
            x = masked_encoder_layer(x, input_mask)
            #print("EncoderStack1: ", x.shape, x.min(), x.max())
            x = feed_forward_network(x, input_mask)
            #print("EncoderStack2: ", x.shape, x.min(), x.max())
        
        out = self.output_normalization(x)
        return out

In [12]:
class AttentionPooling(torch.nn.Module):
    def __init__(self, embedding_size):
        super().__init__()
        
        self.fc1 = nn.Linear(embedding_size, embedding_size)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(embedding_size, embedding_size)
        
    def forward(self,x, input_mask):
        #print("AttentionPooling0: ", x.shape, x.min(), x.max())
        x = self.fc2(self.relu(self.fc1(x)))
        #print("AttentionPooling1: ", x.shape, x.min(), x.max())
        _mask = torch.zeros_like(input_mask)
        _mask[~input_mask] = float('-inf')
        x = x + _mask.unsqueeze(-1)
        #print("AttentionPooling2: ", x.shape, x.min(), x.max())
        soft = torch.nn.functional.softmax(x,1)
        #print("AttentionPooling3: ", soft.shape, soft.min(), soft.max())
        out = (soft*x).sum(dim = 1)
        return out

In [13]:
class BiteNet(torch.nn.Module):
    
    def __init__(self):
        super().__init__()
        
        self.lr = 0.0001
        self.dropout_rate = 0.1
        self.n_intervals = 12 * 365 + 1
        self.n_visits = 10
        self.n_codes = 39
        self.vocabulary_size = 2438
        self.digit3_size = 2438
        self.pos_encoding = None
        self.embedding_size = 50
        self.num_hidden_layers = 2
        self.num_heads = 2
        self.batch = 32
        
        self.hidden_size = self.embedding_size
        self.filter_size = self.embedding_size
        
        self.code_embedding_layer = torch.nn.Embedding(self.vocabulary_size, self.embedding_size) 
        self.interval_embedding_layer = torch.nn.Embedding(self.n_intervals, self.embedding_size) 
        self.common_layer1 = EncoderStack(self.embedding_size, self.embedding_size, self.num_heads, 'diag', 
                                          self.num_hidden_layers, self.batch * self.n_visits, self.n_codes, 
                                          self.hidden_size, self.filter_size, self.dropout_rate)
        self.attn_pool_layer1 = AttentionPooling(self.embedding_size)
        self.common_layer2 = EncoderStack(self.embedding_size, self.embedding_size, self.num_heads, 'forward', 
                                          self.num_hidden_layers, self.batch, self.n_visits, 
                                          self.hidden_size, self.filter_size, self.dropout_rate)
        self.common_layer3 = EncoderStack(self.embedding_size, self.embedding_size, self.num_heads, 'backward', 
                                          self.num_hidden_layers, self.batch, self.n_visits, 
                                          self.hidden_size, self.filter_size, self.dropout_rate)
        self.attn_pool_layer2 = AttentionPooling(self.embedding_size)
        self.attn_pool_layer3 = AttentionPooling(self.embedding_size)
        self.fc1 = nn.Linear(2*self.embedding_size, 2*self.embedding_size)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(self.dropout_rate)
        self.fc2 = nn.Linear(2*self.embedding_size,1)
        self.sigmoid = nn.Sigmoid()
    def forward(self, code_input, interval_input):
        inputs_mask = (code_input != 0)
        visit_mask = (code_input.sum(-1) != 0)
        
        # shape [batch_size, n_visits, n_codes, embedding_size]
        code_embed = self.code_embedding_layer(code_input)
        # reshape to (batch*n_visit, n_codes, embedding_size)
        e = code_embed.reshape(code_embed.shape[0]*code_embed.shape[1],code_embed.shape[2],code_embed.shape[3])
        
        #print("e: ", e.shape, e.min(), e.max())
        # reshape to (batch*n_visit, n_codes)
        e_mask = inputs_mask.reshape(inputs_mask.shape[0]*inputs_mask.shape[1],inputs_mask.shape[2])
        
        h = self.common_layer1(e, e_mask)
        
        #print("h: ", h.shape, h.min(), h.max())
        
        v = self.attn_pool_layer1(h, e_mask)
        #print("v: ", v.shape, v.min(), v.max())
        
        # reshape to (batch, n_visit, embedding_size)
        v = v.reshape(code_input.shape[0],code_input.shape[1],self.embedding_size)
        
        e_p = self.interval_embedding_layer(interval_input)
        
        
        v = v + e_p
        #print("v: ", v.shape, v.min(), v.max())
        
        
        o_fw = self.common_layer2(v, visit_mask)
        #print("o_fw: ", o_fw.shape, o_fw.min(), o_fw.max())
        u_fw = self.attn_pool_layer2(o_fw, visit_mask)
        #print("u_fw: ", u_fw.shape, u_fw.min(), u_fw.max())
        o_bw = self.common_layer3(v, visit_mask)
        #print("o_bw: ", o_bw.shape, o_bw.min(), o_bw.max())
        u_bw = self.attn_pool_layer3(o_bw, visit_mask)
        #print("u_bw: ", u_bw.shape, u_bw.min(), u_bw.max())
        
        b_bi = torch.cat((u_fw, u_bw), 1)
        #print("b_bi: ", b_bi.shape, b_bi.min(), b_bi.max())
        out = self.sigmoid(self.fc2(self.dropout(self.relu(self.fc1(b_bi)))))
        
        return out

## Part 3: Test for BiteNet

run for 10 rounds. Each round has 10 epochs. Record last epoch auc score.
Take average of 10 rounds for final result

In [14]:
bitenet = BiteNet()
count_para = sum(p.numel() for p in bitenet.parameters() if p.requires_grad)
count_para

1678451

In [16]:
from sklearn.metrics import precision_recall_curve, auc, precision_score
import time

optimizer = torch.optim.Adam(bitenet.parameters(), lr = 0.001)
criterion = torch.nn.BCELoss()

def bitenet_train(train_loader):
    bitenet.train()
    for data in train_loader:  # Iterate in batches over the training dataset.
        x,t,y = data
        pred_y = bitenet(x, t)
        #print(pred_y.shape, pred_y.min(), pred_y.max())
        loss = criterion(pred_y, y.float())
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

def bitenet_test(loader):
    bitenet.eval()
    pred = []
    y_true = []
    for data in loader:  # Iterate in batches over the training/test dataset.
        x,t,y = data
        out = bitenet(x, t)
        pred +=  out.squeeze().tolist() 
        y_true += y.squeeze().tolist()
    precision, recall, thresholds = precision_recall_curve(y_true, pred)
    auc_score = auc(recall, precision)
    y_pred = [p>0.5 for p in pred]
    p_score = precision_score(y_true, y_pred)
    return auc_score, p_score
 


ts = time.time()
auc_score = []
for times in range(10):
    bitenet.__init__()
    criterion = nn.BCELoss()
    optimizer = torch.optim.Adam(bitenet.parameters(), lr=0.001)
    for epoch in range(10):
        bitenet_train(train_loader)
        bitenet_train_acc = bitenet_test(train_loader)
        bitenet_test_acc = bitenet_test(test_loader)
        print(f'Epoch: {epoch + 1:03d}, Train Score: {bitenet_train_acc[0]:.4f}, {bitenet_train_acc[1]:.4f}, Test Score: {bitenet_test_acc[0]:.4f}, {bitenet_test_acc[1]:.4f}')
    auc_score.append(bitenet_test_acc[0])
    print("round {}, auc {}".format(times+1, bitenet_test_acc[0]))
        
te = time.time()
print(te-ts)
auc_ = np.array(auc_score)
print("mean: {}, std: {}".format(np.mean(auc_), np.std(auc_)))

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epoch: 001, Train Score: 0.2207, 0.0000, Test Score: 0.1766, 0.0000


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epoch: 002, Train Score: 0.2331, 0.0000, Test Score: 0.1794, 0.0000


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epoch: 003, Train Score: 0.2472, 0.0000, Test Score: 0.1785, 0.0000


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epoch: 004, Train Score: 0.2764, 0.0000, Test Score: 0.1854, 0.0000


  _warn_prf(average, modifier, msg_start, len(result))


Epoch: 005, Train Score: 0.3201, 1.0000, Test Score: 0.2068, 0.0000
Epoch: 006, Train Score: 0.3251, 1.0000, Test Score: 0.2302, 0.0000
Epoch: 007, Train Score: 0.3782, 0.9697, Test Score: 0.2105, 0.3333
Epoch: 008, Train Score: 0.5010, 0.9096, Test Score: 0.3107, 0.6667
Epoch: 009, Train Score: 0.5516, 0.6054, Test Score: 0.3093, 0.3059
Epoch: 010, Train Score: 0.6000, 0.8514, Test Score: 0.2701, 0.4250
round 1, auc 0.27009182567630896


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epoch: 001, Train Score: 0.2237, 0.0000, Test Score: 0.1798, 0.0000


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epoch: 002, Train Score: 0.2302, 0.0000, Test Score: 0.1779, 0.0000


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epoch: 003, Train Score: 0.2689, 0.0000, Test Score: 0.1891, 0.0000


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epoch: 004, Train Score: 0.3000, 0.0000, Test Score: 0.1938, 0.0000
Epoch: 005, Train Score: 0.4346, 0.9487, Test Score: 0.2952, 1.0000
Epoch: 006, Train Score: 0.4872, 0.8564, Test Score: 0.2900, 0.8333
Epoch: 007, Train Score: 0.5489, 0.8176, Test Score: 0.2946, 0.4500
Epoch: 008, Train Score: 0.5868, 0.7875, Test Score: 0.2923, 0.3469
Epoch: 009, Train Score: 0.6185, 0.8225, Test Score: 0.3119, 0.4571
Epoch: 010, Train Score: 0.6534, 0.8267, Test Score: 0.2777, 0.3091
round 2, auc 0.27765832230032417


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epoch: 001, Train Score: 0.2255, 0.0000, Test Score: 0.1841, 0.0000


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epoch: 002, Train Score: 0.2304, 0.0000, Test Score: 0.1810, 0.0000


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epoch: 003, Train Score: 0.2509, 0.0000, Test Score: 0.1849, 0.0000


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epoch: 004, Train Score: 0.3100, 0.0000, Test Score: 0.1855, 0.0000


  _warn_prf(average, modifier, msg_start, len(result))


Epoch: 005, Train Score: 0.3339, 1.0000, Test Score: 0.1804, 0.0000
Epoch: 006, Train Score: 0.4380, 0.9487, Test Score: 0.2269, 0.5000
Epoch: 007, Train Score: 0.5294, 0.9186, Test Score: 0.2799, 0.7143
Epoch: 008, Train Score: 0.5626, 0.8109, Test Score: 0.2862, 0.4333
Epoch: 009, Train Score: 0.6126, 0.8337, Test Score: 0.2609, 0.4103
Epoch: 010, Train Score: 0.6420, 0.6948, Test Score: 0.2591, 0.2958
round 3, auc 0.25914291330787853


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epoch: 001, Train Score: 0.2229, 0.0000, Test Score: 0.1902, 0.0000


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epoch: 002, Train Score: 0.2292, 0.0000, Test Score: 0.1852, 0.0000


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epoch: 003, Train Score: 0.2332, 0.0000, Test Score: 0.1856, 0.0000


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epoch: 004, Train Score: 0.2865, 0.0000, Test Score: 0.1926, 0.0000
Epoch: 005, Train Score: 0.3896, 0.6809, Test Score: 0.2877, 0.5833
Epoch: 006, Train Score: 0.4875, 0.8757, Test Score: 0.3058, 0.7500
Epoch: 007, Train Score: 0.5346, 0.8367, Test Score: 0.3020, 0.6500
Epoch: 008, Train Score: 0.5807, 0.8515, Test Score: 0.3101, 0.4848
Epoch: 009, Train Score: 0.6275, 0.7766, Test Score: 0.3060, 0.3878
Epoch: 010, Train Score: 0.6552, 0.8177, Test Score: 0.3067, 0.3390
round 4, auc 0.3067466033543461


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epoch: 001, Train Score: 0.2194, 0.0000, Test Score: 0.1835, 0.0000


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epoch: 002, Train Score: 0.2238, 0.0000, Test Score: 0.1833, 0.0000


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epoch: 003, Train Score: 0.2503, 0.0000, Test Score: 0.1750, 0.0000


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epoch: 004, Train Score: 0.2858, 0.0000, Test Score: 0.2002, 0.0000
Epoch: 005, Train Score: 0.3055, 0.9167, Test Score: 0.1890, 0.3333
Epoch: 006, Train Score: 0.3830, 0.9032, Test Score: 0.2028, 0.3333
Epoch: 007, Train Score: 0.4542, 0.7847, Test Score: 0.2134, 0.3333
Epoch: 008, Train Score: 0.5770, 0.9010, Test Score: 0.2995, 0.7059
Epoch: 009, Train Score: 0.6242, 0.8800, Test Score: 0.3205, 0.6957
Epoch: 010, Train Score: 0.6446, 0.8491, Test Score: 0.3186, 0.5000
round 5, auc 0.31861385610417825


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epoch: 001, Train Score: 0.2234, 0.0000, Test Score: 0.1885, 0.0000


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epoch: 002, Train Score: 0.2347, 0.0000, Test Score: 0.1833, 0.0000


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epoch: 003, Train Score: 0.2836, 0.0000, Test Score: 0.2010, 0.0000
Epoch: 004, Train Score: 0.2923, 1.0000, Test Score: 0.1982, 0.0000
Epoch: 005, Train Score: 0.3364, 1.0000, Test Score: 0.1930, 0.0000
Epoch: 006, Train Score: 0.3543, 0.9565, Test Score: 0.1927, 0.2000
Epoch: 007, Train Score: 0.4776, 0.9388, Test Score: 0.2569, 0.6667
Epoch: 008, Train Score: 0.5247, 0.8955, Test Score: 0.2903, 0.6500
Epoch: 009, Train Score: 0.5666, 0.8613, Test Score: 0.2957, 0.4595
Epoch: 010, Train Score: 0.5973, 0.9350, Test Score: 0.3051, 0.7000
round 6, auc 0.30514202541598456


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epoch: 001, Train Score: 0.2227, 0.0000, Test Score: 0.1829, 0.0000


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epoch: 002, Train Score: 0.2242, 0.0000, Test Score: 0.1900, 0.0000


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epoch: 003, Train Score: 0.2430, 0.0000, Test Score: 0.1820, 0.0000


  _warn_prf(average, modifier, msg_start, len(result))


Epoch: 004, Train Score: 0.2518, 1.0000, Test Score: 0.1945, 0.0000
Epoch: 005, Train Score: 0.3113, 0.8261, Test Score: 0.2123, 0.3333
Epoch: 006, Train Score: 0.4597, 0.8433, Test Score: 0.3101, 0.7500
Epoch: 007, Train Score: 0.4999, 0.8806, Test Score: 0.2827, 0.7143
Epoch: 008, Train Score: 0.5461, 0.8671, Test Score: 0.2885, 0.5200
Epoch: 009, Train Score: 0.5786, 0.8930, Test Score: 0.2942, 0.4800
Epoch: 010, Train Score: 0.6109, 0.7681, Test Score: 0.3033, 0.3585
round 7, auc 0.3032742045667778


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epoch: 001, Train Score: 0.2208, 0.0000, Test Score: 0.1836, 0.0000


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epoch: 002, Train Score: 0.2318, 0.0000, Test Score: 0.1815, 0.0000


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epoch: 003, Train Score: 0.2310, 0.0000, Test Score: 0.1824, 0.0000


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epoch: 004, Train Score: 0.3064, 0.0000, Test Score: 0.1945, 0.0000
Epoch: 005, Train Score: 0.4638, 0.7423, Test Score: 0.3065, 0.6071
Epoch: 006, Train Score: 0.5227, 0.7612, Test Score: 0.3093, 0.4750
Epoch: 007, Train Score: 0.5772, 0.8987, Test Score: 0.3174, 0.7143
Epoch: 008, Train Score: 0.6132, 0.8536, Test Score: 0.3149, 0.5455
Epoch: 009, Train Score: 0.6450, 0.8072, Test Score: 0.2884, 0.3878
Epoch: 010, Train Score: 0.6756, 0.7340, Test Score: 0.3081, 0.2737
round 8, auc 0.30813978253430285


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epoch: 001, Train Score: 0.2237, 0.0000, Test Score: 0.1876, 0.0000


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epoch: 002, Train Score: 0.2273, 0.0000, Test Score: 0.1976, 0.0000


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epoch: 003, Train Score: 0.2324, 0.0000, Test Score: 0.1933, 0.0000


  _warn_prf(average, modifier, msg_start, len(result))


Epoch: 004, Train Score: 0.2752, 1.0000, Test Score: 0.2122, 0.0000
Epoch: 005, Train Score: 0.4261, 0.8951, Test Score: 0.3289, 0.8000
Epoch: 006, Train Score: 0.4949, 0.9145, Test Score: 0.3111, 0.9167
Epoch: 007, Train Score: 0.5427, 0.8900, Test Score: 0.3012, 0.8125
Epoch: 008, Train Score: 0.5833, 0.7691, Test Score: 0.3026, 0.4884
Epoch: 009, Train Score: 0.6248, 0.8288, Test Score: 0.3136, 0.5806
Epoch: 010, Train Score: 0.6687, 0.8705, Test Score: 0.3178, 0.4865
round 9, auc 0.31781531752428416


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epoch: 001, Train Score: 0.2210, 0.0000, Test Score: 0.1841, 0.0000


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epoch: 002, Train Score: 0.2342, 0.0000, Test Score: 0.1848, 0.0000


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epoch: 003, Train Score: 0.2550, 0.0000, Test Score: 0.1879, 0.0000


  _warn_prf(average, modifier, msg_start, len(result))


Epoch: 004, Train Score: 0.3148, 1.0000, Test Score: 0.1996, 0.0000
Epoch: 005, Train Score: 0.4237, 0.9262, Test Score: 0.3112, 0.8182
Epoch: 006, Train Score: 0.5028, 0.8857, Test Score: 0.3045, 0.9000
Epoch: 007, Train Score: 0.5594, 0.8423, Test Score: 0.3041, 0.6000
Epoch: 008, Train Score: 0.5843, 0.7615, Test Score: 0.3001, 0.4324
Epoch: 009, Train Score: 0.6268, 0.7922, Test Score: 0.2916, 0.4706
Epoch: 010, Train Score: 0.6570, 0.7479, Test Score: 0.3005, 0.3667
round 10, auc 0.3004758268098901
4184.115975618362
mean: 0.29671006775942754, std: 0.01941065581443842


## Part 4: Test for baseline models

Use same data to predict with RNN and RETAIN.

In [51]:
class RNN(torch.nn.Module):
    
    def __init__(self, num_codes):
        super().__init__()
        
        #num_codes += 1
        num_embedding = 50
        
        self.embedding = nn.Embedding(num_codes, num_embedding)
        self.rnn = nn.GRU(num_embedding, hidden_size=num_embedding, batch_first=True, num_layers=5)
        self.fc = nn.Linear(in_features=num_embedding, out_features=1)
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, code_input, interval_input):
        batch_size = code_input.shape[0]
        input_mask = (code_input != 0)
        #x = torch.cat((code_input, interval_input[:,:,None]), -1)
        x = self.embedding(code_input)
        mask_ = input_mask.unsqueeze(3).repeat(1, 1, 1, x.shape[3])
        sum_embeddings = torch.sum(x*mask_, 2)
        output, _ = self.rnn(sum_embeddings)
        
        
        visit_length = torch.sum(torch.sum(input_mask,2)>0, 1)
        last_hidden_state = output[torch.arange(output.size(0)),visit_length-1,: ]
        logits = self.fc(last_hidden_state)
        probs = self.sigmoid(logits)
        #return probs.view(batch_size)
        return probs
        
        

In [52]:
baseline_rnn = RNN(num_codes = 2438)
criterion = nn.BCELoss()
optimizer = torch.optim.Adam(baseline_rnn.parameters(), lr=0.001)

def rnn_train(train_loader):
    baseline_rnn.train()
    for data in train_loader:  # Iterate in batches over the training dataset.
        x,t,y = data
        pred_y = baseline_rnn(x, t)
        #print(pred_y.shape, pred_y.min(), pred_y.max())
        loss = criterion(pred_y, y.float())
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

def rnn_test(loader):
    baseline_rnn.eval()
    pred = []
    y_true = []
    for data in loader:  # Iterate in batches over the training/test dataset.
        x,t,y = data
        out = baseline_rnn(x, t)
        pred +=  out.squeeze().tolist() 
        y_true += y.squeeze().tolist()
    precision, recall, thresholds = precision_recall_curve(y_true, pred)
    auc_score = auc(recall, precision)
    y_pred = [p>0.5 for p in pred]
    #p_score = precision_score(y_true, y_pred)
    p_score = 0
    return auc_score, p_score
 

ts = time.time()
auc_score = []
for times in range(10):
    baseline_rnn.__init__(num_codes = 2438)
    criterion = nn.BCELoss()
    optimizer = torch.optim.Adam(baseline_rnn.parameters(), lr=0.001)
    for epoch in range(10):
        rnn_train(train_loader)
        rnn_train_acc = rnn_test(train_loader)
        rnn_test_acc = rnn_test(test_loader)
        print(f'Epoch: {epoch + 1:03d}, Train Score: {rnn_train_acc[0]:.4f}, {rnn_train_acc[1]:.4f}, Test Score: {rnn_test_acc[0]:.4f}, {rnn_test_acc[1]:.4f}')
    auc_score.append(rnn_test_acc[0])
    print("round {}, auc {}".format(times+1, rnn_test_acc[0]))
        
te = time.time()
print(te-ts)
auc_ = np.array(auc_score)
print("mean: {}, std: {}".format(np.mean(auc_), np.std(auc_)))

Epoch: 001, Train Score: 0.4286, 0.0000, Test Score: 0.3248, 0.0000
Epoch: 002, Train Score: 0.5302, 0.0000, Test Score: 0.3314, 0.0000
Epoch: 003, Train Score: 0.6566, 0.0000, Test Score: 0.2917, 0.0000
Epoch: 004, Train Score: 0.7594, 0.0000, Test Score: 0.3019, 0.0000
Epoch: 005, Train Score: 0.8565, 0.0000, Test Score: 0.2925, 0.0000
Epoch: 006, Train Score: 0.9177, 0.0000, Test Score: 0.2835, 0.0000
Epoch: 007, Train Score: 0.9522, 0.0000, Test Score: 0.2802, 0.0000
Epoch: 008, Train Score: 0.9701, 0.0000, Test Score: 0.3087, 0.0000
Epoch: 009, Train Score: 0.9806, 0.0000, Test Score: 0.2843, 0.0000
Epoch: 010, Train Score: 0.9843, 0.0000, Test Score: 0.2842, 0.0000
round 1, auc 0.28419003445132524
Epoch: 001, Train Score: 0.3930, 0.0000, Test Score: 0.2786, 0.0000
Epoch: 002, Train Score: 0.4672, 0.0000, Test Score: 0.2834, 0.0000
Epoch: 003, Train Score: 0.5763, 0.0000, Test Score: 0.2956, 0.0000
Epoch: 004, Train Score: 0.6507, 0.0000, Test Score: 0.2763, 0.0000
Epoch: 005, Tra

In [55]:
class AlphaAttention(torch.nn.Module):

    def __init__(self, embedding_dim):
        super().__init__()
        
        self.a_att = nn.Linear(embedding_dim, 1)

    def forward(self, g, rev_masks):

        score = self.a_att(g)
        mask = (rev_masks.sum(2)>0).unsqueeze(2)
        score = score * mask + (~mask)*(-1e9)
        att_value = torch.nn.functional.softmax(score, dim=1)
        #print(att_value.shape)
        return att_value
    
class BetaAttention(torch.nn.Module):

    def __init__(self, embedding_dim):
        super().__init__()

        self.b_att = nn.Linear(embedding_dim, embedding_dim)


    def forward(self, h):

        score = self.b_att(h)
        beta = torch.tanh(score)
        #print(beta.shape)
        return beta

def attention_sum(alpha, beta, rev_v, rev_masks):
    
    mask = (rev_masks.sum(2)>0).unsqueeze(2).repeat(1,1,beta.shape[-1])
    a = alpha.repeat(1,1,beta.shape[-1])
    #print(beta.shape, a.shape, mask.shape, rev_v.shape)
    c = (rev_v * mask * a * beta).sum(1)
    return c

class retain(torch.nn.Module):
    def __init__(self, num_codes, embedding_dim=50):
        super().__init__()
        # Define the embedding layer using `nn.Embedding`. Set `embDimSize` to 128.
        self.embedding = nn.Embedding(num_codes, embedding_dim)
        # Define the RNN-alpha using `nn.GRU()`; Set `hidden_size` to 128. Set `batch_first` to True.
        self.rnn_a = nn.GRU(embedding_dim, embedding_dim, batch_first=True, num_layers=2)
        # Define the RNN-beta using `nn.GRU()`; Set `hidden_size` to 128. Set `batch_first` to True.
        self.rnn_b = nn.GRU(embedding_dim, embedding_dim, batch_first=True, num_layers=2)
        # Define the alpha-attention using `AlphaAttention()`;
        self.att_a = AlphaAttention(embedding_dim)
        # Define the beta-attention using `BetaAttention()`;
        self.att_b = BetaAttention(embedding_dim)
        # Define the linear layers using `nn.Linear()`;
        self.fc = nn.Linear(embedding_dim, 1)
        # Define the final activation layer using `nn.Sigmoid().
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, x, t):
        masks = (x != 0)
        # 1. Pass the reversed sequence through the embedding layer;
        x = self.embedding(x)
        # 2. Sum the reversed embeddings for each diagnosis code up for a visit of a patient.
        
        x = x * masks.unsqueeze(-1)
        x = torch.sum(x, dim = -2)
        # 3. Pass the reversed embegginds through the RNN-alpha and RNN-beta layer separately;
        g, _ = self.rnn_a(x)
        h, _ = self.rnn_b(x)
        # 4. Obtain the alpha and beta attentions using `AlphaAttention()` and `BetaAttention()`;
        alpha = self.att_a(g, masks)
        beta = self.att_b(h)
        # 5. Sum the attention up using `attention_sum()`;
        c = attention_sum(alpha, beta, x, masks)
        # 6. Pass the context vector through the linear and activation layers.
        logits = self.fc(c)
        probs = self.sigmoid(logits)
        return probs

In [56]:
from sklearn.metrics import precision_recall_curve, auc, precision_score
baseline_retain = retain(num_codes = 2438)
criterion = nn.BCELoss()
optimizer = torch.optim.Adam(baseline_retain.parameters(), lr=0.001)

def retain_train(train_loader):
    baseline_retain.train()
    for data in train_loader:  # Iterate in batches over the training dataset.
        x,t,y = data
        pred_y = baseline_retain(x, t)
        #print(pred_y.shape, pred_y.min(), pred_y.max())
        loss = criterion(pred_y, y.float())
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

def retain_test(loader):
    baseline_retain.eval()
    pred = []
    y_true = []
    for data in loader:  # Iterate in batches over the training/test dataset.
        x,t,y = data
        out = baseline_retain(x, t)
        pred +=  out.squeeze().tolist() 
        y_true += y.squeeze().tolist()
    precision, recall, thresholds = precision_recall_curve(y_true, pred)
    auc_score = auc(recall, precision)
    y_pred = [p>0.5 for p in pred]
    p_score = precision_score(y_true, y_pred)
    return auc_score, p_score
 

ts = time.time()
auc_score = []
for times in range(10):
    baseline_retain.__init__(num_codes = 2438)
    criterion = nn.BCELoss()
    optimizer = torch.optim.Adam(baseline_retain.parameters(), lr=0.001)
    for epoch in range(10):
        retain_train(train_loader)
        retain_train_acc = retain_test(train_loader)
        retain_test_acc = retain_test(test_loader)
        print(f'Epoch: {epoch + 1:03d}, Train Score: {retain_train_acc[0]:.4f}, {retain_train_acc[1]:.4f}, Test Score: {retain_test_acc[0]:.4f}, {retain_test_acc[1]:.4f}')
    auc_score.append(retain_test_acc[0])
    print("round {}, auc {}".format(times+1, retain_test_acc[0]))
        
te = time.time()
print(te-ts)
auc_ = np.array(auc_score)
print("mean: {}, std: {}".format(np.mean(auc_), np.std(auc_)))

Epoch: 001, Train Score: 0.4437, 0.9785, Test Score: 0.2930, 1.0000
Epoch: 002, Train Score: 0.6482, 0.9669, Test Score: 0.2944, 0.6250
Epoch: 003, Train Score: 0.8000, 0.9769, Test Score: 0.2886, 0.5500
Epoch: 004, Train Score: 0.9180, 0.9650, Test Score: 0.2830, 0.3269
Epoch: 005, Train Score: 0.9635, 0.9701, Test Score: 0.2640, 0.2436
Epoch: 006, Train Score: 0.9872, 0.9807, Test Score: 0.2581, 0.2453
Epoch: 007, Train Score: 0.9964, 0.9983, Test Score: 0.2408, 0.2474
Epoch: 008, Train Score: 0.9990, 0.9983, Test Score: 0.2542, 0.2444
Epoch: 009, Train Score: 0.9997, 1.0000, Test Score: 0.2571, 0.2667
Epoch: 010, Train Score: 0.9999, 1.0000, Test Score: 0.2508, 0.2593
round 1, auc 0.2508003516919891
Epoch: 001, Train Score: 0.4103, 1.0000, Test Score: 0.2875, 0.7500
Epoch: 002, Train Score: 0.6376, 0.9617, Test Score: 0.2936, 0.6875
Epoch: 003, Train Score: 0.8174, 0.9261, Test Score: 0.2699, 0.3019
Epoch: 004, Train Score: 0.9252, 0.9681, Test Score: 0.2696, 0.2973
Epoch: 005, Trai

### Score for BiteNet, RNN, Retain are 0.2967, 0.2665, 0.2744. BiteNet performs best.

## Ablation study

Replace MasEnc blocks with simple muti-head attention layer.

In [61]:
class multi_attn(torch.nn.Module):
    def __init__(self, input_dim, embed_dim, num_heads):
        super().__init__()

        self.embed_dim = embed_dim
        self.num_heads = num_heads
        #self.atten_mask = attn_mask

        
        self.q_linear = nn.Linear(input_dim, embed_dim, bias = False)
        self.k_linear = nn.Linear(input_dim, embed_dim, bias = False)
        self.v_linear = nn.Linear(input_dim, embed_dim, bias = False)
        
        self.mulatt = nn.MultiheadAttention(embed_dim = self.embed_dim, num_heads = self.num_heads, batch_first = True)
        
    
    def forward(self, x, input_mask):
        queries = x
        keys = x
        
        q = self.q_linear(queries)  # (N, L_q, d)
        k = self.k_linear(keys)  # (N, L_k, d)
        v = self.v_linear(keys)  # (N, L_k, d)
              
        
        #data = self.mulatt(q, k, v, key_padding_mask = ~input_mask, attn_mask = attn_mask, need_weights = False)
        data = self.mulatt(q, k, v, need_weights = False)
        
        
        return data[0]

In [62]:
class BiteNet_ablation(torch.nn.Module):
    
    def __init__(self):
        super().__init__()
        
        self.lr = 0.0001
        self.dropout_rate = 0.1
        self.n_intervals = 12 * 365 + 1
        self.n_visits = 10
        self.n_codes = 39
        self.vocabulary_size = 2438
        self.digit3_size = 2438
        self.pos_encoding = None
        self.embedding_size = 50
        self.num_hidden_layers = 2
        self.num_heads = 2
        self.batch = 32
        
        self.hidden_size = self.embedding_size
        self.filter_size = self.embedding_size
        
        self.code_embedding_layer = torch.nn.Embedding(self.vocabulary_size, self.embedding_size) 
        self.interval_embedding_layer = torch.nn.Embedding(self.n_intervals, self.embedding_size) 
        '''self.common_layer1 = EncoderStack(self.embedding_size, self.embedding_size, self.num_heads, 'diag', 
                                          self.num_hidden_layers, self.batch * self.n_visits, self.n_codes, 
                                          self.hidden_size, self.filter_size, self.dropout_rate)'''
        self.common_layer1 = multi_attn(self.embedding_size, self.embedding_size, self.num_heads)
        self.attn_pool_layer1 = AttentionPooling(self.embedding_size)
        '''self.common_layer2 = EncoderStack(self.embedding_size, self.embedding_size, self.num_heads, 'forward', 
                                          self.num_hidden_layers, self.batch, self.n_visits, 
                                          self.hidden_size, self.filter_size, self.dropout_rate)
        self.common_layer3 = EncoderStack(self.embedding_size, self.embedding_size, self.num_heads, 'backward', 
                                          self.num_hidden_layers, self.batch, self.n_visits, 
                                          self.hidden_size, self.filter_size, self.dropout_rate)'''
        self.common_layer2 = multi_attn(self.embedding_size, self.embedding_size, self.num_heads)
        self.common_layer3 = multi_attn(self.embedding_size, self.embedding_size, self.num_heads)
        
        
        self.attn_pool_layer2 = AttentionPooling(self.embedding_size)
        self.attn_pool_layer3 = AttentionPooling(self.embedding_size)
        self.fc1 = nn.Linear(2*self.embedding_size, 2*self.embedding_size)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(self.dropout_rate)
        self.fc2 = nn.Linear(2*self.embedding_size,1)
        self.sigmoid = nn.Sigmoid()
    def forward(self, code_input, interval_input):
        inputs_mask = (code_input != 0)
        visit_mask = (code_input.sum(-1) != 0)
        
        # shape [batch_size, n_visits, n_codes, embedding_size]
        code_embed = self.code_embedding_layer(code_input)
        # reshape to (batch*n_visit, n_codes, embedding_size)
        e = code_embed.reshape(code_embed.shape[0]*code_embed.shape[1],code_embed.shape[2],code_embed.shape[3])
        
        #print("e: ", e.shape, e.min(), e.max())
        # reshape to (batch*n_visit, n_codes)
        e_mask = inputs_mask.reshape(inputs_mask.shape[0]*inputs_mask.shape[1],inputs_mask.shape[2])
        
        h = self.common_layer1(e, e_mask)
        
        #print("h: ", h.shape, h.min(), h.max())
        
        v = self.attn_pool_layer1(h, e_mask)
        #print("v: ", v.shape, v.min(), v.max())
        
        # reshape to (batch, n_visit, embedding_size)
        v = v.reshape(code_input.shape[0],code_input.shape[1],self.embedding_size)
        
        e_p = self.interval_embedding_layer(interval_input)
        
        
        v = v + e_p
        #print("v: ", v.shape, v.min(), v.max())
        
        
        o_fw = self.common_layer2(v, visit_mask)
        #print("o_fw: ", o_fw.shape, o_fw.min(), o_fw.max())
        u_fw = self.attn_pool_layer2(o_fw, visit_mask)
        #print("u_fw: ", u_fw.shape, u_fw.min(), u_fw.max())
        o_bw = self.common_layer3(v, visit_mask)
        #print("o_bw: ", o_bw.shape, o_bw.min(), o_bw.max())
        u_bw = self.attn_pool_layer3(o_bw, visit_mask)
        #print("u_bw: ", u_bw.shape, u_bw.min(), u_bw.max())
        
        b_bi = torch.cat((u_fw, u_bw), 1)
        #print("b_bi: ", b_bi.shape, b_bi.min(), b_bi.max())
        out = self.sigmoid(self.fc2(self.dropout(self.relu(self.fc1(b_bi)))))
        
        return out

In [63]:
bitenet_a = BiteNet_ablation()
count_para = sum(p.numel() for p in bitenet_a.parameters() if p.requires_grad)
count_para

419551

In [64]:
optimizer = torch.optim.Adam(bitenet_a.parameters(), lr = 0.001)
criterion = torch.nn.BCELoss()

def bitenet_a_train(train_loader):
    bitenet_a.train()
    for data in train_loader:  # Iterate in batches over the training dataset.
        x,t,y = data
        pred_y = bitenet_a(x, t)
        #print(pred_y.shape, pred_y.min(), pred_y.max())
        loss = criterion(pred_y, y.float())
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

def bitenet_a_test(loader):
    bitenet_a.eval()
    pred = []
    y_true = []
    for data in loader:  # Iterate in batches over the training/test dataset.
        x,t,y = data
        out = bitenet_a(x, t)
        pred +=  out.squeeze().tolist() 
        y_true += y.squeeze().tolist()
    precision, recall, thresholds = precision_recall_curve(y_true, pred)
    auc_score = auc(recall, precision)
    y_pred = [p>0.5 for p in pred]
    #p_score = precision_score(y_true, y_pred)
    p_score = 0
    return auc_score, p_score
 


ts = time.time()
auc_score = []
for times in range(10):
    bitenet_a.__init__()
    criterion = nn.BCELoss()
    optimizer = torch.optim.Adam(bitenet_a.parameters(), lr=0.001)
    for epoch in range(10):
        bitenet_a_train(train_loader)
        bitenet_a_train_acc = bitenet_a_test(train_loader)
        bitenet_a_test_acc = bitenet_a_test(test_loader)
        print(f'Epoch: {epoch + 1:03d}, Train Score: {bitenet_a_train_acc[0]:.4f}, {bitenet_a_train_acc[1]:.4f}, Test Score: {bitenet_a_test_acc[0]:.4f}, {bitenet_a_test_acc[1]:.4f}')
    auc_score.append(bitenet_a_test_acc[0])
    print("round {}, auc {}".format(times+1, bitenet_a_test_acc[0]))
        
te = time.time()
print(te-ts)
auc_ = np.array(auc_score)
print("mean: {}, std: {}".format(np.mean(auc_), np.std(auc_)))

Epoch: 001, Train Score: 0.2287, 0.0000, Test Score: 0.1945, 0.0000
Epoch: 002, Train Score: 0.2399, 0.0000, Test Score: 0.1926, 0.0000
Epoch: 003, Train Score: 0.3330, 0.0000, Test Score: 0.2674, 0.0000
Epoch: 004, Train Score: 0.3613, 0.0000, Test Score: 0.2709, 0.0000
Epoch: 005, Train Score: 0.3995, 0.0000, Test Score: 0.2678, 0.0000
Epoch: 006, Train Score: 0.4286, 0.0000, Test Score: 0.2939, 0.0000
Epoch: 007, Train Score: 0.3969, 0.0000, Test Score: 0.2542, 0.0000
Epoch: 008, Train Score: 0.4413, 0.0000, Test Score: 0.2623, 0.0000
Epoch: 009, Train Score: 0.4706, 0.0000, Test Score: 0.2589, 0.0000
Epoch: 010, Train Score: 0.5014, 0.0000, Test Score: 0.2678, 0.0000
round 1, auc 0.2677961594968218
Epoch: 001, Train Score: 0.2139, 0.0000, Test Score: 0.1866, 0.0000
Epoch: 002, Train Score: 0.2322, 0.0000, Test Score: 0.1904, 0.0000
Epoch: 003, Train Score: 0.3676, 0.0000, Test Score: 0.2867, 0.0000
Epoch: 004, Train Score: 0.3687, 0.0000, Test Score: 0.2844, 0.0000
Epoch: 005, Trai

### Score decrease. Shows importance of MasEnc blocks.