In [1]:
# !pip install torch_geometric rdkit torch

In [2]:
from datetime import datetime
import time
import argparse
import sys
import torch
from torch import optim
from torch import nn
import torch.nn.functional as F
from sklearn import metrics
import pandas as pd
import numpy as np
from torch.nn.modules.container import ModuleList
from torch_geometric.nn import (
    GATConv,
    SAGPooling,
    LayerNorm,
    global_mean_pool,
    max_pool_neighbor_x,
    global_add_pool,
)


In [3]:
# Directory configuration
data_dir = "data"
model_dir = "models"
model_name = "case25"

# sys.path.append('/content/drive/MyDrive/Colab Notebooks')

In [4]:
####### Tunning parameters #######

# Number of epochs
n_epochs = 300

# SagPooling ratio & min score. 
# Set sp_ratio to None to disable ratio in SagPooling
sp_ratio = None
sp_min_score = None

# Enable using gpu
use_cuda = True

# Use activation function for CoAttention Layer
use_activation_fn = False

# Use ComplEx instead of RESCAL
use_ComplEx = False

# Use improved CoAttention Layer
# Could be "original" || "improved" || "multihead"
co_attention_method = "multihead"

# Use Explicit Valence
use_explicit_valence = False

# Number of GAT layers
num_GAT_layers = 6

# Number of GAT multiheads
num_GAT_multiheads = 2

#################################

In [5]:
# If using explicit valence feature
if use_explicit_valence:
    from data_preprocessing_explicit_valence import DrugDataset, DrugDataLoader, TOTAL_ATOM_FEATS
else:
    from data_preprocessing import DrugDataset, DrugDataLoader, TOTAL_ATOM_FEATS

  return undirected_edge_list.T, features


In [6]:
mode = "train"

n_atom_feats = TOTAL_ATOM_FEATS
# Not use
n_atom_hid = 64
# Total interactions information in the Interaction_information.csv
rel_total = 86
lr = 1e-2
weight_decay = 5e-4
neg_samples = 1
# Represents the number of samples (or graph instances) loaded in each batch during the training process.
batch_size = 1024
data_size_ratio = 1
kge_dim = 64

device = "cuda" if torch.cuda.is_available() and use_cuda else "cpu"

print(device)
print(f"Epochs: {n_epochs}")
print(f"Total of atom features: {TOTAL_ATOM_FEATS}")

cuda
Epochs: 300
Total of atom features: 55


In [7]:
def print_tunning_parameters():
    print()
    print("####### Tunning parameters #######")
    print()
    
    print("n_epochs =", n_epochs)
    print("use_cuda =", use_cuda)
    print()
    print("num_GAT_layers = ", num_GAT_layers)
    print("num_GAT_multiheads = ", num_GAT_multiheads)
    print()
    print("sp_ratio =", sp_ratio)
    print("sp_min_score =", sp_min_score)
    print()
    print("use_explicit_valence =", use_explicit_valence)
    print()
    print("use_activation_fn =", use_activation_fn)
    print()
    print("use_ComplEx =", use_ComplEx)
    print()
    print("co_attention_method =", co_attention_method)
    
    print()
    print("#################################")
    print()


In [8]:
class CoAttentionLayer(nn.Module):
    def __init__(self, n_features, use_activation_fn=True):
        super().__init__()
        self.n_features = n_features
        self.w_q = nn.Parameter(torch.zeros(n_features, n_features // 2))
        self.w_k = nn.Parameter(torch.zeros(n_features, n_features // 2))
        self.bias = nn.Parameter(torch.zeros(n_features // 2))
        self.a = nn.Parameter(torch.zeros(n_features // 2))
        self.use_activation_fn = use_activation_fn

        nn.init.xavier_uniform_(self.w_q)
        nn.init.xavier_uniform_(self.w_k)
        nn.init.xavier_uniform_(self.bias.view(*self.bias.shape, -1))
        nn.init.xavier_uniform_(self.a.view(*self.a.shape, -1))

    def forward(self, receiver, attendant):
        # receiver.shape  = (1024, 4, 64)
        # attendant.shape = (1024, 4, 64)
        keys = receiver @ self.w_k
        queries = attendant @ self.w_q
        # values = receiver @ self.w_v
        values = receiver

        # queries.shape = (1024, 4, 32)
        # keys.shape = (1024, 4, 32)
        e_activations = queries.unsqueeze(-3) + keys.unsqueeze(-2) + self.bias
        if self.use_activation_fn:
            e_scores = torch.tanh(e_activations) @ self.a
        else:
            e_scores = e_activations @ self.a
        attentions = e_scores

        return attentions


class MultiheadCoAttentionLayer(nn.Module):
    def __init__(self, n_features, use_activation_fn=True, dropout=0.1, n_heads=2):
        super().__init__()
        self.n_features = n_features
        self.n_heads = n_heads

        
        self.W_q = nn.ParameterList([torch.zeros(self.n_features, self.n_features // n_heads) for _ in range(n_heads)])
        self.W_k = nn.ParameterList([torch.zeros(self.n_features, self.n_features // n_heads) for _ in range(n_heads)])
        
        self.a = nn.Parameter(torch.zeros(self.n_features))
        self.bias = nn.ParameterList([torch.zeros(self.n_features // n_heads) for _ in range(n_heads)])
        
        self.use_activation_fn = use_activation_fn

        self.dropout = nn.Dropout(dropout)

        for i in range(n_heads):
            nn.init.xavier_uniform_(self.W_q[i])
            nn.init.xavier_uniform_(self.W_k[i])
            nn.init.xavier_uniform_(self.bias[i].view(*self.bias[i].shape, -1))
        
        nn.init.xavier_uniform_(self.a.view(*self.a.shape, -1))

    def forward(self, receiver, attendant):
        # receiver.shape  = (1024, 4, 64)
        # attendant.shape = (1024, 4, 64)

        # Compute attention score for each head
        head_outputs = []
        for i in range(self.n_heads):
            keys = receiver @ self.W_k[i]
            queries = attendant @ self.W_q[i]
            # print("keys.shape = ", keys.shape)
            e_activations = queries.unsqueeze(-3) + keys.unsqueeze(-2) + self.bias[i]
            # print("e_activations.shape = ", e_activations.shape)
            head_outputs.append(e_activations)

        # Average the outputs from all heads
        # e_activations_avg.shape = (1024, 4, 4, 32)
        e_activations_avg = torch.cat(head_outputs, dim=-1)
        # print("e_activations_avg.shape = ", e_activations_avg.shape)
        
        if self.use_activation_fn:
            e_scores = torch.tanh(e_activations_avg) @ self.a
        else:
            e_scores = e_activations_avg @ self.a

        # attentions.shape = (1024, 4, 4)
        attentions = e_scores

        return attentions

class CoAttentionLayerImproved(nn.Module):
    def __init__(self, n_features, use_activation_fn=True, dropout=0.1, n_heads=2):
        super().__init__()
        self.n_features = n_features
        self.n_heads = n_heads
        self.head_dim = n_features // n_heads

        # Projects for queries and keys per head
        self.w_q = nn.Parameter(torch.zeros(self.head_dim, self.head_dim // 2))
        self.w_k = nn.Parameter(torch.zeros(self.head_dim, self.head_dim // 2))
        self.bias = nn.Parameter(torch.zeros(self.n_features // 2))
        self.a = nn.Parameter(torch.zeros(self.n_features // 2))
        self.use_activation_fn = use_activation_fn

        self.dropout = nn.Dropout(dropout)

        nn.init.xavier_uniform_(self.w_q)
        nn.init.xavier_uniform_(self.w_k)
        nn.init.xavier_uniform_(self.bias.view(*self.bias.shape, -1))
        nn.init.xavier_uniform_(self.a.view(*self.a.shape, -1))

    def forward(self, receiver, attendant):
        # receiver.shape  = (1024, 4, 64)
        # attendant.shape = (1024, 4, 64)

        # Split reciever and attendant into multiple heads
        batch_size, gat_size, n_features = receiver.shape
        receiver = receiver.view(batch_size, gat_size, self.n_heads, self.head_dim)
        attendant = attendant.view(batch_size, gat_size, self.n_heads, self.head_dim)
        
        # Compute keys and queries per head
        # receiver.shape  = (1024, 4, 2, 32)
        # attendant.shape = (1024, 4, 2, 32)
        
        # self.w_k.shape  = (32, 16)
        # self.w_q.shape  = (32, 16)
        
        # self.keys.shape     = (1024, 4, 2, 16)
        # self.queries.shape  = (1024, 4, 2, 16)
        keys = receiver @ self.w_k
        queries = attendant @ self.w_q

        # self.keys.shape     = (1024, 4, 32)
        # self.queries.shape  = (1024, 4, 32)
        keys    = keys.view(batch_size, gat_size, self.head_dim)
        queries = queries.view(batch_size, gat_size, self.head_dim)
        # print("keys.shape", keys.shape)
        # print("queries.shape", queries.shape)

        # e_activations.shape = (1024, 4, 4, 32)
        # self.a.shape = (32,)
        e_activations = queries.unsqueeze(-3) + keys.unsqueeze(-2) + self.bias
        if self.use_activation_fn:
            e_scores = torch.tanh(e_activations) @ self.a
        else:
            e_scores = e_activations @ self.a

        # attentions.shape = (1024, 4, 4)
        attentions = e_scores

        return attentions


class RESCAL(nn.Module):
    def __init__(self, n_rels, n_features):
        """
        n_rels: number of relations = 86
        n_features: kge_dim = 64
        """
        super().__init__()
        self.n_rels = n_rels
        self.n_features = n_features
        # Embedding layer
        self.rel_emb = nn.Embedding(self.n_rels, n_features * n_features)
        #  Initializes the embedding weights with the Xavier uniform distribution, which helps maintain the scale of gradients during training
        nn.init.xavier_uniform_(self.rel_emb.weight)

    def forward(self, heads, tails, rels, alpha_scores):
        rels = self.rel_emb(rels)
        rels = F.normalize(rels, dim=-1)
        heads = F.normalize(heads, dim=-1)
        tails = F.normalize(tails, dim=-1)
        # print(rels.shape)
        # Convert shape (1024, 4096) to (1024, 64, 64) for dot product
        rels = rels.view(-1, self.n_features, self.n_features)
        # print(rels.shape)
        # (1024, 4, 64) @ (1024, 64, 64) = (1024, 4, 64) @ (1024, 64, 4) = (1024, 4, 4)
        scores = heads @ rels @ tails.transpose(-2, -1)

        # alpha_scores.shape = (1024, 4, 4)
        # scores.shape = (1024, 4, 4)
        if alpha_scores is not None:
            scores = alpha_scores * scores
        # print(scores.shape)
        
        # sum the last 2 dimensions
        scores = scores.sum(dim=(-2, -1))
        
        # print(scores.shape)
        # Shape(1024,)
        return scores

    def __repr__(self):
        return f"{self.__class__.__name__}({self.n_rels}, {self.rel_emb.weight.shape})"



class ComplEx(nn.Module):
    def __init__(self, n_rels, n_features):
        super().__init__()
        self.n_rels = n_rels
        self.n_features = n_features
    
        # Relation embeddings are also complex
        self.rel_real = nn.Embedding(self.n_rels, (self.n_features // 2) * (self.n_features // 2))
        self.rel_imag = nn.Embedding(self.n_rels, (self.n_features // 2) * (self.n_features // 2))
        
        # Initialize embeddings
        nn.init.xavier_uniform_(self.rel_real.weight)
        nn.init.xavier_uniform_(self.rel_imag.weight)

    def forward(self, heads, tails, rels, alpha_scores=None):
        # Preprocess
        # heads = F.normalize(heads, dim=-1)
        # tails = F.normalize(tails, dim=-1)
        
        r_real, r_imag = self.rel_real(rels), self.rel_imag(rels)
        r_real = F.normalize(r_real, dim=-1)
        r_imag = F.normalize(r_imag, dim=-1)
        # print(r_real.shape)
        r_real = r_real.view(-1, self.n_features // 2, self.n_features // 2)
        r_imag = r_imag.view(-1, self.n_features // 2, self.n_features // 2)
        # print(r_real.shape)
        # Split heads and tails to imaginary parts
        h_real, h_imag = heads[..., :self.n_features // 2], heads[..., self.n_features // 2:]
        t_real, t_imag = tails[..., :self.n_features // 2], heads[..., self.n_features // 2:]

        h_real, h_imag = F.normalize(h_real, dim=-1), F.normalize(h_imag, dim=-1)
        t_real, t_imag = F.normalize(t_real, dim=-1), F.normalize(t_imag, dim=-1)

        # ComplEx scoring functionn
        first_part_score = h_real @ r_real @ t_real.transpose(-2, -1)
        second_part_score = h_real @ r_imag @ t_imag.transpose(-2, -1)
        third_part_score = h_imag @ r_real @ t_imag.transpose(-2, -1)
        fourth_part_score = h_imag @ r_imag @ t_real.transpose(-2, -1)

        scores = first_part_score + second_part_score + third_part_score + fourth_part_score
        
        # If alpha_scores is provided, apply it
        if alpha_scores is not None:
            scores = alpha_scores * scores

        scores = scores.sum(dim=(-2, -1))
        
        return scores

    def __repr__(self):
        return f"{self.__class__.__name__}({self.n_rels}, {self.rel_real.weight.shape}, {self.rel_imag.weight.shape})"


In [9]:
class SSI_DDI(nn.Module):
    def __init__(
        self,
        in_features,
        hidd_dim,
        kge_dim,
        rel_total,
        heads_out_feat_params,
        blocks_params,
        sp_ratio,
        use_activation_fn,
        use_ComplEx,
        sp_min_score,
        co_attention_method,
    ):
        """
        blocks_params: list of number layers for multi-head attentions
        """
        super().__init__()
        self.in_features = in_features
        # not using this one
        self.hidd_dim = hidd_dim
        self.rel_total = rel_total
        self.kge_dim = kge_dim
        self.n_blocks = len(blocks_params)

        self.initial_norm = LayerNorm(self.in_features)
        self.blocks = []
        self.use_activation_fn = use_activation_fn
        self.use_ComplEx = use_ComplEx
        # Layer normalization list
        self.net_norms = ModuleList()
        for i, (head_out_feats, n_heads) in enumerate(
            zip(heads_out_feat_params, blocks_params)
        ):
            block = SSI_DDI_Block(
                n_heads, in_features, head_out_feats, final_out_feats=self.hidd_dim, sp_ratio=sp_ratio, sp_min_score=sp_min_score
            )
            self.add_module(f"block{i}", block)
            self.blocks.append(block)
            self.net_norms.append(LayerNorm(head_out_feats * n_heads))
            in_features = head_out_feats * n_heads

        if co_attention_method == "multihead":
            self.co_attention = MultiheadCoAttentionLayer(self.kge_dim, self.use_activation_fn)
        elif co_attention_method == "improved":
            self.co_attention = CoAttentionLayerImproved(self.kge_dim, self.use_activation_fn)
        else:
            self.co_attention = CoAttentionLayer(self.kge_dim, self.use_activation_fn)
            
        if self.use_ComplEx:
            self.KGE = ComplEx(self.rel_total, self.kge_dim)
        else:
            self.KGE = RESCAL(self.rel_total, self.kge_dim)

    def forward(self, triples):
        h_data, t_data, rels = triples

        h_data.x = self.initial_norm(h_data.x, h_data.batch)
        t_data.x = self.initial_norm(t_data.x, t_data.batch)

        repr_h = []
        repr_t = []

        for i, block in enumerate(self.blocks):
            out1, out2 = block(h_data), block(t_data)

            h_data = out1[0]
            t_data = out2[0]
            r_h = out1[1]
            r_t = out2[1]

            repr_h.append(r_h)
            repr_t.append(r_t)

            h_data.x = F.elu(self.net_norms[i](h_data.x, h_data.batch))
            t_data.x = F.elu(self.net_norms[i](t_data.x, t_data.batch))

        repr_h = torch.stack(repr_h, dim=-2)
        repr_t = torch.stack(repr_t, dim=-2)

        kge_heads = repr_h
        kge_tails = repr_t

        attentions = self.co_attention(kge_heads, kge_tails)
        # attentions = None
        scores = self.KGE(kge_heads, kge_tails, rels, attentions)

        return scores


class SSI_DDI_Block(nn.Module):
    def __init__(self, n_heads, in_features, head_out_feats, final_out_feats, sp_ratio, sp_min_score):
        """
        n_heades: number of multi-head attentions = 2
        in_features: number of features = 55 . For explicit valence use, number of features = 56.
        head_out_feats: number of out features. For 4 layers: [32, 32, 32, 32]
        sp_ratio: SAGPooling ratio
        """
        super().__init__()
        self.n_heads = n_heads
        self.in_features = in_features
        self.out_features = head_out_feats
        self.conv = GATConv(in_features, head_out_feats, n_heads)
        # SAGPooling: Ranks nodes based on self-attention scores

        if sp_ratio is None and sp_min_score is None:
            self.readout = SAGPooling(n_heads * head_out_feats, min_score=-1)
        else:
            if sp_ratio is not None:
                self.readout = SAGPooling(n_heads * head_out_feats, min_score=sp_min_score, ratio=sp_ratio)
            else:
                self.readout = SAGPooling(n_heads * head_out_feats, min_score=sp_min_score)

    def forward(self, data):
        data.x = self.conv(data.x, data.edge_index)
        # Call SAGPooling here
        # If min_score = -1 so nodes will not be filtered out, basically redudant for using the SAGPooling.
        att_x, att_edge_index, att_edge_attr, att_batch, att_perm, att_scores = (
            self.readout(data.x, data.edge_index, batch=data.batch)
        )
        # Aggregates the pooled node features (att_x) across the graph to obtain a global representation
        global_graph_emb = global_add_pool(att_x, att_batch)

        # data = max_pool_neighbor_x(data)
        return data, global_graph_emb


In [10]:
class SigmoidLoss(nn.Module):
    def __init__(self, adv_temperature=None):
        super().__init__()
        self.adv_temperature = adv_temperature

    def forward(self, p_scores, n_scores):
        if self.adv_temperature:
            weights = F.softmax(self.adv_temperature * n_scores, dim=-1).detach()
            n_scores = weights * n_scores
        p_loss = -F.logsigmoid(p_scores).mean()
        n_loss = -F.logsigmoid(-n_scores).mean()

        return (p_loss + n_loss) / 2, p_loss, n_loss


In [11]:
df_ddi_train = pd.read_csv(f"{data_dir}/ddi_training.csv")
df_ddi_val = pd.read_csv(f"{data_dir}/ddi_validation.csv")
df_ddi_test = pd.read_csv(f"{data_dir}/ddi_test.csv")


train_tup = [
    (h, t, r)
    for h, t, r in zip(df_ddi_train["d1"], df_ddi_train["d2"], df_ddi_train["type"])
]
val_tup = [
    (h, t, r) for h, t, r in zip(df_ddi_val["d1"], df_ddi_val["d2"], df_ddi_val["type"])
]
test_tup = [
    (h, t, r)
    for h, t, r in zip(df_ddi_test["d1"], df_ddi_test["d2"], df_ddi_test["type"])
]

train_data = DrugDataset(train_tup, ratio=data_size_ratio, neg_ent=neg_samples)
val_data = DrugDataset(val_tup, ratio=data_size_ratio, disjoint_split=False)
test_data = DrugDataset(test_tup, disjoint_split=False)

print(
    f"Training with {len(train_data)} samples, validating with {len(val_data)}, and testing with {len(test_data)}"
)

train_data_loader = DrugDataLoader(train_data, batch_size=batch_size, shuffle=True)
val_data_loader = DrugDataLoader(val_data, batch_size=batch_size * 3)
test_data_loader = DrugDataLoader(test_data, batch_size=batch_size * 3)


Training with 115185 samples, validating with 38348, and testing with 38337


In [12]:
def do_compute(model, batch, device, training=True):
    """
    *batch: (pos_tri, neg_tri)
    *pos/neg_tri: (batch_h, batch_t, batch_r)
    """
    probas_pred, ground_truth = [], []
    pos_tri, neg_tri = batch

    pos_tri = [tensor.to(device=device) for tensor in pos_tri]
    p_score = model(pos_tri)
    probas_pred.append(torch.sigmoid(p_score.detach()).cpu())
    ground_truth.append(np.ones(len(p_score)))

    neg_tri = [tensor.to(device=device) for tensor in neg_tri]
    n_score = model(neg_tri)
    probas_pred.append(torch.sigmoid(n_score.detach()).cpu())
    ground_truth.append(np.zeros(len(n_score)))

    probas_pred = np.concatenate(probas_pred)
    ground_truth = np.concatenate(ground_truth)

    return p_score, n_score, probas_pred, ground_truth


def do_compute_metrics(probas_pred, target):

    pred = (probas_pred >= 0.5).astype(np.int64)

    acc = metrics.accuracy_score(target, pred)
    auc_roc = metrics.roc_auc_score(target, probas_pred)
    f1_score = metrics.f1_score(target, pred)

    p, r, t = metrics.precision_recall_curve(target, probas_pred)
    auc_prc = metrics.auc(r, p)

    return acc, auc_roc, auc_prc

In [13]:
import csv
def export_metrics(train_metrics, val_metrics, epoch):
    train_metrics_dir = "train_metrics"
    metrics_file = f"{train_metrics_dir}/{model_name}.csv"
    train_loss, train_acc, train_auc_roc, train_auc_prc = train_metrics
    val_loss, val_acc, val_auc_roc, val_auc_prc = val_metrics

    data = [epoch, train_loss, train_acc, train_auc_roc, train_auc_prc, val_loss, val_acc, val_auc_roc, val_auc_prc]
    header = ["epoch", "train_loss", "train_acc", "train_auc_roc", "train_auc_prc", "val_loss", "val_acc", "val_auc_roc", "val_auc_prc"]
    
    if epoch == 1:
        with open(metrics_file, 'w', newline='') as file:
            writer = csv.writer(file)
            # Write the header
            writer.writerow(header)
            # Write the data rows
            writer.writerow(data)
    else:
        with open(metrics_file, 'a', newline='') as file:
            writer = csv.writer(file)
            # Write the data to the file
            writer.writerow(data)
    
    

In [14]:
model_acc_file = f"{model_dir}/acc/{model_name}.pth"
model_roc_file = f"{model_dir}/roc/{model_name}.pth"
model_prc_file = f"{model_dir}/prc/{model_name}.pth"

def save_model(best, current, met_type):
    model_file = model_prc_file
    if met_type == "acc":
        model_file = model_acc_file
    elif met_type == "roc":
        model_file = model_roc_file
        
    if best < current:
        print(f"Saving model {met_type}")
        best = current
        torch.save(model, model_file)
    return best

In [15]:
def train(
    model,
    train_data_loader,
    val_data_loader,
    loss_fn,
    optimizer,
    n_epochs,
    device,
    scheduler=None,
):
    print("Starting training at:", datetime.today())
    print("Device:", device)
    print_tunning_parameters()
    best_val_prc = 0
    best_val_acc = 0
    best_val_roc = 0
    for i in range(1, n_epochs + 1):
        start = time.time()
        train_loss = 0
        train_loss_pos = 0
        train_loss_neg = 0
        val_loss = 0
        val_loss_pos = 0
        val_loss_neg = 0
        train_probas_pred = []
        train_ground_truth = []
        val_probas_pred = []
        val_ground_truth = []

        for batch in train_data_loader:
            # print(len(batch))
            model.train()
            p_score, n_score, probas_pred, ground_truth = do_compute(model, batch, device)
            train_probas_pred.append(probas_pred)
            train_ground_truth.append(ground_truth)
            loss, loss_p, loss_n = loss_fn(p_score, n_score)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_loss += loss.item() * len(p_score)
        train_loss /= len(train_data)

        with torch.no_grad():
            train_probas_pred = np.concatenate(train_probas_pred)
            train_ground_truth = np.concatenate(train_ground_truth)

            train_acc, train_auc_roc, train_auc_prc = do_compute_metrics(
                train_probas_pred, train_ground_truth
            )

            for batch in val_data_loader:
                model.eval()
                p_score, n_score, probas_pred, ground_truth = do_compute(model, batch, device)
                val_probas_pred.append(probas_pred)
                val_ground_truth.append(ground_truth)
                loss, loss_p, loss_n = loss_fn(p_score, n_score)
                val_loss += loss.item() * len(p_score)

            val_loss /= len(val_data)
            val_probas_pred = np.concatenate(val_probas_pred)
            val_ground_truth = np.concatenate(val_ground_truth)
            val_acc, val_auc_roc, val_auc_prc = do_compute_metrics(
                val_probas_pred, val_ground_truth
            )
            
            # Save model if this is the best so far
            best_val_prc = save_model(best_val_prc, val_auc_prc, "prc")
            best_val_acc = save_model(best_val_acc, val_acc, "acc")
            best_val_roc = save_model(best_val_roc, val_auc_roc, "roc")

        if scheduler:
            # print('scheduling')
            scheduler.step()

        # Exporting metrics for later plots
        train_metrics = (train_loss, train_acc, train_auc_roc, train_auc_prc)
        val_metrics = (val_loss, val_acc, val_auc_roc, val_auc_prc)
        export_metrics(train_metrics, val_metrics, i)
        
        print(
            f"Epoch: {i} ({time.time() - start:.4f}s), train_loss: {train_loss:.4f}, val_loss: {val_loss:.4f},"
            f" train_acc: {train_acc:.4f}, val_acc:{val_acc:.4f}"
        )
        print(
            f"\t\ttrain_roc: {train_auc_roc:.4f}, val_roc: {val_auc_roc:.4f}, train_auprc: {train_auc_prc:.4f}, val_auprc: {val_auc_prc:.4f}"
        )

    return model

In [16]:
def predict(model, test_data_loader, device):
    print('Starting predicting at', datetime.today())
    print('Device', device)

    test_probas_pred = []
    test_ground_truth = []

    # Switch to evaluation mode
    model.eval()

    with torch.no_grad():  # No need to calculate gradients during testing
        for batch in test_data_loader:
            # Get predictions and ground truth for the batch
            p_score, n_score, probas_pred, ground_truth = do_compute(model, batch, device, training=False)

            # Append the predictions and ground truths
            test_probas_pred.append(probas_pred)
            test_ground_truth.append(ground_truth)


    # Concatenate the results for the entire test dataset
    test_probas_pred = np.concatenate(test_probas_pred)
    test_ground_truth = np.concatenate(test_ground_truth)

    # Calculate the metrics for the test dataset
    test_acc, test_auc_roc, test_auc_prc = do_compute_metrics(test_probas_pred, test_ground_truth)

    print(f'Test Accuracy: {test_acc:.4f}')
    print(f'Test ROC AUC: {test_auc_roc:.4f}')
    print(f'Test PRC AUC: {test_auc_prc:.4f}')

In [17]:
heads_out_feat_params = []
block_params = []

for _ in range(num_GAT_layers):
    heads_out_feat_params.append(kge_dim // 2)
    block_params.append(num_GAT_multiheads)

if mode == "train":
    model = SSI_DDI(
        n_atom_feats,
        n_atom_hid,
        kge_dim,
        rel_total,
        heads_out_feat_params=heads_out_feat_params,
        blocks_params=block_params,
        sp_ratio=sp_ratio,
        use_activation_fn=use_activation_fn,
        use_ComplEx=use_ComplEx,
        sp_min_score=sp_min_score,
        co_attention_method=co_attention_method,
    )
    loss = SigmoidLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    scheduler = optim.lr_scheduler.LambdaLR(optimizer, lambda epoch: 0.96 ** (epoch))
    print(model)
    model.to(device=device)

SSI_DDI(
  (initial_norm): LayerNorm(55, affine=True, mode=graph)
  (net_norms): ModuleList(
    (0-5): 6 x LayerNorm(64, affine=True, mode=graph)
  )
  (block0): SSI_DDI_Block(
    (conv): GATConv(55, 32, heads=2)
    (readout): SAGPooling(GraphConv, 64, min_score=-1, multiplier=1.0)
  )
  (block1): SSI_DDI_Block(
    (conv): GATConv(64, 32, heads=2)
    (readout): SAGPooling(GraphConv, 64, min_score=-1, multiplier=1.0)
  )
  (block2): SSI_DDI_Block(
    (conv): GATConv(64, 32, heads=2)
    (readout): SAGPooling(GraphConv, 64, min_score=-1, multiplier=1.0)
  )
  (block3): SSI_DDI_Block(
    (conv): GATConv(64, 32, heads=2)
    (readout): SAGPooling(GraphConv, 64, min_score=-1, multiplier=1.0)
  )
  (block4): SSI_DDI_Block(
    (conv): GATConv(64, 32, heads=2)
    (readout): SAGPooling(GraphConv, 64, min_score=-1, multiplier=1.0)
  )
  (block5): SSI_DDI_Block(
    (conv): GATConv(64, 32, heads=2)
    (readout): SAGPooling(GraphConv, 64, min_score=-1, multiplier=1.0)
  )
  (co_attention

In [18]:
if mode == "train":
  # Train
  train(
      model,
      train_data_loader,
      val_data_loader,
      loss,
      optimizer,
      n_epochs,
      device,
      scheduler,
  )


Starting training at: 2024-10-27 07:46:21.010866
Device: cuda

####### Tunning parameters #######

n_epochs = 300
use_cuda = True

num_GAT_layers =  6
num_GAT_multiheads =  2

sp_ratio = None
sp_min_score = None

use_explicit_valence = False

use_activation_fn = False

use_ComplEx = False

co_attention_method = multihead

#################################





Saving model prc
Saving model acc
Saving model roc
Epoch: 1 (122.2268s), train_loss: 0.7037, val_loss: 0.6557, train_acc: 0.5551, val_acc:0.6048
		train_roc: 0.5786, val_roc: 0.6473, train_auprc: 0.5627, val_auprc: 0.6336




Saving model prc
Saving model acc
Saving model roc
Epoch: 2 (122.5854s), train_loss: 0.6238, val_loss: 0.5947, train_acc: 0.6403, val_acc:0.6705
		train_roc: 0.7002, val_roc: 0.7409, train_auprc: 0.6837, val_auprc: 0.7223




Saving model prc
Saving model acc
Saving model roc
Epoch: 3 (118.4721s), train_loss: 0.5729, val_loss: 0.5549, train_acc: 0.6937, val_acc:0.7105
		train_roc: 0.7658, val_roc: 0.7852, train_auprc: 0.7451, val_auprc: 0.7650




Saving model prc
Saving model acc
Saving model roc
Epoch: 4 (131.7888s), train_loss: 0.5431, val_loss: 0.5313, train_acc: 0.7209, val_acc:0.7320
		train_roc: 0.7961, val_roc: 0.8076, train_auprc: 0.7760, val_auprc: 0.7877




Saving model prc
Saving model acc
Saving model roc
Epoch: 5 (116.8998s), train_loss: 0.5207, val_loss: 0.5090, train_acc: 0.7402, val_acc:0.7487
		train_roc: 0.8166, val_roc: 0.8260, train_auprc: 0.7959, val_auprc: 0.8064




Saving model prc
Saving model acc
Saving model roc
Epoch: 6 (125.1346s), train_loss: 0.5008, val_loss: 0.5016, train_acc: 0.7543, val_acc:0.7574
		train_roc: 0.8327, val_roc: 0.8344, train_auprc: 0.8129, val_auprc: 0.8129




Saving model prc
Saving model acc
Saving model roc
Epoch: 7 (129.0061s), train_loss: 0.4869, val_loss: 0.4809, train_acc: 0.7645, val_acc:0.7708
		train_roc: 0.8436, val_roc: 0.8502, train_auprc: 0.8247, val_auprc: 0.8309




Saving model prc
Saving model roc
Epoch: 8 (120.7930s), train_loss: 0.4745, val_loss: 0.4746, train_acc: 0.7747, val_acc:0.7701
		train_roc: 0.8524, val_roc: 0.8541, train_auprc: 0.8331, val_auprc: 0.8353




Saving model prc
Saving model acc
Saving model roc
Epoch: 9 (114.9930s), train_loss: 0.4628, val_loss: 0.4595, train_acc: 0.7824, val_acc:0.7861
		train_roc: 0.8606, val_roc: 0.8660, train_auprc: 0.8411, val_auprc: 0.8482




Saving model prc
Saving model acc
Saving model roc
Epoch: 10 (112.1057s), train_loss: 0.4516, val_loss: 0.4453, train_acc: 0.7889, val_acc:0.7951
		train_roc: 0.8682, val_roc: 0.8727, train_auprc: 0.8495, val_auprc: 0.8540




Saving model prc
Saving model roc
Epoch: 11 (111.5976s), train_loss: 0.4461, val_loss: 0.4410, train_acc: 0.7929, val_acc:0.7934
		train_roc: 0.8716, val_roc: 0.8749, train_auprc: 0.8524, val_auprc: 0.8585




Saving model acc
Saving model roc
Epoch: 12 (112.3596s), train_loss: 0.4382, val_loss: 0.4406, train_acc: 0.7984, val_acc:0.7963
		train_roc: 0.8764, val_roc: 0.8750, train_auprc: 0.8575, val_auprc: 0.8567




Saving model prc
Saving model acc
Saving model roc
Epoch: 13 (110.5776s), train_loss: 0.4285, val_loss: 0.4323, train_acc: 0.8040, val_acc:0.8035
		train_roc: 0.8820, val_roc: 0.8811, train_auprc: 0.8636, val_auprc: 0.8624




Saving model prc
Saving model acc
Saving model roc
Epoch: 14 (113.2391s), train_loss: 0.4234, val_loss: 0.4203, train_acc: 0.8079, val_acc:0.8096
		train_roc: 0.8853, val_roc: 0.8882, train_auprc: 0.8667, val_auprc: 0.8716




Epoch: 15 (112.1012s), train_loss: 0.4156, val_loss: 0.4223, train_acc: 0.8133, val_acc:0.8063
		train_roc: 0.8902, val_roc: 0.8863, train_auprc: 0.8729, val_auprc: 0.8696




Saving model prc
Saving model acc
Saving model roc
Epoch: 16 (111.4525s), train_loss: 0.4096, val_loss: 0.4141, train_acc: 0.8160, val_acc:0.8155
		train_roc: 0.8929, val_roc: 0.8923, train_auprc: 0.8748, val_auprc: 0.8775




Saving model prc
Saving model acc
Saving model roc
Epoch: 17 (112.6820s), train_loss: 0.4013, val_loss: 0.4085, train_acc: 0.8214, val_acc:0.8157
		train_roc: 0.8974, val_roc: 0.8937, train_auprc: 0.8803, val_auprc: 0.8777




Saving model acc
Saving model roc
Epoch: 18 (112.2459s), train_loss: 0.3970, val_loss: 0.4047, train_acc: 0.8245, val_acc:0.8231
		train_roc: 0.8997, val_roc: 0.8957, train_auprc: 0.8821, val_auprc: 0.8764




Saving model prc
Saving model acc
Saving model roc
Epoch: 19 (112.3073s), train_loss: 0.3926, val_loss: 0.3968, train_acc: 0.8272, val_acc:0.8255
		train_roc: 0.9022, val_roc: 0.9001, train_auprc: 0.8855, val_auprc: 0.8827




Saving model prc
Saving model roc
Epoch: 20 (110.5714s), train_loss: 0.3879, val_loss: 0.3932, train_acc: 0.8290, val_acc:0.8238
		train_roc: 0.9045, val_roc: 0.9022, train_auprc: 0.8875, val_auprc: 0.8871




Saving model prc
Saving model acc
Saving model roc
Epoch: 21 (112.2513s), train_loss: 0.3840, val_loss: 0.3880, train_acc: 0.8325, val_acc:0.8302
		train_roc: 0.9065, val_roc: 0.9047, train_auprc: 0.8891, val_auprc: 0.8885




Saving model prc
Saving model acc
Saving model roc
Epoch: 22 (110.7709s), train_loss: 0.3771, val_loss: 0.3815, train_acc: 0.8353, val_acc:0.8348
		train_roc: 0.9101, val_roc: 0.9087, train_auprc: 0.8933, val_auprc: 0.8931




Saving model prc
Epoch: 23 (110.8033s), train_loss: 0.3715, val_loss: 0.3810, train_acc: 0.8383, val_acc:0.8306
		train_roc: 0.9124, val_roc: 0.9086, train_auprc: 0.8969, val_auprc: 0.8944




Saving model prc
Saving model acc
Saving model roc
Epoch: 24 (110.6215s), train_loss: 0.3666, val_loss: 0.3675, train_acc: 0.8412, val_acc:0.8434
		train_roc: 0.9147, val_roc: 0.9147, train_auprc: 0.8990, val_auprc: 0.8989




Epoch: 25 (111.5864s), train_loss: 0.3641, val_loss: 0.3741, train_acc: 0.8426, val_acc:0.8382
		train_roc: 0.9160, val_roc: 0.9127, train_auprc: 0.8999, val_auprc: 0.8974




Saving model prc
Saving model roc
Epoch: 26 (112.2543s), train_loss: 0.3580, val_loss: 0.3691, train_acc: 0.8462, val_acc:0.8378
		train_roc: 0.9188, val_roc: 0.9151, train_auprc: 0.9034, val_auprc: 0.9013




Saving model prc
Saving model acc
Saving model roc
Epoch: 27 (111.0201s), train_loss: 0.3568, val_loss: 0.3625, train_acc: 0.8471, val_acc:0.8455
		train_roc: 0.9196, val_roc: 0.9168, train_auprc: 0.9034, val_auprc: 0.9015




Epoch: 28 (111.1323s), train_loss: 0.3476, val_loss: 0.3687, train_acc: 0.8511, val_acc:0.8414
		train_roc: 0.9233, val_roc: 0.9149, train_auprc: 0.9084, val_auprc: 0.9003




Saving model prc
Saving model acc
Saving model roc
Epoch: 29 (110.5392s), train_loss: 0.3474, val_loss: 0.3556, train_acc: 0.8520, val_acc:0.8473
		train_roc: 0.9235, val_roc: 0.9207, train_auprc: 0.9085, val_auprc: 0.9064




Saving model acc
Epoch: 30 (111.7972s), train_loss: 0.3427, val_loss: 0.3592, train_acc: 0.8537, val_acc:0.8487
		train_roc: 0.9255, val_roc: 0.9191, train_auprc: 0.9108, val_auprc: 0.9038




Saving model prc
Saving model acc
Saving model roc
Epoch: 31 (112.7801s), train_loss: 0.3407, val_loss: 0.3467, train_acc: 0.8559, val_acc:0.8536
		train_roc: 0.9265, val_roc: 0.9238, train_auprc: 0.9112, val_auprc: 0.9101




Saving model prc
Saving model acc
Saving model roc
Epoch: 32 (111.4543s), train_loss: 0.3347, val_loss: 0.3458, train_acc: 0.8592, val_acc:0.8537
		train_roc: 0.9290, val_roc: 0.9252, train_auprc: 0.9145, val_auprc: 0.9102




Saving model prc
Saving model acc
Saving model roc
Epoch: 33 (112.8873s), train_loss: 0.3331, val_loss: 0.3456, train_acc: 0.8599, val_acc:0.8553
		train_roc: 0.9296, val_roc: 0.9266, train_auprc: 0.9154, val_auprc: 0.9125




Saving model prc
Saving model acc
Saving model roc
Epoch: 34 (110.6848s), train_loss: 0.3276, val_loss: 0.3428, train_acc: 0.8635, val_acc:0.8571
		train_roc: 0.9319, val_roc: 0.9268, train_auprc: 0.9177, val_auprc: 0.9130




Saving model prc
Saving model roc
Epoch: 35 (112.3106s), train_loss: 0.3239, val_loss: 0.3442, train_acc: 0.8651, val_acc:0.8531
		train_roc: 0.9336, val_roc: 0.9276, train_auprc: 0.9194, val_auprc: 0.9145




Saving model prc
Saving model acc
Saving model roc
Epoch: 36 (111.1421s), train_loss: 0.3225, val_loss: 0.3380, train_acc: 0.8652, val_acc:0.8578
		train_roc: 0.9339, val_roc: 0.9289, train_auprc: 0.9195, val_auprc: 0.9155




Saving model prc
Saving model acc
Saving model roc
Epoch: 37 (111.7480s), train_loss: 0.3183, val_loss: 0.3320, train_acc: 0.8678, val_acc:0.8611
		train_roc: 0.9355, val_roc: 0.9309, train_auprc: 0.9222, val_auprc: 0.9177




Saving model prc
Saving model acc
Saving model roc
Epoch: 38 (111.2274s), train_loss: 0.3142, val_loss: 0.3309, train_acc: 0.8696, val_acc:0.8623
		train_roc: 0.9372, val_roc: 0.9312, train_auprc: 0.9237, val_auprc: 0.9181




Saving model prc
Saving model roc
Epoch: 39 (113.4347s), train_loss: 0.3103, val_loss: 0.3285, train_acc: 0.8723, val_acc:0.8622
		train_roc: 0.9388, val_roc: 0.9325, train_auprc: 0.9257, val_auprc: 0.9205




Saving model prc
Saving model acc
Saving model roc
Epoch: 40 (111.2918s), train_loss: 0.3085, val_loss: 0.3262, train_acc: 0.8729, val_acc:0.8649
		train_roc: 0.9393, val_roc: 0.9333, train_auprc: 0.9261, val_auprc: 0.9206




Saving model prc
Saving model roc
Epoch: 41 (112.3467s), train_loss: 0.3040, val_loss: 0.3272, train_acc: 0.8749, val_acc:0.8648
		train_roc: 0.9411, val_roc: 0.9340, train_auprc: 0.9282, val_auprc: 0.9217




Saving model prc
Saving model acc
Saving model roc
Epoch: 42 (112.3722s), train_loss: 0.3038, val_loss: 0.3207, train_acc: 0.8755, val_acc:0.8675
		train_roc: 0.9413, val_roc: 0.9353, train_auprc: 0.9281, val_auprc: 0.9220




Saving model prc
Epoch: 43 (112.2605s), train_loss: 0.2995, val_loss: 0.3222, train_acc: 0.8777, val_acc:0.8667
		train_roc: 0.9428, val_roc: 0.9351, train_auprc: 0.9299, val_auprc: 0.9221




Saving model prc
Saving model acc
Saving model roc
Epoch: 44 (111.7587s), train_loss: 0.2971, val_loss: 0.3154, train_acc: 0.8791, val_acc:0.8715
		train_roc: 0.9438, val_roc: 0.9374, train_auprc: 0.9314, val_auprc: 0.9247




Epoch: 45 (111.8840s), train_loss: 0.2951, val_loss: 0.3187, train_acc: 0.8793, val_acc:0.8680
		train_roc: 0.9445, val_roc: 0.9365, train_auprc: 0.9324, val_auprc: 0.9247




Epoch: 46 (111.0958s), train_loss: 0.2927, val_loss: 0.3181, train_acc: 0.8818, val_acc:0.8698
		train_roc: 0.9454, val_roc: 0.9362, train_auprc: 0.9330, val_auprc: 0.9229




Saving model prc
Saving model acc
Saving model roc
Epoch: 47 (112.4713s), train_loss: 0.2909, val_loss: 0.3102, train_acc: 0.8829, val_acc:0.8730
		train_roc: 0.9458, val_roc: 0.9396, train_auprc: 0.9329, val_auprc: 0.9278




Saving model prc
Epoch: 48 (111.4418s), train_loss: 0.2874, val_loss: 0.3128, train_acc: 0.8847, val_acc:0.8710
		train_roc: 0.9470, val_roc: 0.9392, train_auprc: 0.9344, val_auprc: 0.9278




Saving model prc
Saving model acc
Saving model roc
Epoch: 49 (111.6469s), train_loss: 0.2836, val_loss: 0.3081, train_acc: 0.8855, val_acc:0.8739
		train_roc: 0.9483, val_roc: 0.9405, train_auprc: 0.9366, val_auprc: 0.9291




Saving model prc
Saving model acc
Saving model roc
Epoch: 50 (112.0352s), train_loss: 0.2823, val_loss: 0.3033, train_acc: 0.8864, val_acc:0.8761
		train_roc: 0.9487, val_roc: 0.9422, train_auprc: 0.9367, val_auprc: 0.9315




Saving model prc
Saving model acc
Saving model roc
Epoch: 51 (112.0072s), train_loss: 0.2778, val_loss: 0.2994, train_acc: 0.8884, val_acc:0.8780
		train_roc: 0.9504, val_roc: 0.9438, train_auprc: 0.9390, val_auprc: 0.9336




Saving model acc
Saving model roc
Epoch: 52 (111.0960s), train_loss: 0.2760, val_loss: 0.3008, train_acc: 0.8889, val_acc:0.8782
		train_roc: 0.9508, val_roc: 0.9440, train_auprc: 0.9396, val_auprc: 0.9329




Epoch: 53 (111.7002s), train_loss: 0.2752, val_loss: 0.2992, train_acc: 0.8894, val_acc:0.8780
		train_roc: 0.9511, val_roc: 0.9440, train_auprc: 0.9396, val_auprc: 0.9334




Saving model prc
Saving model acc
Saving model roc
Epoch: 54 (111.8270s), train_loss: 0.2720, val_loss: 0.2989, train_acc: 0.8911, val_acc:0.8789
		train_roc: 0.9523, val_roc: 0.9446, train_auprc: 0.9414, val_auprc: 0.9341




Saving model prc
Saving model acc
Saving model roc
Epoch: 55 (110.7173s), train_loss: 0.2705, val_loss: 0.2968, train_acc: 0.8918, val_acc:0.8789
		train_roc: 0.9526, val_roc: 0.9455, train_auprc: 0.9413, val_auprc: 0.9357




Saving model prc
Saving model acc
Saving model roc
Epoch: 56 (111.5148s), train_loss: 0.2694, val_loss: 0.2937, train_acc: 0.8934, val_acc:0.8806
		train_roc: 0.9532, val_roc: 0.9458, train_auprc: 0.9422, val_auprc: 0.9363




Saving model acc
Epoch: 57 (112.3078s), train_loss: 0.2662, val_loss: 0.2944, train_acc: 0.8942, val_acc:0.8807
		train_roc: 0.9541, val_roc: 0.9456, train_auprc: 0.9433, val_auprc: 0.9358




Saving model prc
Saving model acc
Saving model roc
Epoch: 58 (112.7414s), train_loss: 0.2642, val_loss: 0.2924, train_acc: 0.8958, val_acc:0.8830
		train_roc: 0.9548, val_roc: 0.9467, train_auprc: 0.9436, val_auprc: 0.9364




Saving model prc
Epoch: 59 (112.2147s), train_loss: 0.2599, val_loss: 0.2937, train_acc: 0.8977, val_acc:0.8801
		train_roc: 0.9561, val_roc: 0.9466, train_auprc: 0.9456, val_auprc: 0.9368




Saving model prc
Saving model roc
Epoch: 60 (110.5283s), train_loss: 0.2594, val_loss: 0.2911, train_acc: 0.8977, val_acc:0.8827
		train_roc: 0.9561, val_roc: 0.9470, train_auprc: 0.9452, val_auprc: 0.9373




Saving model prc
Saving model acc
Saving model roc
Epoch: 61 (112.6540s), train_loss: 0.2593, val_loss: 0.2882, train_acc: 0.8980, val_acc:0.8835
		train_roc: 0.9562, val_roc: 0.9478, train_auprc: 0.9453, val_auprc: 0.9383




Saving model acc
Epoch: 62 (111.1737s), train_loss: 0.2556, val_loss: 0.2899, train_acc: 0.8991, val_acc:0.8837
		train_roc: 0.9573, val_roc: 0.9477, train_auprc: 0.9472, val_auprc: 0.9377




Saving model prc
Saving model acc
Saving model roc
Epoch: 63 (112.7872s), train_loss: 0.2541, val_loss: 0.2878, train_acc: 0.9003, val_acc:0.8857
		train_roc: 0.9578, val_roc: 0.9488, train_auprc: 0.9477, val_auprc: 0.9391




Epoch: 64 (111.7802s), train_loss: 0.2527, val_loss: 0.2898, train_acc: 0.9011, val_acc:0.8848
		train_roc: 0.9584, val_roc: 0.9479, train_auprc: 0.9480, val_auprc: 0.9373




Saving model prc
Saving model acc
Saving model roc
Epoch: 65 (111.5092s), train_loss: 0.2507, val_loss: 0.2825, train_acc: 0.9019, val_acc:0.8874
		train_roc: 0.9588, val_roc: 0.9502, train_auprc: 0.9485, val_auprc: 0.9410




Epoch: 66 (112.8885s), train_loss: 0.2488, val_loss: 0.2873, train_acc: 0.9024, val_acc:0.8865
		train_roc: 0.9594, val_roc: 0.9488, train_auprc: 0.9495, val_auprc: 0.9387




Saving model prc
Saving model roc
Epoch: 67 (111.5205s), train_loss: 0.2493, val_loss: 0.2828, train_acc: 0.9023, val_acc:0.8872
		train_roc: 0.9592, val_roc: 0.9504, train_auprc: 0.9494, val_auprc: 0.9414




Epoch: 68 (113.1470s), train_loss: 0.2444, val_loss: 0.2864, train_acc: 0.9044, val_acc:0.8870
		train_roc: 0.9608, val_roc: 0.9499, train_auprc: 0.9513, val_auprc: 0.9395




Saving model prc
Saving model acc
Saving model roc
Epoch: 69 (112.6862s), train_loss: 0.2445, val_loss: 0.2796, train_acc: 0.9053, val_acc:0.8885
		train_roc: 0.9608, val_roc: 0.9513, train_auprc: 0.9512, val_auprc: 0.9426




Saving model roc
Epoch: 70 (111.4074s), train_loss: 0.2413, val_loss: 0.2809, train_acc: 0.9057, val_acc:0.8872
		train_roc: 0.9617, val_roc: 0.9514, train_auprc: 0.9526, val_auprc: 0.9423




Epoch: 71 (110.8725s), train_loss: 0.2403, val_loss: 0.2843, train_acc: 0.9065, val_acc:0.8861
		train_roc: 0.9618, val_roc: 0.9504, train_auprc: 0.9523, val_auprc: 0.9414




Saving model acc
Saving model roc
Epoch: 72 (113.2000s), train_loss: 0.2401, val_loss: 0.2788, train_acc: 0.9076, val_acc:0.8901
		train_roc: 0.9617, val_roc: 0.9523, train_auprc: 0.9522, val_auprc: 0.9424




Saving model prc
Saving model acc
Saving model roc
Epoch: 73 (111.4768s), train_loss: 0.2380, val_loss: 0.2746, train_acc: 0.9077, val_acc:0.8908
		train_roc: 0.9625, val_roc: 0.9533, train_auprc: 0.9530, val_auprc: 0.9445




Epoch: 74 (111.3727s), train_loss: 0.2383, val_loss: 0.2774, train_acc: 0.9084, val_acc:0.8894
		train_roc: 0.9623, val_roc: 0.9526, train_auprc: 0.9522, val_auprc: 0.9434




Saving model prc
Saving model acc
Saving model roc
Epoch: 75 (110.4473s), train_loss: 0.2354, val_loss: 0.2730, train_acc: 0.9095, val_acc:0.8912
		train_roc: 0.9631, val_roc: 0.9540, train_auprc: 0.9538, val_auprc: 0.9458




Epoch: 76 (111.5422s), train_loss: 0.2329, val_loss: 0.2778, train_acc: 0.9101, val_acc:0.8897
		train_roc: 0.9639, val_roc: 0.9533, train_auprc: 0.9551, val_auprc: 0.9447




Saving model prc
Saving model acc
Saving model roc
Epoch: 77 (111.4618s), train_loss: 0.2325, val_loss: 0.2744, train_acc: 0.9107, val_acc:0.8920
		train_roc: 0.9641, val_roc: 0.9546, train_auprc: 0.9549, val_auprc: 0.9468




Saving model acc
Epoch: 78 (112.2534s), train_loss: 0.2308, val_loss: 0.2751, train_acc: 0.9112, val_acc:0.8924
		train_roc: 0.9647, val_roc: 0.9538, train_auprc: 0.9562, val_auprc: 0.9452




Epoch: 79 (111.6319s), train_loss: 0.2296, val_loss: 0.2735, train_acc: 0.9119, val_acc:0.8922
		train_roc: 0.9649, val_roc: 0.9540, train_auprc: 0.9560, val_auprc: 0.9457




Saving model acc
Saving model roc
Epoch: 80 (111.8775s), train_loss: 0.2297, val_loss: 0.2713, train_acc: 0.9117, val_acc:0.8936
		train_roc: 0.9649, val_roc: 0.9547, train_auprc: 0.9560, val_auprc: 0.9465




Epoch: 81 (112.1832s), train_loss: 0.2281, val_loss: 0.2740, train_acc: 0.9122, val_acc:0.8922
		train_roc: 0.9651, val_roc: 0.9542, train_auprc: 0.9563, val_auprc: 0.9463




Epoch: 82 (111.9915s), train_loss: 0.2276, val_loss: 0.2750, train_acc: 0.9130, val_acc:0.8919
		train_roc: 0.9653, val_roc: 0.9533, train_auprc: 0.9564, val_auprc: 0.9445




Epoch: 83 (112.7388s), train_loss: 0.2251, val_loss: 0.2728, train_acc: 0.9140, val_acc:0.8934
		train_roc: 0.9660, val_roc: 0.9542, train_auprc: 0.9575, val_auprc: 0.9460




Saving model prc
Saving model acc
Saving model roc
Epoch: 84 (112.4296s), train_loss: 0.2257, val_loss: 0.2715, train_acc: 0.9144, val_acc:0.8945
		train_roc: 0.9660, val_roc: 0.9552, train_auprc: 0.9571, val_auprc: 0.9470




Saving model prc
Epoch: 85 (110.6802s), train_loss: 0.2228, val_loss: 0.2715, train_acc: 0.9149, val_acc:0.8943
		train_roc: 0.9666, val_roc: 0.9551, train_auprc: 0.9581, val_auprc: 0.9477




Saving model roc
Epoch: 86 (111.5426s), train_loss: 0.2220, val_loss: 0.2708, train_acc: 0.9158, val_acc:0.8945
		train_roc: 0.9668, val_roc: 0.9556, train_auprc: 0.9583, val_auprc: 0.9476




Epoch: 87 (111.7327s), train_loss: 0.2213, val_loss: 0.2733, train_acc: 0.9158, val_acc:0.8941
		train_roc: 0.9667, val_roc: 0.9553, train_auprc: 0.9580, val_auprc: 0.9472




Saving model prc
Saving model acc
Saving model roc
Epoch: 88 (111.9582s), train_loss: 0.2197, val_loss: 0.2686, train_acc: 0.9165, val_acc:0.8960
		train_roc: 0.9674, val_roc: 0.9561, train_auprc: 0.9592, val_auprc: 0.9482




Epoch: 89 (111.8292s), train_loss: 0.2207, val_loss: 0.2738, train_acc: 0.9163, val_acc:0.8921
		train_roc: 0.9671, val_roc: 0.9549, train_auprc: 0.9586, val_auprc: 0.9473




Saving model prc
Saving model roc
Epoch: 90 (111.5872s), train_loss: 0.2191, val_loss: 0.2694, train_acc: 0.9174, val_acc:0.8955
		train_roc: 0.9675, val_roc: 0.9566, train_auprc: 0.9588, val_auprc: 0.9488




Saving model prc
Epoch: 91 (110.9346s), train_loss: 0.2159, val_loss: 0.2685, train_acc: 0.9187, val_acc:0.8947
		train_roc: 0.9683, val_roc: 0.9561, train_auprc: 0.9601, val_auprc: 0.9489




Epoch: 92 (111.7108s), train_loss: 0.2174, val_loss: 0.2732, train_acc: 0.9175, val_acc:0.8939
		train_roc: 0.9679, val_roc: 0.9556, train_auprc: 0.9597, val_auprc: 0.9481




Saving model acc
Epoch: 93 (111.8399s), train_loss: 0.2171, val_loss: 0.2690, train_acc: 0.9183, val_acc:0.8968
		train_roc: 0.9679, val_roc: 0.9557, train_auprc: 0.9595, val_auprc: 0.9474




Saving model prc
Saving model roc
Epoch: 94 (111.0054s), train_loss: 0.2163, val_loss: 0.2672, train_acc: 0.9182, val_acc:0.8967
		train_roc: 0.9683, val_roc: 0.9571, train_auprc: 0.9599, val_auprc: 0.9498




Saving model prc
Saving model acc
Saving model roc
Epoch: 95 (112.1255s), train_loss: 0.2145, val_loss: 0.2660, train_acc: 0.9189, val_acc:0.8976
		train_roc: 0.9689, val_roc: 0.9576, train_auprc: 0.9608, val_auprc: 0.9508




Epoch: 96 (111.6350s), train_loss: 0.2126, val_loss: 0.2679, train_acc: 0.9193, val_acc:0.8970
		train_roc: 0.9694, val_roc: 0.9571, train_auprc: 0.9617, val_auprc: 0.9497




Epoch: 97 (111.8212s), train_loss: 0.2148, val_loss: 0.2726, train_acc: 0.9193, val_acc:0.8950
		train_roc: 0.9684, val_roc: 0.9557, train_auprc: 0.9601, val_auprc: 0.9484




Epoch: 98 (112.0428s), train_loss: 0.2140, val_loss: 0.2689, train_acc: 0.9194, val_acc:0.8976
		train_roc: 0.9688, val_roc: 0.9570, train_auprc: 0.9608, val_auprc: 0.9497




Saving model acc
Epoch: 99 (111.6435s), train_loss: 0.2114, val_loss: 0.2675, train_acc: 0.9205, val_acc:0.8984
		train_roc: 0.9696, val_roc: 0.9571, train_auprc: 0.9617, val_auprc: 0.9497




Saving model roc
Epoch: 100 (111.8009s), train_loss: 0.2149, val_loss: 0.2662, train_acc: 0.9196, val_acc:0.8978
		train_roc: 0.9683, val_roc: 0.9579, train_auprc: 0.9598, val_auprc: 0.9503




Saving model prc
Saving model acc
Saving model roc
Epoch: 101 (111.1392s), train_loss: 0.2104, val_loss: 0.2644, train_acc: 0.9205, val_acc:0.8987
		train_roc: 0.9698, val_roc: 0.9582, train_auprc: 0.9619, val_auprc: 0.9512




Epoch: 102 (112.3526s), train_loss: 0.2112, val_loss: 0.2688, train_acc: 0.9203, val_acc:0.8970
		train_roc: 0.9696, val_roc: 0.9572, train_auprc: 0.9618, val_auprc: 0.9499




Saving model prc
Saving model acc
Saving model roc
Epoch: 103 (110.7741s), train_loss: 0.2086, val_loss: 0.2644, train_acc: 0.9219, val_acc:0.8992
		train_roc: 0.9705, val_roc: 0.9587, train_auprc: 0.9627, val_auprc: 0.9522




Epoch: 104 (112.2357s), train_loss: 0.2107, val_loss: 0.2706, train_acc: 0.9206, val_acc:0.8964
		train_roc: 0.9696, val_roc: 0.9570, train_auprc: 0.9616, val_auprc: 0.9500




Epoch: 105 (111.3878s), train_loss: 0.2073, val_loss: 0.2677, train_acc: 0.9217, val_acc:0.8976
		train_roc: 0.9706, val_roc: 0.9576, train_auprc: 0.9630, val_auprc: 0.9511




Epoch: 106 (112.3291s), train_loss: 0.2076, val_loss: 0.2689, train_acc: 0.9223, val_acc:0.8971
		train_roc: 0.9706, val_roc: 0.9571, train_auprc: 0.9628, val_auprc: 0.9500




Saving model prc
Epoch: 107 (111.4082s), train_loss: 0.2081, val_loss: 0.2660, train_acc: 0.9218, val_acc:0.8983
		train_roc: 0.9702, val_roc: 0.9584, train_auprc: 0.9623, val_auprc: 0.9522




Epoch: 108 (112.2173s), train_loss: 0.2076, val_loss: 0.2659, train_acc: 0.9221, val_acc:0.8984
		train_roc: 0.9703, val_roc: 0.9582, train_auprc: 0.9627, val_auprc: 0.9515




Epoch: 109 (110.7002s), train_loss: 0.2074, val_loss: 0.2678, train_acc: 0.9224, val_acc:0.8975
		train_roc: 0.9704, val_roc: 0.9576, train_auprc: 0.9625, val_auprc: 0.9509




Epoch: 110 (112.8072s), train_loss: 0.2081, val_loss: 0.2678, train_acc: 0.9224, val_acc:0.8987
		train_roc: 0.9701, val_roc: 0.9574, train_auprc: 0.9620, val_auprc: 0.9502




Epoch: 111 (112.5334s), train_loss: 0.2071, val_loss: 0.2692, train_acc: 0.9219, val_acc:0.8975
		train_roc: 0.9706, val_roc: 0.9577, train_auprc: 0.9628, val_auprc: 0.9507




Saving model acc
Epoch: 112 (110.9233s), train_loss: 0.2061, val_loss: 0.2674, train_acc: 0.9225, val_acc:0.8994
		train_roc: 0.9710, val_roc: 0.9575, train_auprc: 0.9634, val_auprc: 0.9500




Epoch: 113 (110.7356s), train_loss: 0.2054, val_loss: 0.2703, train_acc: 0.9228, val_acc:0.8969
		train_roc: 0.9710, val_roc: 0.9579, train_auprc: 0.9634, val_auprc: 0.9510




Epoch: 114 (112.6533s), train_loss: 0.2037, val_loss: 0.2673, train_acc: 0.9240, val_acc:0.8978
		train_roc: 0.9716, val_roc: 0.9581, train_auprc: 0.9641, val_auprc: 0.9511




Epoch: 115 (111.0168s), train_loss: 0.2055, val_loss: 0.2689, train_acc: 0.9233, val_acc:0.8979
		train_roc: 0.9709, val_roc: 0.9579, train_auprc: 0.9630, val_auprc: 0.9507




Epoch: 116 (111.5000s), train_loss: 0.2054, val_loss: 0.2694, train_acc: 0.9233, val_acc:0.8978
		train_roc: 0.9709, val_roc: 0.9579, train_auprc: 0.9632, val_auprc: 0.9511




Epoch: 117 (111.8356s), train_loss: 0.2041, val_loss: 0.2666, train_acc: 0.9237, val_acc:0.8990
		train_roc: 0.9712, val_roc: 0.9583, train_auprc: 0.9638, val_auprc: 0.9516




Saving model prc
Saving model roc
Epoch: 118 (111.7058s), train_loss: 0.2037, val_loss: 0.2658, train_acc: 0.9244, val_acc:0.8981
		train_roc: 0.9713, val_roc: 0.9591, train_auprc: 0.9636, val_auprc: 0.9528




Saving model prc
Epoch: 119 (112.8487s), train_loss: 0.2037, val_loss: 0.2649, train_acc: 0.9241, val_acc:0.8992
		train_roc: 0.9713, val_roc: 0.9591, train_auprc: 0.9636, val_auprc: 0.9528




Epoch: 120 (111.3490s), train_loss: 0.2041, val_loss: 0.2685, train_acc: 0.9243, val_acc:0.8990
		train_roc: 0.9712, val_roc: 0.9578, train_auprc: 0.9634, val_auprc: 0.9502




Epoch: 121 (111.2761s), train_loss: 0.2039, val_loss: 0.2696, train_acc: 0.9234, val_acc:0.8968
		train_roc: 0.9713, val_roc: 0.9581, train_auprc: 0.9638, val_auprc: 0.9513




Epoch: 122 (110.6780s), train_loss: 0.2022, val_loss: 0.2675, train_acc: 0.9250, val_acc:0.8986
		train_roc: 0.9717, val_roc: 0.9584, train_auprc: 0.9641, val_auprc: 0.9514




Saving model prc
Saving model acc
Saving model roc
Epoch: 123 (110.8031s), train_loss: 0.2032, val_loss: 0.2645, train_acc: 0.9243, val_acc:0.9005
		train_roc: 0.9715, val_roc: 0.9595, train_auprc: 0.9636, val_auprc: 0.9529




Epoch: 124 (111.8630s), train_loss: 0.2018, val_loss: 0.2667, train_acc: 0.9250, val_acc:0.8996
		train_roc: 0.9718, val_roc: 0.9582, train_auprc: 0.9641, val_auprc: 0.9511




Epoch: 125 (112.4320s), train_loss: 0.2018, val_loss: 0.2684, train_acc: 0.9246, val_acc:0.8994
		train_roc: 0.9720, val_roc: 0.9584, train_auprc: 0.9646, val_auprc: 0.9511




Epoch: 126 (111.6546s), train_loss: 0.2019, val_loss: 0.2675, train_acc: 0.9252, val_acc:0.8983
		train_roc: 0.9719, val_roc: 0.9583, train_auprc: 0.9642, val_auprc: 0.9516




Epoch: 127 (112.2257s), train_loss: 0.2027, val_loss: 0.2725, train_acc: 0.9242, val_acc:0.8969
		train_roc: 0.9716, val_roc: 0.9573, train_auprc: 0.9641, val_auprc: 0.9499




Epoch: 128 (111.2436s), train_loss: 0.2010, val_loss: 0.2685, train_acc: 0.9256, val_acc:0.8989
		train_roc: 0.9720, val_roc: 0.9582, train_auprc: 0.9643, val_auprc: 0.9508




Epoch: 129 (111.6124s), train_loss: 0.2010, val_loss: 0.2685, train_acc: 0.9253, val_acc:0.8987
		train_roc: 0.9720, val_roc: 0.9586, train_auprc: 0.9645, val_auprc: 0.9523




Saving model prc
Saving model roc
Epoch: 130 (111.6649s), train_loss: 0.2025, val_loss: 0.2636, train_acc: 0.9244, val_acc:0.9001
		train_roc: 0.9716, val_roc: 0.9600, train_auprc: 0.9639, val_auprc: 0.9542




Epoch: 131 (112.3155s), train_loss: 0.2022, val_loss: 0.2689, train_acc: 0.9248, val_acc:0.8990
		train_roc: 0.9716, val_roc: 0.9583, train_auprc: 0.9639, val_auprc: 0.9514




Epoch: 132 (112.5303s), train_loss: 0.1997, val_loss: 0.2696, train_acc: 0.9256, val_acc:0.8979
		train_roc: 0.9723, val_roc: 0.9578, train_auprc: 0.9648, val_auprc: 0.9507




Epoch: 133 (111.6839s), train_loss: 0.2009, val_loss: 0.2677, train_acc: 0.9249, val_acc:0.8990
		train_roc: 0.9721, val_roc: 0.9589, train_auprc: 0.9649, val_auprc: 0.9525




Epoch: 134 (112.1758s), train_loss: 0.2009, val_loss: 0.2652, train_acc: 0.9254, val_acc:0.8994
		train_roc: 0.9721, val_roc: 0.9596, train_auprc: 0.9645, val_auprc: 0.9532




Epoch: 135 (111.1578s), train_loss: 0.1989, val_loss: 0.2686, train_acc: 0.9257, val_acc:0.8991
		train_roc: 0.9727, val_roc: 0.9584, train_auprc: 0.9656, val_auprc: 0.9515




Epoch: 136 (112.5329s), train_loss: 0.1992, val_loss: 0.2668, train_acc: 0.9262, val_acc:0.8989
		train_roc: 0.9725, val_roc: 0.9591, train_auprc: 0.9651, val_auprc: 0.9528




Epoch: 137 (112.1502s), train_loss: 0.1993, val_loss: 0.2668, train_acc: 0.9254, val_acc:0.8997
		train_roc: 0.9724, val_roc: 0.9592, train_auprc: 0.9652, val_auprc: 0.9528




Epoch: 138 (111.7053s), train_loss: 0.2018, val_loss: 0.2688, train_acc: 0.9245, val_acc:0.8985
		train_roc: 0.9718, val_roc: 0.9586, train_auprc: 0.9642, val_auprc: 0.9520




Epoch: 139 (111.5203s), train_loss: 0.1985, val_loss: 0.2660, train_acc: 0.9261, val_acc:0.8990
		train_roc: 0.9727, val_roc: 0.9595, train_auprc: 0.9656, val_auprc: 0.9535




Epoch: 140 (111.0580s), train_loss: 0.1985, val_loss: 0.2689, train_acc: 0.9258, val_acc:0.8986
		train_roc: 0.9727, val_roc: 0.9587, train_auprc: 0.9656, val_auprc: 0.9520




Epoch: 141 (111.0273s), train_loss: 0.1989, val_loss: 0.2697, train_acc: 0.9259, val_acc:0.8981
		train_roc: 0.9726, val_roc: 0.9585, train_auprc: 0.9654, val_auprc: 0.9519




Epoch: 142 (110.7677s), train_loss: 0.1992, val_loss: 0.2711, train_acc: 0.9261, val_acc:0.8974
		train_roc: 0.9724, val_roc: 0.9582, train_auprc: 0.9648, val_auprc: 0.9516




Saving model acc
Epoch: 143 (111.7854s), train_loss: 0.1997, val_loss: 0.2658, train_acc: 0.9256, val_acc:0.9010
		train_roc: 0.9723, val_roc: 0.9592, train_auprc: 0.9648, val_auprc: 0.9522




Epoch: 144 (111.0725s), train_loss: 0.1985, val_loss: 0.2653, train_acc: 0.9264, val_acc:0.8996
		train_roc: 0.9728, val_roc: 0.9597, train_auprc: 0.9653, val_auprc: 0.9532




Epoch: 145 (111.6687s), train_loss: 0.1991, val_loss: 0.2689, train_acc: 0.9260, val_acc:0.8990
		train_roc: 0.9724, val_roc: 0.9583, train_auprc: 0.9649, val_auprc: 0.9517




Epoch: 146 (111.6820s), train_loss: 0.2000, val_loss: 0.2690, train_acc: 0.9255, val_acc:0.8983
		train_roc: 0.9722, val_roc: 0.9586, train_auprc: 0.9646, val_auprc: 0.9521




Epoch: 147 (111.1651s), train_loss: 0.1991, val_loss: 0.2675, train_acc: 0.9263, val_acc:0.9002
		train_roc: 0.9724, val_roc: 0.9594, train_auprc: 0.9649, val_auprc: 0.9529




Epoch: 148 (111.3751s), train_loss: 0.1988, val_loss: 0.2685, train_acc: 0.9263, val_acc:0.8986
		train_roc: 0.9725, val_roc: 0.9589, train_auprc: 0.9650, val_auprc: 0.9525




Epoch: 149 (112.4943s), train_loss: 0.1970, val_loss: 0.2667, train_acc: 0.9266, val_acc:0.9000
		train_roc: 0.9733, val_roc: 0.9594, train_auprc: 0.9662, val_auprc: 0.9525




Epoch: 150 (112.1001s), train_loss: 0.1981, val_loss: 0.2676, train_acc: 0.9270, val_acc:0.8998
		train_roc: 0.9727, val_roc: 0.9590, train_auprc: 0.9651, val_auprc: 0.9525




Epoch: 151 (111.4234s), train_loss: 0.1989, val_loss: 0.2694, train_acc: 0.9264, val_acc:0.8986
		train_roc: 0.9725, val_roc: 0.9584, train_auprc: 0.9649, val_auprc: 0.9515




Epoch: 152 (111.5846s), train_loss: 0.1975, val_loss: 0.2708, train_acc: 0.9267, val_acc:0.8975
		train_roc: 0.9729, val_roc: 0.9580, train_auprc: 0.9656, val_auprc: 0.9513




Epoch: 153 (112.8558s), train_loss: 0.1978, val_loss: 0.2694, train_acc: 0.9265, val_acc:0.8979
		train_roc: 0.9727, val_roc: 0.9585, train_auprc: 0.9655, val_auprc: 0.9521




Epoch: 154 (111.5179s), train_loss: 0.1989, val_loss: 0.2692, train_acc: 0.9261, val_acc:0.8993
		train_roc: 0.9725, val_roc: 0.9585, train_auprc: 0.9652, val_auprc: 0.9517




Epoch: 155 (110.7550s), train_loss: 0.1969, val_loss: 0.2702, train_acc: 0.9268, val_acc:0.8986
		train_roc: 0.9732, val_roc: 0.9584, train_auprc: 0.9661, val_auprc: 0.9516




Epoch: 156 (111.8136s), train_loss: 0.1983, val_loss: 0.2711, train_acc: 0.9263, val_acc:0.8976
		train_roc: 0.9727, val_roc: 0.9581, train_auprc: 0.9652, val_auprc: 0.9514




Epoch: 157 (110.9248s), train_loss: 0.1973, val_loss: 0.2700, train_acc: 0.9271, val_acc:0.8986
		train_roc: 0.9728, val_roc: 0.9583, train_auprc: 0.9654, val_auprc: 0.9512




Epoch: 158 (112.2334s), train_loss: 0.1981, val_loss: 0.2698, train_acc: 0.9265, val_acc:0.8981
		train_roc: 0.9727, val_roc: 0.9587, train_auprc: 0.9652, val_auprc: 0.9522




Epoch: 159 (112.6624s), train_loss: 0.1970, val_loss: 0.2669, train_acc: 0.9268, val_acc:0.8990
		train_roc: 0.9729, val_roc: 0.9596, train_auprc: 0.9655, val_auprc: 0.9536




Epoch: 160 (111.3556s), train_loss: 0.1982, val_loss: 0.2702, train_acc: 0.9268, val_acc:0.8984
		train_roc: 0.9727, val_roc: 0.9583, train_auprc: 0.9651, val_auprc: 0.9518




Epoch: 161 (112.4404s), train_loss: 0.1967, val_loss: 0.2702, train_acc: 0.9270, val_acc:0.8978
		train_roc: 0.9731, val_roc: 0.9584, train_auprc: 0.9657, val_auprc: 0.9520




Epoch: 162 (112.3276s), train_loss: 0.1986, val_loss: 0.2701, train_acc: 0.9262, val_acc:0.8985
		train_roc: 0.9724, val_roc: 0.9584, train_auprc: 0.9648, val_auprc: 0.9515




Epoch: 163 (110.7802s), train_loss: 0.1964, val_loss: 0.2688, train_acc: 0.9275, val_acc:0.8986
		train_roc: 0.9731, val_roc: 0.9589, train_auprc: 0.9658, val_auprc: 0.9525




Epoch: 164 (112.3643s), train_loss: 0.1980, val_loss: 0.2678, train_acc: 0.9265, val_acc:0.8987
		train_roc: 0.9726, val_roc: 0.9591, train_auprc: 0.9651, val_auprc: 0.9532




Epoch: 165 (113.2457s), train_loss: 0.1980, val_loss: 0.2702, train_acc: 0.9264, val_acc:0.8992
		train_roc: 0.9727, val_roc: 0.9583, train_auprc: 0.9654, val_auprc: 0.9512




Epoch: 166 (111.7511s), train_loss: 0.1979, val_loss: 0.2692, train_acc: 0.9262, val_acc:0.8983
		train_roc: 0.9727, val_roc: 0.9588, train_auprc: 0.9655, val_auprc: 0.9522




Epoch: 167 (112.2304s), train_loss: 0.1984, val_loss: 0.2677, train_acc: 0.9261, val_acc:0.8997
		train_roc: 0.9726, val_roc: 0.9593, train_auprc: 0.9653, val_auprc: 0.9529




Epoch: 168 (112.2009s), train_loss: 0.1976, val_loss: 0.2676, train_acc: 0.9264, val_acc:0.9000
		train_roc: 0.9729, val_roc: 0.9593, train_auprc: 0.9656, val_auprc: 0.9526




Epoch: 169 (110.7853s), train_loss: 0.1965, val_loss: 0.2696, train_acc: 0.9272, val_acc:0.8993
		train_roc: 0.9731, val_roc: 0.9584, train_auprc: 0.9660, val_auprc: 0.9515




Epoch: 170 (111.8293s), train_loss: 0.1972, val_loss: 0.2674, train_acc: 0.9273, val_acc:0.8999
		train_roc: 0.9729, val_roc: 0.9592, train_auprc: 0.9653, val_auprc: 0.9527




Epoch: 171 (110.9949s), train_loss: 0.1966, val_loss: 0.2690, train_acc: 0.9270, val_acc:0.8991
		train_roc: 0.9731, val_roc: 0.9589, train_auprc: 0.9660, val_auprc: 0.9526




Epoch: 172 (112.2592s), train_loss: 0.1961, val_loss: 0.2686, train_acc: 0.9275, val_acc:0.9001
		train_roc: 0.9731, val_roc: 0.9588, train_auprc: 0.9658, val_auprc: 0.9521




Epoch: 173 (111.4151s), train_loss: 0.1947, val_loss: 0.2667, train_acc: 0.9276, val_acc:0.9005
		train_roc: 0.9737, val_roc: 0.9596, train_auprc: 0.9667, val_auprc: 0.9532




Epoch: 174 (111.8457s), train_loss: 0.1978, val_loss: 0.2723, train_acc: 0.9272, val_acc:0.8975
		train_roc: 0.9727, val_roc: 0.9580, train_auprc: 0.9651, val_auprc: 0.9508




Epoch: 175 (112.6054s), train_loss: 0.1975, val_loss: 0.2669, train_acc: 0.9267, val_acc:0.9000
		train_roc: 0.9728, val_roc: 0.9596, train_auprc: 0.9654, val_auprc: 0.9532




Epoch: 176 (111.6482s), train_loss: 0.1954, val_loss: 0.2698, train_acc: 0.9276, val_acc:0.8985
		train_roc: 0.9735, val_roc: 0.9585, train_auprc: 0.9664, val_auprc: 0.9521




Epoch: 177 (111.4901s), train_loss: 0.1983, val_loss: 0.2698, train_acc: 0.9266, val_acc:0.8988
		train_roc: 0.9726, val_roc: 0.9587, train_auprc: 0.9651, val_auprc: 0.9519




Epoch: 178 (111.8195s), train_loss: 0.1953, val_loss: 0.2702, train_acc: 0.9276, val_acc:0.8982
		train_roc: 0.9736, val_roc: 0.9587, train_auprc: 0.9667, val_auprc: 0.9521




Epoch: 179 (112.1347s), train_loss: 0.1972, val_loss: 0.2715, train_acc: 0.9268, val_acc:0.8992
		train_roc: 0.9731, val_roc: 0.9580, train_auprc: 0.9657, val_auprc: 0.9508




Epoch: 180 (112.5634s), train_loss: 0.1975, val_loss: 0.2695, train_acc: 0.9268, val_acc:0.8995
		train_roc: 0.9729, val_roc: 0.9586, train_auprc: 0.9655, val_auprc: 0.9519




Epoch: 181 (112.1615s), train_loss: 0.1952, val_loss: 0.2712, train_acc: 0.9276, val_acc:0.8983
		train_roc: 0.9736, val_roc: 0.9581, train_auprc: 0.9663, val_auprc: 0.9512




Epoch: 182 (112.2050s), train_loss: 0.1980, val_loss: 0.2708, train_acc: 0.9259, val_acc:0.8983
		train_roc: 0.9727, val_roc: 0.9586, train_auprc: 0.9656, val_auprc: 0.9515




Epoch: 183 (111.2859s), train_loss: 0.1981, val_loss: 0.2679, train_acc: 0.9266, val_acc:0.9003
		train_roc: 0.9725, val_roc: 0.9594, train_auprc: 0.9648, val_auprc: 0.9527




Epoch: 184 (111.5487s), train_loss: 0.1958, val_loss: 0.2722, train_acc: 0.9272, val_acc:0.8973
		train_roc: 0.9734, val_roc: 0.9578, train_auprc: 0.9662, val_auprc: 0.9514




Epoch: 185 (110.9110s), train_loss: 0.1962, val_loss: 0.2693, train_acc: 0.9276, val_acc:0.9000
		train_roc: 0.9732, val_roc: 0.9588, train_auprc: 0.9658, val_auprc: 0.9516




Epoch: 186 (111.5632s), train_loss: 0.1958, val_loss: 0.2713, train_acc: 0.9274, val_acc:0.8981
		train_roc: 0.9733, val_roc: 0.9582, train_auprc: 0.9660, val_auprc: 0.9513




Epoch: 187 (112.0209s), train_loss: 0.1965, val_loss: 0.2690, train_acc: 0.9274, val_acc:0.8990
		train_roc: 0.9730, val_roc: 0.9590, train_auprc: 0.9655, val_auprc: 0.9528




Saving model acc
Epoch: 188 (111.0439s), train_loss: 0.1944, val_loss: 0.2657, train_acc: 0.9283, val_acc:0.9010
		train_roc: 0.9737, val_roc: 0.9599, train_auprc: 0.9665, val_auprc: 0.9537




Epoch: 189 (111.8633s), train_loss: 0.1973, val_loss: 0.2697, train_acc: 0.9270, val_acc:0.8988
		train_roc: 0.9728, val_roc: 0.9588, train_auprc: 0.9653, val_auprc: 0.9520




Epoch: 190 (112.1388s), train_loss: 0.1968, val_loss: 0.2703, train_acc: 0.9273, val_acc:0.8986
		train_roc: 0.9730, val_roc: 0.9585, train_auprc: 0.9656, val_auprc: 0.9520




Epoch: 191 (110.4284s), train_loss: 0.1956, val_loss: 0.2684, train_acc: 0.9278, val_acc:0.8998
		train_roc: 0.9733, val_roc: 0.9594, train_auprc: 0.9662, val_auprc: 0.9527




Epoch: 192 (110.4013s), train_loss: 0.1966, val_loss: 0.2711, train_acc: 0.9273, val_acc:0.8987
		train_roc: 0.9731, val_roc: 0.9584, train_auprc: 0.9657, val_auprc: 0.9513




Epoch: 193 (111.0730s), train_loss: 0.1967, val_loss: 0.2696, train_acc: 0.9271, val_acc:0.8996
		train_roc: 0.9730, val_roc: 0.9587, train_auprc: 0.9656, val_auprc: 0.9521




Epoch: 194 (111.5257s), train_loss: 0.1969, val_loss: 0.2715, train_acc: 0.9271, val_acc:0.8986
		train_roc: 0.9729, val_roc: 0.9583, train_auprc: 0.9653, val_auprc: 0.9511




Epoch: 195 (111.0455s), train_loss: 0.1959, val_loss: 0.2702, train_acc: 0.9278, val_acc:0.8990
		train_roc: 0.9733, val_roc: 0.9586, train_auprc: 0.9662, val_auprc: 0.9518




Epoch: 196 (112.4578s), train_loss: 0.1974, val_loss: 0.2708, train_acc: 0.9271, val_acc:0.8988
		train_roc: 0.9728, val_roc: 0.9583, train_auprc: 0.9654, val_auprc: 0.9515




Epoch: 197 (110.9587s), train_loss: 0.1964, val_loss: 0.2675, train_acc: 0.9273, val_acc:0.8994
		train_roc: 0.9731, val_roc: 0.9596, train_auprc: 0.9657, val_auprc: 0.9537




Epoch: 198 (112.0211s), train_loss: 0.1949, val_loss: 0.2699, train_acc: 0.9277, val_acc:0.8986
		train_roc: 0.9735, val_roc: 0.9587, train_auprc: 0.9665, val_auprc: 0.9523




Epoch: 199 (111.1981s), train_loss: 0.1949, val_loss: 0.2697, train_acc: 0.9281, val_acc:0.8984
		train_roc: 0.9736, val_roc: 0.9587, train_auprc: 0.9666, val_auprc: 0.9522




Epoch: 200 (112.3203s), train_loss: 0.1962, val_loss: 0.2687, train_acc: 0.9270, val_acc:0.8993
		train_roc: 0.9731, val_roc: 0.9590, train_auprc: 0.9659, val_auprc: 0.9527




Saving model roc
Epoch: 201 (111.0254s), train_loss: 0.1981, val_loss: 0.2665, train_acc: 0.9266, val_acc:0.8996
		train_roc: 0.9725, val_roc: 0.9600, train_auprc: 0.9650, val_auprc: 0.9540




Epoch: 202 (111.5724s), train_loss: 0.1964, val_loss: 0.2685, train_acc: 0.9270, val_acc:0.9000
		train_roc: 0.9731, val_roc: 0.9594, train_auprc: 0.9659, val_auprc: 0.9528




Epoch: 203 (112.3256s), train_loss: 0.1958, val_loss: 0.2692, train_acc: 0.9276, val_acc:0.8989
		train_roc: 0.9733, val_roc: 0.9588, train_auprc: 0.9661, val_auprc: 0.9527




Epoch: 204 (112.5432s), train_loss: 0.1948, val_loss: 0.2700, train_acc: 0.9277, val_acc:0.8992
		train_roc: 0.9737, val_roc: 0.9587, train_auprc: 0.9668, val_auprc: 0.9516




Epoch: 205 (111.8074s), train_loss: 0.1962, val_loss: 0.2702, train_acc: 0.9272, val_acc:0.8988
		train_roc: 0.9733, val_roc: 0.9588, train_auprc: 0.9662, val_auprc: 0.9521




Epoch: 206 (111.8646s), train_loss: 0.1959, val_loss: 0.2692, train_acc: 0.9272, val_acc:0.8991
		train_roc: 0.9733, val_roc: 0.9590, train_auprc: 0.9660, val_auprc: 0.9526




Epoch: 207 (111.2393s), train_loss: 0.1964, val_loss: 0.2678, train_acc: 0.9272, val_acc:0.8999
		train_roc: 0.9731, val_roc: 0.9595, train_auprc: 0.9657, val_auprc: 0.9535




Epoch: 208 (111.8463s), train_loss: 0.1961, val_loss: 0.2728, train_acc: 0.9268, val_acc:0.8976
		train_roc: 0.9733, val_roc: 0.9577, train_auprc: 0.9665, val_auprc: 0.9505




Epoch: 209 (111.3286s), train_loss: 0.1953, val_loss: 0.2700, train_acc: 0.9279, val_acc:0.8986
		train_roc: 0.9734, val_roc: 0.9586, train_auprc: 0.9660, val_auprc: 0.9525




Epoch: 210 (110.4992s), train_loss: 0.1958, val_loss: 0.2700, train_acc: 0.9273, val_acc:0.8985
		train_roc: 0.9734, val_roc: 0.9588, train_auprc: 0.9662, val_auprc: 0.9522




Epoch: 211 (111.0039s), train_loss: 0.1950, val_loss: 0.2700, train_acc: 0.9282, val_acc:0.8991
		train_roc: 0.9734, val_roc: 0.9587, train_auprc: 0.9663, val_auprc: 0.9523




Epoch: 212 (111.3493s), train_loss: 0.1965, val_loss: 0.2698, train_acc: 0.9278, val_acc:0.8986
		train_roc: 0.9729, val_roc: 0.9587, train_auprc: 0.9653, val_auprc: 0.9524




Epoch: 213 (111.3711s), train_loss: 0.1965, val_loss: 0.2693, train_acc: 0.9276, val_acc:0.8996
		train_roc: 0.9731, val_roc: 0.9588, train_auprc: 0.9657, val_auprc: 0.9522




Epoch: 214 (111.0319s), train_loss: 0.1961, val_loss: 0.2717, train_acc: 0.9275, val_acc:0.8983
		train_roc: 0.9731, val_roc: 0.9582, train_auprc: 0.9660, val_auprc: 0.9512




Epoch: 215 (111.9151s), train_loss: 0.1972, val_loss: 0.2696, train_acc: 0.9269, val_acc:0.8985
		train_roc: 0.9727, val_roc: 0.9589, train_auprc: 0.9654, val_auprc: 0.9526




Epoch: 216 (111.0811s), train_loss: 0.1949, val_loss: 0.2727, train_acc: 0.9276, val_acc:0.8986
		train_roc: 0.9736, val_roc: 0.9577, train_auprc: 0.9666, val_auprc: 0.9505




Epoch: 217 (111.0820s), train_loss: 0.1956, val_loss: 0.2690, train_acc: 0.9274, val_acc:0.8992
		train_roc: 0.9734, val_roc: 0.9592, train_auprc: 0.9664, val_auprc: 0.9529




Epoch: 218 (111.1477s), train_loss: 0.1975, val_loss: 0.2715, train_acc: 0.9265, val_acc:0.8988
		train_roc: 0.9728, val_roc: 0.9582, train_auprc: 0.9654, val_auprc: 0.9510




Epoch: 219 (111.9030s), train_loss: 0.1958, val_loss: 0.2712, train_acc: 0.9277, val_acc:0.8976
		train_roc: 0.9732, val_roc: 0.9583, train_auprc: 0.9660, val_auprc: 0.9520




Epoch: 220 (111.9834s), train_loss: 0.1959, val_loss: 0.2695, train_acc: 0.9276, val_acc:0.8983
		train_roc: 0.9732, val_roc: 0.9588, train_auprc: 0.9658, val_auprc: 0.9527




Epoch: 221 (111.9407s), train_loss: 0.1977, val_loss: 0.2711, train_acc: 0.9266, val_acc:0.8977
		train_roc: 0.9726, val_roc: 0.9584, train_auprc: 0.9651, val_auprc: 0.9518




Epoch: 222 (112.1571s), train_loss: 0.1957, val_loss: 0.2678, train_acc: 0.9277, val_acc:0.8996
		train_roc: 0.9733, val_roc: 0.9596, train_auprc: 0.9661, val_auprc: 0.9532




Epoch: 223 (111.3649s), train_loss: 0.1977, val_loss: 0.2682, train_acc: 0.9266, val_acc:0.8990
		train_roc: 0.9727, val_roc: 0.9594, train_auprc: 0.9652, val_auprc: 0.9534




Epoch: 224 (111.0580s), train_loss: 0.1959, val_loss: 0.2696, train_acc: 0.9275, val_acc:0.8986
		train_roc: 0.9732, val_roc: 0.9590, train_auprc: 0.9659, val_auprc: 0.9524




Epoch: 225 (111.7103s), train_loss: 0.1956, val_loss: 0.2692, train_acc: 0.9274, val_acc:0.8995
		train_roc: 0.9733, val_roc: 0.9589, train_auprc: 0.9661, val_auprc: 0.9522




Epoch: 226 (111.9623s), train_loss: 0.1960, val_loss: 0.2681, train_acc: 0.9275, val_acc:0.8991
		train_roc: 0.9732, val_roc: 0.9594, train_auprc: 0.9660, val_auprc: 0.9532




Epoch: 227 (112.4706s), train_loss: 0.1965, val_loss: 0.2702, train_acc: 0.9271, val_acc:0.8994
		train_roc: 0.9731, val_roc: 0.9585, train_auprc: 0.9657, val_auprc: 0.9516




Epoch: 228 (111.2730s), train_loss: 0.1949, val_loss: 0.2694, train_acc: 0.9277, val_acc:0.8997
		train_roc: 0.9737, val_roc: 0.9590, train_auprc: 0.9664, val_auprc: 0.9519




Epoch: 229 (111.4565s), train_loss: 0.1957, val_loss: 0.2717, train_acc: 0.9274, val_acc:0.8983
		train_roc: 0.9734, val_roc: 0.9581, train_auprc: 0.9662, val_auprc: 0.9508




Epoch: 230 (112.1706s), train_loss: 0.1962, val_loss: 0.2692, train_acc: 0.9272, val_acc:0.8990
		train_roc: 0.9732, val_roc: 0.9591, train_auprc: 0.9656, val_auprc: 0.9528




Epoch: 231 (112.1569s), train_loss: 0.1971, val_loss: 0.2726, train_acc: 0.9269, val_acc:0.8976
		train_roc: 0.9729, val_roc: 0.9578, train_auprc: 0.9653, val_auprc: 0.9504




Epoch: 232 (112.5054s), train_loss: 0.1970, val_loss: 0.2689, train_acc: 0.9277, val_acc:0.8993
		train_roc: 0.9728, val_roc: 0.9591, train_auprc: 0.9649, val_auprc: 0.9529




Epoch: 233 (112.0562s), train_loss: 0.1969, val_loss: 0.2689, train_acc: 0.9272, val_acc:0.8993
		train_roc: 0.9729, val_roc: 0.9593, train_auprc: 0.9654, val_auprc: 0.9529




Epoch: 234 (112.5886s), train_loss: 0.1966, val_loss: 0.2711, train_acc: 0.9272, val_acc:0.8978
		train_roc: 0.9730, val_roc: 0.9585, train_auprc: 0.9657, val_auprc: 0.9520




Epoch: 235 (111.4801s), train_loss: 0.1958, val_loss: 0.2690, train_acc: 0.9272, val_acc:0.8987
		train_roc: 0.9733, val_roc: 0.9591, train_auprc: 0.9665, val_auprc: 0.9526




Epoch: 236 (113.2583s), train_loss: 0.1962, val_loss: 0.2717, train_acc: 0.9273, val_acc:0.8982
		train_roc: 0.9733, val_roc: 0.9581, train_auprc: 0.9661, val_auprc: 0.9514




Epoch: 237 (111.4840s), train_loss: 0.1965, val_loss: 0.2671, train_acc: 0.9270, val_acc:0.8994
		train_roc: 0.9731, val_roc: 0.9598, train_auprc: 0.9660, val_auprc: 0.9538




Epoch: 238 (111.1725s), train_loss: 0.1952, val_loss: 0.2691, train_acc: 0.9280, val_acc:0.8993
		train_roc: 0.9735, val_roc: 0.9589, train_auprc: 0.9663, val_auprc: 0.9526




Epoch: 239 (111.2422s), train_loss: 0.1970, val_loss: 0.2697, train_acc: 0.9271, val_acc:0.8995
		train_roc: 0.9729, val_roc: 0.9588, train_auprc: 0.9655, val_auprc: 0.9523




Epoch: 240 (112.4459s), train_loss: 0.1961, val_loss: 0.2714, train_acc: 0.9272, val_acc:0.8975
		train_roc: 0.9732, val_roc: 0.9583, train_auprc: 0.9659, val_auprc: 0.9520




Epoch: 241 (111.1730s), train_loss: 0.1957, val_loss: 0.2686, train_acc: 0.9278, val_acc:0.8993
		train_roc: 0.9732, val_roc: 0.9593, train_auprc: 0.9658, val_auprc: 0.9528




Epoch: 242 (111.1665s), train_loss: 0.1965, val_loss: 0.2694, train_acc: 0.9273, val_acc:0.8993
		train_roc: 0.9730, val_roc: 0.9590, train_auprc: 0.9656, val_auprc: 0.9524




Epoch: 243 (110.6848s), train_loss: 0.1972, val_loss: 0.2698, train_acc: 0.9268, val_acc:0.8991
		train_roc: 0.9728, val_roc: 0.9588, train_auprc: 0.9654, val_auprc: 0.9522




Epoch: 244 (112.4964s), train_loss: 0.1963, val_loss: 0.2717, train_acc: 0.9270, val_acc:0.8989
		train_roc: 0.9731, val_roc: 0.9581, train_auprc: 0.9658, val_auprc: 0.9506




Epoch: 245 (111.1644s), train_loss: 0.1957, val_loss: 0.2687, train_acc: 0.9274, val_acc:0.9000
		train_roc: 0.9734, val_roc: 0.9591, train_auprc: 0.9665, val_auprc: 0.9530




Epoch: 246 (111.7278s), train_loss: 0.1960, val_loss: 0.2684, train_acc: 0.9270, val_acc:0.8995
		train_roc: 0.9732, val_roc: 0.9594, train_auprc: 0.9661, val_auprc: 0.9532




Epoch: 247 (111.0138s), train_loss: 0.1974, val_loss: 0.2690, train_acc: 0.9264, val_acc:0.8988
		train_roc: 0.9729, val_roc: 0.9593, train_auprc: 0.9656, val_auprc: 0.9530




Epoch: 248 (111.2831s), train_loss: 0.1967, val_loss: 0.2682, train_acc: 0.9271, val_acc:0.8992
		train_roc: 0.9731, val_roc: 0.9595, train_auprc: 0.9658, val_auprc: 0.9533




Epoch: 249 (111.3916s), train_loss: 0.1949, val_loss: 0.2699, train_acc: 0.9277, val_acc:0.8994
		train_roc: 0.9736, val_roc: 0.9587, train_auprc: 0.9666, val_auprc: 0.9515




Epoch: 250 (112.5279s), train_loss: 0.1970, val_loss: 0.2698, train_acc: 0.9271, val_acc:0.8996
		train_roc: 0.9729, val_roc: 0.9588, train_auprc: 0.9656, val_auprc: 0.9518




Epoch: 251 (110.6979s), train_loss: 0.1950, val_loss: 0.2704, train_acc: 0.9277, val_acc:0.8984
		train_roc: 0.9735, val_roc: 0.9585, train_auprc: 0.9665, val_auprc: 0.9517




Epoch: 252 (112.4966s), train_loss: 0.1946, val_loss: 0.2707, train_acc: 0.9280, val_acc:0.8987
		train_roc: 0.9738, val_roc: 0.9584, train_auprc: 0.9667, val_auprc: 0.9517




Epoch: 253 (112.3810s), train_loss: 0.1959, val_loss: 0.2720, train_acc: 0.9276, val_acc:0.8984
		train_roc: 0.9732, val_roc: 0.9580, train_auprc: 0.9656, val_auprc: 0.9510




Epoch: 254 (111.1220s), train_loss: 0.1959, val_loss: 0.2698, train_acc: 0.9276, val_acc:0.8986
		train_roc: 0.9732, val_roc: 0.9587, train_auprc: 0.9657, val_auprc: 0.9524




Epoch: 255 (111.3201s), train_loss: 0.1963, val_loss: 0.2694, train_acc: 0.9270, val_acc:0.8993
		train_roc: 0.9732, val_roc: 0.9589, train_auprc: 0.9660, val_auprc: 0.9524




Epoch: 256 (111.1969s), train_loss: 0.1976, val_loss: 0.2700, train_acc: 0.9270, val_acc:0.8983
		train_roc: 0.9727, val_roc: 0.9587, train_auprc: 0.9649, val_auprc: 0.9521




Epoch: 257 (113.0012s), train_loss: 0.1955, val_loss: 0.2712, train_acc: 0.9278, val_acc:0.8980
		train_roc: 0.9735, val_roc: 0.9583, train_auprc: 0.9663, val_auprc: 0.9517




Saving model prc
Saving model acc
Saving model roc
Epoch: 258 (110.8558s), train_loss: 0.1968, val_loss: 0.2655, train_acc: 0.9268, val_acc:0.9011
		train_roc: 0.9730, val_roc: 0.9605, train_auprc: 0.9654, val_auprc: 0.9546




Epoch: 259 (112.0729s), train_loss: 0.1959, val_loss: 0.2684, train_acc: 0.9272, val_acc:0.8992
		train_roc: 0.9734, val_roc: 0.9593, train_auprc: 0.9662, val_auprc: 0.9529




Epoch: 260 (111.7630s), train_loss: 0.1979, val_loss: 0.2688, train_acc: 0.9262, val_acc:0.8994
		train_roc: 0.9726, val_roc: 0.9592, train_auprc: 0.9651, val_auprc: 0.9527




Epoch: 261 (111.4926s), train_loss: 0.1963, val_loss: 0.2698, train_acc: 0.9274, val_acc:0.8984
		train_roc: 0.9731, val_roc: 0.9587, train_auprc: 0.9656, val_auprc: 0.9523




Epoch: 262 (111.1827s), train_loss: 0.1945, val_loss: 0.2713, train_acc: 0.9282, val_acc:0.8987
		train_roc: 0.9737, val_roc: 0.9583, train_auprc: 0.9666, val_auprc: 0.9515




Epoch: 263 (111.8899s), train_loss: 0.1963, val_loss: 0.2692, train_acc: 0.9273, val_acc:0.8984
		train_roc: 0.9732, val_roc: 0.9589, train_auprc: 0.9660, val_auprc: 0.9526




Epoch: 264 (109.8208s), train_loss: 0.1977, val_loss: 0.2717, train_acc: 0.9269, val_acc:0.8982
		train_roc: 0.9727, val_roc: 0.9584, train_auprc: 0.9650, val_auprc: 0.9512




Epoch: 265 (111.6152s), train_loss: 0.1956, val_loss: 0.2686, train_acc: 0.9274, val_acc:0.8996
		train_roc: 0.9734, val_roc: 0.9592, train_auprc: 0.9661, val_auprc: 0.9528




Epoch: 266 (111.7416s), train_loss: 0.1966, val_loss: 0.2704, train_acc: 0.9272, val_acc:0.8990
		train_roc: 0.9732, val_roc: 0.9587, train_auprc: 0.9657, val_auprc: 0.9515




Epoch: 267 (110.7270s), train_loss: 0.1959, val_loss: 0.2689, train_acc: 0.9274, val_acc:0.8993
		train_roc: 0.9732, val_roc: 0.9591, train_auprc: 0.9659, val_auprc: 0.9527




Epoch: 268 (111.3132s), train_loss: 0.1953, val_loss: 0.2701, train_acc: 0.9277, val_acc:0.8984
		train_roc: 0.9734, val_roc: 0.9586, train_auprc: 0.9664, val_auprc: 0.9523




Epoch: 269 (111.9660s), train_loss: 0.1963, val_loss: 0.2706, train_acc: 0.9275, val_acc:0.8984
		train_roc: 0.9731, val_roc: 0.9585, train_auprc: 0.9656, val_auprc: 0.9519




Epoch: 270 (112.2794s), train_loss: 0.1983, val_loss: 0.2674, train_acc: 0.9259, val_acc:0.9004
		train_roc: 0.9726, val_roc: 0.9598, train_auprc: 0.9653, val_auprc: 0.9531




Epoch: 271 (111.6979s), train_loss: 0.1954, val_loss: 0.2717, train_acc: 0.9277, val_acc:0.8978
		train_roc: 0.9733, val_roc: 0.9580, train_auprc: 0.9658, val_auprc: 0.9511




Epoch: 272 (110.3506s), train_loss: 0.1969, val_loss: 0.2692, train_acc: 0.9267, val_acc:0.8987
		train_roc: 0.9730, val_roc: 0.9590, train_auprc: 0.9655, val_auprc: 0.9528




Epoch: 273 (112.4568s), train_loss: 0.1954, val_loss: 0.2694, train_acc: 0.9273, val_acc:0.8987
		train_roc: 0.9734, val_roc: 0.9589, train_auprc: 0.9662, val_auprc: 0.9523




Epoch: 274 (111.8513s), train_loss: 0.1954, val_loss: 0.2714, train_acc: 0.9280, val_acc:0.8985
		train_roc: 0.9733, val_roc: 0.9582, train_auprc: 0.9658, val_auprc: 0.9511




Epoch: 275 (112.0932s), train_loss: 0.1953, val_loss: 0.2690, train_acc: 0.9275, val_acc:0.8986
		train_roc: 0.9734, val_roc: 0.9592, train_auprc: 0.9662, val_auprc: 0.9528




Epoch: 276 (111.1355s), train_loss: 0.1959, val_loss: 0.2719, train_acc: 0.9274, val_acc:0.8978
		train_roc: 0.9733, val_roc: 0.9580, train_auprc: 0.9661, val_auprc: 0.9512




Epoch: 277 (111.3266s), train_loss: 0.1969, val_loss: 0.2670, train_acc: 0.9265, val_acc:0.9004
		train_roc: 0.9731, val_roc: 0.9599, train_auprc: 0.9659, val_auprc: 0.9537




Epoch: 278 (111.9435s), train_loss: 0.1966, val_loss: 0.2698, train_acc: 0.9274, val_acc:0.8990
		train_roc: 0.9730, val_roc: 0.9587, train_auprc: 0.9657, val_auprc: 0.9520




Epoch: 279 (111.3914s), train_loss: 0.1954, val_loss: 0.2712, train_acc: 0.9276, val_acc:0.8986
		train_roc: 0.9734, val_roc: 0.9584, train_auprc: 0.9662, val_auprc: 0.9514




Epoch: 280 (110.8697s), train_loss: 0.1961, val_loss: 0.2692, train_acc: 0.9271, val_acc:0.8986
		train_roc: 0.9733, val_roc: 0.9590, train_auprc: 0.9662, val_auprc: 0.9527




Epoch: 281 (111.8894s), train_loss: 0.1964, val_loss: 0.2702, train_acc: 0.9270, val_acc:0.8984
		train_roc: 0.9732, val_roc: 0.9586, train_auprc: 0.9658, val_auprc: 0.9520




Epoch: 282 (111.0700s), train_loss: 0.1956, val_loss: 0.2698, train_acc: 0.9278, val_acc:0.8987
		train_roc: 0.9733, val_roc: 0.9589, train_auprc: 0.9659, val_auprc: 0.9521




Epoch: 283 (112.4287s), train_loss: 0.1951, val_loss: 0.2736, train_acc: 0.9279, val_acc:0.8977
		train_roc: 0.9734, val_roc: 0.9575, train_auprc: 0.9660, val_auprc: 0.9504




Epoch: 284 (111.8938s), train_loss: 0.1955, val_loss: 0.2706, train_acc: 0.9277, val_acc:0.8992
		train_roc: 0.9734, val_roc: 0.9584, train_auprc: 0.9663, val_auprc: 0.9516




Epoch: 285 (110.7761s), train_loss: 0.1952, val_loss: 0.2677, train_acc: 0.9281, val_acc:0.9008
		train_roc: 0.9734, val_roc: 0.9596, train_auprc: 0.9662, val_auprc: 0.9531




Epoch: 286 (110.4826s), train_loss: 0.1967, val_loss: 0.2695, train_acc: 0.9271, val_acc:0.8995
		train_roc: 0.9730, val_roc: 0.9589, train_auprc: 0.9658, val_auprc: 0.9524




Epoch: 287 (110.9738s), train_loss: 0.1948, val_loss: 0.2716, train_acc: 0.9282, val_acc:0.8979
		train_roc: 0.9735, val_roc: 0.9581, train_auprc: 0.9664, val_auprc: 0.9514




Epoch: 288 (111.3449s), train_loss: 0.1960, val_loss: 0.2716, train_acc: 0.9271, val_acc:0.8982
		train_roc: 0.9733, val_roc: 0.9582, train_auprc: 0.9662, val_auprc: 0.9513




Epoch: 289 (112.9665s), train_loss: 0.1947, val_loss: 0.2699, train_acc: 0.9280, val_acc:0.8992
		train_roc: 0.9737, val_roc: 0.9588, train_auprc: 0.9666, val_auprc: 0.9521




Epoch: 290 (111.5208s), train_loss: 0.1946, val_loss: 0.2704, train_acc: 0.9283, val_acc:0.8984
		train_roc: 0.9736, val_roc: 0.9586, train_auprc: 0.9665, val_auprc: 0.9515




Epoch: 291 (111.8795s), train_loss: 0.1965, val_loss: 0.2714, train_acc: 0.9274, val_acc:0.8977
		train_roc: 0.9731, val_roc: 0.9583, train_auprc: 0.9657, val_auprc: 0.9518




Epoch: 292 (111.2406s), train_loss: 0.1952, val_loss: 0.2718, train_acc: 0.9274, val_acc:0.8977
		train_roc: 0.9735, val_roc: 0.9580, train_auprc: 0.9666, val_auprc: 0.9515




Epoch: 293 (110.7149s), train_loss: 0.1958, val_loss: 0.2701, train_acc: 0.9276, val_acc:0.8987
		train_roc: 0.9732, val_roc: 0.9587, train_auprc: 0.9660, val_auprc: 0.9523




Epoch: 294 (110.7271s), train_loss: 0.1953, val_loss: 0.2687, train_acc: 0.9278, val_acc:0.8998
		train_roc: 0.9735, val_roc: 0.9592, train_auprc: 0.9662, val_auprc: 0.9524




Epoch: 295 (112.0490s), train_loss: 0.1960, val_loss: 0.2708, train_acc: 0.9270, val_acc:0.8986
		train_roc: 0.9734, val_roc: 0.9585, train_auprc: 0.9664, val_auprc: 0.9518




Epoch: 296 (110.3402s), train_loss: 0.1938, val_loss: 0.2693, train_acc: 0.9280, val_acc:0.8992
		train_roc: 0.9739, val_roc: 0.9590, train_auprc: 0.9669, val_auprc: 0.9523




Epoch: 297 (110.3298s), train_loss: 0.1965, val_loss: 0.2684, train_acc: 0.9271, val_acc:0.8990
		train_roc: 0.9729, val_roc: 0.9593, train_auprc: 0.9654, val_auprc: 0.9529




Epoch: 298 (111.3804s), train_loss: 0.1980, val_loss: 0.2724, train_acc: 0.9268, val_acc:0.8973
		train_roc: 0.9727, val_roc: 0.9579, train_auprc: 0.9651, val_auprc: 0.9510




Epoch: 299 (111.8457s), train_loss: 0.1952, val_loss: 0.2684, train_acc: 0.9283, val_acc:0.8994
		train_roc: 0.9734, val_roc: 0.9594, train_auprc: 0.9659, val_auprc: 0.9532




Epoch: 300 (110.1253s), train_loss: 0.1959, val_loss: 0.2714, train_acc: 0.9282, val_acc:0.8976
		train_roc: 0.9731, val_roc: 0.9583, train_auprc: 0.9656, val_auprc: 0.9517


In [21]:
# Predict
model = torch.load(model_roc_file)
print(model)
model.to(device=device)
predict(model, test_data_loader, device)

  model = torch.load(model_roc_file)


SSI_DDI(
  (initial_norm): LayerNorm(55, affine=True, mode=graph)
  (net_norms): ModuleList(
    (0-5): 6 x LayerNorm(64, affine=True, mode=graph)
  )
  (block0): SSI_DDI_Block(
    (conv): GATConv(55, 32, heads=2)
    (readout): SAGPooling(GraphConv, 64, min_score=-1, multiplier=1.0)
  )
  (block1): SSI_DDI_Block(
    (conv): GATConv(64, 32, heads=2)
    (readout): SAGPooling(GraphConv, 64, min_score=-1, multiplier=1.0)
  )
  (block2): SSI_DDI_Block(
    (conv): GATConv(64, 32, heads=2)
    (readout): SAGPooling(GraphConv, 64, min_score=-1, multiplier=1.0)
  )
  (block3): SSI_DDI_Block(
    (conv): GATConv(64, 32, heads=2)
    (readout): SAGPooling(GraphConv, 64, min_score=-1, multiplier=1.0)
  )
  (block4): SSI_DDI_Block(
    (conv): GATConv(64, 32, heads=2)
    (readout): SAGPooling(GraphConv, 64, min_score=-1, multiplier=1.0)
  )
  (block5): SSI_DDI_Block(
    (conv): GATConv(64, 32, heads=2)
    (readout): SAGPooling(GraphConv, 64, min_score=-1, multiplier=1.0)
  )
  (co_attention