In [105]:
import time
import argparse
import numpy as np
import torch
import torch.nn.functional as F
import torch.optim as optim
from utils import load_citation, sgc_precompute, set_seed
from models import get_model
from metrics import accuracy
import pickle as pkl
from args import get_citation_args
from time import perf_counter

# Arguments
args = get_citation_args()

if args.tuned:
    if args.model == "SGC":
        with open("{}-tuning/{}.txt".format(args.model, args.dataset), 'rb') as f:
            args.weight_decay = pkl.load(f)['weight_decay']
            print("using tuned weight decay: {}".format(args.weight_decay))
    else:
        raise NotImplemented

# setting random seeds
set_seed(args.seed, args.cuda)

adj, features, labels, idx_train, idx_val, idx_test = load_citation(args.dataset, args.normalization, args.cuda)



In [106]:
degree = 6
features_bak = features.clone()

In [107]:
t = perf_counter()
feature_list = [features]
for i in range(degree):
    features = torch.spmm(adj, features)
    feature_list.append(features)
precompute_time = perf_counter()-t

In [108]:
len(feature_list)

7

In [89]:
args.degree

7

In [90]:
features.shape

torch.Size([2708, 1433])

In [91]:
features = torch.cat(feature_list,dim=1)

In [92]:
X_feature = features.cpu().data.numpy()

In [93]:
from sklearn.ensemble import RandomForestClassifier

In [94]:
clf = RandomForestClassifier(n_estimators=500, n_jobs=-1)

In [95]:
hop_index = 2

In [100]:
for hop_index in range(1,9):
    clf.fit(X_feature[idx_train][:,-1433*hop_index:],labels[idx_train])
    test_res = clf.predict(X_feature[idx_test][:,-1433*hop_index: ])
    print(hop_index, metrics.accuracy_score(labels[idx_test],test_res))

1 0.752
2 0.748
3 0.754
4 0.769
5 0.765
6 0.774
7 0.764


In [97]:
X_feature.shape

(2708, 11464)

In [99]:
8*1433

11464

In [36]:
from sklearn import metrics

In [37]:
metrics.accuracy_score(labels[idx_test],test_res)

0.791

In [20]:
X_feature.shape

(2708, 4299)

In [None]:
model = get_model(args.model, features.size(1), labels.max().item()+1, args.hidden, args.dropout, args.cuda)