# Train a Cell Attention Network (CAN)

We create and train a simplified version of the CAN originally proposed in [Giusti et. al : Cell Attention Network (2022)](https://arxiv.org/abs/2209.08179).

TODO organize them in a better way.
Equations:

$\textbf{h}_x^{t+1} =  \phi^t \Bigg ( \textbf{h}_x^{t}, \bigotimes_{\mathcal{N}_k\in\mathcal N}\bigoplus_{y \in \mathcal{N}_k(x)}  \alpha_k(h_x^t,h_y^t)\Bigg )$ 

         
🟥 $\quad m_{(y \rightarrow x),k}^{(r)} = \alpha_k(h_x^t,h_y^t) = a_k(h_x^{t}, h_y^{t}) \cdot \psi_k^t(h_x^{t})\quad \forall \mathcal N_k$

🟧 $\quad m_{x,k}^{(r)} = \bigoplus_{y \in \mathcal{N}_k(x)}  m^{(r)}  _{(y \rightarrow x),k}$

🟩 $\quad m_{x}^{(r)} = \bigotimes_{\mathcal{N}_k\in\mathcal N}m_{x,k}^{(r)}$    

🟦 $\quad h_x^{t+1,(r)} = \phi^{t}(h_x^t, m_{x}^{(r)})$

Attentional Lift:

🟥 $\quad m_{(y,z) \rightarrow x}^{(0 \rightarrow 1)} = \alpha(h_y,h_e) = \Theta(h_z||h_y)$

🟦 $\quad h_x^{(1)} = \phi(h_x, m_x^{(1)})$

Attentional Pooling:

🟥 $\quad m_{x}^{(r)} = \gamma^t(h_x^t) = \tau^t (a^t\cdot h_x^t)$

🟦 $\quad h_x^{t+1,(r)} = \phi^t(h_x^t, m_{x}^{(r)}), \forall x\in \mathcal C_r^{t+1}$




### The Neural Network:

The equations of one layer of this neural network are given by:

1. A convolution from nodes to nodes using an adjacency message passing scheme (AMPS):

🟥 $\quad$

🟧 $\quad$ 

🟩 $\quad$

🟦 $\quad$

2. A convolution from edges to faces using a cohomology message passing scheme:

🟥 $\quad$

🟧 $\quad$ 

🟩 $\quad$

🟦 $\quad$

Where the notations are defined in [Papillon et al : Architectures of Topological Deep Learning: A Survey of Topological Neural Networks (2023)](https://arxiv.org/abs/2304.10031).

### The Task:

We train this model to perform entire complex classification on [`MUTAG` from the TUDataset](https://paperswithcode.com/dataset/mutag). This dataset contains:
- 188 samples of chemical compounds represented as graphs,
- with 7 discrete node features.

The task is to predict the mutagenicity of each compound on Salmonella typhimurium.

# Set-up


In [287]:
import random

import numpy as np
import torch
from sklearn.model_selection import train_test_split
from toponetx import CellComplex
from torch_geometric.datasets import TUDataset
from torch_geometric.utils.convert import to_networkx
import torch.nn.functional as F

from topomodelx.nn.cell.can_layer import CANLayer
from topomodelx.nn.cell.attentional_lift_layer import MultiHeadLiftLayer
from topomodelx.nn.cell.attentional_pooling_layer import PoolLayer

If GPU's are available, we will make use of them. Otherwise, this will run on CPU.

In [288]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


# Pre-processing

## Import data ##

We import a subset of MUTAG, a benchmark dataset for graph classification. 

We then lift each graph into our topological domain of choice, here: a cell complex.

We also retrieve:
- input signals `x_0` and `x_1` on the nodes (0-cells) and edges (1-cells) for each complex: these will be the model's inputs,
- a binary classification label `y` associated to the cell complex.

In [289]:
dataset = TUDataset(
    root="/tmp/MUTAG", name="MUTAG", use_edge_attr=True, use_node_attr=True
)
dataset = dataset
cc_list = []
x_0_list = []
x_1_list = []
y_list = []
for graph in dataset:
    cell_complex = CellComplex(to_networkx(graph))
    cc_list.append(cell_complex)
    x_0_list.append(graph.x)
    x_1_list.append(graph.edge_attr)
    y_list.append(int(graph.y))
else:
    print(graph)

i_cc = 0
print(f"Features on nodes for the {i_cc}th cell complex: {x_0_list[i_cc].shape}.")
print(f"Features on edges for the {i_cc}th cell complex: {x_1_list[i_cc].shape}.")
print(f"Label of {i_cc}th cell complex: {y_list[i_cc]}.")

Data(edge_index=[2, 36], x=[16, 7], edge_attr=[36, 4], y=[1])
Features on nodes for the 0th cell complex: torch.Size([17, 7]).
Features on edges for the 0th cell complex: torch.Size([38, 4]).
Label of 0th cell complex: 1.


## Define neighborhood structures. ##

Implementing the CCXN architecture will require to perform message passing along neighborhood structures of the cell complexes.

Thus, now we retrieve these neighborhood structures (i.e. their representative matrices) that we will use to send messages. 

For the CCXN, we need the adjacency matrix $A_{\uparrow, 0}$ and the coboundary matrix $B_2^T$ of each cell complex.

In [290]:
lower_neighborhood_list = []
upper_neighborhood_list = []
adjacency_0_list = []

for cell_complex in cc_list:
    adjacency_0 = cell_complex.adjacency_matrix(rank=0)
    adjacency_0 = torch.from_numpy(adjacency_0.todense()).to_sparse()
    adjacency_0_list.append(adjacency_0)

    lower_neighborhood_t = cell_complex.down_laplacian_matrix(rank=1)
    lower_neighborhood_t = torch.from_numpy(lower_neighborhood_t.todense()).to_sparse()
    lower_neighborhood_list.append(lower_neighborhood_t)

    try:
        upper_neighborhood_t = cell_complex.up_laplacian_matrix(rank=1)
        upper_neighborhood_t = torch.from_numpy(
            upper_neighborhood_t.todense()
        ).to_sparse()
    except:
        upper_neighborhood_t = np.zeros(
            (lower_neighborhood_t.shape[0], lower_neighborhood_t.shape[0])
        )
        upper_neighborhood_t = torch.from_numpy(upper_neighborhood_t).to_sparse()

    upper_neighborhood_list.append(upper_neighborhood_t)

# Create the Neural Network

Using the CCXNLayer class, we create a neural network with stacked layers.

In [291]:
in_channels_0 = x_0_list[0].shape[-1]
in_channels_1 = x_1_list[0].shape[-1]
in_channels_2 = 5
print(
    f"The dimension of input features on nodes, edges and faces are: {in_channels_0}, {in_channels_1} and {in_channels_2}."
)

The dimension of input features on nodes, edges and faces are: 7, 4 and 5.


In [292]:
class CAN(torch.nn.Module):
    def __init__(
        self,
        in_channels_0,
        in_channels_1,
        out_channels,
        num_classes,
        dropout=0.5,
        heads=3,
        concat=True,
        skip_connection=True,
        att_activation=torch.nn.LeakyReLU(0.2),
        n_layers=2,
        att_lift=True,
    ):
        super().__init__()

        if att_lift:
            self.lift_layer = MultiHeadLiftLayer(
                in_channels_0=in_channels_0,
                heads=in_channels_0,
                signal_lift_dropout=0.5,
            )
            in_channels_1 = in_channels_1 + in_channels_0

        layers = []

        layers.append(
            CANLayer(
                in_channels=in_channels_1,
                out_channels=out_channels,
                heads=heads,
                concat=concat,
                skip_connection=skip_connection,
                att_activation=att_activation,
                aggr_func="sum",
                update_func="relu",
            )
        )

        for _ in range(n_layers - 1):
            layers.append(
                CANLayer(
                    in_channels=out_channels * heads,
                    out_channels=out_channels,
                    dropout=dropout,
                    heads=heads,
                    concat=concat,
                    skip_connection=skip_connection,
                    att_activation=att_activation,
                    aggr_func="sum",
                    update_func="relu",
                )
            )

            layers.append(
                PoolLayer(
                    k_pool=0.5,
                    in_channels_0=out_channels * heads,
                    signal_pool_activation=torch.nn.Sigmoid(),
                    readout=True,
                )
            )

        self.layers = torch.nn.ModuleList(layers)
        self.lin_0 = torch.nn.Linear(heads * out_channels, 128)
        self.lin_1 = torch.nn.Linear(128, num_classes)

    def forward(
        self, x_0, x_1, neighborhood_0_to_0, lower_neighborhood, upper_neighborhood
    ):
        if hasattr(self, "lift_layer"):
            x_1 = self.lift_layer(x_0, neighborhood_0_to_0, x_1)

        for layer in self.layers:
            if isinstance(layer, PoolLayer):
                x_1, lower_neighborhood, upper_neighborhood = layer(
                    x_1, lower_neighborhood, upper_neighborhood
                )
            else:
                x_1 = layer(x_1, lower_neighborhood, upper_neighborhood)
                x_1 = F.dropout(x_1, p=0.5, training=self.training)

        # max pooling over all nodes in each graph
        x = x_1.max(dim=0)[0]

        # Feed-Foward Neural Network to predict the graph label
        out = self.lin_1(torch.nn.functional.relu(self.lin_0(x)))

        return out

# Train the Neural Network

We specify the model, initialize loss, and specify an optimizer. We first try it without any attention mechanism.

In [293]:
model = CAN(
    in_channels_0,
    in_channels_1,
    32,
    dropout=0.5,
    heads=2,
    num_classes=2,
    n_layers=2,
    att_lift=True,
)
model = model.to(device)
crit = torch.nn.CrossEntropyLoss()
opt = torch.optim.Adam(model.parameters(), lr=0.001)
model

CAN(
  (lift_layer): MultiHeadLiftLayer(
    (lifts): LiftLayer()
  )
  (layers): ModuleList(
    (0): CANLayer(
      (lower_att): MultiHeadCellAttention(
        (att_activation): LeakyReLU(negative_slope=0.2)
        (lin): Linear(in_features=11, out_features=64, bias=False)
      )
      (upper_att): MultiHeadCellAttention(
        (att_activation): LeakyReLU(negative_slope=0.2)
        (lin): Linear(in_features=11, out_features=64, bias=False)
      )
      (lin): Linear(in_features=11, out_features=64, bias=False)
      (aggregation): Aggregation()
    )
    (1): CANLayer(
      (lower_att): MultiHeadCellAttention(
        (att_activation): LeakyReLU(negative_slope=0.2)
        (lin): Linear(in_features=64, out_features=64, bias=False)
      )
      (upper_att): MultiHeadCellAttention(
        (att_activation): LeakyReLU(negative_slope=0.2)
        (lin): Linear(in_features=64, out_features=64, bias=False)
      )
      (lin): Linear(in_features=64, out_features=64, bias=False)
 

We split the dataset into train and test sets.

In [294]:
test_size = 0.3
x_1_train, x_1_test = train_test_split(x_1_list, test_size=test_size, shuffle=False)
x_0_train, x_0_test = train_test_split(x_0_list, test_size=test_size, shuffle=False)
lower_neighborhood_train, lower_neighborhood_test = train_test_split(
    lower_neighborhood_list, test_size=test_size, shuffle=False
)
upper_neighborhood_train, upper_neighborhood_test = train_test_split(
    upper_neighborhood_list, test_size=test_size, shuffle=False
)
adjacency_0_train, adjacency_0_test = train_test_split(
    adjacency_0_list, test_size=test_size, shuffle=False
)
y_train, y_test = train_test_split(y_list, test_size=test_size, shuffle=False)

In [295]:
test_interval = 1
num_epochs = 10
for epoch_i in range(1, num_epochs + 1):
    epoch_loss = []
    num_samples = 0
    correct = 0
    model.train()
    for x_0, x_1, adjacency, lower_neighborhood, upper_neighborhood, y in zip(
        x_0_train,
        x_1_train,
        adjacency_0_train,
        lower_neighborhood_train,
        upper_neighborhood_train,
        y_train,
    ):
        x_0 = x_0.float().to(device)
        x_1, y = x_1.float().to(device), torch.tensor(y, dtype=torch.long).to(device)
        adjacency = adjacency.float().to(device)
        lower_neighborhood, upper_neighborhood = lower_neighborhood.float().to(
            device
        ), upper_neighborhood.float().to(device)
        opt.zero_grad()
        y_hat = model(x_0, x_1, adjacency, lower_neighborhood, upper_neighborhood)
        loss = crit(y_hat, y)
        correct += (y_hat.argmax() == y).sum().item()
        num_samples += 1
        loss.backward()
        opt.step()
        epoch_loss.append(loss.item())
    train_acc = correct / num_samples
    print(
        f"Epoch: {epoch_i} loss: {np.mean(epoch_loss):.4f} Train_acc: {train_acc:.4f}",
        flush=True,
    )
    if epoch_i % test_interval == 0:
        with torch.no_grad():
            num_samples = 0
            correct = 0
            for x_0, x_1, adjacency, lower_neighborhood, upper_neighborhood, y in zip(
                x_0_test,
                x_1_test,
                adjacency_0_test,
                lower_neighborhood_test,
                upper_neighborhood_test,
                y_test,
            ):
                x_0 = x_0.float().to(device)
                x_1, y = x_1.float().to(device), torch.tensor(y, dtype=torch.long).to(
                    device
                )
                adjacency = adjacency.float().to(device)
                lower_neighborhood, upper_neighborhood = lower_neighborhood.float().to(
                    device
                ), upper_neighborhood.float().to(device)
                y_hat = model(
                    x_0, x_1, adjacency, lower_neighborhood, upper_neighborhood
                )
                correct += (y_hat.argmax() == y).sum().item()
                num_samples += 1
            test_acc = correct / num_samples
            print(f"Test_acc: {test_acc:.4f}", flush=True)

Epoch: 1 loss: 0.6238 Train_acc: 0.6870
Test_acc: 0.5965
Epoch: 2 loss: 0.6272 Train_acc: 0.6947
Test_acc: 0.5965
Epoch: 3 loss: 0.5953 Train_acc: 0.6947
Test_acc: 0.5965
Epoch: 4 loss: 0.5952 Train_acc: 0.6947
Test_acc: 0.5965
Epoch: 5 loss: 0.5804 Train_acc: 0.7023
Test_acc: 0.6316
Epoch: 6 loss: 0.5702 Train_acc: 0.7252
Test_acc: 0.6316
Epoch: 7 loss: 0.5534 Train_acc: 0.7176
Test_acc: 0.7193
Epoch: 8 loss: 0.5415 Train_acc: 0.7481
Test_acc: 0.7368
Epoch: 9 loss: 0.5556 Train_acc: 0.7328
Test_acc: 0.7544
Epoch: 10 loss: 0.5322 Train_acc: 0.7405
Test_acc: 0.7368
