# Muscle Arm RNN Control

A biologically-inspired neural network controller for a MuJoCo muscle-driven arm.

## Architecture Overview

### Neural Network Modules (`models/modules/`)
- **SensoryModule**: Biologically-inspired proprioceptive neurons
  - Type Ia: Velocity-sensitive (muscle spindle primary)
  - Type II: Length-sensitive (muscle spindle secondary)  
  - Type Ib: Force-sensitive (Golgi tendon organ)
- **TargetEncoder**: Gaussian-tuned spatial grid for target position
- **RNNCore / MLPCore**: Main processing module (RNN has recurrent layer)
- **MotorModule**: Alpha MNs + Gamma static/dynamic outputs

### Controller-Level Features (`models/controllers.py`)
- **Monosynaptic Stretch Reflex Arcs**: Type Ia/II → Alpha MN connections
  - Implemented at controller level (not in motor module)
  - Allows different controllers to have different reflex configurations

### Environment (`envs/`)
- **plant.py**: MuJoCo physics interface with XML parsing
- **reaching.py**: Center-out reaching task with phased trials

### Training Methods
1. **CMA-ES**: Black-box evolutionary optimization
2. **Distillation**: MLP teacher → RNN student knowledge transfer

In [None]:
# Setup path
import sys
sys.path.insert(0, '.')  # Adjust if notebook is not in project root

# Core imports
import numpy as np
import torch
import matplotlib.pyplot as plt
from pathlib import Path
import pickle

%matplotlib inline
plt.rcParams['figure.figsize'] = [12, 6]

## 1. Model Information

Parse the MuJoCo XML to see what's in the model.

In [None]:
from envs.plant import parse_mujoco_xml, get_model_dimensions

XML_PATH = 'mujoco/arm.xml'

parsed = parse_mujoco_xml(XML_PATH)
dims = get_model_dimensions(parsed)

print(f"Model: {parsed.model_name}")
print(f"Timestep: {parsed.timestep}s")
print(f"\nJoints ({parsed.num_joints}):")
for j in parsed.joints:
    print(f"  - {j.name}: {j.joint_type}, range={j.range}")

print(f"\nMuscles ({parsed.num_muscles}):")
for m in parsed.muscles:
    print(f"  - {m.name}: force={m.force}N")

print(f"\nSensors ({parsed.num_sensors}):")
for s in parsed.sensors:
    print(f"  - {s.name}: {s.sensor_type}")

print(f"\nNetwork Dimensions:")
for k, v in dims.items():
    print(f"  {k}: {v}")

## 2. Sensor Calibration

Run random episodes to gather sensor statistics for normalization.

In [None]:
from envs.plant import calibrate_sensors
from envs.reaching import ReachingEnv

# Run calibration (takes ~30 seconds)
print("Calibrating sensors...")
sensor_stats = calibrate_sensors(XML_PATH, num_episodes=50, max_steps=200)

print("\nSensor Statistics:")
for k, v in sensor_stats.items():
    print(f"  {k}: {v}")

# Save for later use
with open('sensor_stats.pkl', 'wb') as f:
    pickle.dump(sensor_stats, f)
print("\nSaved to sensor_stats.pkl")

## 3. Test Environment

The environment now provides raw target XYZ position (not encoded).
Target encoding is done by the controller's TargetEncoder module.

In [None]:
# Create environment
env = ReachingEnv(XML_PATH, sensor_stats=sensor_stats)

print(f"Observation space: {env.observation_space.shape}")
print(f"  - Proprioceptive: {env.num_muscles * 3} (length, velocity, force per muscle)")
print(f"  - Target: 3 (raw XYZ position)")
print(f"Action space: {env.action_space.shape}")
print(f"  - Alpha MN activations only")

# Run a test episode with random actions
obs, info = env.reset()
print(f"\nInitial info: {info}")

total_reward = 0
for step in range(200):
    action = env.action_space.sample()
    obs, reward, terminated, truncated, info = env.step(action)
    total_reward += reward
    
    if step % 50 == 0:
        print(f"Step {step}: phase={info['phase']}, distance={info['distance_to_target']:.3f}")
    
    if terminated or truncated:
        break

print(f"\nEpisode finished after {step+1} steps, total reward: {total_reward:.2f}")
env.close()

## 4. Explore Modules

Let's look at the individual neural network modules.

In [None]:
from models.modules import SensoryModule, MotorModule, TargetEncoder, RNNCore

num_muscles = parsed.num_muscles
num_target_units = 16  # 4x4 grid

# Sensory module
sensory = SensoryModule(num_muscles=num_muscles, use_bias=True)
print(f"SensoryModule:")
print(f"  - Type Ia (velocity): {num_muscles} units")
print(f"  - Type II (length): {num_muscles} units")
print(f"  - Type Ib (force): {num_muscles} units")
print(f"  - Output: {num_muscles * 3} total")

# Target encoder (converts XYZ to Gaussian grid)
target_encoder = TargetEncoder(grid_size=4, sigma=0.25)
print(f"\nTargetEncoder:")
print(f"  - Grid: 4x4 = 16 units")
print(f"  - Sigma: 0.25")

# Test encoding a target
test_target = torch.tensor([[0.1, 0.2, 0.0]])
encoded = target_encoder.encode(test_target)
print(f"  - Input: {test_target.shape}")
print(f"  - Output: {encoded.shape}")

# Motor module (cortical outputs only - reflex is at controller level)
motor = MotorModule(input_size=32, num_muscles=num_muscles)
print(f"\nMotorModule (cortical pathway):")
print(f"  - Alpha MN: {num_muscles} outputs")
print(f"  - Gamma static: {num_muscles} outputs")
print(f"  - Gamma dynamic: {num_muscles} outputs")
print(f"\nNote: Reflex arcs (Ia→Alpha, II→Alpha) are at controller level")

## 5. Create Controller

Initialize the RNN controller and inspect its architecture.

In [None]:
from core.config import ModelConfig
from models.controllers import RNNController, MLPController, create_controller

# Create model config
config = ModelConfig(
    num_muscles=parsed.num_muscles,
    num_sensors=parsed.num_sensors,
    num_target_units=16,  # 4x4 grid
    rnn_hidden_size=32,
    rnn_type='rnn',
    # Bias settings
    proprioceptive_bias=True,
    target_encoding_bias=True,
    output_bias=True,
    reflex_bias=False,  # Reflexes typically don't have bias
)

print("Config:")
print(f"  num_muscles: {config.num_muscles}")
print(f"  num_sensors: {config.num_sensors}")
print(f"  num_target_units: {config.num_target_units}")
print(f"  rnn_hidden_size: {config.rnn_hidden_size}")

# Create RNN controller
rnn_controller = RNNController(config)
print(f"\nRNN Controller: {rnn_controller.count_parameters():,} parameters")

# Create MLP controller (for comparison)
mlp_controller = MLPController(config)
print(f"MLP Controller: {mlp_controller.count_parameters():,} parameters")

# Show reflex weights location
print("\nReflex weights (at controller level):")
for name, param in rnn_controller.named_parameters():
    if 'to_alpha' in name:
        print(f"  - {name}: {param.shape}")

In [None]:
# Test forward pass
batch_size = 1
obs_dim = config.num_sensors + 3  # sensors + raw XYZ target
dummy_obs = torch.randn(batch_size, obs_dim)

rnn_controller.eval()
rnn_controller.init_hidden(batch_size, torch.device('cpu'))

with torch.no_grad():
    action, hidden, info = rnn_controller(dummy_obs)

print(f"Input observation: {dummy_obs.shape}")
print(f"Output action (alpha MN): {action.shape}")
print(f"\nInfo dict keys: {list(info.keys())}")
print(f"  - alpha: {info['alpha'].shape} (final output)")
print(f"  - alpha_cortical: {info['alpha_cortical'].shape} (before reflex)")
print(f"  - gamma_static: {info['gamma_static'].shape}")
print(f"  - gamma_dynamic: {info['gamma_dynamic'].shape}")
print(f"  - sensory_outputs: {list(info['sensory_outputs'].keys())}")
print(f"  - rnn_hidden: {info['rnn_hidden'].shape}")

## 6. Training Method 1: CMA-ES

CMA-ES (Covariance Matrix Adaptation Evolution Strategy) is a black-box optimization method that evolves network weights without computing gradients.

**Advantages:**
- Works with non-differentiable objectives
- Good exploration of parameter space
- Parallelizable across CPUs

**Note**: Set `use_multiprocessing=False` for Jupyter compatibility, or run from command line for better performance.

In [None]:
from training.train_cmaes import run_cmaes_training

# CMA-ES Training
cmaes_results = run_cmaes_training(
    xml_path=XML_PATH,
    output_dir='outputs/cmaes',
    num_generations=100,        # Increase for better results (500+ recommended)
    population_size=32,
    sigma_init=0.1,
    use_multiprocessing=False,  # Set False for Jupyter
    calibration_episodes=30,
    save_checkpoint_every=25,
    inspection_every=25,        # Generate inspection plots periodically
)

print(f"\n{'='*50}")
print(f"CMA-ES Training Complete!")
print(f"{'='*50}")
print(f"Best fitness: {cmaes_results['best_fitness']:.2f}")
print(f"Generations: {cmaes_results['generations']}")
print(f"Time: {cmaes_results['total_time']:.1f}s")

## 7. Training Method 2: Distillation

Distillation training uses a two-phase approach:
1. Train an MLP "teacher" network using behavioral cloning
2. Train an RNN "student" to imitate the teacher's behavior

**Advantages:**
- Leverages gradient-based optimization (faster convergence)
- Teacher can use larger capacity
- Student RNN learns temporal dependencies

In [None]:
from training.train_distillation import run_distillation_training

# Distillation Training
distill_results = run_distillation_training(
    xml_path=XML_PATH,
    output_dir='outputs/distillation',
    teacher_epochs=50,          # MLP teacher training
    student_epochs=100,         # RNN student training
    calibration_episodes=30,
)

print(f"\n{'='*50}")
print(f"Distillation Training Complete!")
print(f"{'='*50}")
print(f"Teacher success rate: {distill_results.get('teacher_success_rate', 0):.1%}")
print(f"Student success rate: {distill_results.get('student_success_rate', 0):.1%}")

## 8. Load and Compare Both Controllers

Load both trained controllers and compare their performance.

In [None]:
from utils.visualization import load_controller, evaluate_controller

# Paths to trained controllers
CMAES_PATH = 'outputs/cmaes/best_controller_final.pt'
DISTILL_PATH = 'outputs/distillation/student_rnn.pt'

# Load CMA-ES controller
cmaes_controller, cmaes_config, cmaes_ckpt = load_controller(CMAES_PATH)
cmaes_controller.eval()
print(f"CMA-ES Controller: {cmaes_controller.count_parameters():,} params")
print(f"  Fitness: {cmaes_ckpt.get('fitness', 'N/A')}")
print(f"  Generation: {cmaes_ckpt.get('generation', 'N/A')}")

# Load Distillation controller
distill_controller, distill_config, distill_ckpt = load_controller(DISTILL_PATH)
distill_controller.eval()
print(f"\nDistillation Controller: {distill_controller.count_parameters():,} params")

# Load sensor stats (same for both)
sensor_stats_path = Path(CMAES_PATH).parent / 'sensor_stats.pkl'
with open(sensor_stats_path, 'rb') as f:
    sensor_stats = pickle.load(f)
print(f"\nLoaded sensor stats from {sensor_stats_path}")

## 9. Quantitative Comparison

Evaluate both controllers on the same set of episodes.

In [None]:
from utils.visualization import compare_controllers

# Side-by-side comparison
comparison = compare_controllers(
    controller_paths=[CMAES_PATH, DISTILL_PATH],
    labels=['CMA-ES', 'Distillation'],
    xml_path=XML_PATH,
    sensor_stats=sensor_stats,
    num_episodes=50
)

# Summary table
print("\n" + "="*60)
print("COMPARISON SUMMARY")
print("="*60)
print(f"{'Metric':<25} {'CMA-ES':>15} {'Distillation':>15}")
print("-"*60)
for label in ['CMA-ES', 'Distillation']:
    r = comparison[label]
    if label == 'CMA-ES':
        print(f"{'Success Rate':<25} {r['success_rate']:>14.1%} ", end='')
    else:
        print(f"{r['success_rate']:>14.1%}")
        
print(f"{'Mean Reward':<25} {comparison['CMA-ES']['mean_reward']:>15.2f} {comparison['Distillation']['mean_reward']:>15.2f}")
print(f"{'Std Reward':<25} {comparison['CMA-ES']['std_reward']:>15.2f} {comparison['Distillation']['std_reward']:>15.2f}")

In [None]:
from utils.visualization import plot_reflex_connections

# Compare reflex connections between the two controllers
fig, axes = plt.subplots(2, 2, figsize=(12, 10))

# CMA-ES reflex weights
for name, param in cmaes_controller.named_parameters():
    if 'Ia_to_alpha' in name and 'weight' in name:
        im = axes[0, 0].imshow(param.data.cpu().numpy(), cmap='RdBu_r', aspect='auto')
        axes[0, 0].set_title('CMA-ES: Ia → Alpha')
        plt.colorbar(im, ax=axes[0, 0])
    elif 'II_to_alpha' in name and 'weight' in name:
        im = axes[0, 1].imshow(param.data.cpu().numpy(), cmap='RdBu_r', aspect='auto')
        axes[0, 1].set_title('CMA-ES: II → Alpha')
        plt.colorbar(im, ax=axes[0, 1])

# Distillation reflex weights
for name, param in distill_controller.named_parameters():
    if 'Ia_to_alpha' in name and 'weight' in name:
        im = axes[1, 0].imshow(param.data.cpu().numpy(), cmap='RdBu_r', aspect='auto')
        axes[1, 0].set_title('Distillation: Ia → Alpha')
        plt.colorbar(im, ax=axes[1, 0])
    elif 'II_to_alpha' in name and 'weight' in name:
        im = axes[1, 1].imshow(param.data.cpu().numpy(), cmap='RdBu_r', aspect='auto')
        axes[1, 1].set_title('Distillation: II → Alpha')
        plt.colorbar(im, ax=axes[1, 1])

for ax in axes.flat:
    ax.set_xlabel('Sensory Neuron (muscle)')
    ax.set_ylabel('Alpha MN (muscle)')

plt.suptitle('Stretch Reflex Connections Comparison', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

## 10. Episode Recording and Analysis

Record episodes from both controllers and compare their behavior.

In [None]:
from utils.visualization import record_episode

# Record episodes with same seed for fair comparison
SEED = 42

# CMA-ES episode
cmaes_trajectory = record_episode(
    controller=cmaes_controller,
    xml_path=XML_PATH,
    sensor_stats=sensor_stats,
    max_steps=300,
    seed=SEED,
)
cmaes_reward = sum(cmaes_trajectory['rewards'])
print(f"CMA-ES Episode: {len(cmaes_trajectory['rewards'])} steps, reward={cmaes_reward:.2f}")

# Distillation episode
distill_trajectory = record_episode(
    controller=distill_controller,
    xml_path=XML_PATH,
    sensor_stats=sensor_stats,
    max_steps=300,
    seed=SEED,
)
distill_reward = sum(distill_trajectory['rewards'])
print(f"Distillation Episode: {len(distill_trajectory['rewards'])} steps, reward={distill_reward:.2f}")

In [None]:
from utils.visualization import plot_episode_summary

# Side-by-side episode summaries
fig = plt.figure(figsize=(20, 12))

# CMA-ES summary
plt.subplot(1, 2, 1)
plt.title('CMA-ES Controller', fontsize=14, fontweight='bold')

# Plot motor outputs comparison
fig, axes = plt.subplots(3, 2, figsize=(16, 12), sharex=True)

# CMA-ES motor outputs
for i in range(cmaes_trajectory['alpha'].shape[1]):
    axes[0, 0].plot(cmaes_trajectory['alpha'][:, i], label=f'M{i+1}')
axes[0, 0].set_ylabel('Alpha MN')
axes[0, 0].set_title('CMA-ES Controller')
axes[0, 0].legend(loc='upper right')

for i in range(cmaes_trajectory['gamma_static'].shape[1]):
    axes[1, 0].plot(cmaes_trajectory['gamma_static'][:, i])
axes[1, 0].set_ylabel('Gamma Static')

for i in range(cmaes_trajectory['gamma_dynamic'].shape[1]):
    axes[2, 0].plot(cmaes_trajectory['gamma_dynamic'][:, i])
axes[2, 0].set_ylabel('Gamma Dynamic')
axes[2, 0].set_xlabel('Step')

# Distillation motor outputs
for i in range(distill_trajectory['alpha'].shape[1]):
    axes[0, 1].plot(distill_trajectory['alpha'][:, i], label=f'M{i+1}')
axes[0, 1].set_title('Distillation Controller')
axes[0, 1].legend(loc='upper right')

for i in range(distill_trajectory['gamma_static'].shape[1]):
    axes[1, 1].plot(distill_trajectory['gamma_static'][:, i])

for i in range(distill_trajectory['gamma_dynamic'].shape[1]):
    axes[2, 1].plot(distill_trajectory['gamma_dynamic'][:, i])
axes[2, 1].set_xlabel('Step')

plt.suptitle('Motor Output Comparison (Same Target)', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

In [None]:
# Compare sensory processing
fig, axes = plt.subplots(3, 2, figsize=(16, 10), sharex=True)

# CMA-ES sensory
axes[0, 0].plot(cmaes_trajectory['sensory_Ia'])
axes[0, 0].set_ylabel('Type Ia\n(velocity)')
axes[0, 0].set_title('CMA-ES Sensory Activity')

axes[1, 0].plot(cmaes_trajectory['sensory_II'])
axes[1, 0].set_ylabel('Type II\n(length)')

axes[2, 0].plot(cmaes_trajectory['sensory_Ib'])
axes[2, 0].set_ylabel('Type Ib\n(force)')
axes[2, 0].set_xlabel('Step')

# Distillation sensory
axes[0, 1].plot(distill_trajectory['sensory_Ia'])
axes[0, 1].set_title('Distillation Sensory Activity')

axes[1, 1].plot(distill_trajectory['sensory_II'])

axes[2, 1].plot(distill_trajectory['sensory_Ib'])
axes[2, 1].set_xlabel('Step')

plt.suptitle('Proprioceptive Sensory Neuron Activity', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

## 11. Network Activity Visualization

Visualize the neural network as a spatial diagram with colored units representing activation levels.

In [None]:
from utils.episode_recorder import record_and_save

# Record CMA-ES with full network visualization
print("Recording CMA-ES episode with network visualization...")
cmaes_data = record_and_save(
    controller=cmaes_controller,
    xml_path=XML_PATH,
    sensor_stats=sensor_stats,
    output_dir='outputs/cmaes/visualizations',
    max_steps=300,
    seed=42,
    fps=30,
)
print(f"  Episode length: {len(cmaes_data.rewards)} steps")
print(f"  Total reward: {sum(cmaes_data.rewards):.2f}")
print(f"  Saved to: outputs/cmaes/visualizations/")

In [None]:
# Record Distillation with full network visualization
print("Recording Distillation episode with network visualization...")
distill_data = record_and_save(
    controller=distill_controller,
    xml_path=XML_PATH,
    sensor_stats=sensor_stats,
    output_dir='outputs/distillation/visualizations',
    max_steps=300,
    seed=42,
    fps=30,
)
print(f"  Episode length: {len(distill_data.rewards)} steps")
print(f"  Total reward: {sum(distill_data.rewards):.2f}")
print(f"  Saved to: outputs/distillation/visualizations/")

## 12. Display Network Activity Frames

Show single frames from the network visualization to see the spatial layout of neural activity.

In [None]:
# Display frames at different phases
from PIL import Image

def show_frame(data, step, title):
    """Show a single frame from recorded episode."""
    if hasattr(data, 'combined_frames') and data.combined_frames:
        frame_idx = min(step, len(data.combined_frames) - 1)
        plt.figure(figsize=(16, 6))
        plt.imshow(data.combined_frames[frame_idx])
        plt.axis('off')
        phase = data.infos[frame_idx].get('phase', 'unknown') if data.infos else 'unknown'
        plt.title(f'{title} - Step {step} (Phase: {phase})', fontsize=12)
        plt.show()

# Show key moments from CMA-ES episode
print("CMA-ES Controller:")
show_frame(cmaes_data, 50, "CMA-ES: Pre-movement")
show_frame(cmaes_data, 150, "CMA-ES: Mid-reach")
show_frame(cmaes_data, 250, "CMA-ES: Near target")

## 13. Full Checkpoint Inspection

Generate comprehensive inspection plots for a trained controller.

In [None]:
from utils.visualization import inspect_checkpoint

# Full inspection of CMA-ES controller
print("="*60)
print("CMA-ES Controller Inspection")
print("="*60)
inspect_checkpoint(
    checkpoint_path=CMAES_PATH,
    xml_path=XML_PATH,
    output_dir='outputs/cmaes/inspection',
    num_episodes=3,
    max_steps=300,
    show=True
)

## 14. Training History Visualization

Plot training curves to see convergence behavior.

In [None]:
from utils.visualization import plot_training_curves
import json

# Plot CMA-ES training history
cmaes_history_path = 'outputs/cmaes/history.json'
if Path(cmaes_history_path).exists():
    plot_training_curves(cmaes_history_path, show=True)
else:
    print(f"Training history not found at {cmaes_history_path}")
    print("Run training first to generate history.")

## 15. Summary and Conclusions

### Key Takeaways

| Aspect | CMA-ES | Distillation |
|--------|--------|--------------|
| **Training** | Black-box, no gradients | Gradient-based, two-phase |
| **Speed** | Slower (evolutionary) | Faster (supervised) |
| **Parallelization** | Excellent (CPU) | Limited (GPU preferred) |
| **Reflex Learning** | Learns from scratch | Learns from teacher |

### Generated Outputs

After running this notebook, you'll have:
- `outputs/cmaes/` - CMA-ES training results, checkpoints, visualizations
- `outputs/distillation/` - Distillation training results
- Network activity videos showing real-time neural processing
- Episode summary plots with phase-shaded regions
- Reflex connection heatmaps

### Command Line Usage

For faster training with multiprocessing:

```bash
# CMA-ES training
python run.py train mujoco/arm.xml --method cmaes \
    --generations 500 --population 64 --workers 8 \
    --inspection-every 25

# Distillation training  
python run.py train mujoco/arm.xml --method distillation \
    --teacher-epochs 100 --student-epochs 200

# Visualization
python run.py visualize mujoco/arm.xml outputs/cmaes/best_controller_final.pt
```