Inspired by https://arxiv.org/pdf/1609.02907.pdf

Install everything in the Readme

In [2]:
import torch
import numpy as np

GCN spec:

Needs: <br>
-adjacency list of size nxn <br>
-output of size nx1 <br>
-classification or regression

Can also provide: <br>
-initial features of size nxf <br>
-hidden layer size (hl) <br>
-train/test split on nodes


TODO: <br>
-sticking it on a GPU, which really just means adding .cuda() and a function parameter <br>
-parameterize `predict` function, so you can predict all nodes or just a subset (train/test); easy to do

# Hand-written GCN class

In [203]:
class GCN:
    def __init__(self, adj_list, out_labels, method, nl=3, init_feats=None, train_test=0.5):
        """
        Creates important model variables and weights
        method: 'cat' or 'cont'
        """
        # assertions
        assert len(adj_list) == len(adj_list[0])
        assert len(out_labels) == len(adj_list)
        assert method in ['cat', 'cont']
        assert type(nl) == int and nl >= 2 # input and output at minimum
        assert init_feats is None or len(init_feats) == len(adj_list)
        assert type(train_test) == np.ndarray or (type(train_test) == float and train_test <= 1 and train_test > 0)
        
        self.adj_list = torch.from_numpy(adj_list).double()
        sqrt = True # turn to False to use 'simple' adj_list normalization
        D = GCN.create_D(self.adj_list, sqrt=sqrt)
        if sqrt:
            self.norm_adj_list = D @ self.adj_list @ D # https://tkipf.github.io/graph-convolutional-networks/#fn3
        else:
            self.norm_adj_list = D @ self.adj_list
        self.out_labels = torch.from_numpy(out_labels).view(1, -1)
        self.method = method
        if self.method == 'cat':
            self.nc = len(torch.unique(self.out_labels))
        # if no features given, use identity matrix
        if init_feats is None:
            init_feats = np.eye(len(adj_list))
        self.init_feats = torch.from_numpy(init_feats)
        self.nl = nl
        # can give train indices directly or give a train/test split
        if type(train_test) == np.ndarray:
            self.train_indices = train_test
        else:
            self.train_indices = np.random.choice(np.arange(len(adj_list)), int(train_test*len(adj_list)), replace=False)
        
        self.weight_list = []
        if self.method == 'cat':
            sizes = [max(40 - 5*i, 2*self.nc) for i in range(self.nl)]
        else:
            sizes = [max(40 - 5*i, 20) for i in range(self.nl)]
        sizes[-1] = self.nc if self.method == 'cat' else 1 # 1 for continuous output
        for i in range(self.nl):
            if i == 0:
                w = torch.randn(init_feats.shape[1], sizes[0], dtype=torch.double, requires_grad=True)
            else:
                w = torch.randn(sizes[i-1], sizes[i], dtype=torch.double, requires_grad=True)
            self.weight_list.append(w)
            
        self.lr = 1e-4
        self.epoch_stats = []
        
        
    def train(self, epochs=100):
        for i in range(epochs):
            self.train_epoch()
            
            
    def train_epoch(self):
        """
        Propogates the model through the GCN using the formula:
            L_next = A_hat @ L_cur @ w_cur
            where A_hat is the normalized adj matrix
        Then, backprops the model and updates the weights
        Stores, the loss in self.epoch_stats
        """
        l_prev = None
        for i, w in enumerate(self.weight_list):
            if i == 0:
                l_prev = self.gcn_layer(self.norm_adj_list, self.init_feats, w, 'relu')
            elif i != len(self.weight_list)-1:
                l_prev = self.gcn_layer(self.norm_adj_list, l_prev, w, 'relu')
            else:
                if self.method == 'cat':
                    l_prev = self.gcn_layer(self.norm_adj_list, l_prev, w, 'softmax') # predictions!
                else:
                    l_prev = self.gcn_layer(self.norm_adj_list, l_prev, w, None) # no activation for final layer
                    
        loss = self.compute_loss(l_prev, self.out_labels)
        self.epoch_stats.append(loss)
        
        # backprop
        loss.backward() # torch is amazing.. :)
        with torch.no_grad():
            for i in range(len(self.weight_list)):
                self.weight_list[i] -= self.lr * self.weight_list[i].grad # updates weight
                self.weight_list[i].grad.zero_() # resets to 0
                
                
    def predict(self):
        if self.method == 'cat':
            return self.predict_cat()
        else:
            return self.predict_cont()
        
        
    def predict_cont(self):
        with torch.no_grad():
            l_prev = None
            for i, w in enumerate(self.weight_list):
                if i == 0:
                    l_prev = self.gcn_layer(self.norm_adj_list, self.init_feats, w, 'relu')
                elif i != len(self.weight_list)-1:
                    l_prev = self.gcn_layer(self.norm_adj_list, l_prev, w, 'relu')
                else:
                    l_prev = self.gcn_layer(self.norm_adj_list, l_prev, w, None) # predictions!
            return l_prev
        
        
    def predict_cat(self):
        with torch.no_grad():
            l_prev = None
            for i, w in enumerate(self.weight_list):
                if i == 0:
                    l_prev = self.gcn_layer(self.norm_adj_list, self.init_feats, w, 'relu')
                elif i != len(self.weight_list)-1:
                    l_prev = self.gcn_layer(self.norm_adj_list, l_prev, w, 'relu')
                else:
                    l_prev = self.gcn_layer(self.norm_adj_list, l_prev, w, 'softmax') # predictions!
            return torch.argmax(l_prev, axis=1)
    
    
    def compute_loss(self, preds, actual):
        if self.method == 'cat':
            return self.compute_cat_loss(preds, actual)
        else:
            return self.compute_cont_loss(preds, actual)
        
        
    def compute_cont_loss(self, preds, actual):
        """
        Uses MSE loss
        """
        masked_actual = actual.clone().detach().view(-1,1)
        masked_actual[~self.train_indices,:] = 0 # if not a train index, set to 0
        preds *= (masked_actual != 0).double()
        # by doing this, we are calculating the loss for the vertices specified as everything else is 0
        # this will also make sure the loss is only fed (via sum) from rows with train nodes
        # hence, weights are only updated with regards to this loss
        # and the GCN effectively still doesn't know about the vertices it never got the actual loss for !
        return ((preds - masked_actual)**2).sum()
        
    
    def compute_cat_loss(self, preds, actual):
        """
        Uses categorical cross-entropy loss
        """
        # this assigns each class a unique index. This is helpful becase the classes could be 2 and 8
        # for all I know. This will assign 2 to 0 and 8 to 1
        classes = torch.unique(actual).numpy()
        mapping = {}
        for i in range(len(classes)):
            mapping[classes[i]] = i
        # this makes a one-hot matrix based on the number of classes
        baseline = torch.eye(self.nc)
        # this is the final one-hot, where each row corresponds to the actual output in one-hot form
        one_hot = np.zeros(preds.shape, dtype=np.double)
        # this grabs the vertices that we can actually use to train and assigns the row in 
        # one-hot to the actual one-hot value (so if we are at vertex 2 and class=3, we turn 
        # row 2 into 0 1 0 0 ... 0)
        for i, v in enumerate(actual.numpy().reshape(-1)[self.train_indices]):
            index = self.train_indices[i]
            one_hot[index,:] = baseline[mapping[v]].numpy().tolist()
        one_hot = torch.from_numpy(one_hot)
        preds = -torch.log(preds)
        # by doing this, we are calculating the loss for the vertices specified as everything else is 0
        # this will also make sure the loss is only fed (via sum) from rows with non-zero one-hot
        # aka the vertices we set in the previous for loop
        # hence, weights are only updated with regards to this loss
        # and the GCN effectively still doesn't know about the vertices it never got the actual loss for !
        return (preds * one_hot).sum()
    
        
    def create_D(adj_list, sqrt=True):
        # creates the matrix for making A_hat
        # sqrt is optional, can simply invert the diagonal as well, but the paper recommends this approach
        D = torch.eye(len(adj_list), dtype=torch.double)
        diag_sums = adj_list.sum(dim=0)
        for i in range(len(adj_list)):
            if sqrt:
                D[i,i] = 1/torch.sqrt(diag_sums[i])
            else:
                D[i,i] = 1/diag_sums[i]
        return D
    
    
    def gcn_layer(self, ad_list, inp, weights, activation):
        """
        Propogates the model through the GCN using the formula:
            L_next = A_hat @ L_cur @ w_cur
            where A_hat is the normalized adj matrix
        """
        out = ad_list @ inp @ weights
        if activation == 'relu':
            return torch.relu(out)
        elif activation == 'softmax':
            return torch.softmax(out, dim=0)
        elif activation is None:
            return out
        else:
            raise ValueError()

# Karate Club Example
https://en.wikipedia.org/wiki/Zachary%27s_karate_club

In [74]:
import networkx

In [75]:
G = networkx.karate_club_graph()

In [76]:
adj_dict = dict(G.adjacency())

In [77]:
adj_m = np.eye(len(adj_dict.keys()))
for k in adj_dict:
    for n in adj_dict[k]:
        adj_m[k,n] = 1

In [78]:
adj_m # this has self loops, so A is a neighbor of A; this is good

array([[1., 1., 1., ..., 1., 0., 0.],
       [1., 1., 1., ..., 0., 0., 0.],
       [1., 1., 1., ..., 0., 1., 0.],
       ...,
       [1., 0., 0., ..., 1., 1., 1.],
       [0., 0., 1., ..., 1., 1., 1.],
       [0., 0., 0., ..., 1., 1., 1.]])

In [79]:
out_labels = []
for d in G.nodes:
    if G.nodes[d]['club'] == 'Mr. Hi':
        out_labels.append(0)
    else:
        out_labels.append(1)
out_labels = np.array(out_labels)

In [80]:
out_labels # 0 = Mr. Hi, 1 = the other guy

array([0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 1, 0,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])

In [81]:
# known = [3, 5, 24, 28]
# known = np.array(known) # can also pass in this into "train_test" to show which nodes are known

In [82]:
# giving the model about 0.3 of the data makes it predict nearly every node correctly
gcn = GCN(adj_m, out_labels, method='cat', nl=3, init_feats=None, train_test=0.3)

In [83]:
preds_init = gcn.predict(); preds_init # nonsense initial predictions

tensor([0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0])

In [84]:
out_labels

array([0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 1, 0,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])

In [85]:
(preds_init.numpy() == out_labels).mean() # very bad clearly

0.3235294117647059

In [86]:
gcn.train(epochs=40)

In [87]:
preds_init = gcn.predict()
(preds_init.numpy() == out_labels).mean() # does this well while training on only half the labeled nodes

0.9117647058823529

In [88]:
gcn.epoch_stats

[tensor(52.7011, dtype=torch.float64, grad_fn=<SumBackward0>),
 tensor(35.6980, dtype=torch.float64, grad_fn=<SumBackward0>),
 tensor(32.5641, dtype=torch.float64, grad_fn=<SumBackward0>),
 tensor(30.8256, dtype=torch.float64, grad_fn=<SumBackward0>),
 tensor(29.7124, dtype=torch.float64, grad_fn=<SumBackward0>),
 tensor(28.9693, dtype=torch.float64, grad_fn=<SumBackward0>),
 tensor(28.4385, dtype=torch.float64, grad_fn=<SumBackward0>),
 tensor(28.0156, dtype=torch.float64, grad_fn=<SumBackward0>),
 tensor(27.6558, dtype=torch.float64, grad_fn=<SumBackward0>),
 tensor(27.3410, dtype=torch.float64, grad_fn=<SumBackward0>),
 tensor(27.0583, dtype=torch.float64, grad_fn=<SumBackward0>),
 tensor(26.8021, dtype=torch.float64, grad_fn=<SumBackward0>),
 tensor(26.5675, dtype=torch.float64, grad_fn=<SumBackward0>),
 tensor(26.3478, dtype=torch.float64, grad_fn=<SumBackward0>),
 tensor(26.1430, dtype=torch.float64, grad_fn=<SumBackward0>),
 tensor(25.9503, dtype=torch.float64, grad_fn=<SumBackw

# Hand made dataset
This corresponds to the drawing in `proof_of_concept.jpg`

In [50]:
adj_m = [
        [1,1,1,0,0,0],
        [1,1,1,0,0,1],
        [1,1,1,1,1,0],
        [0,0,1,1,1,0],
        [0,0,1,1,1,1],
        [0,1,0,0,1,1]
                    ]

In [51]:
init_features = [
                [1, 5, 11, 0, 0, 0],
                [5, 1, 6, 0, 0, 13],
                [11, 6, 1, 7, 6, 0],
                [0, 0, 7, 1, 3, 0],
                [0, 0, 6, 3, 1, 6],
                [0, 13, 0, 0, 6, 1]
                                    ]

In [52]:
labels = [0, 0, 1, 1, 1, 1] # NOTE, only index 0, 2, and 3 are actually known
known = [0, 2, 3]

In [53]:
adj_m = np.array(adj_m)
init_feats = np.array(init_features, dtype=np.double)
labels = np.array(labels)
known = np.array(known)

In [54]:
np.place(init_feats, init_feats==0, 1000)

In [55]:
init_feats

array([[   1.,    5.,   11., 1000., 1000., 1000.],
       [   5.,    1.,    6., 1000., 1000.,   13.],
       [  11.,    6.,    1.,    7.,    6., 1000.],
       [1000., 1000.,    7.,    1.,    3., 1000.],
       [1000., 1000.,    6.,    3.,    1.,    6.],
       [1000.,   13., 1000., 1000.,    6.,    1.]])

In [56]:
# normalizes each row
norm_init = 1/init_feats.sum(axis=0); norm_init

array([0.00033146, 0.00049383, 0.00096993, 0.00033212, 0.00049603,
       0.00033113])

In [57]:
norm_init_feats = init_feats * norm_init.reshape(-1,1); norm_init_feats

array([[3.31455088e-04, 1.65727544e-03, 3.64600597e-03, 3.31455088e-01,
        3.31455088e-01, 3.31455088e-01],
       [2.46913580e-03, 4.93827160e-04, 2.96296296e-03, 4.93827160e-01,
        4.93827160e-01, 6.41975309e-03],
       [1.06692532e-02, 5.81959263e-03, 9.69932105e-04, 6.78952473e-03,
        5.81959263e-03, 9.69932105e-01],
       [3.32115576e-01, 3.32115576e-01, 2.32480903e-03, 3.32115576e-04,
        9.96346729e-04, 3.32115576e-01],
       [4.96031746e-01, 4.96031746e-01, 2.97619048e-03, 1.48809524e-03,
        4.96031746e-04, 2.97619048e-03],
       [3.31125828e-01, 4.30463576e-03, 3.31125828e-01, 3.31125828e-01,
        1.98675497e-03, 3.31125828e-04]])

In [128]:
gcn = GCN(adj_m, labels, method='cat', nl=3, init_feats=norm_init_feats, train_test=known)

In [129]:
preds_init = gcn.predict(); preds_init # randomly initial predictions

tensor([0, 0, 3, 2, 1, 2])

In [66]:
gcn.train(25)

In [67]:
preds = gcn.predict(); preds # pretty much what intuition suggests !!

tensor([0, 0, 1, 1, 1, 1])

In [68]:
gcn.epoch_stats # loss over epochs

[tensor(8.3114, dtype=torch.float64, grad_fn=<SumBackward0>),
 tensor(7.9898, dtype=torch.float64, grad_fn=<SumBackward0>),
 tensor(7.6924, dtype=torch.float64, grad_fn=<SumBackward0>),
 tensor(7.4108, dtype=torch.float64, grad_fn=<SumBackward0>),
 tensor(7.1532, dtype=torch.float64, grad_fn=<SumBackward0>),
 tensor(6.9108, dtype=torch.float64, grad_fn=<SumBackward0>),
 tensor(6.6719, dtype=torch.float64, grad_fn=<SumBackward0>),
 tensor(6.4493, dtype=torch.float64, grad_fn=<SumBackward0>),
 tensor(6.2470, dtype=torch.float64, grad_fn=<SumBackward0>),
 tensor(6.0633, dtype=torch.float64, grad_fn=<SumBackward0>),
 tensor(5.8927, dtype=torch.float64, grad_fn=<SumBackward0>),
 tensor(5.7345, dtype=torch.float64, grad_fn=<SumBackward0>),
 tensor(5.5918, dtype=torch.float64, grad_fn=<SumBackward0>),
 tensor(5.4625, dtype=torch.float64, grad_fn=<SumBackward0>),
 tensor(5.3448, dtype=torch.float64, grad_fn=<SumBackward0>),
 tensor(5.2372, dtype=torch.float64, grad_fn=<SumBackward0>),
 tensor(

# Continuous Dataset

In [101]:
adj_m = [
        [1,1,1,0,0,0],
        [1,1,1,0,0,1],
        [1,1,1,1,1,0],
        [0,0,1,1,1,0],
        [0,0,1,1,1,1],
        [0,1,0,0,1,1]
                    ]

In [102]:
init_features = [
                [1, 5, 11, 0, 0, 0],
                [5, 1, 6, 0, 0, 13],
                [11, 6, 1, 7, 6, 0],
                [0, 0, 7, 1, 3, 0],
                [0, 0, 6, 3, 1, 6],
                [0, 13, 0, 0, 6, 1]
                                    ]

In [157]:
labels = [-0.4, -0.2, 0, 0.1, 0.2, 0.3] # NOTE, only index 0, 2, and 3 are actually known
known = [0, 2, 3]

In [158]:
adj_m = np.array(adj_m)
init_feats = np.array(init_features, dtype=np.double)
labels = np.array(labels)
known = np.array(known)

In [159]:
np.place(init_feats, init_feats==0, 1000)

In [160]:
init_feats

array([[   1.,    5.,   11., 1000., 1000., 1000.],
       [   5.,    1.,    6., 1000., 1000.,   13.],
       [  11.,    6.,    1.,    7.,    6., 1000.],
       [1000., 1000.,    7.,    1.,    3., 1000.],
       [1000., 1000.,    6.,    3.,    1.,    6.],
       [1000.,   13., 1000., 1000.,    6.,    1.]])

In [161]:
# normalizes each row
norm_init = 1/init_feats.sum(axis=0); norm_init

array([0.00033146, 0.00049383, 0.00096993, 0.00033212, 0.00049603,
       0.00033113])

In [162]:
norm_init_feats = init_feats * norm_init.reshape(-1,1); norm_init_feats

array([[3.31455088e-04, 1.65727544e-03, 3.64600597e-03, 3.31455088e-01,
        3.31455088e-01, 3.31455088e-01],
       [2.46913580e-03, 4.93827160e-04, 2.96296296e-03, 4.93827160e-01,
        4.93827160e-01, 6.41975309e-03],
       [1.06692532e-02, 5.81959263e-03, 9.69932105e-04, 6.78952473e-03,
        5.81959263e-03, 9.69932105e-01],
       [3.32115576e-01, 3.32115576e-01, 2.32480903e-03, 3.32115576e-04,
        9.96346729e-04, 3.32115576e-01],
       [4.96031746e-01, 4.96031746e-01, 2.97619048e-03, 1.48809524e-03,
        4.96031746e-04, 2.97619048e-03],
       [3.31125828e-01, 4.30463576e-03, 3.31125828e-01, 3.31125828e-01,
        1.98675497e-03, 3.31125828e-04]])

In [204]:
gcn = GCN(adj_m, labels, method='cont', nl=3, init_feats=norm_init_feats, train_test=known)

In [205]:
preds_init = gcn.predict(); preds_init # randomly initial predictions

tensor([[-0.8765],
        [-0.8096],
        [-2.4902],
        [-2.9867],
        [-2.6371],
        [-0.9308]], dtype=torch.float64)

In [220]:
gcn.train(100)

In [221]:
preds_init = gcn.predict(); preds_init # not perfect but it captures the trend, only so much u can do w small dataset

tensor([[-0.5171],
        [-0.0588],
        [-0.4675],
        [-0.2559],
        [ 0.1674],
        [ 0.6249]], dtype=torch.float64)