# Train a Simplicial Convolutional Neural Network (SCNN)

In this notebook, we will create and train a convolutional neural network in the simplicial complex domain, as proposed in the paper by [Yang et. al : SIMPLICIAL CONVOLUTIONAL NEURAL NETWORKS (2022)](https://arxiv.org/pdf/2110.02585.pdf). 

We train the model to perform binary node classification using the karate club dataset.  

In [1]:
import torch
import numpy as np

from toponetx import SimplicialComplex
import toponetx.datasets.graph as graph


In [2]:
dataset = graph.karate_club(complex_type="simplicial")
print(dataset)
max_rank = dataset.dim
print(max_rank)

Simplicial Complex with shape (34, 78, 45, 11, 2) and dimension 4
4


# Get incidence matrices and Hodge Laplacians

In [3]:
incidence_1 = dataset.incidence_matrix(rank=1)
incidence_1 = torch.from_numpy(incidence_1.todense()).to_sparse()
incidence_2 = dataset.incidence_matrix(rank=2)
incidence_2 = torch.from_numpy(incidence_2.todense()).to_sparse()
print(f"The incidence matrix B1 has shape: {incidence_1.shape}.")
print(f"The incidence matrix B2 has shape: {incidence_2.shape}.")

The incidence matrix B1 has shape: torch.Size([34, 78]).
The incidence matrix B2 has shape: torch.Size([78, 45]).


In [4]:
laplacian_0  = dataset.hodge_laplacian_matrix(rank=0,weight=True)
laplacian_down_1 = dataset.down_laplacian_matrix(rank=1,weight=True)
laplacian_up_1 = dataset.up_laplacian_matrix(rank=1,weight=True)
laplacian_down_2 = dataset.down_laplacian_matrix(rank=2,weight=True)
laplacian_up_2 = dataset.up_laplacian_matrix(rank=2,weight=True)

laplacian_0 = torch.from_numpy(laplacian_0.todense()).to_sparse()
laplacian_down_1 = torch.from_numpy(laplacian_down_1.todense()).to_sparse()
laplacian_up_1 = torch.from_numpy(laplacian_up_1.todense()).to_sparse()
laplacian_down_2 = torch.from_numpy(laplacian_down_2.todense()).to_sparse()
laplacian_up_2 = torch.from_numpy(laplacian_up_2.todense()).to_sparse()
    


# Import signal

In [5]:
x_0 = []
for _, v in dataset.get_simplex_attributes("node_feat").items():
    x_0.append(v)
x_0 = torch.tensor(np.stack(x_0))
channels_nodes = x_0.shape[-1]
x_1 = []
for k, v in dataset.get_simplex_attributes("edge_feat").items():
    x_1.append(v)
x_1 = np.stack(x_1)
chennel_edges = x_1.shape[-1]
x_2 = []
for k, v in dataset.get_simplex_attributes("face_feat").items():
    x_2.append(v)
x_2 = np.stack(x_2)
channel_faces = x_2.shape[-1]
print(f"There are {x_0.shape[0]} nodes with features of dimension {x_0.shape[1]}.")
print(f"There are {x_1.shape[0]} edges with features of dimension {x_1.shape[1]}.")
print(f"There are {x_2.shape[0]} faces with features of dimension {x_2.shape[1]}.")

There are 34 nodes with features of dimension 2.
There are 78 edges with features of dimension 2.
There are 45 faces with features of dimension 2.


In [6]:
"""A function to obtain features based on the input: rank
"""
def get_simplicial_features(dataset,rank):
    if rank == 0: 
        which_feat = "node_feat"
    elif rank == 1:
        which_feat = "edge_feat"
    elif rank == 2:
        which_feat = "face_feat"
    else:
        raise ValueError(f"input dimension must be 0, 1 or 2, because features are supported on nodes, edges and faces") 
    
    x = []
    for _, v in dataset.get_simplex_attributes(which_feat).items():
        x.append(v)
    
    x = torch.tensor(np.stack(x))
    return x


# Define binary labels

In [7]:
y = np.array(
    [
        1,
        1,
        1,
        1,
        1,
        1,
        1,
        1,
        1,
        0,
        1,
        1,
        1,
        1,
        0,
        0,
        1,
        1,
        0,
        1,
        0,
        1,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
    ]
)
y_true = np.zeros((34, 2))
y_true[:, 0] = y
y_true[:, 1] = 1 - y
y_test = y_true[-4:]
y_train = y_true[:30]

y_train = torch.from_numpy(y_train)
y_test = torch.from_numpy(y_test)

# Create the SCNN
Use the SCNNLayer classm we create a neural network with stacked layers. A final linear layer produces an output with shape $n_{\rm{nodes}}\times 2$, so we can compare with the binary labels

In [8]:
from topomodelx.nn.simplicial.scnn_layer import SCNNLayer

class SCNN(torch.nn.Module):
    """Simplicial convolutional neural network implementation for binary node classification. 
    
    Note: At the last layer, we obtain the output on simplcies, e.g., edges. To perform the node classification task, we consider a projection step which obtains the node labels from the edge output. 

    Parameters
    ----------
    
    """
    def __init__(self, in_channels, intermediate_channels, out_channels,conv_order_down,conv_order_up,aggr_norm=False,update_func=None, n_layers=2):
        super().__init__()
        # First layer -- initial layer has the in_channels as input, and inter_channels as the output
        layers = [SCNNLayer(in_channels=in_channels,out_channels=intermediate_channels,conv_order_down=conv_order_down,conv_order_up=conv_order_up)]

        for _ in range(n_layers-1):
            layers.append(
                SCNNLayer(in_channels=intermediate_channels,out_channels=out_channels,conv_order_down=conv_order_down,conv_order_up=conv_order_up,aggr_norm=aggr_norm,update_func=update_func)
            )
            
        self.linear = torch.nn.Linear(out_channels,2)
        self.layers = layers 

    def forward(self, x, laplacian_down, laplacian_up, incidence_1):
        """Forward computation.
        
        Parameters
        ---------
        x: tensor
            shape = [n_simplices, channels]
            node/edge/face features
        
        laplacian: tensor
            shape = [n_simplices,n_simplices]
            For node features, laplacian_down = None
        """
        for layer in self.layers:
            x = layer(x,laplacian_down,laplacian_up)
        """Project the output from edges to nodes 
        incidence_1 @ x
        """
        logits = self.linear(incidence_1 @ x)
        return torch.softmax(logits,dim=-1)

# Train the SCNN 

In [9]:
"""Select the simplex order, i.e., on which level of simplices the learning will be performed 
"""
rank = 1 # simplex level 
conv_order_down = 2
conv_order_up = 2
x = get_simplicial_features(dataset,rank)
channels_x = x.shape[-1]
if rank == 0: 
    laplacian_down = None
    laplacian_up = laplacian_0 # the graph laplacian 
    conv_order_down = 0
elif rank == 1:
    laplacian_down = laplacian_down_1 
    laplacian_up = laplacian_up_1 
elif rank == 2:
    laplacian_down = laplacian_down_2 
    laplacian_up = laplacian_up_2 
else: 
    raise ValueError(
        f"Rank must be not larger than 2 on this dataset"
    )
    
intermediate_channels = 16
out_channels = intermediate_channels
num_layers = 5
model = SCNN(in_channels=channels_nodes,intermediate_channels=intermediate_channels,out_channels=out_channels,conv_order_down=conv_order_down,conv_order_up=conv_order_up,n_layers=num_layers
)

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

In [14]:
test_interval = 2
num_epochs = 5
for epoch_i in range(1, num_epochs + 1):
    epoch_loss = []
    model.train()
    optimizer.zero_grad()

    y_hat = model(x, laplacian_down, laplacian_up, incidence_1)
    print(y_hat.shape)
    loss = torch.nn.functional.binary_cross_entropy_with_logits(
        y_hat[: len(y_train)].float(), y_train.float()
    )
    epoch_loss.append(loss.item())
    loss.backward()
    optimizer.step()

    y_pred = torch.where(y_hat > 0.5, torch.tensor(1), torch.tensor(0))
    accuracy = (y_pred[-len(y_train) :] == y_train).all(dim=1).float().mean().item()
    print(
        f"Epoch: {epoch_i} loss: {np.mean(epoch_loss):.4f} Train_acc: {accuracy:.4f}",
        flush=True,
    )
    if epoch_i % test_interval == 0:
        with torch.no_grad():
            y_hat_test = model(x, laplacian_down, laplacian_up, incidence_1)
            y_pred_test = torch.where(
                y_hat_test > 0.5, torch.tensor(1), torch.tensor(0)
            )
            test_accuracy = (
                torch.eq(y_pred_test[-len(y_test) :], y_test)
                .all(dim=1)
                .float()
                .mean()
                .item()
            )
            print(f"Test_acc: {test_accuracy:.4f}", flush=True)

torch.Size([34, 2])
Epoch: 1 loss: 0.6365 Train_acc: 0.6667
torch.Size([34, 2])
Epoch: 2 loss: 0.6365 Train_acc: 0.6667
Test_acc: 0.7500
torch.Size([34, 2])
Epoch: 3 loss: 0.6365 Train_acc: 0.6667
torch.Size([34, 2])
Epoch: 4 loss: 0.6365 Train_acc: 0.6667
Test_acc: 0.7500
torch.Size([34, 2])
Epoch: 5 loss: 0.6365 Train_acc: 0.6667
