In [1]:
import time
import argparse
import numpy as np
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv


from datasets import get_planetoid_dataset
from train_eval import run, evaluate

In [2]:
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', type=str, default="cora")
parser.add_argument('--split', type=str, default='public')
parser.add_argument('--runs', type=int, default=10)
parser.add_argument('--epochs', type=int, default=200)
parser.add_argument('--lr', type=float, default=0.01)
parser.add_argument('--weight_decay', type=float, default=0.0005)
parser.add_argument('--early_stopping', type=int, default=0)
parser.add_argument('--hidden', type=int, default=16)
parser.add_argument('--dropout', type=float, default=0.5)
parser.add_argument('--normalize_features', type=bool, default=True)
parser.add_argument('--logger', type=str, default=None)
parser.add_argument('--optimizer', type=str, default='Adam')
parser.add_argument('--preconditioner', type=str, default=None)
parser.add_argument('--momentum', type=float, default=0.9)
parser.add_argument('--eps', type=float, default=0.01)
parser.add_argument('--update_freq', type=int, default=50)
parser.add_argument('--gamma', type=float, default=None)
parser.add_argument('--alpha', type=float, default=None)
parser.add_argument('--hyperparam', type=str, default=None)
args, unknown = parser.parse_known_args()

In [3]:
dataset = get_planetoid_dataset(name=args.dataset, normalize_features=args.normalize_features, split=args.split)

In [4]:
class Net_orig(torch.nn.Module):
    def __init__(self, dataset):
        super(Net2, self).__init__()
        self.conv1 = GCNConv(dataset.num_features, args.hidden)
        self.conv2 = GCNConv(args.hidden, dataset.num_classes)

    def reset_parameters(self):
        self.conv1.reset_parameters()
        self.conv2.reset_parameters()

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = F.relu(self.conv1(x, edge_index))
        x = F.dropout(x, p=args.dropout, training=self.training)
        x = self.conv2(x, edge_index)
        return F.log_softmax(x, dim=1)

class CRD(torch.nn.Module):
    def __init__(self, d_in, d_out, p):
        super(CRD, self).__init__()
        self.conv = GCNConv(d_in, d_out, cached=True) 
        self.p = p

    def reset_parameters(self):
        self.conv.reset_parameters()

    def forward(self, x, edge_index, mask=None):
        x = F.relu(self.conv(x, edge_index))
        x = F.dropout(x, p=self.p, training=self.training)
        return x

class CLS(torch.nn.Module):
    def __init__(self, d_in, d_out):
        super(CLS, self).__init__()
        self.conv = GCNConv(d_in, d_out, cached=True)

    def reset_parameters(self):
        self.conv.reset_parameters()

    def forward(self, x, edge_index, mask=None):
        x = self.conv(x, edge_index)
        x = F.log_softmax(x, dim=1)
        return x
    
class Net(torch.nn.Module):
    def __init__(self, dataset):
        super(Net, self).__init__()
        self.crd = CRD(dataset.num_features, args.hidden, args.dropout)
        self.cls = CLS(args.hidden, dataset.num_classes)

    def reset_parameters(self):
        self.crd.reset_parameters()
        self.cls.reset_parameters()

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.crd(x, edge_index, data.train_mask)
        x = self.cls(x, edge_index, data.train_mask)
        return x



In [5]:
kwargs = {
    'dataset': dataset, 
    'model': Net(dataset), 
    'str_optimizer': args.optimizer, 
    'str_preconditioner': args.preconditioner, 
    'runs': args.runs, 
    'epochs': args.epochs, 
    'lr': args.lr, 
    'weight_decay': args.weight_decay, 
    'early_stopping': args.early_stopping, 
    'logger': args.logger, 
    'momentum': args.momentum,
    'eps': args.eps,
    'update_freq': args.update_freq,
    'gamma': args.gamma,
    'alpha': args.alpha,
    'hyperparam': args.hyperparam
}

In [6]:
if args.hyperparam == 'eps':
    for param in np.logspace(-3, 0, 10, endpoint=True):
        print(f"{args.hyperparam}: {param}")
        kwargs[args.hyperparam] = param
        run(**kwargs)
elif args.hyperparam == 'update_freq':
    for param in [4, 8, 16, 32, 64, 128]:
        print(f"{args.hyperparam}: {param}")
        kwargs[args.hyperparam] = param
        run(**kwargs)
elif args.hyperparam == 'gamma':
    for param in np.linspace(1., 10., 10, endpoint=True):
        print(f"{args.hyperparam}: {param}")
        kwargs[args.hyperparam] = param
        run(**kwargs)
else:
    run(**kwargs)

Val Loss: 0.8192, Test Accuracy: 80.87 ± 0.84, Duration: 8.379 



In [7]:
evaluate(Net(dataset), dataset[0])

{'train loss': 1.9452314376831055,
 'train acc': 0.15714285714285714,
 'val loss': 1.9449515342712402,
 'val acc': 0.252,
 'test loss': 1.9448139667510986,
 'test acc': 0.261}

In [8]:
start = time.time()
evaluate(Net(dataset), dataset[0])
end = time.time()
t_inference= end-start
print(f"Time inference:{t_inference}  ")

Time inference:0.1298234462738037  


In [11]:
def get_num_parameters(model, count_nonzero_only=False) -> int:
    """
    calculate the total number of parameters of model
    :param count_nonzero_only: only count nonzero weights
    """
    num_counted_elements = 0
    for param in model.parameters():
        if count_nonzero_only:
            num_counted_elements += param.count_nonzero()
        else:
            num_counted_elements += param.numel()
    return num_counted_elements


def get_model_size(model, data_width=32, count_nonzero_only=False) -> int:
    """
    calculate the model size in bits
    :param data_width: #bits per element
    :param count_nonzero_only: only count nonzero weights
    """
    return get_num_parameters(model, count_nonzero_only) * data_width

Byte = 8
KiB = 1024 * Byte
MiB = 1024 * KiB
GiB = 1024 * MiB



In [12]:
get_model_size(Net(dataset))

738016