# Training TBModel on Auditory Cortex Data

This notebook demonstrates three different tasks using the A123 mouse auditory cortex dataset:

1. **Graph-level Classification**: Predict frequency bin (0-8) from graph structure
2. **Triangle Classification**: Classify topological role of triangles in the correlation graph
3. **Triangle Common-Neighbors**: Predict the number of common neighbors for triangles

We'll show how to load the dataset, apply lifting transformations, define a backbone, and train a `TBModel` using `TBLoss` and `TBOptimizer`.

Requirements: the project installed in PYTHONPATH and optional dependencies (torch_geometric, networkx, ripser/persim) if you want advanced features.

In [1]:
import os
os.chdir('..')

In [2]:
# 1) Imports
import torch
import lightning as pl
from omegaconf import OmegaConf

# Data loading / preprocessing utilities from the repo
from topobench.data.loaders.graph.a123_loader import A123DatasetLoader
from topobench.dataloader.dataloader import TBDataloader
from topobench.data.preprocessor import PreProcessor

# Model / training building blocks
from topobench.model.model import TBModel
# example backbone building block (SCN2 is optional; we provide a tiny custom backbone below)
# from topomodelx.nn.simplicial.scn2 import SCN2
from topobench.nn.wrappers.simplicial import SCNWrapper
from topobench.nn.encoders import AllCellFeatureEncoder
from topobench.nn.readouts import PropagateSignalDown

# Optimization / evaluation
from topobench.loss.loss import TBLoss
from topobench.optimizer import TBOptimizer
from topobench.evaluator.evaluator import TBEvaluator

print('Imports OK')

Imports OK


  from pkg_resources import parse_version


In [3]:
# 2) Configurations for different tasks
# Note: We'll demonstrate each task separately by changing the specific_task parameter

# Common loader config
loader_config_base = {
    'data_domain': 'graph',
    'data_type': 'A123',
    'data_name': 'a123_cortex_m',
    'data_dir': './data/a123/',
    'corr_threshold': 0.3,  # Higher threshold ensures graphs have meaningful edges
}

# Transform config: using CellCycleLifting (more robust for graphs with few edges)
# CellCycleLifting finds cycles and lifts them to 2-cells, handles empty graphs gracefully
transform_config = {
    'transform_type': 'lifting',
    'transform_name': 'CellCycleLifting',
    'max_cell_length': None,  # No limit on cycle length
}

split_config = {
    'learning_setting': 'inductive',
    'split_type': 'random',
    'data_seed': 0,
    'data_split_dir': './data/a123/splits/',
    'train_prop': 0.5,
}

# Task configurations
tasks = {
    'graph_classification': {
        'description': 'Graph-level classification (predict frequency bin 0-8)',
        'specific_task': 'classification',
        'in_channels': 3,
        'out_channels': 9,
        'task_level': 'graph',
    },
    'triangle_classification': {
        'description': 'Triangle classification (predict topological role, 9 classes)',
        'specific_task': 'triangle_classification',
        'in_channels': 3,
        'out_channels': 9,
        'task_level': 'graph',
    },
    'triangle_common_neighbors': {
        'description': 'Triangle common-neighbors (predict # common neighbors, 9 classes)',
        'specific_task': 'triangle_common_neighbors',
        'in_channels': 3,
        'out_channels': 9,
        'task_level': 'graph',
    }
}

# Select task to run (change to 'triangle_classification' or 'triangle_common_neighbors' to run different tasks)
TASK_NAME = 'triangle_classification'
TASK_CONFIG = tasks[TASK_NAME]

print(f"Selected task: {TASK_NAME}")
print(f"Description: {TASK_CONFIG['description']}")

# Create loader config with specific task
loader_config = OmegaConf.create({**loader_config_base, 'specific_task': TASK_CONFIG['specific_task']})

dim_hidden = 16
in_channels = TASK_CONFIG['in_channels']
out_channels = TASK_CONFIG['out_channels']

readout_config = {
    'readout_name': 'PropagateSignalDown',
    'num_cell_dimensions': 1,
    'hidden_dim': dim_hidden,
    'out_channels': out_channels,
    'task_level': TASK_CONFIG['task_level'],
    'pooling_type': 'sum',
}

loss_config = {
    'dataset_loss': {
        'task': 'classification',
        'loss_type': 'cross_entropy',
    }
}

evaluator_config = {
    'task': 'classification',
    'num_classes': out_channels,
    'metrics': ['f1', 'precision', 'recall', 'accuracy'],
}

optimizer_config = {
    'optimizer_id': 'Adam',
    'parameters': {'lr': 0.001, 'weight_decay': 0.0005},
}

# Convert to OmegaConf
transform_config = OmegaConf.create(transform_config)
split_config = OmegaConf.create(split_config)
readout_config = OmegaConf.create(readout_config)
loss_config = OmegaConf.create(loss_config)
evaluator_config = OmegaConf.create(evaluator_config)
optimizer_config = OmegaConf.create(optimizer_config)

print('Configs created')
print(f"Loader config: {loader_config}")
print(f"Input channels: {in_channels}, Output channels: {out_channels}")

Selected task: triangle_classification
Description: Triangle classification (predict topological role, 9 classes)
Configs created
Loader config: {'data_domain': 'graph', 'data_type': 'A123', 'data_name': 'a123_cortex_m', 'data_dir': './data/a123/', 'corr_threshold': 0.3, 'specific_task': 'triangle_classification'}
Input channels: 3, Output channels: 9


In [4]:
# 3) Loading the data

# Use the A123-specific loader (A123DatasetLoader) to construct the dataset
graph_loader = A123DatasetLoader(loader_config)

dataset, dataset_dir = graph_loader.load()
print(f'Dataset loaded: {len(dataset)} samples')

# For triangle-level tasks, skip lifting transformations (triangles have no edge_index)
# Only apply lifting for graph-level classification
task_type = TASK_CONFIG['specific_task']
if task_type in ['triangle_classification', 'triangle_common_neighbors']:
    # Skip lifting for triangle tasks - they don't have graph structure
    print(f"Task '{task_type}' uses triangle-level features (no edge_index)")
    print("Skipping lifting transformation for triangle data")
    preprocessor = PreProcessor(dataset, dataset_dir, transforms_config=None)
else:
    # Apply lifting for graph-level tasks
    preprocessor = PreProcessor(dataset, dataset_dir, transform_config)

dataset_train, dataset_val, dataset_test = preprocessor.load_dataset_splits(split_config)
print(f'Dataset splits created:')
print(f'  Train: {len(dataset_train)} samples')
print(f'  Val: {len(dataset_val)} samples')
print(f'  Test: {len(dataset_test)} samples')

# create the TopoBench datamodule / dataloader wrappers
datamodule = TBDataloader(dataset_train, dataset_val, dataset_test, batch_size=32)

print('Datasets and datamodule ready')

Processing...


[A123] Processing dataset from: data/a123_cortex_m/raw
[A123] Files in raw_dir: ['Auditory cortex data', '__MACOSX']
[A123] Starting extract_samples()...
Processing session 0: allPlanesVariables27-Feb-2021.mat
Processing session 1: allPlanesVariables27-Feb-2021.mat
Processing session 1: allPlanesVariables27-Feb-2021.mat
Processing session 2: allPlanesVariables27-Feb-2021.mat
Processing session 2: allPlanesVariables27-Feb-2021.mat
Processing session 3: allPlanesVariables27-Feb-2021.mat
Processing session 3: allPlanesVariables27-Feb-2021.mat
Processing session 4: allPlanesVariables27-Feb-2021.mat
Processing session 4: allPlanesVariables27-Feb-2021.mat
Processing session 5: allPlanesVariables27-Feb-2021.mat
Processing session 5: allPlanesVariables27-Feb-2021.mat
Processing session 6: allPlanesVariables27-Feb-2021.mat
Processing session 6: allPlanesVariables27-Feb-2021.mat
Processing session 7: allPlanesVariables27-Feb-2021.mat
Processing session 7: allPlanesVariables27-Feb-2021.mat
[A123]

Done!
Processing...


[A123 Loader] Loaded triangle classification task dataset
Dataset loaded: 335458 samples
Task 'triangle_classification' uses triangle-level features (no edge_index)
Skipping lifting transformation for triangle data


Done!


Dataset splits created:
  Train: 167729 samples
  Val: 83864 samples
  Test: 83865 samples
Datasets and datamodule ready


In [13]:
def undersample_majority_class(dataset, target_samples_per_class=100, random_state=42):
    """
    Undersample all classes to a target number of samples per class.
    
    Parameters
    ----------
    dataset : DataloadDataset
        Dataset to undersample
    target_samples_per_class : int
        Target number of samples per class (default: 100)
    random_state : int
        Random seed for reproducibility
        
    Returns
    -------
    DataloadDataset
        Undersampled dataset
    """
    np.random.seed(random_state)
    
    # Handle DataloadDataset which returns (values, keys) tuples
    labels = []
    for item in dataset:
        # The dataset returns (values_list, keys_list)
        if isinstance(item, (list, tuple)) and len(item) == 2:
            values, keys = item
            # The 'y' label is the last value in the list
            y = values[-1]
        else:
            # Fallback: try to access .y attribute
            if hasattr(item, 'y'):
                y = item.y
            else:
                continue  # Skip if we can't extract label
        
        # Convert tensor to scalar
        if hasattr(y, 'item'):
            labels.append(int(y.item()))
        elif hasattr(y, '__len__') and len(y) == 1:
            # Single-element tensor or array
            labels.append(int(y[0]))
        else:
            labels.append(int(y))
    
    # Check if we extracted any labels
    if len(labels) == 0:
        raise ValueError(f"No labels extracted from dataset of size {len(dataset)}")
    
    labels = np.array(labels)
    unique_labels, counts = np.unique(labels, return_counts=True)
    
    print(f"Original class distribution:")
    for label, count in zip(unique_labels, counts):
        print(f"  Class {label}: {count} samples")
    
    # Get indices for each class
    indices_by_class = {label: np.where(labels == label)[0] for label in unique_labels}
    
    # Undersample each class to target_samples_per_class (or fewer if class has fewer samples)
    undersampled_indices = []
    for label in unique_labels:
        indices = indices_by_class[label]
        # Select up to target_samples_per_class indices from this class
        actual_samples = min(len(indices), target_samples_per_class)
        selected = np.random.choice(indices, size=actual_samples, replace=False)
        undersampled_indices.extend(selected)
    
    # Shuffle the final indices
    undersampled_indices = np.random.permutation(undersampled_indices)
    
    # Create subset of dataset
    from torch.utils.data import Subset
    undersampled_dataset = Subset(dataset, undersampled_indices)
    
    # Get new label distribution
    new_labels = labels[undersampled_indices]
    new_unique, new_counts = np.unique(new_labels, return_counts=True)
    
    print(f"\nAfter undersampling to {target_samples_per_class} per class:")
    
    for label, count in zip(new_unique, new_counts):
        print(f"  Class {label}: {count} samples")
    
    imbalance_ratio_before = counts.max() / counts.min()
    imbalance_ratio_after = new_counts.max() / new_counts.min()
    print(f"\nImbalance ratio: {imbalance_ratio_before:.2f} → {imbalance_ratio_after:.2f}")
    print(f"Dataset size: {len(dataset)} → {len(undersampled_dataset)} samples\n")
    
    return undersampled_dataset

# Apply undersampling to training set
print("Undersampling training set...")
dataset_train = undersample_majority_class(dataset_train, target_samples_per_class=100, random_state=0)

# Optionally also undersample validation set for consistency
print("Undersampling validation set...")
dataset_val = undersample_majority_class(dataset_val, target_samples_per_class=100, random_state=0)

# Optionally also undersample test set for consistency
print("Undersampling test set...")
dataset_test = undersample_majority_class(dataset_test, target_samples_per_class=100, random_state=0)

# Recreate datamodule with undersampled datasets
datamodule = TBDataloader(dataset_train, dataset_val, dataset_test, batch_size=32)

print('Datasets rebalanced and datamodule recreated')

Undersampling training set...
Original class distribution:
  Class 0: 66 samples
  Class 1: 66 samples
  Class 3: 66 samples
  Class 4: 66 samples
  Class 6: 66 samples
  Class 7: 66 samples

After undersampling to 100 per class:
  Class 0: 66 samples
  Class 1: 66 samples
  Class 3: 66 samples
  Class 4: 66 samples
  Class 6: 66 samples
  Class 7: 66 samples

Imbalance ratio: 1.00 → 1.00
Dataset size: 396 → 396 samples

Undersampling validation set...
Original class distribution:
  Class 0: 38 samples
  Class 1: 38 samples
  Class 3: 38 samples
  Class 4: 38 samples
  Class 6: 38 samples
  Class 7: 38 samples

After undersampling to 100 per class:
  Class 0: 38 samples
  Class 1: 38 samples
  Class 3: 38 samples
  Class 4: 38 samples
  Class 6: 38 samples
  Class 7: 38 samples

Imbalance ratio: 1.00 → 1.00
Dataset size: 228 → 228 samples

Undersampling test set...
Original class distribution:
  Class 0: 39 samples
  Class 1: 39 samples
  Class 3: 39 samples
  Class 4: 39 samples
  Cla

In [14]:
# Debug: Inspect the actual dataset structure
print("Inspecting dataset structure...")
sample = dataset_train[0]
print(f"Type of sample: {type(sample)}")
print(f"Sample: {sample}")

if isinstance(sample, (tuple, list)):
    print(f"\nFirst element type: {type(sample[0])}")
    print(f"First element: {sample[0]}")
    if len(sample) > 1:
        print(f"Second element: {sample[1]}")
        
    # Try to access y from first element
    first = sample[0]
    print(f"\nAttributes of first element: {dir(first)}")
    if hasattr(first, 'y'):
        print(f"  .y = {first.y}")
    if hasattr(first, '__dict__'):
        print(f"  __dict__ = {first.__dict__}")


Inspecting dataset structure...
Type of sample: <class 'tuple'>
Sample: ([tensor([89]), tensor([[0.5515, 0.7086, 0.6578]]), 'bridge_strong', tensor([1]), tensor([0]), tensor([0]), tensor([ 3, 13, 35]), tensor([3])], ['graph_idx', 'x', 'role', 'train_mask', 'val_mask', 'test_mask', 'nodes', 'y'])

First element type: <class 'list'>
First element: [tensor([89]), tensor([[0.5515, 0.7086, 0.6578]]), 'bridge_strong', tensor([1]), tensor([0]), tensor([0]), tensor([ 3, 13, 35]), tensor([3])]
Second element: ['graph_idx', 'x', 'role', 'train_mask', 'val_mask', 'test_mask', 'nodes', 'y']

Attributes of first element: ['__add__', '__class__', '__class_getitem__', '__contains__', '__delattr__', '__delitem__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattribute__', '__getitem__', '__getstate__', '__gt__', '__hash__', '__iadd__', '__imul__', '__init__', '__init_subclass__', '__iter__', '__le__', '__len__', '__lt__', '__mul__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', 

## 4) Backbone definition

We implement a tiny backbone as a `pl.LightningModule` which computes node and hyperedge features: $X_1 = B_1 dot X_0$ and applies two linear layers with ReLU.

In [15]:
class MyBackbone(pl.LightningModule):
    def __init__(self, dim_hidden):
        super().__init__()
        self.linear_0 = torch.nn.Linear(dim_hidden, dim_hidden)
        self.linear_1 = torch.nn.Linear(dim_hidden, dim_hidden)

    def forward(self, batch):
        # batch.x_0: node features (dense tensor of shape [N, dim_hidden])
        # batch.incidence_hyperedges: sparse incidence matrix with shape [m, n] or [n, m] depending on preprocessor convention
        x_0 = batch.x_0
        incidence_hyperedges = getattr(batch, 'incidence_hyperedges', None)
        if incidence_hyperedges is None:
            # fallback: try incidence as batch.incidence if available
            incidence_hyperedges = getattr(batch, 'incidence', None)

        # compute hyperedge features X_1 = B_1 dot X_0 (we assume B_1 is sparse and transposed appropriately)
        x_1 = None
        if incidence_hyperedges is not None:
            try:
                x_1 = torch.sparse.mm(incidence_hyperedges, x_0)
            except Exception:
                # if orientation differs, try transpose
                x_1 = torch.sparse.mm(incidence_hyperedges.T, x_0)
        else:
            # no incidence available: create a zero hyperedge feature placeholder
            x_1 = torch.zeros_like(x_0)

        x_0 = self.linear_0(x_0)
        x_0 = torch.relu(x_0)

        x_1 = self.linear_1(x_1)
        x_1 = torch.relu(x_1)

        model_out = {'labels': batch.y, 'batch_0': getattr(batch, 'batch_0', None)}
        model_out['x_0'] = x_0
        model_out['hyperedge'] = x_1
        return model_out

print('Backbone defined')

Backbone defined


In [16]:
# 5) Model initialization (components)
backbone = MyBackbone(dim_hidden)
readout = PropagateSignalDown(**readout_config)
loss = TBLoss(**loss_config)
feature_encoder = AllCellFeatureEncoder(in_channels=[in_channels], out_channels=dim_hidden)
evaluator = TBEvaluator(**evaluator_config)
optimizer = TBOptimizer(**optimizer_config)

print('Components instantiated')

Components instantiated


In [17]:
# 6) Instantiate TBModel
model = TBModel(backbone=backbone,
                backbone_wrapper=None,
                readout=readout,
                loss=loss,
                feature_encoder=feature_encoder,
                evaluator=evaluator,
                optimizer=optimizer,
                compile=False)

# Print a short summary (repr) to verify construction
print(model)

TBModel(backbone=MyBackbone(
  (linear_0): Linear(in_features=16, out_features=16, bias=True)
  (linear_1): Linear(in_features=16, out_features=16, bias=True)
), readout=PropagateSignalDown(num_cell_dimensions=0, self.hidden_dim=16, readout_name=PropagateSignalDown, loss=TBLoss(losses=[DatasetLoss(task=classification, loss_type=cross_entropy)]), feature_encoder=AllCellFeatureEncoder(in_channels=[3], out_channels=16, dimensions=range(0, 1)))


In [18]:
# 7) Training loop (Lightning trainer)
# Suppress some warnings for cleaner output
import warnings
warnings.filterwarnings('ignore', category=UserWarning, module='torchmetrics')

trainer = pl.Trainer(
    max_epochs=2,  # reduced for faster iteration
    accelerator='cpu',
    enable_progress_bar=True,
    log_every_n_steps=1,
    enable_model_summary=False,  # skip the model summary printout
)
trainer.fit(model, datamodule)
train_metrics = trainer.callback_metrics

print('\nTraining finished. Collected metrics:')
for key, val in train_metrics.items():
    try:
        print(f'{key:25s} {float(val):.4f}')
    except Exception:
        print(key, val)

GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/Users/mariayuffa/anaconda3/envs/tb3/lib/python3.11/site-packages/lightning/pytorch/trainer/setup.py:177: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.
/Users/mariayuffa/anaconda3/envs/tb3/lib/python3.11/site-packages/lightning/pytorch/trainer/setup.py:177: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/Users/mariayuffa/anaconda3/envs/tb3/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.
/Users/mariayuffa/anaconda3/envs/tb3/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=2` reached.



Training finished. Collected metrics:
train/accuracy            0.1667
train/f1                  0.0476
train/precision           0.0278
train/recall              0.1667
val/loss                  2.1869
val/accuracy              0.1667
val/f1                    0.0476
val/precision             0.0278
val/recall                0.1667
train/loss                2.1941


In [19]:
# 8) Testing and printing metrics
trainer.test(model, datamodule)
test_metrics = trainer.callback_metrics
print('\nTest metrics:')
for key, val in test_metrics.items():
    try:
        print(f'{key:25s} {float(val):.4f}')
    except Exception:
        print(key, val)

/Users/mariayuffa/anaconda3/envs/tb3/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.


Testing: |          | 0/? [00:00<?, ?it/s]


Test metrics:
test/loss                 2.1825
test/accuracy             0.1667
test/f1                   0.0476
test/precision            0.0278
test/recall               0.1667


## Running Other Tasks

To run a different task, modify the `TASK_NAME` variable in cell 4 (configurations) to one of:
- `'graph_classification'` (default): Predict frequency bin from graph structure
- `'triangle_classification'`: Classify topological role of triangles (9 embedding × weight classes)
- `'triangle_common_neighbors'`: Predict number of common neighbors for each triangle

Then re-run the configuration cell and subsequent cells. The dataset will automatically load the appropriate task variant, and the model will be configured with the correct number of output classes (9 for all tasks).

### Task Details:

**Task 1: Graph-level Classification**
- Input: Graph structure with node features (mean correlation, std correlation, noise diagonal)
- Output: Frequency bin (0-8) representing the best frequency
- Level: Graph-level prediction

**Task 2: Triangle Classification**
- Input: Topological features of triangles (3 edge weights from correlation matrix)
- Output: Triangle role classification (9 classes based on embedding × weight):
  - Embedding classes: Core (many common neighbors), Bridge (some), Isolated (few)
  - Weight classes: Strong (high correlation), Medium, Weak (low correlation)
- Level: Triangle (motif) level prediction

**Task 3: Triangle Common-Neighbors**
- Input: Triangle node degrees (structural features)
- Output: Number of common neighbors (0-8, mapping neighbors count to class)
- Level: Triangle (motif) level prediction