In [1]:
!pip install ogb



In [2]:
from torch_scatter import scatter_softmax, scatter_mean
from tqdm import tqdm

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Embedding, ModuleList
from torch.nn import Sequential, Linear, BatchNorm1d, ReLU
from torch_scatter import scatter, scatter_mean, scatter_add, scatter_sum
from torch_geometric.nn import GINConv, GINEConv


class AtomEncoder(torch.nn.Module):
    def __init__(self, hidden_channels):
        super(AtomEncoder, self).__init__()

        self.embeddings = torch.nn.ModuleList()

        for i in range(9):
            self.embeddings.append(Embedding(100, hidden_channels))

    def reset_parameters(self):
        for embedding in self.embeddings:
            embedding.reset_parameters()

    def forward(self, x):
        if x.dim() == 1:
            x = x.unsqueeze(1)

        out = 0
        for i in range(x.size(1)):
            out += self.embeddings[i](x[:, i])
        return out


class BondEncoder(torch.nn.Module):
    def __init__(self, hidden_channels):
        super(BondEncoder, self).__init__()

        self.embeddings = torch.nn.ModuleList()

        for i in range(3):
            self.embeddings.append(Embedding(6, hidden_channels))

    def reset_parameters(self):
        for embedding in self.embeddings:
            embedding.reset_parameters()

    def forward(self, edge_attr):
        if edge_attr.dim() == 1:
            edge_attr = edge_attr.unsqueeze(1)

        out = 0
        for i in range(edge_attr.size(1)):
            out += self.embeddings[i](edge_attr[:, i])
        return out

In [16]:
import torch
import torch.nn as nn
from torch_scatter import scatter
import numpy as np
import matplotlib.pyplot as plt
# partial scatter func
from functools import partial


class Diff_Hist(nn.Module):
    def __init__(self, centers=[-5, -2.5, 0, 2.5, 5], scale=2):
        super(Diff_Hist, self).__init__()
        """ centers: center of the bins used. 
            scale  : exponential scale for sigmoid. the higher,
                the lower the amplitud of the bins.
        """
        # save params
        self.scale        = scale
        self.true_centers = centers
        
        # augmenting sigmoid so higher gradient and localized feats
        self.max_scaler   = 1
        self.sigmoid_x    = lambda y: 1/(1+torch.exp(-self.scale*y))
        self.max_scaler   = torch.max(self.func(torch.arange(start=-1, end=1, step=1e-3)))
        self.sigmoid_x    = lambda y: 1/(1+torch.exp(-self.scale*y))
        
        # our function is centered @ 0.5 by default, but not at extremes (then it's centered @ 0)
        self.centers      = [centers[0]]+list(0.5 + np.array(centers[1:-1]))+[centers[-1]]
        self.n_centers    = len(self.centers)

    def forward(self, x, reduce="mean"):
        r""" reduces dim=0 as a set of bins. designed for 2d outs. """
        x_shape = x.shape
        out     = torch.zeros(x_shape[0], self.n_centers, *x.shape[1:], device=device)
        for i, center in enumerate(self.centers):
            # add the 2 extremes to differentiate btween them (+/- and right/left is inversed)
            if i == 0:
                out[:, i] = self.make_right_extreme(x, center)
            elif i == self.n_centers-1:
                out[:, i] = self.make_left_extreme(x, center)
            else:
                out[:, i] = self.make_bin(x, center)
            
        # summarizes a set of (n_examples, n_feats) into a single (n_centers, n_feats)
        if reduce == "sum" or reduce == "add":
            return out.sum(dim=0)
        elif reduce:
            return out.mean(dim=0)
        return out

    def make_bin(self, x, center):
        return self.func(x+center)
    
    def make_left_extreme(self, x, center):
        return (1-self.sigmoid_x(x+center))
    
    def make_right_extreme(self, x, center):
        return self.sigmoid_x(x+center)
        
    def func(self, x):
        # perform minmax scaling so max is 1 and min (is already) 0
        return (self.sigmoid_x(x)-self.sigmoid_x(x-1))/self.max_scaler

    def __repr__(self):
        return "Simple differentiable histogram layer:" + \
               str({"centers"   : self.true_centers,
                    "exp_scale" : self.scale})
    

class Readout_Hist(nn.Module):
    def __init__(self, centers, reduce=None):
        super(Readout_Hist, self).__init__()
        # save centers
        self.true_centers  = centers
        self.n_centers     = len(centers)
        self.diff_hist     = Diff_Hist(centers=centers)
        # reduces nodes -> graph by gather/scatter
        if reduce is None:
            reduce = partial(scatter, reduce="mean")
        self.reduce_scheme = reduce

    def forward(self, x, batch=None, bsize=None, dim=0):
        r""" reduces dim=0 by gather/scatter each node. designed for 2d outs. """
        hist_aux = self.diff_hist(x, reduce=None)
        hist_aux = hist_aux.reshape(hist_aux.shape[0], -1)
        return self.reduce_scheme(hist_aux, batch, dim_size=bsize, dim=dim)

    def __repr__(self):
        return "Readout by histogram:" + str({"centers": self.true_centers})

In [36]:
class Net(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels, num_layers, dropout=0.0,
                 inter_message_passing=True):
        super(Net, self).__init__()
        self.num_layers = num_layers
        self.dropout = dropout
        self.inter_message_passing = inter_message_passing
        print("using inter_message_passing:", inter_message_passing)

        self.atom_encoder = AtomEncoder(hidden_channels)
        self.clique_encoder = Embedding(4, hidden_channels)

        self.bond_encoders = ModuleList()
        self.atom_convs = ModuleList()
        self.atom_batch_norms = ModuleList()
        
        # DIFFERENTIABLE HISTOGRAMS
        offset, n   = 2, 7 # from -(2*5) to 2*5 
        centers     = list(range(-offset*n, offset*n+1, offset))
        self.reader = Readout_Hist(centers=centers)

        for _ in range(num_layers):
            self.bond_encoders.append(BondEncoder(hidden_channels))
            nn = Sequential(
                Linear(hidden_channels, 2 * hidden_channels),
                BatchNorm1d(2 * hidden_channels),
                ReLU(),
                Linear(2 * hidden_channels, hidden_channels),
            )
            self.atom_convs.append(GINEConv(nn, train_eps=True))
            self.atom_batch_norms.append(BatchNorm1d(hidden_channels))

        self.clique_convs = ModuleList()
        self.clique_batch_norms = ModuleList()

        for _ in range(num_layers):
            nn = Sequential(
                Linear(hidden_channels, 2 * hidden_channels),
                BatchNorm1d(2 * hidden_channels),
                ReLU(),
                Linear(2 * hidden_channels, hidden_channels),
            )
            self.clique_convs.append(GINConv(nn, train_eps=True))
            self.clique_batch_norms.append(BatchNorm1d(hidden_channels))

        self.atom2clique_lins = ModuleList()
        self.clique2atom_lins = ModuleList()

        for _ in range(num_layers):
            self.atom2clique_lins.append(
                Linear(hidden_channels, hidden_channels))
            self.clique2atom_lins.append(
                Linear(hidden_channels, hidden_channels))

        self.atom_lin = Linear(hidden_channels*len(centers), hidden_channels*len(centers))
        self.clique_lin = Linear(hidden_channels*len(centers), hidden_channels*len(centers))
        # final layer merges info from readout + clique
        self.lin = Linear(hidden_channels*len(centers), out_channels)

    def reset_parameters(self):
        self.atom_encoder.reset_parameters()
        self.clique_encoder.reset_parameters()

        for emb, conv, batch_norm in zip(self.bond_encoders, self.atom_convs,
                                         self.atom_batch_norms):
            emb.reset_parameters()
            conv.reset_parameters()
            batch_norm.reset_parameters()

        for conv, batch_norm in zip(self.clique_convs,
                                    self.clique_batch_norms):
            conv.reset_parameters()
            batch_norm.reset_parameters()

        for lin1, lin2 in zip(self.atom2clique_lins, self.clique2atom_lins):
            lin1.reset_parameters()
            lin2.reset_parameters()

        self.atom_lin.reset_parameters()
        self.clique_lin.reset_parameters()
        self.lin.reset_parameters()
        # self.reader.reset_parameters()

    def forward(self, data):
        x = self.atom_encoder(data.x.squeeze())

        if self.inter_message_passing:
            x_clique = self.clique_encoder(data.x_clique.squeeze())

        for i in range(self.num_layers):
            edge_attr = self.bond_encoders[i](data.edge_attr)
            x = self.atom_convs[i](x, data.edge_index, edge_attr)
            x = self.atom_batch_norms[i](x)
            x = F.relu(x)
            x = F.dropout(x, self.dropout, training=self.training)

            if self.inter_message_passing:
                row, col = data.atom2clique_index

                x_clique = x_clique + F.relu(self.atom2clique_lins[i](scatter(
                    x[row], col, dim=0, dim_size=x_clique.size(0),
                    reduce='mean')))

                x_clique = self.clique_convs[i](x_clique, data.tree_edge_index)
                x_clique = self.clique_batch_norms[i](x_clique)
                x_clique = F.relu(x_clique)
                x_clique = F.dropout(x_clique, self.dropout,
                                     training=self.training)

                x = x + F.relu(self.clique2atom_lins[i](scatter(
                    x_clique[col], row, dim=0, dim_size=x.size(0),
                    reduce='mean')))
        # print(x.shape, "before readout")
        x = self.reader(x, data.batch)
        # print(x.shape, "after readout")
        x = F.dropout(x, self.dropout, training=self.training)
        x = self.atom_lin(x)
        # print(x.shape, "after atom_lin")
        

        if self.inter_message_passing:
            tree_batch = torch.repeat_interleave(data.num_cliques)
            # x_clique = scatter(x_clique, tree_batch, dim=0, dim_size=x.size(0), reduce='mean')
            x_clique = self.reader(x_clique, tree_batch, bsize=x.size(0))
            x_clique = F.dropout(x_clique, self.dropout,
                                 training=self.training)
            x_clique = self.clique_lin(x_clique)
            x = x + x_clique

        x = F.relu(x)
        x = F.dropout(x, self.dropout, training=self.training)
        # print(x.shape, "before last")
        x = self.lin(x)
        return x

In [37]:
import torch
from torch_geometric.data import Data
from torch_geometric.utils import tree_decomposition

from rdkit import Chem
from rdkit.Chem.rdchem import BondType

bonds = [BondType.SINGLE, BondType.DOUBLE, BondType.TRIPLE, BondType.AROMATIC]


def mol_from_data(data):
    mol = Chem.RWMol()

    x = data.x if data.x.dim() == 1 else data.x[:, 0]
    for z in x.tolist():
        mol.AddAtom(Chem.Atom(z))

    row, col = data.edge_index
    mask = row < col
    row, col = row[mask].tolist(), col[mask].tolist()

    bond_type = data.edge_attr
    bond_type = bond_type if bond_type.dim() == 1 else bond_type[:, 0]
    bond_type = bond_type[mask].tolist()

    for i, j, bond in zip(row, col, bond_type):
        assert bond >= 1 and bond <= 4
        mol.AddBond(i, j, bonds[bond - 1])

    return mol.GetMol()


class JunctionTreeData(Data):
    def __inc__(self, key, item):
        if key == 'tree_edge_index':
            return self.x_clique.size(0)
        elif key == 'atom2clique_index':
            return torch.tensor([[self.x.size(0)], [self.x_clique.size(0)]])
        else:
            return super(JunctionTreeData, self).__inc__(key, item)


class JunctionTree(object):
    def __call__(self, data):
        mol = mol_from_data(data)
        out = tree_decomposition(mol, return_vocab=True)
        tree_edge_index, atom2clique_index, num_cliques, x_clique = out

        data = JunctionTreeData(**{k: v for k, v in data})

        data.tree_edge_index = tree_edge_index
        data.atom2clique_index = atom2clique_index
        data.num_cliques = num_cliques
        data.x_clique = x_clique

        return data

In [38]:
# edit the function causing the error: add argument chem=None + modify function code: 
# Chem=chem if chem is not None else Chem

# tree_decomposition(Chem.MolFromSmiles("cicccc1c"), return_vocab=True)
# once modified and saved, restart the environmnet, comment this cell and run the experiment.

In [39]:
import argparse

import torch
from torch.optim import Adam
import numpy as np
from sklearn.metrics import roc_auc_score

from ogb.graphproppred import PygGraphPropPredDataset
from torch_geometric.data import DataLoader
from torch_geometric.transforms import Compose

class Argparse_emulate():
    def __init__(self, device=0, hidden_channels=256, num_layers=2, dropout=0.5,
               epochs=100, no_inter_message_passing="store_true"):
        self.device = device
        self.hidden_channels = hidden_channels
        self.num_layers = num_layers
        self.dropout = 0.6 # dropout
        self.epochs = epochs
        self.no_inter_message_passing = no_inter_message_passing
        return

args = Argparse_emulate()
# parser.add_argument('--device', type=int, default=0)
# parser.add_argument('--hidden_channels', type=int, default=256)
# parser.add_argument('--num_layers', type=int, default=2)
# parser.add_argument('--dropout', type=float, default=0.5)
# parser.add_argument('--epochs', type=int, default=100)
# parser.add_argument('--no_inter_message_passing', action='store_true')
# args = parser.parse_args()
# print(args)


class OGBTransform(object):
    # OGB saves atom and bond types zero-index based. We need to revert that.
    def __call__(self, data):
        data.x[:, 0] += 1
        data.edge_attr[:, 0] += 1
        return data


transform = Compose([OGBTransform(), JunctionTree()])

name = 'ogbg-molhiv'
dataset = PygGraphPropPredDataset(name, 'data', pre_transform=transform)

# correct splits
split_idx = dataset.get_idx_split()
train_dataset = dataset[split_idx['train']]
val_dataset = dataset[split_idx['valid']]
test_dataset = dataset[split_idx['test']]

train_loader = DataLoader(train_dataset, 128, shuffle=True)
val_loader = DataLoader(val_dataset, 128, shuffle=False)
test_loader = DataLoader(test_dataset, 128, shuffle=False)


device = f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu'

In [None]:
def train(epoch, vals=False):
    values = []
    model.train()

    total_loss = 0
    for i,data in tqdm(enumerate(train_loader)):
        data = data.to(device)
        optimizer.zero_grad()
        mask = ~torch.isnan(data.y)
        out = model(data)[mask]
        y = data.y.to(torch.float)[mask]
        loss = torch.nn.BCEWithLogitsLoss()(out, y)
        loss.backward()
        total_loss += loss.item() * data.num_graphs
        optimizer.step()
        # record beta and p values
        if vals:
            # display computational graph
            # global g
            #g = make_dot(out)
            # "a"+9
            pass

    return total_loss / len(train_loader.dataset), values


@torch.no_grad()
def test(loader):
    model.eval()

    y_preds, y_trues = [], []
    for data in loader:
        data = data.to(device)
        y_preds.append(model(data))
        y_trues.append(data.y)

    y_pred = torch.cat(y_preds, dim=0).cpu().numpy()
    y_true = torch.cat(y_trues, dim=0).cpu().numpy()

    rocauc_list = []
    for i in range(y_true.shape[1]):
        # AUC is only defined when there is at least one positive data.
        if np.sum(y_true[:, i] == 1) > 0 and np.sum(y_true[:, i] == 0) > 0:
            # ignore nan values
            is_labeled = y_true[:, i] == y_true[:, i]
            rocauc_list.append(
                roc_auc_score(y_true[is_labeled, i], y_pred[is_labeled, i]))

    return {"rocauc": sum(rocauc_list) / len(rocauc_list)}


values     = []
test_perfs = []
for run in range(10):
    print()
    print(f'Run {run}:')
    print()
    model = Net(hidden_channels=args.hidden_channels,
            out_channels=dataset.num_tasks, num_layers=args.num_layers,
            dropout = args.dropout if run<10 else 0.6, # edited to increase dropout
            inter_message_passing= args.no_inter_message_passing).to(device) # not args.no_inter

    model.reset_parameters()
    optimizer = Adam(model.parameters(), lr=0.0001)

    best_val_perf = test_perf = 0
    for epoch in range(1, args.epochs + 1):
        loss, epoch_values = train(epoch, vals=True)
        train_perf = test(train_loader)
        val_perf = test(val_loader)

        if val_perf["rocauc"] > best_val_perf:
            best_val_perf = val_perf["rocauc"]
            test_perf = test(test_loader)

        print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, '
              f'Train: {train_perf["rocauc"]:.4f}, Val: {val_perf["rocauc"]:.4f}, '
              f'Test: {test_perf["rocauc"]:.4f}')
        if epoch % 10 == 0:
            print("Recorded values:", epoch_values)

    test_perfs.append(test_perf["rocauc"])
    # values.append(epoch_values[-1])

test_perf = torch.tensor(test_perfs)
print('===========================')
print(f'Final Test: {test_perf.mean():.4f} ± {test_perf.std():.4f}')


Run 0:

using inter_message_passing: store_true


258it [00:37,  6.88it/s]
1it [00:00,  6.85it/s]

Epoch: 001, Loss: 0.1669, Train: 0.6953, Val: 0.6703, Test: 0.6531


258it [00:37,  6.90it/s]
1it [00:00,  6.41it/s]

Epoch: 002, Loss: 0.1507, Train: 0.7446, Val: 0.7558, Test: 0.7680


258it [00:37,  6.90it/s]
1it [00:00,  6.49it/s]

Epoch: 003, Loss: 0.1410, Train: 0.7553, Val: 0.7734, Test: 0.7604


258it [00:37,  6.89it/s]
1it [00:00,  5.99it/s]

Epoch: 004, Loss: 0.1397, Train: 0.7789, Val: 0.7608, Test: 0.7604


258it [00:37,  6.90it/s]
1it [00:00,  6.21it/s]

Epoch: 005, Loss: 0.1357, Train: 0.7790, Val: 0.7915, Test: 0.7836


258it [00:37,  6.89it/s]
1it [00:00,  6.41it/s]

Epoch: 006, Loss: 0.1336, Train: 0.7996, Val: 0.7864, Test: 0.7836


258it [00:37,  6.90it/s]
1it [00:00,  5.99it/s]

Epoch: 007, Loss: 0.1314, Train: 0.7925, Val: 0.7944, Test: 0.7652


258it [00:37,  6.89it/s]
1it [00:00,  5.99it/s]

Epoch: 008, Loss: 0.1304, Train: 0.8008, Val: 0.7988, Test: 0.7715


258it [00:37,  6.89it/s]
1it [00:00,  5.75it/s]

Epoch: 009, Loss: 0.1292, Train: 0.8089, Val: 0.8114, Test: 0.7675


258it [00:37,  6.90it/s]
1it [00:00,  5.71it/s]

Epoch: 010, Loss: 0.1279, Train: 0.8052, Val: 0.8002, Test: 0.7675
Recorded values: []


258it [00:37,  6.89it/s]
1it [00:00,  5.95it/s]

Epoch: 011, Loss: 0.1268, Train: 0.8180, Val: 0.7939, Test: 0.7675


258it [00:37,  6.90it/s]
1it [00:00,  6.29it/s]

Epoch: 012, Loss: 0.1255, Train: 0.8202, Val: 0.8061, Test: 0.7675


258it [00:37,  6.89it/s]
1it [00:00,  6.29it/s]

Epoch: 013, Loss: 0.1246, Train: 0.8265, Val: 0.8016, Test: 0.7675


258it [00:37,  6.89it/s]
1it [00:00,  6.41it/s]

Epoch: 014, Loss: 0.1241, Train: 0.8319, Val: 0.8061, Test: 0.7675


258it [00:37,  6.89it/s]
1it [00:00,  6.10it/s]

Epoch: 015, Loss: 0.1222, Train: 0.8382, Val: 0.7991, Test: 0.7675


258it [00:37,  6.89it/s]
1it [00:00,  5.99it/s]

Epoch: 016, Loss: 0.1214, Train: 0.8424, Val: 0.8142, Test: 0.7858


258it [00:37,  6.91it/s]
1it [00:00,  6.49it/s]

Epoch: 017, Loss: 0.1224, Train: 0.8310, Val: 0.8130, Test: 0.7858


258it [00:37,  6.89it/s]
1it [00:00,  6.37it/s]

Epoch: 018, Loss: 0.1199, Train: 0.8461, Val: 0.8204, Test: 0.7796


258it [00:37,  6.90it/s]
1it [00:00,  6.54it/s]

Epoch: 019, Loss: 0.1195, Train: 0.8465, Val: 0.8256, Test: 0.7930


258it [00:37,  6.89it/s]
1it [00:00,  6.02it/s]

Epoch: 020, Loss: 0.1181, Train: 0.8531, Val: 0.8233, Test: 0.7930
Recorded values: []


258it [00:38,  6.78it/s]
1it [00:00,  6.02it/s]

Epoch: 021, Loss: 0.1175, Train: 0.8524, Val: 0.8080, Test: 0.7930


258it [00:38,  6.77it/s]
1it [00:00,  5.78it/s]

Epoch: 022, Loss: 0.1170, Train: 0.8586, Val: 0.8151, Test: 0.7930


258it [00:37,  6.82it/s]
1it [00:00,  6.54it/s]

Epoch: 023, Loss: 0.1156, Train: 0.8591, Val: 0.8166, Test: 0.7930


258it [00:37,  6.91it/s]
1it [00:00,  6.29it/s]

Epoch: 024, Loss: 0.1169, Train: 0.8531, Val: 0.8237, Test: 0.7930


258it [00:37,  6.90it/s]
1it [00:00,  6.17it/s]

Epoch: 025, Loss: 0.1153, Train: 0.8576, Val: 0.8016, Test: 0.7930


258it [00:37,  6.90it/s]
1it [00:00,  5.92it/s]

Epoch: 026, Loss: 0.1141, Train: 0.8636, Val: 0.8147, Test: 0.7930


258it [00:37,  6.89it/s]
1it [00:00,  6.45it/s]

Epoch: 027, Loss: 0.1142, Train: 0.8681, Val: 0.8241, Test: 0.7930


258it [00:37,  6.89it/s]
1it [00:00,  6.29it/s]

Epoch: 028, Loss: 0.1131, Train: 0.8680, Val: 0.8055, Test: 0.7930


258it [00:37,  6.89it/s]
1it [00:00,  6.41it/s]

Epoch: 029, Loss: 0.1145, Train: 0.8657, Val: 0.8187, Test: 0.7930


258it [00:37,  6.89it/s]
1it [00:00,  5.88it/s]

Epoch: 030, Loss: 0.1138, Train: 0.8584, Val: 0.8062, Test: 0.7930
Recorded values: []


258it [00:37,  6.80it/s]
1it [00:00,  6.17it/s]

Epoch: 031, Loss: 0.1123, Train: 0.8802, Val: 0.8230, Test: 0.7930


258it [00:38,  6.74it/s]
1it [00:00,  6.14it/s]

Epoch: 032, Loss: 0.1115, Train: 0.8781, Val: 0.8205, Test: 0.7930


258it [00:37,  6.86it/s]
1it [00:00,  6.41it/s]

Epoch: 033, Loss: 0.1128, Train: 0.8651, Val: 0.7868, Test: 0.7930


258it [00:37,  6.89it/s]
1it [00:00,  6.45it/s]

Epoch: 034, Loss: 0.1139, Train: 0.8787, Val: 0.8226, Test: 0.7930


258it [00:37,  6.89it/s]
1it [00:00,  6.49it/s]

Epoch: 035, Loss: 0.1116, Train: 0.8821, Val: 0.8207, Test: 0.7930


258it [00:37,  6.89it/s]
1it [00:00,  6.06it/s]

Epoch: 036, Loss: 0.1102, Train: 0.8824, Val: 0.8171, Test: 0.7930


258it [00:37,  6.88it/s]
1it [00:00,  5.99it/s]

Epoch: 037, Loss: 0.1094, Train: 0.8881, Val: 0.8150, Test: 0.7930


258it [00:37,  6.83it/s]
1it [00:00,  6.41it/s]

Epoch: 038, Loss: 0.1089, Train: 0.8838, Val: 0.8161, Test: 0.7930


258it [00:37,  6.90it/s]
1it [00:00,  6.21it/s]

Epoch: 039, Loss: 0.1082, Train: 0.8847, Val: 0.8021, Test: 0.7930


258it [00:37,  6.90it/s]
1it [00:00,  6.41it/s]

Epoch: 040, Loss: 0.1072, Train: 0.8819, Val: 0.8180, Test: 0.7930
Recorded values: []


258it [00:37,  6.90it/s]
1it [00:00,  6.29it/s]

Epoch: 041, Loss: 0.1090, Train: 0.8911, Val: 0.8244, Test: 0.7930


258it [00:37,  6.90it/s]
1it [00:00,  6.02it/s]

Epoch: 042, Loss: 0.1087, Train: 0.8885, Val: 0.8186, Test: 0.7930


258it [00:37,  6.90it/s]
1it [00:00,  5.95it/s]

Epoch: 043, Loss: 0.1081, Train: 0.8954, Val: 0.8333, Test: 0.7540


258it [00:37,  6.89it/s]
1it [00:00,  5.81it/s]

Epoch: 044, Loss: 0.1076, Train: 0.8943, Val: 0.8253, Test: 0.7540


258it [00:37,  6.89it/s]
1it [00:00,  6.25it/s]

Epoch: 045, Loss: 0.1075, Train: 0.8942, Val: 0.8210, Test: 0.7540


258it [00:37,  6.89it/s]
0it [00:00, ?it/s]

Epoch: 046, Loss: 0.1067, Train: 0.8959, Val: 0.8110, Test: 0.7540


258it [00:37,  6.90it/s]
1it [00:00,  6.14it/s]

Epoch: 047, Loss: 0.1060, Train: 0.8936, Val: 0.8191, Test: 0.7540


258it [00:37,  6.90it/s]
1it [00:00,  6.02it/s]

Epoch: 048, Loss: 0.1057, Train: 0.8966, Val: 0.8188, Test: 0.7540


258it [00:37,  6.92it/s]
1it [00:00,  5.78it/s]

Epoch: 049, Loss: 0.1056, Train: 0.8879, Val: 0.8173, Test: 0.7540


258it [00:37,  6.90it/s]
1it [00:00,  6.58it/s]

Epoch: 050, Loss: 0.1064, Train: 0.9003, Val: 0.8308, Test: 0.7540
Recorded values: []


258it [00:37,  6.89it/s]
1it [00:00,  5.92it/s]

Epoch: 051, Loss: 0.1052, Train: 0.8977, Val: 0.8056, Test: 0.7540


258it [00:37,  6.92it/s]
1it [00:00,  6.80it/s]

Epoch: 052, Loss: 0.1048, Train: 0.8944, Val: 0.8123, Test: 0.7540


258it [00:37,  6.90it/s]
1it [00:00,  6.06it/s]

Epoch: 053, Loss: 0.1054, Train: 0.9058, Val: 0.8203, Test: 0.7540


258it [00:37,  6.91it/s]
1it [00:00,  6.58it/s]

Epoch: 054, Loss: 0.1046, Train: 0.8970, Val: 0.8228, Test: 0.7540


258it [00:37,  6.90it/s]
1it [00:00,  6.37it/s]

Epoch: 055, Loss: 0.1056, Train: 0.9063, Val: 0.8118, Test: 0.7540


258it [00:37,  6.89it/s]
1it [00:00,  6.14it/s]

Epoch: 056, Loss: 0.1044, Train: 0.9029, Val: 0.8115, Test: 0.7540


258it [00:37,  6.85it/s]
1it [00:00,  5.92it/s]

Epoch: 057, Loss: 0.1076, Train: 0.9055, Val: 0.8174, Test: 0.7540


258it [00:38,  6.76it/s]
1it [00:00,  6.02it/s]

Epoch: 058, Loss: 0.1035, Train: 0.9001, Val: 0.8113, Test: 0.7540


258it [00:38,  6.63it/s]
1it [00:00,  5.62it/s]

Epoch: 059, Loss: 0.1028, Train: 0.9098, Val: 0.8088, Test: 0.7540


258it [00:37,  6.86it/s]
1it [00:00,  5.68it/s]

Epoch: 060, Loss: 0.1015, Train: 0.9137, Val: 0.8424, Test: 0.7649
Recorded values: []


258it [00:37,  6.89it/s]
1it [00:00,  6.58it/s]

Epoch: 061, Loss: 0.1019, Train: 0.9106, Val: 0.8279, Test: 0.7649


258it [00:37,  6.89it/s]
1it [00:00,  5.99it/s]

Epoch: 062, Loss: 0.1024, Train: 0.9115, Val: 0.8211, Test: 0.7649


258it [00:37,  6.89it/s]
1it [00:00,  6.14it/s]

Epoch: 063, Loss: 0.1016, Train: 0.9068, Val: 0.8102, Test: 0.7649


258it [00:37,  6.87it/s]
1it [00:00,  6.62it/s]

Epoch: 064, Loss: 0.1017, Train: 0.9136, Val: 0.8258, Test: 0.7649


258it [00:37,  6.89it/s]
1it [00:00,  6.17it/s]

Epoch: 065, Loss: 0.1037, Train: 0.9078, Val: 0.8179, Test: 0.7649


258it [00:37,  6.89it/s]
1it [00:00,  5.81it/s]

Epoch: 066, Loss: 0.1013, Train: 0.9158, Val: 0.8203, Test: 0.7649


258it [00:37,  6.90it/s]
1it [00:00,  6.37it/s]

Epoch: 067, Loss: 0.1002, Train: 0.9104, Val: 0.8089, Test: 0.7649


258it [00:37,  6.88it/s]
1it [00:00,  6.49it/s]

Epoch: 068, Loss: 0.1024, Train: 0.9182, Val: 0.8134, Test: 0.7649


258it [00:37,  6.90it/s]
1it [00:00,  5.78it/s]

Epoch: 069, Loss: 0.1001, Train: 0.9176, Val: 0.8201, Test: 0.7649


258it [00:37,  6.91it/s]
1it [00:00,  6.58it/s]

Epoch: 070, Loss: 0.0995, Train: 0.9160, Val: 0.8203, Test: 0.7649
Recorded values: []


258it [00:37,  6.88it/s]
1it [00:00,  5.75it/s]

Epoch: 071, Loss: 0.0998, Train: 0.9174, Val: 0.8303, Test: 0.7649


258it [00:37,  6.88it/s]
1it [00:00,  6.33it/s]

Epoch: 072, Loss: 0.1008, Train: 0.9149, Val: 0.8050, Test: 0.7649


258it [00:37,  6.88it/s]
1it [00:00,  5.99it/s]

Epoch: 073, Loss: 0.0987, Train: 0.9192, Val: 0.8128, Test: 0.7649


258it [00:37,  6.89it/s]
1it [00:00,  6.41it/s]

Epoch: 074, Loss: 0.0992, Train: 0.9216, Val: 0.8208, Test: 0.7649


258it [00:37,  6.89it/s]
1it [00:00,  6.71it/s]

Epoch: 075, Loss: 0.0977, Train: 0.9222, Val: 0.8069, Test: 0.7649


258it [00:37,  6.88it/s]
1it [00:00,  6.37it/s]

Epoch: 076, Loss: 0.0975, Train: 0.9253, Val: 0.8267, Test: 0.7649


258it [00:37,  6.89it/s]
1it [00:00,  6.67it/s]

Epoch: 077, Loss: 0.0986, Train: 0.9211, Val: 0.8150, Test: 0.7649


258it [00:37,  6.89it/s]
1it [00:00,  5.78it/s]

Epoch: 078, Loss: 0.0976, Train: 0.9258, Val: 0.8210, Test: 0.7649


258it [00:37,  6.89it/s]
1it [00:00,  5.81it/s]

Epoch: 079, Loss: 0.0975, Train: 0.9253, Val: 0.8148, Test: 0.7649


258it [00:37,  6.89it/s]
1it [00:00,  6.94it/s]

Epoch: 080, Loss: 0.0986, Train: 0.9237, Val: 0.8149, Test: 0.7649
Recorded values: []


258it [00:37,  6.89it/s]
1it [00:00,  6.21it/s]

Epoch: 081, Loss: 0.0969, Train: 0.9297, Val: 0.8105, Test: 0.7649


258it [00:37,  6.89it/s]
1it [00:00,  6.45it/s]

Epoch: 082, Loss: 0.0964, Train: 0.9236, Val: 0.7988, Test: 0.7649


258it [00:37,  6.88it/s]
1it [00:00,  6.37it/s]

Epoch: 083, Loss: 0.0966, Train: 0.9306, Val: 0.8251, Test: 0.7649


258it [00:37,  6.88it/s]
1it [00:00,  6.06it/s]

Epoch: 084, Loss: 0.0964, Train: 0.9318, Val: 0.8197, Test: 0.7649


258it [00:37,  6.89it/s]
1it [00:00,  6.37it/s]

Epoch: 085, Loss: 0.0969, Train: 0.9318, Val: 0.8238, Test: 0.7649


258it [00:37,  6.87it/s]
1it [00:00,  6.14it/s]

Epoch: 086, Loss: 0.0958, Train: 0.9348, Val: 0.8257, Test: 0.7649


258it [00:37,  6.88it/s]
1it [00:00,  6.49it/s]

Epoch: 087, Loss: 0.0960, Train: 0.9258, Val: 0.8019, Test: 0.7649


258it [00:37,  6.85it/s]
1it [00:00,  6.67it/s]

Epoch: 088, Loss: 0.0953, Train: 0.9291, Val: 0.8041, Test: 0.7649


258it [00:37,  6.90it/s]
1it [00:00,  6.62it/s]

Epoch: 089, Loss: 0.0948, Train: 0.9327, Val: 0.8183, Test: 0.7649


258it [00:37,  6.90it/s]
1it [00:00,  6.37it/s]

Epoch: 090, Loss: 0.0954, Train: 0.9345, Val: 0.8218, Test: 0.7649
Recorded values: []


258it [00:37,  6.89it/s]
1it [00:00,  6.67it/s]

Epoch: 091, Loss: 0.0946, Train: 0.9273, Val: 0.8074, Test: 0.7649


258it [00:37,  6.91it/s]
1it [00:00,  6.29it/s]

Epoch: 092, Loss: 0.0952, Train: 0.9353, Val: 0.8335, Test: 0.7649


258it [00:37,  6.90it/s]
1it [00:00,  6.58it/s]

Epoch: 093, Loss: 0.0937, Train: 0.9374, Val: 0.8143, Test: 0.7649


258it [00:37,  6.91it/s]
1it [00:00,  5.78it/s]

Epoch: 094, Loss: 0.0946, Train: 0.9325, Val: 0.8092, Test: 0.7649


258it [00:37,  6.90it/s]
1it [00:00,  6.25it/s]

Epoch: 095, Loss: 0.0938, Train: 0.9389, Val: 0.8013, Test: 0.7649


258it [00:37,  6.89it/s]
1it [00:00,  6.21it/s]

Epoch: 096, Loss: 0.0938, Train: 0.9406, Val: 0.8108, Test: 0.7649


258it [00:37,  6.91it/s]
1it [00:00,  6.10it/s]

Epoch: 097, Loss: 0.0941, Train: 0.9419, Val: 0.7960, Test: 0.7649


258it [00:37,  6.88it/s]
1it [00:00,  5.99it/s]

Epoch: 098, Loss: 0.0928, Train: 0.9332, Val: 0.8002, Test: 0.7649


258it [00:37,  6.89it/s]
1it [00:00,  6.85it/s]

Epoch: 099, Loss: 0.0937, Train: 0.9408, Val: 0.7963, Test: 0.7649


258it [00:37,  6.90it/s]


Epoch: 100, Loss: 0.0938, Train: 0.9392, Val: 0.8127, Test: 0.7649
Recorded values: []

Run 1:

using inter_message_passing: store_true


258it [00:37,  6.90it/s]
1it [00:00,  5.75it/s]

Epoch: 001, Loss: 0.1663, Train: 0.7067, Val: 0.7361, Test: 0.7207


258it [00:38,  6.69it/s]
1it [00:00,  6.02it/s]

Epoch: 002, Loss: 0.1477, Train: 0.7573, Val: 0.7414, Test: 0.7698


258it [00:37,  6.83it/s]
1it [00:00,  6.45it/s]

Epoch: 003, Loss: 0.1402, Train: 0.7742, Val: 0.7639, Test: 0.7973


258it [00:37,  6.82it/s]
1it [00:00,  6.14it/s]

Epoch: 004, Loss: 0.1374, Train: 0.7865, Val: 0.7798, Test: 0.7556


258it [00:37,  6.83it/s]
1it [00:00,  6.21it/s]

Epoch: 005, Loss: 0.1362, Train: 0.7918, Val: 0.7923, Test: 0.7504


258it [00:37,  6.82it/s]
1it [00:00,  6.06it/s]

Epoch: 006, Loss: 0.1340, Train: 0.7915, Val: 0.7813, Test: 0.7504


258it [00:37,  6.80it/s]
1it [00:00,  6.06it/s]

Epoch: 007, Loss: 0.1309, Train: 0.8001, Val: 0.8048, Test: 0.7715


258it [00:37,  6.81it/s]
1it [00:00,  6.10it/s]

Epoch: 008, Loss: 0.1304, Train: 0.8113, Val: 0.8050, Test: 0.7877


258it [00:38,  6.75it/s]


In [29]:
# sum. dropout=0.5. readout-> layer -> +cliques -> final

0it [00:00, ?it/s]


Run 0:

using inter_message_passing: store_true


258it [00:28,  8.93it/s]
1it [00:00,  7.69it/s]

Epoch: 001, Loss: 0.1719, Train: 0.7334, Val: 0.7373, Test: 0.7231


258it [00:28,  9.04it/s]
1it [00:00,  7.46it/s]

Epoch: 002, Loss: 0.1459, Train: 0.7473, Val: 0.7580, Test: 0.7565


258it [00:28,  9.01it/s]
1it [00:00,  8.20it/s]

Epoch: 003, Loss: 0.1409, Train: 0.7898, Val: 0.7856, Test: 0.7686


258it [00:28,  9.04it/s]
1it [00:00,  7.09it/s]

Epoch: 004, Loss: 0.1355, Train: 0.7991, Val: 0.7795, Test: 0.7686


258it [00:28,  9.02it/s]
1it [00:00,  7.75it/s]

Epoch: 005, Loss: 0.1336, Train: 0.8015, Val: 0.7810, Test: 0.7686


258it [00:28,  8.99it/s]
1it [00:00,  7.58it/s]

Epoch: 006, Loss: 0.1322, Train: 0.8159, Val: 0.8000, Test: 0.7803


258it [00:28,  9.00it/s]
1it [00:00,  8.26it/s]

Epoch: 007, Loss: 0.1277, Train: 0.8322, Val: 0.7949, Test: 0.7803


258it [00:28,  9.03it/s]
1it [00:00,  8.00it/s]

Epoch: 008, Loss: 0.1274, Train: 0.8346, Val: 0.8032, Test: 0.7800


258it [00:28,  9.03it/s]
1it [00:00,  8.06it/s]

Epoch: 009, Loss: 0.1258, Train: 0.8389, Val: 0.7997, Test: 0.7800


258it [00:28,  9.02it/s]
1it [00:00,  8.26it/s]

Epoch: 010, Loss: 0.1242, Train: 0.8432, Val: 0.8000, Test: 0.7800
Recorded values: []


258it [00:28,  9.00it/s]
1it [00:00,  7.58it/s]

Epoch: 011, Loss: 0.1219, Train: 0.8455, Val: 0.8063, Test: 0.7841


258it [00:28,  9.01it/s]
1it [00:00,  8.33it/s]

Epoch: 012, Loss: 0.1211, Train: 0.8483, Val: 0.7939, Test: 0.7841


258it [00:28,  9.02it/s]
1it [00:00,  8.26it/s]

Epoch: 013, Loss: 0.1210, Train: 0.8551, Val: 0.8093, Test: 0.7736


258it [00:28,  9.01it/s]
1it [00:00,  8.33it/s]

Epoch: 014, Loss: 0.1198, Train: 0.8538, Val: 0.8004, Test: 0.7736


258it [00:28,  9.05it/s]
1it [00:00,  7.75it/s]

Epoch: 015, Loss: 0.1188, Train: 0.8592, Val: 0.8176, Test: 0.7865


258it [00:28,  9.01it/s]
1it [00:00,  7.25it/s]

Epoch: 016, Loss: 0.1199, Train: 0.8679, Val: 0.8163, Test: 0.7865


258it [00:28,  9.03it/s]
1it [00:00,  7.75it/s]

Epoch: 017, Loss: 0.1169, Train: 0.8648, Val: 0.8293, Test: 0.7781


258it [00:28,  9.03it/s]
1it [00:00,  8.47it/s]

Epoch: 018, Loss: 0.1163, Train: 0.8693, Val: 0.8177, Test: 0.7781


258it [00:28,  9.04it/s]
1it [00:00,  7.25it/s]

Epoch: 019, Loss: 0.1147, Train: 0.8701, Val: 0.7959, Test: 0.7781


258it [00:28,  9.02it/s]
1it [00:00,  7.46it/s]

Epoch: 020, Loss: 0.1151, Train: 0.8779, Val: 0.8076, Test: 0.7781
Recorded values: []


258it [00:28,  9.03it/s]
1it [00:00,  7.81it/s]

Epoch: 021, Loss: 0.1132, Train: 0.8753, Val: 0.8026, Test: 0.7781


258it [00:28,  9.02it/s]
1it [00:00,  7.94it/s]

Epoch: 022, Loss: 0.1129, Train: 0.8793, Val: 0.8027, Test: 0.7781


258it [00:28,  9.04it/s]
1it [00:00,  8.13it/s]

Epoch: 023, Loss: 0.1124, Train: 0.8751, Val: 0.7938, Test: 0.7781


258it [00:28,  9.02it/s]
1it [00:00,  8.26it/s]

Epoch: 024, Loss: 0.1116, Train: 0.8882, Val: 0.8167, Test: 0.7781


258it [00:28,  9.03it/s]
1it [00:00,  7.58it/s]

Epoch: 025, Loss: 0.1117, Train: 0.8839, Val: 0.7981, Test: 0.7781


258it [00:28,  9.04it/s]
1it [00:00,  8.00it/s]

Epoch: 026, Loss: 0.1103, Train: 0.8881, Val: 0.8041, Test: 0.7781


258it [00:28,  9.04it/s]
1it [00:00,  7.81it/s]

Epoch: 027, Loss: 0.1095, Train: 0.8878, Val: 0.7955, Test: 0.7781


258it [00:28,  8.99it/s]
1it [00:00,  8.77it/s]

Epoch: 028, Loss: 0.1095, Train: 0.8981, Val: 0.8004, Test: 0.7781


258it [00:28,  9.04it/s]
1it [00:00,  7.52it/s]

Epoch: 029, Loss: 0.1081, Train: 0.8905, Val: 0.8178, Test: 0.7781


258it [00:28,  9.01it/s]
1it [00:00,  8.06it/s]

Epoch: 030, Loss: 0.1077, Train: 0.8969, Val: 0.8149, Test: 0.7781
Recorded values: []


258it [00:28,  9.02it/s]
1it [00:00,  7.81it/s]

Epoch: 031, Loss: 0.1073, Train: 0.9023, Val: 0.8137, Test: 0.7781


258it [00:28,  9.02it/s]
1it [00:00,  7.52it/s]

Epoch: 032, Loss: 0.1068, Train: 0.8932, Val: 0.8075, Test: 0.7781


258it [00:28,  8.99it/s]
1it [00:00,  8.00it/s]

Epoch: 033, Loss: 0.1052, Train: 0.8912, Val: 0.8033, Test: 0.7781


258it [00:28,  9.03it/s]
1it [00:00,  8.93it/s]

Epoch: 034, Loss: 0.1055, Train: 0.8994, Val: 0.7987, Test: 0.7781


258it [00:28,  9.02it/s]
1it [00:00,  7.63it/s]

Epoch: 035, Loss: 0.1063, Train: 0.9094, Val: 0.8057, Test: 0.7781


258it [00:28,  9.02it/s]
1it [00:00,  7.87it/s]

Epoch: 036, Loss: 0.1068, Train: 0.9086, Val: 0.8227, Test: 0.7781


258it [00:28,  9.01it/s]
1it [00:00,  8.26it/s]

Epoch: 037, Loss: 0.1040, Train: 0.9067, Val: 0.8163, Test: 0.7781


258it [00:28,  9.04it/s]
1it [00:00,  7.25it/s]

Epoch: 038, Loss: 0.1032, Train: 0.9077, Val: 0.7986, Test: 0.7781


258it [00:29,  8.79it/s]
1it [00:00,  7.41it/s]

Epoch: 039, Loss: 0.1058, Train: 0.9026, Val: 0.8017, Test: 0.7781


258it [00:29,  8.69it/s]
1it [00:00,  7.58it/s]

Epoch: 040, Loss: 0.1026, Train: 0.9164, Val: 0.8037, Test: 0.7781
Recorded values: []


258it [00:29,  8.89it/s]
1it [00:00,  7.52it/s]

Epoch: 041, Loss: 0.1041, Train: 0.9136, Val: 0.8110, Test: 0.7781


258it [00:29,  8.74it/s]
1it [00:00,  7.87it/s]

Epoch: 042, Loss: 0.1036, Train: 0.9096, Val: 0.8084, Test: 0.7781


258it [00:28,  9.01it/s]
1it [00:00,  7.75it/s]

Epoch: 043, Loss: 0.1014, Train: 0.9157, Val: 0.8063, Test: 0.7781


258it [00:28,  9.01it/s]
1it [00:00,  8.06it/s]

Epoch: 044, Loss: 0.1018, Train: 0.9187, Val: 0.8186, Test: 0.7781


258it [00:28,  9.05it/s]
1it [00:00,  7.69it/s]

Epoch: 045, Loss: 0.1002, Train: 0.9242, Val: 0.8281, Test: 0.7781


258it [00:28,  9.01it/s]
1it [00:00,  8.13it/s]

Epoch: 046, Loss: 0.1004, Train: 0.9219, Val: 0.8091, Test: 0.7781


258it [00:28,  9.03it/s]
1it [00:00,  8.62it/s]

Epoch: 047, Loss: 0.0995, Train: 0.9196, Val: 0.8154, Test: 0.7781


258it [00:28,  8.98it/s]
1it [00:00,  7.75it/s]

Epoch: 048, Loss: 0.0992, Train: 0.9247, Val: 0.8221, Test: 0.7781


258it [00:28,  9.05it/s]
1it [00:00,  8.34it/s]

Epoch: 049, Loss: 0.0997, Train: 0.9251, Val: 0.8053, Test: 0.7781


258it [00:28,  9.03it/s]
1it [00:00,  7.94it/s]

Epoch: 050, Loss: 0.0992, Train: 0.9271, Val: 0.8103, Test: 0.7781
Recorded values: []


258it [00:28,  9.06it/s]
1it [00:00,  7.09it/s]

Epoch: 051, Loss: 0.0978, Train: 0.9286, Val: 0.8026, Test: 0.7781


258it [00:28,  9.07it/s]
1it [00:00,  8.00it/s]

Epoch: 052, Loss: 0.0979, Train: 0.9308, Val: 0.8164, Test: 0.7781


258it [00:28,  9.05it/s]
1it [00:00,  8.20it/s]

Epoch: 053, Loss: 0.0999, Train: 0.9259, Val: 0.8121, Test: 0.7781


258it [00:28,  9.05it/s]
1it [00:00,  7.04it/s]

Epoch: 054, Loss: 0.0982, Train: 0.9300, Val: 0.8127, Test: 0.7781


258it [00:28,  9.04it/s]
1it [00:00,  8.20it/s]

Epoch: 055, Loss: 0.0970, Train: 0.9322, Val: 0.8083, Test: 0.7781


258it [00:28,  9.03it/s]
1it [00:00,  8.62it/s]

Epoch: 056, Loss: 0.0966, Train: 0.9386, Val: 0.8118, Test: 0.7781


258it [00:28,  9.06it/s]
1it [00:00,  8.13it/s]

Epoch: 057, Loss: 0.0961, Train: 0.9206, Val: 0.8121, Test: 0.7781


258it [00:28,  9.05it/s]
1it [00:00,  8.00it/s]

Epoch: 058, Loss: 0.0968, Train: 0.9340, Val: 0.8235, Test: 0.7781


258it [00:28,  9.03it/s]
1it [00:00,  8.47it/s]

Epoch: 059, Loss: 0.0966, Train: 0.9357, Val: 0.8077, Test: 0.7781


258it [00:28,  9.02it/s]
1it [00:00,  8.20it/s]

Epoch: 060, Loss: 0.0953, Train: 0.9340, Val: 0.7988, Test: 0.7781
Recorded values: []


258it [00:28,  9.01it/s]
1it [00:00,  8.85it/s]

Epoch: 061, Loss: 0.0950, Train: 0.9347, Val: 0.8176, Test: 0.7781


258it [00:28,  9.03it/s]
1it [00:00,  8.20it/s]

Epoch: 062, Loss: 0.0925, Train: 0.9408, Val: 0.8083, Test: 0.7781


258it [00:28,  8.95it/s]
1it [00:00,  7.69it/s]

Epoch: 063, Loss: 0.0937, Train: 0.9436, Val: 0.8076, Test: 0.7781


258it [00:28,  9.01it/s]
1it [00:00,  7.75it/s]

Epoch: 064, Loss: 0.0937, Train: 0.9441, Val: 0.8163, Test: 0.7781


258it [00:29,  8.86it/s]
1it [00:00,  8.20it/s]

Epoch: 065, Loss: 0.0948, Train: 0.9484, Val: 0.8119, Test: 0.7781


258it [00:28,  8.94it/s]
1it [00:00,  8.20it/s]

Epoch: 066, Loss: 0.0936, Train: 0.9467, Val: 0.8042, Test: 0.7781


258it [00:29,  8.84it/s]
1it [00:00,  8.33it/s]

Epoch: 067, Loss: 0.0922, Train: 0.9475, Val: 0.8269, Test: 0.7781


258it [00:28,  9.04it/s]
1it [00:00,  7.94it/s]

Epoch: 068, Loss: 0.0939, Train: 0.9438, Val: 0.8357, Test: 0.7678


258it [00:28,  9.03it/s]
1it [00:00,  7.52it/s]

Epoch: 069, Loss: 0.0952, Train: 0.9474, Val: 0.8279, Test: 0.7678


258it [00:28,  9.01it/s]
1it [00:00,  7.69it/s]

Epoch: 070, Loss: 0.0922, Train: 0.9474, Val: 0.8283, Test: 0.7678
Recorded values: []


258it [00:28,  8.99it/s]
1it [00:00,  7.58it/s]

Epoch: 071, Loss: 0.0897, Train: 0.9493, Val: 0.8100, Test: 0.7678


258it [00:28,  9.00it/s]
1it [00:00,  7.94it/s]

Epoch: 072, Loss: 0.0905, Train: 0.9513, Val: 0.8099, Test: 0.7678


258it [00:28,  9.01it/s]
1it [00:00,  7.87it/s]

Epoch: 073, Loss: 0.0906, Train: 0.9499, Val: 0.8123, Test: 0.7678


258it [00:28,  9.03it/s]
1it [00:00,  7.58it/s]

Epoch: 074, Loss: 0.0899, Train: 0.9533, Val: 0.8313, Test: 0.7678


258it [00:28,  9.04it/s]
1it [00:00,  7.63it/s]

Epoch: 075, Loss: 0.0914, Train: 0.9534, Val: 0.8099, Test: 0.7678


258it [00:28,  9.01it/s]
1it [00:00,  8.06it/s]

Epoch: 076, Loss: 0.0930, Train: 0.9501, Val: 0.8014, Test: 0.7678


258it [00:28,  9.02it/s]
1it [00:00,  8.00it/s]

Epoch: 077, Loss: 0.0906, Train: 0.9576, Val: 0.8148, Test: 0.7678


258it [00:28,  9.02it/s]
1it [00:00,  7.87it/s]

Epoch: 078, Loss: 0.0875, Train: 0.9560, Val: 0.8021, Test: 0.7678


258it [00:28,  9.05it/s]
1it [00:00,  8.06it/s]

Epoch: 079, Loss: 0.0879, Train: 0.9541, Val: 0.8011, Test: 0.7678


258it [00:28,  9.02it/s]
1it [00:00,  7.63it/s]

Epoch: 080, Loss: 0.0885, Train: 0.9568, Val: 0.7882, Test: 0.7678
Recorded values: []


258it [00:28,  8.99it/s]
1it [00:00,  8.26it/s]

Epoch: 081, Loss: 0.0883, Train: 0.9603, Val: 0.8175, Test: 0.7678


258it [00:28,  9.00it/s]
1it [00:00,  7.63it/s]

Epoch: 082, Loss: 0.0871, Train: 0.9585, Val: 0.8282, Test: 0.7678


258it [00:28,  9.02it/s]
1it [00:00,  7.35it/s]

Epoch: 083, Loss: 0.0873, Train: 0.9589, Val: 0.8184, Test: 0.7678


258it [00:28,  9.03it/s]
1it [00:00,  7.69it/s]

Epoch: 084, Loss: 0.0881, Train: 0.9576, Val: 0.8169, Test: 0.7678


258it [00:28,  9.02it/s]
1it [00:00,  7.75it/s]

Epoch: 085, Loss: 0.0883, Train: 0.9561, Val: 0.8077, Test: 0.7678


258it [00:28,  9.03it/s]
1it [00:00,  8.26it/s]

Epoch: 086, Loss: 0.0860, Train: 0.9612, Val: 0.7985, Test: 0.7678


258it [00:28,  9.05it/s]
1it [00:00,  7.63it/s]

Epoch: 087, Loss: 0.0868, Train: 0.9584, Val: 0.8125, Test: 0.7678


258it [00:28,  9.04it/s]
1it [00:00,  7.81it/s]

Epoch: 088, Loss: 0.0859, Train: 0.9619, Val: 0.8192, Test: 0.7678


258it [00:28,  9.00it/s]
1it [00:00,  8.20it/s]

Epoch: 089, Loss: 0.0841, Train: 0.9632, Val: 0.8230, Test: 0.7678


258it [00:28,  9.04it/s]
1it [00:00,  8.26it/s]

Epoch: 090, Loss: 0.0858, Train: 0.9630, Val: 0.8007, Test: 0.7678
Recorded values: []


258it [00:28,  9.01it/s]
1it [00:00,  7.46it/s]

Epoch: 091, Loss: 0.0824, Train: 0.9631, Val: 0.8149, Test: 0.7678


258it [00:28,  9.04it/s]
1it [00:00,  7.41it/s]

Epoch: 092, Loss: 0.0893, Train: 0.9613, Val: 0.8275, Test: 0.7678


258it [00:28,  9.04it/s]
1it [00:00,  8.26it/s]

Epoch: 093, Loss: 0.0849, Train: 0.9649, Val: 0.8270, Test: 0.7678


258it [00:28,  9.03it/s]
1it [00:00,  7.63it/s]

Epoch: 094, Loss: 0.0838, Train: 0.9644, Val: 0.8213, Test: 0.7678


258it [00:28,  8.98it/s]
1it [00:00,  7.69it/s]

Epoch: 095, Loss: 0.0824, Train: 0.9617, Val: 0.8222, Test: 0.7678


258it [00:28,  9.02it/s]
1it [00:00,  7.94it/s]

Epoch: 096, Loss: 0.0822, Train: 0.9678, Val: 0.8170, Test: 0.7678


258it [00:28,  9.01it/s]
1it [00:00,  8.06it/s]

Epoch: 097, Loss: 0.0828, Train: 0.9652, Val: 0.8125, Test: 0.7678


258it [00:28,  9.03it/s]
1it [00:00,  8.40it/s]

Epoch: 098, Loss: 0.0822, Train: 0.9680, Val: 0.8276, Test: 0.7678


258it [00:28,  9.00it/s]
1it [00:00,  8.00it/s]

Epoch: 099, Loss: 0.0825, Train: 0.9695, Val: 0.8250, Test: 0.7678


258it [00:28,  9.02it/s]


Epoch: 100, Loss: 0.0796, Train: 0.9701, Val: 0.8207, Test: 0.7678
Recorded values: []


IndexError: list index out of range

### Analize params

In [None]:
import matplotlib.pyplot as plt
p_epoch    = np.array([batch["p"] for batch in epoch_values]).flatten()
beta_epoch = np.array([batch["beta"] for batch in epoch_values]).flatten() 

plt.figure(figsize=(10,4))
plt.title("Evolution of parameters beta and p during training - Power Mean")
plt.plot(p_epoch, "r-", label="p")
plt.plot(beta_epoch, "b-", label="beta")
plt.ylim(0, 1.1*np.amax(np.maximum(p_epoch, beta_epoch)))
plt.legend()
plt.show()