# Train a Simplicial Convolutional Neural Network (SCNN)

In this notebook, we will create and train a convolutional neural network in the simplicial complex domain, as proposed in the paper by [Yang et. al : SIMPLICIAL CONVOLUTIONAL NEURAL NETWORKS (2022)](https://arxiv.org/pdf/2110.02585.pdf). 

We train the model to perform binary node classification using the karate club dataset.

## Simplicial Convolutional Neural Networks <a href="https://arxiv.org/pdf/2110.02585.pdf">[SCNN]</a>

At layer $t$, given the input simplicial (edge) feature matrix $\mathbf{H}_t$, the SCNN layer is defined as 
$$
    \mathbf{H}_{t+1} = \sigma \Bigg[ \mathbf{H}_t\mathbf{\Theta}_t + \sum_{p_d=1}^{P_d}(\mathbf{L}_{\downarrow,1})^{p_d}\mathbf{H}_t\mathbf{\Theta}_{t,p_d} + \sum_{p_u=1}^{P_u}(\mathbf{L}_{\uparrow,1})^{p_u}\mathbf{H}_{t}\mathbf{\Theta}_{t,p_u} \Bigg]
$$
where $p_d$ and $p_u$ are the lower and upper convolution orders, respectively, and $\mathbf{\Theta}_{t,p_d}$ and $\mathbf{\Theta}_{t,p_u}$ are the learnable weights.
One can use $(\mathbf{L}_{\uparrow,1})^{p_u}$ and $(\mathbf{L}_{\uparrow,1})^{p_d}$ to perform higher-order upper and lower convolutions.


To align with the notations in [Papillon et al : Architectures of Topological Deep Learning: A Survey of Topological Neural Networks (2023)](https://arxiv.org/abs/2304.10031), we can use the following to denote the above layer definition. 

🟥 $\quad m_{y \rightarrow \{z\} \rightarrow x}^{p_u(1 \rightarrow 2 \rightarrow 1)}  = ((L_{\uparrow,1})^{p_u})_{xy} \cdot h_y^{t,(1)} \cdot \theta^{t, p_u} $  -------- Aggregate from $p_u$-hop upper neighbor $y$ to $x$

🟥 $\quad m_{y \rightarrow \{z\} \rightarrow x}^{p_d(1 \rightarrow 0 \rightarrow 1)} = ((L_{\downarrow,1})^{p_d})_{xy} \cdot h_y^{t,(1)} \cdot \theta^{t, p_d} $ -------- Aggregate from $p_d$-hop lower neighbor $y$ to $x$

🟥 $\quad m^{(1 \rightarrow 1)}_{x \rightarrow x} = \theta^t \cdot h_x^{t, (1)}$ --------  Aggregate from $x$ itself

🟧 $\quad m_{x}^{p_u,(1 \rightarrow 2 \rightarrow 1)}  = \sum_{y \in \mathcal{L}_\uparrow(x)}m_{y \rightarrow \{z\} \rightarrow x}^{p_u,(1 \rightarrow 2 \rightarrow 1)}$  -------- Collect the aggregated information from the upper neighborhood

🟧 $\quad m_{x}^{p_d,(1 \rightarrow 0 \rightarrow 1)} = \sum_{y \in \mathcal{L}_\downarrow(x)}m_{y \rightarrow \{z\} \rightarrow x}^{p_d,(1 \rightarrow 0 \rightarrow 1)}$ -------- Collect the aggregated information from the lower neighborhood

🟧 $\quad m^{(1 \rightarrow 1)}_{x} = m^{(1 \rightarrow 1)}_{x \rightarrow x}$

🟩 $\quad m_x^{(1)}  = m_x^{(1 \rightarrow 1)} + \sum_{p_u=1}^{P_u} m_{x}^{p_u,(1 \rightarrow 2 \rightarrow 1)} + \sum_{p_d=1}^{P_d} m_{x}^{p_d,(1 \rightarrow 0 \rightarrow 1)}$ -------- Collect all the aggregated information 

🟦 $\quad h_x^{t+1, (1)} = \sigma(m_x^{(1)})$ -------- Pass through the nonlinearity



In [1]:
import torch
import numpy as np
from sklearn.model_selection import train_test_split



from toponetx import SimplicialComplex
import toponetx.datasets as datasets

from topomodelx.nn.simplicial.scnn_layer import SCNNLayer

# Pre-processing

## Import shrec dataset ##

We must first lift our graph dataset into the simplicial complex domain.

In [2]:
shrec, _ = datasets.mesh.shrec_16(size="small")

shrec = {key: np.array(value) for key, value in shrec.items()}
 
x_0s = shrec["node_feat"]
x_1s = shrec["edge_feat"]
x_2s = shrec["face_feat"]

ys = shrec["label"]
simplexes = shrec["complexes"]

Loading shrec 16 small dataset...

done!


## Consider using edge features for classification 

In [3]:
in_channels_0 = x_0s[-1].shape[1]
in_channels_1 = x_1s[-1].shape[1]
in_channels_2 = x_2s[-1].shape[1]

### Define Neighborhood Strctures
Get incidence matrices and Hodge Laplacians

In [4]:
max_rank = 2 # the order of the SC is two 
incidence_1_list = []
incidence_2_list = []

laplacian_0_list = []
laplacian_down_1_list = []
laplacian_up_1_list = []
laplacian_2_list = []
 
for simplex in simplexes: 
    incidence_1 = simplex.incidence_matrix(rank=1)
    incidence_2 = simplex.incidence_matrix(rank=2)
    laplacian_0 = simplex.hodge_laplacian_matrix(rank=0,weight=True)
    laplacian_down_1 = simplex.down_laplacian_matrix(rank=1,weight=True)
    laplacian_up_1 = simplex.up_laplacian_matrix(rank=1,weight=True)
    laplacian_2 = simplex.hodge_laplacian_matrix(rank=2,weight=True)
    
    incidence_1 = torch.from_numpy(incidence_1.todense()).to_sparse()
    incidence_2 = torch.from_numpy(incidence_2.todense()).to_sparse()
    laplacian_0 = torch.from_numpy(laplacian_0.todense()).to_sparse()
    laplacian_down_1 = torch.from_numpy(laplacian_down_1.todense()).to_sparse()
    laplacian_up_1 = torch.from_numpy(laplacian_up_1.todense()).to_sparse()
    laplacian_2 = torch.from_numpy(laplacian_2.todense()).to_sparse()
    
    incidence_1_list.append(incidence_1)
    incidence_2_list.append(incidence_2)
    laplacian_0_list.append(laplacian_0)
    laplacian_down_1_list.append(laplacian_down_1)
    laplacian_up_1_list.append(laplacian_up_1)
    laplacian_2_list.append(laplacian_2)

# Create the SCNN

In [5]:
from topomodelx.nn.simplicial.scnn_layer import SCNNLayer

class SCNN(torch.nn.Module):
    """Simplicial convolutional neural network implementation for complex classification. 
    
    Note: At the last layer, we obtain the output on simplcies, e.g., edges. 
    To perform the complex classification task for this challenge, we consider pass the final output to a linear layer and compute the average. 

    Parameters
    ----------
    in_channels: int
        Dimension of input features 
    intermediate_channels: int
        Dimension of features of intermediate layers
    out_channels: int
        Dimension of output features
    conv_order_down: int
        Order of lower convolution
    conv_order_up: int
        Order of upper convolution 
    n_layers: int
        Numer of layers 
    """
    def __init__(self, in_channels, intermediate_channels, out_channels,conv_order_down,conv_order_up,aggr_norm=False,update_func=None, n_layers=2):
        super().__init__()
        # First layer -- initial layer has the in_channels as input, and inter_channels as the output
        layers = [SCNNLayer(in_channels=in_channels,out_channels=intermediate_channels,conv_order_down=conv_order_down,conv_order_up=conv_order_up)]

        for _ in range(n_layers-1):
            layers.append(
                SCNNLayer(in_channels=intermediate_channels,out_channels=out_channels,conv_order_down=conv_order_down,conv_order_up=conv_order_up,aggr_norm=aggr_norm,update_func=update_func)
            )
            
        self.linear = torch.nn.Linear(out_channels,1)
        self.layers = torch.nn.ModuleList(layers)  

    def forward(self, x, laplacian_down, laplacian_up):
        """Forward computation.
        
        Parameters
        ---------
        x: tensor
            shape = [n_simplices, channels]
            node/edge/face features
        
        laplacian: tensor
            shape = [n_simplices,n_simplices]
            For node features, laplacian_down = None

        incidence_1: tensor 
            order 1 incidence matrix 
            shape = [n_edges, n_nodes]
        """
        for layer in self.layers:
            x = layer(x,laplacian_down,laplacian_up)

        x_1 = self.linear(x)
        one_dimensional_cells_mean = torch.nanmean(x_1, dim=0)
        one_dimensional_cells_mean[torch.isnan(one_dimensional_cells_mean)] = 0
        
        return one_dimensional_cells_mean

# Train the Neural Network

We specify the model with our pre-made neighborhood structures and specify an optimizer.

In [6]:
rank = 1 # simplex level 
conv_order_down = 2
conv_order_up = 2
intermediate_channels = 4
out_channels = intermediate_channels
num_layers = 2

# select the simplex level
if rank == 0: 
    laplacian_down = None
    laplacian_up = laplacian_0_list # the graph laplacian 
    conv_order_down = 0
    x = x_0s
    in_channels = in_channels_0
elif rank == 1:
    laplacian_down = laplacian_down_1_list
    laplacian_up = laplacian_up_1_list
    x = x_1s
    in_channels = in_channels_1
elif rank == 2:
    laplacian_down = laplacian_2_list
    laplacian_up = None 
    x = x_2s
    in_channels = in_channels_2 
else: 
    raise ValueError(
        f"Rank must be not larger than 2 on this dataset"
    )


In [7]:
model = SCNN(in_channels = in_channels,intermediate_channels = intermediate_channels,out_channels = out_channels, conv_order_down=conv_order_down,conv_order_up=conv_order_up,n_layers=num_layers
)

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
loss_fn = torch.nn.MSELoss()

In [8]:
test_size = 0.2
x_train, x_test = train_test_split(x, test_size=test_size, shuffle=False)

laplacian_down_train, laplacian_down_test = train_test_split(laplacian_down, test_size=test_size, shuffle=False)
laplacian_up_train, laplacian_up_test = train_test_split(laplacian_up, test_size=test_size, shuffle=False)
y_train, y_test = train_test_split(ys, test_size=test_size, shuffle=False)

In [9]:
test_interval = 1
num_epochs = 5

# select which feature to use for labeling
simplex_order_select = 1

for epoch_i in range(1, num_epochs + 1):
    epoch_loss = []
    model.train()
    for x, laplacian_down, laplacian_up, y in zip(x_train, laplacian_down_train, laplacian_up_train, y_train):

        x = torch.tensor(x,dtype=torch.float)
        y = torch.tensor(y,dtype=torch.float)
        optimizer.zero_grad()

        y_hat = model(x, laplacian_down, laplacian_up)

        # print(y_hat.shape)
        loss = loss_fn(y_hat, y)

        epoch_loss.append(loss.item())
        loss.backward()
        optimizer.step()    

    print(
        f"Epoch: {epoch_i} loss: {np.mean(epoch_loss):.4f}",
        flush=True,
    )
    if epoch_i % test_interval == 0:
        with torch.no_grad():
            for x, laplacian_down, laplacian_up, y in zip(x_test, laplacian_down_test, laplacian_up_test, y_test):
    
                x = torch.tensor(x,dtype=torch.float)
                y = torch.tensor(y,dtype=torch.float)
                optimizer.zero_grad()

                y_hat = model(x, laplacian_down, laplacian_up)

                    
                loss = loss_fn(y_hat, y)
            print(f"Test_loss: {loss:.4f}", flush=True)

  return F.mse_loss(input, target, reduction=self.reduction)
  if not is_compiling() and torch.has_cuda and torch.cuda.is_available():


Epoch: 1 loss: 261.3912
Test_loss: 97.1429
Epoch: 2 loss: 94.8871
Test_loss: 120.7043
Epoch: 3 loss: 83.3380
Test_loss: 122.5341
Epoch: 4 loss: 83.5518
Test_loss: 120.2451
Epoch: 5 loss: 85.0916
Test_loss: 110.6804
