# PyTorch Optimizer Comparison on SOEN SingleDendrite Model

This notebook compares different PyTorch optimizers on a **real SOEN model** with `SingleDendrite` neurons.

## Model Architecture
- **Input Layer**: 1D flux input (Linear layer, non-physical)
- **Hidden Layer**: 3 SingleDendrite neurons (physical SOEN neurons)
- **Output Layer**: 2D readout (for binary classification)

## Task
Binary classification: distinguish single-pulse vs double-pulse input signals.

## Optimizers Compared
1. **SGD** - Vanilla Stochastic Gradient Descent
2. **SGD + Momentum** - SGD with momentum
3. **Adam** - Adaptive Moment Estimation
4. **AdamW** - Adam with decoupled weight decay
5. **RMSprop** - Root Mean Square Propagation

---

In [None]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import copy

from soen_toolkit.core import (
    ConnectionConfig,
    LayerConfig,
    SimulationConfig,
    SOENModelCore,
)

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

print(f"PyTorch version: {torch.__version__}")

## 1. Generate Synthetic Pulse Data

Create a binary classification dataset:
- **Class 0**: Single pulse
- **Class 1**: Double pulse

In [None]:
def generate_pulse_data(n_samples=100, seq_len=50, noise_std=0.02):
    """
    Generate single-pulse vs double-pulse classification data.
    
    Returns:
        X: [n_samples, seq_len, 1] - input flux sequences
        y: [n_samples] - labels (0=single, 1=double)
    """
    X = []
    y = []
    
    pulse_width = 5
    pulse_amplitude = 0.2
    
    for i in range(n_samples):
        signal = torch.zeros(seq_len)
        
        if i % 2 == 0:  # Single pulse (Class 0)
            start = seq_len // 4
            signal[start:start+pulse_width] = pulse_amplitude
            y.append(0)
        else:  # Double pulse (Class 1)
            start1 = seq_len // 5
            start2 = 3 * seq_len // 5
            signal[start1:start1+pulse_width] = pulse_amplitude
            signal[start2:start2+pulse_width] = pulse_amplitude
            y.append(1)
        
        # Add noise
        signal += noise_std * torch.randn(seq_len)
        X.append(signal.unsqueeze(-1))  # [seq_len, 1]
    
    X = torch.stack(X)  # [n_samples, seq_len, 1]
    y = torch.tensor(y, dtype=torch.long)
    
    return X, y

# Generate data
N_SAMPLES = 100
SEQ_LEN = 50
X_data, y_data = generate_pulse_data(N_SAMPLES, SEQ_LEN)

print(f"Input shape: {X_data.shape}")
print(f"Labels shape: {y_data.shape}")
print(f"Class distribution: {torch.bincount(y_data)}")

# Visualize samples
fig, axes = plt.subplots(2, 3, figsize=(12, 6))
for i, ax in enumerate(axes.flat):
    idx = i if i < 3 else (i - 3 + 1)  # Show alternating classes
    ax.plot(X_data[idx, :, 0].numpy())
    ax.set_title(f"Sample {idx}: {'Single' if y_data[idx]==0 else 'Double'} Pulse")
    ax.set_xlabel('Time step')
    ax.set_ylabel('Flux')
    ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

## 2. Build SOEN Model with SingleDendrite Neurons

Architecture: 1D → 3D (SingleDendrite) → 2D (Output)

- **Learnable parameters**: Connection weights (J matrices)
- **SingleDendrite dynamics**: ds/dt = γ+ * g(φ) - γ- * s

In [None]:
def build_soen_model(hidden_dim=3, output_dim=2, dt=100.0):
    """
    Build a simple SOEN model for binary classification.
    
    Architecture: 1 (input) → hidden_dim (SingleDendrite) → output_dim (readout)
    """
    sim_cfg = SimulationConfig(
        dt=dt,
        input_type="state",  # Input directly as state
        track_phi=False,
        track_power=False,
    )
    
    # Layer 0: Input (1D flux input)
    layer0 = LayerConfig(
        layer_id=0,
        layer_type="Input",
        description="Input flux",
        params={"dim": 1},
    )
    
    # Layer 1: Hidden SingleDendrite neurons
    layer1 = LayerConfig(
        layer_id=1,
        layer_type="SingleDendrite",
        description="Hidden SOEN neurons",
        params={
            "dim": hidden_dim,
            "solver": "FE",
            "source_func": "Heaviside_fit_state_dep",
            "phi_offset": 0.02,
            "bias_current": 1.98,
            "gamma_plus": 0.00013,
            "gamma_minus": 1e-8,
        },
    )
    
    # Layer 2: Output (readout)
    layer2 = LayerConfig(
        layer_id=2,
        layer_type="Input",  # Non-dynamic readout
        description="Output readout",
        params={"dim": output_dim},
    )
    
    layers = [layer0, layer1, layer2]
    
    # Connection 0→1: Input to hidden (learnable)
    conn01 = ConnectionConfig(
        from_layer=0,
        to_layer=1,
        connection_type="all_to_all",
        learnable=True,
        params={
            "init": "uniform",
            "min": 0.1,
            "max": 0.24,
            "constraints": {"min": -0.24, "max": 0.24},
        },
    )
    
    # Connection 1→2: Hidden to output (learnable)
    conn12 = ConnectionConfig(
        from_layer=1,
        to_layer=2,
        connection_type="all_to_all",
        learnable=True,
        params={
            "init": "uniform",
            "min": -0.5,
            "max": 0.5,
        },
    )
    
    connections = [conn01, conn12]
    
    model = SOENModelCore(
        sim_config=sim_cfg,
        layers_config=layers,
        connections_config=connections,
    )
    
    return model

# Build and inspect model
test_model = build_soen_model()
print("Model Structure:")
print(f"  Layers: {[l.dim for l in test_model.layers]}")
print(f"  Trainable parameters: {sum(p.numel() for p in test_model.parameters() if p.requires_grad)}")

# List trainable parameters
print("\nTrainable Parameters:")
for name, param in test_model.named_parameters():
    if param.requires_grad:
        print(f"  {name}: {param.shape}")

In [None]:
# Visualize the model architecture
test_model.visualize(show_descriptions=True, theme="modern")

## 3. Define Training Loop

Train the SOEN model and track loss/accuracy over epochs.

In [None]:
def train_soen_model(optimizer_class, optimizer_kwargs, X_data, y_data,
                     n_epochs=50, batch_size=16, verbose=False):
    """
    Train a SOEN model with the specified optimizer.
    
    Returns:
        Dict with losses, accuracies, and final model state
    """
    # Build fresh model
    model = build_soen_model()
    model.train()
    
    # Loss and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optimizer_class(model.parameters(), **optimizer_kwargs)
    
    # Tracking
    losses = []
    accuracies = []
    
    n_samples = len(X_data)
    
    for epoch in range(n_epochs):
        # Shuffle data
        perm = torch.randperm(n_samples)
        X_shuffled = X_data[perm]
        y_shuffled = y_data[perm]
        
        epoch_loss = 0.0
        correct = 0
        total = 0
        
        for i in range(0, n_samples, batch_size):
            X_batch = X_shuffled[i:i+batch_size]
            y_batch = y_shuffled[i:i+batch_size]
            
            optimizer.zero_grad()
            
            # Forward pass through SOEN model
            # Returns (final_history, all_histories)
            final_history, _ = model(X_batch)
            
            # Time pooling: take mean over sequence
            # final_history shape: [batch, seq_len+1, output_dim]
            output = final_history[:, 1:, :].mean(dim=1)  # [batch, output_dim]
            
            # Compute loss
            loss = criterion(output, y_batch)
            
            # Backward pass
            loss.backward()
            
            # Gradient clipping to prevent exploding gradients
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            
            optimizer.step()
            
            epoch_loss += loss.item() * len(X_batch)
            _, predicted = output.max(1)
            correct += (predicted == y_batch).sum().item()
            total += len(y_batch)
        
        avg_loss = epoch_loss / n_samples
        accuracy = correct / total
        losses.append(avg_loss)
        accuracies.append(accuracy)
        
        if verbose and (epoch + 1) % 10 == 0:
            print(f"  Epoch {epoch+1}: Loss={avg_loss:.4f}, Acc={accuracy:.2%}")
    
    return {
        'losses': losses,
        'accuracies': accuracies,
        'final_loss': losses[-1],
        'final_accuracy': accuracies[-1],
        'model': model,
    }

## 4. Compare Optimizers on SOEN Training

Train the SingleDendrite network with different optimizers.

In [None]:
# Optimizer configurations
OPTIMIZERS = {
    'SGD (lr=0.01)': (torch.optim.SGD, {'lr': 0.01}),
    'SGD (lr=0.1)': (torch.optim.SGD, {'lr': 0.1}),
    'SGD + Momentum': (torch.optim.SGD, {'lr': 0.01, 'momentum': 0.9}),
    'Adam (lr=0.01)': (torch.optim.Adam, {'lr': 0.01}),
    'Adam (lr=0.001)': (torch.optim.Adam, {'lr': 0.001}),
    'AdamW': (torch.optim.AdamW, {'lr': 0.01}),
    'RMSprop': (torch.optim.RMSprop, {'lr': 0.01}),
}

N_EPOCHS = 50
BATCH_SIZE = 16

# Train with each optimizer
results = {}
for name, (opt_class, opt_kwargs) in OPTIMIZERS.items():
    print(f"Training with {name}...")
    results[name] = train_soen_model(
        opt_class, opt_kwargs, X_data, y_data,
        n_epochs=N_EPOCHS, batch_size=BATCH_SIZE, verbose=False
    )
    print(f"  Final: Loss={results[name]['final_loss']:.4f}, "
          f"Acc={results[name]['final_accuracy']:.2%}")

print("\nTraining complete!")

## 5. Visualize Training Curves

In [None]:
# Plot loss and accuracy curves
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
colors = plt.cm.tab10(np.linspace(0, 1, len(results)))

# Loss curves
ax1 = axes[0]
for (name, res), color in zip(results.items(), colors):
    ax1.plot(res['losses'], label=name, linewidth=2, color=color)
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Cross-Entropy Loss')
ax1.set_title('Training Loss (SOEN SingleDendrite Model)')
ax1.legend(loc='upper right', fontsize=9)
ax1.grid(True, alpha=0.3)

# Accuracy curves
ax2 = axes[1]
for (name, res), color in zip(results.items(), colors):
    ax2.plot(res['accuracies'], label=name, linewidth=2, color=color)
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Accuracy')
ax2.set_title('Training Accuracy (SOEN SingleDendrite Model)')
ax2.legend(loc='lower right', fontsize=9)
ax2.axhline(y=0.5, color='gray', linestyle='--', alpha=0.5, label='Random')
ax2.set_ylim([0.4, 1.05])
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## 6. Log-Scale Loss Comparison

In [None]:
fig, ax = plt.subplots(figsize=(10, 6))

for (name, res), color in zip(results.items(), colors):
    ax.plot(res['losses'], label=name, linewidth=2, color=color)

ax.set_xlabel('Epoch', fontsize=12)
ax.set_ylabel('Loss (log scale)', fontsize=12)
ax.set_yscale('log')
ax.set_title('Training Loss Comparison (Log Scale)', fontsize=14)
ax.legend(loc='upper right')
ax.grid(True, alpha=0.3, which='both')

plt.tight_layout()
plt.show()

## 7. Convergence Speed Analysis

In [None]:
def epochs_to_accuracy(accuracies, threshold):
    """Return number of epochs to reach accuracy threshold."""
    for i, acc in enumerate(accuracies):
        if acc >= threshold:
            return i + 1
    return None

# Analyze convergence
thresholds = [0.6, 0.7, 0.8, 0.9, 0.95]

print("Epochs to reach accuracy threshold:")
print("=" * 80)
header = f"{'Optimizer':<25}" + "".join([f"Acc>{t:.0%}    " for t in thresholds])
print(header)
print("-" * 80)

for name, res in results.items():
    row = f"{name:<25}"
    for thresh in thresholds:
        epochs = epochs_to_accuracy(res['accuracies'], thresh)
        row += f"{str(epochs) if epochs else 'N/A':<12}"
    print(row)

## 8. Summary Statistics

In [None]:
import pandas as pd

# Create summary table
summary_data = []
for name, res in results.items():
    # Find best accuracy epoch
    best_acc_epoch = np.argmax(res['accuracies']) + 1
    best_acc = max(res['accuracies'])
    
    summary_data.append({
        'Optimizer': name,
        'Final Loss': f"{res['final_loss']:.4f}",
        'Final Accuracy': f"{res['final_accuracy']:.1%}",
        'Best Accuracy': f"{best_acc:.1%}",
        'Best Acc Epoch': best_acc_epoch,
        'Loss @ Epoch 10': f"{res['losses'][9]:.4f}" if len(res['losses']) > 9 else 'N/A',
    })

df = pd.DataFrame(summary_data)
print("\n" + "=" * 90)
print("SOEN OPTIMIZER COMPARISON SUMMARY")
print("=" * 90)
print(f"\nModel: 1D → 3D (SingleDendrite) → 2D")
print(f"Task: Binary pulse classification")
print(f"Epochs: {N_EPOCHS}, Batch size: {BATCH_SIZE}")
print(f"Samples: {N_SAMPLES}\n")
print(df.to_string(index=False))
print("=" * 90)

## 9. Visualize Best Model's Output

Show how the trained model processes single vs double pulse inputs.

In [None]:
# Find best performing optimizer
best_name = max(results, key=lambda x: results[x]['final_accuracy'])
best_model = results[best_name]['model']
print(f"Best optimizer: {best_name} (Acc={results[best_name]['final_accuracy']:.1%})")

# Test on sample inputs
best_model.eval()
with torch.no_grad():
    # Get one sample of each class
    single_pulse_idx = (y_data == 0).nonzero()[0].item()
    double_pulse_idx = (y_data == 1).nonzero()[0].item()
    
    test_inputs = torch.stack([X_data[single_pulse_idx], X_data[double_pulse_idx]])
    test_labels = torch.tensor([0, 1])
    
    final_hist, all_hist = best_model(test_inputs)

# Plot state trajectories
fig, axes = plt.subplots(2, 3, figsize=(15, 8))

for row, (label, name) in enumerate([(0, 'Single Pulse'), (1, 'Double Pulse')]):
    # Input
    axes[row, 0].plot(test_inputs[row, :, 0].numpy())
    axes[row, 0].set_title(f'{name}: Input Flux')
    axes[row, 0].set_xlabel('Time step')
    axes[row, 0].set_ylabel('Flux')
    axes[row, 0].grid(True, alpha=0.3)
    
    # Hidden layer states (SingleDendrite neurons)
    hidden_states = all_hist[1][row, 1:, :].numpy()  # [seq_len, hidden_dim]
    for n in range(hidden_states.shape[1]):
        axes[row, 1].plot(hidden_states[:, n], label=f'Neuron {n}')
    axes[row, 1].set_title(f'{name}: SingleDendrite States (s)')
    axes[row, 1].set_xlabel('Time step')
    axes[row, 1].set_ylabel('State (s)')
    axes[row, 1].legend()
    axes[row, 1].grid(True, alpha=0.3)
    
    # Output layer
    output_states = all_hist[2][row, 1:, :].numpy()  # [seq_len, output_dim]
    axes[row, 2].plot(output_states[:, 0], label='Class 0 (Single)')
    axes[row, 2].plot(output_states[:, 1], label='Class 1 (Double)')
    axes[row, 2].set_title(f'{name}: Output Layer')
    axes[row, 2].set_xlabel('Time step')
    axes[row, 2].set_ylabel('Output')
    axes[row, 2].legend()
    axes[row, 2].grid(True, alpha=0.3)

plt.suptitle(f'Best Model ({best_name}) - State Trajectories', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

## 10. Key Observations

### SOEN-Specific Training Insights:

| Optimizer | Characteristics for SOEN |
|-----------|-------------------------|
| **SGD (low lr)** | Stable but slow; may need many epochs for SOEN dynamics |
| **SGD (high lr)** | Risk of instability with SOEN's nonlinear dynamics |
| **SGD + Momentum** | Helps navigate SOEN's complex loss landscape |
| **Adam** | Generally good default; adaptive to SOEN's varying gradients |
| **AdamW** | Better regularization for SOEN weights |
| **RMSprop** | Good for SOEN's non-stationary gradient statistics |

### SOEN Training Considerations:
- **Gradient clipping** is often necessary due to exploding gradients through time
- **Learning rate** must be tuned carefully for the SingleDendrite dynamics
- **Time pooling** strategy (mean, max, final) affects gradient flow
- Connection weight constraints (min/max) affect optimization landscape

### Physical Mapping:
- Each `SingleDendrite` unit corresponds to **one physical neuron** on SOEN hardware
- This 3-neuron hidden layer model would map to 3 physical neurons on a SOEN chip

In [None]:
print("Notebook complete!")
print(f"\nBest performing optimizer: {best_name}")
print(f"Final accuracy: {results[best_name]['final_accuracy']:.1%}")