In [1]:
from preload import load_data
from model import *
# from edge import *
from utils import *
from sampler import *
from sklearn import metrics
from sklearn.metrics import accuracy_score
import os

root_path = os.getcwd() + '/'

dataset = 'reddit'
batch_size = 256
node_iters = 20
emb_size = 128
node_early_stop = 20
node_least_iter = 100
num_neigh0=5
num_neigh1=5
is_cuda = False
if dataset=='pubmed' or dataset=='reddit':
    batch_size=128
    num_neigh0 = 25
    num_neigh1 = 10
    node_least_iter = 80

eva_size = 16

model_name = 'model_save/' + dataset + '.pkl'

eva_iter = 4
patience = 5

from preload import load_data, load_reddit_data

adj, features, labels, train_index, val_index, test_index = load_reddit_data(cuda=False)

s = SamplerReddit(adj)

sampler = s.normal_sample
num_nodes = features.shape[0]
num_feature = features.shape[1]
num_classes = labels.shape[1]
feat = nn.Embedding(num_nodes, num_feature)
feat.weight = nn.Parameter(torch.FloatTensor(features), requires_grad=False)

agg1 = MeanAggregator(feat, sampler, is_softgate=True, cuda=is_cuda)
enc1 = Encoder(feat, num_feature, emb_size, agg1, gcn=False, cuda=is_cuda)
agg2 = MeanAggregator(
    lambda nodes: enc1(nodes).t(), sampler, is_softgate=True, cuda=is_cuda)
enc2 = Encoder(lambda nodes: enc1(nodes).t(), enc1.embed_dim, emb_size, agg2,
    base_model=enc1, gcn=False, cuda=is_cuda)
enc1.num_sample = num_neigh1
enc2.num_sample = num_neigh0
graphsage = SupervisedGraphSage(num_classes, enc2, drop_out=0.5, cuda=is_cuda)
optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, graphsage.parameters()), lr=1.0)

def batch_eva(net, x, y, batch_size):
    out_list = []
    for idx in range(0, len(x),batch_size):
        end_idx = min(idx+batch_size,len(x))
        batch_nodes = x[idx:end_idx]
        output = net.forward(batch_nodes).data.cpu().numpy().argmax(axis=1) 
        out_list.append(output)
    pred = np.concatenate(out_list,axis=0)
    acc = accuracy_score(y,pred)
    return acc

score_list = []
times = []
patience_count=0
max_score=0
for batch in range(100000):
    batch_nodes = random.sample(list(train_index), batch_size)
    start_time = time.time()
    optimizer.zero_grad()
    label = Variable(torch.LongTensor(labels[batch_nodes].argmax(axis=1)))
    if is_cuda:
        label = label.cuda()
    tr_output, loss = graphsage.loss(
        batch_nodes,
        label,
        is_train=True
    )
    loss.backward()
    optimizer.step()
    end_time = time.time()
    times.append(end_time - start_time)
    if batch%eva_iter ==0:
        va_nodes = random.sample(list(val_index), eva_size)
        tr_acc = batch_eva(graphsage,batch_nodes,labels[batch_nodes].argmax(axis=1),batch_size)
        va_acc = batch_eva(graphsage,va_nodes,labels[va_nodes].argmax(axis=1),batch_size)
        if va_acc > max_score:
            max_score=va_acc
            patience_count = 0
            torch.save(graphsage.state_dict(), model_name)
        else:
            patience_count +=1
            if patience_count==patience:
                print('Early Stop')
                break
        print('Epoch: %d,' %(batch),
          '|Pat: %d/%d' % (patience_count, patience),
          '|loss: %.4f' % loss.data.numpy(),
          '|tr_acc: %.4f' % tr_acc,
          '|va_acc: %.4f' % va_acc)
print('*'*20, 'Final Test', '*'*20)
graphsage.load_state_dict(torch.load(model_name))  
ts_acc = batch_eva(graphsage,test_index,labels[test_index].argmax(axis=1),batch_size)
print('Test Acc:%.4f'%ts_acc)