In [3]:
import pandas as pd
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [4]:
class GCNLayer(nn.Module):
    def __init__(self, A, in_feats, out_feats):
        super(GCNLayer, self).__init__()
        self.A_hat = A+torch.eye(A.size(0))
        self.D = torch.diag(torch.sum(A,1))
        self.D = self.D.inverse().sqrt()
        self.W = nn.Parameter(torch.rand(in_feats, out_feats, requires_grad=True))

    def forward(self, X):
        out = torch.mm(torch.mm(self.D, self.A_hat), self.D)
        out = torch.mm(out, X)
        out = torch.mm(out, self.W)
        out = torch.relu(out)
        return out

class GCN(nn.Module):
    def __init__(self, A, in_feats, hide_feats, out_feats):
        super(GCN, self).__init__()
        self.gcn1 = GCNLayer(A, in_feats, hide_feats)
        self.gcn2 = GCNLayer(A, hide_feats, out_feats)

    def forward(self, X):
        out = self.gcn1(X)
        out = self.gcn2(out)
        return out


In [5]:
def load_cora():
    raw_content = pd.read_csv('./data/cora/cora.content', sep='\t', header=None)
    num_nodes = raw_content.shape[0]

    id = list(raw_content[0])
    idx = list(raw_content.index)
    id2idx = dict(zip(id, idx))

    feat_data = raw_content.iloc[:, 1:-1]
    feat_data = torch.FloatTensor(feat_data.values)
    labels = pd.get_dummies(raw_content[1434])
    labels = torch.FloatTensor(labels.values)

    raw_cites = pd.read_csv('./data/cora/cora.cites', sep='\t', header=None)

    adj = np.zeros((num_nodes, num_nodes))
    for i, j in zip(raw_cites[0], raw_cites[1]):
        x = id2idx[i]
        y = id2idx[j]
        adj[x][y] = adj[y][x] = 1
    adj = torch.FloatTensor(adj)
    
    return feat_data, labels, adj, id2idx


In [7]:
feat_data, labels, adj, id2idx = load_cora()
gcn_model = GCN(A=adj, in_feats=feat_data.shape[1], hide_feats=256, out_feats=128)
out = gcn_model(feat_data)

In [8]:
out

tensor([[1476.1133, 1549.0864, 1572.0503,  ..., 1486.0237, 1546.4064,
         1558.3378],
        [1769.6311, 1852.2921, 1865.6219,  ..., 1771.0582, 1862.7021,
         1864.8555],
        [1482.4706, 1550.1504, 1565.7719,  ..., 1482.0724, 1554.6464,
         1559.7914],
        ...,
        [2022.5492, 2108.3916, 2134.3616,  ..., 2033.2152, 2125.6111,
         2133.7974],
        [2351.2390, 2456.5281, 2482.4014,  ..., 2357.4795, 2459.8433,
         2471.5010],
        [2568.5979, 2677.4792, 2720.9409,  ..., 2573.4099, 2674.2449,
         2707.1863]], grad_fn=<ReluBackward0>)