# Training TBModel on Auditory Cortex Data for 1 and 2/3 regions.

This notebook demonstrates loading the MUTAG dataset, applying a simple lifting, defining a small backbone, and training a `TBModel` using `TBLoss` and `TBOptimizer`.

Requirements: the project installed in PYTHONPATH and optional dependencies (torch_geometric, networkx, ripser/persim) if you want advanced features.

In [5]:
import os
os.chdir('..')

In [6]:
# 1) Imports
import torch
import lightning as pl
from omegaconf import OmegaConf

# Data loading / preprocessing utilities from the repo
from topobench.data.loaders.graph.a123_loader import A123DatasetLoader
from topobench.dataloader.dataloader import TBDataloader
from topobench.data.preprocessor import PreProcessor

# Model / training building blocks
from topobench.model.model import TBModel
# example backbone building block (SCN2 is optional; we provide a tiny custom backbone below)
# from topomodelx.nn.simplicial.scn2 import SCN2
from topobench.nn.wrappers.simplicial import SCNWrapper
from topobench.nn.encoders import AllCellFeatureEncoder
from topobench.nn.readouts import PropagateSignalDown

# Optimization / evaluation
from topobench.loss.loss import TBLoss
from topobench.optimizer import TBOptimizer
from topobench.evaluator.evaluator import TBEvaluator

print('Imports OK')

Imports OK


  from pkg_resources import parse_version


In [7]:
# 2) Configurations and utilities
loader_config = {
    'data_domain': 'graph',
    'data_type': 'A123',
    # the loader/dataset expects the dataset name key used in the dataset class
    'data_name': 'a123_cortex_m',
    'data_dir': './data/a123/'
}

# Transform config: single transform with transform_name and transform_type
# PreProcessor expects either {"transform_name": ...} (single) or {"key1": {...}, "key2": {...}} (multiple)
transform_config = {
    'transform_type': 'lifting',
    'transform_name': 'HypergraphKHopLifting',
    'k_value': 1,
}

split_config = {
    'learning_setting': 'inductive',
    'split_type': 'random',
    'data_seed': 0,
    'data_split_dir': './data/a123/splits/',
    'train_prop': 0.5,
}

# model / task hyperparameters
# A123 sample node features are: [mean_corr, std_corr, noise_diag] => 3 channels
in_channels = 3
# Multiclass classification: 9 frequency bins (bf_bin 0-8)
out_channels = 9
dim_hidden = 16
n_bins = 9  # default binning from extract_samples

readout_config = {
    'readout_name': 'PropagateSignalDown',
    'num_cell_dimensions': 1,
    'hidden_dim': dim_hidden,
    'out_channels': out_channels,
    'task_level': 'graph',
    'pooling_type': 'sum',
}

loss_config = {
    'dataset_loss': {
        'task': 'classification',
        'loss_type': 'cross_entropy',
    }
}

evaluator_config = {
    'task': 'classification',
    'num_classes': out_channels,
    'metrics': ['accuracy', 'precision', 'recall'],
}

optimizer_config = {
    'optimizer_id': 'Adam',
    'parameters': {'lr': 0.001, 'weight_decay': 0.0005},
}

# convert to OmegaConf (the project often expects DictConfig)
loader_config = OmegaConf.create(loader_config)
transform_config = OmegaConf.create(transform_config)
split_config = OmegaConf.create(split_config)
readout_config = OmegaConf.create(readout_config)
loss_config = OmegaConf.create(loss_config)
evaluator_config = OmegaConf.create(evaluator_config)
optimizer_config = OmegaConf.create(optimizer_config)

print('Configs created')

Configs created


In [8]:
# 3) Loading the data

# Use the A123-specific loader (A123DatasetLoader) to construct the dataset
graph_loader = A123DatasetLoader(loader_config)

dataset, dataset_dir = graph_loader.load()
print('Dataset loaded')

preprocessor = PreProcessor(dataset, dataset_dir, transform_config)
dataset_train, dataset_val, dataset_test = preprocessor.load_dataset_splits(split_config)
print('Dataset splits created')

# create the TopoBench datamodule / dataloader wrappers
datamodule = TBDataloader(dataset_train, dataset_val, dataset_test, batch_size=32)

print('Datasets and datamodule ready')

Dataset loaded
Transform parameters are the same, using existing data_dir: data/a123/a123_cortex_m/transform_type_transform_name_k_value/563224662
Dataset splits created
Datasets and datamodule ready


## 4) Backbone definition

We implement a tiny backbone as a `pl.LightningModule` which computes node and hyperedge features: $X_1 = B_1 dot X_0$ and applies two linear layers with ReLU.

In [9]:
class MyBackbone(pl.LightningModule):
    def __init__(self, dim_hidden):
        super().__init__()
        self.linear_0 = torch.nn.Linear(dim_hidden, dim_hidden)
        self.linear_1 = torch.nn.Linear(dim_hidden, dim_hidden)

    def forward(self, batch):
        # batch.x_0: node features (dense tensor of shape [N, dim_hidden])
        # batch.incidence_hyperedges: sparse incidence matrix with shape [m, n] or [n, m] depending on preprocessor convention
        x_0 = batch.x_0
        incidence_hyperedges = getattr(batch, 'incidence_hyperedges', None)
        if incidence_hyperedges is None:
            # fallback: try incidence as batch.incidence if available
            incidence_hyperedges = getattr(batch, 'incidence', None)

        # compute hyperedge features X_1 = B_1 dot X_0 (we assume B_1 is sparse and transposed appropriately)
        x_1 = None
        if incidence_hyperedges is not None:
            try:
                x_1 = torch.sparse.mm(incidence_hyperedges, x_0)
            except Exception:
                # if orientation differs, try transpose
                x_1 = torch.sparse.mm(incidence_hyperedges.T, x_0)
        else:
            # no incidence available: create a zero hyperedge feature placeholder
            x_1 = torch.zeros_like(x_0)

        x_0 = self.linear_0(x_0)
        x_0 = torch.relu(x_0)

        x_1 = self.linear_1(x_1)
        x_1 = torch.relu(x_1)

        model_out = {'labels': batch.y, 'batch_0': getattr(batch, 'batch_0', None)}
        model_out['x_0'] = x_0
        model_out['hyperedge'] = x_1
        return model_out

print('Backbone defined')

Backbone defined


In [10]:
# 5) Model initialization (components)
backbone = MyBackbone(dim_hidden)
readout = PropagateSignalDown(**readout_config)
loss = TBLoss(**loss_config)
feature_encoder = AllCellFeatureEncoder(in_channels=[in_channels], out_channels=dim_hidden)
evaluator = TBEvaluator(**evaluator_config)
optimizer = TBOptimizer(**optimizer_config)

print('Components instantiated')

Components instantiated


In [11]:
# 6) Instantiate TBModel
model = TBModel(backbone=backbone,
                backbone_wrapper=None,
                readout=readout,
                loss=loss,
                feature_encoder=feature_encoder,
                evaluator=evaluator,
                optimizer=optimizer,
                compile=False)

# Print a short summary (repr) to verify construction
print(model)

TBModel(backbone=MyBackbone(
  (linear_0): Linear(in_features=16, out_features=16, bias=True)
  (linear_1): Linear(in_features=16, out_features=16, bias=True)
), readout=PropagateSignalDown(num_cell_dimensions=0, self.hidden_dim=16, readout_name=PropagateSignalDown, loss=TBLoss(losses=[DatasetLoss(task=classification, loss_type=cross_entropy)]), feature_encoder=AllCellFeatureEncoder(in_channels=[3], out_channels=16, dimensions=range(0, 1)))


In [14]:
# 7) Training loop (Lightning trainer)
# Suppress some warnings for cleaner output
import warnings
warnings.filterwarnings('ignore', category=UserWarning, module='torchmetrics')

trainer = pl.Trainer(
    max_epochs=100,  # reduced for faster iteration
    accelerator='cpu',
    enable_progress_bar=True,
    log_every_n_steps=1,
    enable_model_summary=False,  # skip the model summary printout
)
trainer.fit(model, datamodule)
train_metrics = trainer.callback_metrics

print('\nTraining finished. Collected metrics:')
for key, val in train_metrics.items():
    try:
        print(f'{key:25s} {float(val):.4f}')
    except Exception:
        print(key, val)

GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/Users/mariayuffa/anaconda3/envs/tb/lib/python3.11/site-packages/lightning/pytorch/trainer/setup.py:177: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/Users/mariayuffa/anaconda3/envs/tb/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.
/Users/mariayuffa/anaconda3/envs/tb/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=100` reached.



Training finished. Collected metrics:
train/accuracy            0.2320
train/precision           0.1971
train/recall              0.2260
val/loss                  2.1351
val/accuracy              0.1129
val/precision             0.0566
val/recall                0.0991
train/loss                1.9733


In [15]:
# 8) Testing and printing metrics
trainer.test(model, datamodule)
test_metrics = trainer.callback_metrics
print('\nTest metrics:')
for key, val in test_metrics.items():
    try:
        print(f'{key:25s} {float(val):.4f}')
    except Exception:
        print(key, val)

/Users/mariayuffa/anaconda3/envs/tb/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.


Testing: |          | 0/? [00:00<?, ?it/s]


Test metrics:
test/loss                 2.1252
test/accuracy             0.1270
test/precision            0.0508
test/recall               0.1250
