# Train a SCCNN

In this notebook, we will create and train a High Skip Network in the simplicial complex domain, as proposed in the paper by [Yang et. al : Convolutional Learning on Simplicial Complexes (2023)](https://arxiv.org/abs/2301.11163). 

We train the model to perform binary node classification using the KarateClub benchmark dataset. 

In [1]:
import torch
import numpy as np
from sklearn.model_selection import train_test_split



from toponetx import SimplicialComplex
import toponetx.datasets as datasets

from topomodelx.nn.simplicial.scnn_layer import SCNNLayer

# Pre-processing

## Import dataset ##

We must first lift our graph dataset into the simplicial complex domain.

In [3]:
shrec, _ = datasets.mesh.shrec_16(size="small")

shrec = {key: np.array(value) for key, value in shrec.items()}
 
x_0s = shrec["node_feat"]
x_1s = shrec["edge_feat"]
x_2s = shrec["face_feat"]

ys = shrec["label"]
simplexes = shrec["complexes"]

Loading shrec 16 small dataset...

done!


## Consider using edge features for classification 

In [7]:
in_channels_0 = x_0s[-1].shape[1]
in_channels_1 = x_1s[-1].shape[1]
in_channels_2 = x_2s[-1].shape[1]



In [15]:
max_rank = 2 # the order of the SC is two 
incidence_1_list = []
incidence_2_list = []

laplacian_0_list = []
laplacian_down_1_list = []
laplacian_up_1_list = []
laplacian_2_list = []
 
for simplex in simplexes: 
    incidence_1 = simplex.incidence_matrix(rank=1)
    incidence_2 = simplex.incidence_matrix(rank=2)
    laplacian_0 = simplex.hodge_laplacian_matrix(rank=0,weight=True)
    laplacian_down_1 = simplex.down_laplacian_matrix(rank=1,weight=True)
    laplacian_up_1 = simplex.up_laplacian_matrix(rank=1,weight=True)
    laplacian_2 = simplex.hodge_laplacian_matrix(rank=2,weight=True)
    
    incidence_1 = torch.from_numpy(incidence_1.todense()).to_sparse()
    incidence_2 = torch.from_numpy(incidence_2.todense()).to_sparse()
    laplacian_0 = torch.from_numpy(laplacian_0.todense()).to_sparse()
    laplacian_down_1 = torch.from_numpy(laplacian_down_1.todense()).to_sparse()
    laplacian_up_1 = torch.from_numpy(laplacian_up_1.todense()).to_sparse()
    laplacian_2 = torch.from_numpy(laplacian_2.todense()).to_sparse()
    
    incidence_1_list.append(incidence_1)
    incidence_2_list.append(incidence_2)
    laplacian_0_list.append(laplacian_0)
    laplacian_down_1_list.append(laplacian_down_1)
    laplacian_up_1_list.append(laplacian_up_1)
    laplacian_2_list.append(laplacian_2)
    

    

# Create the SCCNN

In [16]:
class SCNN(torch.nn.Module):
    """SCCNN implementation for binary node classification 
    Note: In this task, we direcly consider the finaly output on the nodes, which is passed by a linear layer, as the label output. 

    Parameters
    """
    def __init__(self, in_channels, intermediate_channels,out_channels, conv_order_down, conv_order_up, aggr_norm=False,update_func="sigmoid", n_layers=2):
        super().__init__()
        # first layer 
        # self.in_linear_1 = torch.nn.Linear(in_channels_all[1],intermediate_channels_all[1]) 
        layers = [SCNNLayer(in_channels=in_channels,out_channels=intermediate_channels,conv_order_down=conv_order_down,conv_order_up=conv_order_up)]
        for _ in range(n_layers-1):
            layers.append(
                SCNNLayer(in_channels=intermediate_channels,out_channels=out_channels,conv_order_down=conv_order_down,conv_order_up=conv_order_up,aggr_norm=aggr_norm,update_func=update_func)
            )
            
        self.layers = layers    
        self.out_linear = torch.nn.Linear(out_channels,1)


    def forward(self,x,laplacian_down, laplacian_up):
        """Forward computation. 
        
        Parameters
        ----------
        """

        for layer in self.layers:
            x = layer(x,laplacian_down,laplacian_up)
        
        pooled_x = torch.max(x,dim=0)[0]
        y = torch.sigmoid(self.out_linear(pooled_x))[0]
 
        return y  

# Train the Neural Network

We specify the model with our pre-made neighborhood structures and specify an optimizer.

In [20]:
rank = 1 # simplex level 
conv_order_down = 2
conv_order_up = 2
intermediate_channels = 4
out_channels = intermediate_channels
num_layers = 2

# select the simplex level
if rank == 0: 
    laplacian_down = None
    laplacian_up = laplacian_0_list # the graph laplacian 
    conv_order_down = 0
    x = x_0s
    in_channels = in_channels_0
elif rank == 1:
    laplacian_down = laplacian_down_1_list
    laplacian_up = laplacian_up_1_list
    x = x_1s
    in_channels = in_channels_1
elif rank == 2:
    laplacian_down = laplacian_2_list
    laplacian_up = None 
    x = x_2s
    in_channels = in_channels_2 
else: 
    raise ValueError(
        f"Rank must be not larger than 2 on this dataset"
    )


In [18]:
model = SCNN(in_channels = in_channels,intermediate_channels = intermediate_channels,out_channels = out_channels, conv_order_down=conv_order_down,conv_order_up=conv_order_up,n_layers=num_layers
)

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
loss_fn = torch.nn.MSELoss()

In [21]:
test_size = 0.2
x_train, x_test = train_test_split(x, test_size=test_size, shuffle=False)

laplacian_down_train, laplacian_down_test = train_test_split(laplacian_down, test_size=test_size, shuffle=False)
laplacian_up_train, laplacian_up_test = train_test_split(laplacian_up, test_size=test_size, shuffle=False)
y_train, y_test = train_test_split(ys, test_size=test_size, shuffle=False)

In [24]:
test_interval = 1
num_epochs = 5

# select which feature to use for labeling
simplex_order_select = 1

for epoch_i in range(1, num_epochs + 1):
    epoch_loss = []
    model.train()
    for x, laplacian_down, laplacian_up, y in zip(x_train, laplacian_down_train, laplacian_up_train, y_train):

        x = torch.tensor(x,dtype=torch.float)
        y = torch.tensor(y,dtype=torch.float)
        optimizer.zero_grad()

        y_hat = model(x, laplacian_down, laplacian_up)

        # print(y_hat.shape)
        loss = loss_fn(y_hat, y)

        epoch_loss.append(loss.item())
        loss.backward()
        optimizer.step()    

    print(
        f"Epoch: {epoch_i} loss: {np.mean(epoch_loss):.4f}",
        flush=True,
    )
    if epoch_i % test_interval == 0:
        with torch.no_grad():
            for x, laplacian_down, laplacian_up, y in zip(x_test, laplacian_down_test, laplacian_up_test, y_test):
    
                x = torch.tensor(x,dtype=torch.float)
                y = torch.tensor(y,dtype=torch.float)
                optimizer.zero_grad()

                y_hat = model(x, laplacian_down, laplacian_up)

                    
                loss = loss_fn(y_hat, y)
            print(f"Test_loss: {loss:.4f}", flush=True)

  if not is_compiling() and torch.has_cuda and torch.cuda.is_available():


Epoch: 1 loss: 279.5853
Test_loss: 531.4301
Epoch: 2 loss: 275.5110
Test_loss: 529.9162
Epoch: 3 loss: 275.0299
Test_loss: 529.5111
Epoch: 4 loss: 274.8655
Test_loss: 529.3348
Epoch: 5 loss: 274.7860
Test_loss: 529.2397
