# Train an All-Set TNN

In this notebook, we will create and train a two-step message passing network named AllSet (Chien et al., [2021](https://arxiv.org/abs/2106.13264)) in the hypergraph domain. We will use a benchmark dataset, shrec16, a collection of 3D meshes, to train the model to perform classification at the level of the hypergraph. 

🟧 $\quad m_{\rightarrow z}^{(\rightarrow 1)} = AGG_{y \in \mathcal{B}(z)} (h_y^{t, (0)}, h_z^{t,(1)})$ 

🟦 $\quad h_z^{t+1,(1)} = \sigma(m_{\rightarrow z}^{(\rightarrow 1)})$ 

Edge to vertex: 

🟧 $\quad m_{\rightarrow x}^{(\rightarrow 0)} = AGG_{z \in \mathcal{C}(x)} (h_z^{t+1,(1)}, h_x^{t,(0)})$ 

🟦 $\quad h_x^{t+1,(0)} = \sigma(m_{\rightarrow x}^{(\rightarrow 0)})$

### Additional theoretical clarifications
Given a hypergraph $G=(\mathcal{V}, \mathcal{E})$, let $\textbf{X} \in \mathbb{R}^{|\mathcal{V}| \times F}$ and $\textbf{Z} \in \mathbb{R}^{|\mathcal{E}| \times F'}$ denote the hidden node and hyperedge representations, respectively. Additionally, define $V_{e, \textbf{X}} = \{\textbf{X}_{u,:}: u \in e\}$ as the multiset of hidden node representations in the hyperedge $e$ and $E_{v, \textbf{Z}} = \{\textbf{Z}_{e,:}: v \in e\}$ as the multiset of hidden representations of hyperedges containing $v$.

\
In this setting, the two general update rules that AllSet's framework puts in place in each layer are:

🔷 $\textbf{Z}_{e,:}^{(t+1)} = f_{\mathcal{V} \rightarrow \mathcal{E}}(V_{e, \textbf{X}^{(t)}}; \textbf{Z}_{e,:}^{(t)})$

🔷 $\textbf{X}_{v,:}^{(t+1)} = f_{\mathcal{E} \rightarrow \mathcal{V}}(E_{v, \textbf{Z}^{(t+1)}}; \textbf{X}_{v,:}^{(t)})$

in which $f_{\mathcal{V} \rightarrow \mathcal{E}}$ and $f_{\mathcal{E} \rightarrow \mathcal{V}}$ are two permutation invariant functions with respect to their first input. The matrices $\textbf{Z}_{e,:}^{(0)}$ and $\textbf{X}_{v,:}^{(0)}$ are initialized with the hyperedge and node features respectively, if available, otherwise they are set to be all-zero matrices.

In the practical implementation of the model, $f_{\mathcal{V} \rightarrow \mathcal{E}}$ and $f_{\mathcal{E} \rightarrow \mathcal{V}}$ are parametrized and $learnt$ for each dataset and task, and the information of their second argument is not utilized. 


In [4]:
"""
This module contains the AllSet class for hypergraph-based neural networks.

The AllSet class implements a specific hypergraph-based neural network architecture
used for solving certain types of problems.

Author: Your Name

"""

import torch
import numpy as np
from torch_geometric.utils import to_undirected
import torch_geometric.datasets as geom_datasets

from topomodelx.nn.hypergraph.allset import AllSet


If GPU's are available, we will make use of them. Otherwise, this will run on CPU.

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cpu


# Pre-processing

## Import data ##

The first step is to import the dataset, Cora, a benchmark classification datase. We then lift the graph into our domain of choice, a hypergraph.


In [6]:

cora = geom_datasets.Planetoid(root="/TopoModelX/data/cora", name="Cora")
data = cora.data

x_0s = data.x
y = data.y
edge_index = data.edge_index

train_mask = data.train_mask
val_mask = data.val_mask
test_mask = data.test_mask




## Define neighborhood structures and lift into hypergraph domain. ##

Now we retrieve the neighborhood structure (i.e. their representative matrice) that we will use to send messges from node to hyperedges. In the case of this architecture, we need the boundary matrix (or incidence matrix) $B_1$ with shape $n_\text{nodes} \times n_\text{edges}$.

In citation Cora dataset we lift graph structure to the hypergraph domain by creating hyperedges from 1-hop graph neighbourhood of each node. 


In [7]:
# Ensure the graph is undirected (optional but often useful for one-hop neighborhoods).
edge_index = to_undirected(edge_index)

# Create a list of one-hop neighborhoods for each node.
one_hop_neighborhoods = []
for node in range(data.num_nodes):
    # Get the one-hop neighbors of the current node.
    neighbors = data.edge_index[1, data.edge_index[0] == node]

    # Append the neighbors to the list of one-hop neighborhoods.
    one_hop_neighborhoods.append(neighbors.numpy())

# Detect and eliminate duplicate hyperedges.
unique_hyperedges = set()
hyperedges = []
for neighborhood in one_hop_neighborhoods:
    # Sort the neighborhood to ensure consistent comparison.
    neighborhood = tuple(sorted(neighborhood))
    if neighborhood not in unique_hyperedges:
        hyperedges.append(list(neighborhood))
        unique_hyperedges.add(neighborhood)    

Additionally we print the statictis associated with obtained incidence matrix

In [8]:

# Calculate hyperedge statistics.
hyperedge_sizes = [len(he) for he in hyperedges]
min_size = min(hyperedge_sizes)
max_size = max(hyperedge_sizes)
mean_size = np.mean(hyperedge_sizes)
median_size = np.median(hyperedge_sizes)
std_size = np.std(hyperedge_sizes)
num_single_node_hyperedges = sum(np.array(hyperedge_sizes) == 1)

# Print the hyperedge statistics.
print(f'Hyperedge statistics: ')
print('Number of hyperedges without duplicated hyperedges', len(hyperedges))
print(f'min = {min_size}, ')
print(f'max = {max_size}, ')
print(f'mean = {mean_size}, ')
print(f'median = {median_size}, ')
print(f'std = {std_size}, ')
print(f'Number of hyperedges with size equal to one = {num_single_node_hyperedges}')


Hyperedge statistics: 
Number of hyperedges without duplicated hyperedges 2581
min = 1, 
max = 168, 
mean = 4.003099573808601, 
median = 3.0, 
std = 5.327622607829558, 
Number of hyperedges with size equal to one = 412


Construct incidence matrix

In [9]:
max_edges = len(hyperedges)
incidence_1 = np.zeros((x_0s.shape[0], max_edges))
for col, neighibourhood in enumerate(hyperedges):
    for row in neighibourhood:
        incidence_1[row, col] = 1

assert all(incidence_1.sum(0)>0) == True, "Some hyperedges are empty"
assert all(incidence_1.sum(1)>0) == True, "Some nodes are not in any hyperedges"
incidence_1 = torch.Tensor(incidence_1).to_sparse_coo()

# Create the Neural Network

In [10]:
in_channels = x_0s.shape[1]
hidden_channels = 128
out_channels = torch.unique(y).shape[0]
task_level = "graph" if out_channels==1 else "node"
n_layers=1

# Define the model
model = AllSet(
    in_channels=in_channels,
    hidden_channels=hidden_channels,
    out_channels=out_channels,
    n_layers=n_layers,
    mlp_num_layers=1,
    task_level=task_level,
)
model = model.to(device)

# Train the Neural Network

We specify the model, the loss, and an optimizer.

In [11]:
# Optimizer and loss
opt = torch.optim.Adam(model.parameters(), lr=0.01)

# Categorial cross-entropy loss
loss_fn = torch.nn.CrossEntropyLoss()

# Accuracy
acc_fn = lambda y, y_hat: (y == y_hat).float().mean()

In [12]:
x_0s = torch.tensor(x_0s)
x_0s, incidence_1, y = (
            x_0s.float().to(device),
            incidence_1.float().to(device),
            torch.tensor(y, dtype=torch.long).to(device),
        )

  x_0s = torch.tensor(x_0s)
  torch.tensor(y, dtype=torch.long).to(device),


The following cell performs the training, looping over the network for a low amount of epochs. We keep training minimal for the purpose of rapid testing.

In [13]:
torch.manual_seed(0)
test_interval = 5
num_epochs = 30

epoch_loss = []
for epoch_i in range(1, num_epochs + 1):
    
    model.train()

    opt.zero_grad()
    
    # Extract edge_index from sparse incidence matrix
    y_hat = model(x_0s, incidence_1)
    loss = loss_fn(y_hat[train_mask], y[train_mask])

    loss.backward()
    opt.step()
    epoch_loss.append(loss.item())

    print(
        f"Epoch: {epoch_i} loss: {np.mean(epoch_loss):.4f}, \
        acc: {acc_fn(y_hat[train_mask].argmax(1), y[train_mask]):.4f}",
        flush=True,
    )

    if epoch_i % test_interval == 0:
        model.eval()
        y_hat = model(x_0s, incidence_1)

        loss = loss_fn(y_hat[val_mask], y[val_mask])
        print(f"Val_loss: {loss:.4f}, Val_acc: {acc_fn(y_hat[val_mask].argmax(1), y[val_mask]):.4f}", flush=True)

        loss = loss_fn(y_hat[test_mask], y[test_mask])
        print(f"Test_loss: {loss:.4f}, Test_acc: {acc_fn(y_hat[test_mask].argmax(1), y[test_mask]):.4f}", flush=True)

Epoch: 1 loss: 1.9463,         acc: 0.1429
Epoch: 2 loss: 1.9444,         acc: 0.2500
Epoch: 3 loss: 1.9251,         acc: 0.1857
Epoch: 4 loss: 1.8896,         acc: 0.1429
Epoch: 5 loss: 1.8660,         acc: 0.4000
Val_loss: 1.8107, Val_acc: 0.4660
Test_loss: 1.8106, Test_acc: 0.4580


Epoch: 6 loss: 1.8350,         acc: 0.4071
Epoch: 7 loss: 1.7756,         acc: 0.4429
Epoch: 8 loss: 1.6999,         acc: 0.4429
Epoch: 9 loss: 1.6292,         acc: 0.5214
Epoch: 10 loss: 1.5617,         acc: 0.5643
Val_loss: 1.6181, Val_acc: 0.5460
Test_loss: 1.5937, Test_acc: 0.5300
Epoch: 11 loss: 1.4966,         acc: 0.5857
Epoch: 12 loss: 1.4400,         acc: 0.6357
Epoch: 13 loss: 1.3909,         acc: 0.8071
Epoch: 14 loss: 1.3518,         acc: 0.8000
Epoch: 15 loss: 1.3084,         acc: 0.8143
Val_loss: 2.4014, Val_acc: 0.6240
Test_loss: 2.4001, Test_acc: 0.6330
Epoch: 16 loss: 1.2680,         acc: 0.7857
Epoch: 17 loss: 1.2288,         acc: 0.7357
Epoch: 18 loss: 1.1883,         acc: 0.8143
Epoch: 19 loss: 1.1475,         acc: 0.8429
Epoch: 20 loss: 1.1041,         acc: 0.8714
Val_loss: 3.5885, Val_acc: 0.5960
Test_loss: 3.2893, Test_acc: 0.6270
Epoch: 21 loss: 1.0625,         acc: 0.8857
Epoch: 22 loss: 1.0231,         acc: 0.9286
Epoch: 23 loss: 0.9872,         acc: 0.9000
Ep