# Train a Simplicial Complex Convolutional Network (SCCN)

We create a SCCN model a la [Yang et al : Efficient Representation Learning for Higher-Order Data with
Simplicial Complexes (LoG 2022)](https://proceedings.mlr.press/v198/yang22a/yang22a.pdf)

We train the model to perform binary node classification using the KarateClub benchmark dataset. 

The model operates on cells of all ranks up to some max rank $r_\mathrm{max}$.
The equations of one layer of this neural network are given by:

🟥 $\quad m_{{y \rightarrow x}}^{(r \rightarrow r)} = (H_{r})_{xy} \cdot h^{t,(r)}_y \cdot \Theta^{t,(r\to r)}$,    (for $0\leq r \leq r_\mathrm{max}$)

🟥 $\quad m_{{y \rightarrow x}}^{(r-1 \rightarrow r)} = (B_{r}^T)_{xy} \cdot h^{t,(r-1)}_y \cdot \Theta^{t,(r-1\to r)}$,    (for $1\leq r \leq r_\mathrm{max}$)

🟥 $\quad m_{{y \rightarrow x}}^{(r+1 \rightarrow r)} = (B_{r+1})_{xy} \cdot h^{t,(r+1)}_y \cdot \Theta^{t,(r+1\to r)}$,    (for $0\leq r \leq r_\mathrm{max}-1$)

🟧 $\quad m_{x}^{(r \rightarrow r)}  = \sum_{y \in \mathcal{L}_\downarrow(x)\bigcup \mathcal{L}_\uparrow(x)} m_{y \rightarrow x}^{(r \rightarrow r)}$

🟧 $\quad m_{x}^{(r-1 \rightarrow r)}  = \sum_{y \in \mathcal{B}(x)} m_{y \rightarrow x}^{(r-1 \rightarrow r)}$

🟧 $\quad m_{x}^{(r+1 \rightarrow r)}  = \sum_{y \in \mathcal{C}(x)} m_{y \rightarrow x}^{(r+1 \rightarrow r)}$

🟩 $\quad m_x^{(r)}  = m_x^{(r \rightarrow r)} + m_x^{(r-1 \rightarrow r)} + m_x^{(r+1 \rightarrow r)}$

🟦 $\quad h_x^{t+1,(r)}  = \sigma(m_x^{(r)})$

Where the notations are defined in [Papillon et al : Architectures of Topological Deep Learning: A Survey of Topological Neural Networks (2023)](https://arxiv.org/abs/2304.10031).

In [1]:
import torch
import numpy as np

import toponetx.datasets.graph as graph

from topomodelx.nn.simplicial.sccn import SCCN

# Pre-processing

## Import dataset ##

The first step is to import the Karate Club (https://www.jstor.org/stable/3629752) dataset. This is a singular graph with 34 nodes that belong to two different social groups. We will use these groups for the task of node-level binary classification.

We must first lift our graph dataset into the simplicial complex domain.

Since our task will be node classification, we must retrieve an input signal on the nodes. The signal will have shape $n_\text{nodes} \times$ in_channels, where in_channels is the dimension of each cell's feature. The feature dimension is `feat_dim`.

In [2]:
dataset = graph.karate_club(complex_type="simplicial", feat_dim=8)
print(dataset)

Simplicial Complex with shape (34, 78, 45, 11, 2) and dimension 4


## Define neighborhood structures. ##

Our implementation allows for features on cells up to an arbitrary maximum rank. In this dataset, we can use at most `max_rank = 3`, which is what we choose.

We define incidence and adjacency matrices up to the max rank and put them in dictionaries indexed by the rank, as is expected by the `SCCNLayer`.
The form of tha adjacency and incidence matrices could be chosen arbitrarily, here we follow the original formulation by Yang et al. quite closely and select the adjacencies as r-Hodge Laplacians $H_r$, summed with $2I$ (or just $I$ for $r\in\{0, r_\mathrm{max}\}$) to allow cells to pass messages to themselves. The incidence matrices are the usual boundary matrices $B_r$.
One could additionally weight/normalize these matrices as suggested by Yang et al., but we refrain from doing this for simplicity.

In [3]:
max_rank = 3  # There are features up to tetrahedron order in the dataset

In [4]:
def sparse_to_torch(X):
    return torch.from_numpy(X.todense()).to_sparse()


incidences = {
    f"rank_{r}": sparse_to_torch(dataset.incidence_matrix(rank=r))
    for r in range(1, max_rank + 1)
}

adjacencies = {}
adjacencies["rank_0"] = (
    sparse_to_torch(dataset.adjacency_matrix(rank=0))
    + torch.eye(dataset.shape[0]).to_sparse()
)
for r in range(1, max_rank):
    adjacencies[f"rank_{r}"] = (
        sparse_to_torch(
            dataset.adjacency_matrix(rank=r) + dataset.coadjacency_matrix(rank=r)
        )
        + 2 * torch.eye(dataset.shape[r]).to_sparse()
    )
adjacencies[f"rank_{max_rank}"] = (
    sparse_to_torch(dataset.coadjacency_matrix(rank=max_rank))
    + torch.eye(dataset.shape[max_rank]).to_sparse()
)

for r in range(max_rank + 1):
    print(f"The adjacency matrix H{r} has shape: {adjacencies[f'rank_{r}'].shape}.")
    if r > 0:
        print(f"The incidence matrix B{r} has shape: {incidences[f'rank_{r}'].shape}.")

The adjacency matrix H0 has shape: torch.Size([34, 34]).
The adjacency matrix H1 has shape: torch.Size([78, 78]).
The incidence matrix B1 has shape: torch.Size([34, 78]).
The adjacency matrix H2 has shape: torch.Size([45, 45]).
The incidence matrix B2 has shape: torch.Size([78, 45]).
The adjacency matrix H3 has shape: torch.Size([11, 11]).
The incidence matrix B3 has shape: torch.Size([45, 11]).


  self._set_arrayXarray(i, j, x)


## Import signal ##

We import the features at each rank.

In [5]:
x_0 = []
for _, v in dataset.get_simplex_attributes("node_feat").items():
    x_0.append(v)
x_0 = torch.tensor(np.stack(x_0))
channels_nodes = x_0.shape[-1]

In [6]:
print(f"There are {x_0.shape[0]} nodes with features of dimension {x_0.shape[1]}.")

There are 34 nodes with features of dimension 8.


Load edge features.

In [7]:
x_1 = []
for k, v in dataset.get_simplex_attributes("edge_feat").items():
    x_1.append(v)
x_1 = torch.tensor(np.stack(x_1))

In [8]:
print(f"There are {x_1.shape[0]} edges with features of dimension {x_1.shape[1]}.")

There are 78 edges with features of dimension 8.


Similarly for face features:

In [9]:
x_2 = []
for k, v in dataset.get_simplex_attributes("face_feat").items():
    x_2.append(v)
x_2 = torch.tensor(np.stack(x_2))

In [10]:
print(f"There are {x_2.shape[0]} faces with features of dimension {x_2.shape[1]}.")

There are 45 faces with features of dimension 8.


Higher order features:

In [11]:
x_3 = []
for k, v in dataset.get_simplex_attributes("tetrahedron_feat").items():
    x_3.append(v)
x_3 = torch.tensor(np.stack(x_3))

In [12]:
print(
    f"There are {x_3.shape[0]} tetrahedrons with features of dimension {x_3.shape[1]}."
)

There are 11 tetrahedrons with features of dimension 8.


The features are organized in a dictionary keeping track of their rank, similar to the adjacencies/incidences earlier.

In [13]:
features = {"rank_0": x_0, "rank_1": x_1, "rank_2": x_2, "rank_3": x_3}

## Define binary labels
We retrieve the labels associated to the nodes of each input simplex. In the KarateClub dataset, two social groups emerge. So we assign binary labels to the nodes indicating of which group they are a part.

We keep the last four nodes' true labels for the purpose of testing.

In [14]:
y = np.array(
    [
        1,
        1,
        1,
        1,
        1,
        1,
        1,
        1,
        1,
        0,
        1,
        1,
        1,
        1,
        0,
        0,
        1,
        1,
        0,
        1,
        0,
        1,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
    ]
)

y_train = torch.from_numpy(y[:30])
y_test = torch.from_numpy(y[30:])
y_train, y_test

(tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 0, 1, 1, 0, 1, 0, 1, 0, 0,
         0, 0, 0, 0, 0, 0]),
 tensor([0, 0, 0, 0]))

# Create the Neural Network

Using the HSNLayer class, we create a neural network with stacked layers. A linear layer at the end produces an output with shape $n_\text{nodes}$, so we can compare with our binary labels.

# Train the Neural Network

We specify the model with our pre-made neighborhood structures and specify an optimizer.

In [16]:
model = SCCN(
    channels=channels_nodes,
    max_rank=max_rank,
    n_layers=5,
    update_func="sigmoid",
)
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)

The following cell performs the training, looping over the network for a low number of epochs. Typically achieves 100% train accuracy. Test accuracy is more arbitrary between runs, likely due to the small dataset set size.

In [17]:
test_interval = 50
num_epochs = 200
for epoch_i in range(1, num_epochs + 1):
    epoch_loss = []
    model.train()
    optimizer.zero_grad()

    y_hat = model(features, incidences, adjacencies)
    loss = torch.nn.functional.binary_cross_entropy_with_logits(
        y_hat[: len(y_train)].float(), y_train.float()
    )
    epoch_loss.append(loss.item())
    loss.backward()
    optimizer.step()

    y_pred = (y_hat > 0).long()
    accuracy = (y_pred[: len(y_train)] == y_train).float().mean().item()
    print(
        f"Epoch: {epoch_i} loss: {np.mean(epoch_loss):.4f} Train_acc: {accuracy:.4f}",
        flush=True,
    )
    if epoch_i % test_interval == 0:
        with torch.no_grad():
            y_hat_test = model(features, incidences, adjacencies)
            y_pred_test = (y_hat_test > 0).long()
            test_accuracy = (
                (y_pred_test[-len(y_test) :] == y_test).float().mean().item()
            )
            print(f"Test_acc: {test_accuracy:.4f}", flush=True)

Epoch: 1 loss: 0.6721 Train_acc: 0.6333
Epoch: 2 loss: 0.6891 Train_acc: 0.5667
Epoch: 3 loss: 0.6284 Train_acc: 0.5667
Epoch: 4 loss: 0.6173 Train_acc: 0.6667
Epoch: 5 loss: 0.6110 Train_acc: 0.7000
Epoch: 6 loss: 0.5831 Train_acc: 0.7000
Epoch: 7 loss: 0.5695 Train_acc: 0.7000
Epoch: 8 loss: 0.5638 Train_acc: 0.7000
Epoch: 9 loss: 0.5493 Train_acc: 0.7333
Epoch: 10 loss: 0.5384 Train_acc: 0.7667
Epoch: 11 loss: 0.5141 Train_acc: 0.7333
Epoch: 12 loss: 0.5201 Train_acc: 0.6667
Epoch: 13 loss: 0.5201 Train_acc: 0.7000
Epoch: 14 loss: 0.5038 Train_acc: 0.6667
Epoch: 15 loss: 0.5016 Train_acc: 0.7333
Epoch: 16 loss: 0.4906 Train_acc: 0.7333
Epoch: 17 loss: 0.4763 Train_acc: 0.7000
Epoch: 18 loss: 0.4545 Train_acc: 0.7667
Epoch: 19 loss: 0.4483 Train_acc: 0.7667
Epoch: 20 loss: 0.4153 Train_acc: 0.8000
Epoch: 21 loss: 0.4062 Train_acc: 0.8000
Epoch: 22 loss: 0.3790 Train_acc: 0.8333
Epoch: 23 loss: 0.3916 Train_acc: 0.7667
Epoch: 24 loss: 0.3529 Train_acc: 0.8667
Epoch: 25 loss: 0.2900 Tr