In [1]:
# !pip install torch_geometric rdkit torch

In [2]:
from datetime import datetime
import time
import argparse
import sys
import torch
from torch import optim
from torch import nn
import torch.nn.functional as F
from sklearn import metrics
import pandas as pd
import numpy as np
from torch.nn.modules.container import ModuleList
from torch_geometric.nn import (
    GATConv,
    SAGPooling,
    LayerNorm,
    global_mean_pool,
    max_pool_neighbor_x,
    global_add_pool,
)


In [3]:
# Directory configuration
data_dir = "data"
model_dir = "models"
model_name = "case16"

# sys.path.append('/content/drive/MyDrive/Colab Notebooks')

In [4]:
####### Tunning parameters #######

# Number of epochs
n_epochs = 300

# SagPooling ratio & min score. 
# Set sp_ratio to None to disable ratio in SagPooling
sp_ratio = None
sp_min_score = None

# Enable using gpu
use_cuda = True

# Use activation function for CoAttention Layer
use_activation_fn = False

# Use ComplEx instead of RESCAL
use_ComplEx = True

# Use improved CoAttention Layer
use_improved_CoAttention = True

# Use Explicit Valence
use_explicit_valence = False

# Number of GAT layers
num_GAT_layers = 5

# Number of GAT multiheads
num_GAT_multiheads = 2

#################################

In [5]:
# If using explicit valence feature
if use_explicit_valence:
    from data_preprocessing_explicit_valence import DrugDataset, DrugDataLoader, TOTAL_ATOM_FEATS
else:
    from data_preprocessing import DrugDataset, DrugDataLoader, TOTAL_ATOM_FEATS

  return undirected_edge_list.T, features


In [6]:
mode = "train"

n_atom_feats = TOTAL_ATOM_FEATS
# Not use
n_atom_hid = 64
# Total interactions information in the Interaction_information.csv
rel_total = 86
lr = 1e-2
weight_decay = 5e-4
neg_samples = 1
# Represents the number of samples (or graph instances) loaded in each batch during the training process.
batch_size = 1024
data_size_ratio = 1
kge_dim = 64

device = "cuda" if torch.cuda.is_available() and use_cuda else "cpu"

print(device)
print(f"Epochs: {n_epochs}")
print(f"Total of atom features: {TOTAL_ATOM_FEATS}")

cuda
Epochs: 300
Total of atom features: 55


In [7]:
def print_tunning_parameters():
    print()
    print("####### Tunning parameters #######")
    print()
    
    print("n_epochs =", n_epochs)
    print("use_cuda =", use_cuda)
    print()
    print("num_GAT_layers = ", num_GAT_layers)
    print("num_GAT_multiheads = ", num_GAT_multiheads)
    print()
    print("sp_ratio =", sp_ratio)
    print("sp_min_score =", sp_min_score)
    print()
    print("use_explicit_valence =", use_explicit_valence)
    print()
    print("use_activation_fn =", use_activation_fn)
    print()
    print("use_ComplEx =", use_ComplEx)
    print()
    print("use_improved_CoAttention =", use_improved_CoAttention)
    
    print()
    print("#################################")
    print()


In [8]:
class CoAttentionLayer(nn.Module):
    def __init__(self, n_features, use_activation_fn=True):
        super().__init__()
        self.n_features = n_features
        self.w_q = nn.Parameter(torch.zeros(n_features, n_features // 2))
        self.w_k = nn.Parameter(torch.zeros(n_features, n_features // 2))
        self.bias = nn.Parameter(torch.zeros(n_features // 2))
        self.a = nn.Parameter(torch.zeros(n_features // 2))
        self.use_activation_fn = use_activation_fn

        nn.init.xavier_uniform_(self.w_q)
        nn.init.xavier_uniform_(self.w_k)
        nn.init.xavier_uniform_(self.bias.view(*self.bias.shape, -1))
        nn.init.xavier_uniform_(self.a.view(*self.a.shape, -1))

    def forward(self, receiver, attendant):
        # receiver.shape  = (1024, 4, 64)
        # attendant.shape = (1024, 4, 64)
        keys = receiver @ self.w_k
        queries = attendant @ self.w_q
        # values = receiver @ self.w_v
        values = receiver

        # queries.shape = (1024, 4, 32)
        # keys.shape = (1024, 4, 32)
        e_activations = queries.unsqueeze(-3) + keys.unsqueeze(-2) + self.bias
        if self.use_activation_fn:
            e_scores = torch.tanh(e_activations) @ self.a
        else:
            e_scores = e_activations @ self.a
        attentions = e_scores

        return attentions

class CoAttentionLayerImproved(nn.Module):
    def __init__(self, n_features, use_activation_fn=True, dropout=0.1, n_heads=2):
        super().__init__()
        self.n_features = n_features
        self.n_heads = n_heads
        self.head_dim = n_features // n_heads

        # Projects for queries and keys per head
        self.w_q = nn.Parameter(torch.zeros(self.head_dim, self.head_dim // 2))
        self.w_k = nn.Parameter(torch.zeros(self.head_dim, self.head_dim // 2))
        self.bias = nn.Parameter(torch.zeros(self.n_features // 2))
        self.a = nn.Parameter(torch.zeros(self.n_features // 2))
        self.use_activation_fn = use_activation_fn

        self.dropout = nn.Dropout(dropout)

        nn.init.xavier_uniform_(self.w_q)
        nn.init.xavier_uniform_(self.w_k)
        nn.init.xavier_uniform_(self.bias.view(*self.bias.shape, -1))
        nn.init.xavier_uniform_(self.a.view(*self.a.shape, -1))

    def forward(self, receiver, attendant):
        # receiver.shape  = (1024, 4, 64)
        # attendant.shape = (1024, 4, 64)

        # Split reciever and attendant into multiple heads
        batch_size, gat_size, n_features = receiver.shape
        receiver = receiver.view(batch_size, gat_size, self.n_heads, self.head_dim)
        attendant = attendant.view(batch_size, gat_size, self.n_heads, self.head_dim)
        
        # Compute keys and queries per head
        # receiver.shape  = (1024, 4, 2, 32)
        # attendant.shape = (1024, 4, 2, 32)
        
        # self.w_k.shape  = (32, 16)
        # self.w_q.shape  = (32, 16)
        
        # self.keys.shape     = (1024, 4, 2, 16)
        # self.queries.shape  = (1024, 4, 2, 16)
        keys = receiver @ self.w_k
        queries = attendant @ self.w_q

        # self.keys.shape     = (1024, 4, 32)
        # self.queries.shape  = (1024, 4, 32)
        keys    = keys.view(batch_size, gat_size, self.head_dim)
        queries = queries.view(batch_size, gat_size, self.head_dim)
        # print("keys.shape", keys.shape)
        # print("queries.shape", queries.shape)

        # e_activations.shape = (1024, 4, 4, 32)
        # self.a.shape = (32,)
        e_activations = queries.unsqueeze(-3) + keys.unsqueeze(-2) + self.bias
        if self.use_activation_fn:
            e_scores = torch.tanh(e_activations) @ self.a
        else:
            e_scores = e_activations @ self.a

        # attentions.shape = (1024, 4, 4)
        attentions = e_scores

        return attentions


class RESCAL(nn.Module):
    def __init__(self, n_rels, n_features):
        """
        n_rels: number of relations = 86
        n_features: kge_dim = 64
        """
        super().__init__()
        self.n_rels = n_rels
        self.n_features = n_features
        # Embedding layer
        self.rel_emb = nn.Embedding(self.n_rels, n_features * n_features)
        #  Initializes the embedding weights with the Xavier uniform distribution, which helps maintain the scale of gradients during training
        nn.init.xavier_uniform_(self.rel_emb.weight)

    def forward(self, heads, tails, rels, alpha_scores):
        rels = self.rel_emb(rels)
        rels = F.normalize(rels, dim=-1)
        heads = F.normalize(heads, dim=-1)
        tails = F.normalize(tails, dim=-1)
        # print(rels.shape)
        # Convert shape (1024, 4096) to (1024, 64, 64) for dot product
        rels = rels.view(-1, self.n_features, self.n_features)
        # print(rels.shape)
        # (1024, 4, 64) @ (1024, 64, 64) = (1024, 4, 64) @ (1024, 64, 4) = (1024, 4, 4)
        scores = heads @ rels @ tails.transpose(-2, -1)

        # alpha_scores.shape = (1024, 4, 4)
        # scores.shape = (1024, 4, 4)
        if alpha_scores is not None:
            scores = alpha_scores * scores
        # print(scores.shape)
        
        # sum the last 2 dimensions
        scores = scores.sum(dim=(-2, -1))
        
        # print(scores.shape)
        # Shape(1024,)
        return scores

    def __repr__(self):
        return f"{self.__class__.__name__}({self.n_rels}, {self.rel_emb.weight.shape})"



class ComplEx(nn.Module):
    def __init__(self, n_rels, n_features):
        super().__init__()
        self.n_rels = n_rels
        self.n_features = n_features
    
        # Relation embeddings are also complex
        self.rel_real = nn.Embedding(self.n_rels, (self.n_features // 2) * (self.n_features // 2))
        self.rel_imag = nn.Embedding(self.n_rels, (self.n_features // 2) * (self.n_features // 2))
        
        # Initialize embeddings
        nn.init.xavier_uniform_(self.rel_real.weight)
        nn.init.xavier_uniform_(self.rel_imag.weight)

    def forward(self, heads, tails, rels, alpha_scores=None):
        # Preprocess
        heads = F.normalize(heads, dim=-1)
        tails = F.normalize(tails, dim=-1)
        
        r_real, r_imag = self.rel_real(rels), self.rel_imag(rels)
        r_real = F.normalize(r_real, dim=-1)
        r_imag = F.normalize(r_imag, dim=-1)
        # print(r_real.shape)
        r_real = r_real.view(-1, self.n_features // 2, self.n_features // 2)
        r_imag = r_imag.view(-1, self.n_features // 2, self.n_features // 2)
        # print(r_real.shape)
        # Split heads and tails to imaginary parts
        h_real, h_imag = heads[..., :self.n_features // 2], heads[..., self.n_features // 2:]
        t_real, t_imag = tails[..., :self.n_features // 2], heads[..., self.n_features // 2:]

        # ComplEx scoring functionn
        first_part_score = h_real @ r_real @ t_real.transpose(-2, -1)
        second_part_score = h_real @ r_imag @ t_imag.transpose(-2, -1)
        third_part_score = h_imag @ r_real @ t_imag.transpose(-2, -1)
        fourth_part_score = h_imag @ r_imag @ t_real.transpose(-2, -1)

        scores = first_part_score + second_part_score + third_part_score + fourth_part_score
        
        # If alpha_scores is provided, apply it
        if alpha_scores is not None:
            scores = alpha_scores * scores

        scores = scores.sum(dim=(-2, -1))
        
        return scores

    def __repr__(self):
        return f"{self.__class__.__name__}({self.n_rels}, {self.rel_real.weight.shape}, {self.rel_imag.weight.shape})"


In [9]:
class SSI_DDI(nn.Module):
    def __init__(
        self,
        in_features,
        hidd_dim,
        kge_dim,
        rel_total,
        heads_out_feat_params,
        blocks_params,
        sp_ratio,
        use_activation_fn,
        use_ComplEx,
        sp_min_score,
        use_improved_CoAttention,
    ):
        """
        blocks_params: list of number layers for multi-head attentions
        """
        super().__init__()
        self.in_features = in_features
        # not using this one
        self.hidd_dim = hidd_dim
        self.rel_total = rel_total
        self.kge_dim = kge_dim
        self.n_blocks = len(blocks_params)

        self.initial_norm = LayerNorm(self.in_features)
        self.blocks = []
        self.use_activation_fn = use_activation_fn
        self.use_ComplEx = use_ComplEx
        # Layer normalization list
        self.net_norms = ModuleList()
        for i, (head_out_feats, n_heads) in enumerate(
            zip(heads_out_feat_params, blocks_params)
        ):
            block = SSI_DDI_Block(
                n_heads, in_features, head_out_feats, final_out_feats=self.hidd_dim, sp_ratio=sp_ratio, sp_min_score=sp_min_score
            )
            self.add_module(f"block{i}", block)
            self.blocks.append(block)
            self.net_norms.append(LayerNorm(head_out_feats * n_heads))
            in_features = head_out_feats * n_heads

        if use_improved_CoAttention:
            self.co_attention = CoAttentionLayerImproved(self.kge_dim, self.use_activation_fn)
        else:
            self.co_attention = CoAttentionLayer(self.kge_dim, self.use_activation_fn)
            
        if self.use_ComplEx:
            self.KGE = ComplEx(self.rel_total, self.kge_dim)
        else:
            self.KGE = RESCAL(self.rel_total, self.kge_dim)

    def forward(self, triples):
        h_data, t_data, rels = triples

        h_data.x = self.initial_norm(h_data.x, h_data.batch)
        t_data.x = self.initial_norm(t_data.x, t_data.batch)

        repr_h = []
        repr_t = []

        for i, block in enumerate(self.blocks):
            out1, out2 = block(h_data), block(t_data)

            h_data = out1[0]
            t_data = out2[0]
            r_h = out1[1]
            r_t = out2[1]

            repr_h.append(r_h)
            repr_t.append(r_t)

            h_data.x = F.elu(self.net_norms[i](h_data.x, h_data.batch))
            t_data.x = F.elu(self.net_norms[i](t_data.x, t_data.batch))

        repr_h = torch.stack(repr_h, dim=-2)
        repr_t = torch.stack(repr_t, dim=-2)

        kge_heads = repr_h
        kge_tails = repr_t

        attentions = self.co_attention(kge_heads, kge_tails)
        # attentions = None
        scores = self.KGE(kge_heads, kge_tails, rels, attentions)

        return scores


class SSI_DDI_Block(nn.Module):
    def __init__(self, n_heads, in_features, head_out_feats, final_out_feats, sp_ratio, sp_min_score):
        """
        n_heades: number of multi-head attentions = 2
        in_features: number of features = 55 . For explicit valence use, number of features = 56.
        head_out_feats: number of out features. For 4 layers: [32, 32, 32, 32]
        sp_ratio: SAGPooling ratio
        """
        super().__init__()
        self.n_heads = n_heads
        self.in_features = in_features
        self.out_features = head_out_feats
        self.conv = GATConv(in_features, head_out_feats, n_heads)
        # SAGPooling: Ranks nodes based on self-attention scores

        if sp_ratio is None and sp_min_score is None:
            self.readout = SAGPooling(n_heads * head_out_feats, min_score=-1)
        else:
            if sp_ratio is not None:
                self.readout = SAGPooling(n_heads * head_out_feats, min_score=sp_min_score, ratio=sp_ratio)
            else:
                self.readout = SAGPooling(n_heads * head_out_feats, min_score=sp_min_score)

    def forward(self, data):
        data.x = self.conv(data.x, data.edge_index)
        # Call SAGPooling here
        # If min_score = -1 so nodes will not be filtered out, basically redudant for using the SAGPooling.
        att_x, att_edge_index, att_edge_attr, att_batch, att_perm, att_scores = (
            self.readout(data.x, data.edge_index, batch=data.batch)
        )
        # Aggregates the pooled node features (att_x) across the graph to obtain a global representation
        global_graph_emb = global_add_pool(att_x, att_batch)

        # data = max_pool_neighbor_x(data)
        return data, global_graph_emb


In [10]:
class SigmoidLoss(nn.Module):
    def __init__(self, adv_temperature=None):
        super().__init__()
        self.adv_temperature = adv_temperature

    def forward(self, p_scores, n_scores):
        if self.adv_temperature:
            weights = F.softmax(self.adv_temperature * n_scores, dim=-1).detach()
            n_scores = weights * n_scores
        p_loss = -F.logsigmoid(p_scores).mean()
        n_loss = -F.logsigmoid(-n_scores).mean()

        return (p_loss + n_loss) / 2, p_loss, n_loss


In [11]:
df_ddi_train = pd.read_csv(f"{data_dir}/ddi_training.csv")
df_ddi_val = pd.read_csv(f"{data_dir}/ddi_validation.csv")
df_ddi_test = pd.read_csv(f"{data_dir}/ddi_test.csv")


train_tup = [
    (h, t, r)
    for h, t, r in zip(df_ddi_train["d1"], df_ddi_train["d2"], df_ddi_train["type"])
]
val_tup = [
    (h, t, r) for h, t, r in zip(df_ddi_val["d1"], df_ddi_val["d2"], df_ddi_val["type"])
]
test_tup = [
    (h, t, r)
    for h, t, r in zip(df_ddi_test["d1"], df_ddi_test["d2"], df_ddi_test["type"])
]

train_data = DrugDataset(train_tup, ratio=data_size_ratio, neg_ent=neg_samples)
val_data = DrugDataset(val_tup, ratio=data_size_ratio, disjoint_split=False)
test_data = DrugDataset(test_tup, disjoint_split=False)

print(
    f"Training with {len(train_data)} samples, validating with {len(val_data)}, and testing with {len(test_data)}"
)

train_data_loader = DrugDataLoader(train_data, batch_size=batch_size, shuffle=True)
val_data_loader = DrugDataLoader(val_data, batch_size=batch_size * 3)
test_data_loader = DrugDataLoader(test_data, batch_size=batch_size * 3)


Training with 115185 samples, validating with 38348, and testing with 38337


In [12]:
def do_compute(model, batch, device, training=True):
    """
    *batch: (pos_tri, neg_tri)
    *pos/neg_tri: (batch_h, batch_t, batch_r)
    """
    probas_pred, ground_truth = [], []
    pos_tri, neg_tri = batch

    pos_tri = [tensor.to(device=device) for tensor in pos_tri]
    p_score = model(pos_tri)
    probas_pred.append(torch.sigmoid(p_score.detach()).cpu())
    ground_truth.append(np.ones(len(p_score)))

    neg_tri = [tensor.to(device=device) for tensor in neg_tri]
    n_score = model(neg_tri)
    probas_pred.append(torch.sigmoid(n_score.detach()).cpu())
    ground_truth.append(np.zeros(len(n_score)))

    probas_pred = np.concatenate(probas_pred)
    ground_truth = np.concatenate(ground_truth)

    return p_score, n_score, probas_pred, ground_truth


def do_compute_metrics(probas_pred, target):

    pred = (probas_pred >= 0.5).astype(np.int64)

    acc = metrics.accuracy_score(target, pred)
    auc_roc = metrics.roc_auc_score(target, probas_pred)
    f1_score = metrics.f1_score(target, pred)

    p, r, t = metrics.precision_recall_curve(target, probas_pred)
    auc_prc = metrics.auc(r, p)

    return acc, auc_roc, auc_prc

In [13]:
import csv
def export_metrics(train_metrics, val_metrics, epoch):
    train_metrics_dir = "train_metrics"
    metrics_file = f"{train_metrics_dir}/{model_name}.csv"
    train_loss, train_acc, train_auc_roc, train_auc_prc = train_metrics
    val_loss, val_acc, val_auc_roc, val_auc_prc = val_metrics

    data = [epoch, train_loss, train_acc, train_auc_roc, train_auc_prc, val_loss, val_acc, val_auc_roc, val_auc_prc]
    header = ["epoch", "train_loss", "train_acc", "train_auc_roc", "train_auc_prc", "val_loss", "val_acc", "val_auc_roc", "val_auc_prc"]
    
    if epoch == 1:
        with open(metrics_file, 'w', newline='') as file:
            writer = csv.writer(file)
            # Write the header
            writer.writerow(header)
            # Write the data rows
            writer.writerow(data)
    else:
        with open(metrics_file, 'a', newline='') as file:
            writer = csv.writer(file)
            # Write the data to the file
            writer.writerow(data)
    
    

In [14]:
def train(
    model,
    train_data_loader,
    val_data_loader,
    loss_fn,
    optimizer,
    n_epochs,
    device,
    scheduler=None,
):
    print("Starting training at:", datetime.today())
    print("Device:", device)
    print_tunning_parameters()
    best_val_auc_prc = 0
    for i in range(1, n_epochs + 1):
        start = time.time()
        train_loss = 0
        train_loss_pos = 0
        train_loss_neg = 0
        val_loss = 0
        val_loss_pos = 0
        val_loss_neg = 0
        train_probas_pred = []
        train_ground_truth = []
        val_probas_pred = []
        val_ground_truth = []

        for batch in train_data_loader:
            # print(len(batch))
            model.train()
            p_score, n_score, probas_pred, ground_truth = do_compute(model, batch, device)
            train_probas_pred.append(probas_pred)
            train_ground_truth.append(ground_truth)
            loss, loss_p, loss_n = loss_fn(p_score, n_score)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_loss += loss.item() * len(p_score)
        train_loss /= len(train_data)

        with torch.no_grad():
            train_probas_pred = np.concatenate(train_probas_pred)
            train_ground_truth = np.concatenate(train_ground_truth)

            train_acc, train_auc_roc, train_auc_prc = do_compute_metrics(
                train_probas_pred, train_ground_truth
            )

            for batch in val_data_loader:
                model.eval()
                p_score, n_score, probas_pred, ground_truth = do_compute(model, batch, device)
                val_probas_pred.append(probas_pred)
                val_ground_truth.append(ground_truth)
                loss, loss_p, loss_n = loss_fn(p_score, n_score)
                val_loss += loss.item() * len(p_score)

            val_loss /= len(val_data)
            val_probas_pred = np.concatenate(val_probas_pred)
            val_ground_truth = np.concatenate(val_ground_truth)
            val_acc, val_auc_roc, val_auc_prc = do_compute_metrics(
                val_probas_pred, val_ground_truth
            )
            
            # Save model if this is best result based on val_auc_prc
            if best_val_auc_prc < val_auc_prc:
                print("Saving model")
                best_val_auc_prc = val_auc_prc
                torch.save(model, model_file)

        if scheduler:
            # print('scheduling')
            scheduler.step()

        # Exporting metrics for later plots
        train_metrics = (train_loss, train_acc, train_auc_roc, train_auc_prc)
        val_metrics = (val_loss, val_acc, val_auc_roc, val_auc_prc)
        export_metrics(train_metrics, val_metrics, i)
        
        print(
            f"Epoch: {i} ({time.time() - start:.4f}s), train_loss: {train_loss:.4f}, val_loss: {val_loss:.4f},"
            f" train_acc: {train_acc:.4f}, val_acc:{val_acc:.4f}"
        )
        print(
            f"\t\ttrain_roc: {train_auc_roc:.4f}, val_roc: {val_auc_roc:.4f}, train_auprc: {train_auc_prc:.4f}, val_auprc: {val_auc_prc:.4f}"
        )

    return model

In [15]:
def predict(model, test_data_loader, device):
    print('Starting predicting at', datetime.today())
    print('Device', device)

    test_probas_pred = []
    test_ground_truth = []

    # Switch to evaluation mode
    model.eval()

    with torch.no_grad():  # No need to calculate gradients during testing
        for batch in test_data_loader:
            # Get predictions and ground truth for the batch
            p_score, n_score, probas_pred, ground_truth = do_compute(model, batch, device, training=False)

            # Append the predictions and ground truths
            test_probas_pred.append(probas_pred)
            test_ground_truth.append(ground_truth)


    # Concatenate the results for the entire test dataset
    test_probas_pred = np.concatenate(test_probas_pred)
    test_ground_truth = np.concatenate(test_ground_truth)

    # Calculate the metrics for the test dataset
    test_acc, test_auc_roc, test_auc_prc = do_compute_metrics(test_probas_pred, test_ground_truth)

    print(f'Test Accuracy: {test_acc:.4f}')
    print(f'Test ROC AUC: {test_auc_roc:.4f}')
    print(f'Test PRC AUC: {test_auc_prc:.4f}')

In [16]:
model_file = f"{model_dir}/{model_name}.pth"

heads_out_feat_params = []
block_params = []

for _ in range(num_GAT_layers):
    heads_out_feat_params.append(kge_dim // 2)
    block_params.append(num_GAT_multiheads)

if mode == "train":
    model = SSI_DDI(
        n_atom_feats,
        n_atom_hid,
        kge_dim,
        rel_total,
        heads_out_feat_params=heads_out_feat_params,
        blocks_params=block_params,
        sp_ratio=sp_ratio,
        use_activation_fn=use_activation_fn,
        use_ComplEx=use_ComplEx,
        sp_min_score=sp_min_score,
        use_improved_CoAttention=use_improved_CoAttention,
    )
    loss = SigmoidLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    scheduler = optim.lr_scheduler.LambdaLR(optimizer, lambda epoch: 0.96 ** (epoch))
    print(model)
    model.to(device=device)

SSI_DDI(
  (initial_norm): LayerNorm(55, affine=True, mode=graph)
  (net_norms): ModuleList(
    (0-4): 5 x LayerNorm(64, affine=True, mode=graph)
  )
  (block0): SSI_DDI_Block(
    (conv): GATConv(55, 32, heads=2)
    (readout): SAGPooling(GraphConv, 64, min_score=-1, multiplier=1.0)
  )
  (block1): SSI_DDI_Block(
    (conv): GATConv(64, 32, heads=2)
    (readout): SAGPooling(GraphConv, 64, min_score=-1, multiplier=1.0)
  )
  (block2): SSI_DDI_Block(
    (conv): GATConv(64, 32, heads=2)
    (readout): SAGPooling(GraphConv, 64, min_score=-1, multiplier=1.0)
  )
  (block3): SSI_DDI_Block(
    (conv): GATConv(64, 32, heads=2)
    (readout): SAGPooling(GraphConv, 64, min_score=-1, multiplier=1.0)
  )
  (block4): SSI_DDI_Block(
    (conv): GATConv(64, 32, heads=2)
    (readout): SAGPooling(GraphConv, 64, min_score=-1, multiplier=1.0)
  )
  (co_attention): CoAttentionLayerImproved(
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (KGE): ComplEx(86, torch.Size([86, 1024]), torch.Size([86, 

In [17]:
if mode == "train":
  # Train
  train(
      model,
      train_data_loader,
      val_data_loader,
      loss,
      optimizer,
      n_epochs,
      device,
      scheduler,
  )


Starting training at: 2024-10-24 01:33:13.869661
Device: cuda

####### Tunning parameters #######

n_epochs = 300
use_cuda = True

num_GAT_layers =  5
num_GAT_multiheads =  2

sp_ratio = None
sp_min_score = None

use_explicit_valence = False

use_activation_fn = False

use_ComplEx = True

use_improved_CoAttention = True

#################################





Saving model
Epoch: 1 (78.1022s), train_loss: 0.6593, val_loss: 0.6119, train_acc: 0.5939, val_acc:0.6567
		train_roc: 0.6389, val_roc: 0.7197, train_auprc: 0.6253, val_auprc: 0.7007




Saving model
Epoch: 2 (72.1949s), train_loss: 0.5899, val_loss: 0.5706, train_acc: 0.6782, val_acc:0.6971
		train_roc: 0.7466, val_roc: 0.7681, train_auprc: 0.7264, val_auprc: 0.7482




Saving model
Epoch: 3 (71.8835s), train_loss: 0.5546, val_loss: 0.5386, train_acc: 0.7097, val_acc:0.7260
		train_roc: 0.7837, val_roc: 0.8010, train_auprc: 0.7628, val_auprc: 0.7785




Saving model
Epoch: 4 (72.0719s), train_loss: 0.5289, val_loss: 0.5221, train_acc: 0.7329, val_acc:0.7414
		train_roc: 0.8088, val_roc: 0.8179, train_auprc: 0.7869, val_auprc: 0.7972




Saving model
Epoch: 5 (72.3282s), train_loss: 0.5132, val_loss: 0.5048, train_acc: 0.7456, val_acc:0.7517
		train_roc: 0.8224, val_roc: 0.8287, train_auprc: 0.8012, val_auprc: 0.8077




Saving model
Epoch: 6 (72.9791s), train_loss: 0.4971, val_loss: 0.4940, train_acc: 0.7570, val_acc:0.7560
		train_roc: 0.8348, val_roc: 0.8367, train_auprc: 0.8132, val_auprc: 0.8187




Saving model
Epoch: 7 (71.9338s), train_loss: 0.4877, val_loss: 0.4828, train_acc: 0.7647, val_acc:0.7683
		train_roc: 0.8424, val_roc: 0.8462, train_auprc: 0.8212, val_auprc: 0.8229




Saving model
Epoch: 8 (72.0332s), train_loss: 0.4759, val_loss: 0.4735, train_acc: 0.7726, val_acc:0.7748
		train_roc: 0.8508, val_roc: 0.8547, train_auprc: 0.8292, val_auprc: 0.8347




Saving model
Epoch: 9 (71.9238s), train_loss: 0.4673, val_loss: 0.4648, train_acc: 0.7792, val_acc:0.7783
		train_roc: 0.8572, val_roc: 0.8592, train_auprc: 0.8355, val_auprc: 0.8399




Saving model
Epoch: 10 (71.8110s), train_loss: 0.4569, val_loss: 0.4591, train_acc: 0.7864, val_acc:0.7853
		train_roc: 0.8641, val_roc: 0.8627, train_auprc: 0.8432, val_auprc: 0.8417




Saving model
Epoch: 11 (72.2196s), train_loss: 0.4512, val_loss: 0.4469, train_acc: 0.7906, val_acc:0.7941
		train_roc: 0.8678, val_roc: 0.8708, train_auprc: 0.8473, val_auprc: 0.8518




Saving model
Epoch: 12 (72.1230s), train_loss: 0.4429, val_loss: 0.4415, train_acc: 0.7961, val_acc:0.7961
		train_roc: 0.8730, val_roc: 0.8748, train_auprc: 0.8521, val_auprc: 0.8547




Saving model
Epoch: 13 (72.1554s), train_loss: 0.4380, val_loss: 0.4414, train_acc: 0.7995, val_acc:0.7984
		train_roc: 0.8763, val_roc: 0.8747, train_auprc: 0.8552, val_auprc: 0.8550




Saving model
Epoch: 14 (72.4029s), train_loss: 0.4305, val_loss: 0.4288, train_acc: 0.8049, val_acc:0.8047
		train_roc: 0.8807, val_roc: 0.8812, train_auprc: 0.8605, val_auprc: 0.8617




Saving model
Epoch: 15 (72.0613s), train_loss: 0.4262, val_loss: 0.4303, train_acc: 0.8064, val_acc:0.7993
		train_roc: 0.8831, val_roc: 0.8826, train_auprc: 0.8623, val_auprc: 0.8641




Saving model
Epoch: 16 (72.2337s), train_loss: 0.4199, val_loss: 0.4208, train_acc: 0.8107, val_acc:0.8092
		train_roc: 0.8866, val_roc: 0.8860, train_auprc: 0.8663, val_auprc: 0.8670




Saving model
Epoch: 17 (72.0580s), train_loss: 0.4147, val_loss: 0.4109, train_acc: 0.8138, val_acc:0.8176
		train_roc: 0.8900, val_roc: 0.8926, train_auprc: 0.8699, val_auprc: 0.8738




Epoch: 18 (72.0125s), train_loss: 0.4105, val_loss: 0.4154, train_acc: 0.8171, val_acc:0.8152
		train_roc: 0.8923, val_roc: 0.8905, train_auprc: 0.8723, val_auprc: 0.8710




Epoch: 19 (71.8936s), train_loss: 0.4044, val_loss: 0.4134, train_acc: 0.8204, val_acc:0.8149
		train_roc: 0.8954, val_roc: 0.8916, train_auprc: 0.8761, val_auprc: 0.8716




Saving model
Epoch: 20 (71.9487s), train_loss: 0.4005, val_loss: 0.4065, train_acc: 0.8232, val_acc:0.8178
		train_roc: 0.8974, val_roc: 0.8954, train_auprc: 0.8773, val_auprc: 0.8766




Saving model
Epoch: 21 (71.9179s), train_loss: 0.3968, val_loss: 0.4032, train_acc: 0.8259, val_acc:0.8226
		train_roc: 0.8997, val_roc: 0.8966, train_auprc: 0.8804, val_auprc: 0.8788




Saving model
Epoch: 22 (72.0608s), train_loss: 0.3885, val_loss: 0.3927, train_acc: 0.8298, val_acc:0.8237
		train_roc: 0.9036, val_roc: 0.9014, train_auprc: 0.8855, val_auprc: 0.8844




Saving model
Epoch: 23 (71.8269s), train_loss: 0.3887, val_loss: 0.3927, train_acc: 0.8292, val_acc:0.8294
		train_roc: 0.9033, val_roc: 0.9023, train_auprc: 0.8847, val_auprc: 0.8845




Saving model
Epoch: 24 (71.9676s), train_loss: 0.3816, val_loss: 0.3872, train_acc: 0.8341, val_acc:0.8317
		train_roc: 0.9072, val_roc: 0.9045, train_auprc: 0.8891, val_auprc: 0.8870




Saving model
Epoch: 25 (71.9899s), train_loss: 0.3785, val_loss: 0.3798, train_acc: 0.8359, val_acc:0.8355
		train_roc: 0.9086, val_roc: 0.9080, train_auprc: 0.8903, val_auprc: 0.8913




Epoch: 26 (71.8876s), train_loss: 0.3748, val_loss: 0.3835, train_acc: 0.8376, val_acc:0.8317
		train_roc: 0.9105, val_roc: 0.9064, train_auprc: 0.8929, val_auprc: 0.8903




Epoch: 27 (72.2086s), train_loss: 0.3700, val_loss: 0.3822, train_acc: 0.8406, val_acc:0.8340
		train_roc: 0.9125, val_roc: 0.9075, train_auprc: 0.8952, val_auprc: 0.8904




Epoch: 28 (71.9938s), train_loss: 0.3673, val_loss: 0.3811, train_acc: 0.8417, val_acc:0.8340
		train_roc: 0.9141, val_roc: 0.9069, train_auprc: 0.8968, val_auprc: 0.8893




Saving model
Epoch: 29 (72.1582s), train_loss: 0.3621, val_loss: 0.3680, train_acc: 0.8450, val_acc:0.8418
		train_roc: 0.9160, val_roc: 0.9138, train_auprc: 0.8992, val_auprc: 0.8977




Saving model
Epoch: 30 (71.9811s), train_loss: 0.3594, val_loss: 0.3657, train_acc: 0.8473, val_acc:0.8420
		train_roc: 0.9176, val_roc: 0.9149, train_auprc: 0.9001, val_auprc: 0.8983




Epoch: 31 (71.5377s), train_loss: 0.3567, val_loss: 0.3693, train_acc: 0.8489, val_acc:0.8438
		train_roc: 0.9189, val_roc: 0.9144, train_auprc: 0.9017, val_auprc: 0.8975




Epoch: 32 (71.5560s), train_loss: 0.3554, val_loss: 0.3712, train_acc: 0.8488, val_acc:0.8408
		train_roc: 0.9190, val_roc: 0.9123, train_auprc: 0.9018, val_auprc: 0.8943




Saving model
Epoch: 33 (71.7892s), train_loss: 0.3480, val_loss: 0.3574, train_acc: 0.8528, val_acc:0.8486
		train_roc: 0.9225, val_roc: 0.9185, train_auprc: 0.9060, val_auprc: 0.9033




Saving model
Epoch: 34 (71.6823s), train_loss: 0.3457, val_loss: 0.3546, train_acc: 0.8538, val_acc:0.8498
		train_roc: 0.9234, val_roc: 0.9199, train_auprc: 0.9069, val_auprc: 0.9051




Epoch: 35 (71.6114s), train_loss: 0.3427, val_loss: 0.3594, train_acc: 0.8552, val_acc:0.8475
		train_roc: 0.9248, val_roc: 0.9181, train_auprc: 0.9087, val_auprc: 0.9011




Saving model
Epoch: 36 (71.6840s), train_loss: 0.3412, val_loss: 0.3530, train_acc: 0.8566, val_acc:0.8501
		train_roc: 0.9254, val_roc: 0.9213, train_auprc: 0.9092, val_auprc: 0.9068




Epoch: 37 (71.4865s), train_loss: 0.3394, val_loss: 0.3545, train_acc: 0.8569, val_acc:0.8505
		train_roc: 0.9261, val_roc: 0.9204, train_auprc: 0.9099, val_auprc: 0.9045




Saving model
Epoch: 38 (71.4720s), train_loss: 0.3346, val_loss: 0.3470, train_acc: 0.8598, val_acc:0.8544
		train_roc: 0.9282, val_roc: 0.9235, train_auprc: 0.9128, val_auprc: 0.9088




Saving model
Epoch: 39 (71.6123s), train_loss: 0.3329, val_loss: 0.3429, train_acc: 0.8605, val_acc:0.8558
		train_roc: 0.9287, val_roc: 0.9249, train_auprc: 0.9126, val_auprc: 0.9098




Saving model
Epoch: 40 (71.3268s), train_loss: 0.3303, val_loss: 0.3412, train_acc: 0.8619, val_acc:0.8570
		train_roc: 0.9298, val_roc: 0.9263, train_auprc: 0.9140, val_auprc: 0.9120




Epoch: 41 (71.4535s), train_loss: 0.3270, val_loss: 0.3455, train_acc: 0.8632, val_acc:0.8571
		train_roc: 0.9310, val_roc: 0.9248, train_auprc: 0.9155, val_auprc: 0.9096




Epoch: 42 (71.4477s), train_loss: 0.3250, val_loss: 0.3405, train_acc: 0.8653, val_acc:0.8576
		train_roc: 0.9321, val_roc: 0.9268, train_auprc: 0.9167, val_auprc: 0.9119




Saving model
Epoch: 43 (71.7059s), train_loss: 0.3216, val_loss: 0.3370, train_acc: 0.8665, val_acc:0.8591
		train_roc: 0.9334, val_roc: 0.9276, train_auprc: 0.9179, val_auprc: 0.9130




Epoch: 44 (71.3214s), train_loss: 0.3214, val_loss: 0.3390, train_acc: 0.8676, val_acc:0.8583
		train_roc: 0.9334, val_roc: 0.9263, train_auprc: 0.9179, val_auprc: 0.9108




Saving model
Epoch: 45 (71.4018s), train_loss: 0.3166, val_loss: 0.3343, train_acc: 0.8688, val_acc:0.8617
		train_roc: 0.9352, val_roc: 0.9283, train_auprc: 0.9211, val_auprc: 0.9138




Saving model
Epoch: 46 (71.6628s), train_loss: 0.3134, val_loss: 0.3358, train_acc: 0.8709, val_acc:0.8602
		train_roc: 0.9363, val_roc: 0.9284, train_auprc: 0.9215, val_auprc: 0.9145




Saving model
Epoch: 47 (71.5528s), train_loss: 0.3123, val_loss: 0.3274, train_acc: 0.8725, val_acc:0.8641
		train_roc: 0.9370, val_roc: 0.9320, train_auprc: 0.9225, val_auprc: 0.9186




Epoch: 48 (71.4543s), train_loss: 0.3113, val_loss: 0.3283, train_acc: 0.8722, val_acc:0.8633
		train_roc: 0.9372, val_roc: 0.9318, train_auprc: 0.9225, val_auprc: 0.9185




Saving model
Epoch: 49 (71.2853s), train_loss: 0.3079, val_loss: 0.3251, train_acc: 0.8742, val_acc:0.8663
		train_roc: 0.9382, val_roc: 0.9328, train_auprc: 0.9237, val_auprc: 0.9187




Saving model
Epoch: 50 (71.2841s), train_loss: 0.3060, val_loss: 0.3221, train_acc: 0.8744, val_acc:0.8658
		train_roc: 0.9390, val_roc: 0.9340, train_auprc: 0.9252, val_auprc: 0.9213




Epoch: 51 (71.6343s), train_loss: 0.3030, val_loss: 0.3233, train_acc: 0.8763, val_acc:0.8673
		train_roc: 0.9403, val_roc: 0.9332, train_auprc: 0.9263, val_auprc: 0.9195




Epoch: 52 (71.2967s), train_loss: 0.3029, val_loss: 0.3233, train_acc: 0.8769, val_acc:0.8685
		train_roc: 0.9405, val_roc: 0.9342, train_auprc: 0.9267, val_auprc: 0.9208




Epoch: 53 (71.2954s), train_loss: 0.3005, val_loss: 0.3233, train_acc: 0.8780, val_acc:0.8655
		train_roc: 0.9411, val_roc: 0.9335, train_auprc: 0.9273, val_auprc: 0.9201




Saving model
Epoch: 54 (71.4368s), train_loss: 0.2957, val_loss: 0.3195, train_acc: 0.8798, val_acc:0.8687
		train_roc: 0.9429, val_roc: 0.9350, train_auprc: 0.9293, val_auprc: 0.9218




Saving model
Epoch: 55 (71.5702s), train_loss: 0.2947, val_loss: 0.3197, train_acc: 0.8811, val_acc:0.8689
		train_roc: 0.9435, val_roc: 0.9350, train_auprc: 0.9301, val_auprc: 0.9219




Saving model
Epoch: 56 (71.5064s), train_loss: 0.2936, val_loss: 0.3160, train_acc: 0.8809, val_acc:0.8710
		train_roc: 0.9436, val_roc: 0.9368, train_auprc: 0.9301, val_auprc: 0.9251




Epoch: 57 (71.2820s), train_loss: 0.2932, val_loss: 0.3144, train_acc: 0.8822, val_acc:0.8724
		train_roc: 0.9439, val_roc: 0.9373, train_auprc: 0.9303, val_auprc: 0.9241




Epoch: 58 (71.1983s), train_loss: 0.2906, val_loss: 0.3144, train_acc: 0.8819, val_acc:0.8717
		train_roc: 0.9449, val_roc: 0.9372, train_auprc: 0.9318, val_auprc: 0.9250




Saving model
Epoch: 59 (71.7432s), train_loss: 0.2891, val_loss: 0.3141, train_acc: 0.8829, val_acc:0.8723
		train_roc: 0.9451, val_roc: 0.9378, train_auprc: 0.9321, val_auprc: 0.9255




Epoch: 60 (71.5868s), train_loss: 0.2895, val_loss: 0.3117, train_acc: 0.8833, val_acc:0.8738
		train_roc: 0.9454, val_roc: 0.9379, train_auprc: 0.9323, val_auprc: 0.9255




Saving model
Epoch: 61 (71.4241s), train_loss: 0.2848, val_loss: 0.3122, train_acc: 0.8853, val_acc:0.8729
		train_roc: 0.9468, val_roc: 0.9385, train_auprc: 0.9342, val_auprc: 0.9260




Saving model
Epoch: 62 (71.4119s), train_loss: 0.2845, val_loss: 0.3101, train_acc: 0.8859, val_acc:0.8758
		train_roc: 0.9471, val_roc: 0.9392, train_auprc: 0.9343, val_auprc: 0.9264




Epoch: 63 (71.6936s), train_loss: 0.2820, val_loss: 0.3142, train_acc: 0.8865, val_acc:0.8710
		train_roc: 0.9478, val_roc: 0.9380, train_auprc: 0.9352, val_auprc: 0.9259




Saving model
Epoch: 64 (71.3896s), train_loss: 0.2811, val_loss: 0.3094, train_acc: 0.8871, val_acc:0.8747
		train_roc: 0.9480, val_roc: 0.9391, train_auprc: 0.9349, val_auprc: 0.9266




Saving model
Epoch: 65 (71.5081s), train_loss: 0.2801, val_loss: 0.3077, train_acc: 0.8883, val_acc:0.8750
		train_roc: 0.9486, val_roc: 0.9396, train_auprc: 0.9360, val_auprc: 0.9275




Saving model
Epoch: 66 (71.3743s), train_loss: 0.2786, val_loss: 0.3083, train_acc: 0.8882, val_acc:0.8752
		train_roc: 0.9488, val_roc: 0.9401, train_auprc: 0.9362, val_auprc: 0.9281




Epoch: 67 (71.6196s), train_loss: 0.2769, val_loss: 0.3101, train_acc: 0.8892, val_acc:0.8747
		train_roc: 0.9493, val_roc: 0.9391, train_auprc: 0.9372, val_auprc: 0.9268




Saving model
Epoch: 68 (71.3268s), train_loss: 0.2770, val_loss: 0.3009, train_acc: 0.8891, val_acc:0.8789
		train_roc: 0.9493, val_roc: 0.9426, train_auprc: 0.9367, val_auprc: 0.9313




Saving model
Epoch: 69 (71.3546s), train_loss: 0.2741, val_loss: 0.3011, train_acc: 0.8898, val_acc:0.8790
		train_roc: 0.9504, val_roc: 0.9429, train_auprc: 0.9384, val_auprc: 0.9318




Epoch: 70 (71.7627s), train_loss: 0.2744, val_loss: 0.3050, train_acc: 0.8908, val_acc:0.8756
		train_roc: 0.9501, val_roc: 0.9412, train_auprc: 0.9376, val_auprc: 0.9305




Epoch: 71 (71.3640s), train_loss: 0.2725, val_loss: 0.3014, train_acc: 0.8914, val_acc:0.8784
		train_roc: 0.9509, val_roc: 0.9426, train_auprc: 0.9389, val_auprc: 0.9316




Epoch: 72 (71.4125s), train_loss: 0.2718, val_loss: 0.3032, train_acc: 0.8915, val_acc:0.8790
		train_roc: 0.9510, val_roc: 0.9415, train_auprc: 0.9389, val_auprc: 0.9300




Saving model
Epoch: 73 (71.2561s), train_loss: 0.2733, val_loss: 0.3005, train_acc: 0.8914, val_acc:0.8787
		train_roc: 0.9505, val_roc: 0.9433, train_auprc: 0.9379, val_auprc: 0.9331




Epoch: 74 (71.8040s), train_loss: 0.2697, val_loss: 0.3020, train_acc: 0.8922, val_acc:0.8784
		train_roc: 0.9515, val_roc: 0.9423, train_auprc: 0.9398, val_auprc: 0.9312




Epoch: 75 (71.5035s), train_loss: 0.2678, val_loss: 0.3015, train_acc: 0.8928, val_acc:0.8801
		train_roc: 0.9525, val_roc: 0.9425, train_auprc: 0.9406, val_auprc: 0.9308




Epoch: 76 (71.1310s), train_loss: 0.2687, val_loss: 0.3015, train_acc: 0.8928, val_acc:0.8796
		train_roc: 0.9518, val_roc: 0.9426, train_auprc: 0.9399, val_auprc: 0.9315




Epoch: 77 (71.4123s), train_loss: 0.2661, val_loss: 0.3009, train_acc: 0.8934, val_acc:0.8797
		train_roc: 0.9529, val_roc: 0.9429, train_auprc: 0.9418, val_auprc: 0.9321




Epoch: 78 (71.3067s), train_loss: 0.2640, val_loss: 0.2984, train_acc: 0.8950, val_acc:0.8815
		train_roc: 0.9535, val_roc: 0.9436, train_auprc: 0.9420, val_auprc: 0.9325




Saving model
Epoch: 79 (71.2966s), train_loss: 0.2631, val_loss: 0.2994, train_acc: 0.8952, val_acc:0.8795
		train_roc: 0.9538, val_roc: 0.9439, train_auprc: 0.9426, val_auprc: 0.9331




Epoch: 80 (71.4866s), train_loss: 0.2635, val_loss: 0.2989, train_acc: 0.8957, val_acc:0.8801
		train_roc: 0.9535, val_roc: 0.9435, train_auprc: 0.9414, val_auprc: 0.9327




Saving model
Epoch: 81 (71.5584s), train_loss: 0.2635, val_loss: 0.2941, train_acc: 0.8950, val_acc:0.8840
		train_roc: 0.9537, val_roc: 0.9456, train_auprc: 0.9423, val_auprc: 0.9357




Epoch: 82 (71.3352s), train_loss: 0.2625, val_loss: 0.2987, train_acc: 0.8953, val_acc:0.8816
		train_roc: 0.9540, val_roc: 0.9440, train_auprc: 0.9426, val_auprc: 0.9331




Epoch: 83 (71.4096s), train_loss: 0.2615, val_loss: 0.2947, train_acc: 0.8968, val_acc:0.8820
		train_roc: 0.9543, val_roc: 0.9456, train_auprc: 0.9428, val_auprc: 0.9352




Epoch: 84 (71.4664s), train_loss: 0.2603, val_loss: 0.3008, train_acc: 0.8965, val_acc:0.8808
		train_roc: 0.9545, val_roc: 0.9437, train_auprc: 0.9437, val_auprc: 0.9321




Epoch: 85 (71.2155s), train_loss: 0.2599, val_loss: 0.2968, train_acc: 0.8973, val_acc:0.8816
		train_roc: 0.9547, val_roc: 0.9448, train_auprc: 0.9437, val_auprc: 0.9342




Epoch: 86 (71.4535s), train_loss: 0.2589, val_loss: 0.2972, train_acc: 0.8973, val_acc:0.8821
		train_roc: 0.9552, val_roc: 0.9449, train_auprc: 0.9440, val_auprc: 0.9335




Epoch: 87 (71.3406s), train_loss: 0.2582, val_loss: 0.2945, train_acc: 0.8974, val_acc:0.8838
		train_roc: 0.9553, val_roc: 0.9460, train_auprc: 0.9444, val_auprc: 0.9351




Epoch: 88 (71.1875s), train_loss: 0.2572, val_loss: 0.2952, train_acc: 0.8982, val_acc:0.8834
		train_roc: 0.9557, val_roc: 0.9454, train_auprc: 0.9447, val_auprc: 0.9342




Saving model
Epoch: 89 (71.4572s), train_loss: 0.2572, val_loss: 0.2881, train_acc: 0.8987, val_acc:0.8856
		train_roc: 0.9554, val_roc: 0.9480, train_auprc: 0.9443, val_auprc: 0.9389




Epoch: 90 (71.2541s), train_loss: 0.2547, val_loss: 0.2927, train_acc: 0.8995, val_acc:0.8831
		train_roc: 0.9563, val_roc: 0.9464, train_auprc: 0.9455, val_auprc: 0.9365




Epoch: 91 (71.2636s), train_loss: 0.2569, val_loss: 0.2924, train_acc: 0.8987, val_acc:0.8839
		train_roc: 0.9554, val_roc: 0.9470, train_auprc: 0.9442, val_auprc: 0.9371




Epoch: 92 (71.3216s), train_loss: 0.2554, val_loss: 0.2893, train_acc: 0.8996, val_acc:0.8869
		train_roc: 0.9561, val_roc: 0.9481, train_auprc: 0.9450, val_auprc: 0.9382




Epoch: 93 (71.4213s), train_loss: 0.2521, val_loss: 0.2939, train_acc: 0.9007, val_acc:0.8847
		train_roc: 0.9572, val_roc: 0.9458, train_auprc: 0.9466, val_auprc: 0.9350




Epoch: 94 (71.1294s), train_loss: 0.2518, val_loss: 0.2927, train_acc: 0.9015, val_acc:0.8842
		train_roc: 0.9571, val_roc: 0.9462, train_auprc: 0.9459, val_auprc: 0.9362




Epoch: 95 (71.4045s), train_loss: 0.2526, val_loss: 0.2931, train_acc: 0.9005, val_acc:0.8856
		train_roc: 0.9570, val_roc: 0.9463, train_auprc: 0.9462, val_auprc: 0.9356




Epoch: 96 (71.5585s), train_loss: 0.2531, val_loss: 0.2948, train_acc: 0.9001, val_acc:0.8833
		train_roc: 0.9566, val_roc: 0.9462, train_auprc: 0.9460, val_auprc: 0.9361




Epoch: 97 (71.4953s), train_loss: 0.2503, val_loss: 0.2916, train_acc: 0.9015, val_acc:0.8860
		train_roc: 0.9578, val_roc: 0.9471, train_auprc: 0.9475, val_auprc: 0.9371




Epoch: 98 (71.3500s), train_loss: 0.2517, val_loss: 0.2938, train_acc: 0.9007, val_acc:0.8847
		train_roc: 0.9571, val_roc: 0.9460, train_auprc: 0.9461, val_auprc: 0.9352




Epoch: 99 (71.4773s), train_loss: 0.2504, val_loss: 0.2918, train_acc: 0.9015, val_acc:0.8848
		train_roc: 0.9576, val_roc: 0.9472, train_auprc: 0.9468, val_auprc: 0.9376




Epoch: 100 (71.5294s), train_loss: 0.2507, val_loss: 0.2940, train_acc: 0.9018, val_acc:0.8840
		train_roc: 0.9574, val_roc: 0.9465, train_auprc: 0.9467, val_auprc: 0.9360




Epoch: 101 (71.4008s), train_loss: 0.2502, val_loss: 0.2912, train_acc: 0.9019, val_acc:0.8860
		train_roc: 0.9575, val_roc: 0.9470, train_auprc: 0.9468, val_auprc: 0.9368




Epoch: 102 (71.2198s), train_loss: 0.2498, val_loss: 0.2911, train_acc: 0.9018, val_acc:0.8859
		train_roc: 0.9578, val_roc: 0.9473, train_auprc: 0.9473, val_auprc: 0.9372




Epoch: 103 (71.6368s), train_loss: 0.2488, val_loss: 0.2930, train_acc: 0.9026, val_acc:0.8848
		train_roc: 0.9580, val_roc: 0.9468, train_auprc: 0.9474, val_auprc: 0.9368




Epoch: 104 (71.4687s), train_loss: 0.2467, val_loss: 0.2902, train_acc: 0.9035, val_acc:0.8861
		train_roc: 0.9590, val_roc: 0.9477, train_auprc: 0.9486, val_auprc: 0.9374




Epoch: 105 (71.4484s), train_loss: 0.2466, val_loss: 0.2888, train_acc: 0.9029, val_acc:0.8860
		train_roc: 0.9589, val_roc: 0.9482, train_auprc: 0.9488, val_auprc: 0.9389




Epoch: 106 (71.4877s), train_loss: 0.2494, val_loss: 0.2938, train_acc: 0.9027, val_acc:0.8841
		train_roc: 0.9576, val_roc: 0.9465, train_auprc: 0.9469, val_auprc: 0.9364




Epoch: 107 (71.3882s), train_loss: 0.2502, val_loss: 0.2923, train_acc: 0.9018, val_acc:0.8849
		train_roc: 0.9574, val_roc: 0.9474, train_auprc: 0.9466, val_auprc: 0.9374




Saving model
Epoch: 108 (71.3990s), train_loss: 0.2476, val_loss: 0.2886, train_acc: 0.9032, val_acc:0.8870
		train_roc: 0.9584, val_roc: 0.9485, train_auprc: 0.9478, val_auprc: 0.9390




Saving model
Epoch: 109 (71.5464s), train_loss: 0.2465, val_loss: 0.2900, train_acc: 0.9040, val_acc:0.8860
		train_roc: 0.9587, val_roc: 0.9485, train_auprc: 0.9481, val_auprc: 0.9391




Saving model
Epoch: 110 (71.4441s), train_loss: 0.2453, val_loss: 0.2888, train_acc: 0.9042, val_acc:0.8873
		train_roc: 0.9593, val_roc: 0.9487, train_auprc: 0.9491, val_auprc: 0.9394




Epoch: 111 (71.4844s), train_loss: 0.2444, val_loss: 0.2910, train_acc: 0.9048, val_acc:0.8875
		train_roc: 0.9595, val_roc: 0.9474, train_auprc: 0.9490, val_auprc: 0.9373




Epoch: 112 (71.4304s), train_loss: 0.2465, val_loss: 0.2916, train_acc: 0.9036, val_acc:0.8868
		train_roc: 0.9586, val_roc: 0.9476, train_auprc: 0.9481, val_auprc: 0.9373




Epoch: 113 (71.4525s), train_loss: 0.2453, val_loss: 0.2907, train_acc: 0.9043, val_acc:0.8865
		train_roc: 0.9591, val_roc: 0.9481, train_auprc: 0.9485, val_auprc: 0.9387




Epoch: 114 (71.0957s), train_loss: 0.2427, val_loss: 0.2910, train_acc: 0.9048, val_acc:0.8871
		train_roc: 0.9600, val_roc: 0.9478, train_auprc: 0.9501, val_auprc: 0.9387




Epoch: 115 (71.4127s), train_loss: 0.2451, val_loss: 0.2929, train_acc: 0.9040, val_acc:0.8857
		train_roc: 0.9592, val_roc: 0.9472, train_auprc: 0.9488, val_auprc: 0.9369




Epoch: 116 (71.4178s), train_loss: 0.2421, val_loss: 0.2894, train_acc: 0.9056, val_acc:0.8873
		train_roc: 0.9602, val_roc: 0.9483, train_auprc: 0.9502, val_auprc: 0.9387




Epoch: 117 (71.1303s), train_loss: 0.2459, val_loss: 0.2938, train_acc: 0.9038, val_acc:0.8854
		train_roc: 0.9589, val_roc: 0.9471, train_auprc: 0.9483, val_auprc: 0.9372




Saving model
Epoch: 118 (71.5306s), train_loss: 0.2448, val_loss: 0.2889, train_acc: 0.9044, val_acc:0.8880
		train_roc: 0.9592, val_roc: 0.9489, train_auprc: 0.9488, val_auprc: 0.9398




Epoch: 119 (71.2226s), train_loss: 0.2411, val_loss: 0.2917, train_acc: 0.9059, val_acc:0.8863
		train_roc: 0.9606, val_roc: 0.9478, train_auprc: 0.9507, val_auprc: 0.9382




Epoch: 120 (71.1791s), train_loss: 0.2440, val_loss: 0.2891, train_acc: 0.9044, val_acc:0.8882
		train_roc: 0.9593, val_roc: 0.9490, train_auprc: 0.9490, val_auprc: 0.9393




Epoch: 121 (71.5043s), train_loss: 0.2430, val_loss: 0.2891, train_acc: 0.9055, val_acc:0.8880
		train_roc: 0.9599, val_roc: 0.9490, train_auprc: 0.9493, val_auprc: 0.9396




Epoch: 122 (71.4061s), train_loss: 0.2433, val_loss: 0.2899, train_acc: 0.9050, val_acc:0.8874
		train_roc: 0.9597, val_roc: 0.9489, train_auprc: 0.9494, val_auprc: 0.9396




Epoch: 123 (71.4099s), train_loss: 0.2436, val_loss: 0.2906, train_acc: 0.9053, val_acc:0.8873
		train_roc: 0.9595, val_roc: 0.9483, train_auprc: 0.9487, val_auprc: 0.9384




Epoch: 124 (71.2678s), train_loss: 0.2428, val_loss: 0.2933, train_acc: 0.9052, val_acc:0.8864
		train_roc: 0.9597, val_roc: 0.9471, train_auprc: 0.9496, val_auprc: 0.9373




Epoch: 125 (71.3002s), train_loss: 0.2423, val_loss: 0.2914, train_acc: 0.9049, val_acc:0.8860
		train_roc: 0.9601, val_roc: 0.9478, train_auprc: 0.9502, val_auprc: 0.9382




Epoch: 126 (71.4014s), train_loss: 0.2425, val_loss: 0.2923, train_acc: 0.9054, val_acc:0.8868
		train_roc: 0.9598, val_roc: 0.9477, train_auprc: 0.9493, val_auprc: 0.9378




Epoch: 127 (71.3217s), train_loss: 0.2436, val_loss: 0.2898, train_acc: 0.9046, val_acc:0.8878
		train_roc: 0.9597, val_roc: 0.9488, train_auprc: 0.9495, val_auprc: 0.9398




Saving model
Epoch: 128 (71.5267s), train_loss: 0.2407, val_loss: 0.2889, train_acc: 0.9057, val_acc:0.8878
		train_roc: 0.9606, val_roc: 0.9489, train_auprc: 0.9506, val_auprc: 0.9398




Epoch: 129 (71.4454s), train_loss: 0.2422, val_loss: 0.2919, train_acc: 0.9054, val_acc:0.8864
		train_roc: 0.9600, val_roc: 0.9479, train_auprc: 0.9497, val_auprc: 0.9382




Epoch: 130 (71.4606s), train_loss: 0.2430, val_loss: 0.2916, train_acc: 0.9050, val_acc:0.8864
		train_roc: 0.9597, val_roc: 0.9477, train_auprc: 0.9491, val_auprc: 0.9381




Epoch: 131 (71.3282s), train_loss: 0.2422, val_loss: 0.2896, train_acc: 0.9055, val_acc:0.8889
		train_roc: 0.9600, val_roc: 0.9486, train_auprc: 0.9495, val_auprc: 0.9390




Epoch: 132 (71.8465s), train_loss: 0.2415, val_loss: 0.2929, train_acc: 0.9054, val_acc:0.8874
		train_roc: 0.9604, val_roc: 0.9475, train_auprc: 0.9503, val_auprc: 0.9375




Saving model
Epoch: 133 (71.5136s), train_loss: 0.2435, val_loss: 0.2860, train_acc: 0.9054, val_acc:0.8887
		train_roc: 0.9596, val_roc: 0.9503, train_auprc: 0.9492, val_auprc: 0.9420




Epoch: 134 (71.3851s), train_loss: 0.2421, val_loss: 0.2908, train_acc: 0.9052, val_acc:0.8869
		train_roc: 0.9600, val_roc: 0.9485, train_auprc: 0.9499, val_auprc: 0.9389




Epoch: 135 (71.6422s), train_loss: 0.2417, val_loss: 0.2907, train_acc: 0.9055, val_acc:0.8878
		train_roc: 0.9601, val_roc: 0.9482, train_auprc: 0.9499, val_auprc: 0.9387




Epoch: 136 (71.6276s), train_loss: 0.2394, val_loss: 0.2867, train_acc: 0.9069, val_acc:0.8895
		train_roc: 0.9609, val_roc: 0.9496, train_auprc: 0.9512, val_auprc: 0.9407




Epoch: 137 (71.4381s), train_loss: 0.2419, val_loss: 0.2920, train_acc: 0.9062, val_acc:0.8872
		train_roc: 0.9600, val_roc: 0.9477, train_auprc: 0.9494, val_auprc: 0.9376




Epoch: 138 (71.2429s), train_loss: 0.2408, val_loss: 0.2928, train_acc: 0.9064, val_acc:0.8849
		train_roc: 0.9605, val_roc: 0.9478, train_auprc: 0.9501, val_auprc: 0.9382




Epoch: 139 (71.3084s), train_loss: 0.2405, val_loss: 0.2915, train_acc: 0.9062, val_acc:0.8873
		train_roc: 0.9605, val_roc: 0.9484, train_auprc: 0.9504, val_auprc: 0.9387




Epoch: 140 (71.6383s), train_loss: 0.2428, val_loss: 0.2902, train_acc: 0.9053, val_acc:0.8872
		train_roc: 0.9598, val_roc: 0.9486, train_auprc: 0.9492, val_auprc: 0.9395




Epoch: 141 (71.1746s), train_loss: 0.2399, val_loss: 0.2911, train_acc: 0.9064, val_acc:0.8876
		train_roc: 0.9608, val_roc: 0.9481, train_auprc: 0.9510, val_auprc: 0.9382




Epoch: 142 (71.2980s), train_loss: 0.2412, val_loss: 0.2900, train_acc: 0.9061, val_acc:0.8884
		train_roc: 0.9603, val_roc: 0.9485, train_auprc: 0.9500, val_auprc: 0.9390




Epoch: 143 (71.6832s), train_loss: 0.2400, val_loss: 0.2877, train_acc: 0.9067, val_acc:0.8894
		train_roc: 0.9608, val_roc: 0.9499, train_auprc: 0.9507, val_auprc: 0.9411




Epoch: 144 (71.2014s), train_loss: 0.2395, val_loss: 0.2918, train_acc: 0.9067, val_acc:0.8873
		train_roc: 0.9609, val_roc: 0.9483, train_auprc: 0.9508, val_auprc: 0.9383




Epoch: 145 (71.2788s), train_loss: 0.2398, val_loss: 0.2908, train_acc: 0.9070, val_acc:0.8876
		train_roc: 0.9607, val_roc: 0.9484, train_auprc: 0.9504, val_auprc: 0.9388




Epoch: 146 (71.2509s), train_loss: 0.2392, val_loss: 0.2888, train_acc: 0.9070, val_acc:0.8886
		train_roc: 0.9608, val_roc: 0.9492, train_auprc: 0.9505, val_auprc: 0.9394




Epoch: 147 (71.1579s), train_loss: 0.2402, val_loss: 0.2912, train_acc: 0.9065, val_acc:0.8876
		train_roc: 0.9605, val_roc: 0.9482, train_auprc: 0.9501, val_auprc: 0.9390




Epoch: 148 (71.7928s), train_loss: 0.2394, val_loss: 0.2908, train_acc: 0.9062, val_acc:0.8873
		train_roc: 0.9609, val_roc: 0.9484, train_auprc: 0.9511, val_auprc: 0.9390




Epoch: 149 (71.3649s), train_loss: 0.2417, val_loss: 0.2882, train_acc: 0.9055, val_acc:0.8894
		train_roc: 0.9601, val_roc: 0.9493, train_auprc: 0.9499, val_auprc: 0.9395




Epoch: 150 (71.4932s), train_loss: 0.2396, val_loss: 0.2863, train_acc: 0.9069, val_acc:0.8894
		train_roc: 0.9608, val_roc: 0.9503, train_auprc: 0.9503, val_auprc: 0.9414




Epoch: 151 (71.4249s), train_loss: 0.2401, val_loss: 0.2903, train_acc: 0.9064, val_acc:0.8876
		train_roc: 0.9605, val_roc: 0.9490, train_auprc: 0.9503, val_auprc: 0.9393




Epoch: 152 (71.6654s), train_loss: 0.2404, val_loss: 0.2884, train_acc: 0.9069, val_acc:0.8883
		train_roc: 0.9604, val_roc: 0.9495, train_auprc: 0.9502, val_auprc: 0.9405




Epoch: 153 (71.2950s), train_loss: 0.2395, val_loss: 0.2891, train_acc: 0.9069, val_acc:0.8876
		train_roc: 0.9607, val_roc: 0.9491, train_auprc: 0.9504, val_auprc: 0.9400




Epoch: 154 (71.5199s), train_loss: 0.2395, val_loss: 0.2872, train_acc: 0.9067, val_acc:0.8896
		train_roc: 0.9607, val_roc: 0.9502, train_auprc: 0.9503, val_auprc: 0.9411




Epoch: 155 (71.3611s), train_loss: 0.2389, val_loss: 0.2901, train_acc: 0.9072, val_acc:0.8876
		train_roc: 0.9610, val_roc: 0.9487, train_auprc: 0.9508, val_auprc: 0.9398




Epoch: 156 (71.1544s), train_loss: 0.2385, val_loss: 0.2895, train_acc: 0.9074, val_acc:0.8887
		train_roc: 0.9611, val_roc: 0.9491, train_auprc: 0.9511, val_auprc: 0.9391




Epoch: 157 (71.4733s), train_loss: 0.2386, val_loss: 0.2903, train_acc: 0.9073, val_acc:0.8878
		train_roc: 0.9611, val_roc: 0.9488, train_auprc: 0.9511, val_auprc: 0.9395




Epoch: 158 (71.2344s), train_loss: 0.2404, val_loss: 0.2906, train_acc: 0.9064, val_acc:0.8878
		train_roc: 0.9604, val_roc: 0.9485, train_auprc: 0.9502, val_auprc: 0.9392




Epoch: 159 (71.5162s), train_loss: 0.2388, val_loss: 0.2871, train_acc: 0.9072, val_acc:0.8898
		train_roc: 0.9609, val_roc: 0.9497, train_auprc: 0.9507, val_auprc: 0.9400




Epoch: 160 (71.4387s), train_loss: 0.2394, val_loss: 0.2891, train_acc: 0.9067, val_acc:0.8883
		train_roc: 0.9608, val_roc: 0.9494, train_auprc: 0.9507, val_auprc: 0.9405




Epoch: 161 (71.6058s), train_loss: 0.2393, val_loss: 0.2882, train_acc: 0.9068, val_acc:0.8889
		train_roc: 0.9609, val_roc: 0.9494, train_auprc: 0.9506, val_auprc: 0.9402




Epoch: 162 (71.2226s), train_loss: 0.2390, val_loss: 0.2865, train_acc: 0.9072, val_acc:0.8903
		train_roc: 0.9610, val_roc: 0.9501, train_auprc: 0.9510, val_auprc: 0.9410




Epoch: 163 (71.5010s), train_loss: 0.2394, val_loss: 0.2875, train_acc: 0.9065, val_acc:0.8891
		train_roc: 0.9608, val_roc: 0.9498, train_auprc: 0.9507, val_auprc: 0.9406




Epoch: 164 (71.4884s), train_loss: 0.2404, val_loss: 0.2909, train_acc: 0.9062, val_acc:0.8876
		train_roc: 0.9605, val_roc: 0.9487, train_auprc: 0.9501, val_auprc: 0.9391




Epoch: 165 (71.4487s), train_loss: 0.2396, val_loss: 0.2878, train_acc: 0.9069, val_acc:0.8893
		train_roc: 0.9607, val_roc: 0.9495, train_auprc: 0.9503, val_auprc: 0.9405




Epoch: 166 (71.3496s), train_loss: 0.2387, val_loss: 0.2927, train_acc: 0.9073, val_acc:0.8875
		train_roc: 0.9612, val_roc: 0.9477, train_auprc: 0.9509, val_auprc: 0.9374




Epoch: 167 (71.3950s), train_loss: 0.2378, val_loss: 0.2912, train_acc: 0.9078, val_acc:0.8876
		train_roc: 0.9613, val_roc: 0.9484, train_auprc: 0.9511, val_auprc: 0.9391




Saving model
Epoch: 168 (71.6587s), train_loss: 0.2393, val_loss: 0.2856, train_acc: 0.9069, val_acc:0.8900
		train_roc: 0.9608, val_roc: 0.9507, train_auprc: 0.9508, val_auprc: 0.9422




Epoch: 169 (71.6155s), train_loss: 0.2410, val_loss: 0.2914, train_acc: 0.9062, val_acc:0.8872
		train_roc: 0.9601, val_roc: 0.9485, train_auprc: 0.9499, val_auprc: 0.9390




Epoch: 170 (71.3263s), train_loss: 0.2399, val_loss: 0.2881, train_acc: 0.9064, val_acc:0.8885
		train_roc: 0.9606, val_roc: 0.9496, train_auprc: 0.9507, val_auprc: 0.9413




Epoch: 171 (71.3624s), train_loss: 0.2402, val_loss: 0.2896, train_acc: 0.9068, val_acc:0.8878
		train_roc: 0.9605, val_roc: 0.9492, train_auprc: 0.9498, val_auprc: 0.9401




Epoch: 172 (71.3972s), train_loss: 0.2383, val_loss: 0.2907, train_acc: 0.9072, val_acc:0.8875
		train_roc: 0.9612, val_roc: 0.9487, train_auprc: 0.9514, val_auprc: 0.9394




Epoch: 173 (71.3954s), train_loss: 0.2389, val_loss: 0.2906, train_acc: 0.9072, val_acc:0.8881
		train_roc: 0.9609, val_roc: 0.9487, train_auprc: 0.9508, val_auprc: 0.9389




Epoch: 174 (71.3388s), train_loss: 0.2390, val_loss: 0.2885, train_acc: 0.9073, val_acc:0.8882
		train_roc: 0.9608, val_roc: 0.9496, train_auprc: 0.9503, val_auprc: 0.9410




Epoch: 175 (71.2734s), train_loss: 0.2380, val_loss: 0.2893, train_acc: 0.9074, val_acc:0.8878
		train_roc: 0.9612, val_roc: 0.9493, train_auprc: 0.9512, val_auprc: 0.9401




Epoch: 176 (71.3827s), train_loss: 0.2403, val_loss: 0.2884, train_acc: 0.9068, val_acc:0.8884
		train_roc: 0.9604, val_roc: 0.9497, train_auprc: 0.9501, val_auprc: 0.9405




Epoch: 177 (71.3191s), train_loss: 0.2391, val_loss: 0.2878, train_acc: 0.9067, val_acc:0.8891
		train_roc: 0.9609, val_roc: 0.9499, train_auprc: 0.9507, val_auprc: 0.9410




Epoch: 178 (71.4691s), train_loss: 0.2389, val_loss: 0.2898, train_acc: 0.9070, val_acc:0.8891
		train_roc: 0.9611, val_roc: 0.9488, train_auprc: 0.9511, val_auprc: 0.9385




Epoch: 179 (71.1623s), train_loss: 0.2379, val_loss: 0.2905, train_acc: 0.9076, val_acc:0.8876
		train_roc: 0.9615, val_roc: 0.9488, train_auprc: 0.9516, val_auprc: 0.9391




Epoch: 180 (71.0179s), train_loss: 0.2372, val_loss: 0.2888, train_acc: 0.9074, val_acc:0.8891
		train_roc: 0.9616, val_roc: 0.9494, train_auprc: 0.9517, val_auprc: 0.9405




Epoch: 181 (71.3121s), train_loss: 0.2394, val_loss: 0.2935, train_acc: 0.9068, val_acc:0.8859
		train_roc: 0.9608, val_roc: 0.9476, train_auprc: 0.9507, val_auprc: 0.9378




Epoch: 182 (71.3156s), train_loss: 0.2374, val_loss: 0.2884, train_acc: 0.9081, val_acc:0.8887
		train_roc: 0.9615, val_roc: 0.9496, train_auprc: 0.9514, val_auprc: 0.9409




Epoch: 183 (71.2601s), train_loss: 0.2386, val_loss: 0.2862, train_acc: 0.9074, val_acc:0.8899
		train_roc: 0.9611, val_roc: 0.9506, train_auprc: 0.9510, val_auprc: 0.9418




Epoch: 184 (71.1811s), train_loss: 0.2387, val_loss: 0.2896, train_acc: 0.9074, val_acc:0.8890
		train_roc: 0.9610, val_roc: 0.9490, train_auprc: 0.9508, val_auprc: 0.9395




Epoch: 185 (71.3435s), train_loss: 0.2386, val_loss: 0.2886, train_acc: 0.9067, val_acc:0.8889
		train_roc: 0.9612, val_roc: 0.9496, train_auprc: 0.9513, val_auprc: 0.9400




Epoch: 186 (71.2743s), train_loss: 0.2396, val_loss: 0.2944, train_acc: 0.9063, val_acc:0.8861
		train_roc: 0.9608, val_roc: 0.9471, train_auprc: 0.9507, val_auprc: 0.9369




Epoch: 187 (71.1764s), train_loss: 0.2398, val_loss: 0.2898, train_acc: 0.9063, val_acc:0.8889
		train_roc: 0.9608, val_roc: 0.9491, train_auprc: 0.9508, val_auprc: 0.9391




Epoch: 188 (71.1289s), train_loss: 0.2396, val_loss: 0.2910, train_acc: 0.9064, val_acc:0.8871
		train_roc: 0.9606, val_roc: 0.9487, train_auprc: 0.9506, val_auprc: 0.9396




Epoch: 189 (71.4586s), train_loss: 0.2379, val_loss: 0.2924, train_acc: 0.9074, val_acc:0.8876
		train_roc: 0.9613, val_roc: 0.9480, train_auprc: 0.9516, val_auprc: 0.9379




Epoch: 190 (71.3538s), train_loss: 0.2400, val_loss: 0.2926, train_acc: 0.9068, val_acc:0.8868
		train_roc: 0.9605, val_roc: 0.9480, train_auprc: 0.9501, val_auprc: 0.9381




Epoch: 191 (71.3627s), train_loss: 0.2399, val_loss: 0.2923, train_acc: 0.9061, val_acc:0.8870
		train_roc: 0.9607, val_roc: 0.9480, train_auprc: 0.9507, val_auprc: 0.9379




Epoch: 192 (71.3553s), train_loss: 0.2388, val_loss: 0.2948, train_acc: 0.9068, val_acc:0.8860
		train_roc: 0.9610, val_roc: 0.9470, train_auprc: 0.9509, val_auprc: 0.9368




Epoch: 193 (71.4080s), train_loss: 0.2386, val_loss: 0.2883, train_acc: 0.9068, val_acc:0.8881
		train_roc: 0.9611, val_roc: 0.9496, train_auprc: 0.9513, val_auprc: 0.9408




Epoch: 194 (71.2039s), train_loss: 0.2375, val_loss: 0.2858, train_acc: 0.9076, val_acc:0.8903
		train_roc: 0.9615, val_roc: 0.9506, train_auprc: 0.9516, val_auprc: 0.9418




Epoch: 195 (71.2705s), train_loss: 0.2382, val_loss: 0.2931, train_acc: 0.9072, val_acc:0.8861
		train_roc: 0.9612, val_roc: 0.9479, train_auprc: 0.9513, val_auprc: 0.9389




Epoch: 196 (71.3220s), train_loss: 0.2376, val_loss: 0.2885, train_acc: 0.9073, val_acc:0.8885
		train_roc: 0.9614, val_roc: 0.9495, train_auprc: 0.9515, val_auprc: 0.9401




Epoch: 197 (71.4646s), train_loss: 0.2379, val_loss: 0.2903, train_acc: 0.9076, val_acc:0.8873
		train_roc: 0.9612, val_roc: 0.9489, train_auprc: 0.9512, val_auprc: 0.9391




Epoch: 198 (71.1341s), train_loss: 0.2400, val_loss: 0.2903, train_acc: 0.9061, val_acc:0.8880
		train_roc: 0.9606, val_roc: 0.9491, train_auprc: 0.9504, val_auprc: 0.9399




Epoch: 199 (71.4790s), train_loss: 0.2383, val_loss: 0.2900, train_acc: 0.9072, val_acc:0.8882
		train_roc: 0.9612, val_roc: 0.9490, train_auprc: 0.9512, val_auprc: 0.9396




Epoch: 200 (71.2927s), train_loss: 0.2395, val_loss: 0.2901, train_acc: 0.9065, val_acc:0.8891
		train_roc: 0.9607, val_roc: 0.9489, train_auprc: 0.9505, val_auprc: 0.9391




Epoch: 201 (71.3299s), train_loss: 0.2385, val_loss: 0.2878, train_acc: 0.9076, val_acc:0.8894
		train_roc: 0.9609, val_roc: 0.9498, train_auprc: 0.9506, val_auprc: 0.9413




Epoch: 202 (71.5115s), train_loss: 0.2382, val_loss: 0.2910, train_acc: 0.9066, val_acc:0.8884
		train_roc: 0.9611, val_roc: 0.9488, train_auprc: 0.9511, val_auprc: 0.9390




Epoch: 203 (71.4844s), train_loss: 0.2407, val_loss: 0.2883, train_acc: 0.9064, val_acc:0.8890
		train_roc: 0.9602, val_roc: 0.9498, train_auprc: 0.9495, val_auprc: 0.9408




Epoch: 204 (71.1554s), train_loss: 0.2386, val_loss: 0.2892, train_acc: 0.9074, val_acc:0.8888
		train_roc: 0.9610, val_roc: 0.9492, train_auprc: 0.9509, val_auprc: 0.9403




Epoch: 205 (71.1807s), train_loss: 0.2387, val_loss: 0.2916, train_acc: 0.9069, val_acc:0.8873
		train_roc: 0.9611, val_roc: 0.9484, train_auprc: 0.9511, val_auprc: 0.9393




Epoch: 206 (71.4191s), train_loss: 0.2384, val_loss: 0.2919, train_acc: 0.9071, val_acc:0.8868
		train_roc: 0.9613, val_roc: 0.9484, train_auprc: 0.9513, val_auprc: 0.9392




Epoch: 207 (71.6374s), train_loss: 0.2377, val_loss: 0.2891, train_acc: 0.9069, val_acc:0.8887
		train_roc: 0.9615, val_roc: 0.9494, train_auprc: 0.9517, val_auprc: 0.9400




Epoch: 208 (71.4603s), train_loss: 0.2380, val_loss: 0.2876, train_acc: 0.9074, val_acc:0.8900
		train_roc: 0.9613, val_roc: 0.9501, train_auprc: 0.9511, val_auprc: 0.9409




Epoch: 209 (71.4576s), train_loss: 0.2376, val_loss: 0.2893, train_acc: 0.9074, val_acc:0.8890
		train_roc: 0.9615, val_roc: 0.9492, train_auprc: 0.9517, val_auprc: 0.9395




Epoch: 210 (71.5232s), train_loss: 0.2380, val_loss: 0.2902, train_acc: 0.9075, val_acc:0.8891
		train_roc: 0.9614, val_roc: 0.9488, train_auprc: 0.9513, val_auprc: 0.9393




Epoch: 211 (71.0985s), train_loss: 0.2378, val_loss: 0.2883, train_acc: 0.9076, val_acc:0.8892
		train_roc: 0.9614, val_roc: 0.9497, train_auprc: 0.9513, val_auprc: 0.9409




Epoch: 212 (71.3481s), train_loss: 0.2384, val_loss: 0.2899, train_acc: 0.9072, val_acc:0.8884
		train_roc: 0.9613, val_roc: 0.9491, train_auprc: 0.9515, val_auprc: 0.9393




Epoch: 213 (71.3747s), train_loss: 0.2388, val_loss: 0.2860, train_acc: 0.9068, val_acc:0.8900
		train_roc: 0.9608, val_roc: 0.9508, train_auprc: 0.9507, val_auprc: 0.9419




Epoch: 214 (71.3391s), train_loss: 0.2370, val_loss: 0.2876, train_acc: 0.9082, val_acc:0.8898
		train_roc: 0.9614, val_roc: 0.9499, train_auprc: 0.9515, val_auprc: 0.9405




Epoch: 215 (71.0910s), train_loss: 0.2393, val_loss: 0.2895, train_acc: 0.9066, val_acc:0.8889
		train_roc: 0.9608, val_roc: 0.9492, train_auprc: 0.9508, val_auprc: 0.9398




Epoch: 216 (71.2390s), train_loss: 0.2375, val_loss: 0.2928, train_acc: 0.9081, val_acc:0.8876
		train_roc: 0.9614, val_roc: 0.9479, train_auprc: 0.9514, val_auprc: 0.9379




Epoch: 217 (71.4527s), train_loss: 0.2383, val_loss: 0.2891, train_acc: 0.9073, val_acc:0.8902
		train_roc: 0.9613, val_roc: 0.9494, train_auprc: 0.9512, val_auprc: 0.9391




Epoch: 218 (71.1925s), train_loss: 0.2404, val_loss: 0.2871, train_acc: 0.9056, val_acc:0.8897
		train_roc: 0.9605, val_roc: 0.9503, train_auprc: 0.9505, val_auprc: 0.9412




Epoch: 219 (71.3508s), train_loss: 0.2381, val_loss: 0.2895, train_acc: 0.9075, val_acc:0.8890
		train_roc: 0.9611, val_roc: 0.9493, train_auprc: 0.9512, val_auprc: 0.9395




Epoch: 220 (71.5711s), train_loss: 0.2378, val_loss: 0.2887, train_acc: 0.9074, val_acc:0.8887
		train_roc: 0.9613, val_roc: 0.9495, train_auprc: 0.9514, val_auprc: 0.9402




Epoch: 221 (71.3098s), train_loss: 0.2371, val_loss: 0.2888, train_acc: 0.9077, val_acc:0.8884
		train_roc: 0.9617, val_roc: 0.9496, train_auprc: 0.9520, val_auprc: 0.9406




Epoch: 222 (71.4445s), train_loss: 0.2384, val_loss: 0.2898, train_acc: 0.9071, val_acc:0.8878
		train_roc: 0.9612, val_roc: 0.9492, train_auprc: 0.9512, val_auprc: 0.9400




Epoch: 223 (71.4058s), train_loss: 0.2393, val_loss: 0.2896, train_acc: 0.9066, val_acc:0.8875
		train_roc: 0.9609, val_roc: 0.9493, train_auprc: 0.9509, val_auprc: 0.9407




Epoch: 224 (71.2926s), train_loss: 0.2382, val_loss: 0.2881, train_acc: 0.9071, val_acc:0.8890
		train_roc: 0.9613, val_roc: 0.9499, train_auprc: 0.9512, val_auprc: 0.9411




Epoch: 225 (71.2650s), train_loss: 0.2377, val_loss: 0.2890, train_acc: 0.9072, val_acc:0.8888
		train_roc: 0.9614, val_roc: 0.9494, train_auprc: 0.9514, val_auprc: 0.9403




Epoch: 226 (71.5994s), train_loss: 0.2377, val_loss: 0.2911, train_acc: 0.9076, val_acc:0.8874
		train_roc: 0.9615, val_roc: 0.9489, train_auprc: 0.9514, val_auprc: 0.9389




Epoch: 227 (71.4929s), train_loss: 0.2384, val_loss: 0.2890, train_acc: 0.9072, val_acc:0.8876
		train_roc: 0.9610, val_roc: 0.9493, train_auprc: 0.9510, val_auprc: 0.9411




Epoch: 228 (71.2805s), train_loss: 0.2383, val_loss: 0.2896, train_acc: 0.9071, val_acc:0.8874
		train_roc: 0.9612, val_roc: 0.9493, train_auprc: 0.9512, val_auprc: 0.9407




Epoch: 229 (71.6433s), train_loss: 0.2379, val_loss: 0.2925, train_acc: 0.9079, val_acc:0.8867
		train_roc: 0.9612, val_roc: 0.9481, train_auprc: 0.9513, val_auprc: 0.9385




Epoch: 230 (72.5615s), train_loss: 0.2381, val_loss: 0.2898, train_acc: 0.9073, val_acc:0.8884
		train_roc: 0.9612, val_roc: 0.9490, train_auprc: 0.9516, val_auprc: 0.9396




Epoch: 231 (71.7617s), train_loss: 0.2384, val_loss: 0.2892, train_acc: 0.9072, val_acc:0.8879
		train_roc: 0.9611, val_roc: 0.9496, train_auprc: 0.9510, val_auprc: 0.9401




Epoch: 232 (72.0409s), train_loss: 0.2401, val_loss: 0.2883, train_acc: 0.9065, val_acc:0.8889
		train_roc: 0.9605, val_roc: 0.9500, train_auprc: 0.9501, val_auprc: 0.9407




Epoch: 233 (71.7652s), train_loss: 0.2398, val_loss: 0.2901, train_acc: 0.9070, val_acc:0.8880
		train_roc: 0.9605, val_roc: 0.9490, train_auprc: 0.9501, val_auprc: 0.9397




Epoch: 234 (72.3534s), train_loss: 0.2381, val_loss: 0.2873, train_acc: 0.9075, val_acc:0.8893
		train_roc: 0.9612, val_roc: 0.9503, train_auprc: 0.9509, val_auprc: 0.9418




Epoch: 235 (71.8336s), train_loss: 0.2396, val_loss: 0.2891, train_acc: 0.9063, val_acc:0.8885
		train_roc: 0.9608, val_roc: 0.9495, train_auprc: 0.9507, val_auprc: 0.9407




Epoch: 236 (71.8958s), train_loss: 0.2385, val_loss: 0.2896, train_acc: 0.9071, val_acc:0.8878
		train_roc: 0.9611, val_roc: 0.9494, train_auprc: 0.9510, val_auprc: 0.9398




Epoch: 237 (72.2793s), train_loss: 0.2369, val_loss: 0.2900, train_acc: 0.9079, val_acc:0.8879
		train_roc: 0.9617, val_roc: 0.9490, train_auprc: 0.9519, val_auprc: 0.9403




Epoch: 238 (72.3640s), train_loss: 0.2367, val_loss: 0.2905, train_acc: 0.9079, val_acc:0.8874
		train_roc: 0.9618, val_roc: 0.9489, train_auprc: 0.9521, val_auprc: 0.9403




Epoch: 239 (71.9941s), train_loss: 0.2382, val_loss: 0.2922, train_acc: 0.9076, val_acc:0.8873
		train_roc: 0.9611, val_roc: 0.9481, train_auprc: 0.9511, val_auprc: 0.9382




Epoch: 240 (71.6727s), train_loss: 0.2368, val_loss: 0.2876, train_acc: 0.9079, val_acc:0.8890
		train_roc: 0.9618, val_roc: 0.9499, train_auprc: 0.9520, val_auprc: 0.9410




Epoch: 241 (71.9043s), train_loss: 0.2380, val_loss: 0.2901, train_acc: 0.9071, val_acc:0.8888
		train_roc: 0.9613, val_roc: 0.9492, train_auprc: 0.9515, val_auprc: 0.9393




Epoch: 242 (72.0967s), train_loss: 0.2359, val_loss: 0.2911, train_acc: 0.9077, val_acc:0.8874
		train_roc: 0.9621, val_roc: 0.9487, train_auprc: 0.9526, val_auprc: 0.9396




Epoch: 243 (71.9395s), train_loss: 0.2386, val_loss: 0.2899, train_acc: 0.9073, val_acc:0.8891
		train_roc: 0.9610, val_roc: 0.9491, train_auprc: 0.9506, val_auprc: 0.9388




Epoch: 244 (71.7515s), train_loss: 0.2395, val_loss: 0.2908, train_acc: 0.9065, val_acc:0.8871
		train_roc: 0.9606, val_roc: 0.9488, train_auprc: 0.9503, val_auprc: 0.9395




Epoch: 245 (72.0584s), train_loss: 0.2374, val_loss: 0.2899, train_acc: 0.9081, val_acc:0.8885
		train_roc: 0.9613, val_roc: 0.9491, train_auprc: 0.9514, val_auprc: 0.9395




Epoch: 246 (71.9365s), train_loss: 0.2386, val_loss: 0.2887, train_acc: 0.9072, val_acc:0.8884
		train_roc: 0.9610, val_roc: 0.9496, train_auprc: 0.9509, val_auprc: 0.9399




Epoch: 247 (71.9003s), train_loss: 0.2394, val_loss: 0.2893, train_acc: 0.9064, val_acc:0.8882
		train_roc: 0.9607, val_roc: 0.9494, train_auprc: 0.9504, val_auprc: 0.9401




Epoch: 248 (71.8758s), train_loss: 0.2373, val_loss: 0.2909, train_acc: 0.9079, val_acc:0.8885
		train_roc: 0.9615, val_roc: 0.9488, train_auprc: 0.9518, val_auprc: 0.9388




Epoch: 249 (71.8860s), train_loss: 0.2385, val_loss: 0.2904, train_acc: 0.9070, val_acc:0.8886
		train_roc: 0.9611, val_roc: 0.9487, train_auprc: 0.9511, val_auprc: 0.9391




Epoch: 250 (71.9925s), train_loss: 0.2374, val_loss: 0.2908, train_acc: 0.9074, val_acc:0.8886
		train_roc: 0.9615, val_roc: 0.9488, train_auprc: 0.9514, val_auprc: 0.9388




Epoch: 251 (71.7069s), train_loss: 0.2377, val_loss: 0.2906, train_acc: 0.9071, val_acc:0.8885
		train_roc: 0.9613, val_roc: 0.9488, train_auprc: 0.9516, val_auprc: 0.9392




Epoch: 252 (72.0007s), train_loss: 0.2374, val_loss: 0.2910, train_acc: 0.9079, val_acc:0.8868
		train_roc: 0.9614, val_roc: 0.9486, train_auprc: 0.9512, val_auprc: 0.9400




Epoch: 253 (71.8285s), train_loss: 0.2378, val_loss: 0.2909, train_acc: 0.9076, val_acc:0.8874
		train_roc: 0.9612, val_roc: 0.9486, train_auprc: 0.9510, val_auprc: 0.9390




Epoch: 254 (72.0037s), train_loss: 0.2383, val_loss: 0.2898, train_acc: 0.9068, val_acc:0.8879
		train_roc: 0.9612, val_roc: 0.9492, train_auprc: 0.9514, val_auprc: 0.9403




Epoch: 255 (72.0689s), train_loss: 0.2384, val_loss: 0.2886, train_acc: 0.9073, val_acc:0.8889
		train_roc: 0.9611, val_roc: 0.9495, train_auprc: 0.9509, val_auprc: 0.9403




Epoch: 256 (72.0028s), train_loss: 0.2373, val_loss: 0.2898, train_acc: 0.9077, val_acc:0.8874
		train_roc: 0.9615, val_roc: 0.9492, train_auprc: 0.9519, val_auprc: 0.9402




Epoch: 257 (72.1300s), train_loss: 0.2375, val_loss: 0.2876, train_acc: 0.9077, val_acc:0.8884
		train_roc: 0.9614, val_roc: 0.9500, train_auprc: 0.9512, val_auprc: 0.9412




Epoch: 258 (71.7936s), train_loss: 0.2380, val_loss: 0.2881, train_acc: 0.9075, val_acc:0.8893
		train_roc: 0.9612, val_roc: 0.9497, train_auprc: 0.9511, val_auprc: 0.9405




Epoch: 259 (72.4941s), train_loss: 0.2405, val_loss: 0.2920, train_acc: 0.9067, val_acc:0.8877
		train_roc: 0.9603, val_roc: 0.9482, train_auprc: 0.9497, val_auprc: 0.9384




Epoch: 260 (73.7869s), train_loss: 0.2374, val_loss: 0.2885, train_acc: 0.9077, val_acc:0.8881
		train_roc: 0.9615, val_roc: 0.9499, train_auprc: 0.9517, val_auprc: 0.9410




Epoch: 261 (83.2777s), train_loss: 0.2380, val_loss: 0.2910, train_acc: 0.9072, val_acc:0.8876
		train_roc: 0.9614, val_roc: 0.9486, train_auprc: 0.9510, val_auprc: 0.9393




Epoch: 262 (90.9462s), train_loss: 0.2385, val_loss: 0.2879, train_acc: 0.9073, val_acc:0.8896
		train_roc: 0.9611, val_roc: 0.9498, train_auprc: 0.9513, val_auprc: 0.9402




Epoch: 263 (90.2823s), train_loss: 0.2382, val_loss: 0.2903, train_acc: 0.9077, val_acc:0.8884
		train_roc: 0.9611, val_roc: 0.9492, train_auprc: 0.9506, val_auprc: 0.9393




Epoch: 264 (82.2504s), train_loss: 0.2388, val_loss: 0.2876, train_acc: 0.9065, val_acc:0.8888
		train_roc: 0.9610, val_roc: 0.9499, train_auprc: 0.9511, val_auprc: 0.9412




Epoch: 265 (82.2457s), train_loss: 0.2392, val_loss: 0.2900, train_acc: 0.9068, val_acc:0.8885
		train_roc: 0.9608, val_roc: 0.9492, train_auprc: 0.9505, val_auprc: 0.9394




Epoch: 266 (82.6260s), train_loss: 0.2387, val_loss: 0.2889, train_acc: 0.9072, val_acc:0.8889
		train_roc: 0.9609, val_roc: 0.9496, train_auprc: 0.9508, val_auprc: 0.9402




Epoch: 267 (81.5236s), train_loss: 0.2385, val_loss: 0.2899, train_acc: 0.9071, val_acc:0.8879
		train_roc: 0.9610, val_roc: 0.9491, train_auprc: 0.9509, val_auprc: 0.9394




Epoch: 268 (83.8269s), train_loss: 0.2365, val_loss: 0.2885, train_acc: 0.9081, val_acc:0.8886
		train_roc: 0.9616, val_roc: 0.9494, train_auprc: 0.9519, val_auprc: 0.9406




Epoch: 269 (82.0664s), train_loss: 0.2396, val_loss: 0.2918, train_acc: 0.9069, val_acc:0.8876
		train_roc: 0.9606, val_roc: 0.9483, train_auprc: 0.9501, val_auprc: 0.9387




Epoch: 270 (81.5482s), train_loss: 0.2362, val_loss: 0.2894, train_acc: 0.9079, val_acc:0.8887
		train_roc: 0.9619, val_roc: 0.9492, train_auprc: 0.9521, val_auprc: 0.9399




Epoch: 271 (81.7139s), train_loss: 0.2374, val_loss: 0.2884, train_acc: 0.9079, val_acc:0.8886
		train_roc: 0.9614, val_roc: 0.9496, train_auprc: 0.9515, val_auprc: 0.9409




Epoch: 272 (81.5896s), train_loss: 0.2398, val_loss: 0.2893, train_acc: 0.9067, val_acc:0.8888
		train_roc: 0.9607, val_roc: 0.9492, train_auprc: 0.9500, val_auprc: 0.9402




Epoch: 273 (81.6038s), train_loss: 0.2380, val_loss: 0.2903, train_acc: 0.9073, val_acc:0.8891
		train_roc: 0.9612, val_roc: 0.9488, train_auprc: 0.9514, val_auprc: 0.9391




Epoch: 274 (81.7537s), train_loss: 0.2376, val_loss: 0.2900, train_acc: 0.9073, val_acc:0.8874
		train_roc: 0.9613, val_roc: 0.9491, train_auprc: 0.9516, val_auprc: 0.9398




Epoch: 275 (81.7206s), train_loss: 0.2391, val_loss: 0.2887, train_acc: 0.9071, val_acc:0.8893
		train_roc: 0.9606, val_roc: 0.9495, train_auprc: 0.9503, val_auprc: 0.9406




Epoch: 276 (81.7909s), train_loss: 0.2390, val_loss: 0.2888, train_acc: 0.9072, val_acc:0.8889
		train_roc: 0.9608, val_roc: 0.9494, train_auprc: 0.9503, val_auprc: 0.9405




Epoch: 277 (81.4467s), train_loss: 0.2391, val_loss: 0.2881, train_acc: 0.9073, val_acc:0.8891
		train_roc: 0.9608, val_roc: 0.9498, train_auprc: 0.9504, val_auprc: 0.9406




Epoch: 278 (81.4577s), train_loss: 0.2377, val_loss: 0.2928, train_acc: 0.9076, val_acc:0.8869
		train_roc: 0.9614, val_roc: 0.9478, train_auprc: 0.9514, val_auprc: 0.9380




Epoch: 279 (81.7475s), train_loss: 0.2381, val_loss: 0.2886, train_acc: 0.9069, val_acc:0.8890
		train_roc: 0.9613, val_roc: 0.9497, train_auprc: 0.9514, val_auprc: 0.9405




Epoch: 280 (81.6725s), train_loss: 0.2390, val_loss: 0.2906, train_acc: 0.9069, val_acc:0.8871
		train_roc: 0.9609, val_roc: 0.9488, train_auprc: 0.9504, val_auprc: 0.9394




Epoch: 281 (81.6380s), train_loss: 0.2392, val_loss: 0.2921, train_acc: 0.9072, val_acc:0.8876
		train_roc: 0.9607, val_roc: 0.9482, train_auprc: 0.9504, val_auprc: 0.9380




Epoch: 282 (81.6562s), train_loss: 0.2369, val_loss: 0.2905, train_acc: 0.9080, val_acc:0.8880
		train_roc: 0.9615, val_roc: 0.9489, train_auprc: 0.9515, val_auprc: 0.9389




Epoch: 283 (81.5183s), train_loss: 0.2390, val_loss: 0.2899, train_acc: 0.9070, val_acc:0.8880
		train_roc: 0.9610, val_roc: 0.9490, train_auprc: 0.9505, val_auprc: 0.9395




Epoch: 284 (81.6804s), train_loss: 0.2394, val_loss: 0.2886, train_acc: 0.9069, val_acc:0.8890
		train_roc: 0.9608, val_roc: 0.9497, train_auprc: 0.9507, val_auprc: 0.9406




Epoch: 285 (81.6565s), train_loss: 0.2384, val_loss: 0.2891, train_acc: 0.9071, val_acc:0.8883
		train_roc: 0.9610, val_roc: 0.9496, train_auprc: 0.9507, val_auprc: 0.9409




Epoch: 286 (81.6339s), train_loss: 0.2375, val_loss: 0.2899, train_acc: 0.9075, val_acc:0.8883
		train_roc: 0.9614, val_roc: 0.9492, train_auprc: 0.9514, val_auprc: 0.9399




Epoch: 287 (81.4972s), train_loss: 0.2391, val_loss: 0.2900, train_acc: 0.9071, val_acc:0.8885
		train_roc: 0.9607, val_roc: 0.9492, train_auprc: 0.9503, val_auprc: 0.9393




Epoch: 288 (81.4305s), train_loss: 0.2381, val_loss: 0.2935, train_acc: 0.9070, val_acc:0.8865
		train_roc: 0.9614, val_roc: 0.9478, train_auprc: 0.9515, val_auprc: 0.9382




Epoch: 289 (81.8210s), train_loss: 0.2388, val_loss: 0.2888, train_acc: 0.9068, val_acc:0.8886
		train_roc: 0.9609, val_roc: 0.9496, train_auprc: 0.9507, val_auprc: 0.9405




Epoch: 290 (81.7463s), train_loss: 0.2381, val_loss: 0.2900, train_acc: 0.9072, val_acc:0.8884
		train_roc: 0.9613, val_roc: 0.9492, train_auprc: 0.9512, val_auprc: 0.9399




Epoch: 291 (81.5727s), train_loss: 0.2384, val_loss: 0.2868, train_acc: 0.9073, val_acc:0.8892
		train_roc: 0.9611, val_roc: 0.9502, train_auprc: 0.9511, val_auprc: 0.9415




Epoch: 292 (81.6946s), train_loss: 0.2402, val_loss: 0.2906, train_acc: 0.9063, val_acc:0.8879
		train_roc: 0.9605, val_roc: 0.9489, train_auprc: 0.9499, val_auprc: 0.9392




Epoch: 293 (81.1347s), train_loss: 0.2381, val_loss: 0.2905, train_acc: 0.9066, val_acc:0.8880
		train_roc: 0.9613, val_roc: 0.9489, train_auprc: 0.9517, val_auprc: 0.9397




Epoch: 294 (81.5643s), train_loss: 0.2361, val_loss: 0.2903, train_acc: 0.9083, val_acc:0.8879
		train_roc: 0.9619, val_roc: 0.9489, train_auprc: 0.9521, val_auprc: 0.9395




Epoch: 295 (81.7091s), train_loss: 0.2399, val_loss: 0.2926, train_acc: 0.9063, val_acc:0.8872
		train_roc: 0.9605, val_roc: 0.9481, train_auprc: 0.9502, val_auprc: 0.9380




Epoch: 296 (81.7442s), train_loss: 0.2386, val_loss: 0.2911, train_acc: 0.9072, val_acc:0.8877
		train_roc: 0.9610, val_roc: 0.9486, train_auprc: 0.9508, val_auprc: 0.9391




Epoch: 297 (81.5639s), train_loss: 0.2390, val_loss: 0.2909, train_acc: 0.9068, val_acc:0.8880
		train_roc: 0.9609, val_roc: 0.9488, train_auprc: 0.9506, val_auprc: 0.9395




Epoch: 298 (81.3571s), train_loss: 0.2378, val_loss: 0.2891, train_acc: 0.9077, val_acc:0.8893
		train_roc: 0.9614, val_roc: 0.9494, train_auprc: 0.9513, val_auprc: 0.9399




Epoch: 299 (81.6694s), train_loss: 0.2385, val_loss: 0.2914, train_acc: 0.9071, val_acc:0.8880
		train_roc: 0.9611, val_roc: 0.9485, train_auprc: 0.9509, val_auprc: 0.9386




Epoch: 300 (81.4430s), train_loss: 0.2396, val_loss: 0.2886, train_acc: 0.9066, val_acc:0.8889
		train_roc: 0.9607, val_roc: 0.9496, train_auprc: 0.9503, val_auprc: 0.9408


In [18]:
# Predict
model = torch.load(model_file)
print(model)
model.to(device=device)
predict(model, test_data_loader, device)

SSI_DDI(
  (initial_norm): LayerNorm(55, affine=True, mode=graph)
  (net_norms): ModuleList(
    (0-4): 5 x LayerNorm(64, affine=True, mode=graph)
  )
  (block0): SSI_DDI_Block(
    (conv): GATConv(55, 32, heads=2)
    (readout): SAGPooling(GraphConv, 64, min_score=-1, multiplier=1.0)
  )
  (block1): SSI_DDI_Block(
    (conv): GATConv(64, 32, heads=2)
    (readout): SAGPooling(GraphConv, 64, min_score=-1, multiplier=1.0)
  )
  (block2): SSI_DDI_Block(
    (conv): GATConv(64, 32, heads=2)
    (readout): SAGPooling(GraphConv, 64, min_score=-1, multiplier=1.0)
  )
  (block3): SSI_DDI_Block(
    (conv): GATConv(64, 32, heads=2)
    (readout): SAGPooling(GraphConv, 64, min_score=-1, multiplier=1.0)
  )
  (block4): SSI_DDI_Block(
    (conv): GATConv(64, 32, heads=2)
    (readout): SAGPooling(GraphConv, 64, min_score=-1, multiplier=1.0)
  )
  (co_attention): CoAttentionLayerImproved(
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (KGE): ComplEx(86, torch.Size([86, 1024]), torch.Size([86, 

  model = torch.load(model_file)


Test Accuracy: 0.8875
Test ROC AUC: 0.9477
Test PRC AUC: 0.9380
