# Train a Convolutional Cell Complex Network (CCXN)

We create and train a simplified version of the CCXN originally proposed in [Hajij et. al : Cell Complex Neural Networks (2020)](https://arxiv.org/pdf/2010.00743.pdf).

### The Neural Network:

The equations of one layer of this neural network are given by:

1. A convolution from nodes to nodes using an adjacency message passing scheme (AMPS):

🟥 $\quad m_{y \rightarrow \{z\} \rightarrow x}^{(0 \rightarrow 1 \rightarrow 0)} = M_{\mathcal{L}_\uparrow}^t(h_x^{t,(0)}, h_y^{t,(0)}, \Theta^{t,(y \rightarrow x)})$ 

🟧 $\quad m_x^{(0 \rightarrow 1 \rightarrow 0)} = AGG_{y \in \mathcal{L}_\uparrow(x)}(m_{y \rightarrow \{z\} \rightarrow x}^{0 \rightarrow 1 \rightarrow 0})$ 

🟩 $\quad m_x^{(0)} = m_x^{(0 \rightarrow 1 \rightarrow 0)}$ 

🟦 $\quad h_x^{t+1,(0)} = U^{t}(h_x^{t,(0)}, m_x^{(0)})$

2. A convolution from edges to faces using a cohomology message passing scheme:

🟥 $\quad m_{y \rightarrow x}^{(r' \rightarrow r)} = M^t_{\mathcal{C}}(h_{x}^{t,(r)}, h_y^{t,(r')}, x, y)$ 

🟧 $\quad m_x^{(r' \rightarrow r)}  = AGG_{y \in \mathcal{C}(x)} m_{y \rightarrow x}^{(r' \rightarrow r)}$ 

🟩 $\quad m_x^{(r)} = m_x^{(r' \rightarrow r)}$ 

🟦 $\quad h_{x}^{t+1,(r)} = U^{t,(r)}(h_{x}^{t,(r)}, m_{x}^{(r)})$

Where the notations are defined in [Papillon et al : Architectures of Topological Deep Learning: A Survey of Topological Neural Networks (2023)](https://arxiv.org/abs/2304.10031).

### The Task:

We train this model to perform entire complex classification on [`MUTAG` from the TUDataset](https://paperswithcode.com/dataset/mutag). This dataset contains:
- 188 samples of chemical compounds represented as graphs,
- with 7 discrete node features.

The task is to predict the mutagenicity of each compound on Salmonella typhimurium.

# Set-up


In [20]:
import random

import numpy as np
import torch
from sklearn.model_selection import train_test_split
from toponetx import CellComplex
from torch_geometric.datasets import TUDataset
from torch_geometric.utils.convert import to_networkx
import torch.nn.functional as F

from topomodelx.nn.cell.can_layer import CANLayer
from topomodelx.nn.cell.attentional_lift_layer import MultiHeadLiftLayer
from topomodelx.nn.cell.attentional_pooling_layer import PoolLayer

If GPU's are available, we will make use of them. Otherwise, this will run on CPU.

In [21]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cpu


# Pre-processing

## Import data ##

We import a subset of MUTAG, a benchmark dataset for graph classification. 

We then lift each graph into our topological domain of choice, here: a cell complex.

We also retrieve:
- input signals `x_0` and `x_1` on the nodes (0-cells) and edges (1-cells) for each complex: these will be the model's inputs,
- a binary classification label `y` associated to the cell complex.

In [22]:
dataset = TUDataset(
    root="/tmp/MUTAG", name="MUTAG", use_edge_attr=True, use_node_attr=True
)
dataset = dataset[:20]
cc_list = []
x_0_list = []
x_1_list = []
y_list = []
for graph in dataset:
    cell_complex = CellComplex(to_networkx(graph))
    cc_list.append(cell_complex)
    x_0_list.append(graph.x)
    x_1_list.append(graph.edge_attr)
    y_list.append(int(graph.y))
else:
    print(graph)

i_cc = 0
print(f"Features on nodes for the {i_cc}th cell complex: {x_0_list[i_cc].shape}.")
print(f"Features on edges for the {i_cc}th cell complex: {x_1_list[i_cc].shape}.")
print(f"Label of {i_cc}th cell complex: {y_list[i_cc]}.")

Data(edge_index=[2, 40], x=[18, 7], edge_attr=[40, 4], y=[1])
Features on nodes for the 0th cell complex: torch.Size([17, 7]).
Features on edges for the 0th cell complex: torch.Size([38, 4]).
Label of 0th cell complex: 1.


## Define neighborhood structures. ##

Implementing the CCXN architecture will require to perform message passing along neighborhood structures of the cell complexes.

Thus, now we retrieve these neighborhood structures (i.e. their representative matrices) that we will use to send messages. 

For the CCXN, we need the adjacency matrix $A_{\uparrow, 0}$ and the coboundary matrix $B_2^T$ of each cell complex.

In [23]:
lower_neighborhood_list = []
upper_neighborhood_list = []
adjacency_0_list = []

for cell_complex in cc_list:
    adjacency_0 = cell_complex.adjacency_matrix(rank=0)
    adjacency_0 = torch.from_numpy(adjacency_0.todense()).to_sparse()
    adjacency_0_list.append(adjacency_0)

    lower_neighborhood_t = cell_complex.down_laplacian_matrix(rank=1)
    lower_neighborhood_t = torch.from_numpy(lower_neighborhood_t.todense()).to_sparse()
    lower_neighborhood_list.append(lower_neighborhood_t)

    try:
        upper_neighborhood_t = cell_complex.up_laplacian_matrix(rank=1)
        upper_neighborhood_t = torch.from_numpy(
            upper_neighborhood_t.todense()
        ).to_sparse()
    except:
        upper_neighborhood_t = np.zeros(
            (lower_neighborhood_t.shape[0], lower_neighborhood_t.shape[0])
        )
        upper_neighborhood_t = torch.from_numpy(upper_neighborhood_t).to_sparse()

    upper_neighborhood_list.append(upper_neighborhood_t)

i_cc = 0

# Create the Neural Network

Using the CCXNLayer class, we create a neural network with stacked layers.

In [24]:
in_channels_0 = x_0_list[0].shape[-1]
in_channels_1 = x_1_list[0].shape[-1]
in_channels_2 = 5
print(
    f"The dimension of input features on nodes, edges and faces are: {in_channels_0}, {in_channels_1} and {in_channels_2}."
)

The dimension of input features on nodes, edges and faces are: 7, 4 and 5.


In [35]:
class CAN(torch.nn.Module):
    def __init__(
        self,
        in_channels_0,
        in_channels_1,
        out_channels,
        num_classes,
        dropout=0.5,
        heads=3,
        concat=True,
        skip_connection=True,
        att_activation=torch.nn.LeakyReLU(0.2),
        n_layers=2,
        att_lift=True,
    ):
        super().__init__()

        if att_lift:
            self.lift_layer = MultiHeadLiftLayer(
                in_channels_0=in_channels_0,
                heads=in_channels_0,
                signal_lift_dropout=0.5,
            )
            in_channels_1 = in_channels_1 + in_channels_0

        layers = []

        layers.append(
            CANLayer(
                in_channels=in_channels_1,
                out_channels=out_channels,
                heads=heads,
                concat=concat,
                skip_connection=skip_connection,
                att_activation=att_activation,
                aggr_func="sum",
                update_func="relu",
            )
        )

        for _ in range(n_layers - 1):
            layers.append(
                CANLayer(
                    in_channels=out_channels * heads,
                    out_channels=out_channels,
                    dropout=dropout,
                    heads=heads,
                    concat=concat,
                    skip_connection=skip_connection,
                    att_activation=att_activation,
                    aggr_func="sum",
                    update_func="relu",
                )
            )

            if _ == n_layers-2:
                layers.append(
                    PoolLayer(
                        k_pool=0.5,
                        in_channels_0=out_channels * heads,
                        signal_pool_activation=torch.nn.Sigmoid(),
                        readout=True,
                    )
                )

        self.layers = torch.nn.ModuleList(layers)
        self.lin_0 = torch.nn.Linear(heads * out_channels, 128)
        self.lin_1 = torch.nn.Linear(128, num_classes)

    def forward(
        self, x_0, x_1, neighborhood_0_to_0, lower_neighborhood, upper_neighborhood
    ):
        if hasattr(self, "lift_layer"):
            x_1 = self.lift_layer(x_0, neighborhood_0_to_0, x_1)

        for layer in self.layers:
            x_1 = layer(x_1, lower_neighborhood, upper_neighborhood)
            x_1 = F.dropout(x_1, p=0.5, training=self.training)

        # max pooling over all nodes in each graph
        x = x_1.max(dim=0)[0]

        # Feed-Foward Neural Network to predict the graph label
        out = self.lin_1(torch.nn.functional.relu(self.lin_0(x)))

        return out

# Train the Neural Network

We specify the model, initialize loss, and specify an optimizer. We first try it without any attention mechanism.

In [36]:
model = CAN(
    in_channels_0,
    in_channels_1,
    32,
    dropout=0.5,
    heads=3,
    num_classes=2,
    n_layers=2,
    att_lift=True,
)
model = model.to(device)
crit = torch.nn.CrossEntropyLoss()
opt = torch.optim.Adam(model.parameters(), lr=0.001)

In [37]:
model

CAN(
  (lift_layer): MultiHeadLiftLayer(
    (lifts): LiftLayer()
  )
  (layers): ModuleList(
    (0): CANLayer(
      (lower_att): MultiHeadCellAttention(
        (att_activation): LeakyReLU(negative_slope=0.2)
        (lin): Linear(in_features=11, out_features=96, bias=False)
      )
      (upper_att): MultiHeadCellAttention(
        (att_activation): LeakyReLU(negative_slope=0.2)
        (lin): Linear(in_features=11, out_features=96, bias=False)
      )
      (lin): Linear(in_features=11, out_features=96, bias=False)
      (aggregation): Aggregation()
    )
    (1): CANLayer(
      (lower_att): MultiHeadCellAttention(
        (att_activation): LeakyReLU(negative_slope=0.2)
        (lin): Linear(in_features=96, out_features=96, bias=False)
      )
      (upper_att): MultiHeadCellAttention(
        (att_activation): LeakyReLU(negative_slope=0.2)
        (lin): Linear(in_features=96, out_features=96, bias=False)
      )
      (lin): Linear(in_features=96, out_features=96, bias=False)
 

We split the dataset into train and test sets.

In [38]:
test_size = 0.3
x_1_train, x_1_test = train_test_split(x_1_list, test_size=test_size, shuffle=False)
x_0_train, x_0_test = train_test_split(x_0_list, test_size=test_size, shuffle=False)
lower_neighborhood_train, lower_neighborhood_test = train_test_split(
    lower_neighborhood_list, test_size=test_size, shuffle=False
)
upper_neighborhood_train, upper_neighborhood_test = train_test_split(
    upper_neighborhood_list, test_size=test_size, shuffle=False
)
adjacency_0_train, adjacency_0_test = train_test_split(
    adjacency_0_list, test_size=test_size, shuffle=False
)
y_train, y_test = train_test_split(y_list, test_size=test_size, shuffle=False)

We train the CCXN using low amount of epochs: we keep training minimal for the purpose of rapid testing.

In [39]:
test_interval = 1
num_epochs = 5
for epoch_i in range(1, num_epochs + 1):
    epoch_loss = []
    num_samples = 0
    correct = 0
    model.train()
    for x_0, x_1, adjacency, lower_neighborhood, upper_neighborhood, y in zip(
        x_0_train,
        x_1_train,
        adjacency_0_train,
        lower_neighborhood_train,
        upper_neighborhood_train,
        y_train,
    ):
        x_0 = x_0.float().to(device)
        x_1, y = x_1.float().to(device), torch.tensor(y, dtype=torch.long).to(device)
        adjacency = adjacency.float().to(device)
        lower_neighborhood, upper_neighborhood = lower_neighborhood.float().to(
            device
        ), upper_neighborhood.float().to(device)
        opt.zero_grad()
        y_hat = model(x_0, x_1, adjacency, lower_neighborhood, upper_neighborhood)
        loss = crit(y_hat, y)
        correct += (y_hat.argmax() == y).sum().item()
        num_samples += 1
        loss.backward()
        opt.step()
        epoch_loss.append(loss.item())
    train_acc = correct / num_samples
    print(
        f"Epoch: {epoch_i} loss: {np.mean(epoch_loss):.4f} Train_acc: {train_acc:.4f}",
        flush=True,
    )
    if epoch_i % test_interval == 0:
        with torch.no_grad():
            num_samples = 0
            correct = 0
            for x_0, x_1, adjacency, lower_neighborhood, upper_neighborhood, y in zip(
                x_0_test,
                x_1_test,
                adjacency_0_test,
                lower_neighborhood_test,
                upper_neighborhood_test,
                y_test,
            ):
                x_0 = x_0.float().to(device)
                x_1, y = x_1.float().to(device), torch.tensor(y, dtype=torch.long).to(
                    device
                )
                adjacency = adjacency.float().to(device)
                lower_neighborhood, upper_neighborhood = lower_neighborhood.float().to(
                    device
                ), upper_neighborhood.float().to(device)
                y_hat = model(
                    x_0, x_1, adjacency, lower_neighborhood, upper_neighborhood
                )
                correct += (y_hat.argmax() == y).sum().item()
                num_samples += 1
            test_acc = correct / num_samples
            print(f"Test_acc: {test_acc:.4f}", flush=True)

NotImplementedError: Could not run 'aten::index.Tensor' with arguments from the 'SparseCPU' backend. This could be because the operator doesn't exist for this backend, or was omitted during the selective/custom build process (if using custom build). If you are a Facebook employee using PyTorch on mobile, please visit https://fburl.com/ptmfixes for possible resolutions. 'aten::index.Tensor' is only available for these backends: [CPU, CUDA, HIP, MPS, IPU, XPU, HPU, VE, Meta, MTIA, PrivateUse1, PrivateUse2, PrivateUse3, FPGA, ORT, Vulkan, Metal, QuantizedCPU, QuantizedCUDA, QuantizedHIP, QuantizedMPS, QuantizedIPU, QuantizedXPU, QuantizedHPU, QuantizedVE, QuantizedMeta, QuantizedMTIA, QuantizedPrivateUse1, QuantizedPrivateUse2, QuantizedPrivateUse3, CustomRNGKeyId, MkldnnCPU, SparseCsrCPU, SparseCsrCUDA, BackendSelect, Python, FuncTorchDynamicLayerBackMode, Functionalize, Named, Conjugate, Negative, ZeroTensor, ADInplaceOrView, AutogradOther, AutogradCPU, AutogradCUDA, AutogradHIP, AutogradXLA, AutogradMPS, AutogradIPU, AutogradXPU, AutogradHPU, AutogradVE, AutogradLazy, AutogradMeta, AutogradMTIA, AutogradPrivateUse1, AutogradPrivateUse2, AutogradPrivateUse3, AutogradNestedTensor, Tracer, AutocastCPU, AutocastCUDA, FuncTorchBatched, FuncTorchVmapMode, Batched, VmapMode, FuncTorchGradWrapper, PythonTLSSnapshot, FuncTorchDynamicLayerFrontMode, PythonDispatcher].

Undefined: registered at aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp:21456 [default backend kernel]
CPU: registered at aten/src/ATen/RegisterCPU.cpp:31034 [kernel]
CUDA: registered at aten/src/ATen/RegisterCUDA.cpp:43986 [kernel]
HIP: registered at aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp:21456 [default backend kernel]
MPS: registered at aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp:21456 [default backend kernel]
IPU: registered at aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp:21456 [default backend kernel]
XPU: registered at aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp:21456 [default backend kernel]
HPU: registered at aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp:21456 [default backend kernel]
VE: registered at aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp:21456 [default backend kernel]
Meta: registered at /dev/null:241 [kernel]
MTIA: registered at aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp:21456 [default backend kernel]
PrivateUse1: registered at aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp:21456 [default backend kernel]
PrivateUse2: registered at aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp:21456 [default backend kernel]
PrivateUse3: registered at aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp:21456 [default backend kernel]
FPGA: registered at aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp:21456 [default backend kernel]
ORT: registered at aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp:21456 [default backend kernel]
Vulkan: registered at aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp:21456 [default backend kernel]
Metal: registered at aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp:21456 [default backend kernel]
QuantizedCPU: registered at aten/src/ATen/RegisterQuantizedCPU.cpp:929 [kernel]
QuantizedCUDA: registered at aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp:21456 [default backend kernel]
QuantizedHIP: registered at aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp:21456 [default backend kernel]
QuantizedMPS: registered at aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp:21456 [default backend kernel]
QuantizedIPU: registered at aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp:21456 [default backend kernel]
QuantizedXPU: registered at aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp:21456 [default backend kernel]
QuantizedHPU: registered at aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp:21456 [default backend kernel]
QuantizedVE: registered at aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp:21456 [default backend kernel]
QuantizedMeta: registered at aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp:21456 [default backend kernel]
QuantizedMTIA: registered at aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp:21456 [default backend kernel]
QuantizedPrivateUse1: registered at aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp:21456 [default backend kernel]
QuantizedPrivateUse2: registered at aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp:21456 [default backend kernel]
QuantizedPrivateUse3: registered at aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp:21456 [default backend kernel]
CustomRNGKeyId: registered at aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp:21456 [default backend kernel]
MkldnnCPU: registered at aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp:21456 [default backend kernel]
SparseCsrCPU: registered at aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp:21456 [default backend kernel]
SparseCsrCUDA: registered at aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp:21456 [default backend kernel]
BackendSelect: fallthrough registered at ../aten/src/ATen/core/BackendSelectFallbackKernel.cpp:3 [backend fallback]
Python: registered at ../aten/src/ATen/core/PythonFallbackKernel.cpp:144 [backend fallback]
FuncTorchDynamicLayerBackMode: registered at ../aten/src/ATen/functorch/DynamicLayer.cpp:491 [backend fallback]
Functionalize: registered at ../aten/src/ATen/FunctionalizeFallbackKernel.cpp:280 [backend fallback]
Named: registered at ../aten/src/ATen/core/NamedRegistrations.cpp:7 [backend fallback]
Conjugate: registered at ../aten/src/ATen/ConjugateFallback.cpp:17 [backend fallback]
Negative: registered at ../aten/src/ATen/native/NegateFallback.cpp:19 [backend fallback]
ZeroTensor: registered at ../aten/src/ATen/ZeroTensorFallback.cpp:86 [backend fallback]
ADInplaceOrView: fallthrough registered at ../aten/src/ATen/core/VariableFallbackKernel.cpp:63 [backend fallback]
AutogradOther: registered at ../torch/csrc/autograd/generated/VariableType_1.cpp:14942 [autograd kernel]
AutogradCPU: registered at ../torch/csrc/autograd/generated/VariableType_1.cpp:14942 [autograd kernel]
AutogradCUDA: registered at ../torch/csrc/autograd/generated/VariableType_1.cpp:14942 [autograd kernel]
AutogradHIP: registered at ../torch/csrc/autograd/generated/VariableType_1.cpp:14942 [autograd kernel]
AutogradXLA: registered at ../torch/csrc/autograd/generated/VariableType_1.cpp:14942 [autograd kernel]
AutogradMPS: registered at ../torch/csrc/autograd/generated/VariableType_1.cpp:14942 [autograd kernel]
AutogradIPU: registered at ../torch/csrc/autograd/generated/VariableType_1.cpp:14942 [autograd kernel]
AutogradXPU: registered at ../torch/csrc/autograd/generated/VariableType_1.cpp:14942 [autograd kernel]
AutogradHPU: registered at ../torch/csrc/autograd/generated/VariableType_1.cpp:14942 [autograd kernel]
AutogradVE: registered at ../torch/csrc/autograd/generated/VariableType_1.cpp:14942 [autograd kernel]
AutogradLazy: registered at ../torch/csrc/autograd/generated/VariableType_1.cpp:14942 [autograd kernel]
AutogradMeta: registered at ../torch/csrc/autograd/generated/VariableType_1.cpp:14942 [autograd kernel]
AutogradMTIA: registered at ../torch/csrc/autograd/generated/VariableType_1.cpp:14942 [autograd kernel]
AutogradPrivateUse1: registered at ../torch/csrc/autograd/generated/VariableType_1.cpp:14942 [autograd kernel]
AutogradPrivateUse2: registered at ../torch/csrc/autograd/generated/VariableType_1.cpp:14942 [autograd kernel]
AutogradPrivateUse3: registered at ../torch/csrc/autograd/generated/VariableType_1.cpp:14942 [autograd kernel]
AutogradNestedTensor: registered at ../torch/csrc/autograd/generated/VariableType_1.cpp:14942 [autograd kernel]
Tracer: registered at ../torch/csrc/autograd/generated/TraceType_1.cpp:15791 [kernel]
AutocastCPU: fallthrough registered at ../aten/src/ATen/autocast_mode.cpp:487 [backend fallback]
AutocastCUDA: fallthrough registered at ../aten/src/ATen/autocast_mode.cpp:354 [backend fallback]
FuncTorchBatched: registered at ../aten/src/ATen/functorch/BatchRulesScatterOps.cpp:1179 [kernel]
FuncTorchVmapMode: fallthrough registered at ../aten/src/ATen/functorch/VmapModeRegistrations.cpp:28 [backend fallback]
Batched: registered at ../aten/src/ATen/LegacyBatchingRegistrations.cpp:1073 [backend fallback]
VmapMode: fallthrough registered at ../aten/src/ATen/VmapModeRegistrations.cpp:33 [backend fallback]
FuncTorchGradWrapper: registered at ../aten/src/ATen/functorch/TensorWrapper.cpp:210 [backend fallback]
PythonTLSSnapshot: registered at ../aten/src/ATen/core/PythonFallbackKernel.cpp:152 [backend fallback]
FuncTorchDynamicLayerFrontMode: registered at ../aten/src/ATen/functorch/DynamicLayer.cpp:487 [backend fallback]
PythonDispatcher: registered at ../aten/src/ATen/core/PythonFallbackKernel.cpp:148 [backend fallback]
