# 06. Performance Optimization: Hyperparameter Search

This notebook performs a **Grid Search** to optimize the hyperparameters of the **Triglial Reservoir** (3GSNN) for MNIST classification.

## Parameters to Tune
1.  **Reservoir Size (`hidden_dim`)**: 100 vs 200 (Larger is usually better but slower).
2.  **Spectral Radius**: Scales the recurrent weights. Controls the dynamic regime (stable vs. chaotic).
3.  **Microglia Pruning Threshold**: Controls sparsity.
4.  **Astrocyte Target Rate**: Controls homeostatic activity levels.

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
import itertools
import time

from pyngn.reservoir import TriglialReservoir

%matplotlib inline

## 1. Load Data (Subset)
Using a small subset (500 samples) for faster search.

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_data = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_data = datasets.MNIST('./data', train=False, download=True, transform=transform)

subset_size = 500
train_subset = Subset(train_data, range(subset_size))
test_subset = Subset(test_data, range(100))

train_loader = DataLoader(train_subset, batch_size=1, shuffle=True)
test_loader = DataLoader(test_subset, batch_size=1, shuffle=False)

def poisson_encode(image, time_steps=50, gain=10.0):
    pixels = image.view(-1)
    rate = pixels * gain
    prob = torch.clamp(rate * 0.1, 0, 1)
    spikes = torch.rand(time_steps, pixels.shape[0]) < prob.unsqueeze(0)
    return spikes.float()

## 2. Evaluation Function

In [None]:
def evaluate_model(hidden_dim, spectral_radius, pruning_threshold, target_rate):
    # Initialize Model
    model = TriglialReservoir(784, hidden_dim, 10, 
                              dt=1.0, max_delay=10,
                              astro_params={'target_rate': target_rate},
                              micro_params={'pruning_threshold': pruning_threshold})
    
    # Scale Recurrent Weights (Spectral Radius)
    # Note: This is a simplified scaling. Proper spectral radius requires eigen-decomposition.
    # Here we just scale the std dev of initialization.
    with torch.no_grad():
        model.recurrent_weights *= spectral_radius
    
    # Train Readout
    X_train_states = []
    Y_train_targets = []
    
    for img, label in train_loader:
        spikes = poisson_encode(img)
        _, state = model(spikes, return_state=True)
        X_train_states.append(state.squeeze(0))
        target = torch.zeros(10)
        target[label] = 1.0
        Y_train_targets.append(target)
        
    X_train = torch.stack(X_train_states)
    Y_train = torch.stack(Y_train_targets)
    model.readout.fit(X_train, Y_train)
    
    # Test
    correct = 0
    total = 0
    with torch.no_grad():
        for img, label in test_loader:
            spikes = poisson_encode(img)
            prediction = model(spikes)
            if torch.argmax(prediction).item() == label.item():
                correct += 1
            total += 1
            
    return correct / total

## 3. Grid Search Loop

In [None]:
# Parameter Grid
param_grid = {
    'hidden_dim': [100, 200],
    'spectral_radius': [1.0, 1.5], # Scaling factor
    'pruning_threshold': [0.1, 0.3, 0.5],
    'target_rate': [0.05, 0.1]
}

results = []
best_acc = 0.0
best_params = {}

keys = param_grid.keys()
combinations = list(itertools.product(*param_grid.values()))

print(f"Starting Grid Search with {len(combinations)} combinations...")
start_time = time.time()

for i, values in enumerate(combinations):
    params = dict(zip(keys, values))
    print(f"Testing {params}...")
    
    acc = evaluate_model(**params)
    results.append({'params': params, 'accuracy': acc})
    
    print(f"  -> Accuracy: {acc*100:.2f}%")
    
    if acc > best_acc:
        best_acc = acc
        best_params = params

elapsed = time.time() - start_time
print(f"\nSearch Complete in {elapsed:.2f}s")
print(f"Best Accuracy: {best_acc*100:.2f}%")
print(f"Best Parameters: {best_params}")

## 4. Visualize Results

In [None]:
# Simple visualization of accuracy distribution
accuracies = [r['accuracy'] for r in results]
plt.hist(accuracies, bins=10)
plt.title("Accuracy Distribution from Grid Search")
plt.xlabel("Accuracy")
plt.ylabel("Count")
plt.show()