In [318]:
import numpy as np
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [319]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.data import Data
import torch.nn.functional as F
import torch_sparse
from torchdiffeq import odeint
from torch_geometric.utils.num_nodes import maybe_num_nodes
from torch_geometric.utils import add_remaining_self_loops
from torch_scatter import scatter_add
# from torchdiffeq import odeint_adjoint as odeint # Might be more stable according to docs of torchdiffeq
import numpy as np
import networkx as nx
import matplotlib
import matplotlib.pyplot as plt

import time

from torch_geometric.datasets import Planetoid
from pathlib import Path


In [320]:
opt = {
    'cora_defaults': False, 'dataset': 'Cora', 'data_norm': 'rw', 'self_loop_weight': 1.0, 'use_labels': False,
    'geom_gcn_splits': False, 'num_splits': 2, 'label_rate': 0.5, 'planetoid_split': False, 'hidden_dim': 80,
    'fc_out': False, 'input_dropout': 0.5, 'dropout': 0.046878964627763316, 'batch_norm': False, 'optimizer': 'adamax',
    'lr': 0.022924849756740397, 'decay': 0.00507685443154266, 'epoch': 100, 'alpha': 1.0, 'alpha_dim': 'sc',
    'no_alpha_sigmoid': False, 'beta_dim': 'sc', 'block': 'constant', 'function': 'laplacian', 'use_mlp': False,
    'add_source': True, 'cgnn': False, 'time': 18.294754260552843, 'augment': False, 'method': 'euler', 'step_size': 1,
    'max_iters': 100, 'adjoint_method': 'adaptive_heun', 'adjoint': False, 'adjoint_step_size': 1,
    'tol_scale': 821.9773048827274, 'tol_scale_adjoint': 1.0, 'ode_blocks': 1, 'max_nfe': 2000, 'no_early': True,
    'earlystopxT': 3, 'max_test_steps': 100, 'leaky_relu_slope': 0.2, 'attention_dropout': 0.0, 'heads': 8,
    'attention_norm_idx': 1, 'attention_dim': 128, 'mix_features': False, 'reweight_attention': False,
    'attention_type': 'scaled_dot', 'square_plus': True, 'jacobian_norm2': None, 'total_deriv': None,
    'kinetic_energy': None, 'directional_penalty': None, 'not_lcc': True, 'rewiring': None, 'gdc_method': 'ppr',
    'gdc_sparsification': 'topk', 'gdc_k': 64, 'gdc_threshold': 0.01, 'gdc_avg_degree': 64, 'ppr_alpha': 0.05,
    'heat_time': 3.0, 'att_samp_pct': 1, 'use_flux': False, 'exact': True, 'M_nodes': 64, 'new_edges': 'random',
    'sparsify': 'S_hat', 'threshold_type': 'addD_rvR', 'rw_addD': 0.02, 'rw_rmvR': 0.02, 'rewire_KNN': False,
    'rewire_KNN_T': 'T0', 'rewire_KNN_epoch': 10, 'rewire_KNN_k': 64, 'rewire_KNN_sym': False, 'KNN_online': False,
    'KNN_online_reps': 4, 'KNN_space': 'pos_distance', 'beltrami': False, 'fa_layer': False, 'pos_enc_type': 'GDC',
    'pos_enc_orientation': 'row', 'feat_hidden_dim': 64, 'pos_enc_hidden_dim': 16, 'edge_sampling': False,
    'edge_sampling_T': 'T0', 'edge_sampling_epoch': 5, 'edge_sampling_add': 0.64,
    'edge_sampling_add_type': 'importance', 'edge_sampling_rmv': 0.32, 'edge_sampling_sym': False,
    'edge_sampling_online': False, 'edge_sampling_online_reps': 4, 'edge_sampling_space': 'attention',
    'symmetric_attention': False, 'fa_layer_edge_sampling_rmv': 0.8, 'gpu': 0, 'pos_enc_csv': False,
    'pos_dist_quantile': 0.001, 'adaptive': False, 'attention_rewiring': False, 'baseline': False, 'cpus': 1,
    'dt': 0.001, 'dt_min': 1e-05, 'gpus': 0.5, 'grace_period': 20, 'max_epochs': 1000, 'metric': 'accuracy',
    'name': 'cora_beltrami_splits', 'num_init': 1, 'num_samples': 1000, 'patience': 100, 'reduction_factor': 10,
    'regularise': False, 'use_lcc': True
}

In [321]:
# Set torch device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# device = 'cpu'

## Get Dataset

In [322]:
dataset_dir = Path('data')
dataset_dir = dataset_dir.absolute()
if not dataset_dir.exists():
    dataset_dir.mkdir(parents=True)

dataset = Planetoid(dataset_dir, 'Cora')

# Some info
print(f'Number of graphs: {len(dataset)}')
print(f'Number of features: {dataset.num_features}')
print(f'Number of classes: {dataset.num_classes}')
data = dataset.data
print(f'Number of nodes: {data.num_nodes}')
print(f'Number of edges: {data.num_edges}')
print(f'Average node degree: {data.num_edges / data.num_nodes:.2f}')
print(f'Number of training nodes: {data.train_mask.sum()}')
print(f'Number of validation nodes: {data.val_mask.sum()}')
print(f'Number of test nodes: {data.test_mask.sum()}')
print(f'Contains isolated nodes: {data.contains_isolated_nodes()}')
print(f'Contains self-loops: {data.contains_self_loops()}')
print(f'Is undirected: {data.is_undirected()}')

Number of graphs: 1
Number of features: 1433
Number of classes: 7
Number of nodes: 2708
Number of edges: 10556
Average node degree: 3.90
Number of training nodes: 140
Number of validation nodes: 500
Number of test nodes: 1000
Contains isolated nodes: False
Contains self-loops: False
Is undirected: True


In [323]:
# # Plot the graph
# G = nx.Graph()
#
# G.add_nodes_from(list(range(data.num_nodes)))
# G.add_edges_from([tuple(x) for x in data.edge_index.T.tolist()])
# nx.draw(G)
# plt.gca().set_facecolor('white')

In [324]:
# Train the given model on the given graph for num_epochs
def train(model, optimizer, data):
    model.train()
    optimizer.zero_grad()
    x = data.x
    y = data.y.squeeze()

    # Set up the loss and the optimizer
    loss_fn = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
    out = model(x)

    loss = loss_fn(out[data.train_mask], y[data.train_mask])
    loss.backward()
    optimizer.step()
    return loss.item()

In [325]:
class LaplacianODEFunc(MessagePassing):

  # currently requires in_features = out_features
  def __init__(self, in_features, out_features, opt, data, device):
    super(MessagePassing, self).__init__()
    self.opt = opt
    self.device = device
    self.edge_index = None
    self.edge_weight = None
    self.attention_weights = None
    self.alpha_train = nn.Parameter(torch.tensor(0.0))
    self.beta_train = nn.Parameter(torch.tensor(0.0))
    self.x0 = None
    self.nfe = 0
    self.alpha_sc = nn.Parameter(torch.ones(1))
    self.beta_sc = nn.Parameter(torch.ones(1))
    self.in_features = in_features
    self.out_features = out_features
    self.w = nn.Parameter(torch.eye(opt['hidden_dim']))
    self.d = nn.Parameter(torch.zeros(opt['hidden_dim']) + 1)
    self.alpha_sc = nn.Parameter(torch.ones(1))
    self.beta_sc = nn.Parameter(torch.ones(1))

  def sparse_multiply(self, x):
    if self.opt['block'] in ['attention']:  # adj is a multihead attention
      mean_attention = self.attention_weights.mean(dim=1)
      ax = torch_sparse.spmm(self.edge_index, mean_attention, x.shape[0], x.shape[0], x)
    elif self.opt['block'] in ['mixed', 'hard_attention']:  # adj is a torch sparse matrix
      ax = torch_sparse.spmm(self.edge_index, self.attention_weights, x.shape[0], x.shape[0], x)
    else:  # adj is a torch sparse matrix
      ax = torch_sparse.spmm(self.edge_index, self.edge_weight, x.shape[0], x.shape[0], x)
    return ax

  def forward(self, t, x):  # the t param is needed by the ODE solver.
    self.nfe += 1
    ax = self.sparse_multiply(x)
    if not self.opt['no_alpha_sigmoid']:
      alpha = torch.sigmoid(self.alpha_train)
    else:
      alpha = self.alpha_train

    f = alpha * (ax - x)
    if self.opt['add_source']:
      f = f + self.beta_train * self.x0
    return f

In [326]:
def get_rw_adj(edge_index, edge_weight=None, norm_dim=1, fill_value=0., num_nodes=None, dtype=None):
  num_nodes = maybe_num_nodes(edge_index, num_nodes)

  if edge_weight is None:
    edge_weight = torch.ones((edge_index.size(1),), dtype=dtype,
                             device=edge_index.device)

  if not fill_value == 0:
    edge_index, tmp_edge_weight = add_remaining_self_loops(
      edge_index, edge_weight, fill_value, num_nodes)
    assert tmp_edge_weight is not None
    edge_weight = tmp_edge_weight

  row, col = edge_index[0], edge_index[1]
  indices = row if norm_dim == 0 else col
  deg = scatter_add(edge_weight, indices, dim=0, dim_size=num_nodes)
  deg_inv_sqrt = deg.pow_(-1)
  edge_weight = deg_inv_sqrt[indices] * edge_weight if norm_dim == 0 else edge_weight * deg_inv_sqrt[indices]
  return edge_index, edge_weight

In [327]:
class ConstantODEblock(nn.Module):
  def __init__(self, odefunc, opt, data, device, t=torch.tensor([0, 1])):
    super(ConstantODEblock, self).__init__()
    self.opt = opt
    self.t = t

    self.aug_dim = 2 if opt['augment'] else 1
    self.odefunc = odefunc(self.aug_dim * opt['hidden_dim'], self.aug_dim * opt['hidden_dim'], opt, data, device)


    if opt['adjoint']:
      from torchdiffeq import odeint_adjoint as odeint
    else:
      from torchdiffeq import odeint
    self.train_integrator = odeint
    self.test_integrator = None
    self.set_tol()

    self.aug_dim = 2 if opt['augment'] else 1
    self.odefunc = odefunc(self.aug_dim * opt['hidden_dim'], self.aug_dim * opt['hidden_dim'], opt, data, device)
    edge_index, edge_weight = get_rw_adj(data.edge_index, edge_weight=data.edge_attr, norm_dim=1,
                                                                   fill_value=opt['self_loop_weight'],
                                                                   num_nodes=data.num_nodes,
                                                                   dtype=data.x.dtype)
    self.odefunc.edge_index = edge_index.to(device)
    self.odefunc.edge_weight = edge_weight.to(device)

    if opt['adjoint']:
      from torchdiffeq import odeint_adjoint as odeint
    else:
      from torchdiffeq import odeint

    self.train_integrator = odeint
    self.test_integrator = odeint
    self.set_tol()

  def set_x0(self, x0):
    self.odefunc.x0 = x0.clone().detach()

  def set_tol(self):
    self.atol = self.opt['tol_scale'] * 1e-7
    self.rtol = self.opt['tol_scale'] * 1e-9
    if self.opt['adjoint']:
      self.atol_adjoint = self.opt['tol_scale_adjoint'] * 1e-7
      self.rtol_adjoint = self.opt['tol_scale_adjoint'] * 1e-9

  def reset_tol(self):
    self.atol = 1e-7
    self.rtol = 1e-9
    self.atol_adjoint = 1e-7
    self.rtol_adjoint = 1e-9

  def set_time(self, time):
    self.t = torch.tensor([0, time]).to(self.device)



  def forward(self, x):
    t = self.t.type_as(x)

    integrator = self.train_integrator if self.training else self.test_integrator


    func = self.odefunc
    state = x

    if self.opt["adjoint"] and self.training:
      state_dt = integrator(
        func, state, t,
        method=self.opt['method'],
        options=dict(step_size=self.opt['step_size'], max_iters=self.opt['max_iters']),
        adjoint_method=self.opt['adjoint_method'],
        adjoint_options=dict(step_size = self.opt['adjoint_step_size'], max_iters=self.opt['max_iters']),
        atol=self.atol,
        rtol=self.rtol,
        adjoint_atol=self.atol_adjoint,
        adjoint_rtol=self.rtol_adjoint)
    else:
      state_dt = integrator(
        func, state, t,
        method=self.opt['method'],
        options=dict(step_size=self.opt['step_size'], max_iters=self.opt['max_iters']),
        atol=self.atol,
        rtol=self.rtol)


    z = state_dt[1]
    return z

  def __repr__(self):
    return self.__class__.__name__ + '( Time Interval ' + str(self.t[0].item()) + ' -> ' + str(self.t[1].item()) \
           + ")"

In [328]:
class GNN(MessagePassing):
  def __init__(self, opt, dataset, device=torch.device('cpu')):
    super(MessagePassing, self).__init__()
    self.opt = opt
    self.T = opt['time']
    self.num_classes = dataset.num_classes
    self.num_features = dataset.data.num_features
    self.num_nodes = dataset.data.num_nodes
    self.device = device

    if opt['beltrami']:
      self.mx = nn.Linear(self.num_features, opt['feat_hidden_dim'])
      self.mp = nn.Linear(opt['pos_enc_dim'], opt['pos_enc_hidden_dim'])
      opt['hidden_dim'] = opt['feat_hidden_dim'] + opt['pos_enc_hidden_dim']
    else:
      self.m1 = nn.Linear(self.num_features, opt['hidden_dim'])

    if self.opt['use_mlp']:
      self.m11 = nn.Linear(opt['hidden_dim'], opt['hidden_dim'])
      self.m12 = nn.Linear(opt['hidden_dim'], opt['hidden_dim'])
    if opt['use_labels']:
      # todo - fastest way to propagate this everywhere, but error prone - refactor later
      opt['hidden_dim'] = opt['hidden_dim'] + dataset.num_classes
    else:
      self.hidden_dim = opt['hidden_dim']
    if opt['fc_out']:
      self.fc = nn.Linear(opt['hidden_dim'], opt['hidden_dim'])
    self.m2 = nn.Linear(opt['hidden_dim'], dataset.num_classes)
    if self.opt['batch_norm']:
      self.bn_in = torch.nn.BatchNorm1d(opt['hidden_dim'])
      self.bn_out = torch.nn.BatchNorm1d(opt['hidden_dim'])

    self.f = LaplacianODEFunc
    block = ConstantODEblock
    time_tensor = torch.tensor([0, self.T]).to(device)
    self.odeblock = block(self.f, opt, dataset.data, device, t=time_tensor).to(device)

  def getNFE(self):
    return self.odeblock.odefunc.nfe + self.odeblock.reg_odefunc.odefunc.nfe

  def resetNFE(self):
    self.odeblock.odefunc.nfe = 0
    self.odeblock.reg_odefunc.odefunc.nfe = 0

  def reset(self):
    self.m1.reset_parameters()
    self.m2.reset_parameters()

  def forward(self, x, pos_encoding=None):
    # Encode each node based on its feature.
    if self.opt['use_labels']:
      y = x[:, -self.num_classes:]
      x = x[:, :-self.num_classes]

    if self.opt['beltrami']:
      x = F.dropout(x, self.opt['input_dropout'], training=self.training)
      x = self.mx(x)
      p = F.dropout(pos_encoding, self.opt['input_dropout'], training=self.training)
      p = self.mp(p)
      x = torch.cat([x, p], dim=1)
    else:
      x = F.dropout(x, self.opt['input_dropout'], training=self.training)
      x = self.m1(x)

    if self.opt['use_mlp']:
      x = F.dropout(x, self.opt['dropout'], training=self.training)
      x = F.dropout(x + self.m11(F.relu(x)), self.opt['dropout'], training=self.training)
      x = F.dropout(x + self.m12(F.relu(x)), self.opt['dropout'], training=self.training)
    # todo investigate if some input non-linearity solves the problem with smooth deformations identified in the ANODE paper

    if self.opt['use_labels']:
      x = torch.cat([x, y], dim=-1)

    if self.opt['batch_norm']:
      x = self.bn_in(x)

    # Solve the initial value problem of the ODE.
    if self.opt['augment']:
      c_aux = torch.zeros(x.shape).to(self.device)
      x = torch.cat([x, c_aux], dim=1)

    self.odeblock.set_x0(x)


    z = self.odeblock(x)

    if self.opt['augment']:
      z = torch.split(z, x.shape[1] // 2, dim=1)[0]

    # Activation.
    z = F.relu(z)

    if self.opt['fc_out']:
      z = self.fc(z)
      z = F.relu(z)

    # Dropout.
    z = F.dropout(z, self.opt['dropout'], training=self.training)

    # Decode each node embedding to get node label.
    z = self.m2(z)
    return z

In [329]:
def print_model_params(model):
  print(model)
  for name, param in model.named_parameters():
    if param.requires_grad:
      print(name)
      print(param.data.shape)

In [330]:
def set_train_val_test_split(
        seed: int,
        data: Data,
        num_development: int = 1500,
        num_per_class: int = 20) -> Data:
  rnd_state = np.random.RandomState(seed)
  num_nodes = data.y.shape[0]
  development_idx = rnd_state.choice(num_nodes, num_development, replace=False)
  test_idx = [i for i in np.arange(num_nodes) if i not in development_idx]

  train_idx = []
  rnd_state = np.random.RandomState(seed)
  for c in range(data.y.max() + 1):
    class_idx = development_idx[np.where(data.y[development_idx].cpu() == c)[0]]
    train_idx.extend(rnd_state.choice(class_idx, num_per_class, replace=False))

  val_idx = [i for i in development_idx if i not in train_idx]

  def get_mask(idx):
    mask = torch.zeros(num_nodes, dtype=torch.bool)
    mask[idx] = 1
    return mask

  data.train_mask = get_mask(train_idx)
  data.val_mask = get_mask(val_idx)
  data.test_mask = get_mask(test_idx)

  return data

In [331]:
@torch.no_grad()
def test(model, data, opt=None):  # opt required for runtime polymorphism
  model.eval()
  logits, accs = model(data.x), []
  for _, mask in data('train_mask', 'val_mask', 'test_mask'):
    pred = logits[mask].max(1)[1]
    acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item()
    accs.append(acc)
  return accs

In [332]:
# def get_optimizer(name, parameters, lr, weight_decay=0):
#   if name == 'sgd':
#     return torch.optim.SGD(parameters, lr=lr, weight_decay=weight_decay)
#   elif name == 'rmsprop':
#     return torch.optim.RMSprop(parameters, lr=lr, weight_decay=weight_decay)
#   elif name == 'adagrad':
#     return torch.optim.Adagrad(parameters, lr=lr, weight_decay=weight_decay)
#   elif name == 'adam':
#     return torch.optim.Adam(parameters, lr=lr, weight_decay=weight_decay)
#   elif name == 'adamax':
#     return torch.optim.Adamax(parameters, lr=lr, weight_decay=weight_decay)
#   else:
#     raise Exception("Unsupported optimizer: {}".format(name))

In [333]:
# if opt['rewire_KNN'] or opt['fa_layer']:
#   model = GNN_KNN(opt, dataset, device).to(device) if opt["no_early"] else GNNKNNEarly(opt, dataset, device).to(device)
# else:
model = GNN(opt, dataset, device).to(device) # if opt["no_early"] else GNNEarly(opt, dataset, device).to(device)

# if not opt['planetoid_split'] and opt['dataset'] in ['Cora','Citeseer','Pubmed']:
#   dataset.data = set_train_val_test_split(np.random.randint(0, 1000), dataset.data, num_development=5000 if opt["dataset"] == "CoauthorCS" else 1500)

data = dataset.data.to(device)

parameters = [p for p in model.parameters() if p.requires_grad]
# print_model_params(model)
optimizer = torch.optim.Adam(parameters, lr=opt['lr'], weight_decay=opt['decay'])
best_time = best_epoch = train_acc = val_acc = test_acc = 0

this_test = test

for epoch in range(1, opt['epoch']):
  start_time = time.time()

  loss = train(model, optimizer, data)
  tmp_train_acc, tmp_val_acc, tmp_test_acc = this_test(model, data, opt)

  best_time = opt['time']
  if tmp_val_acc > val_acc:
    best_epoch = epoch
    train_acc = tmp_train_acc
    val_acc = tmp_val_acc
    test_acc = tmp_test_acc
    best_time = opt['time']
  # if not opt['no_early'] and model.odeblock.test_integrator.solver.best_val > val_acc:
  #   best_epoch = epoch
  #   val_acc = model.odeblock.test_integrator.solver.best_val
  #   test_acc = model.odeblock.test_integrator.solver.best_test
  #   train_acc = model.odeblock.test_integrator.solver.best_train
  #   best_time = model.odeblock.test_integrator.solver.best_time

  log = 'Epoch: {:03d}, Runtime {:03f}, Loss {:03f}, Train: {:.4f}, Val: {:.4f}, Test: {:.4f}, Best time: {:.4f}'

  # print(log.format(epoch, time.time() - start_time, loss, train_acc, val_acc, test_acc, best_time))
print('best val accuracy {:03f} with test accuracy {:03f} at epoch {:d} and best time {:03f}'.format(val_acc, test_acc,
                                                                                                   best_epoch,
                                                                                                   best_time))


best val accuracy 0.806000 with test accuracy 0.815000 at epoch 33 and best time 18.294754


In [334]:
# Train the given model on the given graph for num_epochs
def train_standard_GCN(model, optimizer, data):
    model.train()
    optimizer.zero_grad()
    x = data.x
    y = data.y.squeeze()

    # Set up the loss and the optimizer
    loss_fn = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
    out = model(data)

    loss = loss_fn(out[data.train_mask], y[data.train_mask])
    loss.backward()
    optimizer.step()
    return loss.item()

In [335]:
@torch.no_grad()
def test_standard_GCN(model, data, opt=None):  # opt required for runtime polymorphism
  model.eval()
  logits, accs = model(data), []
  for _, mask in data('train_mask', 'val_mask', 'test_mask'):
    pred = logits[mask].max(1)[1]
    acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item()
    accs.append(acc)
  return accs

In [336]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv

class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = GCNConv(dataset.num_node_features, opt['hidden_dim'])
        self.conv2 = GCNConv(opt['hidden_dim'], dataset.num_classes)

    def forward(self, data):
      x, edge_index = data.x, data.edge_index

      x = self.conv1(x, edge_index)
      x = F.relu(x)
      x = F.dropout(x, training=self.training)
      x = self.conv2(x, edge_index)

      return F.log_softmax(x, dim=1)
# if not opt['planetoid_split'] and opt['dataset'] in ['Cora','Citeseer','Pubmed']:
#   dataset.data = set_train_val_test_split(np.random.randint(0, 1000), dataset.data, num_development=5000 if opt["dataset"] == "CoauthorCS" else 1500)
model = Net().to(device)

data = dataset.data.to(device)

parameters = [p for p in model.parameters() if p.requires_grad]
# print_model_params(model)
optimizer = torch.optim.Adam(parameters, lr=opt['lr'], weight_decay=opt['decay'])
best_time = best_epoch = train_acc = val_acc = test_acc = 0

this_test = test_standard_GCN

for epoch in range(1, opt['epoch']):
  start_time = time.time()

  loss = train_standard_GCN(model, optimizer, data)

  tmp_train_acc, tmp_val_acc, tmp_test_acc = this_test(model, data, opt)

  best_time = opt['time']
  if tmp_val_acc > val_acc:
    best_epoch = epoch
    train_acc = tmp_train_acc
    val_acc = tmp_val_acc
    test_acc = tmp_test_acc
    best_time = opt['time']
  # if not opt['no_early'] and model.odeblock.test_integrator.solver.best_val > val_acc:
  #   best_epoch = epoch
  #   val_acc = model.odeblock.test_integrator.solver.best_val
  #   test_acc = model.odeblock.test_integrator.solver.best_test
  #   train_acc = model.odeblock.test_integrator.solver.best_train
  #   best_time = model.odeblock.test_integrator.solver.best_time

  log = 'Epoch: {:03d}, Runtime {:03f}, Loss {:03f}, Train: {:.4f}, Val: {:.4f}, Test: {:.4f}, Best time: {:.4f}'

  # print(log.format(epoch, time.time() - start_time, loss, train_acc, val_acc, test_acc, best_time))
print('best val accuracy {:03f} with test accuracy {:03f} at epoch {:d} and best time {:03f}'.format(val_acc, test_acc,
                                                                                                   best_epoch,
                                                                                                   best_time))

best val accuracy 0.812000 with test accuracy 0.823000 at epoch 6 and best time 18.294754
