# IRED Code Exploration

This notebook explores the IRED (Iterative Reasoning Energy Diffusion) codebase to understand:
- Optimization loop structure and location
- State variable representation and tensor shapes
- Energy computation functions and their inputs/outputs
- Convergence and success logic
- Evaluation vs training modes
- Key functions to hook for trajectory logging

In [None]:
import sys
import os
sys.path.append('../external/ired')

import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt

## 1. Key Model Architecture Analysis

### Energy-Based Models (EBMs)
The core energy computation happens in several EBM classes:

1. **Basic EBM** (`models.py:164-216`): 
   - Input: concatenated input and output tensors, time embedding
   - Energy computation: `output.pow(2).sum(dim=-1)` (L2 norm squared)
   - Architecture: MLP with time-modulated features

2. **SudokuEBM** (`models.py:328-439`):
   - Input: 9x9x9 tensor (spatial + digit channels)
   - Energy: `output.pow(2).sum(dim=1).sum(dim=1).sum(dim=1)` (spatial sum of squared differences)
   - Constraint handling: Built-in Sudoku constraint evaluation

3. **GraphEBM** (`models.py:507-577`):
   - For connectivity and planning tasks
   - Uses graph neural networks (GNN) for spatial reasoning
   - Energy: L2 distance between predicted and target graph states

## 2. Optimization Loop Location and Structure

**Primary optimization loop**: `diffusion_lib/denoising_diffusion_pytorch_1d.py:373-406` (`opt_step` method)

Key characteristics:
- **Gradient descent on energy landscape**: `energy, grad = self.model(inp, img, t, return_both=True)`
- **Step size control**: `img_new = img - extract(self.opt_step_size, t, grad.shape) * grad * sf`
- **Energy-based step acceptance**: Rejects steps that increase energy
- **Constraint handling**: Masks for fixed/known variables
- **Clamping**: Values clamped to valid ranges based on diffusion schedule

### Iteration Structure:
```python
for i in range(step):  # step=5 by default
    energy, grad = self.model(inp, img, t, return_both=True)
    img_new = img - step_size * grad
    energy_new = self.model(inp, img_new, t, return_energy=True)
    if energy_new > energy:  # Reject bad steps
        img_new = img  # Keep old state
    img = img_new
```

## 3. State Variable Representation

State representations vary by domain:

### Continuous Tasks (Addition, Inverse, LowRank):
- **Shape**: `(batch_size, feature_dim)` where feature_dim varies by task
- **Type**: `torch.float32` continuous values
- **Range**: Typically normalized to `[-1, 1]` or `[0, 1]`

### Sudoku:
- **Shape**: `(batch_size, 729)` = `(batch_size, 9*9*9)`
- **Representation**: One-hot encoded, reshaped to `(batch_size, 9, 9, 9)`
- **Channels**: 9 channels for digits 1-9 at each of 81 positions
- **Values**: `[-1, +1]` (centered one-hot)

### Graph Tasks (Connectivity, Planning):
- **Shape**: `(batch_size, n_nodes, node_features)` or flattened versions
- **Type**: Mixed continuous/discrete depending on task
- **Planning**: Often includes spatial coordinates and connectivity matrices

## 4. Energy Computation Functions

### Key Energy Functions to Hook:

1. **`DiffusionWrapper.forward()`** (`models.py:799-812`):
   ```python
   def forward(self, inp, opt_out, t, return_energy=False, return_both=False):
       energy = self.ebm(opt_variable, t)
       if return_energy:
           return energy
       opt_grad = torch.autograd.grad([energy.sum()], [opt_out], create_graph=True)[0]
       if return_both:
           return energy, opt_grad
       return opt_grad
   ```

2. **`GaussianDiffusion1D.opt_step()`** (`diffusion_lib/denoising_diffusion_pytorch_1d.py:373-406`):
   - **Hook point for trajectories**: This is where we can log `(img, energy, grad, step_i)`
   - **Input state**: `img` tensor (current optimization state)
   - **Energy**: `energy` scalar (current energy value)
   - **Gradient**: `grad` tensor (same shape as img)

3. **Individual EBM classes** for domain-specific energy:
   - `SudokuEBM.forward()`: Sudoku constraint satisfaction
   - `GraphEBM.forward()`: Graph reasoning tasks
   - `EBM.forward()`: General continuous optimization

## 5. Convergence and Success Logic

### Energy-Based Convergence:
- **Step rejection**: Steps that increase energy are rejected
- **Fixed iterations**: Most tasks use fixed number of optimization steps (typically 5)
- **No explicit convergence threshold**: Relies on diffusion schedule and step budget

### Success Evaluation (Task-Specific):
- **Sudoku**: Check constraint satisfaction (row/column/box uniqueness)
- **Connectivity**: Verify graph connectivity properties
- **Planning**: Check if goal state is reached
- **Continuous**: Distance to target within tolerance

### Key Evaluation Points:
1. **During training**: `supervise_energy_landscape` mode adds contrastive energy loss
2. **During sampling**: `p_sample_loop()` method for full trajectory generation
3. **Task evaluation**: Domain-specific success metrics in dataset classes

## 6. Training vs Evaluation Modes

### Training Mode:
- **File**: `diffusion_lib/denoising_diffusion_pytorch_1d.py:705` (forward method)
- **Supervision**: Both denoising loss and energy landscape supervision
- **Energy landscape loss**: Contrastive estimation with data vs noise samples
- **Gradient computation**: Full backpropagation through optimization steps

### Evaluation Mode:
- **File**: `diffusion_lib/denoising_diffusion_pytorch_1d.py:408` (`p_sample_loop` method)
- **Trajectory generation**: Full diffusion process from noise to solution
- **Deterministic**: Uses `.detach()` in optimization steps
- **Metrics**: Task-specific success rates and solution quality

## 7. Key Functions to Hook for Trajectory Logging

### Primary Hook Point:
**`GaussianDiffusion1D.opt_step()`** - Lines 373-406
```python
def opt_step(self, inp, img, t, mask, data_cond, step=5, eval=True, sf=1.0, detach=True):
    # HOOK HERE: Log initial state
    trajectory_states = [img.detach().cpu().numpy()]
    trajectory_energies = []
    
    for i in range(step):
        energy, grad = self.model(inp, img, t, return_both=True)
        # HOOK HERE: Log (state, energy, grad, step_i)
        trajectory_energies.append(energy.detach().cpu().numpy())
        
        # Update step...
        img_new = img - extract(self.opt_step_size, t, grad.shape) * grad * sf
        
        # HOOK HERE: Log updated state
        trajectory_states.append(img_new.detach().cpu().numpy())
```

### Secondary Hook Points:
1. **`p_sample_loop()`**: Full diffusion trajectory (multiple timesteps)
2. **Model-specific `forward()` methods**: Domain-specific energy computations
3. **Training loop**: `Trainer1D.train()` for training dynamics

### Logging Data Structure:
```python
trajectory_data = {
    'states': np.array(states),          # (n_steps, batch_size, state_dim)
    'energies': np.array(energies),      # (n_steps, batch_size)
    'gradients': np.array(gradients),    # (n_steps, batch_size, state_dim)
    'timesteps': np.array(timesteps),    # (batch_size,) diffusion timestep
    'success': np.array(success_flags),  # (batch_size,) task completion
    'metadata': {
        'task': 'sudoku',  # or 'connectivity', 'planning', etc.
        'problem_difficulty': difficulty_scores,
        'convergence_step': final_steps
    }
}
```

## 8. Next Steps for Trajectory Collection

1. **Instrument `opt_step()` method**: Add logging hooks in the optimization loop
2. **Choose target domain**: Sudoku is well-structured and interpretable
3. **Run evaluation**: Generate trajectories on test problems
4. **Save trajectory data**: Store as `.npz` files for manifold analysis
5. **Preprocess for dimensionality reduction**: Flatten state tensors appropriately