In [105]:
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import TensorDataset, DataLoader
from torch.nn.parameter import Parameter
import numpy as np
import torch
from sklearn import preprocessing
from sklearn.model_selection import train_test_split
import math
from layers import LPA_GCN_layer

In [86]:
class LPA_GCN_model(nn.Module):
    def __init__(self, A, len_walk, F, class_number, hid):
        super(LPA_GCN_model, self).__init__()
        self.lg1 = LPA_GCN_layer(F = F, O = hid, A=A, len_walk=len_walk)
        self.lg2 = LPA_GCN_layer(F = hid, O = class_number, A=A, len_walk=len_walk)

    def forward(self, X, A, y):
        X, y_hat = self.lg1(X, A, y)
        X = F.relu(X)
        X, y_hat = self.lg2(X, A, y_hat)
        return F.log_softmax(X, dim=1), F.log_softmax(y_hat,dim=1)

In [95]:
class LPA_GCN():
    def __init__(self, A, X, y, lamb, device='cuda', len_walk=2, F=1433, class_number=7, hid=200, val = 0.3):
        if device == 'cuda':
            self.device= torch.device('cuda')
        else:
            assert('only support cuda')
        le = preprocessing.LabelEncoder()
        one = preprocessing.OneHotEncoder(sparse=False)
        y_ = np.reshape(y, (-1, 1))
        one.fit(y_)
        labels = np.array(one.transform(y_))
        labels = torch.from_numpy(labels).type(torch.float)
        le.fit(y)
        y = le.transform(y)
        
        X = torch.tensor(X)
        X = X.type(torch.float)
        y = torch.tensor(y)
        y = y.type(torch.long)
        y = y.to(self.device)
        A = torch.from_numpy(A).float()
        
        self.X = X.to(self.device)
        self.A = A.to(self.device)
        self.y = y.to(self.device)
        self.Lambda = lamb
        self.labels = labels.to(self.device)
        
        train_idx = np.random.choice(self.X.shape[0], round(self.X.shape[0]*(1-val)), replace=False)
        val_idx = np.array([x for x in range(X.shape[0]) if x not in train_idx])
        print("Train length :{a}, Validation length :{b}".format(a=len(train_idx), b=len(val_idx)))
        
        self.idx_train = torch.LongTensor(train_idx)
        self.idx_val = torch.LongTensor(val_idx)
        
        self.lpa_gcn = LPA_GCN_model(A = self.A, len_walk=len_walk, F=F, class_number = class_number, hid = hid)
        self.lpa_gcn.to(self.device)
    
    def train(self, optimizer, epoch):
        self.lpa_gcn.train()
        optimizer.zero_grad()
        output, y_hat = self.lpa_gcn(self.X, self.A, self.labels)
        loss_gcn = F.cross_entropy(output[self.idx_train], self.y[self.idx_train])
        loss_lpa = F.nll_loss(y_hat, self.y)
        loss_train = loss_gcn + self.Lambda * loss_lpa
        loss_train.backward(retain_graph=True)
        optimizer.step()
        print('Epoch: {x}'.format(x=epoch))
        print('training loss {:.4f}'.format(loss_train.item()))
            
    def test(self):
        self.lpa_gcn.eval()
        test_loss = 0
        correct = 0
        with torch.no_grad():
            output, y_hat = self.lpa_gcn(self.X, self.A, self.labels)
            #print(self.idx_val)
            test_loss = F.cross_entropy(output[self.idx_val], self.y[self.idx_val], reduction='sum').item()
            pred = output.argmax(dim=1, keepdim=True)[self.idx_val]
            correct += pred.eq(self.y[self.idx_val].view_as(pred)).sum().item()

        test_loss /= len(self.idx_val)
        print('Average loss: {:.4f}, Accuracy: {:.4f}%'.format(test_loss, 100. * correct / len(self.idx_val)))
        return 100. * correct / len(self.idx_val)
    
    def train_model(self, epochs=50, lr=1e-3):
        lpa_gcn_test_acc = []
        optimizer = optim.Adam(self.lpa_gcn.parameters(), lr=lr)#, weight_decay=1e-1)

        for epoch in range(epochs):
            self.train(optimizer, epoch)
            accs = self.test()
            lpa_gcn_test_acc.append(accs)
        accs = {'acc': lpa_gcn_test_acc}
        return accs

In [96]:
from coraloader import cora_loader
cora = cora_loader('data' + '/cora.content', 'data' + '/cora.cites', None)
X, y, A = cora.get_train()

In [103]:
test = LPA_GCN(A, X, y, 0)

Train length :1896, Validation length :812


In [102]:
test.train_model()

Epoch: 0
training loss 3.5086
Average loss: 1.9807, Accuracy: 17.9803%
Epoch: 1
training loss 3.3653
Average loss: 1.8593, Accuracy: 30.5419%
Epoch: 2
training loss 3.2347
Average loss: 1.7499, Accuracy: 50.9852%
Epoch: 3
training loss 3.1169
Average loss: 1.6522, Accuracy: 55.9113%
Epoch: 4
training loss 3.0114
Average loss: 1.5645, Accuracy: 56.0345%
Epoch: 5
training loss 2.9165
Average loss: 1.4850, Accuracy: 54.8030%
Epoch: 6
training loss 2.8302
Average loss: 1.4112, Accuracy: 55.0493%
Epoch: 7
training loss 2.7501
Average loss: 1.3408, Accuracy: 56.2808%
Epoch: 8
training loss 2.6739
Average loss: 1.2718, Accuracy: 58.7438%
Epoch: 9
training loss 2.5998
Average loss: 1.2034, Accuracy: 62.4384%
Epoch: 10
training loss 2.5266
Average loss: 1.1355, Accuracy: 66.1330%
Epoch: 11
training loss 2.4543
Average loss: 1.0687, Accuracy: 70.8128%
Epoch: 12
training loss 2.3835
Average loss: 1.0044, Accuracy: 73.3990%
Epoch: 13
training loss 2.3157
Average loss: 0.9439, Accuracy: 76.2315%
Ep

{'acc': [17.980295566502463,
  30.541871921182267,
  50.98522167487685,
  55.91133004926108,
  56.03448275862069,
  54.80295566502463,
  55.04926108374384,
  56.2807881773399,
  58.74384236453202,
  62.4384236453202,
  66.13300492610837,
  70.8128078817734,
  73.39901477832512,
  76.23152709359606,
  78.57142857142857,
  79.92610837438424,
  80.78817733990148,
  82.38916256157636,
  83.86699507389163,
  84.35960591133005,
  85.34482758620689,
  85.5911330049261,
  86.57635467980296,
  86.69950738916256,
  86.94581280788178,
  87.19211822660098,
  87.06896551724138,
  86.94581280788178,
  86.94581280788178,
  86.82266009852216,
  86.69950738916256,
  86.69950738916256,
  86.69950738916256,
  86.82266009852216,
  86.82266009852216,
  86.94581280788178,
  86.82266009852216,
  86.94581280788178,
  86.94581280788178,
  86.94581280788178,
  87.19211822660098,
  87.06896551724138,
  87.5615763546798,
  87.4384236453202,
  87.5615763546798,
  87.5615763546798,
  87.06896551724138,
  87.0689655