In [2]:
import pandas as pd
import numpy as np
import networkx as nx
import torch
from torch import nn
import spektral
import hashlib
from spektral.datasets.citation import Citation
from torch.utils.data import Dataset, DataLoader
from torch.nn.functional import normalize

In [353]:
from math import cos


class CoraDataset(Dataset):
    def __init__(self, path='data/citeseer/', context_size=10, with_wl=True, wl_iterations=2, cutoff=99) -> None:
        """
        Args:
            path (str, optional): path containing the Cora dataset. Defaults to 'data/citeseer/'.
            context_size (int, optional): number of nodes in a target node context, represent the
                                          topk nodes sorted by intimacy score. Defaults to 10.
        """
        super().__init__()

        self.context_size = context_size
        self.with_wl = with_wl
        self.wl_iterations = wl_iterations
        self.cutoff = cutoff

        # load features and labels
        self.raw_features = pd.read_csv(f"{path}/citeseer.content", sep='\t', header=None, index_col=0)
        nodes_name_list = self.raw_features.index.astype(str)

        # dictionaries to map names and ids
        self.name_to_id = {name:idx for idx,name in enumerate(nodes_name_list)}
        self.id_to_name = {idx:name for idx,name in enumerate(nodes_name_list)}

        # cast index of dataframe to ids
        self.raw_features.index = self.name_to_id.values()
        self.raw_features = self.raw_features.sort_index()

        # save feature matrix and labels
        self.labels = self.raw_features.iloc[:,-1].astype('category').cat.codes
        self.raw_features = self.raw_features.iloc[:,:-1]

        # input dimensions
        self.raw_features_size = self.raw_features.shape[1]
        self.n = len(self.raw_features.index)

        # swap nodes on each row as they are listed as target-source
        df = pd.read_csv(f"{path}/citeseer.cites", sep='\t', header=None)
        self.edge_list = list(filter(lambda x: x[0] in nodes_name_list and x[1] in nodes_name_list, df.values.tolist()))
        self.edge_list = [(self.name_to_id[x[1]], self.name_to_id[x[0]]) for x in self.edge_list]
        self.node_list = self.raw_features.index.tolist()
        
        
        # distances between nodes
        G = nx.from_edgelist(self.edge_list, create_using=nx.Graph)
        self.distance_matrix = dict(nx.all_pairs_shortest_path_length(G), cutoff=cutoff)

        # pre-process graph to make data loader more efficient
        self.build_intimacy_matrix()
        self.build_contexts()

        self.wl_colors = None
        if with_wl:
            self.build_wl_coloring()
        self.max_wl_colors = max(self.wl_colors.values())

    def build_intimacy_matrix(self, alpha=0.15):
        # create adjacency matrix
        n = self.n
        adj_mat = np.zeros((n,n))
        name_to_id = self.name_to_id

        for x,y in self.edge_list:
            adj_mat[x, y] = 1
            adj_mat[y, x] = 1
        
        # compute inverse of diagonal degrees matrix
        dinv_mat = np.nan_to_num(np.diag(adj_mat.sum(axis=0)))

        # compute final matrix, for details see (1) in Graph-BERT by Zhang et al. '20 page 3
        Abar = np.matmul(adj_mat, dinv_mat)
        I = np.diag(np.ones(n))
        M = np.linalg.inv(I - (1-alpha)*Abar)

        self.intimacy_mat = alpha*M
    
    def build_contexts(self):
        context_list = np.zeros((self.n, self.context_size+1), dtype=np.int)

        # always include target node into its context
        context_list[:,0] = range(self.n)

        sorted_context_list = (-self.intimacy_mat).argsort(axis=1)

        # context of a node contains the topk nodes sorted by intimacy score
        for i in range(self.n):
            mask = [x in self.raw_features.index for x in sorted_context_list[i]]  # ids with a raw feature vector
            context_list[i,1:] = sorted_context_list[i][mask][:self.context_size]
        
        self.context_list = context_list
    
    def build_wl_coloring(self):
        G = nx.from_edgelist(self.edge_list, create_using=nx.Graph)

        # initialize node colors
        wl_colors = {node:1 for node in G.nodes}

        for _ in range(self.wl_iterations):
            for node in sorted(G.nodes):
                # combine colors from neighbors
                code_list = [wl_colors[node]] + [wl_colors[x] for x in sorted(G.neighbors(node))]
                code = "".join(map(str, code_list))

                # update node code
                wl_colors[node] = hashlib.md5(code.encode()).hexdigest()
        
        color_to_num = {color:i for i,color in enumerate(set(wl_colors.values()))}
        wl_colors = {node:color_to_num[c] for node,c in wl_colors.items()}
        self.wl_colors = wl_colors
    
    def __len__(self):
        """ each node together with its context represent an instance of the graph """
        return self.n
    
    def __getitem__(self, idx):
        # ids of nodes in the context
        context = self.context_list[idx,:]

        n = len(context)
        
        # raw feature vector embedding
        X = torch.tensor([self.raw_features.loc[i].tolist() for i in context])

        # Weisfeiler-Lehman absolute role embedding
        C = torch.tensor([self.wl_colors[i] for i in context])

        # intimacy based relative positional embedding
        I = torch.tensor(range(n))

        # hop based relative distance embedding
        source = context[0]
        keys = self.distance_matrix[source].keys()
        distances = [self.distance_matrix[source][i] if i in keys else self.cutoff for i in context]
        H = torch.tensor(distances)

        # labels
        y = torch.tensor([self.labels[i] for i in context])
        
        return X, C, I, H, y

In [354]:
cora = CoraDataset()

In [66]:
def position_embed(values, hidden_size):
    E = torch.zeros((len(values), hidden_size))

    for i,val in enumerate(values):
        for j in range(hidden_size//2):
            E[i,j] = np.math.sin(val/(10000**((2*i)/hidden_size)))
            E[i,j+1] = np.math.cos(val/10000**((2*i+1)/hidden_size))
    
    return E

In [94]:
class BertEmbeddings(nn.Module):
    def __init__(self, config: GraphBertConfig):
        super().__init__()
        self.raw_features_embeddings = nn.Linear(config.x_size, config.hidden_size)
        self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
    
    def forward(self, raw_features, wl_role_ids, init_pos_ids, hop_dis_ids):
        raw_feature_embeds = self.raw_features_embeddings(raw_features.type(torch.FloatTensor))
        role_embeddings = position_embed(wl_role_ids, config.hidden_size)
        position_embeddings = position_embed(init_pos_ids, config.hidden_size)
        hop_embeddings = position_embed(hop_dis_ids, config.hidden_size)

        embeddings = raw_feature_embeds + role_embeddings + position_embeddings + hop_embeddings
        embeddings = self.layer_norm(embeddings)
        embeddings = self.dropout(embeddings)

        return embeddings

In [95]:
bert_embed = BertEmbeddings(config)

In [100]:
X, C, I, H, y = cora[0]
x = bert_embed(X, C, I, H)

In [130]:
class GraphTransformer(nn.Module):
    def __init__(self, hidden_size=32):
        super().__init__()

        self.hidden_size = hidden_size

        q = torch.Tensor(hidden_size, hidden_size)
        k = torch.Tensor(hidden_size, hidden_size)
        v = torch.Tensor(hidden_size, hidden_size)

        self.q = nn.Parameter(q)
        self.k = nn.Parameter(k)
        self.v = nn.Parameter(v)

        nn.init.kaiming_uniform_(self.q, a=np.sqrt(5))
        nn.init.kaiming_uniform_(self.k, a=np.sqrt(5))
        nn.init.kaiming_uniform_(self.v, a=np.sqrt(5))
    
    def forward(self, x):
        q = torch.matmul(x, self.q)
        k = torch.matmul(x, self.k)
        v = torch.matmul(x, self.v)
        out = torch.matmul(q, k.T)
        
        out = torch.softmax(out/np.sqrt(self.hidden_size), dim=1)
        out = torch.matmul(out, v)

        return out

In [127]:
graph_transformer = GraphTransformer(config)

In [135]:
class GraphBert(nn.Module):
    def __init__(self, num_layers=2, hidden_size=32):
        super().__init__()

        self.num_layers = num_layers
        self.hidden_size = hidden_size
        self.graph_transformer_layers = [GraphTransformer(hidden_size=hidden_size) for _ in range(num_layers)]

    def forward(self, x):
        for i in range(self.num_layers):
            x = self.graph_transformer_layers[i](x)
        
        return x.mean(axis=0)

In [136]:
graph_bert = GraphBert()

In [139]:
class NodeRawAttributeReconstruction(nn.Module):
    def __init__(self, features_dim, num_layers=2, hidden_size=32):
        super().__init__()

        self.graph_bert = GraphBert(num_layers=num_layers, hidden_size=hidden_size)
        self.fc = nn.Linear(hidden_size, features_dim)
        self.act = nn.ReLU()
    
    def forward(self, x):
        x = self.graph_bert(x)
        x = self.fc(x)
        return self.act(x)

In [141]:
node_recons = NodeRawAttributeReconstruction(config.x_size)

In [161]:
num_epochs = 10

for epoch in range(num_epochs):
    for i,data in enumerate(cora):
        X, C, I, H, y = data

0 [   0 3300 1272 1023    0  929 1493  324  788 2884 2668]
1 [   1    1  284 1140 1319 2831 2436 1281 3061 2211 2212]
2 [   2    2    0 2209 2210 2211 2212 2213 2214 2215 2216]
3 [   3 3135 1931 1013  404   81 1704  381 3169 1032 2118]
4 [   4    4 1594  468 3096 2212 2213 2214 2215 2216 2217]
5 [   5  404 1931 1717 3135 2118 2964  998 2152 2106 2311]
6 [   6    6    0 2212 2213 2214 2215 2216 2217 2218 2219]
7 [   7 3135  811 2118   81 1013  488 1931  183    7 2106]
8 [   8  344 1362 2067 1447 2251 2627 3300 1931 1023 1889]
9 [   9 1570 2645 2636  128 1535 2062  177 1472  141 3296]
10 [  10   10 1930  666 1237 2118  596  488 3135   81  811]
11 [  11 1272 1023 3300  929  788 1493  324  489  190 2668]
12 [  12   12 2118   81 2361 1032  183  967 3135 1948 2106]
13 [  13   13    0 2211 2212 2213 2214 2215 2216 2217 2218]
14 [  14 3122 1665 2118   81  183 3135 1032  811 2106  488]
15 [  15  101 2354 2202  809 2887 1362  162 1021   15 2203]
16 [  16   16 1570 2645 2118  401   81 2190 1032 1

KeyError: 'raisamo99evaluating'

In [157]:
cora.distance_matrix[1]

{1: 0,
 978: 1,
 2831: 1,
 544: 2,
 1281: 2,
 2436: 2,
 602: 2,
 3061: 3,
 284: 3,
 1319: 4,
 1140: 4}