In [2]:
import argparse
import sys
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.utils import to_undirected
from torch_scatter import scatter

from logger import Logger, SimpleLogger
from dataset import load_nc_dataset
from data_utils import normalize, gen_normalized_adjs, evaluate, evaluate_whole_graph, eval_acc, eval_rocauc, eval_f1, to_sparse_tensor, load_fixed_splits
from parse import parse_method_base, parse_method_ours, parse_method_gstopr, parser_add_main_args

# NOTE: for consistent data splits, see data_utils.rand_train_test_idx
def fix_seed(seed):
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
fix_seed(0)

### Parse args ###
parser = argparse.ArgumentParser(description='General Training Pipeline')
parser_add_main_args(parser)
# GSTOPR config
parser.add_argument('--r', default=0.2, type=float, help='selected ratio')

parser.add_argument('--noise', default=1., type=float, help='gumbel noise')
parser.add_argument('--temp', default=1., type=float, help='sinkhorn temperature')
parser.add_argument('--max_iter', default=10, type=int, help='sinkhorn max iter')
args = parser.parse_args(args=[])
print(args)

device = torch.device("cuda:" + str(args.device)) if torch.cuda.is_available() else torch.device("cpu")


Namespace(K=3, T=1, beta=1.0, cached=False, cpu=False, data_dir='../data', dataset='elliptic', device=0, directed=False, display_step=1, dropout=0.0, epochs=200, gat_heads=2, gnn='gcn', gpr_alpha=0.1, hidden_channels=32, lp_alpha=0.1, lr=0.01, lr_a=0.005, max_iter=10, method='erm', no_bn=False, noise=1.0, num_layers=2, num_sample=5, r=0.2, rocauc=False, runs=5, sub_dataset='', temp=1.0, weight_decay=0.001)


In [3]:

def get_dataset(dataset, sub_dataset=None):
    ### Load and preprocess data ###
    if dataset == 'elliptic':
        dataset = load_nc_dataset(args.data_dir, 'elliptic', sub_dataset)
    else:
        raise ValueError('Invalid dataname')

    if len(dataset.label.shape) == 1:
        dataset.label = dataset.label.unsqueeze(1)

    dataset.n = dataset.graph['num_nodes']
    dataset.c = max(dataset.label.max().item() + 1, dataset.label.shape[1])
    dataset.d = dataset.graph['node_feat'].shape[1]

    dataset.graph['edge_index'], dataset.graph['node_feat'] = \
        dataset.graph['edge_index'], dataset.graph['node_feat']

    return dataset

if args.dataset == 'elliptic':
    tr_subs, val_subs, te_subs = [i for i in range(6, 11)], [i for i in range(11, 16)], [i for i in range(16, 49)]
    datasets_tr = [get_dataset(dataset='elliptic', sub_dataset=tr_subs[i]) for i in range(len(tr_subs))]
    datasets_val = [get_dataset(dataset='elliptic', sub_dataset=val_subs[i]) for i in range(len(val_subs))]
    datasets_te = [get_dataset(dataset='elliptic', sub_dataset=te_subs[i]) for i in range(len(te_subs))]
else:
    raise ValueError('Invalid dataname')

dataset_tr = datasets_tr[0]
dataset_val = datasets_val[0]
print(f"Train num nodes {dataset_tr.n} | num classes {dataset_tr.c} | num node feats {dataset_tr.d}")
print(f"Val num nodes {dataset_val.n} | num classes {dataset_val.c} | num node feats {dataset_val.d}")
for i in range(len(te_subs)):
    dataset_te = datasets_te[i]
    print(f"Test {i} num nodes {dataset_te.n} | num classes {dataset_te.c} | num node feats {dataset_te.d}")


Train num nodes 6048 | num classes 2 | num node feats 165
Val num nodes 2047 | num classes 2 | num node feats 165
Test 0 num nodes 3385 | num classes 2 | num node feats 165
Test 1 num nodes 1976 | num classes 2 | num node feats 165
Test 2 num nodes 3506 | num classes 2 | num node feats 165
Test 3 num nodes 4291 | num classes 2 | num node feats 165
Test 4 num nodes 3537 | num classes 2 | num node feats 165
Test 5 num nodes 5894 | num classes 2 | num node feats 165
Test 6 num nodes 4165 | num classes 2 | num node feats 165
Test 7 num nodes 4592 | num classes 2 | num node feats 165
Test 8 num nodes 2314 | num classes 2 | num node feats 165
Test 9 num nodes 2523 | num classes 2 | num node feats 165
Test 10 num nodes 1089 | num classes 2 | num node feats 165
Test 11 num nodes 1653 | num classes 2 | num node feats 165
Test 12 num nodes 4275 | num classes 2 | num node feats 165
Test 13 num nodes 2483 | num classes 2 | num node feats 165
Test 14 num nodes 2816 | num classes 2 | num node feats 

  edge_index = torch.tensor(A.nonzero(), dtype=torch.long)


In [5]:

### Load method ###
if args.method == 'erm':
    print(1)
    model = parse_method_base(args, datasets_tr, device)
elif args.method == 'gstopr':
    print(2)
    model = parse_method_gstopr(args, datasets_tr, device)
else:
    print(3)
    model = parse_method_ours(args, datasets_tr, device)


# using rocauc as the eval function
if args.rocauc or args.dataset in ('twitch-e', 'fb100', 'elliptic'):
    criterion = nn.BCEWithLogitsLoss()
    eval_func = eval_f1
else:
    criterion = nn.NLLLoss()
    eval_func = eval_acc

logger = Logger(args.runs, args)

model.train()
print('MODEL:', model)
print('DATASET:', args.dataset)

1
3
MODEL: Model(
  (gnn): GCN(
    (convs): ModuleList(
      (0): GCNConv(165, 32)
      (1): GCNConv(32, 2)
    )
    (bns): ModuleList(
      (0): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (gl): ModuleList(
    (0): Graph_Editer()
    (1): Graph_Editer()
    (2): Graph_Editer()
    (3): Graph_Editer()
    (4): Graph_Editer()
  )
)
DATASET: elliptic


In [None]:


### Training loop ###
for run in range(args.runs):
    model.reset_parameters()
    if args.method in ['erm', 'gstopr']:
        optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    elif args.method == 'eerm':
        optimizer_gnn = torch.optim.AdamW(model.gnn.parameters(), lr=args.lr, weight_decay=args.weight_decay)
        optimizer_aug = torch.optim.AdamW(model.gl.parameters(), lr=args.lr_a)
    best_val = float('-inf')
    for epoch in range(args.epochs):
        model.train()
        if args.method in ['erm', 'gstopr']:
            optimizer.zero_grad()
            loss = model(datasets_tr, criterion)
            loss.backward()
            optimizer.step()
        elif args.method == 'eerm':
            for m in range(args.T):
                Var, Mean, Log_p = model(datasets_tr, criterion)
                outer_loss = Var + args.beta * Mean
                reward = Var.detach()
                inner_loss = - reward * Log_p
                if m == 0:
                    optimizer_gnn.zero_grad()
                    outer_loss.backward()
                    optimizer_gnn.step()
                optimizer_aug.zero_grad()
                inner_loss.backward()
                optimizer_aug.step()

        accs, test_outs = evaluate_whole_graph(args, model, datasets_tr, datasets_val, datasets_te, eval_func)
        logger.add_result(run, accs)

        if epoch % args.display_step == 0:
            if args.method in ['erm', 'gstopr']:
                print(f'Epoch: {epoch:02d}, '
                  f'Loss: {loss:.4f}, '
                  f'Train: {100 * accs[0]:.2f}%, '
                  f'Valid: {100 * accs[1]:.2f}%, ')
                test_info = ''
                for test_acc in accs[2:]:
                    test_info += f'Test: {100 * test_acc:.2f}% '
                print(test_info)
            elif args.method == 'eerm':
                print(f'Epoch: {epoch:02d}, '
                      f'Mean Loss: {Mean:.4f}, '
                      f'Var Loss: {Var:.4f}, '
                      f'Train: {100 * accs[0]:.2f}%, '
                      f'Valid: {100 * accs[1]:.2f}%, ')
                test_info = ''
                for test_acc in accs[2:]:
                    test_info += f'Test: {100 * test_acc:.2f}% '
                print(test_info)

    logger.print_statistics(run)

### Save results ###
results = logger.print_statistics()
filename = f'./results/{args.dataset}.csv'
print(f"Saving results to {filename}")
with open(f"{filename}", 'a+') as write_obj:
    log = f"{args.method}," + (f"r={args.r},g={args.noise}," if args.method == 'gstopr' else '') + f"{args.gnn},"
    torch.save(results, "./results/" + log.replace(',', '-') + '.pt')
    for i in range(results.shape[1]):
        r = results[:, i]
        log += f"{r.mean():.3f} ± {r.std():.3f},"
    write_obj.write(log + f"\n")
