# Implementing your own lifting

In this tutorial we show how you can implement your own lifting and test it on a dataset. 

This particular example uses the MUTAG dataset. The lifting for this example is similar to the SimplicialCliqueLifting but discards the cliques that are bigger than the maximum simplices we want to consider.

We test this lifting using the SCN2 model from `TopoModelX`.

## Imports

In [3]:
from itertools import combinations
from typing import Any

import lightning as pl
import networkx as nx
import torch
import torch_geometric
from omegaconf import OmegaConf
from topomodelx.nn.simplicial.scn2 import SCN2
from toponetx.classes import SimplicialComplex

from topobenchmarkx.data.load import GraphLoader
from topobenchmarkx.data.preprocess import PreProcessor
from topobenchmarkx.dataloader import TBXDataloader
from topobenchmarkx.evaluator import TBXEvaluator
from topobenchmarkx.loss import TBXLoss
from topobenchmarkx.model import TBXModel
from topobenchmarkx.nn.encoders import AllCellFeatureEncoder
from topobenchmarkx.nn.readouts import PropagateSignalDown
from topobenchmarkx.nn.wrappers.simplicial import SCNWrapper
from topobenchmarkx.transforms.liftings.graph2simplicial import (
    Graph2SimplicialLifting,
)

## Configurations and utilities

Configurations can be specified using yaml files or directly specified in your code like in this example. To keep the notebook clean here we already define the configuration for the lifting, which is defined later in the notebook.

In [4]:
loader_config = {
    "data_domain": "graph",
    "data_type": "TUDataset",
    "data_name": "MUTAG",
    "data_dir": "./data/MUTAG/",
}

transform_config = { "clique_lifting":
    {"_target_": "__main__.SimplicialCliquesLEQLifting",
     "transform_name": "SimplicialCliquesLEQLifting",
    "transform_type": "lifting",
    "complex_dim": 3,}
}

split_config = {
    "learning_setting": "inductive",
    "split_type": "k-fold",
    "data_seed": 0,
    "data_split_dir": "./data/MUTAG/splits/",
    "k": 10,
}

in_channels = 7
out_channels = 2
dim_hidden = 128

wrapper_config = {
    "out_channels": dim_hidden,
    "num_cell_dimensions": 3,
}

readout_config = {
    "readout_name": "PropagateSignalDown",
    "num_cell_dimensions": 1,
    "hidden_dim": dim_hidden,
    "out_channels": out_channels,
    "task_level": "graph",
    "pooling_type": "sum",
}

loss_config = {"task": "classification", "loss_type": "cross_entropy"}

evaluator_config = {"task": "classification",
                    "num_classes": out_channels,
                    "classification_metrics": ["accuracy", "precision", "recall"]}

loader_config = OmegaConf.create(loader_config)
transform_config = OmegaConf.create(transform_config)
split_config = OmegaConf.create(split_config)
wrapper_config = OmegaConf.create(wrapper_config)
readout_config = OmegaConf.create(readout_config)
loss_config = OmegaConf.create(loss_config)
evaluator_config = OmegaConf.create(evaluator_config)

In [5]:
def wrapper(**factory_kwargs):
    def factory(backbone):
        return SCNWrapper(backbone, **factory_kwargs)
    return factory

def scheduler(**factory_kwargs):
    def factory(optimizer):
        return torch.optim.lr_scheduler.StepLR(optimizer, **factory_kwargs)
    return factory

## Defining the lifting

Here we define the lifting we intend on using. The `SimplicialCliquesLEQLifting` finds the cliques that have a number of nodes less or equal to the maximum simplices we want to consider and creates simplices from them. The configuration for the lifting was already defined with the other configurations.

In [6]:
class SimplicialCliquesLEQLifting(Graph2SimplicialLifting):
    r"""Lifts graphs to simplicial complex domain by identifying the cliques as k-simplices. Only the cliques with size smaller or equal to the max complex dimension are considered.
    
    Args:
        kwargs (optional): Additional arguments for the class.
    """
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def lift_topology(self, data: torch_geometric.data.Data) -> dict:
        r"""Lifts the topology of a graph to a simplicial complex by identifying the cliques as k-simplices. Only the cliques with size smaller or equal to the max complex dimension are considered.

        Args:
            data (torch_geometric.data.Data): The input data to be lifted.
        Returns:
            dict: The lifted topology.
        """
        graph = self._generate_graph_from_data(data)
        simplicial_complex = SimplicialComplex(graph)
        cliques = nx.find_cliques(graph)
        
        simplices: list[set[tuple[Any, ...]]] = [set() for _ in range(2, self.complex_dim + 1)]
        for clique in cliques:
            if len(clique) <= self.complex_dim + 1:
                for i in range(2, self.complex_dim + 1):
                    for c in combinations(clique, i + 1):
                        simplices[i - 2].add(tuple(c))

        for set_k_simplices in simplices:
            simplicial_complex.add_simplices_from(list(set_k_simplices))

        return self._get_lifted_topology(simplicial_complex, graph)


## Loading the data

In this example we use the MUTAG dataset.

In [7]:
from topobenchmarkx.transforms import TRANSFORMS

TRANSFORMS["SimplicialCliquesLEQLifting"] = SimplicialCliquesLEQLifting

In [8]:
graph_loader = GraphLoader(loader_config)

dataset, dataset_dir = graph_loader.load()

preprocessor = PreProcessor(dataset, dataset_dir, transform_config)
dataset_train, dataset_val, dataset_test = preprocessor.load_dataset_splits(split_config)
datamodule = TBXDataloader(dataset_train, dataset_val, dataset_test, batch_size=32)

Processing...
Done!


## Model initialization

We can create the backbone by instantiating the SCN2 model form TopoModelX. Then the `SCNWrapper` and the `TBXModel` take care of the rest.

In [9]:
backbone = SCN2(in_channels_0=dim_hidden,in_channels_1=dim_hidden,in_channels_2=dim_hidden)
backbone_wrapper = wrapper(**wrapper_config)

readout = PropagateSignalDown(**readout_config)
loss = TBXLoss(**loss_config)
feature_encoder = AllCellFeatureEncoder(in_channels=[in_channels, in_channels, in_channels], out_channels=dim_hidden)

evaluator = TBXEvaluator(**evaluator_config)
optimizer = torch.optim.Adam
scheduler = scheduler(step_size=50, gamma=0.5)

In [10]:
model = TBXModel(backbone=backbone,
                 backbone_wrapper=backbone_wrapper,
                 readout=readout,
                 loss=loss,
                 feature_encoder=feature_encoder,
                 evaluator=evaluator,
                 optimizer=optimizer,
                 scheduler=scheduler,
                 compile=False,)

## Training

Now we can use the `lightning` trainer to train the model. We are prompted to connet a Wandb account to monitor training, but we can also obtain the final training metrics from the trainer directly.

In [12]:
trainer = pl.Trainer(max_epochs=200, accelerator="cpu", enable_progress_bar=False)

trainer.fit(model, datamodule)
train_metrics = trainer.callback_metrics

GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/opt/miniconda3/envs/topox/lib/python3.11/site-packages/lightning/pytorch/trainer/setup.py:187: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.

  | Name            | Type                  | Params
----------------------------------------------------------
0 | feature_encoder | AllCellFeatureEncoder | 53.8 K
1 | backbone        | SCNWrapper            | 99.1 K
2 | readout         | PropagateSignalDown   | 258   
3 | val_acc_best    | MeanMetric            | 0     
----------------------------------------------------------
153 K     Trainable params
0         Non-trainable params
153 K     Total params
0.612     Total estimated model params size (MB)
  normalized_matrix = diag_matrix @ (matrix @ diag_matrix)
/opt/miniconda3/envs/topox/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data

In [13]:
for key in train_metrics:
    print(key,":    ", train_metrics[key].item())

train/accuracy :     0.8402366638183594
train/precision :     0.82198166847229
train/recall :     0.8191572427749634
val/loss :     0.6285274624824524
val/accuracy :     0.7368420958518982
val/precision :     0.7083333730697632
val/recall :     0.6282051205635071
train/loss :     0.3544183671474457


## Testing the model

Finally, we can test the model and obtain the results.

In [14]:
trainer.test(model, datamodule)
test_metrics = trainer.callback_metrics




/opt/miniconda3/envs/topox/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=13` in the `DataLoader` to improve performance.
