# Implementing your own model

In this tutorial we show how to implement your own model and test it on a dataset. 

This particular example uses the MUTAG dataset, uses an hypergraph lifting to create hypergraphs, and defines a model to work on them. 

We train the model using the appropriate training and validation datasets, and finally test it on the test dataset.

### <font color='289C4E'>Table of contents<font><a class='anchor' id='top'></a>
&emsp;[1. Imports](##sec1)

&emsp;[2. Configurations and utilities](##sec2)

&emsp;[3. Loading the data](##sec3)

&emsp;[4. Backbone definition](##sec4)

&emsp;[5. Model initialization](##sec5)

&emsp;[6. Training](##sec6)

&emsp;[7. Testing the model](##sec7)

## 1. Imports <a class="anchor" id="sec1"></a>

In [3]:
import torch
import lightning as pl
# Hydra related imports
from omegaconf import OmegaConf
# Data related imports
from topobench.data.loaders.graph import TUDatasetLoader
from topobench.dataloader.dataloader import TBDataloader
from topobench.data.preprocessor import PreProcessor
# Model related imports
from topobench.model.model import TBModel
from topomodelx.nn.simplicial.scn2 import SCN2
from topobench.nn.wrappers.simplicial import SCNWrapper
from topobench.nn.encoders import AllCellFeatureEncoder
from topobench.nn.readouts import PropagateSignalDown
# Optimization related imports
from topobench.loss.loss import TBLoss
from topobench.optimizer import TBOptimizer
from topobench.evaluator.evaluator import TBEvaluator

## 2. Configurations and utilities <a class="anchor" id="sec2"></a>

Configurations can be specified using yaml files or directly specified in your code like in this example.

In [4]:
loader_config = {
    "data_domain": "graph",
    "data_type": "TUDataset",
    "data_name": "MUTAG",
    "data_dir": "./data/MUTAG/"}

transform_config = { "khop_lifting":
    {"transform_type": "lifting",
    "transform_name": "HypergraphKHopLifting",
    "k_value": 1,}
}

split_config = {
    "learning_setting": "inductive",
    "split_type": "random",
    "data_seed": 0,
    "data_split_dir": "./data/MUTAG/splits/",
    "train_prop": 0.5,
}

in_channels = 7
out_channels = 2
dim_hidden = 16

readout_config = {
    "readout_name": "PropagateSignalDown",
    "num_cell_dimensions": 1,
    "hidden_dim": dim_hidden,
    "out_channels": out_channels,
    "task_level": "graph",
    "pooling_type": "sum",
}

loss_config = {
    "dataset_loss": 
        {
            "task": "classification", 
            "loss_type": "cross_entropy"
        }
}

evaluator_config = {"task": "classification",
                    "num_classes": out_channels,
                    "metrics": ["accuracy", "precision", "recall"]}

optimizer_config = {"optimizer_id": "Adam",
                    "parameters":
                        {"lr": 0.001,"weight_decay": 0.0005}
                    }

loader_config = OmegaConf.create(loader_config)
transform_config = OmegaConf.create(transform_config)
split_config = OmegaConf.create(split_config)
readout_config = OmegaConf.create(readout_config)
loss_config = OmegaConf.create(loss_config)
evaluator_config = OmegaConf.create(evaluator_config)
optimizer_config = OmegaConf.create(optimizer_config)

## 3. Loading the data <a class="anchor" id="sec3"></a>

In this example we use the MUTAG dataset. It is a graph dataset and we use the k-hop lifting to transform the graphs into hypergraphs. 

We invite you to check out the README of the [repository](https://github.com/pyt-team/TopoBenchX) to learn more about the various liftings offered.

In [5]:
graph_loader = TUDatasetLoader(loader_config)

dataset, dataset_dir = graph_loader.load()

preprocessor = PreProcessor(dataset, dataset_dir, transform_config)
dataset_train, dataset_val, dataset_test = preprocessor.load_dataset_splits(split_config)
datamodule = TBDataloader(dataset_train, dataset_val, dataset_test, batch_size=32)

Transform parameters are the same, using existing data_dir: data\MUTAG\MUTAG\khop_lifting\1116229528


## 4. Backbone definition <a class="anchor" id="sec4"></a>

To implement a new model we only need to define the forward method.

With a hypergraph with $n$ nodes and $m$ hyperedges this model simply calculates the hyperedge features as $X_1 = B_1 \cdot X_0$ where $B_1 \in \mathbb{R}^{n \times m}$ is the incidence matrix, where $B_{ij}=1$ if node $i$ belongs to hyperedge $j$ and is 0 otherwise.

Then the outputs are computed as $X^{'}_0=\text{ReLU}(W_0 \cdot X_0 + B_0)$ and $X^{'}_1=\text{ReLU}(W_1 \cdot X_1 + B_1)$, by simply using two linear layers with ReLU activation.

In [6]:
class myModel(pl.LightningModule):
    def __init__(self, dim_hidden):
        super().__init__()
        self.dim_hidden = dim_hidden
        self.linear_0 = torch.nn.Linear(dim_hidden, dim_hidden)
        self.linear_1 = torch.nn.Linear(dim_hidden, dim_hidden)

    def forward(self, batch):
        x_0 = batch.x_0
        incidence_hyperedges = batch.incidence_hyperedges
        x_1 = torch.sparse.mm(incidence_hyperedges, x_0)
        
        x_0 = self.linear_0(x_0)
        x_0 = torch.relu(x_0)
        x_1 = self.linear_1(x_1)
        x_1 = torch.relu(x_1)
        
        model_out = {"labels": batch.y, "batch_0": batch.batch_0}
        model_out["x_0"] = x_0
        model_out["hyperedge"] = x_1
        return model_out

## 5. Model initialization <a class="anchor" id="sec5"></a>

Now that the model is defined we can create the TBModel, which takes care of implementing everything else that is needed to train the model. 

First we need to implement a few classes to specify the behaviour of the model.

In [7]:
backbone = myModel(dim_hidden)

readout = PropagateSignalDown(**readout_config)
loss = TBLoss(**loss_config)
feature_encoder = AllCellFeatureEncoder(in_channels=[in_channels], out_channels=dim_hidden)

evaluator = TBEvaluator(**evaluator_config)
optimizer = TBOptimizer(**optimizer_config)

Now we can instantiate the TBModel.

In [23]:
model = TBModel(backbone=backbone,
                 backbone_wrapper=None,
                 readout=readout,
                 loss=loss,
                 feature_encoder=feature_encoder,
                 evaluator=evaluator,
                 optimizer=optimizer,
                 compile=False)

## 6. Training <a class="anchor" id="sec6"></a>

Now we can use the `lightning` trainer to train the model.

In [9]:
# Increase the number of epochs to get better results
trainer = pl.Trainer(max_epochs=50, accelerator="cpu", enable_progress_bar=False, log_every_n_steps=1)

trainer.fit(model, datamodule)
train_metrics = trainer.callback_metrics

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

  | Name            | Type                  | Params | Mode 
------------------------------------------------------------------
0 | feature_encoder | AllCellFeatureEncoder | 448    | train
1 | backbone        | myModel               | 544    | train
2 | readout         | PropagateSignalDown   | 34     | train
3 | val_acc_best    | MeanMetric            | 0      | train
------------------------------------------------------------------
1.0 K     Trainable params
0         Non-trainable params
1.0 K     Total params
0.004     Total estimated model params size (MB)
13        Modules in train mode
0         Modules in eval mode
c:\Users\giova\anaconda3\envs\tb\Lib\site-packages\lightning\pytorch\trainer\connectors\data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_w

In [10]:
print('      Training metrics\n', '-'*26)
for key in train_metrics:
    print('{:<21s} {:>5.4f}'.format(key+':', train_metrics[key].item()))

      Training metrics
 --------------------------
train/accuracy:       0.7660
train/precision:      0.7476
train/recall:         0.6943
val/loss:             0.5238
val/accuracy:         0.7447
val/precision:        0.7321
val/recall:           0.6354
train/loss:           0.4471


## 7. Testing the model <a class="anchor" id="sec7"></a>

Finally, we can test the model and obtain the results.

In [11]:
trainer.test(model, datamodule)
test_metrics = trainer.callback_metrics




c:\Users\giova\anaconda3\envs\tb\Lib\site-packages\lightning\pytorch\trainer\connectors\data_connector.py:424: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


In [12]:
print('      Testing metrics\n', '-'*25)
for key in test_metrics:
    print('{:<20s} {:>5.4f}'.format(key+':', test_metrics[key].item()))

      Testing metrics
 -------------------------
test/loss:           0.4870
test/accuracy:       0.7234
test/precision:      0.7340
test/recall:         0.6431


In [None]:
import torch
import torch.nn as nn
import pytorch_lightning as pl
import torch_geometric.nn as pygnn
from torch_geometric.data import Batch
from torch_geometric.nn import Linear as Linear_pyg
from torch_geometric.utils import to_dense_batch
from topobench.nn.backbones.graph.gatedgcn_layer import GatedGCNLayer
from topobench.nn.backbones.graph.gine_conv_layer import GINEConvESLapPE

class GraphGPSModel(pl.LightningModule):
    def __init__(
        self,
        dim_h,
        local_gnn_type='GINE',
        global_model_type='Transformer',
        num_heads=8,
        act='relu',
        dropout=0.0,
        attn_dropout=0.0,
        layer_norm=False,
        batch_norm=True,
        equivstable_pe=False
    ):
        super().__init__()
        self.dim_h = dim_h
        self.num_heads = num_heads
        self.layer_norm = layer_norm
        self.batch_norm = batch_norm
        self.equivstable_pe = equivstable_pe
        self.activation = nn.ReLU if act == 'relu' else nn.SiLU
        # --- Local GNN (GINE or CustomGatedGCN) ---
        if local_gnn_type == 'GINE':
            gin_nn = nn.Sequential(
                Linear_pyg(dim_h, dim_h),
                self.activation(),
                Linear_pyg(dim_h, dim_h),
            )
            if equivstable_pe:
                self.local_model = GINEConvESLapPE(gin_nn)
            else:
                self.local_model = pygnn.GINEConv(gin_nn)
        elif local_gnn_type == 'CustomGatedGCN':
            self.local_model = GatedGCNLayer(
                dim_h,
                dim_h,
                dropout=dropout,
                residual=True,
                act=act,
                equivstable_pe=equivstable_pe
            )
        else:
            raise ValueError(
                f"Unsupported local GNN: {local_gnn_type}. "
                "Choose 'GINE' or 'CustomGatedGCN'."
            )
        self.local_gnn_type = local_gnn_type
        self.dropout_local = nn.Dropout(dropout)

        if global_model_type == 'Transformer':
            self.self_attn = nn.MultiheadAttention(
                dim_h,
                num_heads,
                dropout=attn_dropout,
                batch_first=True
            )
        else:
            raise ValueError(
                f"Unsupported global model: {global_model_type}. "
                "Choose 'Transformer'."
            )
        self.dropout_attn = nn.Dropout(dropout)

        if layer_norm and batch_norm:
            raise ValueError("Cannot use both layer_norm and batch_norm")
        Norm = pygnn.norm.LayerNorm if layer_norm else nn.BatchNorm1d
        self.norm1_local = Norm(dim_h)
        self.norm1_attn = Norm(dim_h)
        self.norm2 = Norm(dim_h)

        # --- Feed-forward block ---
        self.ff_linear1 = nn.Linear(dim_h, dim_h * 2)
        self.ff_linear2 = nn.Linear(dim_h * 2, dim_h)
        self.ff_activation = self.activation()
        self.ff_dropout1 = nn.Dropout(dropout)
        self.ff_dropout2 = nn.Dropout(dropout)

    def forward(self, batch):
        h = batch.x
        h_in1 = h
        outputs = []

        # Local pass
        if self.local_gnn_type == 'CustomGatedGCN':
            es = batch.pe_EquivStableLapPE if self.equivstable_pe else None
            local_out = self.local_model(
                Batch(
                    batch=batch,
                    x=h,
                    edge_index=batch.edge_index,
                    edge_attr=batch.edge_attr,
                    pe_EquivStableLapPE=es
                )
            ).x
        else:  # GINE
            pe = batch.pe_EquivStableLapPE if self.equivstable_pe else None
            local_out = (
                self.local_model(h, batch.edge_index, batch.edge_attr, pe)
                if hasattr(self.local_model, 'equivstable_pe')
                else self.local_model(h, batch.edge_index, batch.edge_attr)
            )
        local = self.dropout_local(local_out)
        local = h_in1 + local
        local = (
            self.norm1_local(local, batch.batch)
            if hasattr(self.norm1_local, 'normalized_shape')
            else self.norm1_local(local)
        )
        outputs.append(local)

        # Global Transformer pass
        h_dense, mask = to_dense_batch(h, batch.batch)
        attn_out = self.self_attn(
            h_dense,
            h_dense,
            h_dense,
            key_padding_mask=~mask
        )[0]
        attn = attn_out[mask]
        attn = self.dropout_attn(attn)
        attn = h_in1 + attn
        attn = (
            self.norm1_attn(attn, batch.batch)
            if hasattr(self.norm1_attn, 'normalized_shape')
            else self.norm1_attn(attn)
        )
        outputs.append(attn)

        # Combine + FF
        h = sum(outputs)
        ff = self.ff_dropout1(self.ff_activation(self.ff_linear1(h)))
        ff = self.ff_dropout2(self.ff_linear2(ff))
        h = h + ff
        h = (
            self.norm2(h, batch.batch)
            if hasattr(self.norm2, 'normalized_shape')
            else self.norm2(h)
        )

        batch.x = h
        return {"labels": batch.y, "batch": batch.batch, "x": h}


In [27]:
backbone = GraphGPSModel(dim_hidden)

readout = PropagateSignalDown(**readout_config)
loss = TBLoss(**loss_config)
feature_encoder = AllCellFeatureEncoder(in_channels=[in_channels], out_channels=dim_hidden)

evaluator = TBEvaluator(**evaluator_config)
optimizer = TBOptimizer(**optimizer_config)

model = TBModel(backbone=backbone,
                 backbone_wrapper=None,
                 readout=readout,
                 loss=loss,
                 feature_encoder=feature_encoder,
                 evaluator=evaluator,
                 optimizer=optimizer,
                 compile=False)
print(type(backbone))
trainer = pl.Trainer(max_epochs=50, accelerator="cpu", enable_progress_bar=False, log_every_n_steps=1)

trainer.fit(model, datamodule)
train_metrics = trainer.callback_metrics

Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.


<class '__main__.GraphGPSModel'>


GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


TypeError: `model` must be a `LightningModule` or `torch._dynamo.OptimizedModule`, got `TBModel`