In [None]:
# this notebook contains the 

In [1]:
import networkx as nx
import numpy as np
import torch_geometric
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.nn import Linear, Parameter
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree
from torch_geometric.datasets import WikipediaNetwork, WebKB
from torch_geometric.utils import to_networkx
from torch_geometric.utils import k_hop_subgraph

import wandb

In [2]:
run_ball_attention = wandb.init(
    project = "BALL-Attention",
    config = {
        "architecture": "GCN+Ball-atten",
        "dataset":"chameleon",
        "epoch": 100,
        "lr": 0.001,
        "weight_decay":5e-4,
        "Batch size": 1,
    }
)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33msidgraph[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#dataset = WebKB(root="/home/siddy/META/data", name='wisconsin')
dataset = WikipediaNetwork(root="/home/siddy/META/data", name='chameleon')
data = dataset[0]

 loading planetoid datasets
import torch_geometric.transforms as T
transform = T.Compose([
    T.NormalizeFeatures(),
])

In [4]:
dataset.num_classes

5

In [5]:
data

Data(x=[2277, 2325], edge_index=[2, 36101], y=[2277], train_mask=[2277, 10], val_mask=[2277, 10], test_mask=[2277, 10])

In [6]:
data.x

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 1., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])

In [7]:
data.train_mask[:,0].shape

Exception in thread SystemMonitor:
Traceback (most recent call last):
  File "/home/siddy/anaconda3/envs/GDL/lib/python3.8/threading.py", line 932, in _bootstrap_inner
    self.run()
  File "/home/siddy/anaconda3/envs/GDL/lib/python3.8/threading.py", line 870, in run
    self._target(*self._args, **self._kwargs)
  File "/home/siddy/.local/lib/python3.8/site-packages/wandb/sdk/internal/system/system_monitor.py", line 118, in _start
    asset.start()
  File "/home/siddy/.local/lib/python3.8/site-packages/wandb/sdk/internal/system/assets/cpu.py", line 166, in start
    self.metrics_monitor.start()
  File "/home/siddy/.local/lib/python3.8/site-packages/wandb/sdk/internal/system/assets/interfaces.py", line 168, in start
    logger.info(f"Started {self._process.name}")
AttributeError: 'NoneType' object has no attribute 'name'


torch.Size([2277])

In [8]:
G = to_networkx(data, to_undirected=True)

In [9]:
def init_ball(radius: int, graph):
    edge_dict = {}
    for index, node in enumerate(graph.nodes()):
        paths = nx.single_source_shortest_path(graph, node, radius)
        if index not in edge_dict:
            edge_dict[index] = []
        for key, value in paths.items():
            if len(value) == 2:
                edge_dict[index].append(value)
            elif len(value)==3:
                edge_dict[index].append(value[1:])
    return edge_dict

In [10]:
edge_dict= init_ball(radius=2, graph=G)

In [11]:
class Edge_atten(nn.Module):
    def __init__(self, in_channels, out_channels, num_heads=1, concat_heads=True, alpha=0.2):
        super().__init__()
        self.num_heads=num_heads
        self.concat_heads = concat_heads
        if self.concat_heads:
            assert out_channels % num_heads==0, "number of output channels must be multiple of count of heads"
            out_channels = out_channels // num_heads

        self.linear = nn.Linear(in_channels, out_channels*num_heads)
        self.a = nn.Parameter(torch.Tensor(num_heads, 2*out_channels))
        self.leakyrelu = nn.LeakyReLU(alpha)

        #xavier uniform initialization
        nn.init.xavier_uniform_(self.linear.weight.data, gain=1.414)
        nn.init.xavier_uniform_(self.a.data, gain=1.414)
    
    def forward(self, node_feats, edge_index):
        node_feats = torch.unsqueeze(node_feats, dim=0)
        batch_size, num_nodes = node_feats.size(0), node_feats.size(1)
        node_feats = self.linear(node_feats)
        node_feats = node_feats.view(batch_size, num_nodes, self.num_heads, -1)
        node_feats_flat = node_feats.view(batch_size*num_nodes, self.num_heads, -1)
        edge_indices_row = edge_index[0]
        edge_indices_col = edge_index[1]
        a_input = torch.cat([
            torch.index_select(input=node_feats_flat, index=edge_indices_row, dim=0),
            torch.index_select(input=node_feats_flat, index=edge_indices_col, dim=0)
        ], dim=-1)
        attn_logits = torch.einsum('bhc, hc->bh', a_input, self.a)
        attn_logits = self.leakyrelu(attn_logits)
        attn_probs = F.softmax(attn_logits, dim=-2)
        return attn_probs

In [12]:
class BallGCNConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super().__init__(aggr='add')
        self.lin = Linear(in_channels, out_channels, bias=False)
        self.bias = Parameter(torch.empty(out_channels))

        self.reset_parameters()

    def reset_parameters(self):
        self.lin.reset_parameters()
        self.bias.data.zero_()

    def forward(self, x, edge_index, edge_weight:None):
        # what is the shape of inpur x ? - needed [N, in_channels]
        # edge indices shape needed is [2, E]

        #add self_loops to the adjacency matrix, how to give num nodes?
        #edge_index, _ = add_self_loops(edge_index)
        #print(edge_index)
        # linearly transform node feature matrix
        x = self.lin(x)
        #x = torch.index_select(input=x, index=edge_index[0], dim=0)
        # x_ball = torch.cat([torch.index_select(input=x, index=edge_index[0], dim=0), NOTE THAT IT WILL GIVE INDEX OUT OF RANGE ONE OPTION IS TO GO WITH REINDEXING
        #             torch.index_select(input=x, index=edge_index[1], dim=0)],dim=0)
        #compute normalization
        row, col = edge_index
        deg = degree(col, x.size(0), dtype=x.dtype)
        deg_inv_sqrt = deg.pow(0.5)
        deg_inv_sqrt[deg_inv_sqrt==float('inf')] = 0
        #print(deg_inv_sqrt.shape)
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]

        # propagating messages
        out = self.propagate(edge_index, x=x, edge_weight=edge_weight, norm=norm)
        out = torch.index_select(input=out, index=min(edge_index[0]), dim=0) #NOTE TRICK IS TO PICK MIN EDGE INDEX AS IT WILL CORRESPOND TO THE CENTER NODE OF THE BALL
        # bias
        out += self.bias
        return torch.squeeze(out)

    def message(self, x_j, norm):
        # x_j has shape [E, out_channels]
        # normalize node features
        return norm.view(-1,1) *x_j

Exception in thread SystemMonitor:
Traceback (most recent call last):
  File "/home/siddy/anaconda3/envs/GDL/lib/python3.8/threading.py", line 932, in _bootstrap_inner
    self.run()
  File "/home/siddy/anaconda3/envs/GDL/lib/python3.8/threading.py", line 870, in run
    self._target(*self._args, **self._kwargs)
  File "/home/siddy/.local/lib/python3.8/site-packages/wandb/sdk/internal/system/system_monitor.py", line 118, in _start
    asset.start()
  File "/home/siddy/.local/lib/python3.8/site-packages/wandb/sdk/internal/system/assets/cpu.py", line 166, in start
    self.metrics_monitor.start()
  File "/home/siddy/.local/lib/python3.8/site-packages/wandb/sdk/internal/system/assets/interfaces.py", line 168, in start
    logger.info(f"Started {self._process.name}")
AttributeError: 'NoneType' object has no attribute 'name'


In [13]:
class BallGCN(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = BallGCNConv(in_channels, hidden_channels)
        self.fc = Linear(hidden_channels, out_channels)
    def forward(self, x, edge_index, edge_weight):
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.fc(self.conv1(x, edge_index, edge_weight))
        return x

In [14]:
edge_attn = Edge_atten(in_channels=data.x.size(1), out_channels=64)
ball_gcn = BallGCN(in_channels=data.x.size(1), hidden_channels=64, out_channels=dataset.num_classes)
msg = nn.Linear(in_features=data.x.size(1), out_features=dataset.num_classes, bias=False)

In [15]:
optimizer = torch.optim.Adam(set(edge_attn.parameters())| set(ball_gcn.parameters()) | set(msg.parameters()), lr=0.001, weight_decay=5e-4)

In [16]:
def train():
    edge_attn.train()
    ball_gcn.train()
    feats = []
    for node, ball in edge_dict.items():
        if len(ball)<1:
            index_feat = msg(data.x[node])
            #print(index_feat)
            #print(index_feat.shape)
            # np.vstack((feats, np.array(index_feat.detach())))
            feats.append(index_feat)
        else:
            edge_list = torch.permute(torch.tensor(ball, dtype=torch.long), (1,0))
            attn_probs = edge_attn(data.x, edge_list)
            out = ball_gcn(data.x, edge_list, attn_probs)
            # print(out)
            # print(out.shape)
            # np.vstack((feats, np.array(out.detach())))
            feats.append(out)
    feat = torch.stack(feats, -2)
    #print(feat)
    print(feat.shape)
    loss = F.cross_entropy(feat[data.train_mask[:,0]], data.y[data.train_mask[:,0]])
    loss.backward()
    optimizer.step()
    return float(loss)



In [17]:
@torch.no_grad()
def test():
    edge_attn.eval()
    ball_gcn.eval()
    feats = []
    for node, ball in edge_dict.items():
        if len(ball)<1:
            index_feat = msg(data.x[node]).argmax(dim=-1)
            feats.append(index_feat)
        else:
            edge_list = torch.permute(torch.tensor(ball, dtype=torch.long), (1,0))
            attn_probs = edge_attn(data.x, edge_list)
            out = ball_gcn(data.x, edge_list, attn_probs).argmax(dim=-1)
            feats.append(out)
    feat = torch.stack(feats, -1)
    print(feat.shape)
    accs=[]
    for mask in [data.train_mask[:,0], data.val_mask[:,0], data.test_mask[:,0]]:
        accs.append(int((feat[mask] == data.y[mask]).sum()) / int(mask.sum()))
    return accs   
    

In [18]:
import time
epochs=100
best_val_acc = final_test_acc = 0
times = []
for epoch in range(1, epochs+1):
    start = time.time()
    loss= train()
    wandb.log({"Loss":loss})
    train_acc, val_acc, tmp_test_acc = test()
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        test_acc = tmp_test_acc
    print(f"Epoch:{epoch}, Loss:{loss}, Train:{train_acc}, Val:{val_acc}, Test:{test_acc}")
    wandb.log({"train_acc":train_acc})
    wandb.log({"test_acc":test_acc})
    wandb.log({"val_acc":val_acc})
    times.append(time.time()-start)
print(f"Median time per epoch: {torch.tensor(times).median():.4f}s")

: 

In [None]:
run_ball_attention.finish()

0,1
Loss,█▇▇▆▆▅▅▅▄▄▄▃▃▃▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
test_acc,▁▂▂▂▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄██████████████
train_acc,▁▂▃▃▇███████████████████████████████████
val_acc,▁▃▃▄▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆██████████████

0,1
Loss,1.10178
test_acc,0.60784
train_acc,0.525
val_acc,0.55
