# Tutorial 02 — Train a SOEN Model (ORIGINAL - Control Experiment)

**This notebook uses the ORIGINAL model spec with `J_1_to_2.learnable: false`**

---

## ⚠️ NOISE CONFIGURATION: DISABLED

> **This tutorial runs with NO NOISE INJECTION (ideal conditions).**
>
> | Parameter | Value | Description |
> |-----------|-------|-------------|
> | `phi` | **0.0** | Flux noise (thermal fluctuations) |
> | `s` | **0.0** | State noise (integration errors) |
> | `g` | **0.0** | Source function noise |
> | `bias_current` | **0.0** | Bias current noise |
> | `j` | **0.0** | Connection weight noise |
>
> To enable hardware-realistic noise, modify the `noise` sections in:
> `training/test_models/model_specs/1D_5D_2D_PulseNetSpec.yaml`

---

This serves as a control experiment to demonstrate that even with:
- Higher learning rate (0.001 vs 0.0001)
- More epochs (50 vs 10)
- Same optimizer and scheduler settings

The model **will NOT learn** because gradient flow is blocked at the output connection.

### Expected Result
- Accuracy should remain ~50% (random guessing)
- Loss should stay near ln(2) ≈ 0.693

### Compare With
Run `02_train_a_model.ipynb` which uses the FIXED config with `J_1_to_2.learnable: true`
to see proper training.

### ML Task Overview
Binary classification on time-series inputs:
- **Class 0**: Single pulse
- **Class 1**: Two distinct pulses

**Imports**

In [None]:
# Setup: Ensure soen_toolkit is importable
import sys
from pathlib import Path

# Add src directory to path if running from notebook location
notebook_dir = Path.cwd()
for parent in [notebook_dir] + list(notebook_dir.parents):
    candidate = parent / "src"
    if (candidate / "soen_toolkit").exists():
        sys.path.insert(0, str(candidate))
        break

import numpy as np
import matplotlib.pyplot as plt
import h5py
import torch
import glob

from soen_toolkit.training.trainers.experiment import run_from_config

**Visualize the Dataset**

Let's look at examples from each class to understand the task.

In [None]:
# ============================================================================
# VISUALIZE DATASET: One-pulse vs Two-pulse classification
# ============================================================================

def visualize_dataset(data_path="training/datasets/soen_seq_task_one_or_two_pulses_seq64.hdf5", n_examples=4):
    """Visualize examples from each class in the dataset."""
    
    with h5py.File(data_path, 'r') as f:
        data = np.array(f['train']['data'])
        labels = np.array(f['train']['labels'])
    
    print(f"Dataset shape: {data.shape} (N samples, T timesteps, D features)")
    print(f"Labels shape: {labels.shape}")
    print(f"Class distribution: {np.bincount(labels)}")
    
    # Find examples of each class
    class_0_idx = np.where(labels == 0)[0][:n_examples]
    class_1_idx = np.where(labels == 1)[0][:n_examples]
    
    fig, axes = plt.subplots(2, n_examples, figsize=(3*n_examples, 5))
    fig.suptitle("Input Signals: One-Pulse (Class 0) vs Two-Pulse (Class 1)", fontsize=12, fontweight='bold')
    
    # Plot Class 0 (single pulse)
    for i, idx in enumerate(class_0_idx):
        axes[0, i].plot(data[idx, :, 0], 'b-', linewidth=1.5)
        axes[0, i].set_title(f"Sample {idx}", fontsize=10)
        axes[0, i].set_ylim(-0.1, 1.1)
        axes[0, i].grid(True, alpha=0.3)
        if i == 0:
            axes[0, i].set_ylabel("Class 0\n(One Pulse)", fontsize=10)
    
    # Plot Class 1 (two pulses)
    for i, idx in enumerate(class_1_idx):
        axes[1, i].plot(data[idx, :, 0], 'r-', linewidth=1.5)
        axes[1, i].set_title(f"Sample {idx}", fontsize=10)
        axes[1, i].set_ylim(-0.1, 1.1)
        axes[1, i].grid(True, alpha=0.3)
        if i == 0:
            axes[1, i].set_ylabel("Class 1\n(Two Pulses)", fontsize=10)
        axes[1, i].set_xlabel("Time step")
    
    plt.tight_layout()
    plt.show()
    
    return data, labels

# Visualize the dataset
train_data, train_labels = visualize_dataset()

**Training (ORIGINAL CONFIG - Control Experiment)**

This uses the ORIGINAL model spec with `J_1_to_2.learnable: false`.
Even with 50 epochs and higher LR, accuracy should stay ~50%.

In [None]:
# Launch training with ORIGINAL (broken) config
# Uses: 1D_5D_2D_PulseNetSpec.yaml with J_1_to_2.learnable = FALSE
# Expected: ~50% accuracy (random guessing) due to blocked gradients
run_from_config("training/training_configs/pulse_net_orig_long.yaml", script_dir=Path.cwd())

**Visualize Predictions**

After training, let's see how the model performs. Expected: ~50% accuracy.

In [None]:
# ============================================================================
# VISUALIZE PREDICTIONS: Show model predictions on test samples
# ============================================================================

def visualize_predictions(n_samples=8):
    """Load trained model and visualize predictions on test data."""
    
    from soen_toolkit.core.model_yaml import build_model_from_yaml
    
    # Find the latest checkpoint
    ckpt_patterns = [
        "training/temp/**/checkpoints/**/*.ckpt",
        "training/temp/**/*.ckpt",
    ]
    
    all_ckpts = []
    for pattern in ckpt_patterns:
        all_ckpts.extend(glob.glob(pattern, recursive=True))
    
    if not all_ckpts:
        print("No checkpoint found. Run training first.")
        return
    
    latest_ckpt = max(all_ckpts, key=lambda x: Path(x).stat().st_mtime)
    print(f"Loading checkpoint: {latest_ckpt}")
    
    # Load model architecture and weights - using ORIGINAL spec
    model_path = Path("training/test_models/model_specs/1D_5D_2D_PulseNetSpec.yaml")
    model = build_model_from_yaml(model_path)
    
    # Load trained weights
    ckpt = torch.load(latest_ckpt, map_location='cpu')
    state_dict = ckpt.get('state_dict', ckpt)
    
    # Remove 'model.' prefix if present (from Lightning wrapper)
    clean_state_dict = {}
    for k, v in state_dict.items():
        if k.startswith('model.'):
            clean_state_dict[k[6:]] = v
        else:
            clean_state_dict[k] = v
    
    try:
        model.load_state_dict(clean_state_dict, strict=False)
        print("Model weights loaded successfully")
    except Exception as e:
        print(f"Warning loading weights: {e}")
    
    model.eval()
    
    # Load test data
    data_path = Path("training/datasets/soen_seq_task_one_or_two_pulses_seq64.hdf5")
    with h5py.File(data_path, 'r') as f:
        # Use validation set if available, otherwise train
        if 'val' in f:
            test_data = np.array(f['val']['data'][:n_samples])
            test_labels = np.array(f['val']['labels'][:n_samples])
        else:
            test_data = np.array(f['train']['data'][:n_samples])
            test_labels = np.array(f['train']['labels'][:n_samples])
    
    # Run inference
    with torch.no_grad():
        x = torch.tensor(test_data, dtype=torch.float32)
        output, all_states = model(x)
        
        # Apply max pooling over time (like training)
        if output.dim() == 3:
            pooled = output.max(dim=1)[0]  # [batch, num_classes]
        else:
            pooled = output
        
        # Get predictions
        probs = torch.softmax(pooled, dim=1)
        predictions = torch.argmax(probs, dim=1).numpy()
        confidence = probs.max(dim=1)[0].numpy()
    
    # Visualize
    n_cols = min(4, n_samples)
    n_rows = (n_samples + n_cols - 1) // n_cols
    
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(3.5*n_cols, 3*n_rows))
    if n_rows == 1:
        axes = axes.reshape(1, -1)
    
    fig.suptitle("Model Predictions (ORIGINAL - Expected ~50% Accuracy)", fontsize=14, fontweight='bold')
    
    class_names = ["One Pulse", "Two Pulses"]
    
    for i in range(n_samples):
        row, col = i // n_cols, i % n_cols
        ax = axes[row, col]
        
        # Plot input signal
        ax.plot(test_data[i, :, 0], 'b-', linewidth=1.5, alpha=0.8)
        ax.set_ylim(-0.1, 1.1)
        ax.grid(True, alpha=0.3)
        
        # Determine if prediction is correct
        is_correct = predictions[i] == test_labels[i]
        color = 'green' if is_correct else 'red'
        symbol = '✓' if is_correct else '✗'
        
        # Title with prediction info
        true_label = class_names[test_labels[i]]
        pred_label = class_names[predictions[i]]
        
        ax.set_title(
            f"{symbol} Pred: {pred_label} ({confidence[i]:.1%})\nTrue: {true_label}",
            fontsize=9,
            color=color,
            fontweight='bold' if not is_correct else 'normal'
        )
        
        if col == 0:
            ax.set_ylabel("Signal")
        if row == n_rows - 1:
            ax.set_xlabel("Time step")
    
    # Hide empty subplots
    for i in range(n_samples, n_rows * n_cols):
        row, col = i // n_cols, i % n_cols
        axes[row, col].axis('off')
    
    plt.tight_layout()
    plt.show()
    
    # Print summary
    accuracy = (predictions == test_labels).mean()
    print(f"\n{'='*50}")
    print(f"PREDICTION SUMMARY (ORIGINAL CONFIG)")
    print(f"{'='*50}")
    print(f"Accuracy on {n_samples} samples: {accuracy:.1%}")
    print(f"Correct: {(predictions == test_labels).sum()}/{n_samples}")
    
    if accuracy < 0.6:
        print(f"\n⚠️  As expected, accuracy is near random (~50%)")
        print(f"   This confirms gradient flow is blocked.")
        print(f"   Run 02_train_a_model.ipynb for the fixed version.")

# Visualize predictions
visualize_predictions(n_samples=8)

**Gradient Flow Diagnostic**

Run this to confirm gradients are blocked at `J_1_to_2`.

In [None]:
# ============================================================================
# DIAGNOSTIC: Check gradient flow through the ORIGINAL model
# ============================================================================

def check_gradient_flow():
    """Test if gradients flow through the ORIGINAL SOEN model."""
    
    from soen_toolkit.core.model_yaml import build_model_from_yaml
    
    # Load ORIGINAL model spec
    model_path = Path("training/test_models/model_specs/1D_5D_2D_PulseNetSpec.yaml")
    model = build_model_from_yaml(model_path)
    model.train()
    
    # Load a small batch of data
    data_path = Path("training/datasets/soen_seq_task_one_or_two_pulses_seq64.hdf5")
    with h5py.File(data_path, 'r') as f:
        x = torch.tensor(f['train']['data'][:8], dtype=torch.float32)
        y = torch.tensor(f['train']['labels'][:8], dtype=torch.long)
    
    print("=" * 70)
    print("GRADIENT FLOW DIAGNOSTIC (ORIGINAL MODEL)")
    print("=" * 70)
    
    # Forward pass
    x.requires_grad_(True)
    output, all_states = model(x)
    
    # Apply time pooling
    if output.dim() == 3:
        pooled = output.max(dim=1)[0]
    else:
        pooled = output
    
    # Compute loss and backward
    loss_fn = torch.nn.CrossEntropyLoss()
    loss = loss_fn(pooled, y)
    loss.backward()
    
    # Check gradients
    print(f"\n{'Parameter':<45} {'Grad Norm':<15} {'Has Grad':<10}")
    print("-" * 70)
    
    blocked_params = []
    for name, param in model.named_parameters():
        if param.grad is not None:
            grad_norm = param.grad.norm().item()
            has_grad = "✓" if grad_norm > 1e-10 else "✗ (zero)"
            print(f"{name:<45} {grad_norm:<15.8f} {has_grad}")
        else:
            print(f"{name:<45} {'None':<15} ✗ (no grad)")
            blocked_params.append(name)
    
    print("\n" + "=" * 70)
    print("DIAGNOSIS")
    print("=" * 70)
    
    if blocked_params:
        print(f"⚠️  {len(blocked_params)} parameters have NO gradients:")
        for name in blocked_params:
            print(f"   - {name}")
        print(f"\n→ This confirms gradient flow is BLOCKED!")
        print(f"→ The model cannot learn because J_1_to_2.learnable = false")
    else:
        print("✓ All parameters have gradients (unexpected for original model)")

# Run diagnostic
check_gradient_flow()

---

## Conclusion

This control experiment demonstrates that:

1. **More epochs don't help** when gradients are blocked
2. **Higher learning rate doesn't help** when there's nothing to learn
3. **The issue is architectural** (frozen output connection), not hyperparameters

To see proper training, run `02_train_a_model.ipynb` which uses `J_1_to_2.learnable: true`.