In [1]:
# !pip install torch_geometric rdkit torch

In [2]:
from datetime import datetime
import time
import argparse
import sys
import torch
from torch import optim
from torch import nn
import torch.nn.functional as F
from sklearn import metrics
import pandas as pd
import numpy as np
from torch.nn.modules.container import ModuleList
from torch_geometric.nn import (
    GATConv,
    SAGPooling,
    LayerNorm,
    global_mean_pool,
    max_pool_neighbor_x,
    global_add_pool,
)


In [3]:
# Directory configuration
data_dir = "data"
model_dir = "models"
model_name = "case28"

# sys.path.append('/content/drive/MyDrive/Colab Notebooks')

In [4]:
####### Tunning parameters #######

# Number of epochs
n_epochs = 300

# SagPooling ratio & min score. 
# Set sp_ratio to None to disable ratio in SagPooling
sp_ratio = None
sp_min_score = None

# Enable using gpu
use_cuda = True

# Use activation function for CoAttention Layer
use_activation_fn = False

# Use ComplEx instead of RESCAL
use_ComplEx = True

# Use improved CoAttention Layer
# Could be "original" || "improved" || "multihead"
co_attention_method = "multihead"

# Use Explicit Valence
use_explicit_valence = False

# Number of GAT layers
num_GAT_layers = 4

# Number of GAT multiheads
num_GAT_multiheads = 2

#################################

In [5]:
# If using explicit valence feature
if use_explicit_valence:
    from data_preprocessing_explicit_valence import DrugDataset, DrugDataLoader, TOTAL_ATOM_FEATS
else:
    from data_preprocessing import DrugDataset, DrugDataLoader, TOTAL_ATOM_FEATS

  return undirected_edge_list.T, features


In [6]:
mode = "train"

n_atom_feats = TOTAL_ATOM_FEATS
# Not use
n_atom_hid = 64
# Total interactions information in the Interaction_information.csv
rel_total = 86
lr = 1e-2
weight_decay = 5e-4
neg_samples = 1
# Represents the number of samples (or graph instances) loaded in each batch during the training process.
batch_size = 1024
data_size_ratio = 1
kge_dim = 64

device = "cuda" if torch.cuda.is_available() and use_cuda else "cpu"

print(device)
print(f"Epochs: {n_epochs}")
print(f"Total of atom features: {TOTAL_ATOM_FEATS}")

cuda
Epochs: 300
Total of atom features: 55


In [7]:
def print_tunning_parameters():
    print()
    print("####### Tunning parameters #######")
    print()
    
    print("n_epochs =", n_epochs)
    print("use_cuda =", use_cuda)
    print()
    print("num_GAT_layers = ", num_GAT_layers)
    print("num_GAT_multiheads = ", num_GAT_multiheads)
    print()
    print("sp_ratio =", sp_ratio)
    print("sp_min_score =", sp_min_score)
    print()
    print("use_explicit_valence =", use_explicit_valence)
    print()
    print("use_activation_fn =", use_activation_fn)
    print()
    print("use_ComplEx =", use_ComplEx)
    print()
    print("co_attention_method =", co_attention_method)
    
    print()
    print("#################################")
    print()


In [8]:
class CoAttentionLayer(nn.Module):
    def __init__(self, n_features, use_activation_fn=True):
        super().__init__()
        self.n_features = n_features
        self.w_q = nn.Parameter(torch.zeros(n_features, n_features // 2))
        self.w_k = nn.Parameter(torch.zeros(n_features, n_features // 2))
        self.bias = nn.Parameter(torch.zeros(n_features // 2))
        self.a = nn.Parameter(torch.zeros(n_features // 2))
        self.use_activation_fn = use_activation_fn

        nn.init.xavier_uniform_(self.w_q)
        nn.init.xavier_uniform_(self.w_k)
        nn.init.xavier_uniform_(self.bias.view(*self.bias.shape, -1))
        nn.init.xavier_uniform_(self.a.view(*self.a.shape, -1))

    def forward(self, receiver, attendant):
        # receiver.shape  = (1024, 4, 64)
        # attendant.shape = (1024, 4, 64)
        keys = receiver @ self.w_k
        queries = attendant @ self.w_q
        # values = receiver @ self.w_v
        values = receiver

        # queries.shape = (1024, 4, 32)
        # keys.shape = (1024, 4, 32)
        e_activations = queries.unsqueeze(-3) + keys.unsqueeze(-2) + self.bias
        if self.use_activation_fn:
            e_scores = torch.tanh(e_activations) @ self.a
        else:
            e_scores = e_activations @ self.a
        attentions = e_scores

        return attentions


class MultiheadCoAttentionLayer(nn.Module):
    def __init__(self, n_features, use_activation_fn=True, dropout=0.1, n_heads=2):
        super().__init__()
        self.n_features = n_features
        self.n_heads = n_heads

        
        self.W_q = nn.ParameterList([torch.zeros(self.n_features, self.n_features // n_heads) for _ in range(n_heads)])
        self.W_k = nn.ParameterList([torch.zeros(self.n_features, self.n_features // n_heads) for _ in range(n_heads)])
        
        self.a = nn.Parameter(torch.zeros(self.n_features))
        self.bias = nn.ParameterList([torch.zeros(self.n_features // n_heads) for _ in range(n_heads)])
        
        self.use_activation_fn = use_activation_fn

        self.dropout = nn.Dropout(dropout)

        for i in range(n_heads):
            nn.init.xavier_uniform_(self.W_q[i])
            nn.init.xavier_uniform_(self.W_k[i])
            nn.init.xavier_uniform_(self.bias[i].view(*self.bias[i].shape, -1))
        
        nn.init.xavier_uniform_(self.a.view(*self.a.shape, -1))

    def forward(self, receiver, attendant):
        # receiver.shape  = (1024, 4, 64)
        # attendant.shape = (1024, 4, 64)

        # Compute attention score for each head
        head_outputs = []
        for i in range(self.n_heads):
            keys = receiver @ self.W_k[i]
            queries = attendant @ self.W_q[i]
            # print("keys.shape = ", keys.shape)
            e_activations = queries.unsqueeze(-3) + keys.unsqueeze(-2) + self.bias[i]
            # print("e_activations.shape = ", e_activations.shape)
            head_outputs.append(e_activations)

        # Average the outputs from all heads
        # e_activations_avg.shape = (1024, 4, 4, 32)
        e_activations_avg = torch.cat(head_outputs, dim=-1)
        # print("e_activations_avg.shape = ", e_activations_avg.shape)
        
        if self.use_activation_fn:
            e_scores = torch.tanh(e_activations_avg) @ self.a
        else:
            e_scores = e_activations_avg @ self.a

        # attentions.shape = (1024, 4, 4)
        attentions = e_scores

        return attentions

class CoAttentionLayerImproved(nn.Module):
    def __init__(self, n_features, use_activation_fn=True, dropout=0.1, n_heads=2):
        super().__init__()
        self.n_features = n_features
        self.n_heads = n_heads
        self.head_dim = n_features // n_heads

        # Projects for queries and keys per head
        self.w_q = nn.Parameter(torch.zeros(self.head_dim, self.head_dim // 2))
        self.w_k = nn.Parameter(torch.zeros(self.head_dim, self.head_dim // 2))
        self.bias = nn.Parameter(torch.zeros(self.n_features // 2))
        self.a = nn.Parameter(torch.zeros(self.n_features // 2))
        self.use_activation_fn = use_activation_fn

        self.dropout = nn.Dropout(dropout)

        nn.init.xavier_uniform_(self.w_q)
        nn.init.xavier_uniform_(self.w_k)
        nn.init.xavier_uniform_(self.bias.view(*self.bias.shape, -1))
        nn.init.xavier_uniform_(self.a.view(*self.a.shape, -1))

    def forward(self, receiver, attendant):
        # receiver.shape  = (1024, 4, 64)
        # attendant.shape = (1024, 4, 64)

        # Split reciever and attendant into multiple heads
        batch_size, gat_size, n_features = receiver.shape
        receiver = receiver.view(batch_size, gat_size, self.n_heads, self.head_dim)
        attendant = attendant.view(batch_size, gat_size, self.n_heads, self.head_dim)
        
        # Compute keys and queries per head
        # receiver.shape  = (1024, 4, 2, 32)
        # attendant.shape = (1024, 4, 2, 32)
        
        # self.w_k.shape  = (32, 16)
        # self.w_q.shape  = (32, 16)
        
        # self.keys.shape     = (1024, 4, 2, 16)
        # self.queries.shape  = (1024, 4, 2, 16)
        keys = receiver @ self.w_k
        queries = attendant @ self.w_q

        # self.keys.shape     = (1024, 4, 32)
        # self.queries.shape  = (1024, 4, 32)
        keys    = keys.view(batch_size, gat_size, self.head_dim)
        queries = queries.view(batch_size, gat_size, self.head_dim)
        # print("keys.shape", keys.shape)
        # print("queries.shape", queries.shape)

        # e_activations.shape = (1024, 4, 4, 32)
        # self.a.shape = (32,)
        e_activations = queries.unsqueeze(-3) + keys.unsqueeze(-2) + self.bias
        if self.use_activation_fn:
            e_scores = torch.tanh(e_activations) @ self.a
        else:
            e_scores = e_activations @ self.a

        # attentions.shape = (1024, 4, 4)
        attentions = e_scores

        return attentions


class RESCAL(nn.Module):
    def __init__(self, n_rels, n_features):
        """
        n_rels: number of relations = 86
        n_features: kge_dim = 64
        """
        super().__init__()
        self.n_rels = n_rels
        self.n_features = n_features
        # Embedding layer
        self.rel_emb = nn.Embedding(self.n_rels, n_features * n_features)
        #  Initializes the embedding weights with the Xavier uniform distribution, which helps maintain the scale of gradients during training
        nn.init.xavier_uniform_(self.rel_emb.weight)

    def forward(self, heads, tails, rels, alpha_scores):
        rels = self.rel_emb(rels)
        rels = F.normalize(rels, dim=-1)
        heads = F.normalize(heads, dim=-1)
        tails = F.normalize(tails, dim=-1)
        # print(rels.shape)
        # Convert shape (1024, 4096) to (1024, 64, 64) for dot product
        rels = rels.view(-1, self.n_features, self.n_features)
        # print(rels.shape)
        # (1024, 4, 64) @ (1024, 64, 64) = (1024, 4, 64) @ (1024, 64, 4) = (1024, 4, 4)
        scores = heads @ rels @ tails.transpose(-2, -1)

        # alpha_scores.shape = (1024, 4, 4)
        # scores.shape = (1024, 4, 4)
        if alpha_scores is not None:
            scores = alpha_scores * scores
        # print(scores.shape)
        
        # sum the last 2 dimensions
        scores = scores.sum(dim=(-2, -1))
        
        # print(scores.shape)
        # Shape(1024,)
        return scores

    def __repr__(self):
        return f"{self.__class__.__name__}({self.n_rels}, {self.rel_emb.weight.shape})"



class ComplEx(nn.Module):
    def __init__(self, n_rels, n_features):
        super().__init__()
        self.n_rels = n_rels
        self.n_features = n_features
    
        # Relation embeddings are also complex
        self.rel_real = nn.Embedding(self.n_rels, (self.n_features // 2) * (self.n_features // 2))
        self.rel_imag = nn.Embedding(self.n_rels, (self.n_features // 2) * (self.n_features // 2))
        
        # Initialize embeddings
        nn.init.xavier_uniform_(self.rel_real.weight)
        nn.init.xavier_uniform_(self.rel_imag.weight)

    def forward(self, heads, tails, rels, alpha_scores=None):
        # Preprocess
        # heads = F.normalize(heads, dim=-1)
        # tails = F.normalize(tails, dim=-1)
        
        r_real, r_imag = self.rel_real(rels), self.rel_imag(rels)
        r_real = F.normalize(r_real, dim=-1)
        r_imag = F.normalize(r_imag, dim=-1)
        # print(r_real.shape)
        r_real = r_real.view(-1, self.n_features // 2, self.n_features // 2)
        r_imag = r_imag.view(-1, self.n_features // 2, self.n_features // 2)
        # print(r_real.shape)
        # Split heads and tails to imaginary parts
        h_real, h_imag = heads[..., :self.n_features // 2], heads[..., self.n_features // 2:]
        t_real, t_imag = tails[..., :self.n_features // 2], heads[..., self.n_features // 2:]

        h_real, h_imag = F.normalize(h_real, dim=-1), F.normalize(h_imag, dim=-1)
        t_real, t_imag = F.normalize(t_real, dim=-1), F.normalize(t_imag, dim=-1)

        # ComplEx scoring functionn
        first_part_score = h_real @ r_real @ t_real.transpose(-2, -1)
        second_part_score = h_real @ r_imag @ t_imag.transpose(-2, -1)
        third_part_score = h_imag @ r_real @ t_imag.transpose(-2, -1)
        fourth_part_score = h_imag @ r_imag @ t_real.transpose(-2, -1)

        scores = first_part_score + second_part_score + third_part_score + fourth_part_score
        
        # If alpha_scores is provided, apply it
        if alpha_scores is not None:
            scores = alpha_scores * scores

        scores = scores.sum(dim=(-2, -1))
        
        return scores

    def __repr__(self):
        return f"{self.__class__.__name__}({self.n_rels}, {self.rel_real.weight.shape}, {self.rel_imag.weight.shape})"


In [9]:
class SSI_DDI(nn.Module):
    def __init__(
        self,
        in_features,
        hidd_dim,
        kge_dim,
        rel_total,
        heads_out_feat_params,
        blocks_params,
        sp_ratio,
        use_activation_fn,
        use_ComplEx,
        sp_min_score,
        co_attention_method,
    ):
        """
        blocks_params: list of number layers for multi-head attentions
        """
        super().__init__()
        self.in_features = in_features
        # not using this one
        self.hidd_dim = hidd_dim
        self.rel_total = rel_total
        self.kge_dim = kge_dim
        self.n_blocks = len(blocks_params)

        self.initial_norm = LayerNorm(self.in_features)
        self.blocks = []
        self.use_activation_fn = use_activation_fn
        self.use_ComplEx = use_ComplEx
        # Layer normalization list
        self.net_norms = ModuleList()
        for i, (head_out_feats, n_heads) in enumerate(
            zip(heads_out_feat_params, blocks_params)
        ):
            block = SSI_DDI_Block(
                n_heads, in_features, head_out_feats, final_out_feats=self.hidd_dim, sp_ratio=sp_ratio, sp_min_score=sp_min_score
            )
            self.add_module(f"block{i}", block)
            self.blocks.append(block)
            self.net_norms.append(LayerNorm(head_out_feats * n_heads))
            in_features = head_out_feats * n_heads

        if co_attention_method == "multihead":
            self.co_attention = MultiheadCoAttentionLayer(self.kge_dim, self.use_activation_fn)
        elif co_attention_method == "improved":
            self.co_attention = CoAttentionLayerImproved(self.kge_dim, self.use_activation_fn)
        else:
            self.co_attention = CoAttentionLayer(self.kge_dim, self.use_activation_fn)
            
        if self.use_ComplEx:
            self.KGE = ComplEx(self.rel_total, self.kge_dim)
        else:
            self.KGE = RESCAL(self.rel_total, self.kge_dim)

    def forward(self, triples):
        h_data, t_data, rels = triples

        h_data.x = self.initial_norm(h_data.x, h_data.batch)
        t_data.x = self.initial_norm(t_data.x, t_data.batch)

        repr_h = []
        repr_t = []

        for i, block in enumerate(self.blocks):
            out1, out2 = block(h_data), block(t_data)

            h_data = out1[0]
            t_data = out2[0]
            r_h = out1[1]
            r_t = out2[1]

            repr_h.append(r_h)
            repr_t.append(r_t)

            h_data.x = F.elu(self.net_norms[i](h_data.x, h_data.batch))
            t_data.x = F.elu(self.net_norms[i](t_data.x, t_data.batch))

        repr_h = torch.stack(repr_h, dim=-2)
        repr_t = torch.stack(repr_t, dim=-2)

        kge_heads = repr_h
        kge_tails = repr_t

        attentions = self.co_attention(kge_heads, kge_tails)
        # attentions = None
        scores = self.KGE(kge_heads, kge_tails, rels, attentions)

        return scores


class SSI_DDI_Block(nn.Module):
    def __init__(self, n_heads, in_features, head_out_feats, final_out_feats, sp_ratio, sp_min_score):
        """
        n_heades: number of multi-head attentions = 2
        in_features: number of features = 55 . For explicit valence use, number of features = 56.
        head_out_feats: number of out features. For 4 layers: [32, 32, 32, 32]
        sp_ratio: SAGPooling ratio
        """
        super().__init__()
        self.n_heads = n_heads
        self.in_features = in_features
        self.out_features = head_out_feats
        self.conv = GATConv(in_features, head_out_feats, n_heads)
        # SAGPooling: Ranks nodes based on self-attention scores

        if sp_ratio is None and sp_min_score is None:
            self.readout = SAGPooling(n_heads * head_out_feats, min_score=-1)
        else:
            if sp_ratio is not None:
                self.readout = SAGPooling(n_heads * head_out_feats, min_score=sp_min_score, ratio=sp_ratio)
            else:
                self.readout = SAGPooling(n_heads * head_out_feats, min_score=sp_min_score)

    def forward(self, data):
        data.x = self.conv(data.x, data.edge_index)
        # Call SAGPooling here
        # If min_score = -1 so nodes will not be filtered out, basically redudant for using the SAGPooling.
        att_x, att_edge_index, att_edge_attr, att_batch, att_perm, att_scores = (
            self.readout(data.x, data.edge_index, batch=data.batch)
        )
        # Aggregates the pooled node features (att_x) across the graph to obtain a global representation
        global_graph_emb = global_add_pool(att_x, att_batch)

        # data = max_pool_neighbor_x(data)
        return data, global_graph_emb


In [10]:
class SigmoidLoss(nn.Module):
    def __init__(self, adv_temperature=None):
        super().__init__()
        self.adv_temperature = adv_temperature

    def forward(self, p_scores, n_scores):
        if self.adv_temperature:
            weights = F.softmax(self.adv_temperature * n_scores, dim=-1).detach()
            n_scores = weights * n_scores
        p_loss = -F.logsigmoid(p_scores).mean()
        n_loss = -F.logsigmoid(-n_scores).mean()

        return (p_loss + n_loss) / 2, p_loss, n_loss


In [11]:
df_ddi_train = pd.read_csv(f"{data_dir}/ddi_training.csv")
df_ddi_val = pd.read_csv(f"{data_dir}/ddi_validation.csv")
df_ddi_test = pd.read_csv(f"{data_dir}/ddi_test.csv")


train_tup = [
    (h, t, r)
    for h, t, r in zip(df_ddi_train["d1"], df_ddi_train["d2"], df_ddi_train["type"])
]
val_tup = [
    (h, t, r) for h, t, r in zip(df_ddi_val["d1"], df_ddi_val["d2"], df_ddi_val["type"])
]
test_tup = [
    (h, t, r)
    for h, t, r in zip(df_ddi_test["d1"], df_ddi_test["d2"], df_ddi_test["type"])
]

train_data = DrugDataset(train_tup, ratio=data_size_ratio, neg_ent=neg_samples)
val_data = DrugDataset(val_tup, ratio=data_size_ratio, disjoint_split=False)
test_data = DrugDataset(test_tup, disjoint_split=False)

print(
    f"Training with {len(train_data)} samples, validating with {len(val_data)}, and testing with {len(test_data)}"
)

train_data_loader = DrugDataLoader(train_data, batch_size=batch_size, shuffle=True)
val_data_loader = DrugDataLoader(val_data, batch_size=batch_size * 3)
test_data_loader = DrugDataLoader(test_data, batch_size=batch_size * 3)


Training with 115185 samples, validating with 38348, and testing with 38337


In [12]:
def do_compute(model, batch, device, training=True):
    """
    *batch: (pos_tri, neg_tri)
    *pos/neg_tri: (batch_h, batch_t, batch_r)
    """
    probas_pred, ground_truth = [], []
    pos_tri, neg_tri = batch

    pos_tri = [tensor.to(device=device) for tensor in pos_tri]
    p_score = model(pos_tri)
    probas_pred.append(torch.sigmoid(p_score.detach()).cpu())
    ground_truth.append(np.ones(len(p_score)))

    neg_tri = [tensor.to(device=device) for tensor in neg_tri]
    n_score = model(neg_tri)
    probas_pred.append(torch.sigmoid(n_score.detach()).cpu())
    ground_truth.append(np.zeros(len(n_score)))

    probas_pred = np.concatenate(probas_pred)
    ground_truth = np.concatenate(ground_truth)

    return p_score, n_score, probas_pred, ground_truth


def do_compute_metrics(probas_pred, target):

    pred = (probas_pred >= 0.5).astype(np.int64)

    acc = metrics.accuracy_score(target, pred)
    auc_roc = metrics.roc_auc_score(target, probas_pred)
    f1_score = metrics.f1_score(target, pred)

    p, r, t = metrics.precision_recall_curve(target, probas_pred)
    auc_prc = metrics.auc(r, p)

    return acc, auc_roc, auc_prc

In [13]:
import csv
def export_metrics(train_metrics, val_metrics, epoch):
    train_metrics_dir = "train_metrics"
    metrics_file = f"{train_metrics_dir}/{model_name}.csv"
    train_loss, train_acc, train_auc_roc, train_auc_prc = train_metrics
    val_loss, val_acc, val_auc_roc, val_auc_prc = val_metrics

    data = [epoch, train_loss, train_acc, train_auc_roc, train_auc_prc, val_loss, val_acc, val_auc_roc, val_auc_prc]
    header = ["epoch", "train_loss", "train_acc", "train_auc_roc", "train_auc_prc", "val_loss", "val_acc", "val_auc_roc", "val_auc_prc"]
    
    if epoch == 1:
        with open(metrics_file, 'w', newline='') as file:
            writer = csv.writer(file)
            # Write the header
            writer.writerow(header)
            # Write the data rows
            writer.writerow(data)
    else:
        with open(metrics_file, 'a', newline='') as file:
            writer = csv.writer(file)
            # Write the data to the file
            writer.writerow(data)
    
    

In [14]:
model_acc_file = f"{model_dir}/acc/{model_name}.pth"
model_roc_file = f"{model_dir}/roc/{model_name}.pth"
model_prc_file = f"{model_dir}/prc/{model_name}.pth"

def save_model(best, current, met_type):
    model_file = model_prc_file
    if met_type == "acc":
        model_file = model_acc_file
    elif met_type == "roc":
        model_file = model_roc_file
        
    if best < current:
        print(f"Saving model {met_type}")
        best = current
        torch.save(model, model_file)
    return best

In [15]:
def train(
    model,
    train_data_loader,
    val_data_loader,
    loss_fn,
    optimizer,
    n_epochs,
    device,
    scheduler=None,
):
    print("Starting training at:", datetime.today())
    print("Device:", device)
    print_tunning_parameters()
    best_val_prc = 0
    best_val_acc = 0
    best_val_roc = 0
    for i in range(1, n_epochs + 1):
        start = time.time()
        train_loss = 0
        train_loss_pos = 0
        train_loss_neg = 0
        val_loss = 0
        val_loss_pos = 0
        val_loss_neg = 0
        train_probas_pred = []
        train_ground_truth = []
        val_probas_pred = []
        val_ground_truth = []

        for batch in train_data_loader:
            # print(len(batch))
            model.train()
            p_score, n_score, probas_pred, ground_truth = do_compute(model, batch, device)
            train_probas_pred.append(probas_pred)
            train_ground_truth.append(ground_truth)
            loss, loss_p, loss_n = loss_fn(p_score, n_score)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_loss += loss.item() * len(p_score)
        train_loss /= len(train_data)

        with torch.no_grad():
            train_probas_pred = np.concatenate(train_probas_pred)
            train_ground_truth = np.concatenate(train_ground_truth)

            train_acc, train_auc_roc, train_auc_prc = do_compute_metrics(
                train_probas_pred, train_ground_truth
            )

            for batch in val_data_loader:
                model.eval()
                p_score, n_score, probas_pred, ground_truth = do_compute(model, batch, device)
                val_probas_pred.append(probas_pred)
                val_ground_truth.append(ground_truth)
                loss, loss_p, loss_n = loss_fn(p_score, n_score)
                val_loss += loss.item() * len(p_score)

            val_loss /= len(val_data)
            val_probas_pred = np.concatenate(val_probas_pred)
            val_ground_truth = np.concatenate(val_ground_truth)
            val_acc, val_auc_roc, val_auc_prc = do_compute_metrics(
                val_probas_pred, val_ground_truth
            )
            
            # Save model if this is the best so far
            best_val_prc = save_model(best_val_prc, val_auc_prc, "prc")
            best_val_acc = save_model(best_val_acc, val_acc, "acc")
            best_val_roc = save_model(best_val_roc, val_auc_roc, "roc")

        if scheduler:
            # print('scheduling')
            scheduler.step()

        # Exporting metrics for later plots
        train_metrics = (train_loss, train_acc, train_auc_roc, train_auc_prc)
        val_metrics = (val_loss, val_acc, val_auc_roc, val_auc_prc)
        export_metrics(train_metrics, val_metrics, i)
        
        print(
            f"Epoch: {i} ({time.time() - start:.4f}s), train_loss: {train_loss:.4f}, val_loss: {val_loss:.4f},"
            f" train_acc: {train_acc:.4f}, val_acc:{val_acc:.4f}"
        )
        print(
            f"\t\ttrain_roc: {train_auc_roc:.4f}, val_roc: {val_auc_roc:.4f}, train_auprc: {train_auc_prc:.4f}, val_auprc: {val_auc_prc:.4f}"
        )

    return model

In [16]:
def predict(model, test_data_loader, device):
    print('Starting predicting at', datetime.today())
    print('Device', device)

    test_probas_pred = []
    test_ground_truth = []

    # Switch to evaluation mode
    model.eval()

    with torch.no_grad():  # No need to calculate gradients during testing
        for batch in test_data_loader:
            # Get predictions and ground truth for the batch
            p_score, n_score, probas_pred, ground_truth = do_compute(model, batch, device, training=False)

            # Append the predictions and ground truths
            test_probas_pred.append(probas_pred)
            test_ground_truth.append(ground_truth)


    # Concatenate the results for the entire test dataset
    test_probas_pred = np.concatenate(test_probas_pred)
    test_ground_truth = np.concatenate(test_ground_truth)

    # Calculate the metrics for the test dataset
    test_acc, test_auc_roc, test_auc_prc = do_compute_metrics(test_probas_pred, test_ground_truth)

    print(f'Test Accuracy: {test_acc:.4f}')
    print(f'Test ROC AUC: {test_auc_roc:.4f}')
    print(f'Test PRC AUC: {test_auc_prc:.4f}')

In [17]:
heads_out_feat_params = []
block_params = []

for _ in range(num_GAT_layers):
    heads_out_feat_params.append(kge_dim // 2)
    block_params.append(num_GAT_multiheads)

if mode == "train":
    model = SSI_DDI(
        n_atom_feats,
        n_atom_hid,
        kge_dim,
        rel_total,
        heads_out_feat_params=heads_out_feat_params,
        blocks_params=block_params,
        sp_ratio=sp_ratio,
        use_activation_fn=use_activation_fn,
        use_ComplEx=use_ComplEx,
        sp_min_score=sp_min_score,
        co_attention_method=co_attention_method,
    )
    loss = SigmoidLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    scheduler = optim.lr_scheduler.LambdaLR(optimizer, lambda epoch: 0.96 ** (epoch))
    print(model)
    model.to(device=device)

SSI_DDI(
  (initial_norm): LayerNorm(55, affine=True, mode=graph)
  (net_norms): ModuleList(
    (0-3): 4 x LayerNorm(64, affine=True, mode=graph)
  )
  (block0): SSI_DDI_Block(
    (conv): GATConv(55, 32, heads=2)
    (readout): SAGPooling(GraphConv, 64, min_score=-1, multiplier=1.0)
  )
  (block1): SSI_DDI_Block(
    (conv): GATConv(64, 32, heads=2)
    (readout): SAGPooling(GraphConv, 64, min_score=-1, multiplier=1.0)
  )
  (block2): SSI_DDI_Block(
    (conv): GATConv(64, 32, heads=2)
    (readout): SAGPooling(GraphConv, 64, min_score=-1, multiplier=1.0)
  )
  (block3): SSI_DDI_Block(
    (conv): GATConv(64, 32, heads=2)
    (readout): SAGPooling(GraphConv, 64, min_score=-1, multiplier=1.0)
  )
  (co_attention): MultiheadCoAttentionLayer(
    (W_q): ParameterList(
        (0): Parameter containing: [torch.float32 of size 64x32]
        (1): Parameter containing: [torch.float32 of size 64x32]
    )
    (W_k): ParameterList(
        (0): Parameter containing: [torch.float32 of size 64

In [18]:
if mode == "train":
  # Train
  train(
      model,
      train_data_loader,
      val_data_loader,
      loss,
      optimizer,
      n_epochs,
      device,
      scheduler,
  )


Starting training at: 2024-10-28 19:11:46.439827
Device: cuda

####### Tunning parameters #######

n_epochs = 300
use_cuda = True

num_GAT_layers =  4
num_GAT_multiheads =  2

sp_ratio = None
sp_min_score = None

use_explicit_valence = False

use_activation_fn = False

use_ComplEx = True

co_attention_method = multihead

#################################





Saving model prc
Saving model acc
Saving model roc
Epoch: 1 (77.5617s), train_loss: 0.7010, val_loss: 0.6408, train_acc: 0.5575, val_acc:0.6200
		train_roc: 0.5841, val_roc: 0.6763, train_auprc: 0.5680, val_auprc: 0.6648




Saving model prc
Saving model acc
Saving model roc
Epoch: 2 (58.1329s), train_loss: 0.5996, val_loss: 0.5728, train_acc: 0.6678, val_acc:0.6977
		train_roc: 0.7354, val_roc: 0.7700, train_auprc: 0.7187, val_auprc: 0.7477




Saving model prc
Saving model acc
Saving model roc
Epoch: 3 (58.1444s), train_loss: 0.5460, val_loss: 0.5307, train_acc: 0.7202, val_acc:0.7335
		train_roc: 0.7945, val_roc: 0.8100, train_auprc: 0.7733, val_auprc: 0.7874




Saving model prc
Saving model acc
Saving model roc
Epoch: 4 (57.9295s), train_loss: 0.5207, val_loss: 0.5085, train_acc: 0.7401, val_acc:0.7491
		train_roc: 0.8172, val_roc: 0.8278, train_auprc: 0.7956, val_auprc: 0.8078




Saving model prc
Saving model acc
Saving model roc
Epoch: 5 (58.1650s), train_loss: 0.5016, val_loss: 0.4977, train_acc: 0.7557, val_acc:0.7572
		train_roc: 0.8331, val_roc: 0.8358, train_auprc: 0.8121, val_auprc: 0.8143




Saving model prc
Saving model acc
Saving model roc
Epoch: 6 (59.9176s), train_loss: 0.4872, val_loss: 0.4903, train_acc: 0.7666, val_acc:0.7647
		train_roc: 0.8443, val_roc: 0.8429, train_auprc: 0.8234, val_auprc: 0.8229




Saving model prc
Saving model acc
Saving model roc
Epoch: 7 (57.9075s), train_loss: 0.4765, val_loss: 0.4760, train_acc: 0.7741, val_acc:0.7705
		train_roc: 0.8520, val_roc: 0.8536, train_auprc: 0.8318, val_auprc: 0.8333




Saving model prc
Saving model acc
Saving model roc
Epoch: 8 (58.0067s), train_loss: 0.4669, val_loss: 0.4715, train_acc: 0.7807, val_acc:0.7758
		train_roc: 0.8585, val_roc: 0.8572, train_auprc: 0.8380, val_auprc: 0.8372




Saving model prc
Saving model acc
Saving model roc
Epoch: 9 (57.9075s), train_loss: 0.4588, val_loss: 0.4645, train_acc: 0.7870, val_acc:0.7829
		train_roc: 0.8643, val_roc: 0.8608, train_auprc: 0.8442, val_auprc: 0.8409




Saving model prc
Saving model acc
Saving model roc
Epoch: 10 (57.7212s), train_loss: 0.4533, val_loss: 0.4548, train_acc: 0.7898, val_acc:0.7920
		train_roc: 0.8675, val_roc: 0.8696, train_auprc: 0.8480, val_auprc: 0.8500




Saving model prc
Saving model acc
Saving model roc
Epoch: 11 (58.1197s), train_loss: 0.4447, val_loss: 0.4404, train_acc: 0.7958, val_acc:0.7988
		train_roc: 0.8732, val_roc: 0.8768, train_auprc: 0.8538, val_auprc: 0.8612




Epoch: 12 (58.0188s), train_loss: 0.4385, val_loss: 0.4429, train_acc: 0.7995, val_acc:0.7938
		train_roc: 0.8770, val_roc: 0.8762, train_auprc: 0.8581, val_auprc: 0.8567




Saving model prc
Saving model acc
Saving model roc
Epoch: 13 (57.9321s), train_loss: 0.4305, val_loss: 0.4306, train_acc: 0.8061, val_acc:0.8051
		train_roc: 0.8822, val_roc: 0.8827, train_auprc: 0.8627, val_auprc: 0.8657




Saving model prc
Saving model acc
Saving model roc
Epoch: 14 (57.9525s), train_loss: 0.4246, val_loss: 0.4239, train_acc: 0.8084, val_acc:0.8077
		train_roc: 0.8854, val_roc: 0.8860, train_auprc: 0.8666, val_auprc: 0.8691




Saving model acc
Epoch: 15 (58.0521s), train_loss: 0.4177, val_loss: 0.4231, train_acc: 0.8127, val_acc:0.8095
		train_roc: 0.8892, val_roc: 0.8851, train_auprc: 0.8708, val_auprc: 0.8666




Saving model acc
Saving model roc
Epoch: 16 (57.9292s), train_loss: 0.4153, val_loss: 0.4204, train_acc: 0.8139, val_acc:0.8124
		train_roc: 0.8904, val_roc: 0.8877, train_auprc: 0.8717, val_auprc: 0.8682




Saving model prc
Saving model acc
Saving model roc
Epoch: 17 (58.0344s), train_loss: 0.4098, val_loss: 0.4095, train_acc: 0.8174, val_acc:0.8185
		train_roc: 0.8935, val_roc: 0.8940, train_auprc: 0.8752, val_auprc: 0.8764




Epoch: 18 (57.9664s), train_loss: 0.4049, val_loss: 0.4153, train_acc: 0.8207, val_acc:0.8168
		train_roc: 0.8964, val_roc: 0.8915, train_auprc: 0.8778, val_auprc: 0.8714




Saving model prc
Saving model acc
Saving model roc
Epoch: 19 (57.8059s), train_loss: 0.3981, val_loss: 0.4042, train_acc: 0.8242, val_acc:0.8199
		train_roc: 0.8996, val_roc: 0.8969, train_auprc: 0.8819, val_auprc: 0.8811




Saving model prc
Saving model acc
Saving model roc
Epoch: 20 (57.7444s), train_loss: 0.3958, val_loss: 0.3970, train_acc: 0.8262, val_acc:0.8257
		train_roc: 0.9010, val_roc: 0.9002, train_auprc: 0.8832, val_auprc: 0.8834




Epoch: 21 (58.0809s), train_loss: 0.3914, val_loss: 0.3992, train_acc: 0.8292, val_acc:0.8248
		train_roc: 0.9033, val_roc: 0.8989, train_auprc: 0.8852, val_auprc: 0.8808




Saving model prc
Saving model acc
Saving model roc
Epoch: 22 (58.1940s), train_loss: 0.3862, val_loss: 0.3941, train_acc: 0.8315, val_acc:0.8268
		train_roc: 0.9060, val_roc: 0.9031, train_auprc: 0.8883, val_auprc: 0.8874




Saving model prc
Saving model acc
Saving model roc
Epoch: 23 (57.8935s), train_loss: 0.3804, val_loss: 0.3893, train_acc: 0.8353, val_acc:0.8304
		train_roc: 0.9087, val_roc: 0.9048, train_auprc: 0.8915, val_auprc: 0.8889




Saving model prc
Saving model acc
Saving model roc
Epoch: 24 (58.0899s), train_loss: 0.3776, val_loss: 0.3831, train_acc: 0.8357, val_acc:0.8333
		train_roc: 0.9103, val_roc: 0.9091, train_auprc: 0.8942, val_auprc: 0.8944




Saving model acc
Epoch: 25 (57.8469s), train_loss: 0.3729, val_loss: 0.3840, train_acc: 0.8386, val_acc:0.8335
		train_roc: 0.9123, val_roc: 0.9074, train_auprc: 0.8962, val_auprc: 0.8919




Saving model prc
Saving model acc
Saving model roc
Epoch: 26 (58.0212s), train_loss: 0.3705, val_loss: 0.3734, train_acc: 0.8403, val_acc:0.8392
		train_roc: 0.9136, val_roc: 0.9123, train_auprc: 0.8972, val_auprc: 0.8975




Saving model prc
Saving model acc
Saving model roc
Epoch: 27 (58.1772s), train_loss: 0.3644, val_loss: 0.3711, train_acc: 0.8435, val_acc:0.8434
		train_roc: 0.9162, val_roc: 0.9141, train_auprc: 0.8999, val_auprc: 0.8992




Epoch: 28 (57.8677s), train_loss: 0.3621, val_loss: 0.3722, train_acc: 0.8455, val_acc:0.8421
		train_roc: 0.9174, val_roc: 0.9135, train_auprc: 0.9013, val_auprc: 0.8980




Saving model prc
Saving model roc
Epoch: 29 (58.0258s), train_loss: 0.3570, val_loss: 0.3675, train_acc: 0.8475, val_acc:0.8418
		train_roc: 0.9196, val_roc: 0.9152, train_auprc: 0.9039, val_auprc: 0.8994




Saving model prc
Epoch: 30 (58.1938s), train_loss: 0.3553, val_loss: 0.3683, train_acc: 0.8494, val_acc:0.8395
		train_roc: 0.9202, val_roc: 0.9146, train_auprc: 0.9038, val_auprc: 0.9008




Saving model acc
Saving model roc
Epoch: 31 (58.0731s), train_loss: 0.3500, val_loss: 0.3631, train_acc: 0.8513, val_acc:0.8447
		train_roc: 0.9227, val_roc: 0.9165, train_auprc: 0.9076, val_auprc: 0.9004




Saving model prc
Saving model acc
Saving model roc
Epoch: 32 (58.0147s), train_loss: 0.3506, val_loss: 0.3602, train_acc: 0.8520, val_acc:0.8478
		train_roc: 0.9224, val_roc: 0.9191, train_auprc: 0.9063, val_auprc: 0.9042




Saving model prc
Saving model acc
Saving model roc
Epoch: 33 (58.2114s), train_loss: 0.3450, val_loss: 0.3564, train_acc: 0.8549, val_acc:0.8494
		train_roc: 0.9247, val_roc: 0.9208, train_auprc: 0.9087, val_auprc: 0.9057




Saving model prc
Saving model roc
Epoch: 34 (58.3451s), train_loss: 0.3423, val_loss: 0.3534, train_acc: 0.8555, val_acc:0.8493
		train_roc: 0.9259, val_roc: 0.9220, train_auprc: 0.9103, val_auprc: 0.9077




Saving model prc
Saving model acc
Saving model roc
Epoch: 35 (57.9669s), train_loss: 0.3362, val_loss: 0.3517, train_acc: 0.8599, val_acc:0.8512
		train_roc: 0.9286, val_roc: 0.9228, train_auprc: 0.9133, val_auprc: 0.9087




Saving model prc
Saving model acc
Saving model roc
Epoch: 36 (58.0443s), train_loss: 0.3342, val_loss: 0.3476, train_acc: 0.8603, val_acc:0.8544
		train_roc: 0.9292, val_roc: 0.9245, train_auprc: 0.9146, val_auprc: 0.9109




Saving model prc
Saving model acc
Saving model roc
Epoch: 37 (58.5350s), train_loss: 0.3311, val_loss: 0.3439, train_acc: 0.8611, val_acc:0.8557
		train_roc: 0.9305, val_roc: 0.9260, train_auprc: 0.9160, val_auprc: 0.9131




Epoch: 38 (57.9492s), train_loss: 0.3286, val_loss: 0.3452, train_acc: 0.8637, val_acc:0.8550
		train_roc: 0.9315, val_roc: 0.9255, train_auprc: 0.9169, val_auprc: 0.9127




Saving model prc
Saving model acc
Saving model roc
Epoch: 39 (63.4974s), train_loss: 0.3248, val_loss: 0.3351, train_acc: 0.8652, val_acc:0.8611
		train_roc: 0.9331, val_roc: 0.9293, train_auprc: 0.9190, val_auprc: 0.9162




Saving model prc
Saving model roc
Epoch: 40 (59.1427s), train_loss: 0.3235, val_loss: 0.3344, train_acc: 0.8667, val_acc:0.8604
		train_roc: 0.9335, val_roc: 0.9299, train_auprc: 0.9188, val_auprc: 0.9170




Epoch: 41 (58.6189s), train_loss: 0.3218, val_loss: 0.3350, train_acc: 0.8673, val_acc:0.8601
		train_roc: 0.9342, val_roc: 0.9299, train_auprc: 0.9196, val_auprc: 0.9165




Saving model prc
Epoch: 42 (58.8438s), train_loss: 0.3168, val_loss: 0.3349, train_acc: 0.8698, val_acc:0.8593
		train_roc: 0.9365, val_roc: 0.9296, train_auprc: 0.9226, val_auprc: 0.9172




Saving model prc
Saving model roc
Epoch: 43 (58.4797s), train_loss: 0.3134, val_loss: 0.3350, train_acc: 0.8708, val_acc:0.8593
		train_roc: 0.9372, val_roc: 0.9300, train_auprc: 0.9237, val_auprc: 0.9175




Saving model acc
Saving model roc
Epoch: 44 (58.1411s), train_loss: 0.3108, val_loss: 0.3324, train_acc: 0.8725, val_acc:0.8626
		train_roc: 0.9384, val_roc: 0.9306, train_auprc: 0.9248, val_auprc: 0.9168




Saving model prc
Saving model acc
Saving model roc
Epoch: 45 (60.3029s), train_loss: 0.3079, val_loss: 0.3232, train_acc: 0.8739, val_acc:0.8669
		train_roc: 0.9394, val_roc: 0.9344, train_auprc: 0.9261, val_auprc: 0.9215




Saving model prc
Epoch: 46 (61.7347s), train_loss: 0.3062, val_loss: 0.3256, train_acc: 0.8743, val_acc:0.8658
		train_roc: 0.9399, val_roc: 0.9342, train_auprc: 0.9264, val_auprc: 0.9216




Epoch: 47 (59.9794s), train_loss: 0.3030, val_loss: 0.3267, train_acc: 0.8760, val_acc:0.8643
		train_roc: 0.9413, val_roc: 0.9334, train_auprc: 0.9286, val_auprc: 0.9203




Saving model prc
Saving model acc
Saving model roc
Epoch: 48 (59.3297s), train_loss: 0.3025, val_loss: 0.3219, train_acc: 0.8770, val_acc:0.8680
		train_roc: 0.9414, val_roc: 0.9349, train_auprc: 0.9282, val_auprc: 0.9223




Saving model prc
Saving model acc
Saving model roc
Epoch: 49 (59.0902s), train_loss: 0.2993, val_loss: 0.3199, train_acc: 0.8784, val_acc:0.8683
		train_roc: 0.9426, val_roc: 0.9364, train_auprc: 0.9294, val_auprc: 0.9247




Saving model roc
Epoch: 50 (59.4015s), train_loss: 0.2956, val_loss: 0.3200, train_acc: 0.8801, val_acc:0.8671
		train_roc: 0.9440, val_roc: 0.9366, train_auprc: 0.9317, val_auprc: 0.9245




Saving model prc
Saving model acc
Saving model roc
Epoch: 51 (61.5825s), train_loss: 0.2960, val_loss: 0.3180, train_acc: 0.8794, val_acc:0.8711
		train_roc: 0.9437, val_roc: 0.9373, train_auprc: 0.9307, val_auprc: 0.9254




Saving model prc
Saving model acc
Saving model roc
Epoch: 52 (60.9351s), train_loss: 0.2919, val_loss: 0.3135, train_acc: 0.8816, val_acc:0.8727
		train_roc: 0.9451, val_roc: 0.9386, train_auprc: 0.9327, val_auprc: 0.9273




Epoch: 53 (60.3161s), train_loss: 0.2895, val_loss: 0.3171, train_acc: 0.8822, val_acc:0.8700
		train_roc: 0.9460, val_roc: 0.9374, train_auprc: 0.9338, val_auprc: 0.9256




Saving model prc
Saving model acc
Saving model roc
Epoch: 54 (58.7905s), train_loss: 0.2885, val_loss: 0.3095, train_acc: 0.8840, val_acc:0.8759
		train_roc: 0.9463, val_roc: 0.9403, train_auprc: 0.9338, val_auprc: 0.9296




Epoch: 55 (58.6731s), train_loss: 0.2870, val_loss: 0.3115, train_acc: 0.8846, val_acc:0.8728
		train_roc: 0.9469, val_roc: 0.9396, train_auprc: 0.9346, val_auprc: 0.9276




Epoch: 56 (62.5846s), train_loss: 0.2839, val_loss: 0.3095, train_acc: 0.8862, val_acc:0.8753
		train_roc: 0.9479, val_roc: 0.9401, train_auprc: 0.9358, val_auprc: 0.9288




Saving model roc
Epoch: 57 (59.9687s), train_loss: 0.2839, val_loss: 0.3113, train_acc: 0.8859, val_acc:0.8729
		train_roc: 0.9479, val_roc: 0.9408, train_auprc: 0.9357, val_auprc: 0.9291




Saving model prc
Saving model acc
Saving model roc
Epoch: 58 (61.9450s), train_loss: 0.2799, val_loss: 0.3032, train_acc: 0.8882, val_acc:0.8769
		train_roc: 0.9492, val_roc: 0.9426, train_auprc: 0.9371, val_auprc: 0.9321




Epoch: 59 (61.0605s), train_loss: 0.2799, val_loss: 0.3078, train_acc: 0.8874, val_acc:0.8740
		train_roc: 0.9491, val_roc: 0.9412, train_auprc: 0.9372, val_auprc: 0.9305




Saving model acc
Saving model roc
Epoch: 60 (61.1180s), train_loss: 0.2770, val_loss: 0.3018, train_acc: 0.8889, val_acc:0.8789
		train_roc: 0.9501, val_roc: 0.9430, train_auprc: 0.9384, val_auprc: 0.9320




Epoch: 61 (61.6841s), train_loss: 0.2740, val_loss: 0.3035, train_acc: 0.8901, val_acc:0.8773
		train_roc: 0.9510, val_roc: 0.9429, train_auprc: 0.9400, val_auprc: 0.9317




Saving model prc
Saving model acc
Saving model roc
Epoch: 62 (60.9217s), train_loss: 0.2746, val_loss: 0.2978, train_acc: 0.8907, val_acc:0.8813
		train_roc: 0.9510, val_roc: 0.9452, train_auprc: 0.9394, val_auprc: 0.9355




Epoch: 63 (60.6680s), train_loss: 0.2731, val_loss: 0.3023, train_acc: 0.8914, val_acc:0.8791
		train_roc: 0.9516, val_roc: 0.9441, train_auprc: 0.9399, val_auprc: 0.9335




Epoch: 64 (59.3734s), train_loss: 0.2719, val_loss: 0.3007, train_acc: 0.8911, val_acc:0.8804
		train_roc: 0.9516, val_roc: 0.9438, train_auprc: 0.9403, val_auprc: 0.9335




Epoch: 65 (59.4276s), train_loss: 0.2683, val_loss: 0.2982, train_acc: 0.8929, val_acc:0.8809
		train_roc: 0.9530, val_roc: 0.9447, train_auprc: 0.9420, val_auprc: 0.9349




Saving model acc
Saving model roc
Epoch: 66 (59.1906s), train_loss: 0.2663, val_loss: 0.2966, train_acc: 0.8945, val_acc:0.8825
		train_roc: 0.9537, val_roc: 0.9454, train_auprc: 0.9426, val_auprc: 0.9351




Saving model acc
Saving model roc
Epoch: 67 (59.2977s), train_loss: 0.2643, val_loss: 0.2951, train_acc: 0.8950, val_acc:0.8827
		train_roc: 0.9542, val_roc: 0.9457, train_auprc: 0.9437, val_auprc: 0.9354




Saving model prc
Saving model roc
Epoch: 68 (59.6402s), train_loss: 0.2634, val_loss: 0.2960, train_acc: 0.8957, val_acc:0.8816
		train_roc: 0.9546, val_roc: 0.9458, train_auprc: 0.9437, val_auprc: 0.9362




Saving model prc
Saving model roc
Epoch: 69 (59.3180s), train_loss: 0.2630, val_loss: 0.2951, train_acc: 0.8959, val_acc:0.8819
		train_roc: 0.9547, val_roc: 0.9473, train_auprc: 0.9441, val_auprc: 0.9380




Saving model acc
Epoch: 70 (59.3563s), train_loss: 0.2612, val_loss: 0.2934, train_acc: 0.8962, val_acc:0.8839
		train_roc: 0.9551, val_roc: 0.9471, train_auprc: 0.9444, val_auprc: 0.9372




Epoch: 71 (59.3468s), train_loss: 0.2607, val_loss: 0.2937, train_acc: 0.8961, val_acc:0.8834
		train_roc: 0.9553, val_roc: 0.9468, train_auprc: 0.9448, val_auprc: 0.9374




Saving model prc
Saving model acc
Saving model roc
Epoch: 72 (59.3879s), train_loss: 0.2585, val_loss: 0.2899, train_acc: 0.8976, val_acc:0.8857
		train_roc: 0.9559, val_roc: 0.9478, train_auprc: 0.9457, val_auprc: 0.9389




Epoch: 73 (59.3495s), train_loss: 0.2584, val_loss: 0.2911, train_acc: 0.8976, val_acc:0.8843
		train_roc: 0.9560, val_roc: 0.9476, train_auprc: 0.9457, val_auprc: 0.9388




Saving model acc
Saving model roc
Epoch: 74 (59.2216s), train_loss: 0.2563, val_loss: 0.2907, train_acc: 0.8988, val_acc:0.8860
		train_roc: 0.9567, val_roc: 0.9480, train_auprc: 0.9462, val_auprc: 0.9384




Epoch: 75 (59.2097s), train_loss: 0.2547, val_loss: 0.2923, train_acc: 0.8990, val_acc:0.8842
		train_roc: 0.9570, val_roc: 0.9469, train_auprc: 0.9469, val_auprc: 0.9371




Epoch: 76 (59.3271s), train_loss: 0.2540, val_loss: 0.2928, train_acc: 0.8998, val_acc:0.8843
		train_roc: 0.9573, val_roc: 0.9473, train_auprc: 0.9470, val_auprc: 0.9378




Saving model prc
Saving model acc
Saving model roc
Epoch: 77 (59.2001s), train_loss: 0.2535, val_loss: 0.2889, train_acc: 0.9001, val_acc:0.8862
		train_roc: 0.9573, val_roc: 0.9488, train_auprc: 0.9472, val_auprc: 0.9399




Saving model prc
Saving model acc
Saving model roc
Epoch: 78 (59.0844s), train_loss: 0.2513, val_loss: 0.2876, train_acc: 0.9013, val_acc:0.8866
		train_roc: 0.9582, val_roc: 0.9491, train_auprc: 0.9480, val_auprc: 0.9403




Saving model acc
Saving model roc
Epoch: 79 (59.0097s), train_loss: 0.2500, val_loss: 0.2876, train_acc: 0.9017, val_acc:0.8872
		train_roc: 0.9584, val_roc: 0.9491, train_auprc: 0.9487, val_auprc: 0.9394




Saving model prc
Saving model acc
Saving model roc
Epoch: 80 (59.0629s), train_loss: 0.2491, val_loss: 0.2852, train_acc: 0.9021, val_acc:0.8884
		train_roc: 0.9588, val_roc: 0.9504, train_auprc: 0.9489, val_auprc: 0.9423




Epoch: 81 (58.9958s), train_loss: 0.2499, val_loss: 0.2873, train_acc: 0.9017, val_acc:0.8876
		train_roc: 0.9587, val_roc: 0.9499, train_auprc: 0.9487, val_auprc: 0.9412




Epoch: 82 (59.0938s), train_loss: 0.2484, val_loss: 0.2868, train_acc: 0.9022, val_acc:0.8869
		train_roc: 0.9590, val_roc: 0.9495, train_auprc: 0.9494, val_auprc: 0.9409




Epoch: 83 (58.7584s), train_loss: 0.2458, val_loss: 0.2862, train_acc: 0.9039, val_acc:0.8883
		train_roc: 0.9598, val_roc: 0.9504, train_auprc: 0.9502, val_auprc: 0.9420




Saving model acc
Epoch: 84 (61.4864s), train_loss: 0.2470, val_loss: 0.2850, train_acc: 0.9027, val_acc:0.8894
		train_roc: 0.9593, val_roc: 0.9501, train_auprc: 0.9495, val_auprc: 0.9409




Epoch: 85 (59.9818s), train_loss: 0.2444, val_loss: 0.2862, train_acc: 0.9043, val_acc:0.8884
		train_roc: 0.9601, val_roc: 0.9499, train_auprc: 0.9506, val_auprc: 0.9411




Epoch: 86 (61.0359s), train_loss: 0.2444, val_loss: 0.2859, train_acc: 0.9042, val_acc:0.8878
		train_roc: 0.9601, val_roc: 0.9500, train_auprc: 0.9504, val_auprc: 0.9416




Epoch: 87 (68.3413s), train_loss: 0.2413, val_loss: 0.2844, train_acc: 0.9058, val_acc:0.8886
		train_roc: 0.9610, val_roc: 0.9503, train_auprc: 0.9519, val_auprc: 0.9419




Epoch: 88 (84.6190s), train_loss: 0.2424, val_loss: 0.2879, train_acc: 0.9059, val_acc:0.8880
		train_roc: 0.9606, val_roc: 0.9500, train_auprc: 0.9511, val_auprc: 0.9409




Saving model prc
Saving model acc
Saving model roc
Epoch: 89 (79.3575s), train_loss: 0.2425, val_loss: 0.2789, train_acc: 0.9055, val_acc:0.8922
		train_roc: 0.9605, val_roc: 0.9521, train_auprc: 0.9508, val_auprc: 0.9441




Epoch: 90 (79.8483s), train_loss: 0.2415, val_loss: 0.2868, train_acc: 0.9056, val_acc:0.8891
		train_roc: 0.9608, val_roc: 0.9500, train_auprc: 0.9514, val_auprc: 0.9409




Epoch: 91 (79.0500s), train_loss: 0.2388, val_loss: 0.2839, train_acc: 0.9063, val_acc:0.8905
		train_roc: 0.9619, val_roc: 0.9517, train_auprc: 0.9530, val_auprc: 0.9431




Epoch: 92 (76.3853s), train_loss: 0.2391, val_loss: 0.2817, train_acc: 0.9061, val_acc:0.8909
		train_roc: 0.9616, val_roc: 0.9515, train_auprc: 0.9522, val_auprc: 0.9432




Epoch: 93 (83.7948s), train_loss: 0.2409, val_loss: 0.2826, train_acc: 0.9059, val_acc:0.8907
		train_roc: 0.9608, val_roc: 0.9516, train_auprc: 0.9514, val_auprc: 0.9429




Epoch: 94 (81.5507s), train_loss: 0.2371, val_loss: 0.2800, train_acc: 0.9081, val_acc:0.8911
		train_roc: 0.9622, val_roc: 0.9518, train_auprc: 0.9526, val_auprc: 0.9432




Epoch: 95 (80.0177s), train_loss: 0.2396, val_loss: 0.2846, train_acc: 0.9072, val_acc:0.8891
		train_roc: 0.9614, val_roc: 0.9512, train_auprc: 0.9517, val_auprc: 0.9423




Epoch: 96 (60.7888s), train_loss: 0.2383, val_loss: 0.2814, train_acc: 0.9070, val_acc:0.8914
		train_roc: 0.9618, val_roc: 0.9519, train_auprc: 0.9523, val_auprc: 0.9432




Epoch: 97 (57.5853s), train_loss: 0.2352, val_loss: 0.2836, train_acc: 0.9084, val_acc:0.8901
		train_roc: 0.9625, val_roc: 0.9515, train_auprc: 0.9534, val_auprc: 0.9427




Epoch: 98 (57.4457s), train_loss: 0.2361, val_loss: 0.2855, train_acc: 0.9082, val_acc:0.8898
		train_roc: 0.9624, val_roc: 0.9508, train_auprc: 0.9530, val_auprc: 0.9421




Epoch: 99 (57.7175s), train_loss: 0.2356, val_loss: 0.2869, train_acc: 0.9084, val_acc:0.8901
		train_roc: 0.9625, val_roc: 0.9507, train_auprc: 0.9534, val_auprc: 0.9414




Epoch: 100 (57.4898s), train_loss: 0.2357, val_loss: 0.2847, train_acc: 0.9079, val_acc:0.8902
		train_roc: 0.9623, val_roc: 0.9515, train_auprc: 0.9530, val_auprc: 0.9426




Saving model prc
Saving model acc
Saving model roc
Epoch: 101 (57.7246s), train_loss: 0.2340, val_loss: 0.2805, train_acc: 0.9090, val_acc:0.8923
		train_roc: 0.9630, val_roc: 0.9523, train_auprc: 0.9540, val_auprc: 0.9442




Saving model prc
Saving model roc
Epoch: 102 (57.5347s), train_loss: 0.2352, val_loss: 0.2803, train_acc: 0.9084, val_acc:0.8923
		train_roc: 0.9625, val_roc: 0.9528, train_auprc: 0.9530, val_auprc: 0.9443




Epoch: 103 (58.9878s), train_loss: 0.2350, val_loss: 0.2828, train_acc: 0.9086, val_acc:0.8910
		train_roc: 0.9626, val_roc: 0.9519, train_auprc: 0.9534, val_auprc: 0.9435




Saving model prc
Epoch: 104 (59.0983s), train_loss: 0.2329, val_loss: 0.2818, train_acc: 0.9089, val_acc:0.8917
		train_roc: 0.9632, val_roc: 0.9524, train_auprc: 0.9546, val_auprc: 0.9446




Saving model prc
Epoch: 105 (58.9492s), train_loss: 0.2325, val_loss: 0.2811, train_acc: 0.9100, val_acc:0.8917
		train_roc: 0.9634, val_roc: 0.9528, train_auprc: 0.9544, val_auprc: 0.9447




Saving model prc
Saving model acc
Saving model roc
Epoch: 106 (58.9733s), train_loss: 0.2331, val_loss: 0.2802, train_acc: 0.9096, val_acc:0.8924
		train_roc: 0.9631, val_roc: 0.9530, train_auprc: 0.9538, val_auprc: 0.9453




Epoch: 107 (58.9022s), train_loss: 0.2334, val_loss: 0.2817, train_acc: 0.9089, val_acc:0.8911
		train_roc: 0.9630, val_roc: 0.9524, train_auprc: 0.9542, val_auprc: 0.9443




Saving model acc
Epoch: 108 (59.1706s), train_loss: 0.2311, val_loss: 0.2813, train_acc: 0.9103, val_acc:0.8925
		train_roc: 0.9638, val_roc: 0.9525, train_auprc: 0.9549, val_auprc: 0.9445




Epoch: 109 (58.9084s), train_loss: 0.2318, val_loss: 0.2814, train_acc: 0.9103, val_acc:0.8913
		train_roc: 0.9634, val_roc: 0.9528, train_auprc: 0.9546, val_auprc: 0.9451




Saving model prc
Saving model acc
Saving model roc
Epoch: 110 (59.0552s), train_loss: 0.2301, val_loss: 0.2790, train_acc: 0.9113, val_acc:0.8936
		train_roc: 0.9638, val_roc: 0.9534, train_auprc: 0.9548, val_auprc: 0.9456




Epoch: 111 (59.4039s), train_loss: 0.2308, val_loss: 0.2854, train_acc: 0.9101, val_acc:0.8901
		train_roc: 0.9637, val_roc: 0.9515, train_auprc: 0.9548, val_auprc: 0.9432




Saving model prc
Saving model roc
Epoch: 112 (58.9565s), train_loss: 0.2292, val_loss: 0.2793, train_acc: 0.9112, val_acc:0.8926
		train_roc: 0.9643, val_roc: 0.9535, train_auprc: 0.9556, val_auprc: 0.9461




Epoch: 113 (59.2510s), train_loss: 0.2298, val_loss: 0.2798, train_acc: 0.9110, val_acc:0.8935
		train_roc: 0.9640, val_roc: 0.9535, train_auprc: 0.9555, val_auprc: 0.9456




Epoch: 114 (59.2802s), train_loss: 0.2308, val_loss: 0.2819, train_acc: 0.9108, val_acc:0.8917
		train_roc: 0.9636, val_roc: 0.9527, train_auprc: 0.9545, val_auprc: 0.9449




Saving model prc
Saving model acc
Saving model roc
Epoch: 115 (59.1128s), train_loss: 0.2292, val_loss: 0.2789, train_acc: 0.9114, val_acc:0.8938
		train_roc: 0.9643, val_roc: 0.9539, train_auprc: 0.9556, val_auprc: 0.9464




Epoch: 116 (58.8795s), train_loss: 0.2295, val_loss: 0.2808, train_acc: 0.9112, val_acc:0.8931
		train_roc: 0.9641, val_roc: 0.9533, train_auprc: 0.9550, val_auprc: 0.9460




Epoch: 117 (64.9698s), train_loss: 0.2283, val_loss: 0.2832, train_acc: 0.9114, val_acc:0.8923
		train_roc: 0.9645, val_roc: 0.9527, train_auprc: 0.9560, val_auprc: 0.9447




Epoch: 118 (70.0841s), train_loss: 0.2283, val_loss: 0.2850, train_acc: 0.9117, val_acc:0.8900
		train_roc: 0.9645, val_roc: 0.9522, train_auprc: 0.9556, val_auprc: 0.9443




Saving model acc
Epoch: 119 (68.6844s), train_loss: 0.2278, val_loss: 0.2789, train_acc: 0.9118, val_acc:0.8944
		train_roc: 0.9648, val_roc: 0.9538, train_auprc: 0.9560, val_auprc: 0.9464




Epoch: 120 (66.6335s), train_loss: 0.2275, val_loss: 0.2800, train_acc: 0.9116, val_acc:0.8942
		train_roc: 0.9648, val_roc: 0.9539, train_auprc: 0.9558, val_auprc: 0.9461




Epoch: 121 (63.9688s), train_loss: 0.2253, val_loss: 0.2836, train_acc: 0.9127, val_acc:0.8921
		train_roc: 0.9653, val_roc: 0.9531, train_auprc: 0.9569, val_auprc: 0.9453




Saving model prc
Saving model roc
Epoch: 122 (63.3194s), train_loss: 0.2278, val_loss: 0.2790, train_acc: 0.9114, val_acc:0.8940
		train_roc: 0.9647, val_roc: 0.9542, train_auprc: 0.9560, val_auprc: 0.9467




Epoch: 123 (58.9853s), train_loss: 0.2273, val_loss: 0.2800, train_acc: 0.9120, val_acc:0.8939
		train_roc: 0.9647, val_roc: 0.9541, train_auprc: 0.9558, val_auprc: 0.9467




Epoch: 124 (59.1190s), train_loss: 0.2274, val_loss: 0.2836, train_acc: 0.9117, val_acc:0.8921
		train_roc: 0.9648, val_roc: 0.9528, train_auprc: 0.9562, val_auprc: 0.9444




Epoch: 125 (59.2348s), train_loss: 0.2263, val_loss: 0.2827, train_acc: 0.9123, val_acc:0.8937
		train_roc: 0.9651, val_roc: 0.9531, train_auprc: 0.9566, val_auprc: 0.9449




Saving model prc
Saving model roc
Epoch: 126 (60.0779s), train_loss: 0.2275, val_loss: 0.2800, train_acc: 0.9125, val_acc:0.8941
		train_roc: 0.9646, val_roc: 0.9543, train_auprc: 0.9559, val_auprc: 0.9474




Epoch: 127 (62.9277s), train_loss: 0.2266, val_loss: 0.2848, train_acc: 0.9124, val_acc:0.8924
		train_roc: 0.9648, val_roc: 0.9523, train_auprc: 0.9561, val_auprc: 0.9438




Epoch: 128 (62.9635s), train_loss: 0.2240, val_loss: 0.2809, train_acc: 0.9132, val_acc:0.8939
		train_roc: 0.9659, val_roc: 0.9536, train_auprc: 0.9575, val_auprc: 0.9459




Epoch: 129 (63.4333s), train_loss: 0.2265, val_loss: 0.2828, train_acc: 0.9125, val_acc:0.8925
		train_roc: 0.9649, val_roc: 0.9532, train_auprc: 0.9561, val_auprc: 0.9450




Epoch: 130 (65.3014s), train_loss: 0.2268, val_loss: 0.2815, train_acc: 0.9120, val_acc:0.8939
		train_roc: 0.9648, val_roc: 0.9538, train_auprc: 0.9559, val_auprc: 0.9462




Saving model acc
Epoch: 131 (64.2370s), train_loss: 0.2267, val_loss: 0.2795, train_acc: 0.9123, val_acc:0.8944
		train_roc: 0.9650, val_roc: 0.9539, train_auprc: 0.9563, val_auprc: 0.9459




Saving model prc
Epoch: 132 (67.2023s), train_loss: 0.2238, val_loss: 0.2784, train_acc: 0.9132, val_acc:0.8934
		train_roc: 0.9657, val_roc: 0.9542, train_auprc: 0.9576, val_auprc: 0.9475




Epoch: 133 (67.5507s), train_loss: 0.2242, val_loss: 0.2826, train_acc: 0.9134, val_acc:0.8928
		train_roc: 0.9657, val_roc: 0.9531, train_auprc: 0.9574, val_auprc: 0.9451




Saving model acc
Epoch: 134 (68.1203s), train_loss: 0.2263, val_loss: 0.2803, train_acc: 0.9129, val_acc:0.8945
		train_roc: 0.9649, val_roc: 0.9539, train_auprc: 0.9558, val_auprc: 0.9460




Epoch: 135 (67.9971s), train_loss: 0.2252, val_loss: 0.2825, train_acc: 0.9133, val_acc:0.8933
		train_roc: 0.9653, val_roc: 0.9530, train_auprc: 0.9563, val_auprc: 0.9450




Epoch: 136 (67.1163s), train_loss: 0.2243, val_loss: 0.2814, train_acc: 0.9136, val_acc:0.8944
		train_roc: 0.9655, val_roc: 0.9536, train_auprc: 0.9568, val_auprc: 0.9458




Saving model acc
Saving model roc
Epoch: 137 (67.4352s), train_loss: 0.2258, val_loss: 0.2786, train_acc: 0.9128, val_acc:0.8954
		train_roc: 0.9651, val_roc: 0.9547, train_auprc: 0.9564, val_auprc: 0.9469




Epoch: 138 (67.6174s), train_loss: 0.2252, val_loss: 0.2816, train_acc: 0.9127, val_acc:0.8936
		train_roc: 0.9653, val_roc: 0.9534, train_auprc: 0.9568, val_auprc: 0.9461




Epoch: 139 (68.2767s), train_loss: 0.2252, val_loss: 0.2815, train_acc: 0.9127, val_acc:0.8941
		train_roc: 0.9652, val_roc: 0.9534, train_auprc: 0.9567, val_auprc: 0.9457




Epoch: 140 (67.1275s), train_loss: 0.2245, val_loss: 0.2832, train_acc: 0.9134, val_acc:0.8932
		train_roc: 0.9655, val_roc: 0.9534, train_auprc: 0.9572, val_auprc: 0.9456




Epoch: 141 (67.4656s), train_loss: 0.2242, val_loss: 0.2827, train_acc: 0.9134, val_acc:0.8939
		train_roc: 0.9656, val_roc: 0.9533, train_auprc: 0.9569, val_auprc: 0.9452




Epoch: 142 (67.5935s), train_loss: 0.2259, val_loss: 0.2823, train_acc: 0.9127, val_acc:0.8928
		train_roc: 0.9650, val_roc: 0.9534, train_auprc: 0.9562, val_auprc: 0.9459




Epoch: 143 (68.2565s), train_loss: 0.2241, val_loss: 0.2790, train_acc: 0.9137, val_acc:0.8943
		train_roc: 0.9656, val_roc: 0.9544, train_auprc: 0.9570, val_auprc: 0.9475




Epoch: 144 (67.7225s), train_loss: 0.2259, val_loss: 0.2810, train_acc: 0.9128, val_acc:0.8947
		train_roc: 0.9650, val_roc: 0.9542, train_auprc: 0.9559, val_auprc: 0.9468




Epoch: 145 (63.5975s), train_loss: 0.2239, val_loss: 0.2813, train_acc: 0.9136, val_acc:0.8945
		train_roc: 0.9658, val_roc: 0.9537, train_auprc: 0.9572, val_auprc: 0.9459




Epoch: 146 (57.2876s), train_loss: 0.2241, val_loss: 0.2815, train_acc: 0.9139, val_acc:0.8938
		train_roc: 0.9655, val_roc: 0.9533, train_auprc: 0.9566, val_auprc: 0.9457




Epoch: 147 (57.0838s), train_loss: 0.2230, val_loss: 0.2821, train_acc: 0.9143, val_acc:0.8940
		train_roc: 0.9659, val_roc: 0.9535, train_auprc: 0.9571, val_auprc: 0.9462




Epoch: 148 (58.4869s), train_loss: 0.2229, val_loss: 0.2819, train_acc: 0.9141, val_acc:0.8939
		train_roc: 0.9660, val_roc: 0.9535, train_auprc: 0.9574, val_auprc: 0.9460




Epoch: 149 (64.9791s), train_loss: 0.2241, val_loss: 0.2840, train_acc: 0.9134, val_acc:0.8934
		train_roc: 0.9656, val_roc: 0.9533, train_auprc: 0.9569, val_auprc: 0.9454




Epoch: 150 (64.3800s), train_loss: 0.2241, val_loss: 0.2813, train_acc: 0.9135, val_acc:0.8951
		train_roc: 0.9655, val_roc: 0.9538, train_auprc: 0.9568, val_auprc: 0.9462




Saving model acc
Epoch: 151 (64.6017s), train_loss: 0.2236, val_loss: 0.2803, train_acc: 0.9141, val_acc:0.8955
		train_roc: 0.9656, val_roc: 0.9541, train_auprc: 0.9570, val_auprc: 0.9461




Epoch: 152 (64.3444s), train_loss: 0.2228, val_loss: 0.2818, train_acc: 0.9140, val_acc:0.8942
		train_roc: 0.9659, val_roc: 0.9539, train_auprc: 0.9579, val_auprc: 0.9464




Epoch: 153 (64.4867s), train_loss: 0.2230, val_loss: 0.2842, train_acc: 0.9139, val_acc:0.8934
		train_roc: 0.9660, val_roc: 0.9529, train_auprc: 0.9575, val_auprc: 0.9451




Epoch: 154 (64.5016s), train_loss: 0.2239, val_loss: 0.2818, train_acc: 0.9132, val_acc:0.8944
		train_roc: 0.9657, val_roc: 0.9539, train_auprc: 0.9571, val_auprc: 0.9462




Epoch: 155 (64.6375s), train_loss: 0.2251, val_loss: 0.2833, train_acc: 0.9130, val_acc:0.8926
		train_roc: 0.9652, val_roc: 0.9532, train_auprc: 0.9563, val_auprc: 0.9456




Epoch: 156 (64.7222s), train_loss: 0.2232, val_loss: 0.2816, train_acc: 0.9142, val_acc:0.8949
		train_roc: 0.9658, val_roc: 0.9541, train_auprc: 0.9570, val_auprc: 0.9461




Epoch: 157 (64.4104s), train_loss: 0.2225, val_loss: 0.2824, train_acc: 0.9144, val_acc:0.8942
		train_roc: 0.9660, val_roc: 0.9535, train_auprc: 0.9575, val_auprc: 0.9462




Epoch: 158 (64.1701s), train_loss: 0.2231, val_loss: 0.2844, train_acc: 0.9144, val_acc:0.8934
		train_roc: 0.9657, val_roc: 0.9530, train_auprc: 0.9569, val_auprc: 0.9451




Epoch: 159 (64.4994s), train_loss: 0.2238, val_loss: 0.2828, train_acc: 0.9136, val_acc:0.8942
		train_roc: 0.9655, val_roc: 0.9534, train_auprc: 0.9566, val_auprc: 0.9452




Epoch: 160 (64.4738s), train_loss: 0.2230, val_loss: 0.2796, train_acc: 0.9138, val_acc:0.8948
		train_roc: 0.9659, val_roc: 0.9545, train_auprc: 0.9575, val_auprc: 0.9475




Epoch: 161 (64.2439s), train_loss: 0.2236, val_loss: 0.2828, train_acc: 0.9136, val_acc:0.8940
		train_roc: 0.9657, val_roc: 0.9534, train_auprc: 0.9572, val_auprc: 0.9460




Epoch: 162 (64.5141s), train_loss: 0.2233, val_loss: 0.2838, train_acc: 0.9141, val_acc:0.8937
		train_roc: 0.9658, val_roc: 0.9532, train_auprc: 0.9568, val_auprc: 0.9453




Epoch: 163 (64.3183s), train_loss: 0.2221, val_loss: 0.2804, train_acc: 0.9144, val_acc:0.8946
		train_roc: 0.9661, val_roc: 0.9543, train_auprc: 0.9576, val_auprc: 0.9470




Epoch: 164 (64.4737s), train_loss: 0.2212, val_loss: 0.2831, train_acc: 0.9145, val_acc:0.8935
		train_roc: 0.9665, val_roc: 0.9535, train_auprc: 0.9582, val_auprc: 0.9455




Epoch: 165 (64.2647s), train_loss: 0.2233, val_loss: 0.2827, train_acc: 0.9140, val_acc:0.8941
		train_roc: 0.9657, val_roc: 0.9538, train_auprc: 0.9571, val_auprc: 0.9459




Saving model acc
Epoch: 166 (64.5657s), train_loss: 0.2220, val_loss: 0.2807, train_acc: 0.9144, val_acc:0.8957
		train_roc: 0.9662, val_roc: 0.9545, train_auprc: 0.9575, val_auprc: 0.9467




Epoch: 167 (64.2954s), train_loss: 0.2233, val_loss: 0.2831, train_acc: 0.9138, val_acc:0.8938
		train_roc: 0.9658, val_roc: 0.9536, train_auprc: 0.9569, val_auprc: 0.9454




Saving model acc
Saving model roc
Epoch: 168 (64.5954s), train_loss: 0.2224, val_loss: 0.2783, train_acc: 0.9148, val_acc:0.8964
		train_roc: 0.9660, val_roc: 0.9551, train_auprc: 0.9569, val_auprc: 0.9474




Epoch: 169 (64.5198s), train_loss: 0.2232, val_loss: 0.2816, train_acc: 0.9142, val_acc:0.8947
		train_roc: 0.9658, val_roc: 0.9540, train_auprc: 0.9570, val_auprc: 0.9463




Epoch: 170 (64.5369s), train_loss: 0.2224, val_loss: 0.2820, train_acc: 0.9144, val_acc:0.8946
		train_roc: 0.9661, val_roc: 0.9538, train_auprc: 0.9574, val_auprc: 0.9457




Epoch: 171 (64.6217s), train_loss: 0.2229, val_loss: 0.2804, train_acc: 0.9139, val_acc:0.8956
		train_roc: 0.9658, val_roc: 0.9546, train_auprc: 0.9572, val_auprc: 0.9468




Epoch: 172 (64.4407s), train_loss: 0.2237, val_loss: 0.2833, train_acc: 0.9139, val_acc:0.8944
		train_roc: 0.9657, val_roc: 0.9535, train_auprc: 0.9570, val_auprc: 0.9455




Epoch: 173 (64.5015s), train_loss: 0.2226, val_loss: 0.2821, train_acc: 0.9143, val_acc:0.8936
		train_roc: 0.9659, val_roc: 0.9539, train_auprc: 0.9575, val_auprc: 0.9466




Epoch: 174 (64.3538s), train_loss: 0.2223, val_loss: 0.2809, train_acc: 0.9140, val_acc:0.8945
		train_roc: 0.9661, val_roc: 0.9544, train_auprc: 0.9575, val_auprc: 0.9470




Epoch: 175 (64.3334s), train_loss: 0.2223, val_loss: 0.2811, train_acc: 0.9144, val_acc:0.8948
		train_roc: 0.9661, val_roc: 0.9543, train_auprc: 0.9575, val_auprc: 0.9465




Saving model prc
Epoch: 176 (64.5550s), train_loss: 0.2215, val_loss: 0.2787, train_acc: 0.9143, val_acc:0.8954
		train_roc: 0.9664, val_roc: 0.9550, train_auprc: 0.9582, val_auprc: 0.9478




Epoch: 177 (64.5430s), train_loss: 0.2215, val_loss: 0.2814, train_acc: 0.9141, val_acc:0.8949
		train_roc: 0.9664, val_roc: 0.9541, train_auprc: 0.9581, val_auprc: 0.9466




Epoch: 178 (64.5310s), train_loss: 0.2238, val_loss: 0.2804, train_acc: 0.9132, val_acc:0.8954
		train_roc: 0.9656, val_roc: 0.9546, train_auprc: 0.9571, val_auprc: 0.9475




Saving model prc
Saving model roc
Epoch: 179 (64.3973s), train_loss: 0.2213, val_loss: 0.2779, train_acc: 0.9147, val_acc:0.8959
		train_roc: 0.9664, val_roc: 0.9553, train_auprc: 0.9579, val_auprc: 0.9487




Epoch: 180 (64.7696s), train_loss: 0.2210, val_loss: 0.2838, train_acc: 0.9148, val_acc:0.8938
		train_roc: 0.9665, val_roc: 0.9531, train_auprc: 0.9579, val_auprc: 0.9454




Epoch: 181 (64.2008s), train_loss: 0.2219, val_loss: 0.2823, train_acc: 0.9148, val_acc:0.8942
		train_roc: 0.9662, val_roc: 0.9538, train_auprc: 0.9574, val_auprc: 0.9461




Epoch: 182 (64.3311s), train_loss: 0.2220, val_loss: 0.2812, train_acc: 0.9146, val_acc:0.8947
		train_roc: 0.9661, val_roc: 0.9544, train_auprc: 0.9577, val_auprc: 0.9467




Epoch: 183 (64.2687s), train_loss: 0.2224, val_loss: 0.2818, train_acc: 0.9144, val_acc:0.8938
		train_roc: 0.9660, val_roc: 0.9539, train_auprc: 0.9573, val_auprc: 0.9465




Epoch: 184 (64.5509s), train_loss: 0.2224, val_loss: 0.2796, train_acc: 0.9147, val_acc:0.8952
		train_roc: 0.9660, val_roc: 0.9549, train_auprc: 0.9571, val_auprc: 0.9478




Epoch: 185 (64.8743s), train_loss: 0.2211, val_loss: 0.2817, train_acc: 0.9149, val_acc:0.8945
		train_roc: 0.9664, val_roc: 0.9541, train_auprc: 0.9578, val_auprc: 0.9462




Epoch: 186 (64.2179s), train_loss: 0.2201, val_loss: 0.2828, train_acc: 0.9151, val_acc:0.8934
		train_roc: 0.9668, val_roc: 0.9535, train_auprc: 0.9586, val_auprc: 0.9461




Epoch: 187 (64.6305s), train_loss: 0.2228, val_loss: 0.2829, train_acc: 0.9141, val_acc:0.8940
		train_roc: 0.9659, val_roc: 0.9540, train_auprc: 0.9573, val_auprc: 0.9460




Epoch: 188 (64.3774s), train_loss: 0.2214, val_loss: 0.2831, train_acc: 0.9144, val_acc:0.8938
		train_roc: 0.9664, val_roc: 0.9536, train_auprc: 0.9580, val_auprc: 0.9459




Epoch: 189 (64.4931s), train_loss: 0.2211, val_loss: 0.2839, train_acc: 0.9149, val_acc:0.8932
		train_roc: 0.9665, val_roc: 0.9531, train_auprc: 0.9582, val_auprc: 0.9455




Epoch: 190 (64.7260s), train_loss: 0.2229, val_loss: 0.2795, train_acc: 0.9137, val_acc:0.8956
		train_roc: 0.9659, val_roc: 0.9552, train_auprc: 0.9574, val_auprc: 0.9477




Epoch: 191 (64.4569s), train_loss: 0.2225, val_loss: 0.2805, train_acc: 0.9145, val_acc:0.8951
		train_roc: 0.9658, val_roc: 0.9546, train_auprc: 0.9574, val_auprc: 0.9467




Epoch: 192 (64.4942s), train_loss: 0.2228, val_loss: 0.2826, train_acc: 0.9143, val_acc:0.8938
		train_roc: 0.9659, val_roc: 0.9536, train_auprc: 0.9572, val_auprc: 0.9461




Epoch: 193 (64.3538s), train_loss: 0.2222, val_loss: 0.2835, train_acc: 0.9146, val_acc:0.8931
		train_roc: 0.9660, val_roc: 0.9535, train_auprc: 0.9574, val_auprc: 0.9461




Epoch: 194 (64.5633s), train_loss: 0.2228, val_loss: 0.2822, train_acc: 0.9141, val_acc:0.8953
		train_roc: 0.9658, val_roc: 0.9540, train_auprc: 0.9572, val_auprc: 0.9459




Epoch: 195 (64.3409s), train_loss: 0.2220, val_loss: 0.2807, train_acc: 0.9144, val_acc:0.8944
		train_roc: 0.9661, val_roc: 0.9545, train_auprc: 0.9575, val_auprc: 0.9472




Epoch: 196 (64.1738s), train_loss: 0.2216, val_loss: 0.2806, train_acc: 0.9141, val_acc:0.8951
		train_roc: 0.9663, val_roc: 0.9544, train_auprc: 0.9579, val_auprc: 0.9467




Epoch: 197 (64.2587s), train_loss: 0.2219, val_loss: 0.2861, train_acc: 0.9146, val_acc:0.8937
		train_roc: 0.9661, val_roc: 0.9524, train_auprc: 0.9574, val_auprc: 0.9438




Epoch: 198 (64.4941s), train_loss: 0.2221, val_loss: 0.2820, train_acc: 0.9142, val_acc:0.8945
		train_roc: 0.9662, val_roc: 0.9541, train_auprc: 0.9578, val_auprc: 0.9471




Epoch: 199 (64.5678s), train_loss: 0.2218, val_loss: 0.2813, train_acc: 0.9142, val_acc:0.8948
		train_roc: 0.9663, val_roc: 0.9543, train_auprc: 0.9580, val_auprc: 0.9465




Epoch: 200 (64.8086s), train_loss: 0.2217, val_loss: 0.2828, train_acc: 0.9143, val_acc:0.8930
		train_roc: 0.9663, val_roc: 0.9539, train_auprc: 0.9578, val_auprc: 0.9471




Epoch: 201 (62.4905s), train_loss: 0.2210, val_loss: 0.2805, train_acc: 0.9152, val_acc:0.8947
		train_roc: 0.9664, val_roc: 0.9546, train_auprc: 0.9577, val_auprc: 0.9475




Epoch: 202 (57.9384s), train_loss: 0.2226, val_loss: 0.2840, train_acc: 0.9147, val_acc:0.8933
		train_roc: 0.9660, val_roc: 0.9534, train_auprc: 0.9570, val_auprc: 0.9458




Epoch: 203 (57.8744s), train_loss: 0.2236, val_loss: 0.2824, train_acc: 0.9136, val_acc:0.8948
		train_roc: 0.9657, val_roc: 0.9539, train_auprc: 0.9568, val_auprc: 0.9459




Epoch: 204 (57.9326s), train_loss: 0.2209, val_loss: 0.2856, train_acc: 0.9153, val_acc:0.8931
		train_roc: 0.9665, val_roc: 0.9528, train_auprc: 0.9578, val_auprc: 0.9443




Epoch: 205 (57.8160s), train_loss: 0.2222, val_loss: 0.2805, train_acc: 0.9142, val_acc:0.8950
		train_roc: 0.9660, val_roc: 0.9546, train_auprc: 0.9574, val_auprc: 0.9473




Epoch: 206 (57.6640s), train_loss: 0.2228, val_loss: 0.2817, train_acc: 0.9140, val_acc:0.8951
		train_roc: 0.9660, val_roc: 0.9543, train_auprc: 0.9575, val_auprc: 0.9467




Epoch: 207 (57.5681s), train_loss: 0.2212, val_loss: 0.2835, train_acc: 0.9146, val_acc:0.8933
		train_roc: 0.9664, val_roc: 0.9534, train_auprc: 0.9580, val_auprc: 0.9462




Epoch: 208 (57.7895s), train_loss: 0.2219, val_loss: 0.2831, train_acc: 0.9147, val_acc:0.8938
		train_roc: 0.9661, val_roc: 0.9536, train_auprc: 0.9577, val_auprc: 0.9463




Epoch: 209 (58.5210s), train_loss: 0.2193, val_loss: 0.2818, train_acc: 0.9155, val_acc:0.8955
		train_roc: 0.9670, val_roc: 0.9540, train_auprc: 0.9589, val_auprc: 0.9458




Epoch: 210 (58.9963s), train_loss: 0.2228, val_loss: 0.2819, train_acc: 0.9142, val_acc:0.8943
		train_roc: 0.9658, val_roc: 0.9541, train_auprc: 0.9570, val_auprc: 0.9462




Epoch: 211 (57.5869s), train_loss: 0.2216, val_loss: 0.2832, train_acc: 0.9146, val_acc:0.8934
		train_roc: 0.9662, val_roc: 0.9538, train_auprc: 0.9576, val_auprc: 0.9460




Epoch: 212 (57.5699s), train_loss: 0.2217, val_loss: 0.2870, train_acc: 0.9145, val_acc:0.8923
		train_roc: 0.9661, val_roc: 0.9523, train_auprc: 0.9579, val_auprc: 0.9440




Epoch: 213 (57.4927s), train_loss: 0.2221, val_loss: 0.2797, train_acc: 0.9148, val_acc:0.8945
		train_roc: 0.9661, val_roc: 0.9550, train_auprc: 0.9574, val_auprc: 0.9481




Epoch: 214 (57.3182s), train_loss: 0.2222, val_loss: 0.2822, train_acc: 0.9144, val_acc:0.8949
		train_roc: 0.9661, val_roc: 0.9540, train_auprc: 0.9575, val_auprc: 0.9459




Epoch: 215 (57.4706s), train_loss: 0.2222, val_loss: 0.2828, train_acc: 0.9147, val_acc:0.8941
		train_roc: 0.9661, val_roc: 0.9537, train_auprc: 0.9573, val_auprc: 0.9460




Epoch: 216 (57.4829s), train_loss: 0.2213, val_loss: 0.2809, train_acc: 0.9148, val_acc:0.8947
		train_roc: 0.9664, val_roc: 0.9544, train_auprc: 0.9580, val_auprc: 0.9470




Epoch: 217 (57.4844s), train_loss: 0.2215, val_loss: 0.2824, train_acc: 0.9145, val_acc:0.8946
		train_roc: 0.9664, val_roc: 0.9539, train_auprc: 0.9580, val_auprc: 0.9458




Epoch: 218 (57.6130s), train_loss: 0.2216, val_loss: 0.2834, train_acc: 0.9149, val_acc:0.8935
		train_roc: 0.9663, val_roc: 0.9535, train_auprc: 0.9577, val_auprc: 0.9455




Epoch: 219 (57.6491s), train_loss: 0.2217, val_loss: 0.2829, train_acc: 0.9145, val_acc:0.8940
		train_roc: 0.9661, val_roc: 0.9537, train_auprc: 0.9578, val_auprc: 0.9465




Epoch: 220 (57.6988s), train_loss: 0.2230, val_loss: 0.2844, train_acc: 0.9146, val_acc:0.8939
		train_roc: 0.9657, val_roc: 0.9530, train_auprc: 0.9569, val_auprc: 0.9451




Epoch: 221 (57.9105s), train_loss: 0.2226, val_loss: 0.2811, train_acc: 0.9142, val_acc:0.8944
		train_roc: 0.9659, val_roc: 0.9543, train_auprc: 0.9574, val_auprc: 0.9470




Epoch: 222 (57.6220s), train_loss: 0.2209, val_loss: 0.2794, train_acc: 0.9143, val_acc:0.8956
		train_roc: 0.9666, val_roc: 0.9547, train_auprc: 0.9585, val_auprc: 0.9477




Epoch: 223 (57.5235s), train_loss: 0.2219, val_loss: 0.2846, train_acc: 0.9145, val_acc:0.8933
		train_roc: 0.9662, val_roc: 0.9529, train_auprc: 0.9576, val_auprc: 0.9454




Epoch: 224 (57.4037s), train_loss: 0.2215, val_loss: 0.2809, train_acc: 0.9148, val_acc:0.8957
		train_roc: 0.9662, val_roc: 0.9544, train_auprc: 0.9576, val_auprc: 0.9469




Epoch: 225 (57.5159s), train_loss: 0.2218, val_loss: 0.2856, train_acc: 0.9145, val_acc:0.8930
		train_roc: 0.9662, val_roc: 0.9527, train_auprc: 0.9575, val_auprc: 0.9444




Epoch: 226 (57.5885s), train_loss: 0.2227, val_loss: 0.2820, train_acc: 0.9142, val_acc:0.8956
		train_roc: 0.9659, val_roc: 0.9540, train_auprc: 0.9571, val_auprc: 0.9457




Epoch: 227 (57.4623s), train_loss: 0.2233, val_loss: 0.2849, train_acc: 0.9137, val_acc:0.8942
		train_roc: 0.9657, val_roc: 0.9530, train_auprc: 0.9570, val_auprc: 0.9447




Epoch: 228 (57.4622s), train_loss: 0.2214, val_loss: 0.2844, train_acc: 0.9150, val_acc:0.8940
		train_roc: 0.9663, val_roc: 0.9531, train_auprc: 0.9580, val_auprc: 0.9448




Epoch: 229 (57.6825s), train_loss: 0.2208, val_loss: 0.2820, train_acc: 0.9149, val_acc:0.8942
		train_roc: 0.9666, val_roc: 0.9540, train_auprc: 0.9582, val_auprc: 0.9464




Epoch: 230 (57.4770s), train_loss: 0.2205, val_loss: 0.2801, train_acc: 0.9153, val_acc:0.8949
		train_roc: 0.9666, val_roc: 0.9548, train_auprc: 0.9580, val_auprc: 0.9477




Epoch: 231 (57.5995s), train_loss: 0.2201, val_loss: 0.2829, train_acc: 0.9154, val_acc:0.8934
		train_roc: 0.9667, val_roc: 0.9537, train_auprc: 0.9583, val_auprc: 0.9464




Epoch: 232 (57.5387s), train_loss: 0.2223, val_loss: 0.2835, train_acc: 0.9143, val_acc:0.8944
		train_roc: 0.9660, val_roc: 0.9536, train_auprc: 0.9574, val_auprc: 0.9448




Epoch: 233 (57.5351s), train_loss: 0.2232, val_loss: 0.2815, train_acc: 0.9138, val_acc:0.8941
		train_roc: 0.9657, val_roc: 0.9544, train_auprc: 0.9569, val_auprc: 0.9474




Epoch: 234 (57.4259s), train_loss: 0.2210, val_loss: 0.2828, train_acc: 0.9152, val_acc:0.8936
		train_roc: 0.9663, val_roc: 0.9538, train_auprc: 0.9578, val_auprc: 0.9461




Epoch: 235 (57.6486s), train_loss: 0.2224, val_loss: 0.2828, train_acc: 0.9145, val_acc:0.8939
		train_roc: 0.9660, val_roc: 0.9537, train_auprc: 0.9575, val_auprc: 0.9461




Epoch: 236 (57.7707s), train_loss: 0.2230, val_loss: 0.2826, train_acc: 0.9139, val_acc:0.8943
		train_roc: 0.9658, val_roc: 0.9539, train_auprc: 0.9571, val_auprc: 0.9460




Epoch: 237 (57.8192s), train_loss: 0.2220, val_loss: 0.2819, train_acc: 0.9147, val_acc:0.8946
		train_roc: 0.9660, val_roc: 0.9541, train_auprc: 0.9573, val_auprc: 0.9468




Epoch: 238 (57.5864s), train_loss: 0.2221, val_loss: 0.2850, train_acc: 0.9147, val_acc:0.8937
		train_roc: 0.9660, val_roc: 0.9529, train_auprc: 0.9573, val_auprc: 0.9445




Saving model acc
Epoch: 239 (57.5982s), train_loss: 0.2212, val_loss: 0.2787, train_acc: 0.9148, val_acc:0.8964
		train_roc: 0.9664, val_roc: 0.9552, train_auprc: 0.9580, val_auprc: 0.9479




Epoch: 240 (57.4929s), train_loss: 0.2219, val_loss: 0.2815, train_acc: 0.9147, val_acc:0.8950
		train_roc: 0.9662, val_roc: 0.9542, train_auprc: 0.9577, val_auprc: 0.9463




Epoch: 241 (57.6149s), train_loss: 0.2214, val_loss: 0.2808, train_acc: 0.9152, val_acc:0.8949
		train_roc: 0.9662, val_roc: 0.9544, train_auprc: 0.9577, val_auprc: 0.9477




Epoch: 242 (57.5049s), train_loss: 0.2206, val_loss: 0.2833, train_acc: 0.9152, val_acc:0.8939
		train_roc: 0.9665, val_roc: 0.9535, train_auprc: 0.9581, val_auprc: 0.9460




Epoch: 243 (57.7011s), train_loss: 0.2228, val_loss: 0.2821, train_acc: 0.9144, val_acc:0.8953
		train_roc: 0.9658, val_roc: 0.9538, train_auprc: 0.9569, val_auprc: 0.9461




Epoch: 244 (57.6360s), train_loss: 0.2222, val_loss: 0.2824, train_acc: 0.9144, val_acc:0.8941
		train_roc: 0.9661, val_roc: 0.9539, train_auprc: 0.9574, val_auprc: 0.9466




Epoch: 245 (57.4817s), train_loss: 0.2203, val_loss: 0.2836, train_acc: 0.9153, val_acc:0.8936
		train_roc: 0.9668, val_roc: 0.9534, train_auprc: 0.9585, val_auprc: 0.9458




Epoch: 246 (57.5483s), train_loss: 0.2224, val_loss: 0.2819, train_acc: 0.9144, val_acc:0.8948
		train_roc: 0.9660, val_roc: 0.9541, train_auprc: 0.9574, val_auprc: 0.9462




Epoch: 247 (57.7172s), train_loss: 0.2217, val_loss: 0.2805, train_acc: 0.9140, val_acc:0.8951
		train_roc: 0.9663, val_roc: 0.9547, train_auprc: 0.9579, val_auprc: 0.9469




Epoch: 248 (57.4628s), train_loss: 0.2231, val_loss: 0.2839, train_acc: 0.9137, val_acc:0.8935
		train_roc: 0.9658, val_roc: 0.9535, train_auprc: 0.9570, val_auprc: 0.9453




Epoch: 249 (57.4496s), train_loss: 0.2212, val_loss: 0.2800, train_acc: 0.9148, val_acc:0.8950
		train_roc: 0.9664, val_roc: 0.9548, train_auprc: 0.9577, val_auprc: 0.9474




Epoch: 250 (57.6303s), train_loss: 0.2223, val_loss: 0.2830, train_acc: 0.9145, val_acc:0.8938
		train_roc: 0.9661, val_roc: 0.9537, train_auprc: 0.9575, val_auprc: 0.9463




Epoch: 251 (57.3928s), train_loss: 0.2234, val_loss: 0.2817, train_acc: 0.9137, val_acc:0.8940
		train_roc: 0.9656, val_roc: 0.9542, train_auprc: 0.9570, val_auprc: 0.9466




Epoch: 252 (57.6047s), train_loss: 0.2230, val_loss: 0.2850, train_acc: 0.9138, val_acc:0.8925
		train_roc: 0.9658, val_roc: 0.9530, train_auprc: 0.9570, val_auprc: 0.9453




Epoch: 253 (57.8500s), train_loss: 0.2231, val_loss: 0.2869, train_acc: 0.9142, val_acc:0.8924
		train_roc: 0.9658, val_roc: 0.9524, train_auprc: 0.9569, val_auprc: 0.9439




Epoch: 254 (57.7245s), train_loss: 0.2205, val_loss: 0.2828, train_acc: 0.9147, val_acc:0.8945
		train_roc: 0.9667, val_roc: 0.9538, train_auprc: 0.9585, val_auprc: 0.9460




Epoch: 255 (57.4226s), train_loss: 0.2211, val_loss: 0.2826, train_acc: 0.9151, val_acc:0.8948
		train_roc: 0.9664, val_roc: 0.9538, train_auprc: 0.9579, val_auprc: 0.9459




Epoch: 256 (57.5681s), train_loss: 0.2201, val_loss: 0.2814, train_acc: 0.9154, val_acc:0.8948
		train_roc: 0.9667, val_roc: 0.9543, train_auprc: 0.9583, val_auprc: 0.9467




Epoch: 257 (57.4465s), train_loss: 0.2200, val_loss: 0.2820, train_acc: 0.9151, val_acc:0.8938
		train_roc: 0.9668, val_roc: 0.9542, train_auprc: 0.9585, val_auprc: 0.9470




Epoch: 258 (57.4177s), train_loss: 0.2215, val_loss: 0.2802, train_acc: 0.9146, val_acc:0.8945
		train_roc: 0.9664, val_roc: 0.9547, train_auprc: 0.9581, val_auprc: 0.9474




Epoch: 259 (57.5348s), train_loss: 0.2210, val_loss: 0.2840, train_acc: 0.9147, val_acc:0.8934
		train_roc: 0.9664, val_roc: 0.9533, train_auprc: 0.9581, val_auprc: 0.9452




Epoch: 260 (57.4564s), train_loss: 0.2226, val_loss: 0.2827, train_acc: 0.9143, val_acc:0.8937
		train_roc: 0.9659, val_roc: 0.9539, train_auprc: 0.9569, val_auprc: 0.9458




Epoch: 261 (57.5570s), train_loss: 0.2211, val_loss: 0.2808, train_acc: 0.9153, val_acc:0.8946
		train_roc: 0.9663, val_roc: 0.9545, train_auprc: 0.9577, val_auprc: 0.9468




Epoch: 262 (57.4246s), train_loss: 0.2221, val_loss: 0.2802, train_acc: 0.9145, val_acc:0.8953
		train_roc: 0.9661, val_roc: 0.9547, train_auprc: 0.9578, val_auprc: 0.9478




Epoch: 263 (57.5113s), train_loss: 0.2208, val_loss: 0.2847, train_acc: 0.9148, val_acc:0.8931
		train_roc: 0.9665, val_roc: 0.9530, train_auprc: 0.9583, val_auprc: 0.9451




Epoch: 264 (59.6963s), train_loss: 0.2222, val_loss: 0.2805, train_acc: 0.9144, val_acc:0.8948
		train_roc: 0.9660, val_roc: 0.9546, train_auprc: 0.9572, val_auprc: 0.9470




Epoch: 265 (59.5717s), train_loss: 0.2223, val_loss: 0.2821, train_acc: 0.9144, val_acc:0.8946
		train_roc: 0.9661, val_roc: 0.9541, train_auprc: 0.9574, val_auprc: 0.9467




Epoch: 266 (57.8386s), train_loss: 0.2238, val_loss: 0.2824, train_acc: 0.9140, val_acc:0.8940
		train_roc: 0.9656, val_roc: 0.9540, train_auprc: 0.9567, val_auprc: 0.9462




Epoch: 267 (57.9370s), train_loss: 0.2210, val_loss: 0.2807, train_acc: 0.9150, val_acc:0.8952
		train_roc: 0.9664, val_roc: 0.9545, train_auprc: 0.9578, val_auprc: 0.9470




Epoch: 268 (57.9085s), train_loss: 0.2216, val_loss: 0.2852, train_acc: 0.9146, val_acc:0.8933
		train_roc: 0.9662, val_roc: 0.9531, train_auprc: 0.9578, val_auprc: 0.9446




Epoch: 269 (57.9419s), train_loss: 0.2234, val_loss: 0.2807, train_acc: 0.9135, val_acc:0.8948
		train_roc: 0.9658, val_roc: 0.9544, train_auprc: 0.9573, val_auprc: 0.9474




Epoch: 270 (57.9310s), train_loss: 0.2208, val_loss: 0.2800, train_acc: 0.9151, val_acc:0.8952
		train_roc: 0.9665, val_roc: 0.9548, train_auprc: 0.9579, val_auprc: 0.9474




Epoch: 271 (58.3481s), train_loss: 0.2220, val_loss: 0.2811, train_acc: 0.9149, val_acc:0.8942
		train_roc: 0.9661, val_roc: 0.9543, train_auprc: 0.9572, val_auprc: 0.9476




Epoch: 272 (58.1537s), train_loss: 0.2206, val_loss: 0.2834, train_acc: 0.9149, val_acc:0.8950
		train_roc: 0.9665, val_roc: 0.9535, train_auprc: 0.9580, val_auprc: 0.9450




Epoch: 273 (60.8051s), train_loss: 0.2224, val_loss: 0.2814, train_acc: 0.9146, val_acc:0.8944
		train_roc: 0.9659, val_roc: 0.9542, train_auprc: 0.9569, val_auprc: 0.9468




Epoch: 274 (60.3599s), train_loss: 0.2217, val_loss: 0.2838, train_acc: 0.9146, val_acc:0.8933
		train_roc: 0.9663, val_roc: 0.9534, train_auprc: 0.9576, val_auprc: 0.9452




Epoch: 275 (60.7814s), train_loss: 0.2221, val_loss: 0.2801, train_acc: 0.9148, val_acc:0.8943
		train_roc: 0.9660, val_roc: 0.9547, train_auprc: 0.9573, val_auprc: 0.9476




Epoch: 276 (60.7370s), train_loss: 0.2207, val_loss: 0.2853, train_acc: 0.9153, val_acc:0.8932
		train_roc: 0.9665, val_roc: 0.9527, train_auprc: 0.9579, val_auprc: 0.9444




Epoch: 277 (58.2851s), train_loss: 0.2215, val_loss: 0.2813, train_acc: 0.9148, val_acc:0.8942
		train_roc: 0.9663, val_roc: 0.9541, train_auprc: 0.9579, val_auprc: 0.9470




Epoch: 278 (58.7486s), train_loss: 0.2218, val_loss: 0.2808, train_acc: 0.9147, val_acc:0.8942
		train_roc: 0.9662, val_roc: 0.9548, train_auprc: 0.9575, val_auprc: 0.9475




Epoch: 279 (58.3949s), train_loss: 0.2219, val_loss: 0.2821, train_acc: 0.9144, val_acc:0.8945
		train_roc: 0.9661, val_roc: 0.9541, train_auprc: 0.9577, val_auprc: 0.9462




Epoch: 280 (59.1735s), train_loss: 0.2232, val_loss: 0.2808, train_acc: 0.9138, val_acc:0.8958
		train_roc: 0.9658, val_roc: 0.9545, train_auprc: 0.9571, val_auprc: 0.9467




Epoch: 281 (58.6489s), train_loss: 0.2206, val_loss: 0.2833, train_acc: 0.9148, val_acc:0.8940
		train_roc: 0.9666, val_roc: 0.9536, train_auprc: 0.9585, val_auprc: 0.9456




Epoch: 282 (58.8112s), train_loss: 0.2212, val_loss: 0.2845, train_acc: 0.9150, val_acc:0.8928
		train_roc: 0.9664, val_roc: 0.9531, train_auprc: 0.9581, val_auprc: 0.9452




Epoch: 283 (58.9224s), train_loss: 0.2213, val_loss: 0.2824, train_acc: 0.9148, val_acc:0.8933
		train_roc: 0.9663, val_roc: 0.9539, train_auprc: 0.9578, val_auprc: 0.9467




Epoch: 284 (59.1985s), train_loss: 0.2216, val_loss: 0.2837, train_acc: 0.9148, val_acc:0.8939
		train_roc: 0.9662, val_roc: 0.9535, train_auprc: 0.9578, val_auprc: 0.9453




Epoch: 285 (58.9355s), train_loss: 0.2219, val_loss: 0.2836, train_acc: 0.9147, val_acc:0.8931
		train_roc: 0.9661, val_roc: 0.9536, train_auprc: 0.9573, val_auprc: 0.9462




Epoch: 286 (59.2554s), train_loss: 0.2226, val_loss: 0.2837, train_acc: 0.9144, val_acc:0.8942
		train_roc: 0.9660, val_roc: 0.9535, train_auprc: 0.9572, val_auprc: 0.9454




Epoch: 287 (58.9573s), train_loss: 0.2217, val_loss: 0.2837, train_acc: 0.9142, val_acc:0.8942
		train_roc: 0.9662, val_roc: 0.9534, train_auprc: 0.9578, val_auprc: 0.9454




Epoch: 288 (58.8259s), train_loss: 0.2214, val_loss: 0.2799, train_acc: 0.9147, val_acc:0.8946
		train_roc: 0.9664, val_roc: 0.9549, train_auprc: 0.9580, val_auprc: 0.9476




Epoch: 289 (58.7080s), train_loss: 0.2224, val_loss: 0.2832, train_acc: 0.9144, val_acc:0.8937
		train_roc: 0.9660, val_roc: 0.9536, train_auprc: 0.9572, val_auprc: 0.9459




Epoch: 290 (57.5099s), train_loss: 0.2210, val_loss: 0.2830, train_acc: 0.9149, val_acc:0.8939
		train_roc: 0.9665, val_roc: 0.9537, train_auprc: 0.9580, val_auprc: 0.9461




Epoch: 291 (64.0338s), train_loss: 0.2221, val_loss: 0.2804, train_acc: 0.9142, val_acc:0.8953
		train_roc: 0.9662, val_roc: 0.9545, train_auprc: 0.9577, val_auprc: 0.9469




Epoch: 292 (64.6030s), train_loss: 0.2215, val_loss: 0.2801, train_acc: 0.9143, val_acc:0.8949
		train_roc: 0.9662, val_roc: 0.9547, train_auprc: 0.9579, val_auprc: 0.9480




Epoch: 293 (66.8570s), train_loss: 0.2214, val_loss: 0.2814, train_acc: 0.9150, val_acc:0.8943
		train_roc: 0.9663, val_roc: 0.9544, train_auprc: 0.9577, val_auprc: 0.9471




Epoch: 294 (57.6101s), train_loss: 0.2222, val_loss: 0.2815, train_acc: 0.9145, val_acc:0.8947
		train_roc: 0.9661, val_roc: 0.9541, train_auprc: 0.9575, val_auprc: 0.9466




Epoch: 295 (57.4861s), train_loss: 0.2230, val_loss: 0.2842, train_acc: 0.9140, val_acc:0.8936
		train_roc: 0.9658, val_roc: 0.9531, train_auprc: 0.9570, val_auprc: 0.9452




Epoch: 296 (58.0620s), train_loss: 0.2220, val_loss: 0.2838, train_acc: 0.9142, val_acc:0.8935
		train_roc: 0.9661, val_roc: 0.9534, train_auprc: 0.9578, val_auprc: 0.9458




Epoch: 297 (58.3877s), train_loss: 0.2207, val_loss: 0.2826, train_acc: 0.9153, val_acc:0.8930
		train_roc: 0.9665, val_roc: 0.9539, train_auprc: 0.9581, val_auprc: 0.9470




Epoch: 298 (58.1608s), train_loss: 0.2219, val_loss: 0.2815, train_acc: 0.9148, val_acc:0.8944
		train_roc: 0.9661, val_roc: 0.9543, train_auprc: 0.9574, val_auprc: 0.9469




Epoch: 299 (58.2980s), train_loss: 0.2198, val_loss: 0.2861, train_acc: 0.9153, val_acc:0.8918
		train_roc: 0.9669, val_roc: 0.9526, train_auprc: 0.9589, val_auprc: 0.9448




Epoch: 300 (58.1237s), train_loss: 0.2201, val_loss: 0.2849, train_acc: 0.9153, val_acc:0.8925
		train_roc: 0.9667, val_roc: 0.9529, train_auprc: 0.9582, val_auprc: 0.9452


In [19]:
# Predict
model = torch.load(model_prc_file)
print(model)
model.to(device=device)
predict(model, test_data_loader, device)

SSI_DDI(
  (initial_norm): LayerNorm(55, affine=True, mode=graph)
  (net_norms): ModuleList(
    (0-3): 4 x LayerNorm(64, affine=True, mode=graph)
  )
  (block0): SSI_DDI_Block(
    (conv): GATConv(55, 32, heads=2)
    (readout): SAGPooling(GraphConv, 64, min_score=-1, multiplier=1.0)
  )
  (block1): SSI_DDI_Block(
    (conv): GATConv(64, 32, heads=2)
    (readout): SAGPooling(GraphConv, 64, min_score=-1, multiplier=1.0)
  )
  (block2): SSI_DDI_Block(
    (conv): GATConv(64, 32, heads=2)
    (readout): SAGPooling(GraphConv, 64, min_score=-1, multiplier=1.0)
  )
  (block3): SSI_DDI_Block(
    (conv): GATConv(64, 32, heads=2)
    (readout): SAGPooling(GraphConv, 64, min_score=-1, multiplier=1.0)
  )
  (co_attention): MultiheadCoAttentionLayer(
    (W_q): ParameterList(
        (0): Parameter containing: [torch.float32 of size 64x32 (cuda:0)]
        (1): Parameter containing: [torch.float32 of size 64x32 (cuda:0)]
    )
    (W_k): ParameterList(
        (0): Parameter containing: [torch.

  model = torch.load(model_prc_file)


Test Accuracy: 0.8955
Test ROC AUC: 0.9546
Test PRC AUC: 0.9467
