In [1]:
# import section
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import TensorDataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn import preprocessing
import numpy as np
import torch
import pandas as pd
import networkx as nx

In [34]:
class two_layer_GraphNet(nn.Module):
    def __init__(self, A, device, F = 1433, class_number=7, hidden_neurons = 200):
        super(two_layer_GraphNet, self).__init__()
        # precompute adjacency matrix before training
        if A[0][0] == 0:
            A = A * 0.1 + np.identity(A.shape[0])
        self.A = torch.from_numpy(A).float()
        self.class_number = class_number
        self.fc1 = nn.Linear(F, hidden_neurons, bias=True)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_neurons, self.class_number, bias=True)
        self.device = device

    def forward(self, x):
        # training on full x, not batch
        x = x.float()
        # average all neighboors
        #print(x.shape)
        #A = self.A.float()
        A = self.A.to(self.device)
        #print(A.shape)
        #print(self.X.shape)
        #print(A.dtype, self.X.dtype, x.dtype)
        x = torch.matmul(A, x)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.relu(x)
        return x
    

In [40]:
class GNN():
    def __init__(self, hidden_neurons=200, learning_rate=1e-3, epoch=50, device='cuda'):
        self.hidden_neurons = hidden_neurons
        self.learning_rate = learning_rate
        self.epoch = epoch
        self.device = torch.device(device)
        self.kwargs = {'num_workers': 1, 'pin_memory': True}
    
    def train(self, model, device, train_loader, optimizer):
        model.train()
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(self.device), target.to(self.device)
            optimizer.zero_grad()
            output = model(data)
            #print(output, target)
            loss = F.cross_entropy(output, target)
            loss.backward()
            optimizer.step()
            print('training loss {:.4f}'.format(loss.item()))

    def test(self, model, device, test_loader):
        model.eval()
        test_loss = 0
        correct = 0
        with torch.no_grad():
            for batch_idx, (data, target) in enumerate(test_loader):
                data, target = data.to(device), target.to(device)
                output = model(data)
                test_loss += F.cross_entropy(output, target, reduction='sum').item()
                pred = output.argmax(dim=1, keepdim=True)
                correct += pred.eq(target.view_as(pred)).sum().item()

        test_loss /= len(test_loader.dataset)
        print('Average loss: {:.4f}, Accuracy: {:.4f}%'.format(test_loss, 100. * correct / len(test_loader.dataset)))
        return 100. * correct / len(test_loader.dataset)
    
    def encode_label(self, y):
        '''

        '''
        label_encoder = preprocessing.LabelEncoder()
        label_encoder.fit(y)
        y = label_encoder.transform(y)
        return y
    
    def fit(self, X, y, A):
        X = torch.from_numpy(X).type(torch.long)
        self.A = A
        y = torch.from_numpy(self.encode_label(y))
        self.dataset = TensorDataset(X, y)
        self.dataloader = DataLoader(self.dataset, batch_size=2708, shuffle=True, **self.kwargs)
        
    def train_epoch(self):
        model = two_layer_GraphNet(self.A, self.device)
        model.to(self.device)
        optimizer = optim.Adam(model.parameters(), lr=self.learning_rate)
        accs = {'acc': []}
        for epoch in range(self.epoch):
            self.train(model, self.device, self.dataloader , optimizer)
            acc = self.test(model, self.device, self.dataloader)
            accs['acc'].append(acc)
        return accs

In [41]:
gnn = GNN()

In [42]:
from coraloader import cora_loader
cora = cora_loader('data/cora.content', 'data/cora.cites', None)
X, y, A = cora.get_train()

In [43]:
gnn.fit(X, y, A)

In [44]:
gnn.train_epoch()

training loss 1.9513
Average loss: 1.9424, Accuracy: 15.2511%
training loss 1.9425
Average loss: 1.9342, Accuracy: 29.9852%
training loss 1.9342
Average loss: 1.9237, Accuracy: 38.7740%
training loss 1.9237
Average loss: 1.9089, Accuracy: 47.4520%
training loss 1.9091
Average loss: 1.8904, Accuracy: 57.2747%
training loss 1.8904
Average loss: 1.8699, Accuracy: 61.7799%
training loss 1.8684
Average loss: 1.8466, Accuracy: 60.1182%
training loss 1.8454
Average loss: 1.8214, Accuracy: 58.8996%
training loss 1.8211
Average loss: 1.7955, Accuracy: 55.6869%
training loss 1.7935
Average loss: 1.7634, Accuracy: 52.9542%
training loss 1.7656
Average loss: 1.7362, Accuracy: 50.8124%
training loss 1.7367
Average loss: 1.7041, Accuracy: 49.3722%
training loss 1.7059
Average loss: 1.6703, Accuracy: 49.6307%
training loss 1.6697
Average loss: 1.6385, Accuracy: 50.0000%
training loss 1.6392
Average loss: 1.6051, Accuracy: 49.6307%
training loss 1.6002
Average loss: 1.5661, Accuracy: 50.7016%
training

{'acc': [15.251107828655835,
  29.98522895125554,
  38.77400295420975,
  47.4519940915805,
  57.27474150664697,
  61.779911373707534,
  60.118168389955684,
  58.899556868537665,
  55.68685376661743,
  52.95420974889217,
  50.81240768094535,
  49.37223042836041,
  49.630723781388475,
  50.0,
  49.630723781388475,
  50.70162481536189,
  52.06794682422452,
  54.61595273264402,
  56.72082717872969,
  59.416543574593796,
  64.0324963072378,
  67.46676514032497,
  70.27326440177252,
  72.48892171344166,
  75.62776957163959,
  77.69571639586411,
  79.1728212703102,
  80.61299852289513,
  80.61299852289513,
  81.90546528803544,
  81.97932053175775,
  82.75480059084195,
  83.16100443131462,
  83.53028064992614,
  84.15805022156573,
  84.15805022156573,
  84.52732644017725,
  85.08124076809453,
  85.78286558345643,
  86.74298375184638,
  86.33677991137371,
  86.81683899556869,
  87.00147710487444,
  87.03840472673559,
  87.70310192023634,
  88.62629246676514,
  88.109305760709,
  88.700147710487