# Using a new dataset

In this tutorial we show how you can use a dataset not present in the library.

This particular example uses the ENZIMES dataset, uses a simplicial lifting to create simplicial complexes, and trains the SAN model. We train the model using the appropriate training and validation datasets, and finally test it on the test dataset.

## Imports

In [1]:
import lightning as pl
import torch
from omegaconf import OmegaConf
from topomodelx.nn.simplicial.san import SAN
from torch_geometric.datasets import TUDataset

from topobenchmarkx.data.preprocess import PreProcessor
from topobenchmarkx.dataloader.dataloader import TBXDataloader
from topobenchmarkx.evaluator.evaluator import TBXEvaluator
from topobenchmarkx.loss.loss import TBXLoss
from topobenchmarkx.model.model import TBXModel
from topobenchmarkx.nn.encoders import AllCellFeatureEncoder
from topobenchmarkx.nn.readouts import PropagateSignalDown
from topobenchmarkx.nn.wrappers.simplicial import SANWrapper

## Configurations and utilities

Configurations can be specified using yaml files or directly specified in your code like in this example.

In [2]:
transform_config = { "clique_lifting":
    {"transform_type": "lifting",
    "transform_name": "SimplicialCliqueLifting",
    "complex_dim": 3,}
}

split_config = {
    "learning_setting": "inductive",
    "split_type": "k-fold",
    "data_seed": 0,
    "data_split_dir": "./data/ENZYMES/splits/",
    "k": 10,
}

in_channels = 3
out_channels = 6
dim_hidden = 16

wrapper_config = {
    "out_channels": dim_hidden,
    "num_cell_dimensions": 3,
}

readout_config = {
    "readout_name": "PropagateSignalDown",
    "num_cell_dimensions": 1,
    "hidden_dim": dim_hidden,
    "out_channels": out_channels,
    "task_level": "graph",
    "pooling_type": "sum",
}

loss_config = {"task": "classification", "loss_type": "cross_entropy"}

evaluator_config = {"task": "classification",
                    "num_classes": out_channels,
                    "classification_metrics": ["accuracy", "precision", "recall"]}

transform_config = OmegaConf.create(transform_config)
split_config = OmegaConf.create(split_config)
readout_config = OmegaConf.create(readout_config)
loss_config = OmegaConf.create(loss_config)
evaluator_config = OmegaConf.create(evaluator_config)

In [3]:
def wrapper(**factory_kwargs):
    def factory(backbone):
        return SANWrapper(backbone, **factory_kwargs)
    return factory

def scheduler(**factory_kwargs):
    def factory(optimizer):
        return torch.optim.lr_scheduler.StepLR(optimizer, **factory_kwargs)
    return factory

## Loading the data

In this example we use the ENZYMES dataset. It is a graph dataset and we use the clique lifting to transform the graphs into simplicial complexes. We invite you to check out the README of the [repository](https://github.com/pyt-team/TopoBenchmarkX) to learn more about the various liftings offered.

In [4]:
dataset_dir = "./data/ENZYMES/"
dataset = TUDataset(root=dataset_dir, name="ENZYMES")

preprocessor = PreProcessor(dataset, dataset_dir, transform_config)
dataset_train, dataset_val, dataset_test = preprocessor.load_dataset_splits(split_config)
datamodule = TBXDataloader(dataset_train, dataset_val, dataset_test, batch_size=32)

Downloading https://www.chrsmrrs.com/graphkerneldatasets/ENZYMES.zip
Processing...
Done!
Processing...
Done!


## Model initialization

We can create the backbone by instantiating the SAN model form TopoModelX. Then the `SANWrapper` and the `TBXModel` take care of the rest.

In [5]:
backbone = SAN(in_channels=dim_hidden,hidden_channels=dim_hidden)
wrapper = wrapper(**wrapper_config)

readout = PropagateSignalDown(**readout_config)
loss = TBXLoss(**loss_config)
feature_encoder = AllCellFeatureEncoder(in_channels=[in_channels, in_channels], out_channels=dim_hidden)

evaluator = TBXEvaluator(**evaluator_config)
optimizer = torch.optim.Adam
scheduler = scheduler(step_size=50, gamma=0.5)

In [6]:
model = TBXModel(backbone=backbone,
                 backbone_wrapper=wrapper,
                 readout=readout,
                 loss=loss,
                 feature_encoder=feature_encoder,
                 evaluator=evaluator,
                 optimizer=optimizer,
                 scheduler=scheduler,
                 compile=False,)

## Training

Now we can use the `lightning` trainer to train the model.

In [7]:
# Increase the number of epochs to get better results
trainer = pl.Trainer(max_epochs=5, accelerator="cpu", enable_progress_bar=False)

trainer.fit(model, datamodule)
train_metrics = trainer.callback_metrics

GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/opt/miniconda3/envs/topox/lib/python3.11/site-packages/lightning/pytorch/trainer/setup.py:187: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.
/opt/miniconda3/envs/topox/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py:75: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `lightning.pytorch` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default
/opt/miniconda3/envs/topox/lib/python3.11/site-packages/lightning/pytorch/utilities/parsing.py:44: attribute 'backbone_wrapper' re

In [8]:
for key in train_metrics:
    print(key,":    ", train_metrics[key].item())

train/accuracy :     0.18703703582286835
train/precision :     0.204917311668396
train/recall :     0.18662351369857788
val/loss :     2.2101078033447266
val/accuracy :     0.15000000596046448
val/precision :     0.1944444477558136
val/recall :     0.11645299196243286
train/loss :     2.1050922870635986


## Testing the model

Finally, we can test the model and obtain the results.

In [9]:
trainer.test(model, datamodule)
test_metrics = trainer.callback_metrics




/opt/miniconda3/envs/topox/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=13` in the `DataLoader` to improve performance.
