In [1]:
import os
import argparse
import pickle


import numpy as np

import torch
from torch_geometric.nn import GCNConv, ChebConv  # noqa
from torch_geometric.data import Data, InMemoryDataset
from torch_geometric.datasets import Planetoid, Amazon, Coauthor, GNNBenchmarkDataset, Reddit2, Flickr
from GNN import GNN
from GNN_KNN import GNN_KNN
import time
from data import get_dataset, Data

from graph_rewiring import get_two_hop, apply_gdc
from ogb.nodeproppred import PygNodePropPredDataset
import torch_geometric.transforms as T
from torch_geometric.utils import to_undirected
from graph_rewiring import make_symmetric, apply_pos_dist_rewire


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
customArgs = []

In [3]:
def get_cora_opt(opt):
  opt['dataset'] = 'Cora'
  opt['data'] = 'Planetoid'
  opt['hidden_dim'] = 16
  opt['input_dropout'] = 0.5
  opt['dropout'] = 0
  opt['optimizer'] = 'rmsprop'
  opt['lr'] = 0.0047
  opt['decay'] = 5e-4
  opt['self_loop_weight'] = 0.555
  opt['alpha'] = 0.918
  opt['time'] = 12.1
  opt['num_feature'] = 1433
  opt['num_class'] = 7
  opt['num_nodes'] = 2708
  opt['epoch'] = 31
  opt['augment'] = True
  opt['attention_dropout'] = 0
  opt['adjoint'] = False
  opt['ode'] = 'ode'
  return opt

def get_computers_opt(opt):
  opt['dataset'] = 'Computers'
  opt['hidden_dim'] = 16
  opt['input_dropout'] = 0.5
  opt['dropout'] = 0
  opt['optimizer'] = 'adam'
  opt['lr'] = 0.01
  opt['decay'] = 5e-4
  opt['self_loop_weight'] = 0.555
  opt['alpha'] = 0.918
  opt['epoch'] = 400
  opt['time'] = 12.1
  opt['num_feature'] = 1433
  opt['num_class'] = 7
  opt['num_nodes'] = 2708
  opt['epoch'] = 100
  opt['attention_dropout'] = 0
  opt['ode'] = 'ode'
  return opt

def get_flickr_opt(opt):
  opt['dataset'] = 'Flickr'
  opt['hidden_dim'] = 128
  opt['feature_hidden_dim'] = 64
  opt['input_dropout'] = 0.5
  opt['dropout'] = 0
  opt['optimizer'] = 'adam'
  opt['lr'] = 0.005451476553977102
  opt['decay'] = 0
  opt['self_loop_weight'] = 0.555
  opt['alpha'] = 1.0
  opt['time'] = 12.1
  opt['num_feature'] = 500
  opt['num_class'] = 7
  opt['num_nodes'] = 89250
  opt['epoch'] = 100
  opt['attention_dropout'] = 0
  opt['ode'] = 'ode'
  opt['gdc_avg_degree']= 48
  opt['gdc_k'] = 48
  opt['gdc_method'] = 'ppr'
  opt['gdc_sparsification'] = 'topk'
  opt['gdc_threshold'] =  0.01
  return opt

In [4]:
def get_optimizer(name, parameters, lr, weight_decay=0):
  if name == 'sgd':
    return torch.optim.SGD(parameters, lr=lr, weight_decay=weight_decay)
  elif name == 'rmsprop':
    return torch.optim.RMSprop(parameters, lr=lr, weight_decay=weight_decay)
  elif name == 'adagrad':
    return torch.optim.Adagrad(parameters, lr=lr, weight_decay=weight_decay)
  elif name == 'adam':
    return torch.optim.Adam(parameters, lr=lr, weight_decay=weight_decay)
  elif name == 'adamax':
    return torch.optim.Adamax(parameters, lr=lr, weight_decay=weight_decay)
  else:
    raise Exception("Unsupported optimizer: {}".format(name))

In [5]:
def train(model, optimizer, data):
  model.train()
  optimizer.zero_grad()
  out = model(data.x)
  lf = torch.nn.CrossEntropyLoss()
  loss = lf(out[data.train_mask], data.y[data.train_mask])

  # TODO: What is this block about???
  if model.odeblock.nreg > 0:  # add regularisation - slower for small data, but faster and better performance for large data
    reg_states = tuple(torch.mean(rs) for rs in model.reg_states)
    regularization_coeffs = model.regularization_coeffs

    reg_loss = sum(
      reg_state * coeff for reg_state, coeff in zip(reg_states, regularization_coeffs) if coeff != 0
    )
    loss = loss + reg_loss

  # Update count of forward evaluations from ODE solver
  # NOTE: fm stands for "forward meter"
  # TODO: Rename this to be more informative!
  model.fm.update(model.getNFE())
  model.resetNFE()

  # Gradient step
  loss.backward()
  optimizer.step()

  # Update count of backwards evaluations from ODE solver
  model.bm.update(model.getNFE())
  model.resetNFE()

  return loss.item()

@torch.no_grad()
def test(model, data):
  model.eval()
  logits, accs = model(data.x), []
  for _, mask in data('train_mask', 'val_mask', 'test_mask'):
    pred = logits[mask].max(1)[1]
    acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item()
    accs.append(acc)
  return accs

def print_model_params(model):
  print(model)
  for name, param in model.named_parameters():
    if param.requires_grad:
      print(name)
      print(param.data.shape)

In [6]:
def get_dataset_benchmark(opt: dict, data_dir, use_lcc: bool = False) -> InMemoryDataset:
    ds = opt['dataset']
    dataset = Flickr(root=os.path.join(data_dir,'Flickr'), transform=T.ToSparseTensor())#T.NormalizeFeatures())  

    """
    d_train = GNNBenchmarkDataset(name=ds, root=path, split='train', transform=T.NormalizeFeatures())
    d_val = GNNBenchmarkDataset(name=ds, root=path, split='val', transform=T.NormalizeFeatures())
    d_test = GNNBenchmarkDataset(name=ds, root=path, split='test', transform=T.NormalizeFeatures())
    
    print(d_train.data)
    print(d_val.data)
    print(d_test.data)
    print(torch.max(d_train.data.edge_index))
    print(torch.max(d_val.data.edge_index))
    print(torch.max(d_test.data.edge_index))
    
    d_temp = Planetoid(root=os.path.join(data_dir,'Cora'), name='Cora')
    print(d_temp.data)
    print(torch.max(d_temp.data.edge_index))
    
    d_temp = PygNodePropPredDataset(name='ogbn-arxiv', root=os.path.join(data_dir,'ogbn-arxiv'), transform=T.ToSparseTensor())
    print(d_temp.data)
    print(torch.max(d_temp.data.edge_index))
    
    d_temp = Flickr(root=os.path.join(data_dir,'Flickr'), transform=T.NormalizeFeatures())  
    print(d_temp.data)
    print(torch.max(d_temp.data.edge_index))
    
    #d_temp = Reddit2(root=os.path.join(data_dir,'Reddit2'), transform=T.NormalizeFeatures())
    #print(d_temp.data)
    #print(torch.max(d_temp.data.edge_index))
    
    return
    """
    
    
    if use_lcc:
        lcc = get_largest_connected_component(dataset)

        x_new = dataset.data.x[lcc]
        y_new = dataset.data.y[lcc]

        row, col = dataset.data.edge_index.numpy()
        edges = [[i, j] for i, j in zip(row, col) if i in lcc and j in lcc]
        edges = remap_edges(edges, get_node_mapper(lcc))

        data = Data(
          x=x_new,
          edge_index=torch.LongTensor(edges),
          y=y_new,
          train_mask=torch.zeros(y_new.size()[0], dtype=torch.bool),
          test_mask=torch.zeros(y_new.size()[0], dtype=torch.bool),
          val_mask=torch.zeros(y_new.size()[0], dtype=torch.bool)
        )
        dataset.data = data
    if opt['rewiring'] is not None:
        dataset.data = rewire(dataset.data, opt, data_dir)
    train_mask_exists = True
    try:
        dataset.data.train_mask
    except AttributeError:
        train_mask_exists = False

    if ds == 'ogbn-arxiv':
        split_idx = dataset.get_idx_split()
        ei = to_undirected(dataset.data.edge_index)
        data = Data(
        x=dataset.data.x,
        edge_index=ei,
        y=dataset.data.y,
        train_mask=split_idx['train'],
        test_mask=split_idx['test'],
        val_mask=split_idx['valid'])
        dataset.data = data
        train_mask_exists = True

    #todo this currently breaks with heterophilic datasets if you don't pass --geom_gcn_splits
    if (use_lcc or not train_mask_exists) and not opt['geom_gcn_splits']:
        dataset.data = set_train_val_test_split(
          12345,
          dataset.data,
          num_development=5000 if ds == "CoauthorCS" else 1500)

    return dataset


def get_component(dataset: InMemoryDataset, start: int = 0) -> set:
    visited_nodes = set()
    queued_nodes = set([start])
    row, col = dataset.data.edge_index.numpy()
    while queued_nodes:
        current_node = queued_nodes.pop()
        visited_nodes.update([current_node])
        neighbors = col[np.where(row == current_node)[0]]
        neighbors = [n for n in neighbors if n not in visited_nodes and n not in queued_nodes]
        queued_nodes.update(neighbors)
    return visited_nodes


def get_largest_connected_component(dataset: InMemoryDataset) -> np.ndarray:
    remaining_nodes = set(range(dataset.data.x.shape[0]))
    comps = []
    while remaining_nodes:
        start = min(remaining_nodes)
        comp = get_component(dataset, start)
        comps.append(comp)
        remaining_nodes = remaining_nodes.difference(comp)
    return np.array(list(comps[np.argmax(list(map(len, comps)))]))


def get_node_mapper(lcc: np.ndarray) -> dict:
    mapper = {}
    counter = 0
    for node in lcc:
        mapper[node] = counter
        counter += 1
    return mapper


def remap_edges(edges: list, mapper: dict) -> list:
    row = [e[0] for e in edges]
    col = [e[1] for e in edges]
    row = list(map(lambda x: mapper[x], row))
    col = list(map(lambda x: mapper[x], col))
    return [row, col]


def set_train_val_test_split(
        seed: int,
        data: Data,
        num_development: int = 1500,
        num_per_class: int = 20) -> Data:
    rnd_state = np.random.RandomState(seed)
    num_nodes = data.y.shape[0]
    development_idx = rnd_state.choice(num_nodes, num_development, replace=False)
    test_idx = [i for i in np.arange(num_nodes) if i not in development_idx]

    train_idx = []
    rnd_state = np.random.RandomState(seed)
    for c in range(data.y.max() + 1):
        class_idx = development_idx[np.where(data.y[development_idx].cpu() == c)[0]]
        train_idx.extend(rnd_state.choice(class_idx, num_per_class, replace=False))

    val_idx = [i for i in development_idx if i not in train_idx]

    def get_mask(idx):
        mask = torch.zeros(num_nodes, dtype=torch.bool)
        mask[idx] = 1
        return mask

    data.train_mask = get_mask(train_idx)
    data.val_mask = get_mask(val_idx)
    data.test_mask = get_mask(test_idx)

    return data



In [7]:
import gc

def run(opt, run_count):

    # Load dataset and create model
    if opt['dataset'] == 'Flickr':
        dataset = get_dataset_benchmark(opt, '../data', False)
    else:
        dataset = get_dataset(opt, '../data', False)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    if opt['rewire_KNN'] or opt['fa_layer']:
        model = GNN_KNN(opt, dataset, device).to(device)
    else:
        model = GNN(opt, dataset, device).to(device)
    data = dataset.data.to(device)
    #model, data = GNN(opt, dataset, device).to(device), dataset.data.to(device)
    print(opt)

    # Todo for some reason the submodule parameters inside the attention module don't show up when running on GPU.
    parameters = [p for p in model.parameters() if p.requires_grad]
    print_model_params(model)

    # Training/test loop
    results = {
        'time':[],
        'loss':[],
        'forward_nfe':[],
        'backward_nfe':[],
        'train_acc':[],
        'test_acc':[],
        'val_acc':[],
        'best_epoch':0,
        'best_val_acc':0.,
        'best_test_acc':0.,
    }
    runtimes = []
    losses = []

    optimizer = get_optimizer(opt['optimizer'], parameters, lr=opt['lr'], weight_decay=opt['decay'])
    best_val_acc = test_acc = train_acc = best_epoch = 0
    overall_time = time.time()
    for epoch in range(1, opt['epoch']):
        start_time = time.time()

        loss = train(model, optimizer, data)
        train_acc, val_acc, test_acc = test(model, data)

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_test_acc = test_acc
            best_train_acc = train_acc
            best_epoch = epoch

        #if epoch % 10 == 0:
        results['time'].append(time.time() - start_time)
        results['loss'].append(loss)
        results['forward_nfe'].append(model.fm.sum)
        results['backward_nfe'].append(model.bm.sum)
        results['train_acc'].append(train_acc)
        results['test_acc'].append(test_acc)
        results['val_acc'].append(val_acc)
        results['best_epoch'] = best_epoch
        results['best_train_acc'] = best_train_acc
        results['best_val_acc'] = best_val_acc
        results['best_test_acc'] = best_test_acc

        log = 'Epoch: {:03d}, Runtime {:03f}, Loss {:03f}, forward nfe {:d}, backward nfe {:d}, Train: {:.4f}, Val: {:.4f}, Test: {:.4f}'
        print(log.format(epoch, results['time'][-1], results['loss'][-1], results['forward_nfe'][-1], results['backward_nfe'][-1], results['train_acc'][-1], results['val_acc'][-1], results['test_acc'][-1]))
        torch.cuda.empty_cache()
        gc.collect()

    print('best val accuracy {:03f} with test accuracy {:03f} at epoch {:d}'.format(best_val_acc, best_test_acc, best_epoch))
    
    results['all_epochs_time'] = time.time() - overall_time

    # TODO: Save results
    # cora_epoch_101_adjoint_false_... . pickle
    pickle.dump( results, open( f"../results/{opt['dataset']}_{opt['method']}_stepsize_{opt['step_size']}_run_{run_count}.pickle", "wb" ) )

    return train_acc, best_val_acc, test_acc


In [8]:
parser = argparse.ArgumentParser()
parser.add_argument('--use_cora_defaults', action='store_true',
                  help='Whether to run with best params for cora. Overrides the choice of dataset')
parser.add_argument('--dataset', type=str, default='Cora',
                  help='Cora, Citeseer, Pubmed, Computers, Photo, CoauthorCS')
parser.add_argument('--data_norm', type=str, default='rw',
                  help='rw for random walk, gcn for symmetric gcn norm')
parser.add_argument('--hidden_dim', type=int, default=16, help='Hidden dimension.')
parser.add_argument('--input_dropout', type=float, default=0.5, help='Input dropout rate.')
parser.add_argument('--dropout', type=float, default=0.0, help='Dropout rate.')
parser.add_argument('--optimizer', type=str, default='adam', help='One from sgd, rmsprop, adam, adagrad, adamax.')
parser.add_argument('--lr', type=float, default=0.01, help='Learning rate.')
parser.add_argument('--decay', type=float, default=5e-4, help='Weight decay for optimization')
parser.add_argument('--self_loop_weight', type=float, default=1.0, help='Weight of self-loops.')
parser.add_argument('--epoch', type=int, default=10, help='Number of training epochs per iteration.')
parser.add_argument('--alpha', type=float, default=1.0, help='Factor in front matrix A.')
parser.add_argument('--time', type=float, default=1.0, help='End time of ODE integrator.')
parser.add_argument('--augment', action='store_true',
                  help='double the length of the feature vector by appending zeros to stabilist ODE learning')
parser.add_argument('--alpha_dim', type=str, default='sc', help='choose either scalar (sc) or vector (vc) alpha')
parser.add_argument('--no_alpha_sigmoid', dest='no_alpha_sigmoid', action='store_true', help='apply sigmoid before multiplying by alpha')
parser.add_argument('--beta_dim', type=str, default='sc', help='choose either scalar (sc) or vector (vc) beta')
parser.add_argument('--block', type=str, default='constant', help='constant, mixed, attention, SDE')
parser.add_argument('--function', type=str, default='laplacian', help='laplacian, transformer, dorsey, GAT, SDE')
parser.add_argument('--geom_gcn_splits', dest='geom_gcn_splits', action='store_true',
                      help='use the 10 fixed splits from '
                           'https://arxiv.org/abs/2002.05287')
# ODE args
parser.add_argument('--method', type=str, default='dopri5',
                  help="set the numerical solver: dopri5, euler, rk4, midpoint")
parser.add_argument('--step_size', type=float, default=1, help='fixed step size when using fixed step solvers e.g. rk4')
parser.add_argument(
    "--adjoint_method", type=str, default="adaptive_heun",
    help="set the numerical solver for the backward pass: dopri5, euler, rk4, midpoint"
)
parser.add_argument('--adjoint_step_size', type=float, default=1, help='fixed step size when using fixed step adjoint solvers e.g. rk4')
parser.add_argument('--adjoint', default=False, help='use the adjoint ODE method to reduce memory footprint')
parser.add_argument('--tol_scale', type=float, default=1., help='multiplier for atol and rtol')
parser.add_argument("--tol_scale_adjoint", type=float, default=1.0,
                  help="multiplier for adjoint_atol and adjoint_rtol")
parser.add_argument('--ode_blocks', type=int, default=1, help='number of ode blocks to run')
parser.add_argument('--add_source', dest='add_source', action='store_true',
                  help='If try get rid of alpha param and the beta*x0 source term')
# SDE args
parser.add_argument('--dt_min', type=float, default=1e-5, help='minimum timestep for the SDE solver')
parser.add_argument('--dt', type=float, default=1e-3, help='fixed step size')
parser.add_argument('--adaptive', dest='adaptive', action='store_true', help='use adaptive step sizes')
# Attention args
parser.add_argument('--leaky_relu_slope', type=float, default=0.2,
                  help='slope of the negative part of the leaky relu used in attention')
parser.add_argument('--attention_dropout', type=float, default=0., help='dropout of attention weights')
parser.add_argument('--heads', type=int, default=4, help='number of attention heads')
parser.add_argument('--attention_norm_idx', type=int, default=0, help='0 = normalise rows, 1 = normalise cols')
parser.add_argument('--attention_dim', type=int, default=64,
                  help='the size to project x to before calculating att scores')
parser.add_argument('--mix_features', dest='mix_features', action='store_true',
                  help='apply a feature transformation xW to the ODE')
parser.add_argument("--max_nfe", type=int, default=1000, help="Maximum number of function evaluations allowed.")
parser.add_argument('--reweight_attention', dest='reweight_attention', action='store_true', help="multiply attention scores by edge weights before softmax")
# regularisation args
parser.add_argument('--jacobian_norm2', type=float, default=None, help="int_t ||df/dx||_F^2")
parser.add_argument('--total_deriv', type=float, default=None, help="int_t ||df/dt||^2")

parser.add_argument('--kinetic_energy', type=float, default=None, help="int_t ||f||_2^2")
parser.add_argument('--directional_penalty', type=float, default=None, help="int_t ||(df/dx)^T f||^2")

# rewiring args
parser.add_argument('--rewiring', type=str, default=None, help="two_hop, gdc")
parser.add_argument('--gdc_method', type=str, default='ppr', help="ppr, heat, coeff")
parser.add_argument('--gdc_sparsification', type=str, default='topk', help="threshold, topk")
parser.add_argument('--gdc_k', type=int, default=64, help="number of neighbours to sparsify to when using topk")
parser.add_argument('--gdc_threshold', type=float, default=0.0001, help="obove this edge weight, keep edges when using threshold")
parser.add_argument('--gdc_avg_degree', type=int, default=64,
                  help="if gdc_threshold is not given can be calculated by specifying avg degree")
parser.add_argument('--ppr_alpha', type=float, default=0.05, help="teleport probability")
parser.add_argument('--heat_time', type=float, default=3., help="time to run gdc heat kernal diffusion for")

parser.add_argument("--not_lcc", action="store_false", help="don't use the largest connected component")
parser.add_argument('--att_samp_pct', type=float, default=1,
                  help="float in [0,1). The percentage of edges to retain based on attention scores")
parser.add_argument('--use_flux', dest='use_flux', action='store_true',
                  help='incorporate the feature grad in attention based edge dropout')
parser.add_argument("--exact", action="store_true",
                  help="for small datasets can do exact diffusion. If dataset is too big for matrix inversion then you can't")
parser.add_argument('--M_nodes', type=int, default=64, help="new number of nodes to add")
parser.add_argument('--new_edges', type=str, default="random", help="random, random_walk, k_hop")
parser.add_argument('--sparsify', type=str, default="S_hat", help="S_hat, recalc_att")
parser.add_argument('--threshold_type', type=str, default="topk_adj", help="topk_adj, addD_rvR")
parser.add_argument('--rw_addD', type=float, default=0.02, help="percentage of new edges to add")
parser.add_argument('--rw_rmvR', type=float, default=0.02, help="percentage of edges to remove")
parser.add_argument('--rewire_KNN', action='store_true', help='perform KNN rewiring every few epochs')
parser.add_argument('--rewire_KNN_T', type=str, default="T0", help="T0, TN")
parser.add_argument('--rewire_KNN_epoch', type=int, default=5, help="frequency of epochs to rewire")
parser.add_argument('--rewire_KNN_k', type=int, default=64, help="target degree for KNN rewire")
parser.add_argument('--rewire_KNN_sym', action='store_true', help='make KNN symmetric')
parser.add_argument('--KNN_online', action='store_true', help='perform rewiring online')
parser.add_argument('--KNN_online_reps', type=int, default=4, help="how many online KNN its")
parser.add_argument('--KNN_space', type=str, default="pos_distance", help="Z,P,QKZ,QKp")

# Stefan's experiment args
parser.add_argument('--count_runs', type=int, default=10,
                  help="number of runs to average results over per parameter settings for each experiment")

# beltrami
parser.add_argument('--beltrami', action='store_true', help='perform diffusion beltrami style')
parser.add_argument('--fa_layer', action='store_true', help='add a bottleneck paper style layer with more edges')
parser.add_argument('--pos_enc_type', type=str, default="DW64",
                  help='positional encoder either GDC, DW64, DW128, DW256')
parser.add_argument('--pos_enc_orientation', type=str, default="row", help="row, col")
parser.add_argument('--feat_hidden_dim', type=int, default=64, help="dimension of features in beltrami")
parser.add_argument('--pos_enc_hidden_dim', type=int, default=32, help="dimension of position in beltrami")
parser.add_argument('--edge_sampling', action='store_true', help='perform edge sampling rewiring')
parser.add_argument('--edge_sampling_T', type=str, default="T0", help="T0, TN")
parser.add_argument('--edge_sampling_epoch', type=int, default=5, help="frequency of epochs to rewire")
parser.add_argument('--edge_sampling_add', type=float, default=0.64, help="percentage of new edges to add")
parser.add_argument('--edge_sampling_add_type', type=str, default="importance",
                  help="random, ,anchored, importance, degree")
parser.add_argument('--edge_sampling_rmv', type=float, default=0.32, help="percentage of edges to remove")
parser.add_argument('--edge_sampling_sym', action='store_true', help='make KNN symmetric')
parser.add_argument('--edge_sampling_online', action='store_true', help='perform rewiring online')
parser.add_argument('--edge_sampling_online_reps', type=int, default=4, help="how many online KNN its")
parser.add_argument('--edge_sampling_space', type=str, default="attention",
                  help="attention,pos_distance, z_distance, pos_distance_QK, z_distance_QK")
parser.add_argument('--symmetric_attention', action='store_true',
                  help='maks the attention symmetric for rewring in QK space')

parser.add_argument('--fa_layer_edge_sampling_rmv', type=float, default=0.8, help="percentage of edges to remove")
parser.add_argument('--gpu', type=int, default=0, help="GPU to run on (default 0)")
parser.add_argument('--pos_enc_csv', action='store_true', help="Generate pos encoding as a sparse CSV")

parser.add_argument('--pos_dist_quantile', type=float, default=0.001, help="percentage of N**2 edges to keep")

#added
parser.add_argument('--use_mlp', dest='use_mlp', action='store_true',
                  help='Add a fully connected layer to the encoder.')
parser.add_argument('--use_labels', dest='use_labels', action='store_true', help='Also diffuse labels')
parser.add_argument('--fc_out', dest='fc_out', action='store_true',
                  help='Add a fully connected layer to the decoder.')
parser.add_argument("--batch_norm", dest='batch_norm', action='store_true', help='search over reg params')

args = parser.parse_args(customArgs)
opt = vars(args)

#'Cora' #'Flickr' #'Computer'
opt['dataset'] = 'Computer' 

if opt['dataset'] == 'Cora':
    opt = get_cora_opt(opt)
elif opt['dataset'] == 'Computer':
    opt = get_computers_opt(opt)
elif opt['dataset'] == 'Flickr':
    opt = get_flickr_opt(opt)

opt['adjoint'] = True
#opt['method'] = 'explicit_adams'
opt['method'] = 'implicit_adams'
#opt['method'] = 'dopri5'
opt['adjoint_method'] = opt['method']
opt['max_iters'] = 100
opt['step_size'] = opt['dt_min'] = 0.01
opt['tol_scale'] = 100.0
opt['tol_scale_adjoint'] = 100.0
#added
opt['max_nfe'] = 100000
if opt['dataset'] == 'Flickr':    
    opt['rewiring'] == 'gdc'

# DEBUG
#for k in ['dataset', 'epoch', 'adjoint', 'rewiring', 'adaptive', 'dt', 'dt_min', 'method', 'adjoint_method']:
#  print(k, opt[k])
#main(opt, 0)

# Run combination of experiments
for stepsize in [1.0,0.5,0.1,0.01]: #[0.5, 0.25, 0.1, 0.01]: # 2.0, 1.0
    print(f'*** Doing stepsize {stepsize} ***')
    for idx in range(opt['count_runs']):
        print(f'*** Doing run {idx} ***')
        # NOTE: I think setting dt_min may not be necessary, doing it just to be safe!
        opt['step_size'] = opt['dt_min'] = stepsize
        run(opt, idx)

*** Doing stepsize 1.0 ***
*** Doing run 0 ***
{'use_cora_defaults': False, 'dataset': 'Computers', 'data_norm': 'rw', 'hidden_dim': 16, 'input_dropout': 0.5, 'dropout': 0, 'optimizer': 'adam', 'lr': 0.01, 'decay': 0.0005, 'self_loop_weight': 0.555, 'epoch': 100, 'alpha': 0.918, 'time': 12.1, 'augment': False, 'alpha_dim': 'sc', 'no_alpha_sigmoid': False, 'beta_dim': 'sc', 'block': 'constant', 'function': 'laplacian', 'geom_gcn_splits': False, 'method': 'implicit_adams', 'step_size': 1.0, 'adjoint_method': 'implicit_adams', 'adjoint_step_size': 1, 'adjoint': True, 'tol_scale': 100.0, 'tol_scale_adjoint': 100.0, 'ode_blocks': 1, 'add_source': False, 'dt_min': 1.0, 'dt': 0.001, 'adaptive': False, 'leaky_relu_slope': 0.2, 'attention_dropout': 0, 'heads': 4, 'attention_norm_idx': 0, 'attention_dim': 64, 'mix_features': False, 'max_nfe': 100000, 'reweight_attention': False, 'jacobian_norm2': None, 'total_deriv': None, 'kinetic_energy': None, 'directional_penalty': None, 'rewiring': None, 'g



Epoch: 001, Runtime 0.828876, Loss 2.311668, forward nfe 119, backward nfe 294, Train: 0.0950, Val: 0.0262, Test: 0.0386
Epoch: 002, Runtime 1.912070, Loss 2.355083, forward nfe 2163, backward nfe 1318, Train: 0.1050, Val: 0.1346, Test: 0.1347
Epoch: 003, Runtime 1.914775, Loss 2.346608, forward nfe 4207, backward nfe 2342, Train: 0.1850, Val: 0.1069, Test: 0.1167
Epoch: 004, Runtime 1.914927, Loss 2.244116, forward nfe 6255, backward nfe 3366, Train: 0.1100, Val: 0.0808, Test: 0.0850
Epoch: 005, Runtime 1.918711, Loss 2.242403, forward nfe 8302, backward nfe 4392, Train: 0.1850, Val: 0.1185, Test: 0.1318
Epoch: 006, Runtime 1.942705, Loss 2.213710, forward nfe 10357, backward nfe 5426, Train: 0.2200, Val: 0.1754, Test: 0.1916
Epoch: 007, Runtime 1.958310, Loss 2.170889, forward nfe 12428, backward nfe 6450, Train: 0.1700, Val: 0.1677, Test: 0.1644
Epoch: 008, Runtime 1.994545, Loss 2.135077, forward nfe 14616, backward nfe 7483, Train: 0.2350, Val: 0.2254, Test: 0.2271
Epoch: 009, Run