# Implementing your own model

In this tutorial we show how to implement your own model and test it on a dataset. 

This particular example uses the MUTAG dataset, uses an hypergraph lifting to create hypergraphs, and defines a model to work on them. 

We train the model using the appropriate training and validation datasets, and finally test it on the test dataset.

### <font color='289C4E'>Table of contents<font><a class='anchor' id='top'></a>
&emsp;[1. Imports](##sec1)

&emsp;[2. Configurations and utilities](##sec2)

&emsp;[3. Loading the data](##sec3)

&emsp;[4. Backbone definition](##sec4)

&emsp;[5. Model initialization](##sec5)

&emsp;[6. Training](##sec6)

&emsp;[7. Testing the model](##sec7)

## 1. Imports <a class="anchor" id="sec1"></a>

In [22]:
import torch
import lightning as pl
# Hydra related imports
from omegaconf import OmegaConf
# Data related imports
from topobench.data.loaders.graph import TUDatasetLoader
from topobench.dataloader.dataloader import TBDataloader
from topobench.data.preprocessor import PreProcessor
# Model related imports
from topobench.model.model import TBModel
from topomodelx.nn.simplicial.scn2 import SCN2
from topobench.nn.wrappers.simplicial import SCNWrapper
from topobench.nn.encoders import AllCellFeatureEncoder
from topobench.nn.readouts import PropagateSignalDown
# Optimization related imports
from topobench.loss.loss import TBLoss
from topobench.optimizer import TBOptimizer
from topobench.evaluator.evaluator import TBEvaluator

## 2. Configurations and utilities <a class="anchor" id="sec2"></a>

Configurations can be specified using yaml files or directly specified in your code like in this example.

In [23]:
loader_config = {
    "data_domain": "graph",
    "data_type": "TUDataset",
    "data_name": "MUTAG",
    "data_dir": "./data/MUTAG/"}

transform_config = { "khop_lifting":
    {"transform_type": "lifting",
    "transform_name": "HypergraphKHopLifting",
    "k_value": 1,}
}

split_config = {
    "learning_setting": "inductive",
    "split_type": "random",
    "data_seed": 0,
    "data_split_dir": "./data/MUTAG/splits/",
    "train_prop": 0.5,
}

in_channels = 7
out_channels = 2
dim_hidden = 16

readout_config = {
    "readout_name": "PropagateSignalDown",
    "num_cell_dimensions": 1,
    "hidden_dim": dim_hidden,
    "out_channels": out_channels,
    "task_level": "graph",
    "pooling_type": "sum",
}

loss_config = {
    "dataset_loss": 
        {
            "task": "classification", 
            "loss_type": "cross_entropy"
        }
}

evaluator_config = {"task": "classification",
                    "num_classes": out_channels,
                    "metrics": ["accuracy", "precision", "recall"]}

optimizer_config = {"optimizer_id": "Adam",
                    "parameters":
                        {"lr": 0.001,"weight_decay": 0.0005}
                    }

loader_config = OmegaConf.create(loader_config)
transform_config = OmegaConf.create(transform_config)
split_config = OmegaConf.create(split_config)
readout_config = OmegaConf.create(readout_config)
loss_config = OmegaConf.create(loss_config)
evaluator_config = OmegaConf.create(evaluator_config)
optimizer_config = OmegaConf.create(optimizer_config)

## 3. Loading the data <a class="anchor" id="sec3"></a>

In this example we use the MUTAG dataset. It is a graph dataset and we use the k-hop lifting to transform the graphs into hypergraphs. 

We invite you to check out the README of the [repository](https://github.com/pyt-team/TopoBenchX) to learn more about the various liftings offered.

In [24]:
graph_loader = TUDatasetLoader(loader_config)

dataset, dataset_dir = graph_loader.load()

preprocessor = PreProcessor(dataset, dataset_dir, transform_config)
dataset_train, dataset_val, dataset_test = preprocessor.load_dataset_splits(split_config)
datamodule = TBDataloader(dataset_train, dataset_val, dataset_test, batch_size=32)

Transform parameters are the same, using existing data_dir: data\MUTAG\MUTAG\khop_lifting\1116229528


## 4. Backbone definition <a class="anchor" id="sec4"></a>

To implement a new model we only need to define the forward method.

With a hypergraph with $n$ nodes and $m$ hyperedges this model simply calculates the hyperedge features as $X_1 = B_1 \cdot X_0$ where $B_1 \in \mathbb{R}^{n \times m}$ is the incidence matrix, where $B_{ij}=1$ if node $i$ belongs to hyperedge $j$ and is 0 otherwise.

Then the outputs are computed as $X^{'}_0=\text{ReLU}(W_0 \cdot X_0 + B_0)$ and $X^{'}_1=\text{ReLU}(W_1 \cdot X_1 + B_1)$, by simply using two linear layers with ReLU activation.

In [25]:
class myModel(pl.LightningModule):
    def __init__(self, dim_hidden):
        super().__init__()
        self.dim_hidden = dim_hidden
        self.linear_0 = torch.nn.Linear(dim_hidden, dim_hidden)
        self.linear_1 = torch.nn.Linear(dim_hidden, dim_hidden)

    def forward(self, batch):
        x_0 = batch.x_0
        incidence_hyperedges = batch.incidence_hyperedges
        x_1 = torch.sparse.mm(incidence_hyperedges, x_0)
        
        x_0 = self.linear_0(x_0)
        x_0 = torch.relu(x_0)
        x_1 = self.linear_1(x_1)
        x_1 = torch.relu(x_1)
        
        model_out = {"labels": batch.y, "batch_0": batch.batch_0}
        model_out["x_0"] = x_0
        model_out["hyperedge"] = x_1
        return model_out

## 5. Model initialization <a class="anchor" id="sec5"></a>

Now that the model is defined we can create the TBModel, which takes care of implementing everything else that is needed to train the model. 

First we need to implement a few classes to specify the behaviour of the model.

In [26]:
backbone = myModel(dim_hidden)

readout = PropagateSignalDown(**readout_config)
loss = TBLoss(**loss_config)
feature_encoder = AllCellFeatureEncoder(in_channels=[in_channels], out_channels=dim_hidden)

evaluator = TBEvaluator(**evaluator_config)
optimizer = TBOptimizer(**optimizer_config)

Now we can instantiate the TBModel.

In [27]:
model = TBModel(backbone=backbone,
                 backbone_wrapper=None,
                 readout=readout,
                 loss=loss,
                 feature_encoder=feature_encoder,
                 evaluator=evaluator,
                 optimizer=optimizer,
                 compile=False)

In [28]:
print(isinstance(model, pl.LightningModule))

True


In [29]:
model

TBModel(backbone=myModel(
  (linear_0): Linear(in_features=16, out_features=16, bias=True)
  (linear_1): Linear(in_features=16, out_features=16, bias=True)
), readout=PropagateSignalDown(num_cell_dimensions=0, self.hidden_dim=16, readout_name=PropagateSignalDown, loss=TBLoss. Losses: [DatasetLoss(task=classification, loss_type=cross_entropy)], feature_encoder=AllCellFeatureEncoder(in_channels=[7], out_channels=16, dimensions=range(0, 1)))

## 6. Training <a class="anchor" id="sec6"></a>

Now we can use the `lightning` trainer to train the model.

In [30]:
# Increase the number of epochs to get better results
trainer = pl.Trainer(max_epochs=50, accelerator="cpu", enable_progress_bar=False, log_every_n_steps=1)

trainer.fit(model, datamodule)
train_metrics = trainer.callback_metrics

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

  | Name            | Type                  | Params | Mode 
------------------------------------------------------------------
0 | feature_encoder | AllCellFeatureEncoder | 448    | train
1 | backbone        | myModel               | 544    | train
2 | readout         | PropagateSignalDown   | 34     | train
3 | val_acc_best    | MeanMetric            | 0      | train
------------------------------------------------------------------
1.0 K     Trainable params
0         Non-trainable params
1.0 K     Total params
0.004     Total estimated model params size (MB)
13        Modules in train mode
0         Modules in eval mode


[DEBUG] x.shape = torch.Size([552, 7]), expected in_features = 7
[DEBUG] x.shape = torch.Size([266, 7]), expected in_features = 7
[DEBUG] x.shape = torch.Size([587, 7]), expected in_features = 7
[DEBUG] x.shape = torch.Size([566, 7]), expected in_features = 7
[DEBUG] x.shape = torch.Size([547, 7]), expected in_features = 7
[DEBUG] x.shape = torch.Size([552, 7]), expected in_features = 7
[DEBUG] x.shape = torch.Size([266, 7]), expected in_features = 7
[DEBUG] x.shape = torch.Size([598, 7]), expected in_features = 7
[DEBUG] x.shape = torch.Size([569, 7]), expected in_features = 7
[DEBUG] x.shape = torch.Size([533, 7]), expected in_features = 7
[DEBUG] x.shape = torch.Size([552, 7]), expected in_features = 7
[DEBUG] x.shape = torch.Size([266, 7]), expected in_features = 7
[DEBUG] x.shape = torch.Size([566, 7]), expected in_features = 7
[DEBUG] x.shape = torch.Size([592, 7]), expected in_features = 7
[DEBUG] x.shape = torch.Size([542, 7]), expected in_features = 7
[DEBUG] x.shape = torch.S

`Trainer.fit` stopped: `max_epochs=50` reached.


[DEBUG] x.shape = torch.Size([587, 7]), expected in_features = 7
[DEBUG] x.shape = torch.Size([584, 7]), expected in_features = 7
[DEBUG] x.shape = torch.Size([529, 7]), expected in_features = 7
[DEBUG] x.shape = torch.Size([552, 7]), expected in_features = 7
[DEBUG] x.shape = torch.Size([266, 7]), expected in_features = 7


In [31]:
print('      Training metrics\n', '-'*26)
for key in train_metrics:
    print('{:<21s} {:>5.4f}'.format(key+':', train_metrics[key].item()))

      Training metrics
 --------------------------
train/accuracy:       0.7553
train/precision:      0.7298
train/recall:         0.6864
val/loss:             0.4941
val/accuracy:         0.7234
val/precision:        0.6992
val/recall:           0.6021
train/loss:           0.4421


## 7. Testing the model <a class="anchor" id="sec7"></a>

Finally, we can test the model and obtain the results.

In [32]:
trainer.test(model, datamodule)
test_metrics = trainer.callback_metrics

[DEBUG] x.shape = torch.Size([556, 7]), expected in_features = 7
[DEBUG] x.shape = torch.Size([297, 7]), expected in_features = 7



In [33]:
print('      Testing metrics\n', '-'*25)
for key in test_metrics:
    print('{:<20s} {:>5.4f}'.format(key+':', test_metrics[key].item()))

      Testing metrics
 -------------------------
test/loss:           0.4420
test/accuracy:       0.7234
test/precision:      0.7703
test/recall:         0.6304


## GraphGPS

In [34]:
import torch
import lightning as pl
# Hydra related imports
from omegaconf import OmegaConf
# Data related imports
from topobench.data.loaders.graph import TUDatasetLoader
from topobench.dataloader.dataloader import TBDataloader
from topobench.data.preprocessor import PreProcessor
# Model related imports
from topobench.model.model import TBModel
from topomodelx.nn.simplicial.scn2 import SCN2
from topobench.nn.wrappers.simplicial import SCNWrapper
from topobench.nn.encoders import AllCellFeatureEncoder
from topobench.nn.readouts import PropagateSignalDown
# Optimization related imports
from topobench.loss.loss import TBLoss
from topobench.optimizer import TBOptimizer
from topobench.evaluator.evaluator import TBEvaluator

In [None]:
import os
from pathlib import Path
import torch
from omegaconf import OmegaConf
from topobench.data.loaders.graph.modecule_datasets import MoleculeDatasetLoader
from topobench.data.preprocessor import PreProcessor
from topobench.dataloader.dataloader import TBDataloader

# ----------------------------
# Configuration
# ----------------------------

loader_config = OmegaConf.create({
    "data_domain": "graph",
    "data_type": "MoleculeDataset",
    "data_name": "ZINC",
    "data_dir": "./data/ZINC/"
})

transform_config = OmegaConf.create({})
split_config = OmegaConf.create({
    "learning_setting": "inductive",
    "split_type": "random",
    "data_seed": 0,
    "data_split_dir": "./data/ZINC/splits/",
    "train_prop": 0.8,
})

in_channels = 1
out_channels = 1
dim_hidden = 64
batch_size = 32

readout_config = OmegaConf.create({
    "readout_name": "PropagateSignalDown",
    "num_cell_dimensions": 1,
    "hidden_dim": dim_hidden,
    "out_channels": out_channels,
    "task_level": "graph",
    "pooling_type": "sum",
})

loss_config = OmegaConf.create({
    "dataset_loss": {
        "task": "regression",
        "loss_type": "mae"
    }
})

evaluator_config = OmegaConf.create({
    "task": "regression",
    "metrics": ["mae", "mse"],
    "num_classes": 1
})

optimizer_config = OmegaConf.create({
    "optimizer_id": "Adam",
    "parameters": {
        "lr": 0.001,
        "weight_decay": 0.0005
    }
})


def safe_load_zinc_dataset(loader_config):
    try:
        loader = MoleculeDatasetLoader(loader_config)
        return loader.load()
    except PermissionError:
        base_dir = Path(loader_config.data_dir)
        molecules_dir = base_dir / "molecules"
        raw_dir = base_dir / "raw"
        if molecules_dir.exists() and not raw_dir.exists():
            os.rename(molecules_dir, raw_dir)
        loader = MoleculeDatasetLoader(loader_config)
        return loader.load()

def cast_all_xi_to_float(dataset):
    for idx in range(len(dataset)):
        data = dataset[idx][0]
        if isinstance(data, list):
            for xi_idx, tensor in enumerate(data):
                if isinstance(tensor, torch.Tensor) and tensor.dtype in [torch.long, torch.int64]:
                    data[xi_idx] = tensor.to(torch.float32)
                    print(f"[i={idx}][x_{xi_idx}] casted to: {data[xi_idx].dtype}")

dataset, dataset_dir = safe_load_zinc_dataset(loader_config)
preprocessor = PreProcessor(dataset, dataset_dir)
dataset_train, dataset_val, dataset_test = preprocessor.load_dataset_splits(split_config)

cast_all_xi_to_float(dataset_train)
cast_all_xi_to_float(dataset_val)
cast_all_xi_to_float(dataset_test)

datamodule = TBDataloader(
    dataset_train,
    dataset_val,
    dataset_test,
    batch_size=batch_size
)


first_batch = next(iter(datamodule.train_dataloader()))
print(f"✅ Input feature dtype: {first_batch[0].x.dtype}")  # should be torch.float32


[i=0][x_0] casted to: torch.float32
[i=0][x_1] casted to: torch.float32
[i=0][x_2] casted to: torch.float32
[i=0][x_3] casted to: torch.float32
[i=0][x_4] casted to: torch.float32
[i=0][x_5] casted to: torch.float32
[i=1][x_0] casted to: torch.float32
[i=1][x_1] casted to: torch.float32
[i=1][x_2] casted to: torch.float32
[i=1][x_3] casted to: torch.float32
[i=1][x_4] casted to: torch.float32
[i=1][x_5] casted to: torch.float32
[i=2][x_0] casted to: torch.float32
[i=2][x_1] casted to: torch.float32
[i=2][x_2] casted to: torch.float32
[i=2][x_3] casted to: torch.float32
[i=2][x_4] casted to: torch.float32
[i=2][x_5] casted to: torch.float32
[i=3][x_0] casted to: torch.float32
[i=3][x_1] casted to: torch.float32
[i=3][x_2] casted to: torch.float32
[i=3][x_3] casted to: torch.float32
[i=3][x_4] casted to: torch.float32
[i=3][x_5] casted to: torch.float32
[i=4][x_0] casted to: torch.float32
[i=4][x_1] casted to: torch.float32
[i=4][x_2] casted to: torch.float32
[i=4][x_3] casted to: torch.

In [36]:
import torch
import torch.nn as nn
import torch_geometric.nn as pygnn
from torch_geometric.data import Batch
from torch_geometric.nn import Linear as Linear_pyg
from torch_geometric.utils import to_dense_batch

from topobench.nn.backbones.graph.gatedgcn_layer import GatedGCNLayer
from topobench.nn.backbones.graph.gine_conv_layer import GINEConvESLapPE


class GraphGPSModel(pl.LightningModule):
    def __init__(
        self,
        dim_h: int,
        edge_dim: int,                     # ← NEW (size of raw edge_attr)
        local_gnn_type: str = "GINE",
        global_model_type: str = "Transformer",
        num_heads: int = 8,
        act: str = "relu",
        dropout: float = 0.0,
        attn_dropout: float = 0.0,
        layer_norm: bool = False,
        batch_norm: bool = True,
        equivstable_pe: bool = False,
    ):
        super().__init__()

        self.dim_h = dim_h
        self.num_heads = num_heads
        self.layer_norm = layer_norm
        self.batch_norm = batch_norm
        self.equivstable_pe = equivstable_pe
        self.activation = nn.ReLU if act == "relu" else nn.SiLU

        # ---- Local GNN ----------------------------------------------------
        if local_gnn_type == "GINE":
            gin_nn = nn.Sequential(
                Linear_pyg(dim_h, dim_h),
                self.activation(),
                Linear_pyg(dim_h, dim_h),
            )
            if equivstable_pe:
                # custom layer already handles edge projection internally
                self.local_model = GINEConvESLapPE(gin_nn, edge_dim=edge_dim)
            else:
                self.local_model = pygnn.GINEConv(gin_nn, edge_dim=edge_dim)
        elif local_gnn_type == "CustomGatedGCN":
            self.local_model = GatedGCNLayer(
                dim_h,
                dim_h,
                dropout=dropout,
                residual=True,
                act=act,
                equivstable_pe=equivstable_pe,
            )
        else:
            raise ValueError(
                f"Unsupported local GNN: {local_gnn_type}. "
                "Choose 'GINE' or 'CustomGatedGCN'."
            )

        self.local_gnn_type = local_gnn_type
        self.dropout_local = nn.Dropout(dropout)

        # ---- Global Transformer block ------------------------------------
        if global_model_type == "Transformer":
            self.self_attn = nn.MultiheadAttention(
                dim_h, num_heads, dropout=attn_dropout, batch_first=True
            )
        else:
            raise ValueError(
                f"Unsupported global model: {global_model_type}. "
                "Choose 'Transformer'."
            )
        self.dropout_attn = nn.Dropout(dropout)

        # ---- Normalisation layers ----------------------------------------
        if layer_norm and batch_norm:
            raise ValueError("Cannot use both layer_norm and batch_norm")
        Norm = pygnn.norm.LayerNorm if layer_norm else nn.BatchNorm1d
        self.norm1_local = Norm(dim_h)
        self.norm1_attn = Norm(dim_h)
        self.norm2 = Norm(dim_h)

        # ---- Feed-forward block ------------------------------------------
        self.ff_linear1 = nn.Linear(dim_h, dim_h * 2)
        self.ff_linear2 = nn.Linear(dim_h * 2, dim_h)
        self.ff_activation = self.activation()
        self.ff_dropout1 = nn.Dropout(dropout)
        self.ff_dropout2 = nn.Dropout(dropout)

    # ---------------------------------------------------------------------
    # Forward
    # ---------------------------------------------------------------------
    def forward(self, batch):
        h = batch.x
        h_in1 = h
        outputs = []

        # ---------- Local pass -------------------------------------------
        if self.local_gnn_type == "CustomGatedGCN":
            es = batch.pe_EquivStableLapPE if self.equivstable_pe else None
            local_out = self.local_model(
                Batch(
                    batch=batch,
                    x=h,
                    edge_index=batch.edge_index,
                    edge_attr=batch.edge_attr,
                    pe_EquivStableLapPE=es,
                )
            ).x
        else:  # GINE
            pe = batch.pe_EquivStableLapPE if self.equivstable_pe else None
            local_out = (
                self.local_model(h, batch.edge_index, batch.edge_attr, pe)
                if hasattr(self.local_model, "equivstable_pe")
                else self.local_model(h, batch.edge_index, batch.edge_attr)
            )

        local = self.dropout_local(local_out)
        local = h_in1 + local
        local = (
            self.norm1_local(local, batch.batch)
            if hasattr(self.norm1_local, "normalized_shape")
            else self.norm1_local(local)
        )
        outputs.append(local)

        # ---------- Global Transformer pass ------------------------------
        h_dense, mask = to_dense_batch(h, batch.batch)
        attn_out = self.self_attn(h_dense, h_dense, h_dense, key_padding_mask=~mask)[0]
        attn = attn_out[mask]
        attn = self.dropout_attn(attn)
        attn = h_in1 + attn
        attn = (
            self.norm1_attn(attn, batch.batch)
            if hasattr(self.norm1_attn, "normalized_shape")
            else self.norm1_attn(attn)
        )
        outputs.append(attn)

        # ---------- Combine + feed-forward -------------------------------
        h = sum(outputs)
        ff = self.ff_dropout1(self.ff_activation(self.ff_linear1(h)))
        ff = self.ff_dropout2(self.ff_linear2(ff))
        h = h + ff
        h = (
            self.norm2(h, batch.batch)
            if hasattr(self.norm2, "normalized_shape")
            else self.norm2(h)
        )

        batch.x = h
        return {"labels": batch.y, "batch": batch.batch, "x": h}


In [37]:
model = GraphGPSModel(
    dim_h=7,        
    edge_dim=4,    
    local_gnn_type="GINE",
    num_heads=1,
)
readout = PropagateSignalDown(**readout_config)
loss = TBLoss(**loss_config)
feature_encoder = AllCellFeatureEncoder(in_channels=[in_channels], out_channels=dim_hidden)

evaluator = TBEvaluator(**evaluator_config)
optimizer = TBOptimizer(**optimizer_config)

model = TBModel(backbone=model,
                 backbone_wrapper=None,
                 readout=readout,
                 loss=loss,
                 feature_encoder=feature_encoder,
                 evaluator=evaluator,
                 optimizer=optimizer,
                 compile=False)
print(isinstance(model, pl.LightningModule))


True


In [38]:
# Increase the number of epochs to get better results
trainer = pl.Trainer(max_epochs=50, accelerator="cpu", enable_progress_bar=False, log_every_n_steps=1)

trainer.fit(model, datamodule)
train_metrics = trainer.callback_metrics

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

  | Name            | Type                  | Params | Mode 
------------------------------------------------------------------
0 | feature_encoder | AllCellFeatureEncoder | 4.5 K  | train
1 | backbone        | GraphGPSModel         | 630    | train
2 | readout         | PropagateSignalDown   | 65     | train
3 | val_acc_best    | MeanMetric            | 0      | train
------------------------------------------------------------------
5.2 K     Trainable params
0         Non-trainable params
5.2 K     Total params
0.021     Total estimated model params size (MB)
30        Modules in train mode
0         Modules in eval mode


⚠️ Casting x from torch.int64 to float32 in BaseEncoder
[DEBUG] x.shape = torch.Size([711, 1]), expected in_features = 1


RuntimeError: mat1 and mat2 must have the same dtype, but got Long and Float

In [None]:
print('      Training metrics\n', '-'*26)
for key in train_metrics:
    print('{:<21s} {:>5.4f}'.format(key+':', train_metrics[key].item()))

      Training metrics
 --------------------------
train/accuracy:       0.7447
train/precision:      0.7113
train/recall:         0.6866
val/loss:             0.5149
val/accuracy:         0.7234
val/precision:        0.6992
val/recall:           0.6021
train/loss:           0.4408
