# Hierarchical Data Generator Analysis

This notebook demonstrates the generation and analysis of hierarchical sparse data using two different tree configurations:

1. **Simple Hierarchy**: A basic two-level tree structure with hierarchical dependencies
2. **Exclusive Groups**: A tree structure with mutually exclusive children 

We'll generate synthetic datasets for both configurations and create comprehensive visualizations including:
- Feature activation patterns
- Data vector representations  
- Feature direction vectors
- Statistical properties and dependencies

The analysis showcases how different tree structures create distinct patterns in the generated data, which can be used to test Sparse Autoencoders (SAEs) under various hierarchical constraints.

## 1. Import Required Libraries

In [3]:
import numpy as np
import torch
import matplotlib.pyplot as plt
import json
import os
import sys

# Import our custom modules
sys.path.append("/Users/kkumbier/github/matryoshka-saes/")
from data_generator import HierarchicalDataGenerator

# Set random seeds for reproducibility
np.random.seed(42)
torch.manual_seed(42)

<torch._C.Generator at 0x1185e2f90>

In [4]:
import sys
print(sys.executable)
import numpy as np
print("NumPy:", np.__version__)
import torch
print("Torch:", torch.__version__)

/usr/local/Caskroom/miniconda/base/envs/mech_int/bin/python
NumPy: 1.26.4
Torch: 2.2.2


In [6]:
import torch
import numpy as np
print(torch.__file__)
print(np.__file__)

/usr/local/Caskroom/miniconda/base/envs/mech_int/lib/python3.11/site-packages/torch/__init__.py
/usr/local/Caskroom/miniconda/base/envs/mech_int/lib/python3.11/site-packages/numpy/__init__.py


## 2. Load Tree Configurations

We'll load the two tree configurations we want to analyze:
- **Simple Hierarchy**: A basic hierarchical structure with standard parameters
- **Exclusive Groups**: A configuration with mutually exclusive feature groups

In [5]:
# Define the tree parameter directory
tree_params_dir = "/Users/kkumbier/github/matryoshka-saes/tree_params"

# Load the two configurations we want to analyze
configs = {}

# Load exclusive groups configuration (full parameters)
with open(os.path.join(tree_params_dir, "exclusive_params.json"), 'r') as f:
    configs['exclusive_groups'] = json.load(f)

# Load simple hierarchy configuration (full parameters)  
with open(os.path.join(tree_params_dir, "simple_params.json"), 'r') as f:
    configs['simple_hierarchy'] = json.load(f)


print("Loaded configurations:")
print(json.dumps(configs, indent=2))

Loaded configurations:
{
  "exclusive_groups": {
    "tree_config": "/Users/kkumbier/github/matryoshka-saes/tree_params/exclusive_groups.json",
    "d_model": 128,
    "feature_correlation": 0.1,
    "noise_level": 0.02,
    "orthogonal_features": true,
    "feature_scale_variation": 0.1,
    "random_seed": 123
  },
  "simple_hierarchy": {
    "tree_config": "/Users/kkumbier/github/matryoshka-saes/tree_params/simple_hierarchy.json",
    "d_model": 256,
    "feature_correlation": 0.0,
    "noise_level": 0.0,
    "orthogonal_features": true,
    "feature_scale_variation": 0.05,
    "random_seed": 42
  }
}


## 3. Generate Datasets for Both Configurations

Now we'll create HierarchicalDataGenerator instances for each configuration and generate sample datasets to analyze.

In [8]:

generator = HierarchicalDataGenerator(**configs["exclusive_groups"])
dataset = generator.create_dataset(
    batch_size=5, num_batches=2, device="cpu"
)

dataset.__getitem__(0)

ValueError: probabilities do not sum to 1

In [1]:
import sys
import os
import importlib.util
import torch
import numpy as np
import matplotlib
import json
print("Python:", sys.version)
print("Executable:", sys.executable)
print("Torch:", torch.__version__, torch.__file__)
print("NumPy:", np.__version__, np.__file__)
print("Matplotlib:", matplotlib.__version__, matplotlib.__file__)
print("JSON:", json.__file__)
print("Current working directory:", os.getcwd())
print("sys.path:")
for p in sys.path: print(p)
print("\nChecking for duplicate C extensions...")
def find_duplicate_modules(module_name):
    found = []
    for p in sys.path:
        try:
            spec = importlib.util.find_spec(module_name, [p])
            if spec and spec.origin and spec.origin not in found:
                found.append(spec.origin)
        except Exception:
            pass
    return found
for mod in ["torch", "numpy", "matplotlib"]:
    paths = find_duplicate_modules(mod)
    print(f"{mod} found at:")
    for path in paths: print("  ", path)

Python: 3.11.13 | packaged by conda-forge | (main, Jun  4 2025, 14:48:01) [Clang 18.1.8 ]
Executable: /usr/local/Caskroom/miniconda/base/envs/mech_int/bin/python
Torch: 2.2.2 /usr/local/Caskroom/miniconda/base/envs/mech_int/lib/python3.11/site-packages/torch/__init__.py
NumPy: 1.26.4 /usr/local/Caskroom/miniconda/base/envs/mech_int/lib/python3.11/site-packages/numpy/__init__.py
Matplotlib: 3.10.5 /usr/local/Caskroom/miniconda/base/envs/mech_int/lib/python3.11/site-packages/matplotlib/__init__.py
JSON: /usr/local/Caskroom/miniconda/base/envs/mech_int/lib/python3.11/json/__init__.py
Current working directory: /Users/kkumbier/github/matryoshka-saes/notebooks
sys.path:
/Users/kkumbier/github/matryoshka-saes/notebooks
/Users/kkumbier/github/persisters/scripts
/Users/kkumbier/github/als_coach
/usr/local/Caskroom/miniconda/base/envs/mech_int/lib/python311.zip
/usr/local/Caskroom/miniconda/base/envs/mech_int/lib/python3.11
/usr/local/Caskroom/miniconda/base/envs/mech_int/lib/python3.11/lib-dyn

In [None]:
# Create generators for both configurations
generators = {}
datasets = {}
feature_directions = {}

for name, config in configs.items():
    print(f"\nGenerating dataset for {name}...")
    
    # Create the generator
    generator = HierarchicalDataGenerator(**config)
    generators[name] = generator

    # Generate a dataset
    X, feature_activations = generator.create_dataset()
    
    datasets[name] = {
        'X': X,
        'feature_activations': feature_activations,
        'generator': generator
    }
    
    print(f"  Generated data shape: {X.shape}")
    print(f"  Number of feature types: {len(feature_activations)}")
    print(f"  Feature activation shapes: {[f'Level {i}: {act.shape}' for i, act in enumerate(feature_activations)]}")
    
    # # Store feature directions for analysis
    # feature_directions[name] = {
    #     'orthogonal': generator.orthogonal_directions,
    #     'correlated': generator.correlated_directions if hasattr(generator, 'correlated_directions') else None
    # }

print("\nDataset generation complete!")

## 4. Data Visualization: Generated Dataset Heatmaps

Let's visualize the generated data matrices to understand the overall structure and patterns in our hierarchical datasets.

In [None]:
# Create heatmaps for generated data
fig, axes = plt.subplots(1, 2, figsize=(16, 6))

for idx, (name, data) in enumerate(datasets.items()):
    X = data['X'].detach().numpy() if torch.is_tensor(data['X']) else data['X']
    
    # Create heatmap
    im = axes[idx].imshow(X[:50].T, aspect='auto', cmap='RdBu_r', interpolation='nearest')
    axes[idx].set_title(f'{name.replace("_", " ").title()} - Generated Data\n(First 50 samples)', fontsize=12, fontweight='bold')
    axes[idx].set_xlabel('Sample Index')
    axes[idx].set_ylabel('Feature Dimension')
    
    # Add colorbar
    plt.colorbar(im, ax=axes[idx])

plt.tight_layout()
plt.show()

# Print data statistics
print("\nData Statistics:")
print("=" * 50)
for name, data in datasets.items():
    X = data['X'].detach().numpy() if torch.is_tensor(data['X']) else data['X']
    print(f"\n{name.upper()}:")
    print(f"  Shape: {X.shape}")
    print(f"  Mean: {X.mean():.4f}")
    print(f"  Std: {X.std():.4f}")
    print(f"  Min: {X.min():.4f}")
    print(f"  Max: {X.max():.4f}")
    print(f"  Non-zero elements: {np.count_nonzero(X)} / {X.size} ({100 * np.count_nonzero(X) / X.size:.2f}%)")

## 5. Feature Activation Analysis

Now let's examine the hierarchical feature activations to understand how features are activated at different levels of the tree structure.

In [None]:
# Analyze feature activations for both configurations
for name, data in datasets.items():
    feature_activations = data['feature_activations']
    
    print(f"\n{name.upper()} - Feature Activation Analysis:")
    print("=" * 60)
    
    # Create subplots for each level
    n_levels = len(feature_activations)
    if n_levels > 0:
        fig, axes = plt.subplots(1, n_levels, figsize=(5 * n_levels, 6))
        if n_levels == 1:
            axes = [axes]
        
        for level, activations in enumerate(feature_activations):
            if torch.is_tensor(activations):
                activations = activations.detach().numpy()
            
            # Show first 50 samples for clarity
            display_activations = activations[:50]
            
            im = axes[level].imshow(display_activations.T, aspect='auto', cmap='Blues', interpolation='nearest')
            axes[level].set_title(f'Level {level} Activations\n({activations.shape[1]} features)', fontweight='bold')
            axes[level].set_xlabel('Sample Index')
            axes[level].set_ylabel('Feature Index')
            plt.colorbar(im, ax=axes[level])
            
            # Print statistics
            print(f"  Level {level}:")
            print(f"    Shape: {activations.shape}")
            print(f"    Active features per sample (mean): {activations.sum(axis=1).mean():.2f}")
            print(f"    Activation probability per feature: {activations.mean(axis=0).mean():.4f}")
            print(f"    Non-zero activations: {np.count_nonzero(activations)} / {activations.size} ({100 * np.count_nonzero(activations) / activations.size:.2f}%)")
        
        plt.tight_layout()
        plt.show()
    else:
        print("  No feature activations found!")
    
    print()

## 6. Feature Direction Analysis

Let's analyze the feature directions (both orthogonal and correlated) to understand the geometric structure of our feature space.

In [None]:
# Analyze feature directions for both configurations
for name, directions in feature_directions.items():
    print(f"\n{name.upper()} - Feature Direction Analysis:")
    print("=" * 60)
    
    # Analyze orthogonal directions
    if directions['orthogonal'] is not None:
        ortho_dirs = directions['orthogonal']
        if torch.is_tensor(ortho_dirs):
            ortho_dirs = ortho_dirs.detach().numpy()
        
        print(f"Orthogonal directions shape: {ortho_dirs.shape}")
        
        # Visualize orthogonal directions
        plt.figure(figsize=(12, 5))
        
        # Heatmap of directions
        plt.subplot(1, 2, 1)
        im1 = plt.imshow(ortho_dirs.T, aspect='auto', cmap='RdBu_r', interpolation='nearest')
        plt.title(f'{name.replace("_", " ").title()}\nOrthogonal Feature Directions', fontweight='bold')
        plt.xlabel('Input Dimension')
        plt.ylabel('Feature Index')
        plt.colorbar(im1)
        
        # Compute and show correlation matrix of directions
        plt.subplot(1, 2, 2)
        correlation_matrix = np.corrcoef(ortho_dirs)
        im2 = plt.imshow(correlation_matrix, cmap='RdBu_r', vmin=-1, vmax=1, interpolation='nearest')
        plt.title('Feature Direction Correlations\n(Should be near-orthogonal)', fontweight='bold')
        plt.xlabel('Feature Index')
        plt.ylabel('Feature Index')
        plt.colorbar(im2)
        
        plt.tight_layout()
        plt.show()
        
        # Statistics
        print(f"  Direction magnitudes (mean ± std): {np.linalg.norm(ortho_dirs, axis=1).mean():.4f} ± {np.linalg.norm(ortho_dirs, axis=1).std():.4f}")
        
        # Check orthogonality
        dot_products = []
        n_features = ortho_dirs.shape[0]
        for i in range(n_features):
            for j in range(i+1, n_features):
                dot_products.append(np.dot(ortho_dirs[i], ortho_dirs[j]))
        
        if dot_products:
            mean_dot = np.mean(np.abs(dot_products))
            print(f"  Mean absolute dot product (orthogonality check): {mean_dot:.6f} (closer to 0 = more orthogonal)")
        
    # Analyze correlated directions if they exist
    if directions['correlated'] is not None:
        corr_dirs = directions['correlated']
        if torch.is_tensor(corr_dirs):
            corr_dirs = corr_dirs.detach().numpy()
        
        print(f"Correlated directions shape: {corr_dirs.shape}")
        
        # Similar analysis for correlated directions
        plt.figure(figsize=(12, 5))
        
        plt.subplot(1, 2, 1)
        im1 = plt.imshow(corr_dirs.T, aspect='auto', cmap='RdBu_r', interpolation='nearest')
        plt.title(f'{name.replace("_", " ").title()}\nCorrelated Feature Directions', fontweight='bold')
        plt.xlabel('Input Dimension')
        plt.ylabel('Feature Index')
        plt.colorbar(im1)
        
        plt.subplot(1, 2, 2)
        correlation_matrix = np.corrcoef(corr_dirs)
        im2 = plt.imshow(correlation_matrix, cmap='RdBu_r', vmin=-1, vmax=1, interpolation='nearest')
        plt.title('Correlated Direction Correlations', fontweight='bold')
        plt.xlabel('Feature Index')
        plt.ylabel('Feature Index')
        plt.colorbar(im2)
        
        plt.tight_layout()
        plt.show()
        
        print(f"  Correlated direction magnitudes (mean ± std): {np.linalg.norm(corr_dirs, axis=1).mean():.4f} ± {np.linalg.norm(corr_dirs, axis=1).std():.4f}")
    
    print()

## 7. Comparative Analysis: Exclusive Groups vs Simple Hierarchy

Let's compare the two configurations directly to understand their differences in structure and behavior.

In [None]:
# Comparative analysis between configurations
print("COMPARATIVE ANALYSIS")
print("=" * 60)

# Compare basic properties
config_names = list(configs.keys())
name1, name2 = config_names[0], config_names[1]

print(f"\nConfiguration Comparison:")
print(f"{'Property':<25} {'Exclusive Groups':<20} {'Simple Hierarchy':<20}")
print("-" * 65)

for prop in ['input_dim', 'n_samples', 'correlation_type', 'feature_correlation']:
    val1 = configs[name1].get(prop, 'N/A')
    val2 = configs[name2].get(prop, 'N/A')
    print(f"{prop:<25} {str(val1):<20} {str(val2):<20}")

# Compare data characteristics
print(f"\nData Characteristics:")
print(f"{'Metric':<25} {'Exclusive Groups':<20} {'Simple Hierarchy':<20}")
print("-" * 65)

data1 = datasets[name1]['X'].detach().numpy() if torch.is_tensor(datasets[name1]['X']) else datasets[name1]['X']
data2 = datasets[name2]['X'].detach().numpy() if torch.is_tensor(datasets[name2]['X']) else datasets[name2]['X']

metrics = {
    'Data shape': [str(data1.shape), str(data2.shape)],
    'Mean activation': [f"{data1.mean():.4f}", f"{data2.mean():.4f}"],
    'Std activation': [f"{data1.std():.4f}", f"{data2.std():.4f}"],
    'Sparsity (% zeros)': [f"{100*(1-np.count_nonzero(data1)/data1.size):.1f}%", 
                          f"{100*(1-np.count_nonzero(data2)/data2.size):.1f}%"]
}

for metric, values in metrics.items():
    print(f"{metric:<25} {values[0]:<20} {values[1]:<20}")

# Compare feature activation patterns
print(f"\nFeature Activation Patterns:")
print(f"{'Level':<10} {'Exclusive Groups':<30} {'Simple Hierarchy':<30}")
print("-" * 70)

max_levels = max(len(datasets[name1]['feature_activations']), len(datasets[name2]['feature_activations']))

for level in range(max_levels):
    act1 = datasets[name1]['feature_activations'][level] if level < len(datasets[name1]['feature_activations']) else None
    act2 = datasets[name2]['feature_activations'][level] if level < len(datasets[name2]['feature_activations']) else None
    
    if act1 is not None:
        act1_np = act1.detach().numpy() if torch.is_tensor(act1) else act1
        act1_desc = f"Shape: {act1_np.shape}, Sparsity: {100*(1-np.count_nonzero(act1_np)/act1_np.size):.1f}%"
    else:
        act1_desc = "No activations"
        
    if act2 is not None:
        act2_np = act2.detach().numpy() if torch.is_tensor(act2) else act2
        act2_desc = f"Shape: {act2_np.shape}, Sparsity: {100*(1-np.count_nonzero(act2_np)/act2_np.size):.1f}%"
    else:
        act2_desc = "No activations"
    
    print(f"{level:<10} {act1_desc:<30} {act2_desc:<30}")

# Side-by-side visualization
print(f"\nSide-by-side Data Visualization:")
fig, axes = plt.subplots(2, 2, figsize=(16, 12))

# Top row: raw data comparison
for idx, (name, data) in enumerate(datasets.items()):
    X = data['X'].detach().numpy() if torch.is_tensor(data['X']) else data['X']
    im = axes[0, idx].imshow(X[:100].T, aspect='auto', cmap='RdBu_r', interpolation='nearest')
    axes[0, idx].set_title(f'{name.replace("_", " ").title()}\nGenerated Data (First 100 samples)', fontweight='bold')
    axes[0, idx].set_xlabel('Sample Index')
    axes[0, idx].set_ylabel('Feature Dimension')
    plt.colorbar(im, ax=axes[0, idx])

# Bottom row: feature activations (level 0 if available)
for idx, (name, data) in enumerate(datasets.items()):
    if len(data['feature_activations']) > 0:
        activations = data['feature_activations'][0]
        if torch.is_tensor(activations):
            activations = activations.detach().numpy()
        im = axes[1, idx].imshow(activations[:100].T, aspect='auto', cmap='Blues', interpolation='nearest')
        axes[1, idx].set_title(f'{name.replace("_", " ").title()}\nLevel 0 Feature Activations', fontweight='bold')
        axes[1, idx].set_xlabel('Sample Index')
        axes[1, idx].set_ylabel('Feature Index')
        plt.colorbar(im, ax=axes[1, idx])
    else:
        axes[1, idx].text(0.5, 0.5, 'No feature activations\navailable', 
                         ha='center', va='center', transform=axes[1, idx].transAxes, fontsize=12)
        axes[1, idx].set_title(f'{name.replace("_", " ").title()}\nNo Feature Activations', fontweight='bold')

plt.tight_layout()
plt.show()