

# Train an All-Set-Transformer TNN

In this notebook, we will create and train a two-step message passing network named AllSetTransformer (Chien et al., [2021](https://arxiv.org/abs/2106.13264)) in the hypergraph domain. We will use a benchmark dataset, shrec16, a collection of 3D meshes, to train the model to perform classification at the level of the hypergraph. 

Following the "awesome-tnns" [github repo.](https://github.com/awesome-tnns/awesome-tnns/blob/main/Hypergraphs.md)

🟧 $\quad m_{\rightarrow z}^{(\rightarrow 1)} = AGG_{y \in \mathcal{B}(z)} (h_y^{t, (0)}, h_z^{t,(1)}) \quad \text{with attention}$ 

🟦 $\quad h_z^{t+1,(1)} = \text{LN}(m_{\rightarrow z}^{(\rightarrow 1)} + \text{MLP}(m_{\rightarrow z}^{(\rightarrow 1)} ))$ 

Edge to vertex: 

🟧 $\quad m_{\rightarrow x}^{(\rightarrow 0)} = AGG_{z \in \mathcal{C}(x)} (h_z^{t+1,(1)}, h_x^{t,(0)}) \quad \text{with attention}$ 

🟦 $\quad h_x^{t+1,(0)} = \text{LN}(m_{\rightarrow x}^{(\rightarrow 0)} + \text{MLP}(m_{\rightarrow x}^{(\rightarrow 0)} ))$



### Additional theoretical clarifications

Given a hypergraph $G=(\mathcal{V}, \mathcal{E})$, let $\textbf{X} \in \mathbb{R}^{|\mathcal{V}| \times F}$ and $\textbf{Z} \in \mathbb{R}^{|\mathcal{E}| \times F'}$ denote the hidden node and hyperedge representations, respectively. Additionally, define $V_{e, \textbf{X}} = \{\textbf{X}_{u,:}: u \in e\}$ as the multiset of hidden node representations in the hyperedge $e$ and $E_{v, \textbf{Z}} = \{\textbf{Z}_{e,:}: v \in e\}$ as the multiset of hidden representations of hyperedges containing $v$.

\
In this setting, the two general update rules that AllSet's framework puts in place in each layer are:

🔷 $\textbf{Z}_{e,:}^{(t+1)} = f_{\mathcal{V} \rightarrow \mathcal{E}}(V_{e, \textbf{X}^{(t)}}; \textbf{Z}_{e,:}^{(t)})$

🔷 $\textbf{X}_{v,:}^{(t+1)} = f_{\mathcal{E} \rightarrow \mathcal{V}}(E_{v, \textbf{Z}^{(t+1)}}; \textbf{X}_{v,:}^{(t)})$

in which $f_{\mathcal{V} \rightarrow \mathcal{E}}$ and $f_{\mathcal{E} \rightarrow \mathcal{V}}$ are two permutation invariant functions with respect to their first input. The matrices $\textbf{Z}_{e,:}^{(0)}$ and $\textbf{X}_{v,:}^{(0)}$ are initialized with the hyperedge and node features respectively, if available, otherwise they are set to be all-zero matrices.

In the practical implementation of the model, $f_{\mathcal{V} \rightarrow \mathcal{E}}$ and $f_{\mathcal{E} \rightarrow \mathcal{V}}$ are parametrized and $learnt$ for each dataset and task, and the information of their second argument is not utilized. The option achieving the best results makes use of attention-based layers, giving rise to the so-called AllSetTransformer architecture.

\
We now dive deep into the details of AllSetTransformer, describing how the update functions $f_{\mathcal{V} \rightarrow \mathcal{E}}$ and $f_{\mathcal{E} \rightarrow \mathcal{V}}$ are iteratively defined.
Their input is a matrix $\textbf{S} \in \mathbb{R}^{|S| \times F}$ which corresponds the multiset of $F$-dimensional feature vectors:

1️⃣ $\textbf{K}^{(i)} = \text{MLP}^{K, i}(\textbf{S}), \textbf{V}^{(i)} = \text{MLP}^{V, i}(\textbf{S})$, where $i \in \{1, ..., h\},$

2️⃣ $ \textbf{O}^{(i)} = \omega (\theta^{(i)}(\textbf{K}^{(i)})^{T}) \textbf{V}^{(i)},$  

3️⃣ $\theta  \overset{\Delta}{=} \mathbin\Vert_{i=1}^{h} \theta^{(i)}, $

4️⃣ $ \text{MH}_{h, \omega}(\theta, \textbf{S}, \textbf{S}) = \mathbin\Vert_{i=1}^{h} \textbf{O}^{(i)}, $

5️⃣ $ \textbf{Y} = \text{LN} (\theta + \text{MH}_{h, \omega}(\theta, \textbf{S}, \textbf{S})), $

6️⃣ $f_{\mathcal{V} \rightarrow \mathcal{E}}(\textbf{S}) =  f_{\mathcal{E} \rightarrow \mathcal{V}}(\textbf{S}) = \text{LN} (\textbf{Y} + \text{MLP}(\textbf{Y}))$.

\

The elements and operations used in these steps are defined as follows:

🔶 $\text{LN}$ means layer normalization (Ba et al., [2016](https://arxiv.org/abs/1607.06450)),

🔶 $\mathbin\Vert$ represents concatenation,

🔶 $\theta \in \mathbb{R}^{1 \times hF_{h}}$ is a learnable weight,

🔶 $\text{MH}_{h, \omega}$ denotes a multihead attention mechanism with $h$ heads and activation function $\omega$ (Vaswani et al., [2017](https://proceedings.neurips.cc/paper_files/paper/2017/hash/3f5ee243547dee91fbd053c1c4a845aa-Abstract.html)),

🔶 all $\text{MLP}$ modules are multi-layer perceptrons that operate row-wise, so they are applied identically and independently to each multiset element of $\textbf{S}$.





In [1]:
import torch
import numpy as np
import toponetx.datasets as datasets
from sklearn.model_selection import train_test_split

from topomodelx.nn.hypergraph.allset_transformer import AllSetTransformer
from topomodelx.utils.sparse import from_sparse

# %load_ext autoreload
# %autoreload 2

If GPU's are available, we will make use of them. Otherwise, this will run on CPU.

In [2]:
device = torch.device("cpu")
print(device)

cpu


# Pre-processing

## Import data ##

The first step is to import the dataset, shrec 16, a benchmark dataset for 3D mesh classification. We then lift each graph into our domain of choice, a hypergraph.

We will also retrieve:
- input signal on the edges for each of these hypergraphs, as that will be what we feed the model in input
- the label associated to the hypergraph

In [3]:
# Load data
import torch_geometric.datasets as geom_datasets
cora = geom_datasets.Planetoid(root="/TopoModelX/data/cora", name="Cora")


x_0s = cora.data.x
edge_index = cora.data.edge_index

y = cora.data.y

# train_mask = cora.data.train_mask
# val_mask = cora.data.val_mask
# test_mask = cora.data.test_mask

# make custom train test val split
idxs = np.arange(len(y))
train_idxs, test_idxs = train_test_split(idxs, test_size=0.2, random_state=42)
train_idxs, val_idxs = train_test_split(train_idxs, test_size=0.2, random_state=42)



train_mask = torch.zeros(len(y), dtype=torch.bool)
train_mask[train_idxs] = True

val_mask = torch.zeros(len(y), dtype=torch.bool)
val_mask[val_idxs] = True

test_mask = torch.zeros(len(y), dtype=torch.bool)
test_mask[test_idxs] = True





In [4]:

import networkx as nx

G = nx.Graph()
for edge in edge_index.numpy().T:
    G.add_edge(edge[0], edge[1])

print('Number of hyperedges', G.number_of_edges())
# check if dirrected
print(G.is_directed())

#check if there is isolated nodes in the graph
print('Number of isolated nodes', len(list(nx.isolates(G))))

Number of hyperedges 5278
False
Number of isolated nodes 0


In [5]:
# from toponetx.classes.simplicial_complex import SimplicialComplex
# a = SimplicialComplex([[0,1,2,3], [3,4,5,6]])
# a.to_hypergraph()

In [6]:
# from toponetx.datasets.graph import coauthorship, karate_club
# coa = coauthorship()

## Define neighborhood structures and lift into hypergraph domain. ##

Now we retrieve the neighborhood structures (i.e. their representative matrices) that we will use to send messges on each simplicial complex. In the case of this architecture, we need the boundary matrix (or incidence matrix) $B_1$ with shape $n_\text{nodes} \times n_\text{edges}$.

Once we have recorded the incidence matrix (note that all incidence amtrices in the hypergraph domain must be unsigned), we lift each simplicial complex into a hypergraph. The pairwise edges will become pairwise hyperedges, and faces in the simplciial complex will become 3-wise hyperedges.

In [7]:
hyperedges = []
for node in G.nodes():
    hyperedge = sorted(list(G.neighbors(node)) + [node])
    hyperedges.append(tuple(hyperedge))
    
        
print('Number of hyperedges', len(hyperedges))

# Delete duplicates
hyperedges = list(set(hyperedges))
print('Number of hyperedges without duplicated hyperedges', len(hyperedges))

print(f'Hyperedge statistics: \
    min = {min([len(he) for he in hyperedges])} \
    max = {max([len(he) for he in hyperedges])} \
    mean = {np.mean([len(he) for he in hyperedges])} \
    median = {np.median([len(he) for he in hyperedges])} \
    std = {np.std([len(he) for he in hyperedges])}')

# Construct hypergraph
from hypernetx import Hypergraph
hyperedges = {f'e{idx}':list(he) for idx, he in enumerate(hyperedges)}
incidence_1 = Hypergraph(hyperedges, static=True).incidence_matrix()
incidence_1 = from_sparse(incidence_1)

Number of hyperedges 2708
Number of hyperedges without duplicated hyperedges 2590
Hyperedge statistics:     min = 2     max = 169     mean = 4.991891891891892     median = 4.0     std = 5.322452976509452


# Define the Neural Network



In [8]:
in_channels = x_0s.shape[1]
hid_dim = 128
out_dim = len(torch.unique(y))
heads = 8
n_layers = 1
mlp_num_layers = 2
#Q_n = 1


# Define the model
model = AllSetTransformer(
    in_channels=in_channels,
    hidden_channels=hid_dim,
    heads=heads,
    out_channels=out_dim,
    n_layers=n_layers,
    mlp_num_layers=mlp_num_layers,
    dropout=0.25,
)
model = model.to(device)

# Train the Neural Network

We specify the model, the loss, and an optimizer.

In [9]:
# Optimizer and loss
opt = torch.optim.Adam(model.parameters(), lr=0.01)
# Categorial cross-entropy loss
loss_fn = torch.nn.CrossEntropyLoss()
# Accuracy
acc_fn = lambda y, y_hat: (y == y_hat).float().mean()


Split the dataset into train and test sets.

In [56]:
# test_size = 0.2
# x_0_train, x_0_test = train_test_split(x_0s, test_size=test_size, shuffle=False)
# incidence_1_train, incidence_1_test = train_test_split(
#     incidence_1_list, test_size=test_size, shuffle=False
# )
# y_train, y_test = train_test_split(ys, test_size=test_size, shuffle=False)

The following cell performs the training, looping over the network for a low amount of epochs. We keep training minimal for the purpose of rapid testing.

In [10]:
x_0s = torch.tensor(x_0s)
x_0s, incidence_1, y = (
            x_0s.float().to(device),
            incidence_1.float().to(device),
            torch.tensor(y, dtype=torch.long).to(device),
        )

test_interval = 1
num_epochs = 100
for epoch_i in range(1, num_epochs + 1):
    epoch_loss = []
    model.train()

    opt.zero_grad()
    # Extract edge_index from sparse incidence matrix
    # edge_index, _ = to_edge_index(incidence_1)
    y_hat = model(x_0s, incidence_1)
    loss = loss_fn(y_hat[train_mask], y[train_mask])

    loss.backward()
    opt.step()
    epoch_loss.append(loss.item())

    print(
        f"Epoch: {epoch_i} loss: {np.mean(epoch_loss):.4f}, \
        acc: {acc_fn(y_hat[train_mask].argmax(1), y[train_mask]):.4f}",
        flush=True,
    )

    if epoch_i % test_interval == 0:
        with torch.no_grad(): 
           # y_hat = model(x_0s, incidence_1)
            loss = loss_fn(y_hat[val_mask], y[val_mask])

            print(f"Test_loss: {loss:.4f}, \
                   Val_acc: {acc_fn(y_hat[val_mask].argmax(1), y[val_mask]):.4f}", flush=True)

  x_0s = torch.tensor(x_0s)
  torch.tensor(y, dtype=torch.long).to(device),


Epoch: 1 loss: 2.0500,         acc: 0.1253
Test_loss: 2.0209,                    Val_acc: 0.1359
Epoch: 2 loss: 2.1502,         acc: 0.2552
Test_loss: 2.1461,                    Val_acc: 0.2373
Epoch: 3 loss: 2.0653,         acc: 0.1686
Test_loss: 2.0859,                    Val_acc: 0.1682
Epoch: 4 loss: 1.8959,         acc: 0.2841
Test_loss: 1.9096,                    Val_acc: 0.2995
Epoch: 5 loss: 1.9125,         acc: 0.2737
Test_loss: 1.9260,                    Val_acc: 0.2558
Epoch: 6 loss: 1.8842,         acc: 0.2812
Test_loss: 1.8673,                    Val_acc: 0.2995
Epoch: 7 loss: 1.8734,         acc: 0.2852
Test_loss: 1.8388,                    Val_acc: 0.3018
Epoch: 8 loss: 1.8695,         acc: 0.2846
Test_loss: 1.8469,                    Val_acc: 0.3065
Epoch: 9 loss: 1.8626,         acc: 0.2639
Test_loss: 1.8628,                    Val_acc: 0.2581
Epoch: 10 loss: 1.8571,         acc: 0.2725
Test_loss: 1.8502,                    Val_acc: 0.2811
Epoch: 11 loss: 1.8486,      

KeyboardInterrupt: 

In [31]:
from topomodelx.nn.hypergraph.allset import AllSet





In [40]:
from topomodelx.nn.hypergraph.allset import AllSet


in_channels = x_0s.shape[1]
hidden_channels = 64
out_channels = len(torch.unique(y))


# Define the model
model = AllSet(
    in_channels=in_channels,
    hidden_channels=hidden_channels,
    out_channels=out_channels,
    n_layers=1,
    mlp_num_layers=1,
)
model = model.to(device)
# Optimizer and loss
opt = torch.optim.Adam(model.parameters(), lr=0.01)
loss_fn = torch.nn.CrossEntropyLoss()

In [41]:
test_interval = 10
num_epochs = 100
for epoch_i in range(1, num_epochs + 1):
    epoch_loss = []
    model.train()

    opt.zero_grad()
    # Extract edge_index from sparse incidence matrix
    # edge_index, _ = to_edge_index(incidence_1)
    y_hat = model(x_0s, incidence_1)
    loss = loss_fn(y_hat[train_mask], y[train_mask])
    
    loss.backward()
    opt.step()
    epoch_loss.append(loss.item())

    print(
        f"Epoch: {epoch_i} loss: {np.mean(epoch_loss):.4f}, \
        acc: {acc_fn(y_hat[train_mask].argmax(1), y[train_mask]):.4f}",
        flush=True,
    )

    if epoch_i % test_interval == 0:
        with torch.no_grad(): 
           # y_hat = model(x_0s, incidence_1)
            loss = loss_fn(y_hat[val_mask], y[val_mask])

            print(f"Test_loss: {loss:.4f}, \
                   Val_acc: {acc_fn(y_hat[val_mask].argmax(1), y[val_mask]):.4f}", flush=True)

Epoch: 1 loss: 1.9529,         acc: 0.1236
Epoch: 2 loss: 1.9349,         acc: 0.2667
Epoch: 3 loss: 1.8677,         acc: 0.2806
Epoch: 4 loss: 2.0707,         acc: 0.2858
Epoch: 5 loss: 1.8432,         acc: 0.2858
Epoch: 6 loss: 1.8947,         acc: 0.2846
Epoch: 7 loss: 1.9171,         acc: 0.2973
Epoch: 8 loss: 1.9251,         acc: 0.2846
Epoch: 9 loss: 1.9260,         acc: 0.2321
Epoch: 10 loss: 1.9237,         acc: 0.2540
Test_loss: 1.9265,                    Val_acc: 0.2350
Epoch: 11 loss: 1.9183,         acc: 0.2494
Epoch: 12 loss: 1.9095,         acc: 0.2286
Epoch: 13 loss: 1.8956,         acc: 0.2350
Epoch: 14 loss: 1.8771,         acc: 0.2061
Epoch: 15 loss: 1.8649,         acc: 0.1940
Epoch: 16 loss: 1.8560,         acc: 0.2246
Epoch: 17 loss: 1.8444,         acc: 0.2598
Epoch: 18 loss: 1.8181,         acc: 0.2627
Epoch: 19 loss: 1.8011,         acc: 0.2823
Epoch: 20 loss: 1.7956,         acc: 0.2852
Test_loss: 1.8108,                    Val_acc: 0.3041
Epoch: 21 loss: 1.783

In [30]:
loss

tensor(1.9480, grad_fn=<NllLossBackward0>)