In [6]:
import time

import pandas as pd
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

In [7]:
class GCNLayer(nn.Module):
    def __init__(self, A, in_feats, out_feats):
        super(GCNLayer, self).__init__()
        self.A_hat = A+torch.eye(A.size(0))
        self.D = torch.diag(torch.sum(A,1))
        self.D = self.D.inverse().sqrt()
        self.W = nn.Parameter(torch.rand(in_feats, out_feats, requires_grad=True))

    def forward(self, X):
        out = torch.mm(torch.mm(self.D, self.A_hat), self.D)
        out = torch.mm(out, X)
        out = torch.mm(out, self.W)
        return out

class GCN(nn.Module):
    def __init__(self, adj, in_feats, hide_feats, out_feats):
        super(GCN, self).__init__()
        self.gcn1 = GCNLayer(adj, in_feats, hide_feats)
        self.gcn2 = GCNLayer(adj, hide_feats, out_feats)

    def forward(self, x):
        x = self.gcn1(x)
        x = F.relu(x)
        x = self.gcn2(x)
        return F.log_softmax(x, dim=1)

In [8]:
def accuracy(output, labels):
    preds = output.max(1)[1].type_as(labels)
    correct = preds.eq(labels).double()
    correct = correct.sum()
    return correct / len(labels)

def load_cora():
    raw_content = pd.read_csv('./data/cora/cora.content', sep='\t', header=None)
    num_nodes = raw_content.shape[0]

    id = list(raw_content[0])
    idx = list(raw_content.index)
    id2idx = dict(zip(id, idx))

    feat_data = raw_content.iloc[:, 1:-1]
    feat_data = torch.FloatTensor(feat_data.values)
    labels = pd.get_dummies(raw_content[1434])
    labels = np.where(labels)[1]
    labels = torch.LongTensor(labels)

    raw_cites = pd.read_csv('./data/cora/cora.cites', sep='\t', header=None)

    adj = np.zeros((num_nodes, num_nodes))
    for i, j in zip(raw_cites[0], raw_cites[1]):
        x = id2idx[i]
        y = id2idx[j]
        adj[x][y] = adj[y][x] = 1
    adj = torch.FloatTensor(adj)

    idx_train = range(2000)
    idx_val = range(2000, 2708)
    idx_train = torch.LongTensor(idx_train)
    idx_val = torch.LongTensor(idx_val)
    
    return feat_data, labels, adj, id2idx, idx_train, idx_val

In [9]:
feat_data, labels, adj, id2idx, idx_train, idx_val = load_cora()
model = GCN(adj=adj, in_feats=feat_data.shape[1], hide_feats=16, out_feats=7)
optimizer = optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

In [11]:
for epoch in range(100):
    t = time.time()
    model.train()
    optimizer.zero_grad()
    output = model(feat_data)

    loss_train = F.nll_loss(output[idx_train], labels[idx_train])
    acc_train = accuracy(output[idx_train], labels[idx_train])
    loss_train.backward()
    optimizer.step()

    loss_val = F.nll_loss(output[idx_val], labels[idx_val])
    acc_val = accuracy(output[idx_val], labels[idx_val])


    acc_train = accuracy(output, labels)
    print('Epoch: {:04d}'.format(epoch+1),
        'loss_train: {:.4f}'.format(loss_train.item()),
        'acc_train: {:.4f}'.format(acc_train.item()),
        'loss_val: {:.4f}'.format(loss_val.item()),
        'acc_val: {:.4f}'.format(acc_val.item()),
        'time: {:.4f}s'.format(time.time() - t))

Epoch: 0001 loss_train: 8.9460 acc_train: 0.2109 loss_val: 10.0348 acc_val: 0.2203 time: 0.3934s
Epoch: 0002 loss_train: 8.7698 acc_train: 0.1736 loss_val: 9.7931 acc_val: 0.1723 time: 0.4062s
Epoch: 0003 loss_train: 8.0958 acc_train: 0.1577 loss_val: 9.0525 acc_val: 0.1610 time: 0.3586s
Epoch: 0004 loss_train: 6.7422 acc_train: 0.1780 loss_val: 7.6366 acc_val: 0.1794 time: 0.4043s
Epoch: 0005 loss_train: 5.1445 acc_train: 0.2707 loss_val: 5.9815 acc_val: 0.2712 time: 0.3849s
Epoch: 0006 loss_train: 4.0486 acc_train: 0.4645 loss_val: 4.8192 acc_val: 0.4068 time: 0.3973s
Epoch: 0007 loss_train: 3.9014 acc_train: 0.3907 loss_val: 4.5827 acc_val: 0.3729 time: 0.3993s
Epoch: 0008 loss_train: 3.9209 acc_train: 0.3944 loss_val: 4.5837 acc_val: 0.3686 time: 0.3818s
Epoch: 0009 loss_train: 3.8632 acc_train: 0.3623 loss_val: 4.5110 acc_val: 0.3489 time: 0.3958s
Epoch: 0010 loss_train: 3.4777 acc_train: 0.3996 loss_val: 4.0771 acc_val: 0.3743 time: 0.4053s
Epoch: 0011 loss_train: 2.9543 acc_trai