# Train a Simplicial 2-complex convolutional neural network (SCConv)


In this notebook, we will create and train a Simplicial 2-complex convolutional neural in the simplicial complex domain, as proposed in the paper by [Bunch et. al : Simplicial 2-Complex Convolutional Neural Networks (2020)](https://openreview.net/pdf?id=Sc8glB-k6e9).


We train the model to perform

The equations of one layer of this neural network are given by:

游린 $\quad m_{y\rightarrow x}^{(0\rightarrow 0)} = ({\tilde{A}_{\uparrow,0}})_{xy} \cdot h_y^{t,(0)} \cdot \Theta^{t,(0\rightarrow0)}$

游린 $\quad m^{(1\rightarrow0)}_{y\rightarrow x}  = (B_1)_{xy} \cdot h_y^{t,(0)} \cdot \Theta^{t,(1\rightarrow 0)}$

游린 $\quad m^{(0 \rightarrow 1)}_{y \rightarrow x}  = (\tilde B_1)_{xy} \cdot h_y^{t,(0)} \cdot \Theta^{t,(0 \rightarrow1)}$

游린 $\quad m^{(1\rightarrow1)}_{y\rightarrow x} = ({\tilde{A}_{\downarrow,1}} + {\tilde{A}_{\uparrow,1}})_{xy} \cdot h_y^{t,(1)} \cdot \Theta^{t,(1\rightarrow1)}$

游린 $\quad m^{(2\rightarrow1)}_{y \rightarrow x}  = (B_2)_{xy} \cdot h_y^{t,(2)} \cdot \Theta^{t,(2 \rightarrow1)}$

游린 $\quad m^{(1 \rightarrow 2)}_{y \rightarrow x}  = (\tilde B_2)_{xy} \cdot h_y^{t,(1)} \cdot \Theta^{t,(1 \rightarrow 2)}$

游린 $\quad m^{(2 \rightarrow 2)}_{y \rightarrow x}  = ({\tilde{A}_{\downarrow,2}})\_{xy} \cdot h_y^{t,(2)} \cdot \Theta^{t,(2 \rightarrow 2)}$

游릲 $\quad m_x^{(0 \rightarrow 0)}  = \sum_{y \in \mathcal{L}_\uparrow(x)} m_{y \rightarrow x}^{(0 \rightarrow 0)}$

游릲 $\quad m_x^{(1 \rightarrow 0)}  = \sum_{y \in \mathcal{C}(x)} m_{y \rightarrow x}^{(1 \rightarrow 0)}$

游릲 $\quad m_x^{(0 \rightarrow 1)}  = \sum_{y \in \mathcal{B}(x)} m_{y \rightarrow x}^{(0 \rightarrow 1)}$

游릲 $\quad m_x^{(1 \rightarrow 1)}  = \sum_{y \in (\mathcal{L}_\uparrow(x) + \mathcal{L}_\downarrow(x))} m_{y \rightarrow x}^{(1 \rightarrow 1)}$

游릲 $\quad m_x^{(2 \rightarrow 1)} = \sum_{y \in \mathcal{C}(x)} m_{y \rightarrow x}^{(2 \rightarrow 1)}$

游릲 $\quad m_x^{(1 \rightarrow 2)}  = \sum_{y \in \mathcal{B}(x)} m_{y \rightarrow x}^{(1 \rightarrow 2)}$

游릲 $\quad m_x^{(2 \rightarrow 2)}  = \sum_{y \in \mathcal{L}_\downarrow(x)} m_{y \rightarrow x}^{(2 \rightarrow 2)}$

游릴 $\quad m_x^{(0)}  = m_x^{(1\rightarrow0)}+ m_x^{(0\rightarrow0)}$

游릴 $\quad m_x^{(1)}  = m_x^{(2\rightarrow1)}+ m_x^{(1\rightarrow1)}$

游릱 $\quad h^{t+1, (0)}_x  = \sigma(m_x^{(0)})$

游릱 $\quad h^{t+1, (1)}_x  = \sigma(m_x^{(1)})$

游릱 $\quad h^{t+1, (2)}_x  = \sigma(m_x^{(2)})$


Where the notations are defined in [Papillon et al : Architectures of Topological Deep Learning: A Survey of Topological Neural Networks (2023)](https://arxiv.org/abs/2304.10031).

In [103]:
import torch
import numpy as np

import toponetx.datasets as datasets

from scipy.sparse import coo_matrix
from scipy.sparse import diags

from topomodelx.base.aggregation import Aggregation

from topomodelx.nn.simplicial.scconv import SCConv

In [93]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


# Pre-processing

## Import dataset ##

The first step is to import the dataset, shrec 16, a benchmark dataset for 3D mesh classification.

In [3]:
shrec, _ = datasets.mesh.shrec_16(size="small")

shrec = {key: np.array(value) for key, value in shrec.items()}
x_0s = shrec["node_feat"]
x_1s = shrec["edge_feat"]
x_2s = shrec["face_feat"]

ys = shrec["label"]
simplexes = shrec["complexes"]

Loading shrec 16 small dataset...

done!


In [97]:
# l = np.unique(ys, return_counts=True)
# print(l)
print(len(np.unique(ys)))

30


In [4]:
i_complex = 0
print(
    f"The {i_complex}th simplicial complex has {x_0s[i_complex].shape[0]} nodes with features of dimension {x_0s[i_complex].shape[1]}."
)
print(
    f"The {i_complex}th simplicial complex has {x_1s[i_complex].shape[0]} edges with features of dimension {x_1s[i_complex].shape[1]}."
)
print(
    f"The {i_complex}th simplicial complex has {x_2s[i_complex].shape[0]} faces with features of dimension {x_2s[i_complex].shape[1]}."
)

The 0th simplicial complex has 252 nodes with features of dimension 6.
The 0th simplicial complex has 750 edges with features of dimension 10.
The 0th simplicial complex has 500 faces with features of dimension 7.


## Helper functions ##


In [27]:
def normalize_higher_order_adj(A_opt):
    """
    Args:
        A_opt is an opt that maps a j-cochain to a k-cochain.
        shape [num_of_k_simplices num_of_j_simplices]

    return:
         D^{-0.5}* (A_opt)* D^{-0.5}.
    """
    rowsum = np.array(np.abs(A_opt).sum(1))
    r_inv_sqrt = np.power(rowsum, -0.5).flatten()
    r_inv_sqrt[np.isinf(r_inv_sqrt)] = 0.0
    r_mat_inv_sqrt = diags(r_inv_sqrt)
    A_opt_to = A_opt.dot(r_mat_inv_sqrt).transpose().dot(r_mat_inv_sqrt)

    return coo_matrix(A_opt_to)

In [19]:
# incidence_1_list

In [43]:
adjacency_1 = simplexes[13].adjacency_matrix(rank=1, signed=False)
incidence_1 = simplexes[13].incidence_matrix(rank=1, signed=False)

# k = normalize_higher_order_adj(adjacency_1)
# print(k)

print(adjacency_1.todense().shape)
k = normalize_higher_order_adj(adjacency_1)
print(k.todense().shape)

(750, 750)
(750, 750)


# Define Neighbourhood Structures

We create the neigborood structures expected by SSConv. The SSConv layer expects the following neighbourhood structures:
* incidence_1 $B_1$
* incidence_1_norm $\tilde{B}_1$
* incidence_2 $B_2$
* incidence_2_norm $\tilde{B}_1$
* adjacency_up_0_norm $\tilde{A}_{\uparrow,0}$
* adjacency_up_1_norm $\tilde{A}_{\uparrow,1}$
* adjacency_down_1_norm $\tilde{A}_{\downarrow,1}$
* adjacency_down_2_norm $\tilde{A}_{\downarrow,2}$

In [99]:
def get_neighborhoods(simplexes):
    incidence_1_list = []
    incidence_1_norm_list = []
    incidence_2_list = []
    incidence_2_norm_list = []
    adjacency_up_0_norm_list = []
    adjacency_up_1_norm_list = []
    adjacency_down_1_norm_list = []
    adjacency_down_2_norm_list = []

    # incidence_1_list = []
    # incidence_2_list = []
    # up_laplacian_1_list = []
    # up_laplacian_2_list = []
    # down_laplacian_1_list = []
    # down_laplacian_2_list = []
    # for simplex in simplexes:
    #     B1 = simplex.incidence_matrix(rank=1, signed=False)
    #     B2 = simplex.incidence_matrix(rank=2, signed=False)
    #
    #     up_laplacian_1 = simplex.up_laplacian_matrix(rank=0)  #1
    #     up_laplacian_2 = simplex.up_laplacian_matrix(rank=1)  #2
    #
    #     down_laplacian_1 = simplex.down_laplacian_matrix(rank=1)  #1
    #     down_laplacian_2 = simplex.down_laplacian_matrix(rank=2)  #2
    #
    #     incidence_1 = torch.from_numpy(B1.todense()).to_sparse()
    #     incidence_2 = torch.from_numpy(B2.todense()).to_sparse()
    #
    #     up_laplacian_1 = torch.from_numpy(up_laplacian_1.todense()).to_sparse()
    #     up_laplacian_2 = torch.from_numpy(up_laplacian_2.todense()).to_sparse()
    #
    #     down_laplacian_1 = torch.from_numpy(down_laplacian_1.todense()).to_sparse()
    #     down_laplacian_2 = torch.from_numpy(down_laplacian_2.todense()).to_sparse()
    #
    #     incidence_1_list.append(incidence_1)
    #     incidence_2_list.append(incidence_2)
    #     up_laplacian_1_list.append(up_laplacian_1)
    #     up_laplacian_2_list.append(up_laplacian_2)
    #     down_laplacian_1_list.append(down_laplacian_1)
    #     down_laplacian_2_list.append(down_laplacian_2)

    return (
        incidence_1_list,
        incidence_1_norm_list,
        incidence_2_list,
        incidence_2_norm_list,
        adjacency_up_0_norm_list,
        adjacency_up_1_norm_list,
        adjacency_down_1_norm_list,
        adjacency_down_2_norm_list,
    )


(
    incidence_1_list,
    incidence_1_norm_list,
    incidence_2_list,
    incidence_2_norm_list,
    adjacency_up_0_norm_list,
    adjacency_up_1_norm_list,
    adjacency_down_1_norm_list,
    adjacency_down_2_norm_list,
) = get_neighborhoods(simplexes)

# Create and Train the Neural Network

## prepare training and test data

In [100]:
# ToDo: apply train/test splitting

x_0_train = x_0s
x_1_train = x_1s
x_2_train = x_2s

(
    incidence_1_list,
    incidence_1_norm_list,
    incidence_2_list,
    incidence_2_norm_list,
    adjacency_up_0_norm_list,
    adjacency_up_1_norm_list,
    adjacency_down_1_norm_list,
    adjacency_down_2_norm_list,
) = get_neighborhoods(simplexes)

In [105]:
model = SCConv(
    node_channels=x_0s[0].shape[1],
    edge_channels=x_1s[0].shape[1],
    face_channels=x_1s[0].shape[1],
    n_classes=len(np.unique(ys)),
    n_layers=2,
)
model = model.to(device)
opt = torch.optim.Adam(model.parameters(), lr=0.1)
loss_fn = torch.nn.MSELoss()



The following cell performs the training, looping over the network for a low number of epochs.

In [113]:
test_interval = 1
num_epochs = 5

for epoch_i in range(1, num_epochs + 1):
    epoch_loss = []
    model.train()
    for (
        x_0,
        x_1,
        x_2,
        incid1,
        incid1_norm,
        incid2,
        incid2_norm,
        adj0_up_norm,
        adj1_up_norm,
        adj1_down_norm,
        adj2_down_norm,
        y,
    ) in zip(
        x_0_train,
        x_1_train,
        x_2_train,
        incidence_1_list,
        incidence_1_norm_list,
        incidence_2_list,
        incidence_2_norm_list,
        adjacency_up_0_norm_list,
        adjacency_up_1_norm_list,
        adjacency_down_1_norm_list,
        adjacency_down_2_norm_list,
        ys,
    ):
        (
            x_0,
            x_1,
            x_2,
            y,
            incid1,
            incid1_norm,
            incid2,
            incid2_norm,
            adj0_up_norm,
            adj1_up_norm,
            adj1_down_norm,
            adj2_down_norm,
        ) = (
            torch.tensor(x_0).float().to(device),
            torch.tensor(x_1).float().to(device),
            torch.tensor(x_2).float().to(device),
            torch.tensor(y).float().to(device),
            incid1.float().to(device),
            incid1_norm.float().to(device),
            incid2.float().to(device),
            incid2_norm.float().to(device),
            adj0_up_norm.float().to(device),
            adj1_up_norm.float().to(device),
            adj1_down_norm.float().to(device),
            adj2_down_norm.float().to(device),
        )

        opt.zero_grad()
        y_hat = model(
            x_0,
            x_1,
            x_2,
            y,
            incid1,
            incid1_norm,
            incid2,
            incid2_norm,
            adj0_up_norm,
            adj1_up_norm,
            adj1_down_norm,
            adj2_down_norm,
        )
        loss = loss_fn(y_hat.flatten(), y)
        loss.backward()

        opt.step()
        epoch_loss.append(loss.item())

    print(
        f"Epoch: {epoch_i} loss: {np.mean(epoch_loss)}",
        flush=True,
    )

    if epoch_i % test_interval == 0:
        correct_count = 0
        with torch.no_grad():
            for (
                x_0t,
                x_1t,
                x_2t,
                incid1t,
                incid1_normt,
                incid2t,
                incid2_normt,
                adj0_up_normt,
                adj1_up_normt,
                adj1_down_normt,
                adj2_down_normt,
                yt,
            ) in zip(
                x_0_train,
                x_1_train,
                x_2_train,
                incidence_1_list,
                incidence_1_norm_list,
                incidence_2_list,
                incidence_2_norm_list,
                adjacency_up_0_norm_list,
                adjacency_up_1_norm_list,
                adjacency_down_1_norm_list,
                adjacency_down_2_norm_list,
                ys,
            ):
                (
                    x_0t,
                    x_1t,
                    x_2t,
                    yt,
                    incid1t,
                    incid1_normt,
                    incid2t,
                    incid2_normt,
                    adj0_up_normt,
                    adj1_up_normt,
                    adj1_down_norm,
                    adj2_down_norm,
                ) = (
                    torch.tensor(x_0t).float().to(device),
                    torch.tensor(x_1t).float().to(device),
                    torch.tensor(x_2t).float().to(device),
                    torch.tensor(yt).float().to(device),
                    incid1t.float().to(device),
                    incid1_normt.float().to(device),
                    incid2t.float().to(device),
                    incid2_normt.float().to(device),
                    adj0_up_normt.float().to(device),
                    adj1_up_normt.float().to(device),
                    adj1_down_normt.float().to(device),
                    adj2_down_normt.float().to(device),
                )

                y_hat = model(
                    x_0t,
                    x_1t,
                    x_2t,
                    yt,
                    incid1t,
                    incid1_normt,
                    incid2t,
                    incid2_normt,
                    adj0_up_normt,
                    adj1_up_normt,
                    adj1_down_normt,
                    adj2_down_normt,
                )
                test_loss = loss_fn(y_hat, yt)

                if round(y_hat.item()) == round(yt.item()):
                    correct_count += 1

                print(f"Test_loss: {test_loss}", flush=True)

Epoch: 1 loss: nan
Epoch: 2 loss: nan
Epoch: 3 loss: nan
Epoch: 4 loss: nan
Epoch: 5 loss: nan
