# Implementing your own model

In this tutorial we show how you can implement your own model and test it on a dataset. 

This particular example uses the MUTAG dataset, uses an hypergraph lifting to create hypergraphs, and defines a model to work on them. We train the model using the appropriate training and validation datasets, and finally test it on the test dataset.

## Imports

In [1]:
import lightning as pl
import torch
from omegaconf import OmegaConf

from topobenchmarkx.data.load import GraphLoader
from topobenchmarkx.data.preprocess import PreProcessor
from topobenchmarkx.dataloader import TBXDataloader
from topobenchmarkx.evaluator import TBXEvaluator
from topobenchmarkx.loss import TBXLoss
from topobenchmarkx.model import TBXModel
from topobenchmarkx.nn.encoders import AllCellFeatureEncoder
from topobenchmarkx.nn.readouts import PropagateSignalDown

## Configurations and utilities

Configurations can be specified using yaml files or directly specified in your code like in this example.

In [2]:
loader_config = {
    "data_domain": "graph",
    "data_type": "TUDataset",
    "data_name": "MUTAG",
    "data_dir": "./data/MUTAG/",
}

transform_config = { "khop_lifting":
    {"transform_type": "lifting",
    "transform_name": "HypergraphKHopLifting",
    "k_value": 1,}
}

split_config = {
    "learning_setting": "inductive",
    "split_type": "k-fold",
    "data_seed": 0,
    "data_split_dir": "./data/MUTAG/splits/",
    "k": 10,
}

in_channels = [7]
out_channels = 2
dim_hidden = 16

readout_config = {
    "readout_name": "PropagateSignalDown",
    "num_cell_dimensions": 1,
    "hidden_dim": dim_hidden,
    "out_channels": out_channels,
    "task_level": "graph",
    "pooling_type": "sum",
}

loss_config = {"task": "classification", "loss_type": "cross_entropy"}

evaluator_config = {"task": "classification",
                    "num_classes": out_channels,
                    "classification_metrics": ["accuracy", "precision", "recall"]}

loader_config = OmegaConf.create(loader_config)
transform_config = OmegaConf.create(transform_config)
split_config = OmegaConf.create(split_config)
readout_config = OmegaConf.create(readout_config)
loss_config = OmegaConf.create(loss_config)
evaluator_config = OmegaConf.create(evaluator_config)

In [3]:
def scheduler(**factory_kwargs):
    def factory(optimizer):
        return torch.optim.lr_scheduler.StepLR(optimizer, **factory_kwargs)
    return factory

## Loading the data

In this example we use the MUTAG dataset. It is a graph dataset and we use the k-hop lifting to transform the graphs into hypergraphs. We invite you to check out the README of the [repository](https://github.com/pyt-team/TopoBenchmarkX) to learn more about the various liftings offered.

In [4]:
graph_loader = GraphLoader(loader_config)

dataset, dataset_dir = graph_loader.load()

preprocessor = PreProcessor(dataset, dataset_dir, transform_config)
dataset_train, dataset_val, dataset_test = preprocessor.load_dataset_splits(split_config)
datamodule = TBXDataloader(dataset_train, dataset_val, dataset_test, batch_size=32)

Downloading https://www.chrsmrrs.com/graphkerneldatasets/MUTAG.zip
Processing...
Done!
Processing...
Done!


## Backbone definition

To implement a new model we only need to define the forward method.

In [5]:
class myModel(pl.LightningModule):
    def __init__(self, dim_hidden):
        super().__init__()
        self.dim_hidden = dim_hidden
        self.linear_0 = torch.nn.Linear(dim_hidden, dim_hidden)
        self.linear_1 = torch.nn.Linear(dim_hidden, dim_hidden)

    def forward(self, batch):
        x_0 = batch.x_0
        incidence_hyperedges = batch.incidence_hyperedges
        x_1 = torch.sparse.mm(incidence_hyperedges, x_0)
        
        x_0 = self.linear_0(x_0)
        x_1 = self.linear_1(x_1)
        
        model_out = {"labels": batch.y, "batch_0": batch.batch_0}
        model_out["x_0"] = x_0
        model_out["hyperedge"] = x_1
        return model_out

## Model initialization

Now that the model is defined we can create the TBXModel, which takes care of implementing everything else that is needed to train the model. First we need to implement a few classes to specify the behaviour of the model.

In [6]:
backbone = myModel(dim_hidden)

readout = PropagateSignalDown(**readout_config)
loss = TBXLoss(**loss_config)
feature_encoder = AllCellFeatureEncoder(in_channels=in_channels, out_channels=dim_hidden)

evaluator = TBXEvaluator(**evaluator_config)
optimizer = torch.optim.Adam
scheduler = scheduler(step_size=50, gamma=0.5)

In [7]:
model = TBXModel(backbone=backbone,
                 backbone_wrapper=None,
                 readout=readout,
                 loss=loss,
                 feature_encoder=feature_encoder,
                 evaluator=evaluator,
                 optimizer=optimizer,
                 scheduler=scheduler,
                 compile=False,)

## Training

Now we can use the `lightning` trainer to train the model. We are prompted to connet a Wandb account to monitor training, but we can also obtain the final training metrics from the trainer directly.

In [8]:
trainer = pl.Trainer(max_epochs=100, accelerator="cpu", enable_progress_bar=False)

trainer.fit(model, datamodule)
train_metrics = trainer.callback_metrics

GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/opt/miniconda3/envs/topox/lib/python3.11/site-packages/lightning/pytorch/trainer/setup.py:187: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.
/opt/miniconda3/envs/topox/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py:75: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `lightning.pytorch` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default
/opt/miniconda3/envs/topox/lib/python3.11/site-packages/lightning/pytorch/utilities/parsing.py:44: attribute 'scheduler' removed f

In [9]:
for key in train_metrics:
    print(key,":    ", train_metrics[key].item())

train/accuracy :     0.668639063835144
train/precision :     0.667670726776123
train/recall :     0.5130795836448669
val/loss :     0.5384264588356018
val/accuracy :     0.6842105388641357
val/precision :     0.34210526943206787
val/recall :     0.5
train/loss :     0.56179279088974


## Testing the model

Finally, we can test the model and obtain the results.

In [10]:
trainer.test(model, datamodule)
test_metrics = trainer.callback_metrics




/opt/miniconda3/envs/topox/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=13` in the `DataLoader` to improve performance.
