# Train a Simplex Convolutional Network (SCN) of Rank 2

This notebook illustrates the SCN layer proposed in [Yang22c]_ for a simplicial complex of
rank 2, that is for 0-cells (nodes), 1-cells (edges) and 2-cells (faces) only.

References
----------
.. [YSB22] Ruochen Yang, Frederic Sala, and Paul Bogdan.
    Efficient Representation Learning for Higher-Order Data with 
    Simplicial Complexes. In Bastian Rieck and Razvan Pascanu, editors, 
    Proceedings of the First Learning on Graphs Conference, volume 198 
    of Proceedings of Machine Learning Research, pages 13:1–13:21. PMLR, 
    09–12 Dec 2022a. https://proceedings.mlr.press/v198/yang22a.html.

In [1]:
import torch
import numpy as np
import toponetx.datasets as datasets

from sklearn.model_selection import train_test_split
from topomodelx.nn.simplicial.scn2 import SCN2

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cpu


# Pre-processing

## Import dataset ##

According to the original paper, SCN is good at simplex classification. Thus, I chose shrec_16, a benchmark dataset for 3D mesh classification.

In [3]:
shrec, _ = datasets.mesh.shrec_16(size="small")

shrec = {key: np.array(value) for key, value in shrec.items()}
x_0s = shrec["node_feat"]
x_1s = shrec["edge_feat"]
x_2s = shrec["face_feat"]

ys = shrec["label"]
ys = ys.reshape((100, 1))
simplexes = shrec["complexes"]

Loading shrec 16 small dataset...

done!


In [4]:
i_complex = 6
print(
    f"The {i_complex}th simplicial complex has {x_0s[i_complex].shape[0]} nodes with features of dimension {x_0s[i_complex].shape[1]}."
)
print(
    f"The {i_complex}th simplicial complex has {x_1s[i_complex].shape[0]} edges with features of dimension {x_1s[i_complex].shape[1]}."
)
print(
    f"The {i_complex}th simplicial complex has {x_2s[i_complex].shape[0]} faces with features of dimension {x_2s[i_complex].shape[1]}."
)

The 6th simplicial complex has 252 nodes with features of dimension 6.
The 6th simplicial complex has 750 edges with features of dimension 10.
The 6th simplicial complex has 500 faces with features of dimension 7.


## Define neighborhood structures. ##

Now we retrieve the neighborhood structures (i.e. their representative matrices) that we will use to send messges on the domain. In this case, we need the normalized Laplacian matrix on nodes, edges, and faces. We also convert the neighborhood structures to torch tensors.

In [5]:
laplacian_0s = []
laplacian_1s = []
laplacian_2s = []
for x in simplexes:
    laplacian_0 = x.normalized_laplacian_matrix(rank=0)
    laplacian_1 = x.normalized_laplacian_matrix(rank=1)
    laplacian_2 = x.normalized_laplacian_matrix(rank=2)

    laplacian_0 = torch.from_numpy(laplacian_0.todense()).to_sparse()
    laplacian_1 = torch.from_numpy(laplacian_1.todense()).to_sparse()
    laplacian_2 = torch.from_numpy(laplacian_2.todense()).to_sparse()

    laplacian_0s.append(laplacian_0)
    laplacian_1s.append(laplacian_1)
    laplacian_2s.append(laplacian_2)

# Train the Neural Network

We specify the model with our pre-made neighborhood structures and specify an optimizer.

In [6]:
in_channels_0 = x_0s[i_complex].shape[1]
in_channels_1 = x_1s[i_complex].shape[1]
in_channels_2 = x_2s[i_complex].shape[1]

In [7]:
model = SCN2(in_channels_0, in_channels_1, in_channels_2, num_classes=1)
model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
loss_fn = torch.nn.MSELoss()

In [8]:
test_size = 0.2
x_0s_train, x_0s_test = train_test_split(x_0s, test_size=test_size, shuffle=False)
x_1s_train, x_1s_test = train_test_split(x_1s, test_size=test_size, shuffle=False)
x_2s_train, x_2s_test = train_test_split(x_2s, test_size=test_size, shuffle=False)

laplacian_0s_train, laplacian_0s_test = train_test_split(
    laplacian_0s, test_size=test_size, shuffle=False
)
laplacian_1s_train, laplacian_1s_test = train_test_split(
    laplacian_1s, test_size=test_size, shuffle=False
)
laplacian_2s_train, laplacian_2s_test = train_test_split(
    laplacian_2s, test_size=test_size, shuffle=False
)

y_train, y_test = train_test_split(ys, test_size=test_size, shuffle=False)

The following cell performs the training, looping over the network for a low number of epochs.

In [9]:
test_interval = 2
num_epochs = 8
for epoch_i in range(1, num_epochs + 1):
    epoch_loss = []
    model.train()
    for x_0, x_1, x_2, laplacian_0, laplacian_1, laplacian_2, y in zip(
        x_0s_train,
        x_1s_train,
        x_2s_train,
        laplacian_0s_train,
        laplacian_1s_train,
        laplacian_2s_train,
        y_train,
    ):
        x_0, x_1, x_2, y = (
            torch.tensor(x_0).float().to(device),
            torch.tensor(x_1).float().to(device),
            torch.tensor(x_2).float().to(device),
            torch.tensor(y).float().to(device),
        )
        laplacian_0, laplacian_1, laplacian_2 = (
            laplacian_0.float().to(device),
            laplacian_1.float().to(device),
            laplacian_2.float().to(device),
        )
        optimizer.zero_grad()
        y_hat = model(x_0, x_1, x_2, laplacian_0, laplacian_1, laplacian_2)
        loss = loss_fn(y_hat, y)
        loss.backward()
        optimizer.step()
        epoch_loss.append(loss.item())
    print(
        f"Epoch: {epoch_i} loss: {np.mean(epoch_loss):.4f}",
        flush=True,
    )
    if epoch_i % test_interval == 0:
        with torch.no_grad():
            for x_0, x_1, x_2, laplacian_0, laplacian_1, laplacian_2, y in zip(
                x_0s_test,
                x_1s_test,
                x_2s_test,
                laplacian_0s_test,
                laplacian_1s_test,
                laplacian_2s_test,
                y_test,
            ):
                x_0, x_1, x_2, y = (
                    torch.tensor(x_0).float().to(device),
                    torch.tensor(x_1).float().to(device),
                    torch.tensor(x_2).float().to(device),
                    torch.tensor(y).float().to(device),
                )
                laplacian_0, laplacian_1, laplacian_2 = (
                    laplacian_0.float().to(device),
                    laplacian_1.float().to(device),
                    laplacian_2.float().to(device),
                )
                y_hat = model(x_0, x_1, x_2, laplacian_0, laplacian_1, laplacian_2)
                test_loss = loss_fn(y_hat, y)
            print(f"Test_loss: {test_loss:.4f}", flush=True)

Epoch: 1 loss: 304.6056
Epoch: 2 loss: 282.2707
Test_loss: 514.9831
Epoch: 3 loss: 232.8605
Epoch: 4 loss: 147.0164
Test_loss: 227.0106
Epoch: 5 loss: 87.9957
Epoch: 6 loss: 77.5624
Test_loss: 104.0872
Epoch: 7 loss: 77.0871
Epoch: 8 loss: 77.0391
Test_loss: 97.5802
