# On-Disk Transductive Learning: Extended Context Batching

Enhance your existing node samplers with **structure-aware context expansion**!

**What you'll learn:**
- üîß How to enhance ANY existing node sampler
- üìà Controlled context expansion for complete structures
- üéØ Core vs context nodes (for selective loss)
- üîÑ Backward compatibility with existing samplers

**üìö Prerequisites:** Complete `tutorial_ondisk_transductive_intro.ipynb` first!  
That tutorial covers: transductive learning, node samplers (Random, ClusterAware), cluster-aware sampling, structure loss problem, and on-disk indexing.

**‚è±Ô∏è Time:** 15-20 minutes

## Why Extended Context Batching?

**Building on Tutorial 1:** We learned that cluster-aware node samplers lose structures at boundaries.

**This approach: Enhance, don't replace!**
- Keep your existing node sampler (Louvain, METIS, etc.)
- **Add context nodes** to complete structures
- **Control expansion** with max_expansion_ratio
- **Result:** 95-100% completeness with modest memory increase

**When to use:**
- Have existing node sampler you like
- Want **backward compatibility**
- Can afford 20-50% memory increase per batch
- Using enumerable structures (cliques, cycles)

## Which Topological Structures Benefit?

### ‚úÖ Benefits from Extended Context:
- **SimplicialCliqueLifting**: Cliques/triangles enumerated from graph
- **CellCycleLifting**: Cycles discovered in graph
- **Any enumerable structure** from the graph topology

### ‚ùå Does NOT benefit:
- **HypergraphKHopLifting**: Neighborhoods generated per node (already complete)
- **KernelLifting**: Structures from kernels, not graph enumeration
- **Feature-based** transforms that don't rely on graph structure

**Why?** Extended context helps when structures **cross batch boundaries**. If structures are generated per-node (like k-hop), they're already complete within the sampled nodes.

### Table of Contents

1. [Setup and Data Loading](#setup)
2. [Building the On-Disk Index](#index)
3. [Using Existing Node Samplers](#samplers)
4. [Extended Context Enhancement](#context)
5. [Core vs Context Nodes](#core-context)
6. [Training with TBModel](#training)
7. [Do Liftings Need Modification?](#liftings)

<a id='setup'></a>
## 1. Setup and Data Loading

Following TopoBench conventions, we define dataset and loader classes.

In [None]:
### Define Dataset Class (TopoBench Style)

**Important: Why `InMemoryDataset` is fine here**

You might wonder: "If we're doing on-disk processing, why use `InMemoryDataset`?"

**Answer:** We keep the **graph** in memory, but **structures** on disk!

- **Graph data** (nodes, edges, features): ~50-200 MB for 8K nodes ‚Üí **fits in RAM** ‚úÖ
- **Topological structures** (triangles, etc.): 1-10 GB ‚Üí **stored on disk** ‚úÖ

**Why this works:**
1. The base graph must be in memory anyway (for subgraph extraction)
2. The **bottleneck** is structures, not the graph itself
3. Our on-disk index handles the structures (SQLite)
4. Result: Constant memory training on large graphs!

**Would on-demand graph loading help?** No, because:
- We need full graph in memory to extract subgraphs
- Preprocessor needs full graph to query edges
- Graph itself is not the memory problem
- Structures are the problem (and we solve that!)

**Bottom line:** `InMemoryDataset` for graph + on-disk index for structures = perfect combo!

### Define Dataset Class (TopoBench Style)

In [None]:
class CommunityGraphDataset(InMemoryDataset):
    """Large graph with clear community structure."""
    
    def __init__(self, root, name, parameters: DictConfig):
        self.name = name
        self.parameters = parameters
        super().__init__(root)
        
        out = fs.torch_load(self.processed_paths[0])
        if len(out) == 4:
            data, self.slices, self.sizes, data_cls = out
            self.data = data_cls.from_dict(data) if isinstance(data, dict) else data
        else:
            data, self.slices, self.sizes = out
            self.data = data
    
    @property
    def raw_file_names(self):
        return []
    
    @property
    def processed_file_names(self):
        return "data.pt"
    
    def download(self):
        pass
    
    def process(self):
        """Generate graph with community structure."""
        from networkx.generators.community import stochastic_block_model
        
        n = self.parameters.num_nodes
        num_communities = self.parameters.num_communities
        nodes_per_comm = n // num_communities
        sizes = [nodes_per_comm] * num_communities
        
        # High intra-community, low inter-community edges
        p_in = 0.3
        p_out = 0.02
        probs = [[p_in if i == j else p_out for j in range(num_communities)] 
                 for i in range(num_communities)]
        
        G = stochastic_block_model(sizes, probs, seed=42)
        
        # Convert to PyG Data
        edges = list(G.edges())
        edge_index = torch.tensor(edges, dtype=torch.long).t()
        edge_index = torch.cat([edge_index, edge_index[[1, 0]]], dim=1)
        
        x = torch.randn(n, self.parameters.num_features)
        y = torch.randint(0, self.parameters.num_classes, (n,))
        
        # Transductive splits
        train_mask = torch.zeros(n, dtype=torch.bool)
        val_mask = torch.zeros(n, dtype=torch.bool)
        test_mask = torch.zeros(n, dtype=torch.bool)
        
        train_mask[:int(0.6 * n)] = True
        val_mask[int(0.6 * n):int(0.8 * n)] = True
        test_mask[int(0.8 * n):] = True
        
        data = Data(
            x=x, edge_index=edge_index, y=y, num_nodes=n,
            train_mask=train_mask, val_mask=val_mask, test_mask=test_mask
        )
        
        self.data, self.slices = self.collate([data])
        fs.torch_save(
            (self._data.to_dict(), self.slices, {}, self._data.__class__),
            self.processed_paths[0]
        )

print("‚úì Dataset class defined")

### Define Loader Class (TopoBench Style)

In [None]:
class CommunityGraphLoader(AbstractLoader):
    """Loader for community-structured graphs."""
    
    def __init__(self, parameters: DictConfig):
        super().__init__(parameters)
    
    def load_dataset(self):
        return CommunityGraphDataset(
            str(self.root_data_dir),
            self.parameters.data_name,
            self.parameters
        )

print("‚úì Loader class defined")

# Imports
import networkx as nx
import torch
from torch_geometric.data import Data, InMemoryDataset
from torch_geometric.io import fs
import lightning as pl
from omegaconf import OmegaConf, DictConfig

# TopoBench imports
from topobench.data.loaders.base import AbstractLoader
from topobench.data.preprocessor import OnDiskTransductivePreprocessor
from topobench.dataloader import TBDataloader
from topobench.model import TBModel
from topobench.nn.backbones.simplicial import SCCNNCustom
from topobench.nn.readouts.simplicial_readout import SimplicialReadout
from topobench.loss import TBLoss
from topobench.optimizer import TBOptimizer

print("‚úì Imports complete")

In [None]:
# Configuration
config = OmegaConf.create({
    "data_dir": "./data/",
    "data_name": "CommunityGraph",
    "num_nodes": 8000,
    "num_communities": 8,
    "num_features": 32,
    "num_classes": 4
})

# Load dataset
loader = CommunityGraphLoader(config)
dataset, _ = loader.load()
graph_data = dataset[0]

print(f"\n‚úì Graph loaded: {graph_data.num_nodes:,} nodes")
print(f"  Edges: {graph_data.edge_index.size(1):,}")
print(f"  Communities: 8 (clear cluster structure)")
print(f"  Train nodes: {graph_data.train_mask.sum().item():,}")

<a id='samplers'></a>
## 3. Using Existing Node Samplers

**Key Feature:** Extended context works with **any node sampler from Tutorial 1**!

**Recap from Tutorial 1:** Node samplers yield batches of node IDs.  
Options: `louvain`, `metis`, `leiden`, `label_propagation`, or custom samplers

Let's use the `ClusterAwareNodeSampler` with Louvain:

In [None]:
# Configure transform
transforms_config = OmegaConf.create({
    "clique_lifting": {
        "transform_type": "lifting",
        "transform_name": "SimplicialCliqueLifting",
        "complex_dim": 2
    }
})

# Create preprocessor (no need to build_index manually - it's automatic!)
preprocessor = OnDiskTransductivePreprocessor(
    graph_data=graph_data,
    data_dir="./index/extended_context_demo",
    transforms_config=transforms_config,
    max_structure_size=3
)

print("‚úì Preprocessor created")

<a id='dataloader'></a>
## 3. Load Dataset Splits (TopoBench Style)

**High-level API:** Exactly like inductive learning!

**What happens under the hood:**
1. Builds structure index (if not exists) ‚Üí saved to disk
2. Creates train/val/test datasets based on masks
3. Each dataset wraps a loader for extended context sampling

**Result:** Train, val, test datasets ready for TBDataloader!

In [None]:
# Create split configuration (like inductive learning!)
split_config = OmegaConf.create({
    "strategy": "extended_context",
    "nodes_per_batch": 1000,       # Core nodes per batch
    "max_expansion_ratio": 1.5,    # Allow up to 50% expansion
    "sampler_method": "louvain",   # Community detection
})

print("üóÑÔ∏è Loading dataset splits (builds index if needed)...\n")

# Load splits - EXACTLY like inductive learning!
train, val, test = preprocessor.load_dataset_splits(split_config)

print(f"\n‚úì Dataset splits loaded!")
print(f"  Train: {len(train)} batches")
print(f"  Val: {len(val)} batches")
print(f"  Test: {len(test)} batches")
print(f"  Strategy: Extended context batching")

# Inspect a sample batch
sample_batch = next(iter(train))
print(f"\nüì¶ Sample training batch:")
print(f"  Total nodes: {sample_batch.num_nodes}")
print(f"  Core nodes: {sample_batch.core_mask.sum().item()}")
print(f"  Context nodes: {sample_batch.num_nodes - sample_batch.core_mask.sum().item()}")
print(f"  Expansion ratio: {sample_batch.expansion_ratio:.2f}x")

### Create Datamodule (TopoBench Style)

**Exactly like inductive learning:** Use TBDataloader!

In [None]:
# Create datamodule - EXACTLY like inductive learning!
datamodule = TBDataloader(
    dataset_train=train,
    dataset_val=val,
    dataset_test=test,
    batch_size=1,  # Already batched by dataset
    num_workers=0
)

print("‚úì Datamodule created (TopoBench style)")
print("  This is IDENTICAL to inductive learning!")
print("  Same API, same workflow, just different sampling strategy")

<a id='training'></a>
## 5. Training with TBModel

**Integration:** Works seamlessly with TopoBench's TBModel!

In [None]:
# Inspect a sample batch
sample_batch = next(iter(train_loader))

print("\nüì¶ Sample Batch Analysis:")
print(f"  Total nodes (with context): {sample_batch.num_nodes}")
print(f"  Core nodes (sampled): {sample_batch.core_mask.sum().item()}")
print(f"  Context nodes (added): {sample_batch.num_nodes - sample_batch.core_mask.sum().item()}")
print(f"  Expansion ratio: {sample_batch.expansion_ratio:.2f}x")

print(f"\n  Edges: {sample_batch.edge_index.size(1)}")
print(f"  Structures: {sample_batch.num_structures}")

print(f"\n  Core mask: {sample_batch.core_mask.sum().item()} True values")
print(f"  üí° Use core_mask to compute loss only on sampled nodes")

<a id='training'></a>
## 6. Training with TBModel

**Integration:** Works seamlessly with TopoBench's TBModel!

In [None]:
# Define model
HIDDEN_DIM = 64
OUT_CHANNELS = 4
IN_CHANNELS = 32

model = TBModel(
    backbone=SCCNNCustom(
        in_channels_all=(IN_CHANNELS, HIDDEN_DIM, HIDDEN_DIM),
        hidden_channels_all=(HIDDEN_DIM, HIDDEN_DIM, HIDDEN_DIM),
        conv_order=1,
        sc_order=2,
        n_layers=2
    ),
    readout=SimplicialReadout(
        HIDDEN_DIM, OUT_CHANNELS, task_level="node"
    ),
    loss=TBLoss(
        dataset_loss={"task": "classification", "loss_type": "cross_entropy"}
    ),
    optimizer=TBOptimizer(
        optimizer_id="Adam", parameters={"lr": 0.01}
    )
)

print("‚úì TBModel created")

In [None]:
# Train - EXACTLY like inductive learning!
trainer = pl.Trainer(
    max_epochs=5,
    accelerator="auto",
    devices=1,
    enable_progress_bar=True
)

print("\nüöÄ Training with extended context...\n")
print("  Note: In production, consider using core_mask for loss")
print("  For this demo, we train on all nodes\n")

trainer.fit(model, datamodule)

print("\n‚úÖ Training complete!")
print(f"  Trained on {graph_data.train_mask.sum().item():,} train nodes")
print(f"  Validated on {graph_data.val_mask.sum().item():,} val nodes")
print("  Standard node sampler + extended context!")

<a id='liftings'></a>
## 6. Do Liftings Need Modification?

**Short Answer:** No! Existing liftings work as-is.

**What's different:**
- **Index building:** Enumerates structures from full graph once
- **Batch-time:** Lifting applied to expanded mini-batch (as normal)
- **Advantage:** Structures queried from index + context nodes added

**Example:**

In [None]:
from topobench.transforms.liftings.graph2simplicial import SimplicialCliqueLifting

# Your existing lifting works unchanged!
lifting = SimplicialCliqueLifting(complex_dim=2)

print("‚úì Existing SimplicialCliqueLifting works as-is")
print("\n  What happens:")
print("  1. Sampler provides core nodes from one community")
print("  2. Extended context adds neighbors for complete structures")
print("  3. Lifting applied to expanded batch (standard process)")
print("  4. Result: More complete structures in mini-batch")
print("\n  No lifting changes needed!")

## Summary

**What we learned:**
1. ‚úÖ **On-disk indexing:** SQLite database, not RAM
2. ‚úÖ **Backward compatible:** Works with existing node samplers
3. ‚úÖ **Controlled expansion:** Set `max_expansion_ratio`
4. ‚úÖ **Core vs context:** Distinguish sampled vs added nodes
5. ‚úÖ **Which structures benefit:** Enumerable structures (cliques, cycles)
6. ‚úÖ **No lifting changes:** Existing liftings work as-is

**When to use this approach:**
- ‚úÖ You have existing node samplers you like (Louvain, METIS, etc.)
- ‚úÖ Using SimplicialCliqueLifting or cycle-based liftings
- ‚úÖ Want better structure completeness without changing workflow
- ‚úÖ Can afford 20-50% memory increase per batch

**When NOT to use:**
- ‚ùå Using k-hop or kernel-based liftings (structures not enumerable)
- ‚ùå Memory is extremely constrained
- ‚ùå Transforms don't rely on graph structure enumeration

**Alternative:** See `tutorial_ondisk_transductive_structure_centric.ipynb` for a structure-first approach!

In [None]:
# Cleanup
preprocessor.close()
print("‚úì Tutorial complete!")