In [100]:
import torch
import torch.nn as nn
from torch.autograd import Variable
from numpy.random import choice
from torch.nn.parameter import Parameter
import math
from sklearn import preprocessing
import numpy as np
from torch import optim
import torch.nn.functional as F

In [101]:
def aggregate(A, X, len_walk, num_neigh, agg_func):
    norm = torch.div(A, torch.sum(A, axis=1))
    norm = torch.matrix_power(norm, len_walk)
    result = torch.zeros(X.shape)
    for i in range(A.shape[0]):
        x = A[i].cpu().detach().numpy()
        ind = np.random.choice(range(x.shape[0]), num_neigh, replace=False)
        if agg_func == "MEAN":
            result[i] = torch.mean(X[ind], axis=0)
        else:
            result[i] = torch.max(X[ind], axis=0).values
    return result

In [102]:
class SageLayer(nn.Module):
    def __init__(self, F, O, len_walk = 2, num_neigh = 10, agg_func="MEAN", bias=True): 
        super(SageLayer, self).__init__()
        self.F = F
        self.O = O
        self.weight = Parameter(torch.FloatTensor(2 * F, O))
        self.len_walk = len_walk
        self.num_neigh = num_neigh
        self.agg_func = agg_func
        if bias:
            self.bias = Parameter(torch.FloatTensor(O))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()
        
    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.uniform_(-stdv, stdv)
    
    def forward(self, X, A):
        aggregated = aggregate(A, X, self.len_walk, self.num_neigh, self.agg_func)
        aggregated = aggregated.to(X.device)
        combined = torch.cat([X, aggregated], dim=1)
        combined = torch.mm(combined, self.weight)
        if self.bias is not None:
            return combined + self.bias
        else:
            return combined
        return combined

In [113]:
class GraphSage_model(nn.Module):
    def __init__(self, X, A, n=0, F=1433, class_num=7, 
                 agg_func='MEAN', hidden_neuron=200, len_walk=2, num_neigh=10, bias=True):
        super(GraphSage_model, self).__init__()

        self.F = F
        self.class_num = class_num
        self.n = n
        self.agg_func = agg_func
        self.X = X
        self.A = A
        
        self.gs1 = SageLayer(F, hidden_neuron, len_walk=len_walk, num_neigh=num_neigh, agg_func=agg_func, bias=bias)
        self.gsh = SageLayer(hidden_neuron, hidden_neuron, len_walk=len_walk, num_neigh=num_neigh, agg_func=agg_func, bias=bias)
        self.gs2 = SageLayer(hidden_neuron, self.class_num, len_walk=len_walk, num_neigh=num_neigh, agg_func=agg_func, bias=bias)
        
    def forward(self, X):
        X = self.gs1(X, self.A)
        X = F.relu(X)
        for i in range(self.n):
            X = self.gsh(X, self.A)
            X = F.relu(X)
        X = self.gs2(X, self.A)
        return F.log_softmax(X, dim=1)

In [118]:
class GraphSage():
    def __init__(self, A, X, y, device='cuda', n=0, F=1433, class_num=7, agg_func="MEAN", hidden_neuron=200,
                len_walk=2, num_neigh=10, bias=True, val_size=0.3):
        if device == 'cuda':
            self.device= torch.device('cuda')
        else:
            assert('only support cuda')
        le = preprocessing.LabelEncoder()
        le.fit(y)
        y = le.transform(y)
        
        X = torch.tensor(X)
        X = X.type(torch.float)
        y = torch.tensor(y)
        y = y.type(torch.long)
        y = y.to(self.device)
        A = torch.from_numpy(A).float()
        
        self.X = X.to(self.device)
        self.A = A.to(self.device)
        self.y = y.to(self.device)
        
        train_idx = np.random.choice(self.X.shape[0], round(self.X.shape[0]*(1-val_size)), replace=False)
        val_idx = np.array([x for x in range(X.shape[0]) if x not in train_idx])
        print("Train length :{a}, Validation length :{b}".format(a=len(train_idx), b=len(val_idx)))
        
        self.idx_train = torch.LongTensor(train_idx)
        self.idx_val = torch.LongTensor(val_idx)
        self.graphsage = GraphSage_model(self.X, self.A, n=n, F=F, agg_func=agg_func, hidden_neuron=hidden_neuron,
                                         class_num = class_num, len_walk=len_walk, bias=bias, num_neigh=num_neigh)
        self.graphsage.to(self.device)
        
    def train(self, optimizer, epoch):
        self.graphsage.train()
        optimizer.zero_grad()
        output = self.graphsage(self.X)
        loss = F.cross_entropy(output[self.idx_train], self.y[self.idx_train])
        loss.backward(retain_graph=True)
        optimizer.step()
        print('Epoch: {x}'.format(x=epoch))
        print('training loss {:.4f}'.format(loss.item()))
            
    def test(self):
        self.graphsage.eval()
        test_loss = 0
        correct = 0
        with torch.no_grad():
            output = self.graphsage(self.X)
            test_loss = F.cross_entropy(output[self.idx_val], self.y[self.idx_val], reduction='sum').item()
            pred = output.argmax(dim=1, keepdim=True)[self.idx_val]
            correct += pred.eq(self.y[self.idx_val].view_as(pred)).sum().item()

        test_loss /= len(self.idx_val)
        print('Validtion: Average loss: {:.4f}, Accuracy: {:.4f}%'.format(test_loss, 100. * correct / len(self.idx_val)))
        return 100. * correct / len(self.idx_val)
    
    def train_epoch(self, epochs=50, lr=1e-3):
        acc = []
        optimizer = optim.Adam(self.graphsage.parameters(), lr=lr)#, weight_decay=1e-1)

        for epoch in range(epochs):
            self.train(optimizer, epoch)
            accs = self.test()
            acc.append(accs)
        accs = {'acc': acc}
        return accs

In [119]:
from coraloader import cora_loader
cora = cora_loader('data' + '/cora.content', 'data' + '/cora.cites', None)
X, y, A = cora.get_train()

In [120]:
model = GraphSage(A, X, y, agg_func="MAX")

Train length :1896, Validation length :812


In [121]:
model.train_epoch()

Epoch: 0
training loss 2.1360
Validtion: Average loss: 1.8881, Accuracy: 25.3695%
Epoch: 1
training loss 1.8759
Validtion: Average loss: 1.8104, Accuracy: 30.2956%
Epoch: 2
training loss 1.7535
Validtion: Average loss: 1.7967, Accuracy: 29.8030%
Epoch: 3
training loss 1.6908
Validtion: Average loss: 1.7592, Accuracy: 30.2956%
Epoch: 4
training loss 1.6239
Validtion: Average loss: 1.6924, Accuracy: 33.2512%
Epoch: 5
training loss 1.5471
Validtion: Average loss: 1.6409, Accuracy: 38.3005%
Epoch: 6
training loss 1.4714
Validtion: Average loss: 1.6047, Accuracy: 44.0887%
Epoch: 7
training loss 1.4066
Validtion: Average loss: 1.5485, Accuracy: 47.9064%
Epoch: 8
training loss 1.3351
Validtion: Average loss: 1.4947, Accuracy: 50.3695%
Epoch: 9
training loss 1.2567
Validtion: Average loss: 1.4508, Accuracy: 54.0640%
Epoch: 10
training loss 1.1881
Validtion: Average loss: 1.3943, Accuracy: 54.6798%
Epoch: 11
training loss 1.1179
Validtion: Average loss: 1.3464, Accuracy: 55.2956%
Epoch: 12
trai

{'acc': [25.36945812807882,
  30.295566502463053,
  29.80295566502463,
  30.295566502463053,
  33.251231527093594,
  38.30049261083744,
  44.08866995073892,
  47.9064039408867,
  50.36945812807882,
  54.064039408866996,
  54.679802955665025,
  55.29556650246305,
  55.78817733990148,
  58.86699507389162,
  60.59113300492611,
  62.5615763546798,
  64.90147783251231,
  66.37931034482759,
  68.5960591133005,
  69.08866995073892,
  70.44334975369458,
  69.45812807881774,
  71.18226600985221,
  70.6896551724138,
  71.18226600985221,
  71.55172413793103,
  73.52216748768473,
  71.18226600985221,
  72.78325123152709,
  73.64532019704434,
  73.39901477832512,
  73.76847290640394,
  74.01477832512315,
  73.15270935960591,
  74.38423645320196,
  73.64532019704434,
  74.13793103448276,
  74.38423645320196,
  74.13793103448276,
  74.50738916256158,
  75.24630541871922,
  74.75369458128078,
  73.27586206896552,
  74.75369458128078,
  74.01477832512315,
  73.64532019704434,
  74.01477832512315,
  74.