# HRNet Topological Optimization - Testing Notebook

This notebook tests all components of the topological optimization system.

## Requirements
- Python 3.8+
- PyTorch 2.0+
- TDA libraries (gudhi, ripser, persim)

Run this notebook to verify everything works correctly.

## 1. Environment Setup and Dependencies Check

In [None]:
# Import required libraries and check versions
import sys
import os

print(f"Python version: {sys.version}")
print(f"Working directory: {os.getcwd()}")
print("="*80)

In [None]:
# Check all required packages
import importlib

required_packages = [
    'torch',
    'torchvision',
    'numpy',
    'scipy',
    'matplotlib',
    'seaborn',
    'sklearn',
    'ripser',
    'persim',
    'tqdm'
]

print("Checking required packages...\n")
missing_packages = []

for package in required_packages:
    try:
        mod = importlib.import_module(package)
        version = getattr(mod, '__version__', 'unknown')
        print(f"✓ {package:15s} {version}")
    except ImportError:
        print(f"✗ {package:15s} NOT FOUND")
        missing_packages.append(package)

if missing_packages:
    print(f"\n⚠️  Missing packages: {', '.join(missing_packages)}")
    print("\nTo install missing packages, run:")
    print(f"pip install {' '.join(missing_packages)}")
else:
    print("\n✓ All required packages are installed!")

## 2. Import Project Modules

In [None]:
# Add paths
import sys
import os
from pathlib import Path

# Add hrnet_base to path
project_root = Path(os.getcwd())
hrnet_lib_path = project_root / 'hrnet_base' / 'lib'
sys.path.insert(0, str(hrnet_lib_path))

print(f"Project root: {project_root}")
print(f"HRNet lib path: {hrnet_lib_path}")

In [None]:
# Import main project modules
import numpy as np
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.notebook import tqdm

# Import custom modules
from topology_analyzer import TopologicalAnalyzer, TopologyAwareTraining
from train_enhanced import HRNetCIFAR

print("✓ All project modules imported successfully!")

## 3. Test Topological Analyzer

In [None]:
# Test basic topological analysis on synthetic data
print("Testing TopologicalAnalyzer...\n")

# Create synthetic data: two clusters
np.random.seed(42)
cluster1 = np.random.randn(50, 10) + np.array([2, 0, 0, 0, 0, 0, 0, 0, 0, 0])
cluster2 = np.random.randn(50, 10) - np.array([2, 0, 0, 0, 0, 0, 0, 0, 0, 0])
data = np.vstack([cluster1, cluster2])

print(f"Test data shape: {data.shape}")

# Initialize analyzer
analyzer = TopologicalAnalyzer(max_dimension=1, distance_threshold=5.0)
print("✓ TopologicalAnalyzer initialized")

# Compute persistence diagram
stats = analyzer.compute_persistence_diagram(data, label='test')

if stats:
    print("\n✓ Persistence diagram computed successfully!")
    print(f"  Betti numbers: {stats['betti_numbers']}")
    print(f"  Persistence entropy: {stats['persistence_entropy']:.4f}")
    print(f"  Number of H_0 features: {len(stats['diagrams'][0])}")
    print(f"  Number of H_1 features: {len(stats['diagrams'][1]) if len(stats['diagrams']) > 1 else 0}")
else:
    print("✗ Failed to compute persistence diagram")

In [None]:
# Test bottleneck distance computation
print("Testing bottleneck distance computation...\n")

# Create two similar clusters
data1 = np.random.randn(50, 10)
data2 = np.random.randn(50, 10) + 0.1  # Slight shift

distance = analyzer.compute_bottleneck_distance(data1, data2, dimension=0)
print(f"Bottleneck distance (H_0): {distance:.4f}")

if distance < float('inf'):
    print("✓ Bottleneck distance computed successfully!")
    print(f"  Distance is {'small' if distance < 1.0 else 'large'} (expected: small for similar data)")
else:
    print("✗ Failed to compute bottleneck distance")

In [None]:
# Visualize persistence diagram
print("Testing persistence diagram visualization...\n")

# Create output directory
output_dir = Path('test_output')
output_dir.mkdir(exist_ok=True)

# Visualize
analyzer.visualize_persistence_diagram(
    data,
    save_path=output_dir / 'test_persistence.png'
)

print("✓ Persistence diagram saved to test_output/test_persistence.png")

# Display inline
from IPython.display import Image, display
if (output_dir / 'test_persistence.png').exists():
    display(Image(filename=str(output_dir / 'test_persistence.png')))

## 4. Test HRNet Model

In [None]:
# Test model instantiation
print("Testing HRNetCIFAR model...\n")

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Create model
model = HRNetCIFAR(num_classes=10, width=18)
model = model.to(device)

print(f"✓ Model created successfully")
print(f"  Total parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"  Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")

In [None]:
# Test forward pass
print("Testing forward pass...\n")

# Create dummy input
batch_size = 4
dummy_input = torch.randn(batch_size, 3, 32, 32).to(device)

# Forward pass without features
output = model(dummy_input)
print(f"Output shape: {output.shape}")
print(f"✓ Standard forward pass successful")

# Forward pass with features
output, features = model(dummy_input, return_features=True)
print(f"\nOutput shape: {output.shape}")
print(f"Features shape: {features.shape}")
print(f"✓ Forward pass with feature extraction successful")

## 5. Test Data Loading

In [None]:
# Test CIFAR-10 data loading
print("Testing CIFAR-10 data loading...\n")

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

# Download a small subset for testing
test_dataset = torchvision.datasets.CIFAR10(
    root='./data',
    train=False,
    download=True,
    transform=transform
)

print(f"✓ CIFAR-10 test dataset loaded")
print(f"  Number of samples: {len(test_dataset)}")
print(f"  Classes: {test_dataset.classes}")

# Create data loader
test_loader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=32,
    shuffle=False,
    num_workers=2
)

print(f"✓ Data loader created")
print(f"  Number of batches: {len(test_loader)}")

In [None]:
# Visualize sample images
print("Visualizing sample images...\n")

# Get one batch
images, labels = next(iter(test_loader))

# Denormalize for visualization
def denormalize(img):
    mean = torch.tensor([0.4914, 0.4822, 0.4465]).view(3, 1, 1)
    std = torch.tensor([0.2023, 0.1994, 0.2010]).view(3, 1, 1)
    return img * std + mean

# Plot
fig, axes = plt.subplots(2, 8, figsize=(16, 4))
for i, ax in enumerate(axes.flat):
    img = denormalize(images[i])
    img = img.permute(1, 2, 0).numpy()
    img = np.clip(img, 0, 1)
    ax.imshow(img)
    ax.set_title(test_dataset.classes[labels[i]], fontsize=10)
    ax.axis('off')

plt.tight_layout()
plt.savefig('test_output/sample_images.png', dpi=100, bbox_inches='tight')
plt.show()

print("✓ Sample images visualized and saved")

## 6. Test Topology-Aware Training

In [None]:
# Test topology-aware training components
print("Testing TopologyAwareTraining...\n")

topology_trainer = TopologyAwareTraining(topology_weight=0.01)
print("✓ TopologyAwareTraining initialized")

# Test loss computation
criterion = nn.CrossEntropyLoss()

# Get a batch
images, labels = next(iter(test_loader))
images, labels = images.to(device), labels.to(device)

# Forward pass
with torch.no_grad():
    output, features = model(images, return_features=True)

# Compute combined loss
combined_loss, loss_stats = topology_trainer.compute_combined_loss(
    output, labels, features, criterion
)

print(f"\n✓ Combined loss computed successfully!")
print(f"  Base loss: {loss_stats['base_loss']:.4f}")
print(f"  Topological loss: {loss_stats['topo_loss']:.4f}")
print(f"  Total loss: {loss_stats['total_loss']:.4f}")

## 7. Test Feature Extraction and Analysis

In [None]:
# Extract features from a small subset
print("Extracting features from test data...\n")

model.eval()
features_list = []
labels_list = []
max_batches = 10  # Use 10 batches for quick testing

with torch.no_grad():
    for i, (images, labels) in enumerate(tqdm(test_loader, desc='Extracting')):
        if i >= max_batches:
            break

        images = images.to(device)
        _, features = model(images, return_features=True)

        features_list.append(features.cpu().numpy())
        labels_list.extend(labels.numpy())

# Concatenate
all_features = np.concatenate(features_list, axis=0)
all_labels = np.array(labels_list)

print(f"\n✓ Features extracted")
print(f"  Shape: {all_features.shape}")
print(f"  Number of samples: {len(all_labels)}")

In [None]:
# Organize features by class
print("Organizing features by class...\n")

features_by_class = {}
for class_id in range(10):
    mask = all_labels == class_id
    features_by_class[class_id] = all_features[mask]
    print(f"  Class {class_id} ({test_dataset.classes[class_id]:12s}): {np.sum(mask)} samples")

print("\n✓ Features organized by class")

In [None]:
# Compute topological features for each class
print("Computing topological features per class...\n")

analyzer_test = TopologicalAnalyzer(max_dimension=1, distance_threshold=3.0)
class_topology = {}

for class_id in tqdm(range(10), desc='Analyzing classes'):
    features = features_by_class[class_id]

    if len(features) > 10:  # Need minimum samples
        # Sample if too many
        if len(features) > 100:
            indices = np.random.choice(len(features), 100, replace=False)
            features = features[indices]

        stats = analyzer_test.compute_persistence_diagram(
            features,
            label=f'class_{class_id}'
        )

        if stats:
            class_topology[class_id] = stats

print(f"\n✓ Topological analysis complete for {len(class_topology)} classes")

In [None]:
# Display topology statistics
print("\nTopological Statistics per Class:")
print("=" * 80)
print(f"{'Class':<12} {'Name':<15} {'Betti-0':<10} {'Betti-1':<10} {'Entropy':<12}")
print("=" * 80)

for class_id in sorted(class_topology.keys()):
    stats = class_topology[class_id]
    name = test_dataset.classes[class_id]
    betti = stats['betti_numbers']
    betti_0 = betti[0] if len(betti) > 0 else 0
    betti_1 = betti[1] if len(betti) > 1 else 0
    entropy = stats['persistence_entropy']

    print(f"{class_id:<12} {name:<15} {betti_0:<10} {betti_1:<10} {entropy:<12.4f}")

print("=" * 80)

## 8. Test Bottleneck Distance Matrix

In [None]:
# Compute inter-class bottleneck distances
print("Computing inter-class bottleneck distances...\n")

num_classes = 10
distance_matrix = np.zeros((num_classes, num_classes))

for i in tqdm(range(num_classes), desc='Computing distances'):
    for j in range(i+1, num_classes):
        if i in features_by_class and j in features_by_class:
            feat_i = features_by_class[i]
            feat_j = features_by_class[j]

            # Sample if needed
            if len(feat_i) > 100:
                feat_i = feat_i[np.random.choice(len(feat_i), 100, replace=False)]
            if len(feat_j) > 100:
                feat_j = feat_j[np.random.choice(len(feat_j), 100, replace=False)]

            distance = analyzer_test.compute_bottleneck_distance(
                feat_i, feat_j, dimension=0
            )

            distance_matrix[i, j] = distance
            distance_matrix[j, i] = distance

print("\n✓ Distance matrix computed")

In [None]:
# Visualize distance matrix
print("Visualizing bottleneck distance matrix...\n")

plt.figure(figsize=(12, 10))

sns.heatmap(
    distance_matrix,
    annot=True,
    fmt='.3f',
    cmap='YlOrRd',
    xticklabels=test_dataset.classes,
    yticklabels=test_dataset.classes,
    square=True,
    cbar_kws={'label': 'Bottleneck Distance'}
)

plt.title('Inter-Class Bottleneck Distance Matrix\n(Untrained Model)',
          fontsize=14, pad=20)
plt.xlabel('Class', fontsize=12)
plt.ylabel('Class', fontsize=12)
plt.tight_layout()
plt.savefig('test_output/distance_matrix_test.png', dpi=150, bbox_inches='tight')
plt.show()

print("✓ Distance matrix visualized and saved")

## 9. Summary and Verification

In [None]:
# Final summary
print("\n" + "="*80)
print("TESTING SUMMARY")
print("="*80)
print("\n✓ All components tested successfully!\n")

print("Components verified:")
print("  1. ✓ TopologicalAnalyzer - persistent homology computation")
print("  2. ✓ Bottleneck distance - topological similarity measurement")
print("  3. ✓ HRNetCIFAR model - forward pass and feature extraction")
print("  4. ✓ Data loading - CIFAR-10 dataset")
print("  5. ✓ TopologyAwareTraining - combined loss computation")
print("  6. ✓ Feature analysis - per-class topology")
print("  7. ✓ Distance matrix - inter-class comparison")
print("  8. ✓ Visualizations - persistence diagrams and heatmaps")

print("\nOutput files created in test_output/:")
output_files = list(Path('test_output').glob('*'))
for f in sorted(output_files):
    print(f"  - {f.name}")

print("\n" + "="*80)
print("READY FOR TRAINING!")
print("="*80)
print("\nTo start training with topological optimization:")
print("  python train_enhanced.py --dataset cifar10 --topology-weight 0.01")
print("\nOr run the quick start script:")
print("  ./quick_start.sh")

## Optional: Quick Training Test (1 Epoch)

Run this cell to test a complete training iteration (takes a few minutes).

In [None]:
# Optional: Test one epoch of training
print("Testing one training epoch...\n")

# Setup
model.train()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
criterion = nn.CrossEntropyLoss()
topology_trainer = TopologyAwareTraining(topology_weight=0.01)

# Train on a few batches
num_batches = 10
running_loss = 0.0
correct = 0
total = 0

for i, (images, labels) in enumerate(tqdm(test_loader, desc='Training')):
    if i >= num_batches:
        break

    images, labels = images.to(device), labels.to(device)

    optimizer.zero_grad()

    # Forward pass
    output, features = model(images, return_features=True)

    # Compute loss with topology
    loss, loss_stats = topology_trainer.compute_combined_loss(
        output, labels, features, criterion
    )

    # Backward pass
    loss.backward()
    optimizer.step()

    # Statistics
    running_loss += loss.item()
    _, predicted = output.max(1)
    total += labels.size(0)
    correct += predicted.eq(labels).sum().item()

# Results
avg_loss = running_loss / num_batches
accuracy = 100. * correct / total

print(f"\n✓ Training test complete!")
print(f"  Average loss: {avg_loss:.4f}")
print(f"  Accuracy: {accuracy:.2f}%")
print(f"\n  (Note: Low accuracy is expected for untrained model)")