In [1]:
# !pip install torch_geometric rdkit torch

In [2]:
from datetime import datetime
import time
import argparse
import sys
import torch
from torch import optim
from torch import nn
import torch.nn.functional as F
from sklearn import metrics
import pandas as pd
import numpy as np
from torch.nn.modules.container import ModuleList
from torch_geometric.nn import (
    GATConv,
    SAGPooling,
    LayerNorm,
    global_mean_pool,
    max_pool_neighbor_x,
    global_add_pool,
)


In [3]:
# Directory configuration
data_dir = "data"
model_dir = "models"
model_name = "case17"

# sys.path.append('/content/drive/MyDrive/Colab Notebooks')

In [4]:
####### Tunning parameters #######

# Number of epochs
n_epochs = 300

# SagPooling ratio & min score. 
# Set sp_ratio to None to disable ratio in SagPooling
sp_ratio = None
sp_min_score = None

# Enable using gpu
use_cuda = True

# Use activation function for CoAttention Layer
use_activation_fn = False

# Use ComplEx instead of RESCAL
use_ComplEx = False

# Use improved CoAttention Layer
use_improved_CoAttention = True

# Use Explicit Valence
use_explicit_valence = True

# Number of GAT layers
num_GAT_layers = 5

# Number of GAT multiheads
num_GAT_multiheads = 2

#################################

In [5]:
# If using explicit valence feature
if use_explicit_valence:
    from data_preprocessing_explicit_valence import DrugDataset, DrugDataLoader, TOTAL_ATOM_FEATS
else:
    from data_preprocessing import DrugDataset, DrugDataLoader, TOTAL_ATOM_FEATS

  return undirected_edge_list.T, features


In [6]:
mode = "train"

n_atom_feats = TOTAL_ATOM_FEATS
# Not use
n_atom_hid = 64
# Total interactions information in the Interaction_information.csv
rel_total = 86
lr = 1e-2
weight_decay = 5e-4
neg_samples = 1
# Represents the number of samples (or graph instances) loaded in each batch during the training process.
batch_size = 1024
data_size_ratio = 1
kge_dim = 64

device = "cuda" if torch.cuda.is_available() and use_cuda else "cpu"

print(device)
print(f"Epochs: {n_epochs}")
print(f"Total of atom features: {TOTAL_ATOM_FEATS}")

cuda
Epochs: 300
Total of atom features: 56


In [7]:
def print_tunning_parameters():
    print()
    print("####### Tunning parameters #######")
    print()
    
    print("n_epochs =", n_epochs)
    print("use_cuda =", use_cuda)
    print()
    print("num_GAT_layers = ", num_GAT_layers)
    print("num_GAT_multiheads = ", num_GAT_multiheads)
    print()
    print("sp_ratio =", sp_ratio)
    print("sp_min_score =", sp_min_score)
    print()
    print("use_explicit_valence =", use_explicit_valence)
    print()
    print("use_activation_fn =", use_activation_fn)
    print()
    print("use_ComplEx =", use_ComplEx)
    print()
    print("use_improved_CoAttention =", use_improved_CoAttention)
    
    print()
    print("#################################")
    print()


In [8]:
class CoAttentionLayer(nn.Module):
    def __init__(self, n_features, use_activation_fn=True):
        super().__init__()
        self.n_features = n_features
        self.w_q = nn.Parameter(torch.zeros(n_features, n_features // 2))
        self.w_k = nn.Parameter(torch.zeros(n_features, n_features // 2))
        self.bias = nn.Parameter(torch.zeros(n_features // 2))
        self.a = nn.Parameter(torch.zeros(n_features // 2))
        self.use_activation_fn = use_activation_fn

        nn.init.xavier_uniform_(self.w_q)
        nn.init.xavier_uniform_(self.w_k)
        nn.init.xavier_uniform_(self.bias.view(*self.bias.shape, -1))
        nn.init.xavier_uniform_(self.a.view(*self.a.shape, -1))

    def forward(self, receiver, attendant):
        # receiver.shape  = (1024, 4, 64)
        # attendant.shape = (1024, 4, 64)
        keys = receiver @ self.w_k
        queries = attendant @ self.w_q
        # values = receiver @ self.w_v
        values = receiver

        # queries.shape = (1024, 4, 32)
        # keys.shape = (1024, 4, 32)
        e_activations = queries.unsqueeze(-3) + keys.unsqueeze(-2) + self.bias
        if self.use_activation_fn:
            e_scores = torch.tanh(e_activations) @ self.a
        else:
            e_scores = e_activations @ self.a
        attentions = e_scores

        return attentions

class CoAttentionLayerImproved(nn.Module):
    def __init__(self, n_features, use_activation_fn=True, dropout=0.1, n_heads=2):
        super().__init__()
        self.n_features = n_features
        self.n_heads = n_heads
        self.head_dim = n_features // n_heads

        # Projects for queries and keys per head
        self.w_q = nn.Parameter(torch.zeros(self.head_dim, self.head_dim // 2))
        self.w_k = nn.Parameter(torch.zeros(self.head_dim, self.head_dim // 2))
        self.bias = nn.Parameter(torch.zeros(self.n_features // 2))
        self.a = nn.Parameter(torch.zeros(self.n_features // 2))
        self.use_activation_fn = use_activation_fn

        self.dropout = nn.Dropout(dropout)

        nn.init.xavier_uniform_(self.w_q)
        nn.init.xavier_uniform_(self.w_k)
        nn.init.xavier_uniform_(self.bias.view(*self.bias.shape, -1))
        nn.init.xavier_uniform_(self.a.view(*self.a.shape, -1))

    def forward(self, receiver, attendant):
        # receiver.shape  = (1024, 4, 64)
        # attendant.shape = (1024, 4, 64)

        # Split reciever and attendant into multiple heads
        batch_size, gat_size, n_features = receiver.shape
        receiver = receiver.view(batch_size, gat_size, self.n_heads, self.head_dim)
        attendant = attendant.view(batch_size, gat_size, self.n_heads, self.head_dim)
        
        # Compute keys and queries per head
        # receiver.shape  = (1024, 4, 2, 32)
        # attendant.shape = (1024, 4, 2, 32)
        
        # self.w_k.shape  = (32, 16)
        # self.w_q.shape  = (32, 16)
        
        # self.keys.shape     = (1024, 4, 2, 16)
        # self.queries.shape  = (1024, 4, 2, 16)
        keys = receiver @ self.w_k
        queries = attendant @ self.w_q

        # self.keys.shape     = (1024, 4, 32)
        # self.queries.shape  = (1024, 4, 32)
        keys    = keys.view(batch_size, gat_size, self.head_dim)
        queries = queries.view(batch_size, gat_size, self.head_dim)
        # print("keys.shape", keys.shape)
        # print("queries.shape", queries.shape)

        # e_activations.shape = (1024, 4, 4, 32)
        # self.a.shape = (32,)
        e_activations = queries.unsqueeze(-3) + keys.unsqueeze(-2) + self.bias
        if self.use_activation_fn:
            e_scores = torch.tanh(e_activations) @ self.a
        else:
            e_scores = e_activations @ self.a

        # attentions.shape = (1024, 4, 4)
        attentions = e_scores

        return attentions


class RESCAL(nn.Module):
    def __init__(self, n_rels, n_features):
        """
        n_rels: number of relations = 86
        n_features: kge_dim = 64
        """
        super().__init__()
        self.n_rels = n_rels
        self.n_features = n_features
        # Embedding layer
        self.rel_emb = nn.Embedding(self.n_rels, n_features * n_features)
        #  Initializes the embedding weights with the Xavier uniform distribution, which helps maintain the scale of gradients during training
        nn.init.xavier_uniform_(self.rel_emb.weight)

    def forward(self, heads, tails, rels, alpha_scores):
        rels = self.rel_emb(rels)
        rels = F.normalize(rels, dim=-1)
        heads = F.normalize(heads, dim=-1)
        tails = F.normalize(tails, dim=-1)
        # print(rels.shape)
        # Convert shape (1024, 4096) to (1024, 64, 64) for dot product
        rels = rels.view(-1, self.n_features, self.n_features)
        # print(rels.shape)
        # (1024, 4, 64) @ (1024, 64, 64) = (1024, 4, 64) @ (1024, 64, 4) = (1024, 4, 4)
        scores = heads @ rels @ tails.transpose(-2, -1)

        # alpha_scores.shape = (1024, 4, 4)
        # scores.shape = (1024, 4, 4)
        if alpha_scores is not None:
            scores = alpha_scores * scores
        # print(scores.shape)
        
        # sum the last 2 dimensions
        scores = scores.sum(dim=(-2, -1))
        
        # print(scores.shape)
        # Shape(1024,)
        return scores

    def __repr__(self):
        return f"{self.__class__.__name__}({self.n_rels}, {self.rel_emb.weight.shape})"



class ComplEx(nn.Module):
    def __init__(self, n_rels, n_features):
        super().__init__()
        self.n_rels = n_rels
        self.n_features = n_features
    
        # Relation embeddings are also complex
        self.rel_real = nn.Embedding(self.n_rels, (self.n_features // 2) * (self.n_features // 2))
        self.rel_imag = nn.Embedding(self.n_rels, (self.n_features // 2) * (self.n_features // 2))
        
        # Initialize embeddings
        nn.init.xavier_uniform_(self.rel_real.weight)
        nn.init.xavier_uniform_(self.rel_imag.weight)

    def forward(self, heads, tails, rels, alpha_scores=None):
        # Preprocess
        heads = F.normalize(heads, dim=-1)
        tails = F.normalize(tails, dim=-1)
        
        r_real, r_imag = self.rel_real(rels), self.rel_imag(rels)
        r_real = F.normalize(r_real, dim=-1)
        r_imag = F.normalize(r_imag, dim=-1)
        # print(r_real.shape)
        r_real = r_real.view(-1, self.n_features // 2, self.n_features // 2)
        r_imag = r_imag.view(-1, self.n_features // 2, self.n_features // 2)
        # print(r_real.shape)
        # Split heads and tails to imaginary parts
        h_real, h_imag = heads[..., :self.n_features // 2], heads[..., self.n_features // 2:]
        t_real, t_imag = tails[..., :self.n_features // 2], heads[..., self.n_features // 2:]

        # ComplEx scoring functionn
        first_part_score = h_real @ r_real @ t_real.transpose(-2, -1)
        second_part_score = h_real @ r_imag @ t_imag.transpose(-2, -1)
        third_part_score = h_imag @ r_real @ t_imag.transpose(-2, -1)
        fourth_part_score = h_imag @ r_imag @ t_real.transpose(-2, -1)

        scores = first_part_score + second_part_score + third_part_score + fourth_part_score
        
        # If alpha_scores is provided, apply it
        if alpha_scores is not None:
            scores = alpha_scores * scores

        scores = scores.sum(dim=(-2, -1))
        
        return scores

    def __repr__(self):
        return f"{self.__class__.__name__}({self.n_rels}, {self.rel_real.weight.shape}, {self.rel_imag.weight.shape})"


In [9]:
class SSI_DDI(nn.Module):
    def __init__(
        self,
        in_features,
        hidd_dim,
        kge_dim,
        rel_total,
        heads_out_feat_params,
        blocks_params,
        sp_ratio,
        use_activation_fn,
        use_ComplEx,
        sp_min_score,
        use_improved_CoAttention,
    ):
        """
        blocks_params: list of number layers for multi-head attentions
        """
        super().__init__()
        self.in_features = in_features
        # not using this one
        self.hidd_dim = hidd_dim
        self.rel_total = rel_total
        self.kge_dim = kge_dim
        self.n_blocks = len(blocks_params)

        self.initial_norm = LayerNorm(self.in_features)
        self.blocks = []
        self.use_activation_fn = use_activation_fn
        self.use_ComplEx = use_ComplEx
        # Layer normalization list
        self.net_norms = ModuleList()
        for i, (head_out_feats, n_heads) in enumerate(
            zip(heads_out_feat_params, blocks_params)
        ):
            block = SSI_DDI_Block(
                n_heads, in_features, head_out_feats, final_out_feats=self.hidd_dim, sp_ratio=sp_ratio, sp_min_score=sp_min_score
            )
            self.add_module(f"block{i}", block)
            self.blocks.append(block)
            self.net_norms.append(LayerNorm(head_out_feats * n_heads))
            in_features = head_out_feats * n_heads

        if use_improved_CoAttention:
            self.co_attention = CoAttentionLayerImproved(self.kge_dim, self.use_activation_fn)
        else:
            self.co_attention = CoAttentionLayer(self.kge_dim, self.use_activation_fn)
            
        if self.use_ComplEx:
            self.KGE = ComplEx(self.rel_total, self.kge_dim)
        else:
            self.KGE = RESCAL(self.rel_total, self.kge_dim)

    def forward(self, triples):
        h_data, t_data, rels = triples

        h_data.x = self.initial_norm(h_data.x, h_data.batch)
        t_data.x = self.initial_norm(t_data.x, t_data.batch)

        repr_h = []
        repr_t = []

        for i, block in enumerate(self.blocks):
            out1, out2 = block(h_data), block(t_data)

            h_data = out1[0]
            t_data = out2[0]
            r_h = out1[1]
            r_t = out2[1]

            repr_h.append(r_h)
            repr_t.append(r_t)

            h_data.x = F.elu(self.net_norms[i](h_data.x, h_data.batch))
            t_data.x = F.elu(self.net_norms[i](t_data.x, t_data.batch))

        repr_h = torch.stack(repr_h, dim=-2)
        repr_t = torch.stack(repr_t, dim=-2)

        kge_heads = repr_h
        kge_tails = repr_t

        attentions = self.co_attention(kge_heads, kge_tails)
        # attentions = None
        scores = self.KGE(kge_heads, kge_tails, rels, attentions)

        return scores


class SSI_DDI_Block(nn.Module):
    def __init__(self, n_heads, in_features, head_out_feats, final_out_feats, sp_ratio, sp_min_score):
        """
        n_heades: number of multi-head attentions = 2
        in_features: number of features = 55 . For explicit valence use, number of features = 56.
        head_out_feats: number of out features. For 4 layers: [32, 32, 32, 32]
        sp_ratio: SAGPooling ratio
        """
        super().__init__()
        self.n_heads = n_heads
        self.in_features = in_features
        self.out_features = head_out_feats
        self.conv = GATConv(in_features, head_out_feats, n_heads)
        # SAGPooling: Ranks nodes based on self-attention scores

        if sp_ratio is None and sp_min_score is None:
            self.readout = SAGPooling(n_heads * head_out_feats, min_score=-1)
        else:
            if sp_ratio is not None:
                self.readout = SAGPooling(n_heads * head_out_feats, min_score=sp_min_score, ratio=sp_ratio)
            else:
                self.readout = SAGPooling(n_heads * head_out_feats, min_score=sp_min_score)

    def forward(self, data):
        data.x = self.conv(data.x, data.edge_index)
        # Call SAGPooling here
        # If min_score = -1 so nodes will not be filtered out, basically redudant for using the SAGPooling.
        att_x, att_edge_index, att_edge_attr, att_batch, att_perm, att_scores = (
            self.readout(data.x, data.edge_index, batch=data.batch)
        )
        # Aggregates the pooled node features (att_x) across the graph to obtain a global representation
        global_graph_emb = global_add_pool(att_x, att_batch)

        # data = max_pool_neighbor_x(data)
        return data, global_graph_emb


In [10]:
class SigmoidLoss(nn.Module):
    def __init__(self, adv_temperature=None):
        super().__init__()
        self.adv_temperature = adv_temperature

    def forward(self, p_scores, n_scores):
        if self.adv_temperature:
            weights = F.softmax(self.adv_temperature * n_scores, dim=-1).detach()
            n_scores = weights * n_scores
        p_loss = -F.logsigmoid(p_scores).mean()
        n_loss = -F.logsigmoid(-n_scores).mean()

        return (p_loss + n_loss) / 2, p_loss, n_loss


In [11]:
df_ddi_train = pd.read_csv(f"{data_dir}/ddi_training.csv")
df_ddi_val = pd.read_csv(f"{data_dir}/ddi_validation.csv")
df_ddi_test = pd.read_csv(f"{data_dir}/ddi_test.csv")


train_tup = [
    (h, t, r)
    for h, t, r in zip(df_ddi_train["d1"], df_ddi_train["d2"], df_ddi_train["type"])
]
val_tup = [
    (h, t, r) for h, t, r in zip(df_ddi_val["d1"], df_ddi_val["d2"], df_ddi_val["type"])
]
test_tup = [
    (h, t, r)
    for h, t, r in zip(df_ddi_test["d1"], df_ddi_test["d2"], df_ddi_test["type"])
]

train_data = DrugDataset(train_tup, ratio=data_size_ratio, neg_ent=neg_samples)
val_data = DrugDataset(val_tup, ratio=data_size_ratio, disjoint_split=False)
test_data = DrugDataset(test_tup, disjoint_split=False)

print(
    f"Training with {len(train_data)} samples, validating with {len(val_data)}, and testing with {len(test_data)}"
)

train_data_loader = DrugDataLoader(train_data, batch_size=batch_size, shuffle=True)
val_data_loader = DrugDataLoader(val_data, batch_size=batch_size * 3)
test_data_loader = DrugDataLoader(test_data, batch_size=batch_size * 3)


Training with 115185 samples, validating with 38348, and testing with 38337


In [12]:
def do_compute(model, batch, device, training=True):
    """
    *batch: (pos_tri, neg_tri)
    *pos/neg_tri: (batch_h, batch_t, batch_r)
    """
    probas_pred, ground_truth = [], []
    pos_tri, neg_tri = batch

    pos_tri = [tensor.to(device=device) for tensor in pos_tri]
    p_score = model(pos_tri)
    probas_pred.append(torch.sigmoid(p_score.detach()).cpu())
    ground_truth.append(np.ones(len(p_score)))

    neg_tri = [tensor.to(device=device) for tensor in neg_tri]
    n_score = model(neg_tri)
    probas_pred.append(torch.sigmoid(n_score.detach()).cpu())
    ground_truth.append(np.zeros(len(n_score)))

    probas_pred = np.concatenate(probas_pred)
    ground_truth = np.concatenate(ground_truth)

    return p_score, n_score, probas_pred, ground_truth


def do_compute_metrics(probas_pred, target):

    pred = (probas_pred >= 0.5).astype(np.int64)

    acc = metrics.accuracy_score(target, pred)
    auc_roc = metrics.roc_auc_score(target, probas_pred)
    f1_score = metrics.f1_score(target, pred)

    p, r, t = metrics.precision_recall_curve(target, probas_pred)
    auc_prc = metrics.auc(r, p)

    return acc, auc_roc, auc_prc

In [13]:
import csv
def export_metrics(train_metrics, val_metrics, epoch):
    train_metrics_dir = "train_metrics"
    metrics_file = f"{train_metrics_dir}/{model_name}.csv"
    train_loss, train_acc, train_auc_roc, train_auc_prc = train_metrics
    val_loss, val_acc, val_auc_roc, val_auc_prc = val_metrics

    data = [epoch, train_loss, train_acc, train_auc_roc, train_auc_prc, val_loss, val_acc, val_auc_roc, val_auc_prc]
    header = ["epoch", "train_loss", "train_acc", "train_auc_roc", "train_auc_prc", "val_loss", "val_acc", "val_auc_roc", "val_auc_prc"]
    
    if epoch == 1:
        with open(metrics_file, 'w', newline='') as file:
            writer = csv.writer(file)
            # Write the header
            writer.writerow(header)
            # Write the data rows
            writer.writerow(data)
    else:
        with open(metrics_file, 'a', newline='') as file:
            writer = csv.writer(file)
            # Write the data to the file
            writer.writerow(data)
    
    

In [14]:
def train(
    model,
    train_data_loader,
    val_data_loader,
    loss_fn,
    optimizer,
    n_epochs,
    device,
    scheduler=None,
):
    print("Starting training at:", datetime.today())
    print("Device:", device)
    print_tunning_parameters()
    best_val_auc_prc = 0
    for i in range(1, n_epochs + 1):
        start = time.time()
        train_loss = 0
        train_loss_pos = 0
        train_loss_neg = 0
        val_loss = 0
        val_loss_pos = 0
        val_loss_neg = 0
        train_probas_pred = []
        train_ground_truth = []
        val_probas_pred = []
        val_ground_truth = []

        for batch in train_data_loader:
            # print(len(batch))
            model.train()
            p_score, n_score, probas_pred, ground_truth = do_compute(model, batch, device)
            train_probas_pred.append(probas_pred)
            train_ground_truth.append(ground_truth)
            loss, loss_p, loss_n = loss_fn(p_score, n_score)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_loss += loss.item() * len(p_score)
        train_loss /= len(train_data)

        with torch.no_grad():
            train_probas_pred = np.concatenate(train_probas_pred)
            train_ground_truth = np.concatenate(train_ground_truth)

            train_acc, train_auc_roc, train_auc_prc = do_compute_metrics(
                train_probas_pred, train_ground_truth
            )

            for batch in val_data_loader:
                model.eval()
                p_score, n_score, probas_pred, ground_truth = do_compute(model, batch, device)
                val_probas_pred.append(probas_pred)
                val_ground_truth.append(ground_truth)
                loss, loss_p, loss_n = loss_fn(p_score, n_score)
                val_loss += loss.item() * len(p_score)

            val_loss /= len(val_data)
            val_probas_pred = np.concatenate(val_probas_pred)
            val_ground_truth = np.concatenate(val_ground_truth)
            val_acc, val_auc_roc, val_auc_prc = do_compute_metrics(
                val_probas_pred, val_ground_truth
            )
            
            # Save model if this is best result based on val_auc_prc
            if best_val_auc_prc < val_auc_prc:
                print("Saving model")
                best_val_auc_prc = val_auc_prc
                torch.save(model, model_file)

        if scheduler:
            # print('scheduling')
            scheduler.step()

        # Exporting metrics for later plots
        train_metrics = (train_loss, train_acc, train_auc_roc, train_auc_prc)
        val_metrics = (val_loss, val_acc, val_auc_roc, val_auc_prc)
        export_metrics(train_metrics, val_metrics, i)
        
        print(
            f"Epoch: {i} ({time.time() - start:.4f}s), train_loss: {train_loss:.4f}, val_loss: {val_loss:.4f},"
            f" train_acc: {train_acc:.4f}, val_acc:{val_acc:.4f}"
        )
        print(
            f"\t\ttrain_roc: {train_auc_roc:.4f}, val_roc: {val_auc_roc:.4f}, train_auprc: {train_auc_prc:.4f}, val_auprc: {val_auc_prc:.4f}"
        )

    return model

In [15]:
def predict(model, test_data_loader, device):
    print('Starting predicting at', datetime.today())
    print('Device', device)

    test_probas_pred = []
    test_ground_truth = []

    # Switch to evaluation mode
    model.eval()

    with torch.no_grad():  # No need to calculate gradients during testing
        for batch in test_data_loader:
            # Get predictions and ground truth for the batch
            p_score, n_score, probas_pred, ground_truth = do_compute(model, batch, device, training=False)

            # Append the predictions and ground truths
            test_probas_pred.append(probas_pred)
            test_ground_truth.append(ground_truth)


    # Concatenate the results for the entire test dataset
    test_probas_pred = np.concatenate(test_probas_pred)
    test_ground_truth = np.concatenate(test_ground_truth)

    # Calculate the metrics for the test dataset
    test_acc, test_auc_roc, test_auc_prc = do_compute_metrics(test_probas_pred, test_ground_truth)

    print(f'Test Accuracy: {test_acc:.4f}')
    print(f'Test ROC AUC: {test_auc_roc:.4f}')
    print(f'Test PRC AUC: {test_auc_prc:.4f}')

In [16]:
model_file = f"{model_dir}/{model_name}.pth"

heads_out_feat_params = []
block_params = []

for _ in range(num_GAT_layers):
    heads_out_feat_params.append(kge_dim // 2)
    block_params.append(num_GAT_multiheads)

if mode == "train":
    model = SSI_DDI(
        n_atom_feats,
        n_atom_hid,
        kge_dim,
        rel_total,
        heads_out_feat_params=heads_out_feat_params,
        blocks_params=block_params,
        sp_ratio=sp_ratio,
        use_activation_fn=use_activation_fn,
        use_ComplEx=use_ComplEx,
        sp_min_score=sp_min_score,
        use_improved_CoAttention=use_improved_CoAttention,
    )
    loss = SigmoidLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    scheduler = optim.lr_scheduler.LambdaLR(optimizer, lambda epoch: 0.96 ** (epoch))
    print(model)
    model.to(device=device)

SSI_DDI(
  (initial_norm): LayerNorm(56, affine=True, mode=graph)
  (net_norms): ModuleList(
    (0-4): 5 x LayerNorm(64, affine=True, mode=graph)
  )
  (block0): SSI_DDI_Block(
    (conv): GATConv(56, 32, heads=2)
    (readout): SAGPooling(GraphConv, 64, min_score=-1, multiplier=1.0)
  )
  (block1): SSI_DDI_Block(
    (conv): GATConv(64, 32, heads=2)
    (readout): SAGPooling(GraphConv, 64, min_score=-1, multiplier=1.0)
  )
  (block2): SSI_DDI_Block(
    (conv): GATConv(64, 32, heads=2)
    (readout): SAGPooling(GraphConv, 64, min_score=-1, multiplier=1.0)
  )
  (block3): SSI_DDI_Block(
    (conv): GATConv(64, 32, heads=2)
    (readout): SAGPooling(GraphConv, 64, min_score=-1, multiplier=1.0)
  )
  (block4): SSI_DDI_Block(
    (conv): GATConv(64, 32, heads=2)
    (readout): SAGPooling(GraphConv, 64, min_score=-1, multiplier=1.0)
  )
  (co_attention): CoAttentionLayerImproved(
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (KGE): RESCAL(86, torch.Size([86, 4096]))
)


In [17]:
if mode == "train":
  # Train
  train(
      model,
      train_data_loader,
      val_data_loader,
      loss,
      optimizer,
      n_epochs,
      device,
      scheduler,
  )


Starting training at: 2024-10-24 13:39:05.515876
Device: cuda

####### Tunning parameters #######

n_epochs = 300
use_cuda = True

num_GAT_layers =  5
num_GAT_multiheads =  2

sp_ratio = None
sp_min_score = None

use_explicit_valence = True

use_activation_fn = False

use_ComplEx = False

use_improved_CoAttention = True

#################################





Saving model
Epoch: 1 (70.3455s), train_loss: 0.6721, val_loss: 0.6300, train_acc: 0.5717, val_acc:0.6309
		train_roc: 0.6083, val_roc: 0.6896, train_auprc: 0.5933, val_auprc: 0.6675




Saving model
Epoch: 2 (71.5757s), train_loss: 0.6046, val_loss: 0.5782, train_acc: 0.6623, val_acc:0.6901
		train_roc: 0.7262, val_roc: 0.7595, train_auprc: 0.7048, val_auprc: 0.7425




Saving model
Epoch: 3 (77.6408s), train_loss: 0.5602, val_loss: 0.5568, train_acc: 0.7059, val_acc:0.7074
		train_roc: 0.7781, val_roc: 0.7814, train_auprc: 0.7575, val_auprc: 0.7618




Saving model
Epoch: 4 (75.1162s), train_loss: 0.5424, val_loss: 0.5338, train_acc: 0.7219, val_acc:0.7294
		train_roc: 0.7960, val_roc: 0.8037, train_auprc: 0.7754, val_auprc: 0.7813




Saving model
Epoch: 5 (75.3001s), train_loss: 0.5286, val_loss: 0.5237, train_acc: 0.7311, val_acc:0.7369
		train_roc: 0.8079, val_roc: 0.8139, train_auprc: 0.7868, val_auprc: 0.7946




Saving model
Epoch: 6 (70.2891s), train_loss: 0.5164, val_loss: 0.5160, train_acc: 0.7420, val_acc:0.7464
		train_roc: 0.8194, val_roc: 0.8216, train_auprc: 0.7989, val_auprc: 0.8009




Saving model
Epoch: 7 (66.9005s), train_loss: 0.5083, val_loss: 0.5029, train_acc: 0.7485, val_acc:0.7520
		train_roc: 0.8258, val_roc: 0.8315, train_auprc: 0.8053, val_auprc: 0.8120




Saving model
Epoch: 8 (66.3992s), train_loss: 0.4982, val_loss: 0.4998, train_acc: 0.7542, val_acc:0.7562
		train_roc: 0.8332, val_roc: 0.8351, train_auprc: 0.8140, val_auprc: 0.8146




Saving model
Epoch: 9 (66.6187s), train_loss: 0.4914, val_loss: 0.4910, train_acc: 0.7604, val_acc:0.7605
		train_roc: 0.8392, val_roc: 0.8403, train_auprc: 0.8181, val_auprc: 0.8194




Saving model
Epoch: 10 (66.7298s), train_loss: 0.4863, val_loss: 0.4842, train_acc: 0.7635, val_acc:0.7636
		train_roc: 0.8425, val_roc: 0.8438, train_auprc: 0.8226, val_auprc: 0.8244




Saving model
Epoch: 11 (66.4306s), train_loss: 0.4788, val_loss: 0.4873, train_acc: 0.7688, val_acc:0.7612
		train_roc: 0.8484, val_roc: 0.8482, train_auprc: 0.8283, val_auprc: 0.8291




Saving model
Epoch: 12 (68.6069s), train_loss: 0.4733, val_loss: 0.4779, train_acc: 0.7732, val_acc:0.7720
		train_roc: 0.8522, val_roc: 0.8501, train_auprc: 0.8321, val_auprc: 0.8302




Saving model
Epoch: 13 (66.8255s), train_loss: 0.4673, val_loss: 0.4705, train_acc: 0.7762, val_acc:0.7762
		train_roc: 0.8559, val_roc: 0.8543, train_auprc: 0.8367, val_auprc: 0.8345




Saving model
Epoch: 14 (69.3813s), train_loss: 0.4604, val_loss: 0.4637, train_acc: 0.7824, val_acc:0.7832
		train_roc: 0.8612, val_roc: 0.8613, train_auprc: 0.8417, val_auprc: 0.8413




Saving model
Epoch: 15 (75.4550s), train_loss: 0.4543, val_loss: 0.4598, train_acc: 0.7856, val_acc:0.7841
		train_roc: 0.8650, val_roc: 0.8635, train_auprc: 0.8456, val_auprc: 0.8452




Saving model
Epoch: 16 (69.2430s), train_loss: 0.4490, val_loss: 0.4541, train_acc: 0.7898, val_acc:0.7907
		train_roc: 0.8683, val_roc: 0.8672, train_auprc: 0.8496, val_auprc: 0.8468




Saving model
Epoch: 17 (83.9970s), train_loss: 0.4446, val_loss: 0.4473, train_acc: 0.7925, val_acc:0.7938
		train_roc: 0.8712, val_roc: 0.8712, train_auprc: 0.8519, val_auprc: 0.8521




Saving model
Epoch: 18 (84.3899s), train_loss: 0.4401, val_loss: 0.4437, train_acc: 0.7949, val_acc:0.7942
		train_roc: 0.8739, val_roc: 0.8729, train_auprc: 0.8558, val_auprc: 0.8534




Saving model
Epoch: 19 (82.6052s), train_loss: 0.4362, val_loss: 0.4370, train_acc: 0.7987, val_acc:0.7975
		train_roc: 0.8767, val_roc: 0.8761, train_auprc: 0.8573, val_auprc: 0.8579




Saving model
Epoch: 20 (85.7812s), train_loss: 0.4314, val_loss: 0.4340, train_acc: 0.8007, val_acc:0.8003
		train_roc: 0.8791, val_roc: 0.8783, train_auprc: 0.8600, val_auprc: 0.8603




Saving model
Epoch: 21 (89.6420s), train_loss: 0.4279, val_loss: 0.4259, train_acc: 0.8024, val_acc:0.8073
		train_roc: 0.8815, val_roc: 0.8834, train_auprc: 0.8633, val_auprc: 0.8652




Epoch: 22 (80.4613s), train_loss: 0.4255, val_loss: 0.4331, train_acc: 0.8052, val_acc:0.7996
		train_roc: 0.8828, val_roc: 0.8791, train_auprc: 0.8637, val_auprc: 0.8624




Saving model
Epoch: 23 (98.9040s), train_loss: 0.4230, val_loss: 0.4259, train_acc: 0.8062, val_acc:0.8062
		train_roc: 0.8842, val_roc: 0.8833, train_auprc: 0.8652, val_auprc: 0.8663




Saving model
Epoch: 24 (93.7244s), train_loss: 0.4186, val_loss: 0.4199, train_acc: 0.8087, val_acc:0.8089
		train_roc: 0.8866, val_roc: 0.8869, train_auprc: 0.8682, val_auprc: 0.8692




Saving model
Epoch: 25 (93.0765s), train_loss: 0.4144, val_loss: 0.4220, train_acc: 0.8112, val_acc:0.8039
		train_roc: 0.8890, val_roc: 0.8871, train_auprc: 0.8707, val_auprc: 0.8695




Saving model
Epoch: 26 (93.0831s), train_loss: 0.4104, val_loss: 0.4151, train_acc: 0.8135, val_acc:0.8130
		train_roc: 0.8913, val_roc: 0.8893, train_auprc: 0.8731, val_auprc: 0.8706




Saving model
Epoch: 27 (80.5362s), train_loss: 0.4064, val_loss: 0.4136, train_acc: 0.8163, val_acc:0.8125
		train_roc: 0.8930, val_roc: 0.8905, train_auprc: 0.8746, val_auprc: 0.8728




Epoch: 28 (97.3113s), train_loss: 0.4055, val_loss: 0.4131, train_acc: 0.8176, val_acc:0.8107
		train_roc: 0.8941, val_roc: 0.8901, train_auprc: 0.8751, val_auprc: 0.8717




Saving model
Epoch: 29 (80.2734s), train_loss: 0.4031, val_loss: 0.4070, train_acc: 0.8184, val_acc:0.8157
		train_roc: 0.8952, val_roc: 0.8931, train_auprc: 0.8769, val_auprc: 0.8758




Saving model
Epoch: 30 (83.8996s), train_loss: 0.3994, val_loss: 0.4018, train_acc: 0.8204, val_acc:0.8176
		train_roc: 0.8971, val_roc: 0.8964, train_auprc: 0.8788, val_auprc: 0.8814




Epoch: 31 (92.7687s), train_loss: 0.3962, val_loss: 0.4035, train_acc: 0.8225, val_acc:0.8170
		train_roc: 0.8991, val_roc: 0.8961, train_auprc: 0.8808, val_auprc: 0.8792




Saving model
Epoch: 32 (92.8661s), train_loss: 0.3932, val_loss: 0.3995, train_acc: 0.8240, val_acc:0.8200
		train_roc: 0.9004, val_roc: 0.8981, train_auprc: 0.8832, val_auprc: 0.8820




Epoch: 33 (82.0967s), train_loss: 0.3908, val_loss: 0.3996, train_acc: 0.8260, val_acc:0.8211
		train_roc: 0.9019, val_roc: 0.8980, train_auprc: 0.8839, val_auprc: 0.8802




Saving model
Epoch: 34 (73.1449s), train_loss: 0.3860, val_loss: 0.3925, train_acc: 0.8292, val_acc:0.8259
		train_roc: 0.9042, val_roc: 0.9019, train_auprc: 0.8862, val_auprc: 0.8859




Saving model
Epoch: 35 (68.6803s), train_loss: 0.3846, val_loss: 0.3897, train_acc: 0.8291, val_acc:0.8257
		train_roc: 0.9048, val_roc: 0.9028, train_auprc: 0.8872, val_auprc: 0.8869




Saving model
Epoch: 36 (68.6490s), train_loss: 0.3828, val_loss: 0.3884, train_acc: 0.8308, val_acc:0.8257
		train_roc: 0.9057, val_roc: 0.9040, train_auprc: 0.8880, val_auprc: 0.8896




Epoch: 37 (68.5476s), train_loss: 0.3775, val_loss: 0.3930, train_acc: 0.8338, val_acc:0.8239
		train_roc: 0.9083, val_roc: 0.9011, train_auprc: 0.8914, val_auprc: 0.8858




Saving model
Epoch: 38 (68.7885s), train_loss: 0.3775, val_loss: 0.3840, train_acc: 0.8336, val_acc:0.8300
		train_roc: 0.9081, val_roc: 0.9065, train_auprc: 0.8912, val_auprc: 0.8909




Epoch: 39 (70.8551s), train_loss: 0.3731, val_loss: 0.3837, train_acc: 0.8358, val_acc:0.8313
		train_roc: 0.9106, val_roc: 0.9062, train_auprc: 0.8939, val_auprc: 0.8899




Saving model
Epoch: 40 (71.2382s), train_loss: 0.3687, val_loss: 0.3820, train_acc: 0.8387, val_acc:0.8325
		train_roc: 0.9127, val_roc: 0.9073, train_auprc: 0.8960, val_auprc: 0.8917




Saving model
Epoch: 41 (70.3351s), train_loss: 0.3670, val_loss: 0.3805, train_acc: 0.8396, val_acc:0.8322
		train_roc: 0.9137, val_roc: 0.9077, train_auprc: 0.8972, val_auprc: 0.8918




Saving model
Epoch: 42 (69.5979s), train_loss: 0.3649, val_loss: 0.3792, train_acc: 0.8413, val_acc:0.8336
		train_roc: 0.9144, val_roc: 0.9082, train_auprc: 0.8973, val_auprc: 0.8925




Saving model
Epoch: 43 (69.5657s), train_loss: 0.3634, val_loss: 0.3742, train_acc: 0.8416, val_acc:0.8366
		train_roc: 0.9149, val_roc: 0.9104, train_auprc: 0.8981, val_auprc: 0.8939




Saving model
Epoch: 44 (69.6266s), train_loss: 0.3586, val_loss: 0.3704, train_acc: 0.8444, val_acc:0.8375
		train_roc: 0.9174, val_roc: 0.9127, train_auprc: 0.9012, val_auprc: 0.8979




Epoch: 45 (69.7049s), train_loss: 0.3587, val_loss: 0.3721, train_acc: 0.8440, val_acc:0.8386
		train_roc: 0.9172, val_roc: 0.9113, train_auprc: 0.9011, val_auprc: 0.8956




Saving model
Epoch: 46 (69.8282s), train_loss: 0.3517, val_loss: 0.3678, train_acc: 0.8484, val_acc:0.8382
		train_roc: 0.9207, val_roc: 0.9139, train_auprc: 0.9053, val_auprc: 0.9001




Epoch: 47 (69.5493s), train_loss: 0.3532, val_loss: 0.3706, train_acc: 0.8477, val_acc:0.8390
		train_roc: 0.9200, val_roc: 0.9127, train_auprc: 0.9036, val_auprc: 0.8980




Epoch: 48 (69.4979s), train_loss: 0.3487, val_loss: 0.3672, train_acc: 0.8496, val_acc:0.8390
		train_roc: 0.9218, val_roc: 0.9145, train_auprc: 0.9064, val_auprc: 0.8996




Saving model
Epoch: 49 (69.4752s), train_loss: 0.3482, val_loss: 0.3661, train_acc: 0.8500, val_acc:0.8412
		train_roc: 0.9220, val_roc: 0.9148, train_auprc: 0.9066, val_auprc: 0.9003




Saving model
Epoch: 50 (69.5611s), train_loss: 0.3477, val_loss: 0.3601, train_acc: 0.8505, val_acc:0.8426
		train_roc: 0.9222, val_roc: 0.9178, train_auprc: 0.9060, val_auprc: 0.9041




Epoch: 51 (69.6181s), train_loss: 0.3422, val_loss: 0.3610, train_acc: 0.8530, val_acc:0.8440
		train_roc: 0.9247, val_roc: 0.9173, train_auprc: 0.9093, val_auprc: 0.9033




Epoch: 52 (69.4856s), train_loss: 0.3402, val_loss: 0.3607, train_acc: 0.8546, val_acc:0.8441
		train_roc: 0.9255, val_roc: 0.9175, train_auprc: 0.9106, val_auprc: 0.9030




Saving model
Epoch: 53 (69.5430s), train_loss: 0.3373, val_loss: 0.3570, train_acc: 0.8565, val_acc:0.8466
		train_roc: 0.9268, val_roc: 0.9189, train_auprc: 0.9114, val_auprc: 0.9043




Saving model
Epoch: 54 (69.3712s), train_loss: 0.3378, val_loss: 0.3549, train_acc: 0.8556, val_acc:0.8470
		train_roc: 0.9265, val_roc: 0.9200, train_auprc: 0.9113, val_auprc: 0.9060




Epoch: 55 (69.4826s), train_loss: 0.3342, val_loss: 0.3573, train_acc: 0.8582, val_acc:0.8437
		train_roc: 0.9279, val_roc: 0.9188, train_auprc: 0.9131, val_auprc: 0.9043




Saving model
Epoch: 56 (69.4921s), train_loss: 0.3318, val_loss: 0.3524, train_acc: 0.8587, val_acc:0.8467
		train_roc: 0.9291, val_roc: 0.9207, train_auprc: 0.9148, val_auprc: 0.9084




Epoch: 57 (69.4780s), train_loss: 0.3298, val_loss: 0.3522, train_acc: 0.8598, val_acc:0.8482
		train_roc: 0.9296, val_roc: 0.9216, train_auprc: 0.9153, val_auprc: 0.9078




Epoch: 58 (69.3897s), train_loss: 0.3285, val_loss: 0.3526, train_acc: 0.8613, val_acc:0.8493
		train_roc: 0.9307, val_roc: 0.9210, train_auprc: 0.9159, val_auprc: 0.9070




Saving model
Epoch: 59 (69.6044s), train_loss: 0.3256, val_loss: 0.3501, train_acc: 0.8620, val_acc:0.8496
		train_roc: 0.9317, val_roc: 0.9233, train_auprc: 0.9177, val_auprc: 0.9090




Epoch: 60 (69.5457s), train_loss: 0.3256, val_loss: 0.3514, train_acc: 0.8627, val_acc:0.8504
		train_roc: 0.9313, val_roc: 0.9218, train_auprc: 0.9170, val_auprc: 0.9079




Saving model
Epoch: 61 (69.5365s), train_loss: 0.3253, val_loss: 0.3474, train_acc: 0.8632, val_acc:0.8521
		train_roc: 0.9320, val_roc: 0.9235, train_auprc: 0.9172, val_auprc: 0.9095




Saving model
Epoch: 62 (69.5265s), train_loss: 0.3238, val_loss: 0.3457, train_acc: 0.8640, val_acc:0.8525
		train_roc: 0.9320, val_roc: 0.9248, train_auprc: 0.9173, val_auprc: 0.9118




Epoch: 63 (69.5549s), train_loss: 0.3205, val_loss: 0.3477, train_acc: 0.8656, val_acc:0.8508
		train_roc: 0.9338, val_roc: 0.9236, train_auprc: 0.9196, val_auprc: 0.9098




Epoch: 64 (69.3384s), train_loss: 0.3195, val_loss: 0.3449, train_acc: 0.8660, val_acc:0.8531
		train_roc: 0.9339, val_roc: 0.9247, train_auprc: 0.9199, val_auprc: 0.9116




Epoch: 65 (69.4542s), train_loss: 0.3170, val_loss: 0.3467, train_acc: 0.8673, val_acc:0.8512
		train_roc: 0.9350, val_roc: 0.9242, train_auprc: 0.9210, val_auprc: 0.9111




Saving model
Epoch: 66 (69.4130s), train_loss: 0.3156, val_loss: 0.3414, train_acc: 0.8677, val_acc:0.8541
		train_roc: 0.9356, val_roc: 0.9266, train_auprc: 0.9222, val_auprc: 0.9135




Epoch: 67 (69.4153s), train_loss: 0.3147, val_loss: 0.3422, train_acc: 0.8686, val_acc:0.8555
		train_roc: 0.9359, val_roc: 0.9260, train_auprc: 0.9220, val_auprc: 0.9131




Epoch: 68 (69.5738s), train_loss: 0.3148, val_loss: 0.3445, train_acc: 0.8695, val_acc:0.8535
		train_roc: 0.9358, val_roc: 0.9253, train_auprc: 0.9214, val_auprc: 0.9119




Saving model
Epoch: 69 (69.5587s), train_loss: 0.3106, val_loss: 0.3413, train_acc: 0.8702, val_acc:0.8551
		train_roc: 0.9373, val_roc: 0.9270, train_auprc: 0.9237, val_auprc: 0.9147




Saving model
Epoch: 70 (69.6881s), train_loss: 0.3100, val_loss: 0.3374, train_acc: 0.8710, val_acc:0.8572
		train_roc: 0.9377, val_roc: 0.9285, train_auprc: 0.9240, val_auprc: 0.9163




Epoch: 71 (69.4788s), train_loss: 0.3098, val_loss: 0.3414, train_acc: 0.8714, val_acc:0.8553
		train_roc: 0.9379, val_roc: 0.9267, train_auprc: 0.9240, val_auprc: 0.9137




Saving model
Epoch: 72 (69.4636s), train_loss: 0.3070, val_loss: 0.3357, train_acc: 0.8732, val_acc:0.8586
		train_roc: 0.9388, val_roc: 0.9290, train_auprc: 0.9252, val_auprc: 0.9170




Epoch: 73 (69.4369s), train_loss: 0.3064, val_loss: 0.3376, train_acc: 0.8732, val_acc:0.8574
		train_roc: 0.9389, val_roc: 0.9282, train_auprc: 0.9253, val_auprc: 0.9156




Epoch: 74 (69.6702s), train_loss: 0.3072, val_loss: 0.3358, train_acc: 0.8735, val_acc:0.8587
		train_roc: 0.9388, val_roc: 0.9291, train_auprc: 0.9249, val_auprc: 0.9169




Epoch: 75 (69.5808s), train_loss: 0.3050, val_loss: 0.3385, train_acc: 0.8739, val_acc:0.8568
		train_roc: 0.9393, val_roc: 0.9271, train_auprc: 0.9258, val_auprc: 0.9139




Saving model
Epoch: 76 (69.4385s), train_loss: 0.3026, val_loss: 0.3363, train_acc: 0.8756, val_acc:0.8568
		train_roc: 0.9405, val_roc: 0.9298, train_auprc: 0.9274, val_auprc: 0.9183




Epoch: 77 (69.4840s), train_loss: 0.3056, val_loss: 0.3360, train_acc: 0.8740, val_acc:0.8579
		train_roc: 0.9392, val_roc: 0.9294, train_auprc: 0.9253, val_auprc: 0.9176




Epoch: 78 (69.7690s), train_loss: 0.3002, val_loss: 0.3331, train_acc: 0.8768, val_acc:0.8601
		train_roc: 0.9414, val_roc: 0.9302, train_auprc: 0.9283, val_auprc: 0.9168




Saving model
Epoch: 79 (69.5762s), train_loss: 0.2990, val_loss: 0.3335, train_acc: 0.8773, val_acc:0.8595
		train_roc: 0.9418, val_roc: 0.9305, train_auprc: 0.9289, val_auprc: 0.9186




Epoch: 80 (69.4536s), train_loss: 0.2996, val_loss: 0.3329, train_acc: 0.8766, val_acc:0.8606
		train_roc: 0.9413, val_roc: 0.9308, train_auprc: 0.9281, val_auprc: 0.9182




Epoch: 81 (69.8542s), train_loss: 0.2983, val_loss: 0.3341, train_acc: 0.8769, val_acc:0.8604
		train_roc: 0.9419, val_roc: 0.9301, train_auprc: 0.9288, val_auprc: 0.9182




Epoch: 82 (69.6708s), train_loss: 0.2952, val_loss: 0.3327, train_acc: 0.8795, val_acc:0.8609
		train_roc: 0.9433, val_roc: 0.9307, train_auprc: 0.9302, val_auprc: 0.9179




Saving model
Epoch: 83 (69.8239s), train_loss: 0.2955, val_loss: 0.3284, train_acc: 0.8784, val_acc:0.8622
		train_roc: 0.9428, val_roc: 0.9326, train_auprc: 0.9301, val_auprc: 0.9217




Epoch: 84 (69.7592s), train_loss: 0.2953, val_loss: 0.3309, train_acc: 0.8794, val_acc:0.8616
		train_roc: 0.9430, val_roc: 0.9315, train_auprc: 0.9301, val_auprc: 0.9195




Epoch: 85 (69.5267s), train_loss: 0.2938, val_loss: 0.3301, train_acc: 0.8802, val_acc:0.8629
		train_roc: 0.9437, val_roc: 0.9320, train_auprc: 0.9308, val_auprc: 0.9200




Epoch: 86 (69.5288s), train_loss: 0.2922, val_loss: 0.3313, train_acc: 0.8810, val_acc:0.8633
		train_roc: 0.9442, val_roc: 0.9316, train_auprc: 0.9314, val_auprc: 0.9196




Epoch: 87 (69.6832s), train_loss: 0.2921, val_loss: 0.3293, train_acc: 0.8812, val_acc:0.8630
		train_roc: 0.9442, val_roc: 0.9320, train_auprc: 0.9311, val_auprc: 0.9204




Epoch: 88 (69.6799s), train_loss: 0.2920, val_loss: 0.3289, train_acc: 0.8810, val_acc:0.8622
		train_roc: 0.9441, val_roc: 0.9328, train_auprc: 0.9315, val_auprc: 0.9214




Saving model
Epoch: 89 (69.3851s), train_loss: 0.2905, val_loss: 0.3281, train_acc: 0.8815, val_acc:0.8626
		train_roc: 0.9446, val_roc: 0.9328, train_auprc: 0.9319, val_auprc: 0.9217




Epoch: 90 (69.5419s), train_loss: 0.2877, val_loss: 0.3316, train_acc: 0.8828, val_acc:0.8627
		train_roc: 0.9458, val_roc: 0.9318, train_auprc: 0.9334, val_auprc: 0.9197




Epoch: 91 (69.5488s), train_loss: 0.2896, val_loss: 0.3284, train_acc: 0.8821, val_acc:0.8630
		train_roc: 0.9450, val_roc: 0.9322, train_auprc: 0.9329, val_auprc: 0.9202




Saving model
Epoch: 92 (69.4012s), train_loss: 0.2898, val_loss: 0.3280, train_acc: 0.8823, val_acc:0.8632
		train_roc: 0.9448, val_roc: 0.9333, train_auprc: 0.9316, val_auprc: 0.9223




Saving model
Epoch: 93 (69.7014s), train_loss: 0.2862, val_loss: 0.3291, train_acc: 0.8837, val_acc:0.8623
		train_roc: 0.9463, val_roc: 0.9334, train_auprc: 0.9341, val_auprc: 0.9226




Epoch: 94 (69.6959s), train_loss: 0.2860, val_loss: 0.3277, train_acc: 0.8841, val_acc:0.8639
		train_roc: 0.9462, val_roc: 0.9335, train_auprc: 0.9339, val_auprc: 0.9223




Saving model
Epoch: 95 (69.7997s), train_loss: 0.2863, val_loss: 0.3260, train_acc: 0.8835, val_acc:0.8650
		train_roc: 0.9463, val_roc: 0.9341, train_auprc: 0.9338, val_auprc: 0.9230




Epoch: 96 (69.7570s), train_loss: 0.2864, val_loss: 0.3292, train_acc: 0.8846, val_acc:0.8629
		train_roc: 0.9462, val_roc: 0.9330, train_auprc: 0.9335, val_auprc: 0.9209




Saving model
Epoch: 97 (69.9400s), train_loss: 0.2867, val_loss: 0.3276, train_acc: 0.8834, val_acc:0.8634
		train_roc: 0.9461, val_roc: 0.9346, train_auprc: 0.9337, val_auprc: 0.9243




Epoch: 98 (69.5111s), train_loss: 0.2851, val_loss: 0.3273, train_acc: 0.8844, val_acc:0.8639
		train_roc: 0.9466, val_roc: 0.9345, train_auprc: 0.9342, val_auprc: 0.9235




Epoch: 99 (69.6387s), train_loss: 0.2834, val_loss: 0.3266, train_acc: 0.8851, val_acc:0.8651
		train_roc: 0.9473, val_roc: 0.9340, train_auprc: 0.9353, val_auprc: 0.9223




Epoch: 100 (69.6312s), train_loss: 0.2816, val_loss: 0.3269, train_acc: 0.8862, val_acc:0.8649
		train_roc: 0.9480, val_roc: 0.9343, train_auprc: 0.9360, val_auprc: 0.9231




Epoch: 101 (69.5359s), train_loss: 0.2839, val_loss: 0.3271, train_acc: 0.8857, val_acc:0.8647
		train_roc: 0.9470, val_roc: 0.9338, train_auprc: 0.9345, val_auprc: 0.9226




Epoch: 102 (69.5431s), train_loss: 0.2819, val_loss: 0.3308, train_acc: 0.8852, val_acc:0.8625
		train_roc: 0.9477, val_roc: 0.9326, train_auprc: 0.9357, val_auprc: 0.9211




Saving model
Epoch: 103 (69.5901s), train_loss: 0.2820, val_loss: 0.3257, train_acc: 0.8864, val_acc:0.8656
		train_roc: 0.9475, val_roc: 0.9351, train_auprc: 0.9350, val_auprc: 0.9245




Epoch: 104 (69.6652s), train_loss: 0.2825, val_loss: 0.3279, train_acc: 0.8854, val_acc:0.8654
		train_roc: 0.9473, val_roc: 0.9346, train_auprc: 0.9350, val_auprc: 0.9227




Saving model
Epoch: 105 (69.6181s), train_loss: 0.2814, val_loss: 0.3237, train_acc: 0.8867, val_acc:0.8668
		train_roc: 0.9479, val_roc: 0.9354, train_auprc: 0.9357, val_auprc: 0.9245




Epoch: 106 (69.5473s), train_loss: 0.2813, val_loss: 0.3239, train_acc: 0.8868, val_acc:0.8675
		train_roc: 0.9477, val_roc: 0.9355, train_auprc: 0.9356, val_auprc: 0.9243




Saving model
Epoch: 107 (69.4618s), train_loss: 0.2792, val_loss: 0.3224, train_acc: 0.8881, val_acc:0.8669
		train_roc: 0.9488, val_roc: 0.9363, train_auprc: 0.9364, val_auprc: 0.9263




Epoch: 108 (69.5172s), train_loss: 0.2780, val_loss: 0.3225, train_acc: 0.8886, val_acc:0.8685
		train_roc: 0.9492, val_roc: 0.9362, train_auprc: 0.9373, val_auprc: 0.9257




Epoch: 109 (69.5926s), train_loss: 0.2782, val_loss: 0.3240, train_acc: 0.8877, val_acc:0.8669
		train_roc: 0.9489, val_roc: 0.9350, train_auprc: 0.9371, val_auprc: 0.9235




Epoch: 110 (69.5562s), train_loss: 0.2776, val_loss: 0.3251, train_acc: 0.8888, val_acc:0.8666
		train_roc: 0.9492, val_roc: 0.9346, train_auprc: 0.9372, val_auprc: 0.9232




Epoch: 111 (69.5265s), train_loss: 0.2789, val_loss: 0.3254, train_acc: 0.8880, val_acc:0.8669
		train_roc: 0.9488, val_roc: 0.9350, train_auprc: 0.9364, val_auprc: 0.9231




Epoch: 112 (69.6218s), train_loss: 0.2779, val_loss: 0.3282, train_acc: 0.8878, val_acc:0.8635
		train_roc: 0.9491, val_roc: 0.9339, train_auprc: 0.9376, val_auprc: 0.9235




Epoch: 113 (69.5920s), train_loss: 0.2773, val_loss: 0.3250, train_acc: 0.8890, val_acc:0.8667
		train_roc: 0.9492, val_roc: 0.9349, train_auprc: 0.9372, val_auprc: 0.9232




Epoch: 114 (69.3297s), train_loss: 0.2774, val_loss: 0.3291, train_acc: 0.8887, val_acc:0.8643
		train_roc: 0.9492, val_roc: 0.9332, train_auprc: 0.9369, val_auprc: 0.9209




Epoch: 115 (69.5352s), train_loss: 0.2771, val_loss: 0.3240, train_acc: 0.8888, val_acc:0.8665
		train_roc: 0.9493, val_roc: 0.9354, train_auprc: 0.9372, val_auprc: 0.9246




Epoch: 116 (69.7430s), train_loss: 0.2759, val_loss: 0.3235, train_acc: 0.8890, val_acc:0.8672
		train_roc: 0.9498, val_roc: 0.9357, train_auprc: 0.9380, val_auprc: 0.9254




Epoch: 117 (69.4425s), train_loss: 0.2774, val_loss: 0.3244, train_acc: 0.8882, val_acc:0.8659
		train_roc: 0.9491, val_roc: 0.9355, train_auprc: 0.9371, val_auprc: 0.9249




Epoch: 118 (69.5769s), train_loss: 0.2762, val_loss: 0.3239, train_acc: 0.8887, val_acc:0.8664
		train_roc: 0.9496, val_roc: 0.9360, train_auprc: 0.9381, val_auprc: 0.9258




Epoch: 119 (69.3901s), train_loss: 0.2751, val_loss: 0.3263, train_acc: 0.8890, val_acc:0.8661
		train_roc: 0.9501, val_roc: 0.9346, train_auprc: 0.9386, val_auprc: 0.9235




Epoch: 120 (69.5206s), train_loss: 0.2769, val_loss: 0.3229, train_acc: 0.8884, val_acc:0.8680
		train_roc: 0.9491, val_roc: 0.9361, train_auprc: 0.9372, val_auprc: 0.9251




Epoch: 121 (69.4581s), train_loss: 0.2754, val_loss: 0.3240, train_acc: 0.8900, val_acc:0.8667
		train_roc: 0.9499, val_roc: 0.9358, train_auprc: 0.9376, val_auprc: 0.9251




Epoch: 122 (69.4437s), train_loss: 0.2744, val_loss: 0.3229, train_acc: 0.8898, val_acc:0.8675
		train_roc: 0.9504, val_roc: 0.9364, train_auprc: 0.9386, val_auprc: 0.9254




Epoch: 123 (69.6628s), train_loss: 0.2780, val_loss: 0.3279, train_acc: 0.8887, val_acc:0.8660
		train_roc: 0.9488, val_roc: 0.9340, train_auprc: 0.9364, val_auprc: 0.9216




Epoch: 124 (69.5917s), train_loss: 0.2726, val_loss: 0.3255, train_acc: 0.8911, val_acc:0.8657
		train_roc: 0.9508, val_roc: 0.9355, train_auprc: 0.9392, val_auprc: 0.9251




Epoch: 125 (69.3872s), train_loss: 0.2731, val_loss: 0.3232, train_acc: 0.8903, val_acc:0.8674
		train_roc: 0.9507, val_roc: 0.9366, train_auprc: 0.9391, val_auprc: 0.9256




Epoch: 126 (69.5879s), train_loss: 0.2753, val_loss: 0.3235, train_acc: 0.8904, val_acc:0.8674
		train_roc: 0.9497, val_roc: 0.9363, train_auprc: 0.9375, val_auprc: 0.9253




Epoch: 127 (69.5533s), train_loss: 0.2723, val_loss: 0.3228, train_acc: 0.8906, val_acc:0.8674
		train_roc: 0.9510, val_roc: 0.9364, train_auprc: 0.9397, val_auprc: 0.9260




Epoch: 128 (69.3971s), train_loss: 0.2738, val_loss: 0.3223, train_acc: 0.8905, val_acc:0.8682
		train_roc: 0.9504, val_roc: 0.9363, train_auprc: 0.9387, val_auprc: 0.9256




Epoch: 129 (69.9081s), train_loss: 0.2732, val_loss: 0.3258, train_acc: 0.8901, val_acc:0.8646
		train_roc: 0.9507, val_roc: 0.9357, train_auprc: 0.9392, val_auprc: 0.9256




Epoch: 130 (69.7111s), train_loss: 0.2735, val_loss: 0.3262, train_acc: 0.8901, val_acc:0.8651
		train_roc: 0.9506, val_roc: 0.9350, train_auprc: 0.9391, val_auprc: 0.9249




Epoch: 131 (69.5250s), train_loss: 0.2717, val_loss: 0.3244, train_acc: 0.8918, val_acc:0.8668
		train_roc: 0.9511, val_roc: 0.9357, train_auprc: 0.9393, val_auprc: 0.9252




Epoch: 132 (69.5231s), train_loss: 0.2723, val_loss: 0.3252, train_acc: 0.8906, val_acc:0.8670
		train_roc: 0.9509, val_roc: 0.9353, train_auprc: 0.9394, val_auprc: 0.9235




Epoch: 133 (69.4932s), train_loss: 0.2723, val_loss: 0.3252, train_acc: 0.8912, val_acc:0.8663
		train_roc: 0.9510, val_roc: 0.9358, train_auprc: 0.9391, val_auprc: 0.9251




Epoch: 134 (70.0307s), train_loss: 0.2736, val_loss: 0.3243, train_acc: 0.8906, val_acc:0.8670
		train_roc: 0.9505, val_roc: 0.9361, train_auprc: 0.9384, val_auprc: 0.9250




Epoch: 135 (69.6710s), train_loss: 0.2726, val_loss: 0.3250, train_acc: 0.8906, val_acc:0.8665
		train_roc: 0.9508, val_roc: 0.9360, train_auprc: 0.9391, val_auprc: 0.9251




Epoch: 136 (69.6435s), train_loss: 0.2724, val_loss: 0.3251, train_acc: 0.8913, val_acc:0.8668
		train_roc: 0.9509, val_roc: 0.9358, train_auprc: 0.9389, val_auprc: 0.9249




Epoch: 137 (69.5907s), train_loss: 0.2726, val_loss: 0.3249, train_acc: 0.8906, val_acc:0.8679
		train_roc: 0.9508, val_roc: 0.9359, train_auprc: 0.9392, val_auprc: 0.9245




Saving model
Epoch: 138 (69.5718s), train_loss: 0.2705, val_loss: 0.3217, train_acc: 0.8911, val_acc:0.8690
		train_roc: 0.9517, val_roc: 0.9370, train_auprc: 0.9406, val_auprc: 0.9269




Epoch: 139 (69.5203s), train_loss: 0.2711, val_loss: 0.3222, train_acc: 0.8918, val_acc:0.8681
		train_roc: 0.9514, val_roc: 0.9370, train_auprc: 0.9397, val_auprc: 0.9267




Epoch: 140 (69.6961s), train_loss: 0.2718, val_loss: 0.3233, train_acc: 0.8915, val_acc:0.8686
		train_roc: 0.9510, val_roc: 0.9368, train_auprc: 0.9389, val_auprc: 0.9264




Epoch: 141 (69.6795s), train_loss: 0.2712, val_loss: 0.3240, train_acc: 0.8909, val_acc:0.8672
		train_roc: 0.9513, val_roc: 0.9364, train_auprc: 0.9401, val_auprc: 0.9259




Epoch: 142 (69.5858s), train_loss: 0.2708, val_loss: 0.3258, train_acc: 0.8917, val_acc:0.8662
		train_roc: 0.9514, val_roc: 0.9354, train_auprc: 0.9398, val_auprc: 0.9243




Epoch: 143 (69.4715s), train_loss: 0.2713, val_loss: 0.3259, train_acc: 0.8918, val_acc:0.8670
		train_roc: 0.9513, val_roc: 0.9354, train_auprc: 0.9396, val_auprc: 0.9242




Epoch: 144 (69.5385s), train_loss: 0.2715, val_loss: 0.3267, train_acc: 0.8917, val_acc:0.8670
		train_roc: 0.9509, val_roc: 0.9348, train_auprc: 0.9388, val_auprc: 0.9239




Epoch: 145 (69.3837s), train_loss: 0.2704, val_loss: 0.3225, train_acc: 0.8919, val_acc:0.8688
		train_roc: 0.9517, val_roc: 0.9371, train_auprc: 0.9401, val_auprc: 0.9268




Epoch: 146 (69.4230s), train_loss: 0.2693, val_loss: 0.3231, train_acc: 0.8925, val_acc:0.8683
		train_roc: 0.9519, val_roc: 0.9367, train_auprc: 0.9405, val_auprc: 0.9262




Epoch: 147 (69.5746s), train_loss: 0.2697, val_loss: 0.3237, train_acc: 0.8927, val_acc:0.8687
		train_roc: 0.9517, val_roc: 0.9367, train_auprc: 0.9399, val_auprc: 0.9261




Epoch: 148 (69.7448s), train_loss: 0.2709, val_loss: 0.3227, train_acc: 0.8920, val_acc:0.8681
		train_roc: 0.9513, val_roc: 0.9373, train_auprc: 0.9395, val_auprc: 0.9268




Epoch: 149 (69.4326s), train_loss: 0.2685, val_loss: 0.3279, train_acc: 0.8929, val_acc:0.8654
		train_roc: 0.9521, val_roc: 0.9350, train_auprc: 0.9412, val_auprc: 0.9239




Epoch: 150 (69.4413s), train_loss: 0.2683, val_loss: 0.3261, train_acc: 0.8933, val_acc:0.8672
		train_roc: 0.9523, val_roc: 0.9357, train_auprc: 0.9411, val_auprc: 0.9246




Epoch: 151 (69.3888s), train_loss: 0.2708, val_loss: 0.3245, train_acc: 0.8912, val_acc:0.8677
		train_roc: 0.9514, val_roc: 0.9366, train_auprc: 0.9399, val_auprc: 0.9259




Epoch: 152 (69.5278s), train_loss: 0.2720, val_loss: 0.3248, train_acc: 0.8913, val_acc:0.8673
		train_roc: 0.9507, val_roc: 0.9366, train_auprc: 0.9387, val_auprc: 0.9262




Epoch: 153 (69.5595s), train_loss: 0.2701, val_loss: 0.3221, train_acc: 0.8925, val_acc:0.8688
		train_roc: 0.9516, val_roc: 0.9375, train_auprc: 0.9400, val_auprc: 0.9269




Epoch: 154 (69.5158s), train_loss: 0.2701, val_loss: 0.3258, train_acc: 0.8915, val_acc:0.8674
		train_roc: 0.9516, val_roc: 0.9358, train_auprc: 0.9401, val_auprc: 0.9252




Epoch: 155 (69.6420s), train_loss: 0.2699, val_loss: 0.3241, train_acc: 0.8928, val_acc:0.8678
		train_roc: 0.9518, val_roc: 0.9369, train_auprc: 0.9400, val_auprc: 0.9261




Saving model
Epoch: 156 (69.5768s), train_loss: 0.2714, val_loss: 0.3237, train_acc: 0.8911, val_acc:0.8670
		train_roc: 0.9510, val_roc: 0.9369, train_auprc: 0.9395, val_auprc: 0.9271




Epoch: 157 (69.5844s), train_loss: 0.2700, val_loss: 0.3266, train_acc: 0.8918, val_acc:0.8667
		train_roc: 0.9517, val_roc: 0.9355, train_auprc: 0.9402, val_auprc: 0.9245




Epoch: 158 (69.5280s), train_loss: 0.2696, val_loss: 0.3234, train_acc: 0.8921, val_acc:0.8686
		train_roc: 0.9520, val_roc: 0.9369, train_auprc: 0.9406, val_auprc: 0.9266




Epoch: 159 (69.5462s), train_loss: 0.2706, val_loss: 0.3248, train_acc: 0.8915, val_acc:0.8677
		train_roc: 0.9515, val_roc: 0.9362, train_auprc: 0.9402, val_auprc: 0.9255




Epoch: 160 (69.5793s), train_loss: 0.2701, val_loss: 0.3253, train_acc: 0.8916, val_acc:0.8673
		train_roc: 0.9515, val_roc: 0.9360, train_auprc: 0.9400, val_auprc: 0.9256




Epoch: 161 (69.4811s), train_loss: 0.2704, val_loss: 0.3234, train_acc: 0.8916, val_acc:0.8687
		train_roc: 0.9515, val_roc: 0.9372, train_auprc: 0.9401, val_auprc: 0.9268




Epoch: 162 (69.6304s), train_loss: 0.2693, val_loss: 0.3278, train_acc: 0.8927, val_acc:0.8664
		train_roc: 0.9522, val_roc: 0.9353, train_auprc: 0.9409, val_auprc: 0.9241




Epoch: 163 (69.5868s), train_loss: 0.2687, val_loss: 0.3236, train_acc: 0.8934, val_acc:0.8699
		train_roc: 0.9522, val_roc: 0.9370, train_auprc: 0.9407, val_auprc: 0.9263




Epoch: 164 (69.4563s), train_loss: 0.2688, val_loss: 0.3261, train_acc: 0.8924, val_acc:0.8672
		train_roc: 0.9524, val_roc: 0.9359, train_auprc: 0.9413, val_auprc: 0.9248




Epoch: 165 (69.5236s), train_loss: 0.2701, val_loss: 0.3237, train_acc: 0.8923, val_acc:0.8681
		train_roc: 0.9515, val_roc: 0.9369, train_auprc: 0.9397, val_auprc: 0.9260




Saving model
Epoch: 166 (69.6220s), train_loss: 0.2703, val_loss: 0.3237, train_acc: 0.8920, val_acc:0.8680
		train_roc: 0.9516, val_roc: 0.9370, train_auprc: 0.9401, val_auprc: 0.9273




Epoch: 167 (69.3909s), train_loss: 0.2696, val_loss: 0.3248, train_acc: 0.8925, val_acc:0.8684
		train_roc: 0.9517, val_roc: 0.9368, train_auprc: 0.9403, val_auprc: 0.9263




Saving model
Epoch: 168 (69.6187s), train_loss: 0.2688, val_loss: 0.3220, train_acc: 0.8924, val_acc:0.8694
		train_roc: 0.9521, val_roc: 0.9376, train_auprc: 0.9406, val_auprc: 0.9274




Epoch: 169 (69.3394s), train_loss: 0.2706, val_loss: 0.3264, train_acc: 0.8922, val_acc:0.8673
		train_roc: 0.9514, val_roc: 0.9360, train_auprc: 0.9393, val_auprc: 0.9243




Epoch: 170 (69.5967s), train_loss: 0.2686, val_loss: 0.3271, train_acc: 0.8928, val_acc:0.8672
		train_roc: 0.9522, val_roc: 0.9355, train_auprc: 0.9406, val_auprc: 0.9246




Epoch: 171 (69.4127s), train_loss: 0.2678, val_loss: 0.3235, train_acc: 0.8930, val_acc:0.8682
		train_roc: 0.9523, val_roc: 0.9370, train_auprc: 0.9413, val_auprc: 0.9267




Epoch: 172 (69.4202s), train_loss: 0.2696, val_loss: 0.3256, train_acc: 0.8926, val_acc:0.8676
		train_roc: 0.9519, val_roc: 0.9361, train_auprc: 0.9400, val_auprc: 0.9255




Epoch: 173 (69.5696s), train_loss: 0.2693, val_loss: 0.3242, train_acc: 0.8929, val_acc:0.8696
		train_roc: 0.9519, val_roc: 0.9369, train_auprc: 0.9398, val_auprc: 0.9260




Epoch: 174 (69.7429s), train_loss: 0.2680, val_loss: 0.3268, train_acc: 0.8930, val_acc:0.8666
		train_roc: 0.9523, val_roc: 0.9357, train_auprc: 0.9409, val_auprc: 0.9249




Epoch: 175 (69.5902s), train_loss: 0.2683, val_loss: 0.3242, train_acc: 0.8928, val_acc:0.8688
		train_roc: 0.9524, val_roc: 0.9370, train_auprc: 0.9412, val_auprc: 0.9262




Epoch: 176 (69.3880s), train_loss: 0.2670, val_loss: 0.3245, train_acc: 0.8937, val_acc:0.8688
		train_roc: 0.9528, val_roc: 0.9365, train_auprc: 0.9415, val_auprc: 0.9254




Epoch: 177 (69.5060s), train_loss: 0.2699, val_loss: 0.3262, train_acc: 0.8920, val_acc:0.8679
		train_roc: 0.9515, val_roc: 0.9358, train_auprc: 0.9398, val_auprc: 0.9249




Epoch: 178 (69.3354s), train_loss: 0.2695, val_loss: 0.3230, train_acc: 0.8923, val_acc:0.8690
		train_roc: 0.9517, val_roc: 0.9373, train_auprc: 0.9400, val_auprc: 0.9271




Epoch: 179 (69.6302s), train_loss: 0.2682, val_loss: 0.3239, train_acc: 0.8927, val_acc:0.8675
		train_roc: 0.9523, val_roc: 0.9369, train_auprc: 0.9409, val_auprc: 0.9261




Epoch: 180 (69.5570s), train_loss: 0.2695, val_loss: 0.3240, train_acc: 0.8922, val_acc:0.8686
		train_roc: 0.9520, val_roc: 0.9371, train_auprc: 0.9405, val_auprc: 0.9266




Epoch: 181 (69.7130s), train_loss: 0.2684, val_loss: 0.3270, train_acc: 0.8928, val_acc:0.8672
		train_roc: 0.9522, val_roc: 0.9356, train_auprc: 0.9408, val_auprc: 0.9248




Epoch: 182 (69.3983s), train_loss: 0.2676, val_loss: 0.3247, train_acc: 0.8925, val_acc:0.8682
		train_roc: 0.9526, val_roc: 0.9366, train_auprc: 0.9417, val_auprc: 0.9260




Epoch: 183 (69.5147s), train_loss: 0.2689, val_loss: 0.3242, train_acc: 0.8930, val_acc:0.8679
		train_roc: 0.9520, val_roc: 0.9367, train_auprc: 0.9403, val_auprc: 0.9269




Epoch: 184 (69.6745s), train_loss: 0.2687, val_loss: 0.3242, train_acc: 0.8929, val_acc:0.8680
		train_roc: 0.9521, val_roc: 0.9368, train_auprc: 0.9403, val_auprc: 0.9265




Epoch: 185 (69.4827s), train_loss: 0.2693, val_loss: 0.3283, train_acc: 0.8926, val_acc:0.8665
		train_roc: 0.9518, val_roc: 0.9351, train_auprc: 0.9401, val_auprc: 0.9240




Epoch: 186 (69.4002s), train_loss: 0.2696, val_loss: 0.3260, train_acc: 0.8924, val_acc:0.8661
		train_roc: 0.9517, val_roc: 0.9359, train_auprc: 0.9401, val_auprc: 0.9254




Epoch: 187 (69.3967s), train_loss: 0.2713, val_loss: 0.3264, train_acc: 0.8914, val_acc:0.8675
		train_roc: 0.9511, val_roc: 0.9358, train_auprc: 0.9395, val_auprc: 0.9252




Epoch: 188 (69.4360s), train_loss: 0.2679, val_loss: 0.3230, train_acc: 0.8930, val_acc:0.8694
		train_roc: 0.9525, val_roc: 0.9372, train_auprc: 0.9411, val_auprc: 0.9268




Epoch: 189 (69.5764s), train_loss: 0.2682, val_loss: 0.3251, train_acc: 0.8927, val_acc:0.8685
		train_roc: 0.9523, val_roc: 0.9362, train_auprc: 0.9411, val_auprc: 0.9255




Epoch: 190 (69.6678s), train_loss: 0.2696, val_loss: 0.3253, train_acc: 0.8923, val_acc:0.8679
		train_roc: 0.9516, val_roc: 0.9364, train_auprc: 0.9400, val_auprc: 0.9257




Epoch: 191 (69.7642s), train_loss: 0.2686, val_loss: 0.3266, train_acc: 0.8923, val_acc:0.8672
		train_roc: 0.9522, val_roc: 0.9358, train_auprc: 0.9410, val_auprc: 0.9250




Epoch: 192 (69.6679s), train_loss: 0.2682, val_loss: 0.3255, train_acc: 0.8934, val_acc:0.8684
		train_roc: 0.9522, val_roc: 0.9362, train_auprc: 0.9406, val_auprc: 0.9249




Epoch: 193 (69.9030s), train_loss: 0.2697, val_loss: 0.3242, train_acc: 0.8916, val_acc:0.8688
		train_roc: 0.9518, val_roc: 0.9369, train_auprc: 0.9402, val_auprc: 0.9261




Epoch: 194 (69.7617s), train_loss: 0.2682, val_loss: 0.3250, train_acc: 0.8927, val_acc:0.8677
		train_roc: 0.9524, val_roc: 0.9365, train_auprc: 0.9410, val_auprc: 0.9259




Epoch: 195 (69.7594s), train_loss: 0.2681, val_loss: 0.3272, train_acc: 0.8928, val_acc:0.8671
		train_roc: 0.9523, val_roc: 0.9353, train_auprc: 0.9406, val_auprc: 0.9241




Epoch: 196 (69.5524s), train_loss: 0.2685, val_loss: 0.3246, train_acc: 0.8930, val_acc:0.8690
		train_roc: 0.9523, val_roc: 0.9367, train_auprc: 0.9409, val_auprc: 0.9253




Epoch: 197 (69.4041s), train_loss: 0.2682, val_loss: 0.3265, train_acc: 0.8928, val_acc:0.8671
		train_roc: 0.9524, val_roc: 0.9357, train_auprc: 0.9408, val_auprc: 0.9246




Epoch: 198 (69.6382s), train_loss: 0.2691, val_loss: 0.3261, train_acc: 0.8926, val_acc:0.8681
		train_roc: 0.9519, val_roc: 0.9358, train_auprc: 0.9404, val_auprc: 0.9247




Epoch: 199 (69.6186s), train_loss: 0.2699, val_loss: 0.3256, train_acc: 0.8921, val_acc:0.8675
		train_roc: 0.9516, val_roc: 0.9363, train_auprc: 0.9398, val_auprc: 0.9259




Epoch: 200 (69.5129s), train_loss: 0.2701, val_loss: 0.3261, train_acc: 0.8923, val_acc:0.8663
		train_roc: 0.9516, val_roc: 0.9360, train_auprc: 0.9395, val_auprc: 0.9257




Epoch: 201 (69.4591s), train_loss: 0.2683, val_loss: 0.3242, train_acc: 0.8930, val_acc:0.8681
		train_roc: 0.9522, val_roc: 0.9369, train_auprc: 0.9407, val_auprc: 0.9266




Epoch: 202 (69.5386s), train_loss: 0.2692, val_loss: 0.3229, train_acc: 0.8927, val_acc:0.8697
		train_roc: 0.9519, val_roc: 0.9373, train_auprc: 0.9402, val_auprc: 0.9267




Epoch: 203 (69.4847s), train_loss: 0.2684, val_loss: 0.3247, train_acc: 0.8926, val_acc:0.8679
		train_roc: 0.9522, val_roc: 0.9365, train_auprc: 0.9409, val_auprc: 0.9255




Epoch: 204 (69.4624s), train_loss: 0.2694, val_loss: 0.3255, train_acc: 0.8922, val_acc:0.8677
		train_roc: 0.9517, val_roc: 0.9365, train_auprc: 0.9398, val_auprc: 0.9258




Epoch: 205 (69.6219s), train_loss: 0.2695, val_loss: 0.3270, train_acc: 0.8926, val_acc:0.8666
		train_roc: 0.9517, val_roc: 0.9355, train_auprc: 0.9402, val_auprc: 0.9249




Epoch: 206 (69.4205s), train_loss: 0.2692, val_loss: 0.3264, train_acc: 0.8922, val_acc:0.8675
		train_roc: 0.9519, val_roc: 0.9359, train_auprc: 0.9406, val_auprc: 0.9250




Epoch: 207 (69.7769s), train_loss: 0.2684, val_loss: 0.3255, train_acc: 0.8933, val_acc:0.8674
		train_roc: 0.9522, val_roc: 0.9363, train_auprc: 0.9406, val_auprc: 0.9259




Epoch: 208 (69.5275s), train_loss: 0.2699, val_loss: 0.3270, train_acc: 0.8920, val_acc:0.8675
		train_roc: 0.9515, val_roc: 0.9357, train_auprc: 0.9399, val_auprc: 0.9242




Epoch: 209 (69.6479s), train_loss: 0.2687, val_loss: 0.3254, train_acc: 0.8928, val_acc:0.8680
		train_roc: 0.9520, val_roc: 0.9362, train_auprc: 0.9403, val_auprc: 0.9257




Epoch: 210 (69.4327s), train_loss: 0.2688, val_loss: 0.3237, train_acc: 0.8927, val_acc:0.8683
		train_roc: 0.9519, val_roc: 0.9371, train_auprc: 0.9401, val_auprc: 0.9269




Epoch: 211 (69.4959s), train_loss: 0.2699, val_loss: 0.3239, train_acc: 0.8925, val_acc:0.8697
		train_roc: 0.9515, val_roc: 0.9372, train_auprc: 0.9397, val_auprc: 0.9260




Epoch: 212 (69.5233s), train_loss: 0.2681, val_loss: 0.3244, train_acc: 0.8929, val_acc:0.8661
		train_roc: 0.9523, val_roc: 0.9367, train_auprc: 0.9410, val_auprc: 0.9272




Epoch: 213 (69.4321s), train_loss: 0.2680, val_loss: 0.3234, train_acc: 0.8935, val_acc:0.8692
		train_roc: 0.9524, val_roc: 0.9372, train_auprc: 0.9408, val_auprc: 0.9267




Epoch: 214 (69.3720s), train_loss: 0.2687, val_loss: 0.3237, train_acc: 0.8926, val_acc:0.8689
		train_roc: 0.9521, val_roc: 0.9369, train_auprc: 0.9405, val_auprc: 0.9266




Epoch: 215 (69.4191s), train_loss: 0.2687, val_loss: 0.3266, train_acc: 0.8929, val_acc:0.8671
		train_roc: 0.9519, val_roc: 0.9357, train_auprc: 0.9405, val_auprc: 0.9250




Epoch: 216 (69.5348s), train_loss: 0.2696, val_loss: 0.3251, train_acc: 0.8925, val_acc:0.8678
		train_roc: 0.9518, val_roc: 0.9365, train_auprc: 0.9404, val_auprc: 0.9259




Epoch: 217 (69.4313s), train_loss: 0.2703, val_loss: 0.3245, train_acc: 0.8921, val_acc:0.8675
		train_roc: 0.9515, val_roc: 0.9366, train_auprc: 0.9397, val_auprc: 0.9264




Epoch: 218 (69.6970s), train_loss: 0.2694, val_loss: 0.3243, train_acc: 0.8925, val_acc:0.8685
		train_roc: 0.9518, val_roc: 0.9368, train_auprc: 0.9400, val_auprc: 0.9263




Epoch: 219 (69.5785s), train_loss: 0.2686, val_loss: 0.3247, train_acc: 0.8931, val_acc:0.8683
		train_roc: 0.9521, val_roc: 0.9367, train_auprc: 0.9408, val_auprc: 0.9262




Epoch: 220 (69.5047s), train_loss: 0.2708, val_loss: 0.3266, train_acc: 0.8919, val_acc:0.8675
		train_roc: 0.9513, val_roc: 0.9359, train_auprc: 0.9393, val_auprc: 0.9250




Epoch: 221 (69.3346s), train_loss: 0.2682, val_loss: 0.3245, train_acc: 0.8928, val_acc:0.8685
		train_roc: 0.9522, val_roc: 0.9369, train_auprc: 0.9408, val_auprc: 0.9261




Epoch: 222 (69.4668s), train_loss: 0.2665, val_loss: 0.3262, train_acc: 0.8942, val_acc:0.8679
		train_roc: 0.9528, val_roc: 0.9358, train_auprc: 0.9418, val_auprc: 0.9245




Epoch: 223 (69.4063s), train_loss: 0.2676, val_loss: 0.3243, train_acc: 0.8937, val_acc:0.8687
		train_roc: 0.9525, val_roc: 0.9367, train_auprc: 0.9410, val_auprc: 0.9259




Epoch: 224 (69.5490s), train_loss: 0.2686, val_loss: 0.3260, train_acc: 0.8931, val_acc:0.8676
		train_roc: 0.9522, val_roc: 0.9362, train_auprc: 0.9402, val_auprc: 0.9253




Epoch: 225 (69.5877s), train_loss: 0.2689, val_loss: 0.3233, train_acc: 0.8925, val_acc:0.8689
		train_roc: 0.9519, val_roc: 0.9374, train_auprc: 0.9405, val_auprc: 0.9268




Epoch: 226 (69.5224s), train_loss: 0.2680, val_loss: 0.3263, train_acc: 0.8931, val_acc:0.8673
		train_roc: 0.9524, val_roc: 0.9358, train_auprc: 0.9414, val_auprc: 0.9248




Epoch: 227 (69.4966s), train_loss: 0.2685, val_loss: 0.3261, train_acc: 0.8928, val_acc:0.8682
		train_roc: 0.9522, val_roc: 0.9359, train_auprc: 0.9407, val_auprc: 0.9248




Epoch: 228 (69.4814s), train_loss: 0.2698, val_loss: 0.3241, train_acc: 0.8923, val_acc:0.8680
		train_roc: 0.9517, val_roc: 0.9370, train_auprc: 0.9397, val_auprc: 0.9267




Epoch: 229 (69.3630s), train_loss: 0.2685, val_loss: 0.3246, train_acc: 0.8932, val_acc:0.8688
		train_roc: 0.9521, val_roc: 0.9369, train_auprc: 0.9405, val_auprc: 0.9262




Epoch: 230 (69.5725s), train_loss: 0.2697, val_loss: 0.3245, train_acc: 0.8924, val_acc:0.8674
		train_roc: 0.9517, val_roc: 0.9368, train_auprc: 0.9401, val_auprc: 0.9269




Epoch: 231 (69.4219s), train_loss: 0.2683, val_loss: 0.3309, train_acc: 0.8931, val_acc:0.8651
		train_roc: 0.9523, val_roc: 0.9340, train_auprc: 0.9410, val_auprc: 0.9222




Epoch: 232 (69.4603s), train_loss: 0.2671, val_loss: 0.3259, train_acc: 0.8941, val_acc:0.8681
		train_roc: 0.9527, val_roc: 0.9362, train_auprc: 0.9415, val_auprc: 0.9249




Epoch: 233 (69.8090s), train_loss: 0.2682, val_loss: 0.3244, train_acc: 0.8931, val_acc:0.8681
		train_roc: 0.9523, val_roc: 0.9370, train_auprc: 0.9407, val_auprc: 0.9265




Epoch: 234 (69.5354s), train_loss: 0.2684, val_loss: 0.3241, train_acc: 0.8932, val_acc:0.8691
		train_roc: 0.9521, val_roc: 0.9368, train_auprc: 0.9405, val_auprc: 0.9257




Epoch: 235 (69.5253s), train_loss: 0.2697, val_loss: 0.3249, train_acc: 0.8923, val_acc:0.8675
		train_roc: 0.9517, val_roc: 0.9366, train_auprc: 0.9400, val_auprc: 0.9266




Epoch: 236 (69.6169s), train_loss: 0.2684, val_loss: 0.3247, train_acc: 0.8929, val_acc:0.8682
		train_roc: 0.9522, val_roc: 0.9369, train_auprc: 0.9407, val_auprc: 0.9260




Epoch: 237 (69.7342s), train_loss: 0.2690, val_loss: 0.3275, train_acc: 0.8922, val_acc:0.8673
		train_roc: 0.9519, val_roc: 0.9355, train_auprc: 0.9404, val_auprc: 0.9244




Epoch: 238 (69.7324s), train_loss: 0.2675, val_loss: 0.3252, train_acc: 0.8937, val_acc:0.8678
		train_roc: 0.9526, val_roc: 0.9365, train_auprc: 0.9408, val_auprc: 0.9261




Epoch: 239 (69.5254s), train_loss: 0.2687, val_loss: 0.3248, train_acc: 0.8926, val_acc:0.8674
		train_roc: 0.9521, val_roc: 0.9367, train_auprc: 0.9406, val_auprc: 0.9264




Epoch: 240 (69.7477s), train_loss: 0.2682, val_loss: 0.3244, train_acc: 0.8932, val_acc:0.8687
		train_roc: 0.9522, val_roc: 0.9367, train_auprc: 0.9405, val_auprc: 0.9256




Epoch: 241 (69.5450s), train_loss: 0.2686, val_loss: 0.3252, train_acc: 0.8926, val_acc:0.8673
		train_roc: 0.9521, val_roc: 0.9365, train_auprc: 0.9405, val_auprc: 0.9261




Epoch: 242 (69.5849s), train_loss: 0.2697, val_loss: 0.3252, train_acc: 0.8921, val_acc:0.8671
		train_roc: 0.9516, val_roc: 0.9363, train_auprc: 0.9397, val_auprc: 0.9262




Epoch: 243 (69.5293s), train_loss: 0.2702, val_loss: 0.3230, train_acc: 0.8919, val_acc:0.8687
		train_roc: 0.9515, val_roc: 0.9375, train_auprc: 0.9395, val_auprc: 0.9271




Epoch: 244 (69.5498s), train_loss: 0.2709, val_loss: 0.3260, train_acc: 0.8919, val_acc:0.8681
		train_roc: 0.9512, val_roc: 0.9362, train_auprc: 0.9392, val_auprc: 0.9253




Epoch: 245 (69.5873s), train_loss: 0.2696, val_loss: 0.3268, train_acc: 0.8918, val_acc:0.8668
		train_roc: 0.9518, val_roc: 0.9357, train_auprc: 0.9402, val_auprc: 0.9247




Epoch: 246 (69.5204s), train_loss: 0.2689, val_loss: 0.3235, train_acc: 0.8927, val_acc:0.8686
		train_roc: 0.9518, val_roc: 0.9373, train_auprc: 0.9404, val_auprc: 0.9267




Saving model
Epoch: 247 (69.5370s), train_loss: 0.2693, val_loss: 0.3229, train_acc: 0.8928, val_acc:0.8686
		train_roc: 0.9517, val_roc: 0.9373, train_auprc: 0.9403, val_auprc: 0.9278




Epoch: 248 (69.5063s), train_loss: 0.2671, val_loss: 0.3255, train_acc: 0.8937, val_acc:0.8674
		train_roc: 0.9528, val_roc: 0.9364, train_auprc: 0.9414, val_auprc: 0.9260




Epoch: 249 (69.4772s), train_loss: 0.2683, val_loss: 0.3263, train_acc: 0.8930, val_acc:0.8679
		train_roc: 0.9522, val_roc: 0.9360, train_auprc: 0.9406, val_auprc: 0.9253




Epoch: 250 (69.4955s), train_loss: 0.2667, val_loss: 0.3279, train_acc: 0.8935, val_acc:0.8662
		train_roc: 0.9529, val_roc: 0.9355, train_auprc: 0.9418, val_auprc: 0.9245




Epoch: 251 (69.6125s), train_loss: 0.2677, val_loss: 0.3233, train_acc: 0.8935, val_acc:0.8686
		train_roc: 0.9526, val_roc: 0.9376, train_auprc: 0.9410, val_auprc: 0.9273




Epoch: 252 (69.4228s), train_loss: 0.2678, val_loss: 0.3259, train_acc: 0.8927, val_acc:0.8674
		train_roc: 0.9524, val_roc: 0.9361, train_auprc: 0.9410, val_auprc: 0.9254




Epoch: 253 (69.4284s), train_loss: 0.2678, val_loss: 0.3253, train_acc: 0.8935, val_acc:0.8674
		train_roc: 0.9525, val_roc: 0.9363, train_auprc: 0.9412, val_auprc: 0.9258




Epoch: 254 (69.3996s), train_loss: 0.2685, val_loss: 0.3253, train_acc: 0.8927, val_acc:0.8681
		train_roc: 0.9522, val_roc: 0.9364, train_auprc: 0.9407, val_auprc: 0.9257




Epoch: 255 (69.5259s), train_loss: 0.2688, val_loss: 0.3233, train_acc: 0.8929, val_acc:0.8687
		train_roc: 0.9520, val_roc: 0.9372, train_auprc: 0.9402, val_auprc: 0.9271




Epoch: 256 (69.6472s), train_loss: 0.2694, val_loss: 0.3263, train_acc: 0.8921, val_acc:0.8675
		train_roc: 0.9519, val_roc: 0.9359, train_auprc: 0.9408, val_auprc: 0.9254




Epoch: 257 (69.3031s), train_loss: 0.2671, val_loss: 0.3248, train_acc: 0.8936, val_acc:0.8675
		train_roc: 0.9527, val_roc: 0.9368, train_auprc: 0.9414, val_auprc: 0.9265




Epoch: 258 (69.5476s), train_loss: 0.2672, val_loss: 0.3248, train_acc: 0.8938, val_acc:0.8683
		train_roc: 0.9527, val_roc: 0.9366, train_auprc: 0.9411, val_auprc: 0.9256




Epoch: 259 (69.5856s), train_loss: 0.2676, val_loss: 0.3255, train_acc: 0.8928, val_acc:0.8681
		train_roc: 0.9525, val_roc: 0.9364, train_auprc: 0.9418, val_auprc: 0.9252




Epoch: 260 (69.5444s), train_loss: 0.2682, val_loss: 0.3245, train_acc: 0.8930, val_acc:0.8685
		train_roc: 0.9524, val_roc: 0.9366, train_auprc: 0.9412, val_auprc: 0.9265




Epoch: 261 (69.4204s), train_loss: 0.2668, val_loss: 0.3239, train_acc: 0.8935, val_acc:0.8681
		train_roc: 0.9527, val_roc: 0.9369, train_auprc: 0.9418, val_auprc: 0.9269




Epoch: 262 (69.5664s), train_loss: 0.2684, val_loss: 0.3267, train_acc: 0.8930, val_acc:0.8675
		train_roc: 0.9521, val_roc: 0.9358, train_auprc: 0.9404, val_auprc: 0.9247




Epoch: 263 (69.4922s), train_loss: 0.2688, val_loss: 0.3260, train_acc: 0.8928, val_acc:0.8685
		train_roc: 0.9522, val_roc: 0.9360, train_auprc: 0.9406, val_auprc: 0.9248




Epoch: 264 (69.3362s), train_loss: 0.2689, val_loss: 0.3259, train_acc: 0.8925, val_acc:0.8678
		train_roc: 0.9520, val_roc: 0.9362, train_auprc: 0.9403, val_auprc: 0.9253




Epoch: 265 (69.5858s), train_loss: 0.2688, val_loss: 0.3241, train_acc: 0.8924, val_acc:0.8679
		train_roc: 0.9519, val_roc: 0.9371, train_auprc: 0.9406, val_auprc: 0.9266




Epoch: 266 (74.4547s), train_loss: 0.2678, val_loss: 0.3254, train_acc: 0.8938, val_acc:0.8672
		train_roc: 0.9524, val_roc: 0.9363, train_auprc: 0.9409, val_auprc: 0.9259




Epoch: 267 (70.2787s), train_loss: 0.2689, val_loss: 0.3243, train_acc: 0.8927, val_acc:0.8685
		train_roc: 0.9520, val_roc: 0.9367, train_auprc: 0.9405, val_auprc: 0.9262




Epoch: 268 (68.7828s), train_loss: 0.2702, val_loss: 0.3252, train_acc: 0.8926, val_acc:0.8688
		train_roc: 0.9515, val_roc: 0.9363, train_auprc: 0.9393, val_auprc: 0.9249




Saving model
Epoch: 269 (68.9024s), train_loss: 0.2696, val_loss: 0.3226, train_acc: 0.8928, val_acc:0.8678
		train_roc: 0.9517, val_roc: 0.9375, train_auprc: 0.9397, val_auprc: 0.9279




Epoch: 270 (68.7281s), train_loss: 0.2688, val_loss: 0.3254, train_acc: 0.8927, val_acc:0.8677
		train_roc: 0.9521, val_roc: 0.9365, train_auprc: 0.9404, val_auprc: 0.9255




Epoch: 271 (69.1361s), train_loss: 0.2703, val_loss: 0.3234, train_acc: 0.8922, val_acc:0.8690
		train_roc: 0.9514, val_roc: 0.9373, train_auprc: 0.9398, val_auprc: 0.9261




Epoch: 272 (69.1027s), train_loss: 0.2696, val_loss: 0.3238, train_acc: 0.8924, val_acc:0.8686
		train_roc: 0.9516, val_roc: 0.9370, train_auprc: 0.9397, val_auprc: 0.9270




Saving model
Epoch: 273 (69.1003s), train_loss: 0.2683, val_loss: 0.3226, train_acc: 0.8929, val_acc:0.8685
		train_roc: 0.9524, val_roc: 0.9376, train_auprc: 0.9409, val_auprc: 0.9280




Epoch: 274 (68.8300s), train_loss: 0.2702, val_loss: 0.3249, train_acc: 0.8921, val_acc:0.8679
		train_roc: 0.9514, val_roc: 0.9365, train_auprc: 0.9396, val_auprc: 0.9260




Epoch: 275 (68.7249s), train_loss: 0.2667, val_loss: 0.3252, train_acc: 0.8942, val_acc:0.8677
		train_roc: 0.9528, val_roc: 0.9366, train_auprc: 0.9415, val_auprc: 0.9261




Epoch: 276 (68.8066s), train_loss: 0.2654, val_loss: 0.3231, train_acc: 0.8940, val_acc:0.8681
		train_roc: 0.9535, val_roc: 0.9374, train_auprc: 0.9428, val_auprc: 0.9274




Epoch: 277 (68.6454s), train_loss: 0.2666, val_loss: 0.3257, train_acc: 0.8938, val_acc:0.8675
		train_roc: 0.9528, val_roc: 0.9363, train_auprc: 0.9416, val_auprc: 0.9255




Epoch: 278 (68.7047s), train_loss: 0.2679, val_loss: 0.3239, train_acc: 0.8931, val_acc:0.8688
		train_roc: 0.9523, val_roc: 0.9371, train_auprc: 0.9410, val_auprc: 0.9265




Epoch: 279 (68.7381s), train_loss: 0.2682, val_loss: 0.3249, train_acc: 0.8931, val_acc:0.8685
		train_roc: 0.9523, val_roc: 0.9367, train_auprc: 0.9409, val_auprc: 0.9257




Epoch: 280 (68.9965s), train_loss: 0.2685, val_loss: 0.3250, train_acc: 0.8930, val_acc:0.8673
		train_roc: 0.9521, val_roc: 0.9366, train_auprc: 0.9406, val_auprc: 0.9263




Epoch: 281 (68.7520s), train_loss: 0.2694, val_loss: 0.3246, train_acc: 0.8925, val_acc:0.8676
		train_roc: 0.9516, val_roc: 0.9367, train_auprc: 0.9400, val_auprc: 0.9263




Epoch: 282 (68.6604s), train_loss: 0.2695, val_loss: 0.3226, train_acc: 0.8922, val_acc:0.8697
		train_roc: 0.9518, val_roc: 0.9377, train_auprc: 0.9402, val_auprc: 0.9270




Epoch: 283 (69.0746s), train_loss: 0.2693, val_loss: 0.3240, train_acc: 0.8928, val_acc:0.8687
		train_roc: 0.9519, val_roc: 0.9372, train_auprc: 0.9403, val_auprc: 0.9263




Epoch: 284 (68.7130s), train_loss: 0.2688, val_loss: 0.3268, train_acc: 0.8929, val_acc:0.8669
		train_roc: 0.9519, val_roc: 0.9358, train_auprc: 0.9401, val_auprc: 0.9252




Epoch: 285 (69.0089s), train_loss: 0.2712, val_loss: 0.3242, train_acc: 0.8915, val_acc:0.8689
		train_roc: 0.9511, val_roc: 0.9369, train_auprc: 0.9391, val_auprc: 0.9261




Epoch: 286 (68.7406s), train_loss: 0.2687, val_loss: 0.3256, train_acc: 0.8927, val_acc:0.8677
		train_roc: 0.9522, val_roc: 0.9362, train_auprc: 0.9406, val_auprc: 0.9251




Epoch: 287 (68.8432s), train_loss: 0.2675, val_loss: 0.3225, train_acc: 0.8939, val_acc:0.8685
		train_roc: 0.9524, val_roc: 0.9379, train_auprc: 0.9407, val_auprc: 0.9278




Epoch: 288 (68.7474s), train_loss: 0.2691, val_loss: 0.3243, train_acc: 0.8925, val_acc:0.8685
		train_roc: 0.9519, val_roc: 0.9369, train_auprc: 0.9401, val_auprc: 0.9264




Epoch: 289 (68.7869s), train_loss: 0.2680, val_loss: 0.3238, train_acc: 0.8928, val_acc:0.8688
		train_roc: 0.9525, val_roc: 0.9372, train_auprc: 0.9411, val_auprc: 0.9266




Epoch: 290 (68.8199s), train_loss: 0.2695, val_loss: 0.3232, train_acc: 0.8924, val_acc:0.8680
		train_roc: 0.9518, val_roc: 0.9373, train_auprc: 0.9400, val_auprc: 0.9272




Epoch: 291 (68.8045s), train_loss: 0.2689, val_loss: 0.3248, train_acc: 0.8926, val_acc:0.8688
		train_roc: 0.9521, val_roc: 0.9367, train_auprc: 0.9405, val_auprc: 0.9254




Epoch: 292 (68.7828s), train_loss: 0.2672, val_loss: 0.3237, train_acc: 0.8937, val_acc:0.8688
		train_roc: 0.9528, val_roc: 0.9373, train_auprc: 0.9414, val_auprc: 0.9262




Epoch: 293 (68.8528s), train_loss: 0.2675, val_loss: 0.3257, train_acc: 0.8936, val_acc:0.8680
		train_roc: 0.9525, val_roc: 0.9361, train_auprc: 0.9410, val_auprc: 0.9252




Saving model
Epoch: 294 (68.9300s), train_loss: 0.2663, val_loss: 0.3225, train_acc: 0.8935, val_acc:0.8686
		train_roc: 0.9530, val_roc: 0.9378, train_auprc: 0.9422, val_auprc: 0.9284




Epoch: 295 (68.7265s), train_loss: 0.2677, val_loss: 0.3246, train_acc: 0.8937, val_acc:0.8677
		train_roc: 0.9525, val_roc: 0.9368, train_auprc: 0.9410, val_auprc: 0.9263




Epoch: 296 (68.6414s), train_loss: 0.2684, val_loss: 0.3225, train_acc: 0.8926, val_acc:0.8693
		train_roc: 0.9524, val_roc: 0.9376, train_auprc: 0.9410, val_auprc: 0.9272




Epoch: 297 (68.7804s), train_loss: 0.2672, val_loss: 0.3244, train_acc: 0.8937, val_acc:0.8685
		train_roc: 0.9526, val_roc: 0.9366, train_auprc: 0.9413, val_auprc: 0.9259




Epoch: 298 (68.7917s), train_loss: 0.2676, val_loss: 0.3237, train_acc: 0.8933, val_acc:0.8691
		train_roc: 0.9526, val_roc: 0.9371, train_auprc: 0.9411, val_auprc: 0.9263




Epoch: 299 (68.6347s), train_loss: 0.2677, val_loss: 0.3269, train_acc: 0.8936, val_acc:0.8673
		train_roc: 0.9524, val_roc: 0.9357, train_auprc: 0.9408, val_auprc: 0.9239




Epoch: 300 (68.8114s), train_loss: 0.2695, val_loss: 0.3244, train_acc: 0.8922, val_acc:0.8675
		train_roc: 0.9518, val_roc: 0.9368, train_auprc: 0.9402, val_auprc: 0.9264


In [18]:
# Predict
model = torch.load(model_file)
print(model)
model.to(device=device)
predict(model, test_data_loader, device)

SSI_DDI(
  (initial_norm): LayerNorm(56, affine=True, mode=graph)
  (net_norms): ModuleList(
    (0-4): 5 x LayerNorm(64, affine=True, mode=graph)
  )
  (block0): SSI_DDI_Block(
    (conv): GATConv(56, 32, heads=2)
    (readout): SAGPooling(GraphConv, 64, min_score=-1, multiplier=1.0)
  )
  (block1): SSI_DDI_Block(
    (conv): GATConv(64, 32, heads=2)
    (readout): SAGPooling(GraphConv, 64, min_score=-1, multiplier=1.0)
  )
  (block2): SSI_DDI_Block(
    (conv): GATConv(64, 32, heads=2)
    (readout): SAGPooling(GraphConv, 64, min_score=-1, multiplier=1.0)
  )
  (block3): SSI_DDI_Block(
    (conv): GATConv(64, 32, heads=2)
    (readout): SAGPooling(GraphConv, 64, min_score=-1, multiplier=1.0)
  )
  (block4): SSI_DDI_Block(
    (conv): GATConv(64, 32, heads=2)
    (readout): SAGPooling(GraphConv, 64, min_score=-1, multiplier=1.0)
  )
  (co_attention): CoAttentionLayerImproved(
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (KGE): RESCAL(86, torch.Size([86, 4096]))
)
Starting predic

  model = torch.load(model_file)


Test Accuracy: 0.8687
Test ROC AUC: 0.9368
Test PRC AUC: 0.9259
