# Hierarchical Graph Pooling

This is a notebook exploring the architecture described in https://arxiv.org/pdf/1911.05954.pdf

In [8]:
import sys

import numpy as np
import torch

sys.path.append("..")

In [19]:
from lightning_modules.utils import make_mlp
from torch_scatter import scatter_mean

## Roadmap

1. Set up fake graph, with node features $x$ and edge indices $e$
2. Set up node network $M_n$ as $d \rightarrow d$ MLP, and edge attention network $M_e$ as $2d \rightarrow 1$ MLP
3. Make Manhattan distance function $D$
4. Make node information score function $p$
5. Rank top-k most informative nodes
-- Stop here and test that this trains!

### 1. Set up fake graph

In [39]:
x = torch.randint(0, 5, (40, 3)).float()
e = torch.randint(0, len(x), (2, 1000))

In [40]:
hidden_channels = 32

### 2. Set up networks

In [41]:
input_network = make_mlp(x.shape[1], [hidden_channels]*3,
             hidden_activation='ReLU',
             output_activation='ReLU',
             layer_norm=True)

node_network = make_mlp(hidden_channels, [hidden_channels]*3,
             hidden_activation='ReLU',
             output_activation='ReLU',
             layer_norm=True)

edge_network = make_mlp(2*x.shape[1], [hidden_channels]*3+[1],
             hidden_activation='ReLU',
             output_activation=None)

In [42]:
x = input_network(x)

### 3. Make distance function

In [43]:
src = torch.cat([e, e.flip(1)], axis=-1)
reconstructed_x = scatter_mean(x[src[0]], src[1], dim=0, dim_size=x.shape[0])

In [61]:
D = torch.abs(x - reconstructed_x)

### 4. Sum distance for node information score

In [62]:
p = torch.sum(D, dim=1)

In [63]:
p

tensor([ 4.1535,  9.7257,  5.0085,  6.2467,  7.0837,  4.9354,  4.7244,  6.7053,
         4.6235,  6.7910,  6.7414,  9.4983,  8.7384,  5.6556,  8.0180,  4.7247,
         4.6897,  7.0028,  3.8158,  5.8215,  7.2737, 12.0679,  6.2713, 10.5605,
         4.5469,  9.0993,  4.1239,  8.4504,  8.0633,  9.1488,  5.8921,  8.5343,
         8.1766,  7.5356,  8.8293,  5.1800,  4.7730,  5.8813,  4.2438,  8.2299],
       grad_fn=<SumBackward1>)

### 5. Rank most informative nodes

In [68]:
rank_idx = torch.argsort(p, descending=True)

In [70]:
top_k = 10
top_k_idx = rank_idx[top_k]