In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import torch.utils.data as data
import torch_geometric
from torch_geometric.datasets import WebKB
from torch_geometric.utils import to_networkx, k_hop_subgraph
from torch_geometric.nn import GCNConv
import networkx as nx
from atten_layer import ATLayer

In [2]:
dataset = WebKB(root="/home/siddy/META/data", name='cornell')
data = dataset[0]

In [3]:
data

Data(x=[183, 1703], edge_index=[2, 298], y=[183], train_mask=[183, 10], val_mask=[183, 10], test_mask=[183, 10])

In [4]:
#netwrokx graph G
G = to_networkx(data, to_undirected=True)

In [5]:
def init_ball(radius: int, graph):
    edge_dict = {}
    for index, node in enumerate(graph.nodes()):
        paths = nx.single_source_shortest_path(graph, node, radius)
        if index not in edge_dict:
            edge_dict[index] = []
        for key, value in paths.items():
            if len(value) == 2:
                edge_dict[index].append(value)
            elif len(value)==3:
                edge_dict[index].append(value[1:])
    return edge_dict

In [6]:
edge_dict= init_ball(radius=2, graph=G)

In [29]:
edge_dict #NOTE: EDGE DICT IS BEAUTIFULLY BALLS WITH KEYS AS CENTER NODES

{0: [[0, 101], [0, 122], [101, 8], [101, 20], [101, 109]],
 1: [[1, 27], [27, 74], [27, 97]],
 2: [[2, 75], [2, 130], [75, 25], [75, 66]],
 3: [],
 4: [],
 5: [],
 6: [[6, 109],
  [6, 149],
  [109, 8],
  [109, 20],
  [109, 57],
  [109, 101],
  [109, 122],
  [149, 165]],
 7: [],
 8: [[8, 28],
  [8, 101],
  [8, 109],
  [8, 122],
  [8, 158],
  [28, 21],
  [28, 24],
  [28, 47],
  [28, 84],
  [101, 0],
  [101, 20],
  [109, 6],
  [109, 57],
  [158, 31],
  [158, 65],
  [158, 147]],
 9: [[9, 148], [148, 57], [148, 67]],
 10: [[10, 142], [10, 154], [10, 166], [154, 57]],
 11: [],
 12: [],
 13: [[13, 150], [150, 66], [150, 89], [150, 110], [150, 146]],
 14: [],
 15: [],
 16: [],
 17: [],
 18: [[18, 96], [96, 57], [96, 70], [96, 121], [96, 173]],
 19: [[19, 103], [103, 57], [103, 83]],
 20: [[20, 47],
  [20, 101],
  [20, 109],
  [20, 158],
  [47, 24],
  [47, 28],
  [101, 0],
  [101, 8],
  [101, 122],
  [109, 6],
  [109, 57],
  [158, 31],
  [158, 65],
  [158, 147]],
 21: [[21, 28],
  [21, 41],
  [

In [7]:
class Edge_atten(nn.Module):
    def __init__(self, in_channels, out_channels, num_heads=1, concat_heads=True, alpha=0.2):
        super().__init__()
        self.num_heads=num_heads
        self.concat_heads = concat_heads
        if self.concat_heads:
            assert out_channels % num_heads==0, "number of output channels must be multiple of count of heads"
            out_channels = out_channels // num_heads

        self.linear = nn.Linear(in_channels, out_channels*num_heads)
        self.a = nn.Parameter(torch.Tensor(num_heads, 2*out_channels))
        self.leakyrelu = nn.LeakyReLU(alpha)

        #xavier uniform initialization
        nn.init.xavier_uniform_(self.linear.weight.data, gain=1.414)
        nn.init.xavier_uniform_(self.a.data, gain=1.414)
    
    def forward(self, node_feats, edge_index):
        node_feats = torch.unsqueeze(node_feats, dim=0)
        batch_size, num_nodes = node_feats.size(0), node_feats.size(1)
        node_feats = self.linear(node_feats)
        node_feats = node_feats.view(batch_size, num_nodes, self.num_heads, -1)
        node_feats_flat = node_feats.view(batch_size*num_nodes, self.num_heads, -1)
        edge_indices_row = edge_index[0]
        edge_indices_col = edge_index[1]
        a_input = torch.cat([
            torch.index_select(input=node_feats_flat, index=edge_indices_row, dim=0),
            torch.index_select(input=node_feats_flat, index=edge_indices_col, dim=0)
        ], dim=-1)
        attn_logits = torch.einsum('bhc, hc->bh', a_input, self.a)
        attn_logits = self.leakyrelu(attn_logits)
        attn_probs = F.softmax(attn_logits, dim=-2)
        return attn_probs

In [45]:
edge_list = torch.permute(torch.tensor(edge_dict[20], dtype=torch.long), (1,0))

In [46]:
edge_list

tensor([[ 20,  20,  20,  20,  47,  47, 101, 101, 101, 109, 109, 158, 158, 158],
        [ 47, 101, 109, 158,  24,  28,   0,   8, 122,   6,  57,  31,  65, 147]])

In [47]:
edge_atten = Edge_atten(in_channels=data.x.size(1), out_channels=64)

In [48]:
atten_probs = edge_atten(data.x, edge_list)
torch.squeeze(atten_probs).detach()

tensor([0.0748, 0.0634, 0.0737, 0.0570, 0.0767, 0.0630, 0.1253, 0.0733, 0.0811,
        0.0603, 0.0617, 0.0636, 0.0656, 0.0606])

In [49]:
edge_list = list(tuple(i) for x,i in enumerate(edge_list.t().numpy()))

In [50]:
edge_list

[(20, 47),
 (20, 101),
 (20, 109),
 (20, 158),
 (47, 24),
 (47, 28),
 (101, 0),
 (101, 8),
 (101, 122),
 (109, 6),
 (109, 57),
 (158, 31),
 (158, 65),
 (158, 147)]

In [51]:
weights = torch.squeeze(atten_probs).detach().numpy()

In [52]:
weighted_edge_list = []
for index,edges in enumerate(edge_list):
    # for k in weights:
    #     weighted_edge_list.append((i,j)+(k))
    weighted_edge_list.append(edges+(weights[index],))
print(weighted_edge_list)

[(20, 47, 0.07479365), (20, 101, 0.0633591), (20, 109, 0.07370044), (20, 158, 0.05704129), (47, 24, 0.07674385), (47, 28, 0.06300399), (101, 0, 0.1252799), (101, 8, 0.0733241), (101, 122, 0.08106825), (109, 6, 0.0602759), (109, 57, 0.061650183), (158, 31, 0.06355356), (158, 65, 0.065591834), (158, 147, 0.060613867)]


In [53]:
weighted_edge_list

[(20, 47, 0.07479365),
 (20, 101, 0.0633591),
 (20, 109, 0.07370044),
 (20, 158, 0.05704129),
 (47, 24, 0.07674385),
 (47, 28, 0.06300399),
 (101, 0, 0.1252799),
 (101, 8, 0.0733241),
 (101, 122, 0.08106825),
 (109, 6, 0.0602759),
 (109, 57, 0.061650183),
 (158, 31, 0.06355356),
 (158, 65, 0.065591834),
 (158, 147, 0.060613867)]

In [54]:
g_ball = nx.Graph()
g_ball.add_weighted_edges_from(weighted_edge_list)

In [57]:
length, path = nx.single_source_bellman_ford(g_ball, 20, target=None, weight='weight')

In [58]:
length

{20: 0,
 47: 0.07479365170001984,
 101: 0.0633590966463089,
 109: 0.0737004429101944,
 158: 0.057041291147470474,
 24: 0.15153750032186508,
 28: 0.13779763877391815,
 0: 0.18863900005817413,
 8: 0.13668319582939148,
 122: 0.1444273442029953,
 6: 0.13397634401917458,
 57: 0.13535062596201897,
 31: 0.12059484794735909,
 65: 0.12263312563300133,
 147: 0.11765515804290771}

In [44]:
path

{1: [1],
 27: [1, 27],
 74: [1, 27, 74],
 97: [1, 27, 97],
 57: [1, 27, 74, 57],
 118: [1, 27, 97, 118],
 152: [1, 27, 97, 152],
 165: [1, 27, 97, 165]}

In [27]:
mean_r = np.mean(list(length.values()))

In [28]:
mean_r

0.2563408985733986

{0: [0],
 101: [0, 101],
 122: [0, 122],
 8: [0, 101, 8],
 20: [0, 101, 20],
 109: [0, 101, 109]}

[(0, 1), (3, 4)]

In [18]:
from torch_geometric.nn import MessagePassing

class BallGCNConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super().__init__(aggr='add')
        self.lin = Linear(in_channels, out_channels, bias=False)
        self.bias = Parameter(torch.empty(out_channels))

        self.reset_parameters()

    def reset_parameters(self):
        self.lin.reset_parameters()
        self.bias.data.zero_()

    def forward(self, x, edge_index, edge_weight:None):
        # what is the shape of inpur x ? - needed [N, in_channels]
        # edge indices shape needed is [2, E]

        #add self_loops to the adjacency matrix, how to give num nodes?
        #edge_index, _ = add_self_loops(edge_index)
        #print(edge_index)
        # linearly transform node feature matrix
        x = self.lin(x)
        #x = torch.index_select(input=x, index=edge_index[0], dim=0)
        # x_ball = torch.cat([torch.index_select(input=x, index=edge_index[0], dim=0), NOTE THAT IT WILL GIVE INDEX OUT OF RANGE ONE OPTION IS TO GO WITH REINDEXING
        #             torch.index_select(input=x, index=edge_index[1], dim=0)],dim=0)
        #compute normalization
        row, col = edge_index
        deg = degree(col, x.size(0), dtype=x.dtype)
        deg_inv_sqrt = deg.pow(0.5)
        deg_inv_sqrt[deg_inv_sqrt==float('inf')] = 0
        #print(deg_inv_sqrt.shape)
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]

        # propagating messages
        out = self.propagate(edge_index, x=x, edge_weight=edge_weight, norm=norm)
        out = torch.index_select(input=out, index=min(edge_index[0]), dim=0) #NOTE TRICK IS TO PICK MIN EDGE INDEX AS IT WILL CORRESPOND TO THE CENTER NODE OF THE BALL
        # bias
        out += self.bias
        return torch.squeeze(out)

    def message(self, x_j, norm):
        # x_j has shape [E, out_channels]
        # normalize node features
        return norm.view(-1,1) *x_j

In [19]:
class BallGCN(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = BallGCNConv(in_channels, hidden_channels)
        self.fc = Linear(hidden_channels, out_channels)
    def forward(self, x, edge_index, edge_weight):
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.fc(self.conv1(x, edge_index, edge_weight))
        return x

In [57]:
x= data.x
print(x.shape)

torch.Size([183, 1703])


In [65]:
x_ball.shape

torch.Size([10, 1703])

In [61]:
x_ball

tensor([0., 0., 0.,  ..., 0., 0., 0.])

In [13]:
edge_list_0[0]

tensor([  0,   0, 101, 101, 101])