# Training TBModel on Auditory Cortex Data for 1 and 2/3 regions.

This notebook demonstrates loading the MUTAG dataset, applying a simple lifting, defining a small backbone, and training a `TBModel` using `TBLoss` and `TBOptimizer`.

Requirements: the project installed in PYTHONPATH and optional dependencies (torch_geometric, networkx, ripser/persim) if you want advanced features.

In [4]:
import os
os.chdir('..')

In [5]:
# 1) Imports
import torch
import lightning as pl
from omegaconf import OmegaConf

# Data loading / preprocessing utilities from the repo
from topobench.data.loaders.graph.a123_loader import A123DatasetLoader
from topobench.dataloader.dataloader import TBDataloader
from topobench.data.preprocessor import PreProcessor

# Model / training building blocks
from topobench.model.model import TBModel
# example backbone building block (SCN2 is optional; we provide a tiny custom backbone below)
# from topomodelx.nn.simplicial.scn2 import SCN2
from topobench.nn.wrappers.simplicial import SCNWrapper
from topobench.nn.encoders import AllCellFeatureEncoder
from topobench.nn.readouts import PropagateSignalDown

# Optimization / evaluation
from topobench.loss.loss import TBLoss
from topobench.optimizer import TBOptimizer
from topobench.evaluator.evaluator import TBEvaluator

print('Imports OK')

Imports OK


In [6]:
# 2) Configurations and utilities
loader_config = {
    'data_domain': 'graph',
    'data_type': 'A123',
    # the loader/dataset expects the dataset name key used in the dataset class
    'data_name': 'a123_cortex_m',
    'data_dir': './data/a123/'
}

# Transform config: single transform with transform_name and transform_type
# PreProcessor expects either {"transform_name": ...} (single) or {"key1": {...}, "key2": {...}} (multiple)
transform_config = {
    'transform_type': 'lifting',
    'transform_name': 'HypergraphKHopLifting',
    'k_value': 1,
}

split_config = {
    'learning_setting': 'inductive',
    'split_type': 'random',
    'data_seed': 0,
    'data_split_dir': './data/a123/splits/',
    'train_prop': 0.5,
}

# model / task hyperparameters
# A123 sample node features are: [mean_corr, std_corr, noise_diag] => 3 channels
in_channels = 3
# Multiclass classification: 9 frequency bins (bf_bin 0-8)
out_channels = 9
dim_hidden = 16
n_bins = 9  # default binning from extract_samples
hodge_k = 10  # Number of Hodge L1 eigenvalues to use (from dataset config)

readout_config = {
    'readout_name': 'PropagateSignalDown',
    'num_cell_dimensions': 1,
    'hidden_dim': dim_hidden,
    'out_channels': out_channels,
    'task_level': 'graph',
    'pooling_type': 'sum',
}

loss_config = {
    'dataset_loss': {
        'task': 'classification',
        'loss_type': 'cross_entropy',
    }
}

evaluator_config = {
    'task': 'classification',
    'num_classes': out_channels,
    'metrics': ['accuracy', 'precision', 'recall'],
}

optimizer_config = {
    'optimizer_id': 'Adam',
    'parameters': {'lr': 0.001, 'weight_decay': 0.0005},
}

# convert to OmegaConf (the project often expects DictConfig)
loader_config = OmegaConf.create(loader_config)
transform_config = OmegaConf.create(transform_config)
split_config = OmegaConf.create(split_config)
readout_config = OmegaConf.create(readout_config)
loss_config = OmegaConf.create(loss_config)
evaluator_config = OmegaConf.create(evaluator_config)
optimizer_config = OmegaConf.create(optimizer_config)

print('Configs created')

Configs created


In [7]:
# 3) Loading the data

# Use the A123-specific loader (A123DatasetLoader) to construct the dataset
graph_loader = A123DatasetLoader(loader_config)

dataset, dataset_dir = graph_loader.load()
print('Dataset loaded')

preprocessor = PreProcessor(dataset, dataset_dir, transform_config)
dataset_train, dataset_val, dataset_test = preprocessor.load_dataset_splits(split_config)
print('Dataset splits created')

# create the TopoBench datamodule / dataloader wrappers
datamodule = TBDataloader(dataset_train, dataset_val, dataset_test, batch_size=32)

print('Datasets and datamodule ready')

Processing...


[A123] Processing dataset from: data/a123_cortex_m/raw
[A123] Files in raw_dir: ['Auditory cortex data', '__MACOSX']
[A123] Starting extract_samples()...
Processing session 0: allPlanesVariables27-Feb-2021.mat
Processing session 1: allPlanesVariables27-Feb-2021.mat
Processing session 1: allPlanesVariables27-Feb-2021.mat
Processing session 2: allPlanesVariables27-Feb-2021.mat
Processing session 2: allPlanesVariables27-Feb-2021.mat
Processing session 3: allPlanesVariables27-Feb-2021.mat
Processing session 3: allPlanesVariables27-Feb-2021.mat
Processing session 4: allPlanesVariables27-Feb-2021.mat
Processing session 4: allPlanesVariables27-Feb-2021.mat
Processing session 5: allPlanesVariables27-Feb-2021.mat
Processing session 5: allPlanesVariables27-Feb-2021.mat
Processing session 6: allPlanesVariables27-Feb-2021.mat
Processing session 6: allPlanesVariables27-Feb-2021.mat
Processing session 7: allPlanesVariables27-Feb-2021.mat
Processing session 7: allPlanesVariables27-Feb-2021.mat
[A123]

Done!
Processing...


Dataset splits created
Datasets and datamodule ready


Done!


## 4) Backbone definition

We implement a tiny backbone as a `pl.LightningModule` which computes node and hyperedge features: $X_1 = B_1 dot X_0$ and applies two linear layers with ReLU.

### Using Hodge L1 Topological Features

The A123 dataset computes **Hodge L1 eigenvalues** for each graph sample during processing (controlled by the `hodge_k` parameter in config). These topological features are attached to each `Data` object as `data.hodge_l1` and are automatically batched into `batch.hodge_l1` during training.

To use Hodge L1 features in your model:
1. Access `batch.hodge_l1` (shape: `[batch_size, hodge_k]`) in your backbone or readout forward pass
2. Process it as additional graph-level information (e.g., concatenate with pooled node features or pass to a separate MLP)
3. Fuse with the main graph embedding before the final classifier

Example usage in a custom backbone or readout:
```python
if hasattr(batch, 'hodge_l1') and batch.hodge_l1 is not None:
    # hodge_l1 shape: [batch_size, hodge_k]
    hodge_features = batch.hodge_l1
    # Process and fuse with other features as needed
```

See `configs/dataset/graph/a123.yaml` for the `hodge_k` parameter configuration.

In [None]:
# Quick check: Verify that hodge_l1 features are present in a batch
# Uncomment to inspect batch contents during development

batch = next(iter(datamodule.train_dataloader()))
print("Batch keys:", dir(batch))
if hasattr(batch, 'hodge_l1'):
    print(f"hodge_l1 shape: {batch.hodge_l1.shape}")  # Expected: [batch_size, hodge_k]
    print(f"hodge_l1 sample:\n{batch.hodge_l1[0]}")
else:
    print("hodge_l1 not found in batch (may need to reprocess dataset)")


Batch keys: ['__abstractmethods__', '__annotations__', '__call__', '__cat_dim__', '__class__', '__contains__', '__copy__', '__deepcopy__', '__delattr__', '__delitem__', '__dict__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattr__', '__getattribute__', '__getitem__', '__getstate__', '__gt__', '__hash__', '__inc__', '__init__', '__init_subclass__', '__iter__', '__le__', '__len__', '__lt__', '__module__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__setitem__', '__setstate__', '__sizeof__', '__slots__', '__str__', '__subclasshook__', '__weakref__', '_abc_impl', '_edge_attr_cls', '_edge_to_layout', '_edges_to_layout', '_find_parent', '_get_edge_index', '_get_tensor', '_get_tensor_size', '_multi_get_tensor', '_put_edge_index', '_put_tensor', '_remove_edge_index', '_remove_tensor', '_store', '_tensor_attr_cls', '_to_type', '_union', 'apply', 'apply_', 'batch', 'batch_size', 'clone', 'coalesce', 'concat', 'connected_components', 'contai

In [None]:
# class MyBackbone(pl.LightningModule):
#     def __init__(self, dim_hidden):
#         super().__init__()
#         self.linear_0 = torch.nn.Linear(dim_hidden, dim_hidden)
#         self.linear_1 = torch.nn.Linear(dim_hidden, dim_hidden)

#     def forward(self, batch):
#         # batch.x_0: node features (dense tensor of shape [N, dim_hidden])
#         # batch.incidence_hyperedges: sparse incidence matrix with shape [m, n] or [n, m] depending on preprocessor convention
#         x_0 = batch.x_0
#         incidence_hyperedges = getattr(batch, 'incidence_hyperedges', None)
#         if incidence_hyperedges is None:
#             # fallback: try incidence as batch.incidence if available
#             incidence_hyperedges = getattr(batch, 'incidence', None)

#         # compute hyperedge features X_1 = B_1 dot X_0 (we assume B_1 is sparse and transposed appropriately)
#         x_1 = None
#         if incidence_hyperedges is not None:
#             try:
#                 x_1 = torch.sparse.mm(incidence_hyperedges, x_0)
#             except Exception:
#                 # if orientation differs, try transpose
#                 x_1 = torch.sparse.mm(incidence_hyperedges.T, x_0)
#         else:
#             # no incidence available: create a zero hyperedge feature placeholder
#             x_1 = torch.zeros_like(x_0)

#         x_0 = self.linear_0(x_0)
#         x_0 = torch.relu(x_0)

#         x_1 = self.linear_1(x_1)
#         x_1 = torch.relu(x_1)

#         model_out = {'labels': batch.y, 'batch_0': getattr(batch, 'batch_0', None)}
#         model_out['x_0'] = x_0
#         model_out['hyperedge'] = x_1
#         return model_out

# print('Backbone defined')

Backbone defined


In [42]:
# IMPROVED: Properly fuse Hodge features with node embeddings
class MyBackbone(pl.LightningModule):
    """Backbone that properly integrates Hodge L1 features into node embeddings."""
    
    def __init__(self, dim_hidden, hodge_k=10, use_hodge=True):
        super().__init__()
        self.dim_hidden = dim_hidden
        self.config_hodge_k = hodge_k  # Store config value
        self.use_hodge = use_hodge
        self.actual_hodge_k = None  # Will be validated on first forward pass
        self.layers_initialized = False
        
        # Create dummy linear layers - will be recreated with correct dimensions on first forward pass
        self.linear_0 = torch.nn.Linear(dim_hidden, dim_hidden)
        self.linear_1 = torch.nn.Linear(dim_hidden, dim_hidden)

    def _initialize_layers(self, x_0_shape, hodge_k):
        """Initialize linear layers with correct dimensions based on data."""
        if self.layers_initialized:
            return
        
        in_features_0 = x_0_shape[1] + (hodge_k if self.use_hodge else 0)
        
        # Recreate linear layers with correct input dimensions
        self.linear_0 = torch.nn.Linear(in_features_0, self.dim_hidden)
        self.linear_1 = torch.nn.Linear(self.dim_hidden, self.dim_hidden)
        
        # Move to same device as input
        self.linear_0 = self.linear_0.to(x_0_shape[0].device if isinstance(x_0_shape[0], torch.Tensor) else 'cpu')
        self.linear_1 = self.linear_1.to(x_0_shape[0].device if isinstance(x_0_shape[0], torch.Tensor) else 'cpu')
        
        self.layers_initialized = True
        print(f"[Backbone] Initialized layers: linear_0({in_features_0}→{self.dim_hidden}), linear_1({self.dim_hidden}→{self.dim_hidden})")

    def forward(self, batch):
        x_0 = batch.x_0  # Shape: [num_nodes, dim_hidden]
        batch_0 = getattr(batch, 'batch_0', None)
        
        # Store original x_0 for hyperedge computation
        x_0_original = x_0
        x_0_augmented = x_0
        
        # Augment with Hodge features if enabled
        if self.use_hodge and hasattr(batch, 'hodge_l1') and batch.hodge_l1 is not None:
            hodge_l1 = batch.hodge_l1
            batch_size = batch.num_graphs if hasattr(batch, 'num_graphs') else batch.y.shape[0]
            
            # Validate hodge_k on first pass (strict mode - no inference)
            if self.actual_hodge_k is None:
                if hodge_l1.dim() == 1:
                    inferred_hodge_k = hodge_l1.shape[0] // batch_size
                else:
                    inferred_hodge_k = hodge_l1.shape[1]
                
                # Check if it matches config - raise error if mismatch
                if inferred_hodge_k != self.config_hodge_k:
                    raise ValueError(
                        f"\n{'='*70}\n"
                        f"HODGE FEATURES CONFIGURATION MISMATCH\n"
                        f"{'='*70}\n"
                        f"Config hodge_k:  {self.config_hodge_k}\n"
                        f"Actual hodge_k:  {inferred_hodge_k}\n"
                        f"{'='*70}\n"
                        f"\nPossible causes:\n"
                        f"  1. Dataset was preprocessed with hodge_k={inferred_hodge_k}\n"
                        f"  2. Config parameter doesn't match the data\n"
                        f"\nSolutions:\n"
                        f"  A) Update config: hodge_k={inferred_hodge_k}\n"
                        f"  B) Disable Hodge features: use_hodge=False\n"
                        f"  C) Reprocess dataset with hodge_k={self.config_hodge_k}\n"
                        f"{'='*70}\n"
                    )
                
                self.actual_hodge_k = inferred_hodge_k
                # Initialize layers if this is first forward pass
                if not self.layers_initialized:
                    self._initialize_layers(x_0.shape, self.actual_hodge_k)
                    print(f"[Backbone] Hodge features validated: hodge_k={self.actual_hodge_k} ✓")
            
            # Reshape and broadcast Hodge features
            if hodge_l1.dim() == 1:
                hodge_l1 = hodge_l1.reshape(batch_size, self.actual_hodge_k)
            
            hodge_expanded = hodge_l1[batch_0]  # [num_nodes, actual_hodge_k]
            x_0_augmented = torch.cat([x_0, hodge_expanded], dim=1)
        else:
            # Initialize layers even without Hodge features
            if not self.layers_initialized:
                self._initialize_layers(x_0.shape, 0)
        
        # Get incidence matrix
        incidence_hyperedges = getattr(batch, 'incidence_hyperedges', None)
        if incidence_hyperedges is None:
            incidence_hyperedges = getattr(batch, 'incidence', None)

        # Compute hyperedge features from ORIGINAL x_0
        if incidence_hyperedges is not None:
            try:
                x_1 = torch.sparse.mm(incidence_hyperedges, x_0_original)
            except Exception:
                x_1 = torch.sparse.mm(incidence_hyperedges.T, x_0_original)
        else:
            x_1 = torch.zeros(x_0_original.shape[0], self.dim_hidden, device=x_0_original.device)

        # Apply transformations
        x_0_out = self.linear_0(x_0_augmented)
        x_0_out = torch.relu(x_0_out)

        x_1 = self.linear_1(x_1)
        x_1 = torch.relu(x_1)

        model_out = {
            'labels': batch.y,
            'batch_0': batch_0,
            'x_0': x_0_out,
            'hyperedge': x_1,
        }
        return model_out

print('Improved Backbone with Hodge Feature Fusion defined (strict validation mode)')


Improved Backbone with Hodge Feature Fusion defined (strict validation mode)


In [None]:
# 5) Model initialization (components)
# Enable Hodge features - backbone will validate hodge_k on first forward pass
# If hodge_k mismatch detected, will raise ValueError with helpful suggestions
backbone = MyBackbone(dim_hidden, hodge_k=hodge_k, use_hodge=True)
readout = PropagateSignalDown(**readout_config)
loss = TBLoss(**loss_config)
feature_encoder = AllCellFeatureEncoder(in_channels=[in_channels], out_channels=dim_hidden)
evaluator = TBEvaluator(**evaluator_config)
optimizer = TBOptimizer(**optimizer_config)

print('Components instantiated (Hodge features ENABLED with strict validation)')


Components instantiated (Hodge features disabled for now)


In [39]:
# 6) Instantiate TBModel
model = TBModel(backbone=backbone,
                backbone_wrapper=None,
                readout=readout,
                loss=loss,
                feature_encoder=feature_encoder,
                evaluator=evaluator,
                optimizer=optimizer,
                compile=False)

# Print a short summary (repr) to verify construction
print(model)

TBModel(backbone=MyBackbone(
  (linear_0): Linear(in_features=26, out_features=16, bias=True)
  (linear_1): Linear(in_features=16, out_features=16, bias=True)
), readout=PropagateSignalDown(num_cell_dimensions=0, self.hidden_dim=16, readout_name=PropagateSignalDown, loss=TBLoss(losses=[DatasetLoss(task=classification, loss_type=cross_entropy)]), feature_encoder=AllCellFeatureEncoder(in_channels=[3], out_channels=16, dimensions=range(0, 1)))


In [41]:
# 7) Training loop (Lightning trainer)
# Suppress some warnings for cleaner output
import warnings
warnings.filterwarnings('ignore', category=UserWarning, module='torchmetrics')

trainer = pl.Trainer(
    max_epochs=50,  # reduced for faster iteration
    accelerator='cpu',
    enable_progress_bar=True,
    log_every_n_steps=1,
    enable_model_summary=False,  # skip the model summary printout
)
trainer.fit(model, datamodule)
train_metrics = trainer.callback_metrics

print('\nTraining finished. Collected metrics:')
for key, val in train_metrics.items():
    try:
        print(f'{key:25s} {float(val):.4f}')
    except Exception:
        print(key, val)

GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/Users/mariayuffa/anaconda3/envs/tb/lib/python3.11/site-packages/lightning/pytorch/trainer/setup.py:177: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/Users/mariayuffa/anaconda3/envs/tb/lib/python3.11/site-packages/lightning/pytorch/trainer/setup.py:177: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/Users/mariayuffa/anaconda3/envs/tb/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.
/Users/mariayuffa/anaconda3/envs/tb/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

RuntimeError: Inference tensors cannot be saved for backward. To work around you can make a clone to get a normal tensor and use it in autograd.

In [40]:
# 8) Testing and printing metrics
trainer.test(model, datamodule)
test_metrics = trainer.callback_metrics
print('\nTest metrics:')
for key, val in test_metrics.items():
    try:
        print(f'{key:25s} {float(val):.4f}')
    except Exception:
        print(key, val)

/Users/mariayuffa/anaconda3/envs/tb/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.


Testing: |          | 0/? [00:00<?, ?it/s]

[Backbone] Inferred hodge_k=6 (config had 10)
[Backbone] Recreated linear_0 with in_features=22



Test metrics:
test/loss                 36.3307
test/accuracy             0.0476
test/precision            0.0108
test/recall               0.0667


### Without Hodge features (100 epochs)

<pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃<span style="font-weight: bold">        Test metric        </span>┃<span style="font-weight: bold">       DataLoader 0        </span>┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│<span style="color: #008080; text-decoration-color: #008080">       test/accuracy       </span>│<span style="color: #800080; text-decoration-color: #800080">    0.2222222238779068     </span>│
│<span style="color: #008080; text-decoration-color: #008080">         test/loss         </span>│<span style="color: #800080; text-decoration-color: #800080">    2.0435843467712402     </span>│
│<span style="color: #008080; text-decoration-color: #008080">      test/precision       </span>│<span style="color: #800080; text-decoration-color: #800080">    0.05868902429938316    </span>│
│<span style="color: #008080; text-decoration-color: #008080">        test/recall        </span>│<span style="color: #800080; text-decoration-color: #800080">    0.16466346383094788    </span>│
└───────────────────────────┴───────────────────────────┘
</pre>


Test metrics:

test/loss                 2.0436

test/accuracy             0.2222

test/precision            0.0587

test/recall               0.1647

In [32]:

# DEBUG: Check backbone shapes before training
print("="*60)
print("DEBUGGING BACKBONE SHAPES")
print("="*60)

batch = next(iter(datamodule.val_dataloader()))
print(f"\nBatch info:")
print(f"  batch.x_0 shape: {batch.x_0.shape}")
print(f"  batch.y shape: {batch.y.shape}")
print(f"  batch.batch_0 shape: {batch.batch_0.shape}")
print(f"  num_graphs: {batch.num_graphs}")
if hasattr(batch, 'hodge_l1'):
    print(f"  batch.hodge_l1 shape: {batch.hodge_l1.shape}")
    print(f"  batch.hodge_l1 type: {type(batch.hodge_l1)}")
if hasattr(batch, 'incidence_hyperedges'):
    print(f"  batch.incidence_hyperedges shape: {batch.incidence_hyperedges.shape}")
elif hasattr(batch, 'incidence'):
    print(f"  batch.incidence shape: {batch.incidence.shape}")

print(f"\nBackbone __init__ parameters:")
print(f"  dim_hidden: {dim_hidden}")
print(f"  hodge_k: {hodge_k}")
print(f"  use_hodge: {backbone.use_hodge}")
print(f"  backbone.linear_0.in_features: {backbone.linear_0.in_features}")
print(f"  backbone.linear_0.out_features: {backbone.linear_0.out_features}")
print(f"  backbone.linear_1.in_features: {backbone.linear_1.in_features}")
print(f"  backbone.linear_1.out_features: {backbone.linear_1.out_features}")

# First, run feature encoder
print(f"\n--- Feature Encoder Stage ---")
model_out_encoded = feature_encoder(batch)
print(f"After feature encoder:")
print(f"  model_out_encoded['x_0'] shape: {model_out_encoded.get('x_0', batch.x_0).shape}")
print(f"  model_out_encoded keys: {list(model_out_encoded.keys())}")

# Test backbone alone on encoded features
print(f"\n--- Backbone Stage ---")
try:
    model_out_backbone = backbone(model_out_encoded)
    print(f"✓ Backbone forward pass successful!")
    print(f"  model_out['x_0'] shape: {model_out_backbone['x_0'].shape}")
    print(f"  model_out['hyperedge'] shape: {model_out_backbone['hyperedge'].shape}")
except Exception as e:
    print(f"✗ Backbone forward pass failed:")
    print(f"  Error: {e}")
    import traceback
    traceback.print_exc()


DEBUGGING BACKBONE SHAPES

Batch info:
  batch.x_0 shape: torch.Size([778, 3])
  batch.y shape: torch.Size([32])
  batch.batch_0 shape: torch.Size([778])
  num_graphs: 32
  batch.hodge_l1 shape: torch.Size([192])
  batch.hodge_l1 type: <class 'torch.Tensor'>
  batch.incidence_hyperedges shape: torch.Size([778, 778])

Backbone __init__ parameters:
  dim_hidden: 16
  hodge_k: 10
  use_hodge: False
  backbone.linear_0.in_features: 16
  backbone.linear_0.out_features: 16
  backbone.linear_1.in_features: 16
  backbone.linear_1.out_features: 16

--- Feature Encoder Stage ---
After feature encoder:
  model_out_encoded['x_0'] shape: torch.Size([778, 16])
  model_out_encoded keys: ['test_mask', 'edge_index', 'session_id', 'val_mask', 'batch_0', 'ptr', 'y', 'batch_hyperedges', 'train_mask', 'num_hyperedges', 'edge_attr', 'layer', 'hodge_l1', 'x_0', 'x', 'x_hyperedges', 'incidence_hyperedges']

--- Backbone Stage ---
✓ Backbone forward pass successful!
  model_out['x_0'] shape: torch.Size([778, 1

## Why Hodge L1 Features Lead to Overfitting

### Problems with the Current Implementation:

1. **Disconnected Features**: The `hodge_emb` is computed but never used in the readout or loss
   - It adds learnable parameters that aren't tied to the prediction task
   - Model can overfit to training noise without helping generalization

2. **Improper Integration**: Features aren't fused with the main prediction pathway
   - The readout only uses `x_0`, `x_1`, etc. cell features
   - Hodge features are isolated and ignored during inference

3. **Batching Issues**: Hodge features concatenate incorrectly
   - Training batches may cause feature-label misalignment
   - Individual graph-level features get flattened into a 1D tensor

### Solutions:

#### Option 1: Properly Fuse Hodge Features into Node Embeddings (Recommended)
Instead of computing separate `hodge_emb`, **augment the initial node features** with Hodge information:

```python
class MyBackbone(pl.LightningModule):
    def __init__(self, dim_hidden, hodge_k=10):
        super().__init__()
        self.linear_0 = torch.nn.Linear(dim_hidden + hodge_k, dim_hidden)  # Include hodge_k
        self.linear_1 = torch.nn.Linear(dim_hidden, dim_hidden)
        self.hodge_k = hodge_k

    def forward(self, batch):
        x_0 = batch.x_0
        batch_size = batch.num_graphs if hasattr(batch, 'num_graphs') else batch.y.shape[0]
        
        # Expand hodge features to match node dimension
        hodge_l1 = None
        if hasattr(batch, 'hodge_l1') and batch.hodge_l1 is not None:
            # Reshape from flattened to [batch_size, hodge_k]
            hodge_l1 = batch.hodge_l1.reshape(batch_size, -1)
            
            # Broadcast to all nodes in each graph
            batch_0 = batch.batch_0
            hodge_expanded = hodge_l1[batch_0]  # [num_nodes, hodge_k]
            
            # Concatenate with node features
            x_0 = torch.cat([x_0, hodge_expanded], dim=1)
        
        # Now x_0 has shape [num_nodes, dim_hidden + hodge_k]
        x_0 = self.linear_0(x_0)
        x_0 = torch.relu(x_0)
        # ... rest of backbone
```

#### Option 2: Use Hodge Features as Graph-Level Auxiliary Information
Fuse Hodge features after graph-level pooling in the readout or create a custom readout.

#### Option 3: Skip Hodge Features for Now
If integration is complex, simply don't use them. The base model performs reasonably without them.

In [None]:

# IMPROVED: Properly fuse Hodge features with node embeddings
class MyBackboneWithHodgeFusion(pl.LightningModule):
    """Backbone that properly integrates Hodge L1 features into node embeddings."""
    
    def __init__(self, dim_hidden, hodge_k=10, use_hodge=True):
        super().__init__()
        self.hodge_k = hodge_k
        self.use_hodge = use_hodge
        
        # Input dimension depends on whether we use hodge features
        in_dim = dim_hidden + (hodge_k if use_hodge else 0)
        
        self.linear_0 = torch.nn.Linear(in_dim, dim_hidden)
        self.linear_1 = torch.nn.Linear(dim_hidden, dim_hidden)

    def forward(self, batch):
        x_0 = batch.x_0  # Shape: [num_nodes, dim_hidden]
        
        # Optionally augment node features with graph-level Hodge features
        if self.use_hodge and hasattr(batch, 'hodge_l1') and batch.hodge_l1 is not None:
            # Get batch structure
            batch_0 = batch.batch_0  # Maps nodes to graphs
            batch_size = batch.num_graphs if hasattr(batch, 'num_graphs') else batch.y.shape[0]
            
            # Reshape hodge_l1 from [batch_size * hodge_k] to [batch_size, hodge_k]
            hodge_l1 = batch.hodge_l1
            if hodge_l1.dim() == 1:
                hodge_l1 = hodge_l1.reshape(batch_size, -1)
            
            # Broadcast graph-level features to each node in the graph
            # batch_0[i] gives the graph index for node i
            hodge_expanded = hodge_l1[batch_0]  # [num_nodes, hodge_k]
            
            # Concatenate: [num_nodes, dim_hidden + hodge_k]
            x_0 = torch.cat([x_0, hodge_expanded], dim=1)
        
        # Process through network
        incidence_hyperedges = getattr(batch, 'incidence_hyperedges', None)
        if incidence_hyperedges is None:
            incidence_hyperedges = getattr(batch, 'incidence', None)

        # Compute hyperedge features
        x_1 = None
        if incidence_hyperedges is not None:
            try:
                x_1 = torch.sparse.mm(incidence_hyperedges, x_0)
            except Exception:
                x_1 = torch.sparse.mm(incidence_hyperedges.T, x_0)
        else:
            x_1 = torch.zeros(x_0.shape[0], x_0.shape[1], device=x_0.device)

        # Apply linear transformations
        x_0 = self.linear_0(x_0)
        x_0 = torch.relu(x_0)

        x_1 = self.linear_1(x_1[:, :self.linear_1.in_features])  # Project if needed
        x_1 = torch.relu(x_1)

        model_out = {
            'labels': batch.y,
            'batch_0': batch_0 if 'batch_0' in locals() else getattr(batch, 'batch_0', None),
            'x_0': x_0,
            'hyperedge': x_1,
        }
        return model_out

print('Improved Backbone with Hodge Feature Fusion defined')


In [None]:

# Experiment: Compare models with and without Hodge feature integration
import pandas as pd

results = []

for use_hodge in [False, True]:
    model_name = f"Model {'WITH' if use_hodge else 'WITHOUT'} Hodge Features"
    print(f"\n{'='*60}")
    print(f"Training: {model_name}")
    print(f"{'='*60}\n")
    
    # Create fresh backbone for this experiment
    backbone = MyBackboneWithHodgeFusion(dim_hidden, hodge_k=hodge_k, use_hodge=use_hodge)
    readout = PropagateSignalDown(**readout_config)
    loss = TBLoss(**loss_config)
    feature_encoder = AllCellFeatureEncoder(in_channels=[in_channels], out_channels=dim_hidden)
    evaluator = TBEvaluator(**evaluator_config)
    optimizer = TBOptimizer(**optimizer_config)
    
    model = TBModel(
        backbone=backbone,
        backbone_wrapper=None,
        readout=readout,
        loss=loss,
        feature_encoder=feature_encoder,
        evaluator=evaluator,
        optimizer=optimizer,
        compile=False
    )
    
    trainer = pl.Trainer(
        max_epochs=50,  # Reduced epochs for faster comparison
        accelerator='cpu',
        enable_progress_bar=True,
        log_every_n_steps=5,
        enable_model_summary=False,
    )
    
    trainer.fit(model, datamodule)
    
    # Get metrics
    train_metrics = trainer.callback_metrics
    trainer.test(model, datamodule)
    test_metrics = trainer.callback_metrics
    
    results.append({
        'Model': model_name,
        'Train Loss': float(train_metrics.get('train_loss_epoch', 0)),
        'Test Accuracy': float(test_metrics.get('test/accuracy', 0)),
        'Test Loss': float(test_metrics.get('test/loss', 0)),
        'Test Precision': float(test_metrics.get('test/precision', 0)),
        'Test Recall': float(test_metrics.get('test/recall', 0)),
    })

# Display comparison
results_df = pd.DataFrame(results)
print("\n" + "="*80)
print("COMPARISON RESULTS")
print("="*80)
print(results_df.to_string(index=False))
print("\nKey Observation:")
print("If Hodge features are not properly integrated, the model WITH Hodge features")
print("may show WORSE test performance due to overfitting on irrelevant feature space.")


## Summary: Hodge L1 Features and Overfitting

### Root Causes of Poor Performance:

| Issue | Impact | Solution |
|-------|--------|----------|
| **Features Not Used** | `hodge_emb` computed but ignored by readout & loss | Fuse into node features or custom readout |
| **Unintegrated Params** | Extra learnable params → overfitting on noise | Connect to prediction pathway |
| **Batching Artifacts** | Graph features concatenated incorrectly | Reshape and broadcast to nodes properly |
| **Model Capacity** | More params without improved signal → overfit | Use regularization (dropout, L2) |

### Key Insights:

1. **Topological features alone aren't enough** - they must be properly integrated into the model's representation learning process

2. **Concatenating features ≠ Learning from them** - simply adding features without ensuring they're used in predictions doesn't help

3. **Graph-level features need broadcasting** - Hodge L1 is computed per-graph but predictions happen at node/edge level, so features must be correctly mapped

### Recommendations:

✅ **DO:**
- Fuse Hodge features with node embeddings early (as shown in `MyBackboneWithHodgeFusion`)
- Validate that features actually affect predictions (check gradients)
- Use early stopping and validation curves to detect overfitting
- Consider dimensionality reduction of Hodge features if hodge_k is large

❌ **DON'T:**
- Compute features that aren't used downstream
- Add disconnected learnable layers
- Ignore the batching structure of PyTorch Geometric
- Assume more features = better performance

### Testing Hypothesis:

Run the comparison experiment above. If results show:
- **WITH Hodge**: Better test accuracy → Good integration
- **WITH Hodge**: Worse test accuracy → Poor integration (as expected with current code)

## Technical Deep Dive: Why Hodge Features Break the Model

### The Problem in Your Current Code:

```python
# Current (BROKEN) approach:
hodge_emb = self.hodge_encoder(batch.hodge_l1)  # Computes embeddings
model_out['hodge_emb'] = hodge_emb               # Stores them
# ❌ But hodge_emb is NEVER used again!
```

### What Happens During Training:

1. **Forward pass**: Hodge encoder learns arbitrary patterns
2. **Backward pass**: Gradients from loss DO NOT flow through hodge_encoder
3. **Result**: Parameters are essentially frozen (no meaningful updates)
4. **Side effect**: Model has to fit everything through other paths, causing overfitting

### The Batching Issue:

```
Individual samples:  hodge_l1[0] has shape [10]  (hodge_k=10)
                    hodge_l1[1] has shape [10]
                    hodge_l1[2] has shape [10]

After batching with from_data_list():
batch.hodge_l1 has shape [30]  ← Concatenated!

Your code tries:
batch_size = 3
hodge_l1.reshape(3, -1)  → [3, 10]  ✓ Correct!

But association with nodes is wrong:
Each graph has different numbers of nodes
batch_0[i] tells you which graph node i belongs to
You need to use batch_0 to properly broadcast!
```

### The Correct Integration Pattern:

```python
# ✅ CORRECT approach - fuse into node features:
batch_0 = batch.batch_0  # [num_nodes] - graph index for each node
hodge_expanded = batch.hodge_l1[batch_0]  # Broadcast to match nodes

# Now gradients flow: loss → x_0 → hodge_expanded → hodge_l1 → params
```

### Mathematical Perspective:

Let $x_{n,i}$ be features of node $n$ in graph $i$, and $h_i$ be Hodge features of graph $i$.

**Wrong way** (current):
$$\hat{y}_i = f(x_{i}^{agg})$$ 
where $h_i$ exists but is never used

**Right way** (improved):
$$\hat{y}_i = f([x_{i}^{agg}; h_i])$$
where graph-level features are concatenated with aggregated node features