In [1]:
# !pip install torch_geometric rdkit torch

In [2]:
from datetime import datetime
import time
import argparse
import sys
import torch
from torch import optim
from torch import nn
import torch.nn.functional as F
from sklearn import metrics
import pandas as pd
import numpy as np
from torch.nn.modules.container import ModuleList
from torch_geometric.nn import (
    GATConv,
    SAGPooling,
    LayerNorm,
    global_mean_pool,
    max_pool_neighbor_x,
    global_add_pool,
)


In [3]:
# Directory configuration
data_dir = "data"
model_dir = "models"
model_name = "case3"

# sys.path.append('/content/drive/MyDrive/Colab Notebooks')

In [4]:
####### Tunning parameters #######

# Number of epochs
n_epochs = 300

# SagPooling ratio & min score. 
# Set sp_ratio to None to disable ratio in SagPooling
sp_ratio = None
sp_min_score = None

# Enable using gpu
use_cuda = True

# Use activation function for CoAttention Layer
use_activation_fn = False

# Use ComplEx instead of RESCAL
use_ComplEx = False

# Use improved CoAttention Layer
use_improved_CoAttention = False

# Use Explicit Valence
use_explicit_valence = False

# Number of GAT layers
num_GAT_layers = 4

# Number of GAT multiheads
num_GAT_multiheads = 2

#################################

In [5]:
# If using explicit valence feature
if use_explicit_valence:
    from data_preprocessing_explicit_valence import DrugDataset, DrugDataLoader, TOTAL_ATOM_FEATS
else:
    from data_preprocessing import DrugDataset, DrugDataLoader, TOTAL_ATOM_FEATS

  return undirected_edge_list.T, features


In [6]:
mode = "train"

n_atom_feats = TOTAL_ATOM_FEATS
# Not use
n_atom_hid = 64
# Total interactions information in the Interaction_information.csv
rel_total = 86
lr = 1e-2
weight_decay = 5e-4
neg_samples = 1
# Represents the number of samples (or graph instances) loaded in each batch during the training process.
batch_size = 1024
data_size_ratio = 1
kge_dim = 64

device = "cuda" if torch.cuda.is_available() and use_cuda else "cpu"

print(device)
print(f"Epochs: {n_epochs}")
print(f"Total of atom features: {TOTAL_ATOM_FEATS}")

cuda
Epochs: 300
Total of atom features: 55


In [7]:
def print_tunning_parameters():
    print()
    print("####### Tunning parameters #######")
    print()
    
    print("n_epochs =", n_epochs)
    print("use_cuda =", use_cuda)
    print()
    print("num_GAT_layers = ", num_GAT_layers)
    print("num_GAT_multiheads = ", num_GAT_multiheads)
    print()
    print("sp_ratio =", sp_ratio)
    print("sp_min_score =", sp_min_score)
    print()
    print("use_explicit_valence =", use_explicit_valence)
    print()
    print("use_activation_fn =", use_activation_fn)
    print()
    print("use_ComplEx =", use_ComplEx)
    print()
    print("use_improved_CoAttention =", use_improved_CoAttention)
    
    print()
    print("#################################")
    print()


In [8]:
class CoAttentionLayer(nn.Module):
    def __init__(self, n_features, use_activation_fn=True):
        super().__init__()
        self.n_features = n_features
        self.w_q = nn.Parameter(torch.zeros(n_features, n_features // 2))
        self.w_k = nn.Parameter(torch.zeros(n_features, n_features // 2))
        self.bias = nn.Parameter(torch.zeros(n_features // 2))
        self.a = nn.Parameter(torch.zeros(n_features // 2))
        self.use_activation_fn = use_activation_fn

        nn.init.xavier_uniform_(self.w_q)
        nn.init.xavier_uniform_(self.w_k)
        nn.init.xavier_uniform_(self.bias.view(*self.bias.shape, -1))
        nn.init.xavier_uniform_(self.a.view(*self.a.shape, -1))

    def forward(self, receiver, attendant):
        # receiver.shape  = (1024, 4, 64)
        # attendant.shape = (1024, 4, 64)
        keys = receiver @ self.w_k
        queries = attendant @ self.w_q
        # values = receiver @ self.w_v
        values = receiver

        # queries.shape = (1024, 4, 32)
        # keys.shape = (1024, 4, 32)
        e_activations = queries.unsqueeze(-3) + keys.unsqueeze(-2) + self.bias
        if self.use_activation_fn:
            e_scores = torch.tanh(e_activations) @ self.a
        else:
            e_scores = e_activations @ self.a
        attentions = e_scores

        return attentions

class CoAttentionLayerImproved(nn.Module):
    def __init__(self, n_features, use_activation_fn=True, dropout=0.1, n_heads=2):
        super().__init__()
        self.n_features = n_features
        self.n_heads = n_heads
        self.head_dim = n_features // n_heads

        # Projects for queries and keys per head
        self.w_q = nn.Parameter(torch.zeros(self.head_dim, self.head_dim // 2))
        self.w_k = nn.Parameter(torch.zeros(self.head_dim, self.head_dim // 2))
        self.bias = nn.Parameter(torch.zeros(self.n_features // 2))
        self.a = nn.Parameter(torch.zeros(self.n_features // 2))
        self.use_activation_fn = use_activation_fn

        self.dropout = nn.Dropout(dropout)

        nn.init.xavier_uniform_(self.w_q)
        nn.init.xavier_uniform_(self.w_k)
        nn.init.xavier_uniform_(self.bias.view(*self.bias.shape, -1))
        nn.init.xavier_uniform_(self.a.view(*self.a.shape, -1))

    def forward(self, receiver, attendant):
        # receiver.shape  = (1024, 4, 64)
        # attendant.shape = (1024, 4, 64)

        # Split reciever and attendant into multiple heads
        batch_size, gat_size, n_features = receiver.shape
        receiver = receiver.view(batch_size, gat_size, self.n_heads, self.head_dim)
        attendant = attendant.view(batch_size, gat_size, self.n_heads, self.head_dim)
        
        # Compute keys and queries per head
        # receiver.shape  = (1024, 4, 2, 32)
        # attendant.shape = (1024, 4, 2, 32)
        
        # self.w_k.shape  = (32, 16)
        # self.w_q.shape  = (32, 16)
        
        # self.keys.shape     = (1024, 4, 2, 16)
        # self.queries.shape  = (1024, 4, 2, 16)
        keys = receiver @ self.w_k
        queries = attendant @ self.w_q

        # self.keys.shape     = (1024, 4, 32)
        # self.queries.shape  = (1024, 4, 32)
        keys    = keys.view(batch_size, gat_size, self.head_dim)
        queries = queries.view(batch_size, gat_size, self.head_dim)
        # print("keys.shape", keys.shape)
        # print("queries.shape", queries.shape)

        # e_activations.shape = (1024, 4, 4, 32)
        # self.a.shape = (32,)
        e_activations = queries.unsqueeze(-3) + keys.unsqueeze(-2) + self.bias
        if self.use_activation_fn:
            e_scores = torch.tanh(e_activations) @ self.a
        else:
            e_scores = e_activations @ self.a

        # attentions.shape = (1024, 4, 4)
        attentions = e_scores

        return attentions


class RESCAL(nn.Module):
    def __init__(self, n_rels, n_features):
        """
        n_rels: number of relations = 86
        n_features: kge_dim = 64
        """
        super().__init__()
        self.n_rels = n_rels
        self.n_features = n_features
        # Embedding layer
        self.rel_emb = nn.Embedding(self.n_rels, n_features * n_features)
        #  Initializes the embedding weights with the Xavier uniform distribution, which helps maintain the scale of gradients during training
        nn.init.xavier_uniform_(self.rel_emb.weight)

    def forward(self, heads, tails, rels, alpha_scores):
        rels = self.rel_emb(rels)
        rels = F.normalize(rels, dim=-1)
        heads = F.normalize(heads, dim=-1)
        tails = F.normalize(tails, dim=-1)
        # print(rels.shape)
        # Convert shape (1024, 4096) to (1024, 64, 64) for dot product
        rels = rels.view(-1, self.n_features, self.n_features)
        # print(rels.shape)
        # (1024, 4, 64) @ (1024, 64, 64) = (1024, 4, 64) @ (1024, 64, 4) = (1024, 4, 4)
        scores = heads @ rels @ tails.transpose(-2, -1)

        # alpha_scores.shape = (1024, 4, 4)
        # scores.shape = (1024, 4, 4)
        if alpha_scores is not None:
            scores = alpha_scores * scores
        # print(scores.shape)
        
        # sum the last 2 dimensions
        scores = scores.sum(dim=(-2, -1))
        
        # print(scores.shape)
        # Shape(1024,)
        return scores

    def __repr__(self):
        return f"{self.__class__.__name__}({self.n_rels}, {self.rel_emb.weight.shape})"



class ComplEx(nn.Module):
    def __init__(self, n_rels, n_features):
        super().__init__()
        self.n_rels = n_rels
        self.n_features = n_features
    
        # Relation embeddings are also complex
        self.rel_real = nn.Embedding(self.n_rels, (self.n_features // 2) * (self.n_features // 2))
        self.rel_imag = nn.Embedding(self.n_rels, (self.n_features // 2) * (self.n_features // 2))
        
        # Initialize embeddings
        nn.init.xavier_uniform_(self.rel_real.weight)
        nn.init.xavier_uniform_(self.rel_imag.weight)

    def forward(self, heads, tails, rels, alpha_scores=None):
        # Preprocess
        heads = F.normalize(heads, dim=-1)
        tails = F.normalize(tails, dim=-1)
        
        r_real, r_imag = self.rel_real(rels), self.rel_imag(rels)
        r_real = F.normalize(r_real, dim=-1)
        r_imag = F.normalize(r_imag, dim=-1)
        # print(r_real.shape)
        r_real = r_real.view(-1, self.n_features // 2, self.n_features // 2)
        r_imag = r_imag.view(-1, self.n_features // 2, self.n_features // 2)
        # print(r_real.shape)
        # Split heads and tails to imaginary parts
        h_real, h_imag = heads[..., :self.n_features // 2], heads[..., self.n_features // 2:]
        t_real, t_imag = tails[..., :self.n_features // 2], heads[..., self.n_features // 2:]

        # ComplEx scoring functionn
        first_part_score = h_real @ r_real @ t_real.transpose(-2, -1)
        second_part_score = h_real @ r_imag @ t_imag.transpose(-2, -1)
        third_part_score = h_imag @ r_real @ t_imag.transpose(-2, -1)
        fourth_part_score = h_imag @ r_imag @ t_real.transpose(-2, -1)

        scores = first_part_score + second_part_score + third_part_score + fourth_part_score
        
        # If alpha_scores is provided, apply it
        if alpha_scores is not None:
            scores = alpha_scores * scores

        scores = scores.sum(dim=(-2, -1))
        
        return scores

    def __repr__(self):
        return f"{self.__class__.__name__}({self.n_rels}, {self.rel_real.weight.shape}, {self.rel_imag.weight.shape})"


In [9]:
class SSI_DDI(nn.Module):
    def __init__(
        self,
        in_features,
        hidd_dim,
        kge_dim,
        rel_total,
        heads_out_feat_params,
        blocks_params,
        sp_ratio,
        use_activation_fn,
        use_ComplEx,
        sp_min_score,
        use_improved_CoAttention,
    ):
        """
        blocks_params: list of number layers for multi-head attentions
        """
        super().__init__()
        self.in_features = in_features
        # not using this one
        self.hidd_dim = hidd_dim
        self.rel_total = rel_total
        self.kge_dim = kge_dim
        self.n_blocks = len(blocks_params)

        self.initial_norm = LayerNorm(self.in_features)
        self.blocks = []
        self.use_activation_fn = use_activation_fn
        self.use_ComplEx = use_ComplEx
        # Layer normalization list
        self.net_norms = ModuleList()
        for i, (head_out_feats, n_heads) in enumerate(
            zip(heads_out_feat_params, blocks_params)
        ):
            block = SSI_DDI_Block(
                n_heads, in_features, head_out_feats, final_out_feats=self.hidd_dim, sp_ratio=sp_ratio, sp_min_score=sp_min_score
            )
            self.add_module(f"block{i}", block)
            self.blocks.append(block)
            self.net_norms.append(LayerNorm(head_out_feats * n_heads))
            in_features = head_out_feats * n_heads

        if use_improved_CoAttention:
            self.co_attention = CoAttentionLayerImproved(self.kge_dim, self.use_activation_fn)
        else:
            self.co_attention = CoAttentionLayer(self.kge_dim, self.use_activation_fn)
            
        if self.use_ComplEx:
            self.KGE = ComplEx(self.rel_total, self.kge_dim)
        else:
            self.KGE = RESCAL(self.rel_total, self.kge_dim)

    def forward(self, triples):
        h_data, t_data, rels = triples

        h_data.x = self.initial_norm(h_data.x, h_data.batch)
        t_data.x = self.initial_norm(t_data.x, t_data.batch)

        repr_h = []
        repr_t = []

        for i, block in enumerate(self.blocks):
            out1, out2 = block(h_data), block(t_data)

            h_data = out1[0]
            t_data = out2[0]
            r_h = out1[1]
            r_t = out2[1]

            repr_h.append(r_h)
            repr_t.append(r_t)

            h_data.x = F.elu(self.net_norms[i](h_data.x, h_data.batch))
            t_data.x = F.elu(self.net_norms[i](t_data.x, t_data.batch))

        repr_h = torch.stack(repr_h, dim=-2)
        repr_t = torch.stack(repr_t, dim=-2)

        kge_heads = repr_h
        kge_tails = repr_t

        attentions = self.co_attention(kge_heads, kge_tails)
        # attentions = None
        scores = self.KGE(kge_heads, kge_tails, rels, attentions)

        return scores


class SSI_DDI_Block(nn.Module):
    def __init__(self, n_heads, in_features, head_out_feats, final_out_feats, sp_ratio, sp_min_score):
        """
        n_heades: number of multi-head attentions = 2
        in_features: number of features = 55 . For explicit valence use, number of features = 56.
        head_out_feats: number of out features. For 4 layers: [32, 32, 32, 32]
        sp_ratio: SAGPooling ratio
        """
        super().__init__()
        self.n_heads = n_heads
        self.in_features = in_features
        self.out_features = head_out_feats
        self.conv = GATConv(in_features, head_out_feats, n_heads)
        # SAGPooling: Ranks nodes based on self-attention scores

        if sp_ratio is None and sp_min_score is None:
            self.readout = SAGPooling(n_heads * head_out_feats, min_score=-1)
        else:
            if sp_ratio is not None:
                self.readout = SAGPooling(n_heads * head_out_feats, min_score=sp_min_score, ratio=sp_ratio)
            else:
                self.readout = SAGPooling(n_heads * head_out_feats, min_score=sp_min_score)

    def forward(self, data):
        data.x = self.conv(data.x, data.edge_index)
        # Call SAGPooling here
        # If min_score = -1 so nodes will not be filtered out, basically redudant for using the SAGPooling.
        att_x, att_edge_index, att_edge_attr, att_batch, att_perm, att_scores = (
            self.readout(data.x, data.edge_index, batch=data.batch)
        )
        # Aggregates the pooled node features (att_x) across the graph to obtain a global representation
        global_graph_emb = global_add_pool(att_x, att_batch)

        # data = max_pool_neighbor_x(data)
        return data, global_graph_emb


In [10]:
class SigmoidLoss(nn.Module):
    def __init__(self, adv_temperature=None):
        super().__init__()
        self.adv_temperature = adv_temperature

    def forward(self, p_scores, n_scores):
        if self.adv_temperature:
            weights = F.softmax(self.adv_temperature * n_scores, dim=-1).detach()
            n_scores = weights * n_scores
        p_loss = -F.logsigmoid(p_scores).mean()
        n_loss = -F.logsigmoid(-n_scores).mean()

        return (p_loss + n_loss) / 2, p_loss, n_loss


In [11]:
df_ddi_train = pd.read_csv(f"{data_dir}/ddi_training.csv")
df_ddi_val = pd.read_csv(f"{data_dir}/ddi_validation.csv")
df_ddi_test = pd.read_csv(f"{data_dir}/ddi_test.csv")


train_tup = [
    (h, t, r)
    for h, t, r in zip(df_ddi_train["d1"], df_ddi_train["d2"], df_ddi_train["type"])
]
val_tup = [
    (h, t, r) for h, t, r in zip(df_ddi_val["d1"], df_ddi_val["d2"], df_ddi_val["type"])
]
test_tup = [
    (h, t, r)
    for h, t, r in zip(df_ddi_test["d1"], df_ddi_test["d2"], df_ddi_test["type"])
]

train_data = DrugDataset(train_tup, ratio=data_size_ratio, neg_ent=neg_samples)
val_data = DrugDataset(val_tup, ratio=data_size_ratio, disjoint_split=False)
test_data = DrugDataset(test_tup, disjoint_split=False)

print(
    f"Training with {len(train_data)} samples, validating with {len(val_data)}, and testing with {len(test_data)}"
)

train_data_loader = DrugDataLoader(train_data, batch_size=batch_size, shuffle=True)
val_data_loader = DrugDataLoader(val_data, batch_size=batch_size * 3)
test_data_loader = DrugDataLoader(test_data, batch_size=batch_size * 3)


Training with 115185 samples, validating with 38348, and testing with 38337


In [12]:
def do_compute(model, batch, device, training=True):
    """
    *batch: (pos_tri, neg_tri)
    *pos/neg_tri: (batch_h, batch_t, batch_r)
    """
    probas_pred, ground_truth = [], []
    pos_tri, neg_tri = batch

    pos_tri = [tensor.to(device=device) for tensor in pos_tri]
    p_score = model(pos_tri)
    probas_pred.append(torch.sigmoid(p_score.detach()).cpu())
    ground_truth.append(np.ones(len(p_score)))

    neg_tri = [tensor.to(device=device) for tensor in neg_tri]
    n_score = model(neg_tri)
    probas_pred.append(torch.sigmoid(n_score.detach()).cpu())
    ground_truth.append(np.zeros(len(n_score)))

    probas_pred = np.concatenate(probas_pred)
    ground_truth = np.concatenate(ground_truth)

    return p_score, n_score, probas_pred, ground_truth


def do_compute_metrics(probas_pred, target):

    pred = (probas_pred >= 0.5).astype(np.int64)

    acc = metrics.accuracy_score(target, pred)
    auc_roc = metrics.roc_auc_score(target, probas_pred)
    f1_score = metrics.f1_score(target, pred)

    p, r, t = metrics.precision_recall_curve(target, probas_pred)
    auc_prc = metrics.auc(r, p)

    return acc, auc_roc, auc_prc

In [13]:
import csv
def export_metrics(train_metrics, val_metrics, epoch):
    train_metrics_dir = "train_metrics"
    metrics_file = f"{train_metrics_dir}/{model_name}.csv"
    train_loss, train_acc, train_auc_roc, train_auc_prc = train_metrics
    val_loss, val_acc, val_auc_roc, val_auc_prc = val_metrics

    data = [epoch, train_loss, train_acc, train_auc_roc, train_auc_prc, val_loss, val_acc, val_auc_roc, val_auc_prc]
    header = ["epoch", "train_loss", "train_acc", "train_auc_roc", "train_auc_prc", "val_loss", "val_acc", "val_auc_roc", "val_auc_prc"]
    
    if epoch == 1:
        with open(metrics_file, 'w', newline='') as file:
            writer = csv.writer(file)
            # Write the header
            writer.writerow(header)
            # Write the data rows
            writer.writerow(data)
    else:
        with open(metrics_file, 'a', newline='') as file:
            writer = csv.writer(file)
            # Write the data to the file
            writer.writerow(data)
    
    

In [14]:
def train(
    model,
    train_data_loader,
    val_data_loader,
    loss_fn,
    optimizer,
    n_epochs,
    device,
    scheduler=None,
):
    print("Starting training at:", datetime.today())
    print("Device:", device)
    print_tunning_parameters()
    best_val_auc_prc = 0
    for i in range(1, n_epochs + 1):
        start = time.time()
        train_loss = 0
        train_loss_pos = 0
        train_loss_neg = 0
        val_loss = 0
        val_loss_pos = 0
        val_loss_neg = 0
        train_probas_pred = []
        train_ground_truth = []
        val_probas_pred = []
        val_ground_truth = []

        for batch in train_data_loader:
            # print(len(batch))
            model.train()
            p_score, n_score, probas_pred, ground_truth = do_compute(model, batch, device)
            train_probas_pred.append(probas_pred)
            train_ground_truth.append(ground_truth)
            loss, loss_p, loss_n = loss_fn(p_score, n_score)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_loss += loss.item() * len(p_score)
        train_loss /= len(train_data)

        with torch.no_grad():
            train_probas_pred = np.concatenate(train_probas_pred)
            train_ground_truth = np.concatenate(train_ground_truth)

            train_acc, train_auc_roc, train_auc_prc = do_compute_metrics(
                train_probas_pred, train_ground_truth
            )

            for batch in val_data_loader:
                model.eval()
                p_score, n_score, probas_pred, ground_truth = do_compute(model, batch, device)
                val_probas_pred.append(probas_pred)
                val_ground_truth.append(ground_truth)
                loss, loss_p, loss_n = loss_fn(p_score, n_score)
                val_loss += loss.item() * len(p_score)

            val_loss /= len(val_data)
            val_probas_pred = np.concatenate(val_probas_pred)
            val_ground_truth = np.concatenate(val_ground_truth)
            val_acc, val_auc_roc, val_auc_prc = do_compute_metrics(
                val_probas_pred, val_ground_truth
            )
            
            # Save model if this is best result based on val_auc_prc
            if best_val_auc_prc < val_auc_prc:
                print("Saving model")
                best_val_auc_prc = val_auc_prc
                torch.save(model, model_file)

        if scheduler:
            # print('scheduling')
            scheduler.step()

        # Exporting metrics for later plots
        train_metrics = (train_loss, train_acc, train_auc_roc, train_auc_prc)
        val_metrics = (val_loss, val_acc, val_auc_roc, val_auc_prc)
        export_metrics(train_metrics, val_metrics, i)
        
        print(
            f"Epoch: {i} ({time.time() - start:.4f}s), train_loss: {train_loss:.4f}, val_loss: {val_loss:.4f},"
            f" train_acc: {train_acc:.4f}, val_acc:{val_acc:.4f}"
        )
        print(
            f"\t\ttrain_roc: {train_auc_roc:.4f}, val_roc: {val_auc_roc:.4f}, train_auprc: {train_auc_prc:.4f}, val_auprc: {val_auc_prc:.4f}"
        )

    return model

In [15]:
def predict(model, test_data_loader, device):
    print('Starting predicting at', datetime.today())
    print('Device', device)

    test_probas_pred = []
    test_ground_truth = []

    # Switch to evaluation mode
    model.eval()

    with torch.no_grad():  # No need to calculate gradients during testing
        for batch in test_data_loader:
            # Get predictions and ground truth for the batch
            p_score, n_score, probas_pred, ground_truth = do_compute(model, batch, device, training=False)

            # Append the predictions and ground truths
            test_probas_pred.append(probas_pred)
            test_ground_truth.append(ground_truth)


    # Concatenate the results for the entire test dataset
    test_probas_pred = np.concatenate(test_probas_pred)
    test_ground_truth = np.concatenate(test_ground_truth)

    # Calculate the metrics for the test dataset
    test_acc, test_auc_roc, test_auc_prc = do_compute_metrics(test_probas_pred, test_ground_truth)

    print(f'Test Accuracy: {test_acc:.4f}')
    print(f'Test ROC AUC: {test_auc_roc:.4f}')
    print(f'Test PRC AUC: {test_auc_prc:.4f}')

In [16]:
model_file = f"{model_dir}/{model_name}.pth"

heads_out_feat_params = []
block_params = []

for _ in range(num_GAT_layers):
    heads_out_feat_params.append(kge_dim // 2)
    block_params.append(num_GAT_multiheads)

if mode == "train":
    model = SSI_DDI(
        n_atom_feats,
        n_atom_hid,
        kge_dim,
        rel_total,
        heads_out_feat_params=heads_out_feat_params,
        blocks_params=block_params,
        sp_ratio=sp_ratio,
        use_activation_fn=use_activation_fn,
        use_ComplEx=use_ComplEx,
        sp_min_score=sp_min_score,
        use_improved_CoAttention=use_improved_CoAttention,
    )
    loss = SigmoidLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    scheduler = optim.lr_scheduler.LambdaLR(optimizer, lambda epoch: 0.96 ** (epoch))
    print(model)
    model.to(device=device)

SSI_DDI(
  (initial_norm): LayerNorm(55, affine=True, mode=graph)
  (net_norms): ModuleList(
    (0-3): 4 x LayerNorm(64, affine=True, mode=graph)
  )
  (block0): SSI_DDI_Block(
    (conv): GATConv(55, 32, heads=2)
    (readout): SAGPooling(GraphConv, 64, min_score=-1, multiplier=1.0)
  )
  (block1): SSI_DDI_Block(
    (conv): GATConv(64, 32, heads=2)
    (readout): SAGPooling(GraphConv, 64, min_score=-1, multiplier=1.0)
  )
  (block2): SSI_DDI_Block(
    (conv): GATConv(64, 32, heads=2)
    (readout): SAGPooling(GraphConv, 64, min_score=-1, multiplier=1.0)
  )
  (block3): SSI_DDI_Block(
    (conv): GATConv(64, 32, heads=2)
    (readout): SAGPooling(GraphConv, 64, min_score=-1, multiplier=1.0)
  )
  (co_attention): CoAttentionLayer()
  (KGE): RESCAL(86, torch.Size([86, 4096]))
)


In [17]:
if mode == "train":
  # Train
  train(
      model,
      train_data_loader,
      val_data_loader,
      loss,
      optimizer,
      n_epochs,
      device,
      scheduler,
  )


Starting training at: 2024-10-23 19:40:25.998659
Device: cuda

####### Tunning parameters #######

n_epochs = 300
use_cuda = True

num_GAT_layers =  4
num_GAT_multiheads =  2

sp_ratio = None
sp_min_score = None

use_explicit_valence = False

use_activation_fn = False

use_ComplEx = False

use_improved_CoAttention = False

#################################





Saving model
Epoch: 1 (61.2847s), train_loss: 0.6834, val_loss: 0.6447, train_acc: 0.5616, val_acc:0.6115
		train_roc: 0.5907, val_roc: 0.6634, train_auprc: 0.5759, val_auprc: 0.6486




Saving model
Epoch: 2 (59.2862s), train_loss: 0.6172, val_loss: 0.5943, train_acc: 0.6490, val_acc:0.6769
		train_roc: 0.7097, val_roc: 0.7430, train_auprc: 0.6891, val_auprc: 0.7172




Saving model
Epoch: 3 (58.8625s), train_loss: 0.5727, val_loss: 0.5619, train_acc: 0.6935, val_acc:0.6985
		train_roc: 0.7647, val_roc: 0.7753, train_auprc: 0.7439, val_auprc: 0.7553




Saving model
Epoch: 4 (59.0640s), train_loss: 0.5464, val_loss: 0.5448, train_acc: 0.7167, val_acc:0.7123
		train_roc: 0.7915, val_roc: 0.7976, train_auprc: 0.7706, val_auprc: 0.7804




Saving model
Epoch: 5 (59.2327s), train_loss: 0.5287, val_loss: 0.5208, train_acc: 0.7302, val_acc:0.7374
		train_roc: 0.8072, val_roc: 0.8148, train_auprc: 0.7862, val_auprc: 0.7968




Saving model
Epoch: 6 (58.8431s), train_loss: 0.5109, val_loss: 0.5086, train_acc: 0.7441, val_acc:0.7480
		train_roc: 0.8224, val_roc: 0.8250, train_auprc: 0.8014, val_auprc: 0.8049




Saving model
Epoch: 7 (59.1233s), train_loss: 0.5008, val_loss: 0.5022, train_acc: 0.7511, val_acc:0.7527
		train_roc: 0.8303, val_roc: 0.8320, train_auprc: 0.8103, val_auprc: 0.8129




Saving model
Epoch: 8 (59.1157s), train_loss: 0.4917, val_loss: 0.4930, train_acc: 0.7585, val_acc:0.7596
		train_roc: 0.8377, val_roc: 0.8381, train_auprc: 0.8174, val_auprc: 0.8164




Saving model
Epoch: 9 (59.3526s), train_loss: 0.4832, val_loss: 0.4888, train_acc: 0.7646, val_acc:0.7622
		train_roc: 0.8442, val_roc: 0.8403, train_auprc: 0.8240, val_auprc: 0.8215




Saving model
Epoch: 10 (59.2044s), train_loss: 0.4760, val_loss: 0.4703, train_acc: 0.7710, val_acc:0.7748
		train_roc: 0.8495, val_roc: 0.8541, train_auprc: 0.8295, val_auprc: 0.8349




Saving model
Epoch: 11 (60.4015s), train_loss: 0.4670, val_loss: 0.4682, train_acc: 0.7765, val_acc:0.7780
		train_roc: 0.8558, val_roc: 0.8564, train_auprc: 0.8363, val_auprc: 0.8361




Saving model
Epoch: 12 (60.3831s), train_loss: 0.4583, val_loss: 0.4624, train_acc: 0.7822, val_acc:0.7790
		train_roc: 0.8617, val_roc: 0.8605, train_auprc: 0.8418, val_auprc: 0.8405




Saving model
Epoch: 13 (60.9692s), train_loss: 0.4537, val_loss: 0.4570, train_acc: 0.7853, val_acc:0.7810
		train_roc: 0.8647, val_roc: 0.8627, train_auprc: 0.8456, val_auprc: 0.8421




Saving model
Epoch: 14 (60.3566s), train_loss: 0.4483, val_loss: 0.4542, train_acc: 0.7893, val_acc:0.7840
		train_roc: 0.8684, val_roc: 0.8648, train_auprc: 0.8492, val_auprc: 0.8473




Saving model
Epoch: 15 (60.9847s), train_loss: 0.4423, val_loss: 0.4400, train_acc: 0.7924, val_acc:0.7933
		train_roc: 0.8723, val_roc: 0.8744, train_auprc: 0.8541, val_auprc: 0.8578




Saving model
Epoch: 16 (60.7989s), train_loss: 0.4366, val_loss: 0.4402, train_acc: 0.7960, val_acc:0.7949
		train_roc: 0.8759, val_roc: 0.8747, train_auprc: 0.8580, val_auprc: 0.8580




Saving model
Epoch: 17 (58.8734s), train_loss: 0.4321, val_loss: 0.4322, train_acc: 0.7995, val_acc:0.8007
		train_roc: 0.8787, val_roc: 0.8791, train_auprc: 0.8600, val_auprc: 0.8631




Epoch: 18 (59.0435s), train_loss: 0.4248, val_loss: 0.4312, train_acc: 0.8035, val_acc:0.8028
		train_roc: 0.8829, val_roc: 0.8807, train_auprc: 0.8651, val_auprc: 0.8624




Saving model
Epoch: 19 (58.9341s), train_loss: 0.4227, val_loss: 0.4258, train_acc: 0.8075, val_acc:0.8027
		train_roc: 0.8843, val_roc: 0.8836, train_auprc: 0.8651, val_auprc: 0.8678




Epoch: 20 (58.8556s), train_loss: 0.4170, val_loss: 0.4212, train_acc: 0.8106, val_acc:0.8080
		train_roc: 0.8880, val_roc: 0.8852, train_auprc: 0.8701, val_auprc: 0.8670




Saving model
Epoch: 21 (58.9545s), train_loss: 0.4130, val_loss: 0.4177, train_acc: 0.8129, val_acc:0.8117
		train_roc: 0.8898, val_roc: 0.8887, train_auprc: 0.8723, val_auprc: 0.8716




Saving model
Epoch: 22 (59.3516s), train_loss: 0.4090, val_loss: 0.4135, train_acc: 0.8151, val_acc:0.8129
		train_roc: 0.8924, val_roc: 0.8901, train_auprc: 0.8743, val_auprc: 0.8733




Saving model
Epoch: 23 (58.9762s), train_loss: 0.4054, val_loss: 0.4105, train_acc: 0.8169, val_acc:0.8155
		train_roc: 0.8941, val_roc: 0.8916, train_auprc: 0.8768, val_auprc: 0.8748




Saving model
Epoch: 24 (58.8790s), train_loss: 0.3993, val_loss: 0.4047, train_acc: 0.8217, val_acc:0.8186
		train_roc: 0.8974, val_roc: 0.8957, train_auprc: 0.8798, val_auprc: 0.8801




Saving model
Epoch: 25 (58.8835s), train_loss: 0.3952, val_loss: 0.3990, train_acc: 0.8237, val_acc:0.8227
		train_roc: 0.8998, val_roc: 0.8985, train_auprc: 0.8824, val_auprc: 0.8825




Epoch: 26 (58.7820s), train_loss: 0.3920, val_loss: 0.4003, train_acc: 0.8249, val_acc:0.8245
		train_roc: 0.9014, val_roc: 0.8986, train_auprc: 0.8851, val_auprc: 0.8808




Saving model
Epoch: 27 (58.9855s), train_loss: 0.3875, val_loss: 0.3972, train_acc: 0.8282, val_acc:0.8241
		train_roc: 0.9041, val_roc: 0.8991, train_auprc: 0.8874, val_auprc: 0.8825




Saving model
Epoch: 28 (58.9819s), train_loss: 0.3822, val_loss: 0.3918, train_acc: 0.8306, val_acc:0.8244
		train_roc: 0.9064, val_roc: 0.9027, train_auprc: 0.8902, val_auprc: 0.8875




Saving model
Epoch: 29 (58.7690s), train_loss: 0.3799, val_loss: 0.3876, train_acc: 0.8328, val_acc:0.8293
		train_roc: 0.9073, val_roc: 0.9046, train_auprc: 0.8906, val_auprc: 0.8892




Saving model
Epoch: 30 (59.0532s), train_loss: 0.3745, val_loss: 0.3861, train_acc: 0.8355, val_acc:0.8288
		train_roc: 0.9101, val_roc: 0.9052, train_auprc: 0.8940, val_auprc: 0.8905




Saving model
Epoch: 31 (58.8678s), train_loss: 0.3718, val_loss: 0.3810, train_acc: 0.8369, val_acc:0.8329
		train_roc: 0.9115, val_roc: 0.9074, train_auprc: 0.8959, val_auprc: 0.8918




Epoch: 32 (59.0130s), train_loss: 0.3690, val_loss: 0.3817, train_acc: 0.8388, val_acc:0.8331
		train_roc: 0.9128, val_roc: 0.9077, train_auprc: 0.8969, val_auprc: 0.8916




Saving model
Epoch: 33 (59.1157s), train_loss: 0.3667, val_loss: 0.3750, train_acc: 0.8398, val_acc:0.8367
		train_roc: 0.9143, val_roc: 0.9109, train_auprc: 0.8980, val_auprc: 0.8954




Epoch: 34 (58.8846s), train_loss: 0.3635, val_loss: 0.3790, train_acc: 0.8419, val_acc:0.8336
		train_roc: 0.9153, val_roc: 0.9097, train_auprc: 0.8998, val_auprc: 0.8953




Saving model
Epoch: 35 (58.9889s), train_loss: 0.3583, val_loss: 0.3706, train_acc: 0.8441, val_acc:0.8380
		train_roc: 0.9178, val_roc: 0.9128, train_auprc: 0.9029, val_auprc: 0.8962




Saving model
Epoch: 36 (58.8492s), train_loss: 0.3580, val_loss: 0.3669, train_acc: 0.8452, val_acc:0.8414
		train_roc: 0.9181, val_roc: 0.9152, train_auprc: 0.9024, val_auprc: 0.9015




Epoch: 37 (58.9307s), train_loss: 0.3532, val_loss: 0.3646, train_acc: 0.8475, val_acc:0.8429
		train_roc: 0.9205, val_roc: 0.9163, train_auprc: 0.9054, val_auprc: 0.9013




Saving model
Epoch: 38 (58.9253s), train_loss: 0.3503, val_loss: 0.3637, train_acc: 0.8493, val_acc:0.8422
		train_roc: 0.9213, val_roc: 0.9161, train_auprc: 0.9059, val_auprc: 0.9017




Saving model
Epoch: 39 (58.7295s), train_loss: 0.3451, val_loss: 0.3631, train_acc: 0.8522, val_acc:0.8448
		train_roc: 0.9240, val_roc: 0.9169, train_auprc: 0.9090, val_auprc: 0.9028




Saving model
Epoch: 40 (58.8466s), train_loss: 0.3438, val_loss: 0.3551, train_acc: 0.8530, val_acc:0.8480
		train_roc: 0.9245, val_roc: 0.9207, train_auprc: 0.9095, val_auprc: 0.9069




Epoch: 41 (58.8583s), train_loss: 0.3416, val_loss: 0.3594, train_acc: 0.8552, val_acc:0.8464
		train_roc: 0.9253, val_roc: 0.9186, train_auprc: 0.9098, val_auprc: 0.9042




Saving model
Epoch: 42 (58.9073s), train_loss: 0.3379, val_loss: 0.3548, train_acc: 0.8562, val_acc:0.8492
		train_roc: 0.9268, val_roc: 0.9209, train_auprc: 0.9122, val_auprc: 0.9069




Saving model
Epoch: 43 (58.8371s), train_loss: 0.3360, val_loss: 0.3470, train_acc: 0.8567, val_acc:0.8525
		train_roc: 0.9277, val_roc: 0.9244, train_auprc: 0.9132, val_auprc: 0.9112




Epoch: 44 (59.0190s), train_loss: 0.3326, val_loss: 0.3486, train_acc: 0.8594, val_acc:0.8523
		train_roc: 0.9292, val_roc: 0.9228, train_auprc: 0.9147, val_auprc: 0.9089




Saving model
Epoch: 45 (58.9349s), train_loss: 0.3282, val_loss: 0.3452, train_acc: 0.8611, val_acc:0.8544
		train_roc: 0.9310, val_roc: 0.9252, train_auprc: 0.9173, val_auprc: 0.9116




Epoch: 46 (58.8723s), train_loss: 0.3271, val_loss: 0.3461, train_acc: 0.8623, val_acc:0.8538
		train_roc: 0.9316, val_roc: 0.9246, train_auprc: 0.9176, val_auprc: 0.9114




Saving model
Epoch: 47 (58.9873s), train_loss: 0.3239, val_loss: 0.3421, train_acc: 0.8639, val_acc:0.8541
		train_roc: 0.9330, val_roc: 0.9263, train_auprc: 0.9194, val_auprc: 0.9151




Epoch: 48 (58.8841s), train_loss: 0.3239, val_loss: 0.3455, train_acc: 0.8638, val_acc:0.8533
		train_roc: 0.9328, val_roc: 0.9246, train_auprc: 0.9185, val_auprc: 0.9119




Epoch: 49 (58.9738s), train_loss: 0.3189, val_loss: 0.3407, train_acc: 0.8671, val_acc:0.8549
		train_roc: 0.9349, val_roc: 0.9266, train_auprc: 0.9212, val_auprc: 0.9138




Saving model
Epoch: 50 (58.7963s), train_loss: 0.3172, val_loss: 0.3383, train_acc: 0.8668, val_acc:0.8563
		train_roc: 0.9354, val_roc: 0.9281, train_auprc: 0.9223, val_auprc: 0.9155




Saving model
Epoch: 51 (59.0559s), train_loss: 0.3135, val_loss: 0.3340, train_acc: 0.8695, val_acc:0.8599
		train_roc: 0.9368, val_roc: 0.9297, train_auprc: 0.9236, val_auprc: 0.9171




Saving model
Epoch: 52 (58.9811s), train_loss: 0.3128, val_loss: 0.3339, train_acc: 0.8699, val_acc:0.8589
		train_roc: 0.9373, val_roc: 0.9307, train_auprc: 0.9240, val_auprc: 0.9186




Saving model
Epoch: 53 (59.0474s), train_loss: 0.3105, val_loss: 0.3280, train_acc: 0.8714, val_acc:0.8619
		train_roc: 0.9383, val_roc: 0.9331, train_auprc: 0.9249, val_auprc: 0.9219




Epoch: 54 (58.7674s), train_loss: 0.3059, val_loss: 0.3282, train_acc: 0.8740, val_acc:0.8633
		train_roc: 0.9400, val_roc: 0.9324, train_auprc: 0.9271, val_auprc: 0.9204




Epoch: 55 (59.0544s), train_loss: 0.3033, val_loss: 0.3260, train_acc: 0.8748, val_acc:0.8651
		train_roc: 0.9409, val_roc: 0.9332, train_auprc: 0.9284, val_auprc: 0.9203




Epoch: 56 (58.9247s), train_loss: 0.3027, val_loss: 0.3333, train_acc: 0.8752, val_acc:0.8601
		train_roc: 0.9410, val_roc: 0.9306, train_auprc: 0.9285, val_auprc: 0.9172




Epoch: 57 (59.0027s), train_loss: 0.3024, val_loss: 0.3290, train_acc: 0.8751, val_acc:0.8627
		train_roc: 0.9412, val_roc: 0.9324, train_auprc: 0.9285, val_auprc: 0.9204




Saving model
Epoch: 58 (58.9668s), train_loss: 0.2988, val_loss: 0.3243, train_acc: 0.8780, val_acc:0.8659
		train_roc: 0.9426, val_roc: 0.9342, train_auprc: 0.9297, val_auprc: 0.9226




Saving model
Epoch: 59 (58.8572s), train_loss: 0.2964, val_loss: 0.3229, train_acc: 0.8788, val_acc:0.8662
		train_roc: 0.9432, val_roc: 0.9353, train_auprc: 0.9309, val_auprc: 0.9242




Epoch: 60 (58.9904s), train_loss: 0.2942, val_loss: 0.3287, train_acc: 0.8801, val_acc:0.8634
		train_roc: 0.9443, val_roc: 0.9342, train_auprc: 0.9316, val_auprc: 0.9221




Epoch: 61 (58.9533s), train_loss: 0.2936, val_loss: 0.3269, train_acc: 0.8801, val_acc:0.8637
		train_roc: 0.9444, val_roc: 0.9343, train_auprc: 0.9323, val_auprc: 0.9232




Saving model
Epoch: 62 (59.0760s), train_loss: 0.2900, val_loss: 0.3163, train_acc: 0.8823, val_acc:0.8702
		train_roc: 0.9458, val_roc: 0.9379, train_auprc: 0.9338, val_auprc: 0.9274




Epoch: 63 (58.8219s), train_loss: 0.2883, val_loss: 0.3244, train_acc: 0.8825, val_acc:0.8656
		train_roc: 0.9463, val_roc: 0.9352, train_auprc: 0.9341, val_auprc: 0.9235




Epoch: 64 (58.8678s), train_loss: 0.2863, val_loss: 0.3171, train_acc: 0.8839, val_acc:0.8696
		train_roc: 0.9471, val_roc: 0.9372, train_auprc: 0.9353, val_auprc: 0.9262




Epoch: 65 (58.9624s), train_loss: 0.2843, val_loss: 0.3178, train_acc: 0.8845, val_acc:0.8702
		train_roc: 0.9475, val_roc: 0.9373, train_auprc: 0.9357, val_auprc: 0.9258




Saving model
Epoch: 66 (58.7122s), train_loss: 0.2824, val_loss: 0.3154, train_acc: 0.8848, val_acc:0.8692
		train_roc: 0.9482, val_roc: 0.9385, train_auprc: 0.9369, val_auprc: 0.9282




Epoch: 67 (58.5850s), train_loss: 0.2809, val_loss: 0.3176, train_acc: 0.8872, val_acc:0.8709
		train_roc: 0.9490, val_roc: 0.9373, train_auprc: 0.9373, val_auprc: 0.9248




Saving model
Epoch: 68 (59.1421s), train_loss: 0.2804, val_loss: 0.3103, train_acc: 0.8871, val_acc:0.8737
		train_roc: 0.9492, val_roc: 0.9405, train_auprc: 0.9375, val_auprc: 0.9305




Epoch: 69 (58.9280s), train_loss: 0.2788, val_loss: 0.3151, train_acc: 0.8877, val_acc:0.8723
		train_roc: 0.9495, val_roc: 0.9381, train_auprc: 0.9379, val_auprc: 0.9262




Epoch: 70 (58.6616s), train_loss: 0.2764, val_loss: 0.3100, train_acc: 0.8893, val_acc:0.8743
		train_roc: 0.9502, val_roc: 0.9402, train_auprc: 0.9386, val_auprc: 0.9295




Epoch: 71 (58.8207s), train_loss: 0.2746, val_loss: 0.3121, train_acc: 0.8896, val_acc:0.8728
		train_roc: 0.9508, val_roc: 0.9399, train_auprc: 0.9396, val_auprc: 0.9291




Saving model
Epoch: 72 (58.7783s), train_loss: 0.2730, val_loss: 0.3086, train_acc: 0.8908, val_acc:0.8753
		train_roc: 0.9516, val_roc: 0.9413, train_auprc: 0.9401, val_auprc: 0.9309




Epoch: 73 (58.7036s), train_loss: 0.2730, val_loss: 0.3091, train_acc: 0.8906, val_acc:0.8758
		train_roc: 0.9514, val_roc: 0.9410, train_auprc: 0.9402, val_auprc: 0.9302




Epoch: 74 (58.9100s), train_loss: 0.2711, val_loss: 0.3104, train_acc: 0.8922, val_acc:0.8746
		train_roc: 0.9520, val_roc: 0.9405, train_auprc: 0.9409, val_auprc: 0.9306




Saving model
Epoch: 75 (58.8486s), train_loss: 0.2699, val_loss: 0.3095, train_acc: 0.8923, val_acc:0.8745
		train_roc: 0.9527, val_roc: 0.9410, train_auprc: 0.9414, val_auprc: 0.9315




Epoch: 76 (58.6548s), train_loss: 0.2694, val_loss: 0.3114, train_acc: 0.8934, val_acc:0.8736
		train_roc: 0.9526, val_roc: 0.9405, train_auprc: 0.9415, val_auprc: 0.9302




Saving model
Epoch: 77 (59.0057s), train_loss: 0.2685, val_loss: 0.3083, train_acc: 0.8934, val_acc:0.8750
		train_roc: 0.9531, val_roc: 0.9417, train_auprc: 0.9418, val_auprc: 0.9318




Epoch: 78 (58.7088s), train_loss: 0.2669, val_loss: 0.3072, train_acc: 0.8935, val_acc:0.8769
		train_roc: 0.9534, val_roc: 0.9421, train_auprc: 0.9427, val_auprc: 0.9316




Epoch: 79 (58.8027s), train_loss: 0.2656, val_loss: 0.3102, train_acc: 0.8946, val_acc:0.8741
		train_roc: 0.9541, val_roc: 0.9412, train_auprc: 0.9432, val_auprc: 0.9316




Epoch: 80 (58.9456s), train_loss: 0.2664, val_loss: 0.3049, train_acc: 0.8940, val_acc:0.8779
		train_roc: 0.9535, val_roc: 0.9424, train_auprc: 0.9425, val_auprc: 0.9318




Saving model
Epoch: 81 (58.7454s), train_loss: 0.2633, val_loss: 0.3057, train_acc: 0.8955, val_acc:0.8755
		train_roc: 0.9545, val_roc: 0.9430, train_auprc: 0.9439, val_auprc: 0.9337




Epoch: 82 (59.0726s), train_loss: 0.2627, val_loss: 0.3054, train_acc: 0.8963, val_acc:0.8782
		train_roc: 0.9547, val_roc: 0.9433, train_auprc: 0.9438, val_auprc: 0.9334




Epoch: 83 (59.2482s), train_loss: 0.2604, val_loss: 0.3059, train_acc: 0.8970, val_acc:0.8775
		train_roc: 0.9557, val_roc: 0.9429, train_auprc: 0.9454, val_auprc: 0.9332




Epoch: 84 (58.8269s), train_loss: 0.2614, val_loss: 0.3078, train_acc: 0.8971, val_acc:0.8774
		train_roc: 0.9551, val_roc: 0.9426, train_auprc: 0.9441, val_auprc: 0.9330




Saving model
Epoch: 85 (58.9089s), train_loss: 0.2595, val_loss: 0.3037, train_acc: 0.8975, val_acc:0.8786
		train_roc: 0.9557, val_roc: 0.9442, train_auprc: 0.9453, val_auprc: 0.9350




Epoch: 86 (58.8964s), train_loss: 0.2568, val_loss: 0.3051, train_acc: 0.8989, val_acc:0.8784
		train_roc: 0.9567, val_roc: 0.9432, train_auprc: 0.9466, val_auprc: 0.9333




Epoch: 87 (58.8223s), train_loss: 0.2567, val_loss: 0.3054, train_acc: 0.8991, val_acc:0.8778
		train_roc: 0.9567, val_roc: 0.9442, train_auprc: 0.9460, val_auprc: 0.9343




Epoch: 88 (58.8868s), train_loss: 0.2564, val_loss: 0.3019, train_acc: 0.8993, val_acc:0.8808
		train_roc: 0.9566, val_roc: 0.9442, train_auprc: 0.9463, val_auprc: 0.9349




Saving model
Epoch: 89 (58.9156s), train_loss: 0.2550, val_loss: 0.3019, train_acc: 0.8996, val_acc:0.8799
		train_roc: 0.9571, val_roc: 0.9446, train_auprc: 0.9468, val_auprc: 0.9355




Saving model
Epoch: 90 (58.7558s), train_loss: 0.2538, val_loss: 0.3015, train_acc: 0.9005, val_acc:0.8797
		train_roc: 0.9575, val_roc: 0.9451, train_auprc: 0.9475, val_auprc: 0.9361




Saving model
Epoch: 91 (58.8582s), train_loss: 0.2534, val_loss: 0.2996, train_acc: 0.9007, val_acc:0.8804
		train_roc: 0.9575, val_roc: 0.9459, train_auprc: 0.9474, val_auprc: 0.9374




Epoch: 92 (58.8550s), train_loss: 0.2550, val_loss: 0.3029, train_acc: 0.8998, val_acc:0.8793
		train_roc: 0.9572, val_roc: 0.9449, train_auprc: 0.9468, val_auprc: 0.9359




Epoch: 93 (58.8653s), train_loss: 0.2522, val_loss: 0.3035, train_acc: 0.9009, val_acc:0.8803
		train_roc: 0.9581, val_roc: 0.9445, train_auprc: 0.9483, val_auprc: 0.9353




Epoch: 94 (58.9334s), train_loss: 0.2520, val_loss: 0.3020, train_acc: 0.9009, val_acc:0.8798
		train_roc: 0.9581, val_roc: 0.9445, train_auprc: 0.9479, val_auprc: 0.9354




Epoch: 95 (58.9414s), train_loss: 0.2499, val_loss: 0.3012, train_acc: 0.9026, val_acc:0.8799
		train_roc: 0.9586, val_roc: 0.9452, train_auprc: 0.9488, val_auprc: 0.9360




Epoch: 96 (58.8339s), train_loss: 0.2509, val_loss: 0.3012, train_acc: 0.9019, val_acc:0.8800
		train_roc: 0.9583, val_roc: 0.9458, train_auprc: 0.9479, val_auprc: 0.9365




Epoch: 97 (58.6978s), train_loss: 0.2487, val_loss: 0.3012, train_acc: 0.9026, val_acc:0.8820
		train_roc: 0.9591, val_roc: 0.9457, train_auprc: 0.9495, val_auprc: 0.9370




Saving model
Epoch: 98 (58.9685s), train_loss: 0.2503, val_loss: 0.2999, train_acc: 0.9028, val_acc:0.8817
		train_roc: 0.9585, val_roc: 0.9463, train_auprc: 0.9483, val_auprc: 0.9378




Epoch: 99 (58.7439s), train_loss: 0.2491, val_loss: 0.3006, train_acc: 0.9030, val_acc:0.8806
		train_roc: 0.9589, val_roc: 0.9457, train_auprc: 0.9491, val_auprc: 0.9374




Epoch: 100 (58.8906s), train_loss: 0.2491, val_loss: 0.2995, train_acc: 0.9025, val_acc:0.8821
		train_roc: 0.9588, val_roc: 0.9460, train_auprc: 0.9491, val_auprc: 0.9366




Epoch: 101 (58.9190s), train_loss: 0.2485, val_loss: 0.3008, train_acc: 0.9030, val_acc:0.8808
		train_roc: 0.9590, val_roc: 0.9456, train_auprc: 0.9491, val_auprc: 0.9366




Epoch: 102 (58.8730s), train_loss: 0.2472, val_loss: 0.3038, train_acc: 0.9038, val_acc:0.8793
		train_roc: 0.9596, val_roc: 0.9445, train_auprc: 0.9501, val_auprc: 0.9354




Saving model
Epoch: 103 (58.9969s), train_loss: 0.2476, val_loss: 0.2963, train_acc: 0.9034, val_acc:0.8832
		train_roc: 0.9592, val_roc: 0.9470, train_auprc: 0.9494, val_auprc: 0.9386




Epoch: 104 (58.8640s), train_loss: 0.2462, val_loss: 0.2989, train_acc: 0.9040, val_acc:0.8815
		train_roc: 0.9597, val_roc: 0.9464, train_auprc: 0.9496, val_auprc: 0.9379




Epoch: 105 (58.8787s), train_loss: 0.2457, val_loss: 0.3003, train_acc: 0.9043, val_acc:0.8811
		train_roc: 0.9599, val_roc: 0.9465, train_auprc: 0.9500, val_auprc: 0.9379




Epoch: 106 (59.0668s), train_loss: 0.2436, val_loss: 0.2996, train_acc: 0.9052, val_acc:0.8833
		train_roc: 0.9605, val_roc: 0.9471, train_auprc: 0.9510, val_auprc: 0.9384




Epoch: 107 (58.9077s), train_loss: 0.2439, val_loss: 0.3024, train_acc: 0.9052, val_acc:0.8811
		train_roc: 0.9604, val_roc: 0.9458, train_auprc: 0.9508, val_auprc: 0.9370




Saving model
Epoch: 108 (59.0345s), train_loss: 0.2445, val_loss: 0.2983, train_acc: 0.9054, val_acc:0.8820
		train_roc: 0.9602, val_roc: 0.9466, train_auprc: 0.9505, val_auprc: 0.9387




Epoch: 109 (58.9757s), train_loss: 0.2437, val_loss: 0.2996, train_acc: 0.9051, val_acc:0.8821
		train_roc: 0.9605, val_roc: 0.9468, train_auprc: 0.9511, val_auprc: 0.9384




Epoch: 110 (58.7844s), train_loss: 0.2420, val_loss: 0.2991, train_acc: 0.9062, val_acc:0.8828
		train_roc: 0.9610, val_roc: 0.9467, train_auprc: 0.9518, val_auprc: 0.9381




Epoch: 111 (58.8549s), train_loss: 0.2420, val_loss: 0.3000, train_acc: 0.9067, val_acc:0.8824
		train_roc: 0.9611, val_roc: 0.9469, train_auprc: 0.9518, val_auprc: 0.9386




Epoch: 112 (59.0536s), train_loss: 0.2409, val_loss: 0.3015, train_acc: 0.9067, val_acc:0.8823
		train_roc: 0.9615, val_roc: 0.9465, train_auprc: 0.9520, val_auprc: 0.9386




Epoch: 113 (59.3680s), train_loss: 0.2411, val_loss: 0.3016, train_acc: 0.9062, val_acc:0.8817
		train_roc: 0.9614, val_roc: 0.9458, train_auprc: 0.9523, val_auprc: 0.9372




Epoch: 114 (58.9275s), train_loss: 0.2396, val_loss: 0.2994, train_acc: 0.9068, val_acc:0.8836
		train_roc: 0.9619, val_roc: 0.9467, train_auprc: 0.9527, val_auprc: 0.9380




Epoch: 115 (59.1273s), train_loss: 0.2423, val_loss: 0.3033, train_acc: 0.9057, val_acc:0.8821
		train_roc: 0.9609, val_roc: 0.9456, train_auprc: 0.9511, val_auprc: 0.9365




Epoch: 116 (58.8837s), train_loss: 0.2405, val_loss: 0.3006, train_acc: 0.9069, val_acc:0.8820
		train_roc: 0.9615, val_roc: 0.9462, train_auprc: 0.9519, val_auprc: 0.9378




Saving model
Epoch: 117 (59.0273s), train_loss: 0.2393, val_loss: 0.2967, train_acc: 0.9077, val_acc:0.8836
		train_roc: 0.9619, val_roc: 0.9481, train_auprc: 0.9526, val_auprc: 0.9406




Epoch: 118 (58.9450s), train_loss: 0.2411, val_loss: 0.3013, train_acc: 0.9066, val_acc:0.8823
		train_roc: 0.9611, val_roc: 0.9466, train_auprc: 0.9514, val_auprc: 0.9374




Epoch: 119 (58.9648s), train_loss: 0.2394, val_loss: 0.2970, train_acc: 0.9068, val_acc:0.8846
		train_roc: 0.9618, val_roc: 0.9478, train_auprc: 0.9529, val_auprc: 0.9399




Epoch: 120 (59.0601s), train_loss: 0.2397, val_loss: 0.2977, train_acc: 0.9075, val_acc:0.8840
		train_roc: 0.9618, val_roc: 0.9476, train_auprc: 0.9524, val_auprc: 0.9396




Epoch: 121 (59.1067s), train_loss: 0.2387, val_loss: 0.3009, train_acc: 0.9073, val_acc:0.8812
		train_roc: 0.9621, val_roc: 0.9469, train_auprc: 0.9531, val_auprc: 0.9383




Epoch: 122 (59.0078s), train_loss: 0.2401, val_loss: 0.3026, train_acc: 0.9069, val_acc:0.8823
		train_roc: 0.9613, val_roc: 0.9462, train_auprc: 0.9518, val_auprc: 0.9376




Epoch: 123 (59.0472s), train_loss: 0.2395, val_loss: 0.3006, train_acc: 0.9067, val_acc:0.8820
		train_roc: 0.9617, val_roc: 0.9466, train_auprc: 0.9525, val_auprc: 0.9379




Epoch: 124 (59.1428s), train_loss: 0.2400, val_loss: 0.3018, train_acc: 0.9077, val_acc:0.8823
		train_roc: 0.9615, val_roc: 0.9463, train_auprc: 0.9518, val_auprc: 0.9377




Epoch: 125 (59.1877s), train_loss: 0.2401, val_loss: 0.3003, train_acc: 0.9066, val_acc:0.8827
		train_roc: 0.9615, val_roc: 0.9466, train_auprc: 0.9520, val_auprc: 0.9383




Epoch: 126 (59.0382s), train_loss: 0.2379, val_loss: 0.3034, train_acc: 0.9083, val_acc:0.8825
		train_roc: 0.9622, val_roc: 0.9462, train_auprc: 0.9531, val_auprc: 0.9371




Epoch: 127 (59.0234s), train_loss: 0.2402, val_loss: 0.2999, train_acc: 0.9074, val_acc:0.8831
		train_roc: 0.9613, val_roc: 0.9470, train_auprc: 0.9515, val_auprc: 0.9387




Epoch: 128 (59.1516s), train_loss: 0.2392, val_loss: 0.3043, train_acc: 0.9079, val_acc:0.8814
		train_roc: 0.9618, val_roc: 0.9455, train_auprc: 0.9522, val_auprc: 0.9366




Epoch: 129 (59.3409s), train_loss: 0.2388, val_loss: 0.2977, train_acc: 0.9081, val_acc:0.8831
		train_roc: 0.9617, val_roc: 0.9479, train_auprc: 0.9520, val_auprc: 0.9403




Epoch: 130 (59.0966s), train_loss: 0.2398, val_loss: 0.3007, train_acc: 0.9072, val_acc:0.8826
		train_roc: 0.9614, val_roc: 0.9467, train_auprc: 0.9514, val_auprc: 0.9386




Epoch: 131 (58.8849s), train_loss: 0.2376, val_loss: 0.3008, train_acc: 0.9081, val_acc:0.8829
		train_roc: 0.9623, val_roc: 0.9471, train_auprc: 0.9531, val_auprc: 0.9387




Saving model
Epoch: 132 (58.9381s), train_loss: 0.2378, val_loss: 0.2971, train_acc: 0.9081, val_acc:0.8836
		train_roc: 0.9621, val_roc: 0.9483, train_auprc: 0.9530, val_auprc: 0.9408




Epoch: 133 (58.7931s), train_loss: 0.2373, val_loss: 0.2976, train_acc: 0.9085, val_acc:0.8833
		train_roc: 0.9622, val_roc: 0.9480, train_auprc: 0.9528, val_auprc: 0.9405




Epoch: 134 (59.1551s), train_loss: 0.2357, val_loss: 0.3011, train_acc: 0.9086, val_acc:0.8830
		train_roc: 0.9628, val_roc: 0.9469, train_auprc: 0.9539, val_auprc: 0.9384




Epoch: 135 (59.1046s), train_loss: 0.2380, val_loss: 0.3006, train_acc: 0.9080, val_acc:0.8827
		train_roc: 0.9621, val_roc: 0.9472, train_auprc: 0.9527, val_auprc: 0.9391




Epoch: 136 (58.9797s), train_loss: 0.2363, val_loss: 0.2980, train_acc: 0.9083, val_acc:0.8842
		train_roc: 0.9627, val_roc: 0.9484, train_auprc: 0.9540, val_auprc: 0.9406




Epoch: 137 (58.8566s), train_loss: 0.2382, val_loss: 0.2994, train_acc: 0.9078, val_acc:0.8842
		train_roc: 0.9619, val_roc: 0.9476, train_auprc: 0.9523, val_auprc: 0.9393




Epoch: 138 (58.9579s), train_loss: 0.2370, val_loss: 0.3029, train_acc: 0.9084, val_acc:0.8822
		train_roc: 0.9624, val_roc: 0.9463, train_auprc: 0.9532, val_auprc: 0.9374




Epoch: 139 (58.9527s), train_loss: 0.2373, val_loss: 0.2996, train_acc: 0.9089, val_acc:0.8831
		train_roc: 0.9623, val_roc: 0.9477, train_auprc: 0.9525, val_auprc: 0.9398




Epoch: 140 (58.9994s), train_loss: 0.2363, val_loss: 0.2991, train_acc: 0.9084, val_acc:0.8841
		train_roc: 0.9626, val_roc: 0.9479, train_auprc: 0.9534, val_auprc: 0.9395




Epoch: 141 (58.8459s), train_loss: 0.2347, val_loss: 0.2999, train_acc: 0.9096, val_acc:0.8838
		train_roc: 0.9631, val_roc: 0.9475, train_auprc: 0.9541, val_auprc: 0.9392




Epoch: 142 (59.0425s), train_loss: 0.2369, val_loss: 0.3028, train_acc: 0.9085, val_acc:0.8826
		train_roc: 0.9623, val_roc: 0.9463, train_auprc: 0.9529, val_auprc: 0.9368




Epoch: 143 (58.8127s), train_loss: 0.2369, val_loss: 0.3009, train_acc: 0.9084, val_acc:0.8829
		train_roc: 0.9624, val_roc: 0.9473, train_auprc: 0.9533, val_auprc: 0.9391




Epoch: 144 (59.1934s), train_loss: 0.2363, val_loss: 0.3028, train_acc: 0.9083, val_acc:0.8815
		train_roc: 0.9627, val_roc: 0.9466, train_auprc: 0.9537, val_auprc: 0.9384




Epoch: 145 (59.2511s), train_loss: 0.2362, val_loss: 0.3037, train_acc: 0.9088, val_acc:0.8820
		train_roc: 0.9626, val_roc: 0.9462, train_auprc: 0.9532, val_auprc: 0.9370




Epoch: 146 (58.9854s), train_loss: 0.2355, val_loss: 0.3006, train_acc: 0.9090, val_acc:0.8837
		train_roc: 0.9628, val_roc: 0.9473, train_auprc: 0.9535, val_auprc: 0.9385




Epoch: 147 (58.8929s), train_loss: 0.2353, val_loss: 0.3002, train_acc: 0.9089, val_acc:0.8820
		train_roc: 0.9630, val_roc: 0.9473, train_auprc: 0.9542, val_auprc: 0.9397




Epoch: 148 (58.9360s), train_loss: 0.2346, val_loss: 0.3004, train_acc: 0.9092, val_acc:0.8835
		train_roc: 0.9631, val_roc: 0.9472, train_auprc: 0.9541, val_auprc: 0.9384




Epoch: 149 (58.9122s), train_loss: 0.2332, val_loss: 0.3027, train_acc: 0.9100, val_acc:0.8828
		train_roc: 0.9637, val_roc: 0.9464, train_auprc: 0.9549, val_auprc: 0.9376




Epoch: 150 (59.0708s), train_loss: 0.2348, val_loss: 0.3025, train_acc: 0.9096, val_acc:0.8823
		train_roc: 0.9630, val_roc: 0.9465, train_auprc: 0.9540, val_auprc: 0.9382




Epoch: 151 (58.8345s), train_loss: 0.2346, val_loss: 0.3047, train_acc: 0.9092, val_acc:0.8813
		train_roc: 0.9632, val_roc: 0.9458, train_auprc: 0.9541, val_auprc: 0.9367




Epoch: 152 (58.9652s), train_loss: 0.2358, val_loss: 0.3006, train_acc: 0.9087, val_acc:0.8838
		train_roc: 0.9627, val_roc: 0.9475, train_auprc: 0.9533, val_auprc: 0.9389




Epoch: 153 (58.9732s), train_loss: 0.2352, val_loss: 0.3018, train_acc: 0.9098, val_acc:0.8823
		train_roc: 0.9628, val_roc: 0.9468, train_auprc: 0.9532, val_auprc: 0.9387




Epoch: 154 (58.8461s), train_loss: 0.2364, val_loss: 0.3028, train_acc: 0.9085, val_acc:0.8821
		train_roc: 0.9624, val_roc: 0.9468, train_auprc: 0.9531, val_auprc: 0.9385




Epoch: 155 (59.0029s), train_loss: 0.2359, val_loss: 0.3015, train_acc: 0.9089, val_acc:0.8823
		train_roc: 0.9627, val_roc: 0.9472, train_auprc: 0.9534, val_auprc: 0.9392




Epoch: 156 (58.9063s), train_loss: 0.2332, val_loss: 0.3026, train_acc: 0.9099, val_acc:0.8817
		train_roc: 0.9637, val_roc: 0.9469, train_auprc: 0.9549, val_auprc: 0.9388




Epoch: 157 (59.2167s), train_loss: 0.2353, val_loss: 0.3041, train_acc: 0.9093, val_acc:0.8810
		train_roc: 0.9629, val_roc: 0.9461, train_auprc: 0.9536, val_auprc: 0.9381




Epoch: 158 (58.9674s), train_loss: 0.2343, val_loss: 0.3018, train_acc: 0.9095, val_acc:0.8825
		train_roc: 0.9632, val_roc: 0.9471, train_auprc: 0.9543, val_auprc: 0.9387




Epoch: 159 (59.0011s), train_loss: 0.2336, val_loss: 0.3000, train_acc: 0.9099, val_acc:0.8831
		train_roc: 0.9635, val_roc: 0.9481, train_auprc: 0.9547, val_auprc: 0.9406




Epoch: 160 (59.0415s), train_loss: 0.2339, val_loss: 0.3035, train_acc: 0.9100, val_acc:0.8820
		train_roc: 0.9634, val_roc: 0.9466, train_auprc: 0.9545, val_auprc: 0.9377




Epoch: 161 (58.9398s), train_loss: 0.2353, val_loss: 0.3009, train_acc: 0.9089, val_acc:0.8840
		train_roc: 0.9628, val_roc: 0.9476, train_auprc: 0.9536, val_auprc: 0.9397




Epoch: 162 (58.9892s), train_loss: 0.2329, val_loss: 0.3038, train_acc: 0.9103, val_acc:0.8816
		train_roc: 0.9636, val_roc: 0.9464, train_auprc: 0.9547, val_auprc: 0.9377




Epoch: 163 (58.8020s), train_loss: 0.2350, val_loss: 0.3040, train_acc: 0.9090, val_acc:0.8822
		train_roc: 0.9629, val_roc: 0.9465, train_auprc: 0.9539, val_auprc: 0.9379




Epoch: 164 (58.8770s), train_loss: 0.2343, val_loss: 0.3024, train_acc: 0.9098, val_acc:0.8822
		train_roc: 0.9631, val_roc: 0.9470, train_auprc: 0.9541, val_auprc: 0.9393




Epoch: 165 (58.9892s), train_loss: 0.2359, val_loss: 0.3028, train_acc: 0.9088, val_acc:0.8829
		train_roc: 0.9626, val_roc: 0.9471, train_auprc: 0.9534, val_auprc: 0.9391




Epoch: 166 (58.9200s), train_loss: 0.2344, val_loss: 0.3034, train_acc: 0.9093, val_acc:0.8819
		train_roc: 0.9631, val_roc: 0.9467, train_auprc: 0.9541, val_auprc: 0.9386




Epoch: 167 (58.9105s), train_loss: 0.2346, val_loss: 0.3013, train_acc: 0.9097, val_acc:0.8825
		train_roc: 0.9630, val_roc: 0.9474, train_auprc: 0.9537, val_auprc: 0.9398




Epoch: 168 (58.8165s), train_loss: 0.2322, val_loss: 0.3019, train_acc: 0.9104, val_acc:0.8836
		train_roc: 0.9639, val_roc: 0.9474, train_auprc: 0.9551, val_auprc: 0.9395




Epoch: 169 (58.9820s), train_loss: 0.2351, val_loss: 0.3025, train_acc: 0.9092, val_acc:0.8830
		train_roc: 0.9629, val_roc: 0.9471, train_auprc: 0.9536, val_auprc: 0.9394




Epoch: 170 (59.0438s), train_loss: 0.2328, val_loss: 0.3034, train_acc: 0.9101, val_acc:0.8824
		train_roc: 0.9637, val_roc: 0.9469, train_auprc: 0.9549, val_auprc: 0.9387




Epoch: 171 (59.0208s), train_loss: 0.2356, val_loss: 0.3032, train_acc: 0.9092, val_acc:0.8827
		train_roc: 0.9627, val_roc: 0.9469, train_auprc: 0.9533, val_auprc: 0.9382




Epoch: 172 (59.0085s), train_loss: 0.2339, val_loss: 0.3008, train_acc: 0.9097, val_acc:0.8839
		train_roc: 0.9634, val_roc: 0.9476, train_auprc: 0.9545, val_auprc: 0.9394




Epoch: 173 (58.9114s), train_loss: 0.2336, val_loss: 0.3014, train_acc: 0.9095, val_acc:0.8834
		train_roc: 0.9635, val_roc: 0.9476, train_auprc: 0.9546, val_auprc: 0.9394




Saving model
Epoch: 174 (58.8769s), train_loss: 0.2351, val_loss: 0.2990, train_acc: 0.9092, val_acc:0.8845
		train_roc: 0.9629, val_roc: 0.9485, train_auprc: 0.9534, val_auprc: 0.9411




Epoch: 175 (58.9302s), train_loss: 0.2336, val_loss: 0.3000, train_acc: 0.9097, val_acc:0.8842
		train_roc: 0.9633, val_roc: 0.9481, train_auprc: 0.9541, val_auprc: 0.9406




Epoch: 176 (58.9105s), train_loss: 0.2331, val_loss: 0.3038, train_acc: 0.9097, val_acc:0.8815
		train_roc: 0.9636, val_roc: 0.9467, train_auprc: 0.9548, val_auprc: 0.9386




Epoch: 177 (58.8250s), train_loss: 0.2321, val_loss: 0.3010, train_acc: 0.9109, val_acc:0.8834
		train_roc: 0.9638, val_roc: 0.9477, train_auprc: 0.9550, val_auprc: 0.9399




Epoch: 178 (58.9311s), train_loss: 0.2335, val_loss: 0.3030, train_acc: 0.9100, val_acc:0.8820
		train_roc: 0.9635, val_roc: 0.9471, train_auprc: 0.9544, val_auprc: 0.9392




Epoch: 179 (59.0359s), train_loss: 0.2333, val_loss: 0.3026, train_acc: 0.9101, val_acc:0.8825
		train_roc: 0.9635, val_roc: 0.9471, train_auprc: 0.9543, val_auprc: 0.9389




Epoch: 180 (58.9765s), train_loss: 0.2339, val_loss: 0.3024, train_acc: 0.9099, val_acc:0.8833
		train_roc: 0.9633, val_roc: 0.9472, train_auprc: 0.9539, val_auprc: 0.9389




Saving model
Epoch: 181 (58.9642s), train_loss: 0.2342, val_loss: 0.2989, train_acc: 0.9098, val_acc:0.8847
		train_roc: 0.9631, val_roc: 0.9488, train_auprc: 0.9537, val_auprc: 0.9413




Epoch: 182 (58.7273s), train_loss: 0.2332, val_loss: 0.3042, train_acc: 0.9101, val_acc:0.8822
		train_roc: 0.9633, val_roc: 0.9465, train_auprc: 0.9542, val_auprc: 0.9380




Epoch: 183 (58.9507s), train_loss: 0.2334, val_loss: 0.3037, train_acc: 0.9102, val_acc:0.8831
		train_roc: 0.9634, val_roc: 0.9467, train_auprc: 0.9542, val_auprc: 0.9378




Epoch: 184 (58.8076s), train_loss: 0.2348, val_loss: 0.3002, train_acc: 0.9096, val_acc:0.8848
		train_roc: 0.9629, val_roc: 0.9481, train_auprc: 0.9536, val_auprc: 0.9399




Epoch: 185 (58.8653s), train_loss: 0.2343, val_loss: 0.3036, train_acc: 0.9097, val_acc:0.8827
		train_roc: 0.9631, val_roc: 0.9467, train_auprc: 0.9539, val_auprc: 0.9387




Epoch: 186 (58.8230s), train_loss: 0.2333, val_loss: 0.3022, train_acc: 0.9098, val_acc:0.8839
		train_roc: 0.9635, val_roc: 0.9473, train_auprc: 0.9544, val_auprc: 0.9387




Epoch: 187 (58.8531s), train_loss: 0.2334, val_loss: 0.3039, train_acc: 0.9101, val_acc:0.8829
		train_roc: 0.9634, val_roc: 0.9467, train_auprc: 0.9544, val_auprc: 0.9382




Epoch: 188 (58.8278s), train_loss: 0.2321, val_loss: 0.3020, train_acc: 0.9106, val_acc:0.8840
		train_roc: 0.9640, val_roc: 0.9472, train_auprc: 0.9551, val_auprc: 0.9389




Epoch: 189 (58.8675s), train_loss: 0.2349, val_loss: 0.3031, train_acc: 0.9095, val_acc:0.8830
		train_roc: 0.9630, val_roc: 0.9470, train_auprc: 0.9534, val_auprc: 0.9390




Epoch: 190 (58.8684s), train_loss: 0.2328, val_loss: 0.3005, train_acc: 0.9102, val_acc:0.8839
		train_roc: 0.9637, val_roc: 0.9481, train_auprc: 0.9549, val_auprc: 0.9399




Epoch: 191 (59.0372s), train_loss: 0.2354, val_loss: 0.3045, train_acc: 0.9093, val_acc:0.8823
		train_roc: 0.9628, val_roc: 0.9466, train_auprc: 0.9531, val_auprc: 0.9384




Epoch: 192 (58.8452s), train_loss: 0.2338, val_loss: 0.3008, train_acc: 0.9105, val_acc:0.8840
		train_roc: 0.9632, val_roc: 0.9478, train_auprc: 0.9539, val_auprc: 0.9395




Epoch: 193 (58.9981s), train_loss: 0.2353, val_loss: 0.3038, train_acc: 0.9094, val_acc:0.8826
		train_roc: 0.9627, val_roc: 0.9467, train_auprc: 0.9530, val_auprc: 0.9381




Epoch: 194 (59.1001s), train_loss: 0.2319, val_loss: 0.3026, train_acc: 0.9109, val_acc:0.8822
		train_roc: 0.9639, val_roc: 0.9470, train_auprc: 0.9549, val_auprc: 0.9389




Epoch: 195 (58.9760s), train_loss: 0.2342, val_loss: 0.3050, train_acc: 0.9093, val_acc:0.8815
		train_roc: 0.9632, val_roc: 0.9462, train_auprc: 0.9542, val_auprc: 0.9379




Epoch: 196 (58.9213s), train_loss: 0.2340, val_loss: 0.3022, train_acc: 0.9103, val_acc:0.8831
		train_roc: 0.9631, val_roc: 0.9473, train_auprc: 0.9539, val_auprc: 0.9394




Epoch: 197 (58.9255s), train_loss: 0.2320, val_loss: 0.3030, train_acc: 0.9106, val_acc:0.8834
		train_roc: 0.9639, val_roc: 0.9469, train_auprc: 0.9552, val_auprc: 0.9382




Epoch: 198 (58.9745s), train_loss: 0.2341, val_loss: 0.3010, train_acc: 0.9101, val_acc:0.8837
		train_roc: 0.9632, val_roc: 0.9479, train_auprc: 0.9539, val_auprc: 0.9397




Epoch: 199 (58.8476s), train_loss: 0.2326, val_loss: 0.3018, train_acc: 0.9104, val_acc:0.8834
		train_roc: 0.9637, val_roc: 0.9474, train_auprc: 0.9546, val_auprc: 0.9392




Epoch: 200 (58.8010s), train_loss: 0.2319, val_loss: 0.3022, train_acc: 0.9108, val_acc:0.8839
		train_roc: 0.9639, val_roc: 0.9474, train_auprc: 0.9552, val_auprc: 0.9390




Epoch: 201 (59.0499s), train_loss: 0.2351, val_loss: 0.3003, train_acc: 0.9093, val_acc:0.8852
		train_roc: 0.9628, val_roc: 0.9482, train_auprc: 0.9535, val_auprc: 0.9399




Epoch: 202 (62.0247s), train_loss: 0.2345, val_loss: 0.3028, train_acc: 0.9096, val_acc:0.8826
		train_roc: 0.9630, val_roc: 0.9473, train_auprc: 0.9538, val_auprc: 0.9388




Epoch: 203 (59.0610s), train_loss: 0.2344, val_loss: 0.3040, train_acc: 0.9095, val_acc:0.8833
		train_roc: 0.9631, val_roc: 0.9468, train_auprc: 0.9539, val_auprc: 0.9382




Epoch: 204 (58.8875s), train_loss: 0.2318, val_loss: 0.3014, train_acc: 0.9102, val_acc:0.8837
		train_roc: 0.9641, val_roc: 0.9478, train_auprc: 0.9554, val_auprc: 0.9398




Epoch: 205 (58.9770s), train_loss: 0.2339, val_loss: 0.2996, train_acc: 0.9100, val_acc:0.8850
		train_roc: 0.9631, val_roc: 0.9485, train_auprc: 0.9538, val_auprc: 0.9408




Epoch: 206 (59.1223s), train_loss: 0.2325, val_loss: 0.3021, train_acc: 0.9104, val_acc:0.8835
		train_roc: 0.9638, val_roc: 0.9474, train_auprc: 0.9546, val_auprc: 0.9395




Epoch: 207 (58.9079s), train_loss: 0.2319, val_loss: 0.3006, train_acc: 0.9108, val_acc:0.8839
		train_roc: 0.9640, val_roc: 0.9480, train_auprc: 0.9552, val_auprc: 0.9403




Epoch: 208 (59.0487s), train_loss: 0.2321, val_loss: 0.3026, train_acc: 0.9102, val_acc:0.8838
		train_roc: 0.9639, val_roc: 0.9473, train_auprc: 0.9550, val_auprc: 0.9387




Epoch: 209 (58.9570s), train_loss: 0.2350, val_loss: 0.3044, train_acc: 0.9096, val_acc:0.8833
		train_roc: 0.9629, val_roc: 0.9466, train_auprc: 0.9535, val_auprc: 0.9377




Epoch: 210 (59.0345s), train_loss: 0.2336, val_loss: 0.3037, train_acc: 0.9097, val_acc:0.8825
		train_roc: 0.9635, val_roc: 0.9469, train_auprc: 0.9544, val_auprc: 0.9383




Epoch: 211 (58.9837s), train_loss: 0.2337, val_loss: 0.3032, train_acc: 0.9094, val_acc:0.8833
		train_roc: 0.9634, val_roc: 0.9470, train_auprc: 0.9544, val_auprc: 0.9384




Epoch: 212 (59.0158s), train_loss: 0.2343, val_loss: 0.3030, train_acc: 0.9093, val_acc:0.8839
		train_roc: 0.9631, val_roc: 0.9471, train_auprc: 0.9539, val_auprc: 0.9383




Epoch: 213 (58.9018s), train_loss: 0.2330, val_loss: 0.3005, train_acc: 0.9101, val_acc:0.8848
		train_roc: 0.9635, val_roc: 0.9481, train_auprc: 0.9545, val_auprc: 0.9400




Epoch: 214 (60.7912s), train_loss: 0.2341, val_loss: 0.3046, train_acc: 0.9100, val_acc:0.8815
		train_roc: 0.9630, val_roc: 0.9464, train_auprc: 0.9538, val_auprc: 0.9385




Saving model
Epoch: 215 (60.0379s), train_loss: 0.2335, val_loss: 0.2989, train_acc: 0.9107, val_acc:0.8845
		train_roc: 0.9632, val_roc: 0.9488, train_auprc: 0.9538, val_auprc: 0.9416




Epoch: 216 (58.7775s), train_loss: 0.2326, val_loss: 0.2995, train_acc: 0.9105, val_acc:0.8855
		train_roc: 0.9637, val_roc: 0.9486, train_auprc: 0.9545, val_auprc: 0.9404




Epoch: 217 (59.0293s), train_loss: 0.2333, val_loss: 0.3031, train_acc: 0.9103, val_acc:0.8830
		train_roc: 0.9633, val_roc: 0.9470, train_auprc: 0.9540, val_auprc: 0.9386




Epoch: 218 (58.9960s), train_loss: 0.2348, val_loss: 0.3035, train_acc: 0.9092, val_acc:0.8829
		train_roc: 0.9628, val_roc: 0.9469, train_auprc: 0.9535, val_auprc: 0.9384




Epoch: 219 (58.9461s), train_loss: 0.2324, val_loss: 0.3047, train_acc: 0.9102, val_acc:0.8825
		train_roc: 0.9636, val_roc: 0.9464, train_auprc: 0.9544, val_auprc: 0.9373




Epoch: 220 (58.9274s), train_loss: 0.2338, val_loss: 0.3032, train_acc: 0.9094, val_acc:0.8831
		train_roc: 0.9632, val_roc: 0.9470, train_auprc: 0.9540, val_auprc: 0.9383




Epoch: 221 (58.8998s), train_loss: 0.2338, val_loss: 0.3016, train_acc: 0.9098, val_acc:0.8835
		train_roc: 0.9633, val_roc: 0.9476, train_auprc: 0.9542, val_auprc: 0.9395




Epoch: 222 (58.8305s), train_loss: 0.2330, val_loss: 0.3005, train_acc: 0.9097, val_acc:0.8844
		train_roc: 0.9635, val_roc: 0.9482, train_auprc: 0.9543, val_auprc: 0.9402




Epoch: 223 (59.0411s), train_loss: 0.2319, val_loss: 0.3039, train_acc: 0.9110, val_acc:0.8826
		train_roc: 0.9639, val_roc: 0.9468, train_auprc: 0.9550, val_auprc: 0.9386




Epoch: 224 (59.1869s), train_loss: 0.2340, val_loss: 0.3004, train_acc: 0.9103, val_acc:0.8836
		train_roc: 0.9630, val_roc: 0.9482, train_auprc: 0.9535, val_auprc: 0.9405




Epoch: 225 (58.9094s), train_loss: 0.2334, val_loss: 0.3009, train_acc: 0.9101, val_acc:0.8840
		train_roc: 0.9635, val_roc: 0.9482, train_auprc: 0.9543, val_auprc: 0.9403




Epoch: 226 (58.7436s), train_loss: 0.2338, val_loss: 0.3043, train_acc: 0.9095, val_acc:0.8820
		train_roc: 0.9632, val_roc: 0.9466, train_auprc: 0.9543, val_auprc: 0.9384




Epoch: 227 (59.0824s), train_loss: 0.2320, val_loss: 0.3010, train_acc: 0.9109, val_acc:0.8843
		train_roc: 0.9639, val_roc: 0.9479, train_auprc: 0.9550, val_auprc: 0.9395




Epoch: 228 (58.9359s), train_loss: 0.2339, val_loss: 0.3031, train_acc: 0.9102, val_acc:0.8829
		train_roc: 0.9632, val_roc: 0.9472, train_auprc: 0.9538, val_auprc: 0.9387




Epoch: 229 (58.8476s), train_loss: 0.2324, val_loss: 0.3030, train_acc: 0.9108, val_acc:0.8838
		train_roc: 0.9637, val_roc: 0.9471, train_auprc: 0.9545, val_auprc: 0.9386




Epoch: 230 (58.8380s), train_loss: 0.2331, val_loss: 0.3050, train_acc: 0.9102, val_acc:0.8819
		train_roc: 0.9635, val_roc: 0.9462, train_auprc: 0.9544, val_auprc: 0.9378




Epoch: 231 (58.8528s), train_loss: 0.2341, val_loss: 0.3017, train_acc: 0.9100, val_acc:0.8835
		train_roc: 0.9631, val_roc: 0.9475, train_auprc: 0.9539, val_auprc: 0.9399




Epoch: 232 (58.8959s), train_loss: 0.2335, val_loss: 0.3012, train_acc: 0.9100, val_acc:0.8841
		train_roc: 0.9633, val_roc: 0.9477, train_auprc: 0.9542, val_auprc: 0.9394




Epoch: 233 (58.9887s), train_loss: 0.2328, val_loss: 0.3033, train_acc: 0.9102, val_acc:0.8831
		train_roc: 0.9636, val_roc: 0.9471, train_auprc: 0.9544, val_auprc: 0.9387




Epoch: 234 (58.8352s), train_loss: 0.2334, val_loss: 0.3036, train_acc: 0.9097, val_acc:0.8834
		train_roc: 0.9634, val_roc: 0.9469, train_auprc: 0.9544, val_auprc: 0.9383




Epoch: 235 (58.8418s), train_loss: 0.2340, val_loss: 0.3050, train_acc: 0.9103, val_acc:0.8825
		train_roc: 0.9631, val_roc: 0.9462, train_auprc: 0.9536, val_auprc: 0.9375




Epoch: 236 (59.2315s), train_loss: 0.2328, val_loss: 0.3013, train_acc: 0.9103, val_acc:0.8842
		train_roc: 0.9637, val_roc: 0.9478, train_auprc: 0.9548, val_auprc: 0.9398




Epoch: 237 (58.7954s), train_loss: 0.2334, val_loss: 0.3040, train_acc: 0.9099, val_acc:0.8829
		train_roc: 0.9634, val_roc: 0.9468, train_auprc: 0.9545, val_auprc: 0.9382




Epoch: 238 (58.9181s), train_loss: 0.2333, val_loss: 0.3013, train_acc: 0.9097, val_acc:0.8833
		train_roc: 0.9635, val_roc: 0.9478, train_auprc: 0.9543, val_auprc: 0.9400




Epoch: 239 (58.8557s), train_loss: 0.2355, val_loss: 0.3020, train_acc: 0.9091, val_acc:0.8839
		train_roc: 0.9625, val_roc: 0.9474, train_auprc: 0.9532, val_auprc: 0.9397




Epoch: 240 (58.9024s), train_loss: 0.2334, val_loss: 0.3028, train_acc: 0.9101, val_acc:0.8831
		train_roc: 0.9634, val_roc: 0.9472, train_auprc: 0.9543, val_auprc: 0.9394




Epoch: 241 (58.8053s), train_loss: 0.2340, val_loss: 0.3003, train_acc: 0.9097, val_acc:0.8846
		train_roc: 0.9632, val_roc: 0.9483, train_auprc: 0.9540, val_auprc: 0.9400




Epoch: 242 (58.8962s), train_loss: 0.2331, val_loss: 0.3021, train_acc: 0.9102, val_acc:0.8839
		train_roc: 0.9636, val_roc: 0.9475, train_auprc: 0.9546, val_auprc: 0.9394




Epoch: 243 (58.8673s), train_loss: 0.2328, val_loss: 0.3024, train_acc: 0.9104, val_acc:0.8836
		train_roc: 0.9636, val_roc: 0.9473, train_auprc: 0.9548, val_auprc: 0.9392




Epoch: 244 (58.7656s), train_loss: 0.2327, val_loss: 0.3014, train_acc: 0.9104, val_acc:0.8835
		train_roc: 0.9636, val_roc: 0.9477, train_auprc: 0.9546, val_auprc: 0.9397




Epoch: 245 (58.8110s), train_loss: 0.2333, val_loss: 0.3017, train_acc: 0.9105, val_acc:0.8845
		train_roc: 0.9635, val_roc: 0.9477, train_auprc: 0.9541, val_auprc: 0.9393




Epoch: 246 (58.9772s), train_loss: 0.2327, val_loss: 0.3026, train_acc: 0.9103, val_acc:0.8832
		train_roc: 0.9636, val_roc: 0.9472, train_auprc: 0.9546, val_auprc: 0.9391




Epoch: 247 (58.8857s), train_loss: 0.2347, val_loss: 0.3034, train_acc: 0.9097, val_acc:0.8829
		train_roc: 0.9629, val_roc: 0.9470, train_auprc: 0.9534, val_auprc: 0.9384




Epoch: 248 (59.0340s), train_loss: 0.2332, val_loss: 0.3035, train_acc: 0.9096, val_acc:0.8833
		train_roc: 0.9633, val_roc: 0.9470, train_auprc: 0.9543, val_auprc: 0.9378




Epoch: 249 (58.9304s), train_loss: 0.2338, val_loss: 0.3046, train_acc: 0.9096, val_acc:0.8823
		train_roc: 0.9632, val_roc: 0.9466, train_auprc: 0.9542, val_auprc: 0.9383




Epoch: 250 (58.7382s), train_loss: 0.2330, val_loss: 0.3002, train_acc: 0.9104, val_acc:0.8840
		train_roc: 0.9636, val_roc: 0.9481, train_auprc: 0.9546, val_auprc: 0.9404




Epoch: 251 (58.8308s), train_loss: 0.2338, val_loss: 0.3027, train_acc: 0.9100, val_acc:0.8830
		train_roc: 0.9632, val_roc: 0.9472, train_auprc: 0.9540, val_auprc: 0.9393




Epoch: 252 (59.2013s), train_loss: 0.2323, val_loss: 0.2989, train_acc: 0.9106, val_acc:0.8848
		train_roc: 0.9637, val_roc: 0.9486, train_auprc: 0.9548, val_auprc: 0.9412




Epoch: 253 (58.9159s), train_loss: 0.2326, val_loss: 0.3023, train_acc: 0.9105, val_acc:0.8834
		train_roc: 0.9636, val_roc: 0.9473, train_auprc: 0.9545, val_auprc: 0.9393




Epoch: 254 (58.9144s), train_loss: 0.2328, val_loss: 0.3032, train_acc: 0.9100, val_acc:0.8834
		train_roc: 0.9637, val_roc: 0.9472, train_auprc: 0.9544, val_auprc: 0.9387




Epoch: 255 (58.8677s), train_loss: 0.2348, val_loss: 0.3015, train_acc: 0.9098, val_acc:0.8839
		train_roc: 0.9627, val_roc: 0.9477, train_auprc: 0.9531, val_auprc: 0.9395




Epoch: 256 (58.8285s), train_loss: 0.2326, val_loss: 0.3029, train_acc: 0.9103, val_acc:0.8832
		train_roc: 0.9638, val_roc: 0.9472, train_auprc: 0.9551, val_auprc: 0.9391




Epoch: 257 (58.6958s), train_loss: 0.2339, val_loss: 0.3026, train_acc: 0.9099, val_acc:0.8834
		train_roc: 0.9632, val_roc: 0.9473, train_auprc: 0.9539, val_auprc: 0.9390




Epoch: 258 (58.8955s), train_loss: 0.2339, val_loss: 0.3002, train_acc: 0.9097, val_acc:0.8844
		train_roc: 0.9633, val_roc: 0.9484, train_auprc: 0.9541, val_auprc: 0.9410




Epoch: 259 (58.8635s), train_loss: 0.2326, val_loss: 0.3002, train_acc: 0.9104, val_acc:0.8843
		train_roc: 0.9636, val_roc: 0.9483, train_auprc: 0.9546, val_auprc: 0.9403




Epoch: 260 (58.8989s), train_loss: 0.2340, val_loss: 0.3028, train_acc: 0.9096, val_acc:0.8830
		train_roc: 0.9632, val_roc: 0.9472, train_auprc: 0.9540, val_auprc: 0.9395




Epoch: 261 (58.8128s), train_loss: 0.2349, val_loss: 0.3030, train_acc: 0.9096, val_acc:0.8831
		train_roc: 0.9630, val_roc: 0.9472, train_auprc: 0.9534, val_auprc: 0.9391




Epoch: 262 (58.9143s), train_loss: 0.2309, val_loss: 0.3048, train_acc: 0.9105, val_acc:0.8828
		train_roc: 0.9643, val_roc: 0.9463, train_auprc: 0.9557, val_auprc: 0.9376




Epoch: 263 (58.8435s), train_loss: 0.2329, val_loss: 0.3023, train_acc: 0.9104, val_acc:0.8842
		train_roc: 0.9636, val_roc: 0.9475, train_auprc: 0.9543, val_auprc: 0.9390




Epoch: 264 (58.8903s), train_loss: 0.2349, val_loss: 0.3000, train_acc: 0.9093, val_acc:0.8847
		train_roc: 0.9628, val_roc: 0.9484, train_auprc: 0.9536, val_auprc: 0.9407




Epoch: 265 (58.8966s), train_loss: 0.2334, val_loss: 0.3019, train_acc: 0.9100, val_acc:0.8833
		train_roc: 0.9634, val_roc: 0.9475, train_auprc: 0.9542, val_auprc: 0.9401




Epoch: 266 (58.7895s), train_loss: 0.2322, val_loss: 0.3006, train_acc: 0.9104, val_acc:0.8844
		train_roc: 0.9639, val_roc: 0.9482, train_auprc: 0.9549, val_auprc: 0.9401




Epoch: 267 (59.0506s), train_loss: 0.2336, val_loss: 0.3039, train_acc: 0.9101, val_acc:0.8826
		train_roc: 0.9634, val_roc: 0.9467, train_auprc: 0.9545, val_auprc: 0.9386




Epoch: 268 (58.9255s), train_loss: 0.2319, val_loss: 0.3005, train_acc: 0.9104, val_acc:0.8843
		train_roc: 0.9640, val_roc: 0.9479, train_auprc: 0.9552, val_auprc: 0.9404




Epoch: 269 (58.7927s), train_loss: 0.2332, val_loss: 0.3022, train_acc: 0.9103, val_acc:0.8834
		train_roc: 0.9633, val_roc: 0.9472, train_auprc: 0.9540, val_auprc: 0.9389




Epoch: 270 (58.9446s), train_loss: 0.2340, val_loss: 0.3026, train_acc: 0.9094, val_acc:0.8823
		train_roc: 0.9633, val_roc: 0.9474, train_auprc: 0.9542, val_auprc: 0.9395




Epoch: 271 (58.8096s), train_loss: 0.2330, val_loss: 0.3028, train_acc: 0.9101, val_acc:0.8823
		train_roc: 0.9635, val_roc: 0.9472, train_auprc: 0.9546, val_auprc: 0.9396




Epoch: 272 (58.7591s), train_loss: 0.2316, val_loss: 0.3017, train_acc: 0.9110, val_acc:0.8841
		train_roc: 0.9640, val_roc: 0.9478, train_auprc: 0.9550, val_auprc: 0.9397




Epoch: 273 (58.7555s), train_loss: 0.2346, val_loss: 0.2994, train_acc: 0.9095, val_acc:0.8846
		train_roc: 0.9630, val_roc: 0.9487, train_auprc: 0.9538, val_auprc: 0.9406




Epoch: 274 (58.8646s), train_loss: 0.2339, val_loss: 0.3011, train_acc: 0.9100, val_acc:0.8846
		train_roc: 0.9631, val_roc: 0.9481, train_auprc: 0.9539, val_auprc: 0.9395




Epoch: 275 (58.8283s), train_loss: 0.2327, val_loss: 0.3005, train_acc: 0.9104, val_acc:0.8845
		train_roc: 0.9636, val_roc: 0.9480, train_auprc: 0.9546, val_auprc: 0.9400




Epoch: 276 (58.9009s), train_loss: 0.2339, val_loss: 0.3023, train_acc: 0.9100, val_acc:0.8833
		train_roc: 0.9631, val_roc: 0.9474, train_auprc: 0.9539, val_auprc: 0.9393




Epoch: 277 (58.9494s), train_loss: 0.2341, val_loss: 0.3035, train_acc: 0.9098, val_acc:0.8831
		train_roc: 0.9631, val_roc: 0.9469, train_auprc: 0.9536, val_auprc: 0.9386




Epoch: 278 (58.8093s), train_loss: 0.2333, val_loss: 0.3011, train_acc: 0.9102, val_acc:0.8842
		train_roc: 0.9633, val_roc: 0.9479, train_auprc: 0.9540, val_auprc: 0.9397




Epoch: 279 (58.9135s), train_loss: 0.2342, val_loss: 0.3026, train_acc: 0.9096, val_acc:0.8832
		train_roc: 0.9631, val_roc: 0.9471, train_auprc: 0.9536, val_auprc: 0.9391




Epoch: 280 (58.7124s), train_loss: 0.2324, val_loss: 0.3022, train_acc: 0.9102, val_acc:0.8830
		train_roc: 0.9637, val_roc: 0.9474, train_auprc: 0.9548, val_auprc: 0.9394




Epoch: 281 (58.7743s), train_loss: 0.2328, val_loss: 0.3031, train_acc: 0.9105, val_acc:0.8831
		train_roc: 0.9636, val_roc: 0.9471, train_auprc: 0.9544, val_auprc: 0.9384




Epoch: 282 (59.0406s), train_loss: 0.2332, val_loss: 0.3043, train_acc: 0.9103, val_acc:0.8830
		train_roc: 0.9634, val_roc: 0.9465, train_auprc: 0.9543, val_auprc: 0.9380




Epoch: 283 (58.7456s), train_loss: 0.2333, val_loss: 0.3021, train_acc: 0.9102, val_acc:0.8833
		train_roc: 0.9633, val_roc: 0.9475, train_auprc: 0.9540, val_auprc: 0.9396




Epoch: 284 (58.9233s), train_loss: 0.2347, val_loss: 0.3016, train_acc: 0.9095, val_acc:0.8836
		train_roc: 0.9629, val_roc: 0.9475, train_auprc: 0.9537, val_auprc: 0.9395




Epoch: 285 (58.8288s), train_loss: 0.2334, val_loss: 0.2996, train_acc: 0.9103, val_acc:0.8847
		train_roc: 0.9635, val_roc: 0.9484, train_auprc: 0.9541, val_auprc: 0.9407




Epoch: 286 (58.8556s), train_loss: 0.2329, val_loss: 0.3031, train_acc: 0.9102, val_acc:0.8829
		train_roc: 0.9635, val_roc: 0.9471, train_auprc: 0.9545, val_auprc: 0.9389




Epoch: 287 (58.8995s), train_loss: 0.2333, val_loss: 0.3031, train_acc: 0.9098, val_acc:0.8829
		train_roc: 0.9636, val_roc: 0.9470, train_auprc: 0.9545, val_auprc: 0.9389




Epoch: 288 (58.8727s), train_loss: 0.2343, val_loss: 0.3043, train_acc: 0.9098, val_acc:0.8828
		train_roc: 0.9630, val_roc: 0.9464, train_auprc: 0.9532, val_auprc: 0.9380




Epoch: 289 (58.8055s), train_loss: 0.2336, val_loss: 0.3032, train_acc: 0.9097, val_acc:0.8826
		train_roc: 0.9634, val_roc: 0.9472, train_auprc: 0.9543, val_auprc: 0.9391




Epoch: 290 (58.8599s), train_loss: 0.2334, val_loss: 0.3004, train_acc: 0.9098, val_acc:0.8840
		train_roc: 0.9634, val_roc: 0.9483, train_auprc: 0.9543, val_auprc: 0.9405




Epoch: 291 (58.7778s), train_loss: 0.2337, val_loss: 0.3027, train_acc: 0.9095, val_acc:0.8831
		train_roc: 0.9633, val_roc: 0.9473, train_auprc: 0.9543, val_auprc: 0.9390




Epoch: 292 (58.8971s), train_loss: 0.2331, val_loss: 0.3024, train_acc: 0.9102, val_acc:0.8831
		train_roc: 0.9635, val_roc: 0.9473, train_auprc: 0.9542, val_auprc: 0.9394




Epoch: 293 (58.8634s), train_loss: 0.2325, val_loss: 0.3012, train_acc: 0.9108, val_acc:0.8848
		train_roc: 0.9636, val_roc: 0.9478, train_auprc: 0.9545, val_auprc: 0.9395




Epoch: 294 (58.8150s), train_loss: 0.2341, val_loss: 0.3014, train_acc: 0.9097, val_acc:0.8833
		train_roc: 0.9631, val_roc: 0.9476, train_auprc: 0.9539, val_auprc: 0.9402




Epoch: 295 (58.9265s), train_loss: 0.2322, val_loss: 0.3029, train_acc: 0.9109, val_acc:0.8836
		train_roc: 0.9637, val_roc: 0.9472, train_auprc: 0.9547, val_auprc: 0.9386




Epoch: 296 (58.8158s), train_loss: 0.2332, val_loss: 0.2997, train_acc: 0.9099, val_acc:0.8849
		train_roc: 0.9636, val_roc: 0.9485, train_auprc: 0.9547, val_auprc: 0.9407




Epoch: 297 (58.7923s), train_loss: 0.2326, val_loss: 0.3049, train_acc: 0.9104, val_acc:0.8821
		train_roc: 0.9638, val_roc: 0.9463, train_auprc: 0.9547, val_auprc: 0.9380




Epoch: 298 (58.8747s), train_loss: 0.2327, val_loss: 0.3030, train_acc: 0.9103, val_acc:0.8831
		train_roc: 0.9636, val_roc: 0.9471, train_auprc: 0.9548, val_auprc: 0.9385




Epoch: 299 (58.8911s), train_loss: 0.2326, val_loss: 0.3020, train_acc: 0.9104, val_acc:0.8836
		train_roc: 0.9637, val_roc: 0.9475, train_auprc: 0.9549, val_auprc: 0.9398




Epoch: 300 (58.7625s), train_loss: 0.2333, val_loss: 0.3036, train_acc: 0.9099, val_acc:0.8821
		train_roc: 0.9634, val_roc: 0.9470, train_auprc: 0.9544, val_auprc: 0.9390


In [18]:
# Predict
model = torch.load(model_file)
print(model)
model.to(device=device)
predict(model, test_data_loader, device)

SSI_DDI(
  (initial_norm): LayerNorm(55, affine=True, mode=graph)
  (net_norms): ModuleList(
    (0-3): 4 x LayerNorm(64, affine=True, mode=graph)
  )
  (block0): SSI_DDI_Block(
    (conv): GATConv(55, 32, heads=2)
    (readout): SAGPooling(GraphConv, 64, min_score=-1, multiplier=1.0)
  )
  (block1): SSI_DDI_Block(
    (conv): GATConv(64, 32, heads=2)
    (readout): SAGPooling(GraphConv, 64, min_score=-1, multiplier=1.0)
  )
  (block2): SSI_DDI_Block(
    (conv): GATConv(64, 32, heads=2)
    (readout): SAGPooling(GraphConv, 64, min_score=-1, multiplier=1.0)
  )
  (block3): SSI_DDI_Block(
    (conv): GATConv(64, 32, heads=2)
    (readout): SAGPooling(GraphConv, 64, min_score=-1, multiplier=1.0)
  )
  (co_attention): CoAttentionLayer()
  (KGE): RESCAL(86, torch.Size([86, 4096]))
)
Starting predicting at 2024-10-24 00:35:23.436567
Device cuda


  model = torch.load(model_file)


Test Accuracy: 0.8833
Test ROC AUC: 0.9477
Test PRC AUC: 0.9383
