In [1]:
#!/usr/bin/env python
import os
import re
from datetime import datetime
import numpy as np
import torch
import torch.nn as nn
import snntorch as snn
import matplotlib.pyplot as plt
import optuna
from mpl_toolkits.mplot3d import Axes3D  # required for 3D plotting
from torch.utils.tensorboard import SummaryWriter

# =============================================================================
# 1. Fix Random Seeds for Reproducibility
# =============================================================================
np.random.seed(42)
torch.manual_seed(42)

# =============================================================================
# 2. Define Helper Function to Create a Unique Output Folder
# =============================================================================
def create_unique_folder(base_dir):
    """
    Creates a unique folder inside base_dir with the naming pattern:
      trial_N_date_yyyy_mm_dd_hh_mm,
    where N is the highest number found among folders starting with "trial_" plus one.
    """
    if not os.path.exists(base_dir):
        os.makedirs(base_dir)
    # List folders in base_dir that start with "trial_"
    existing = [d for d in os.listdir(base_dir)
                if os.path.isdir(os.path.join(base_dir, d)) and d.startswith("trial_")]
    trial_numbers = []
    for folder in existing:
        match = re.match(r"trial_(\d+)_date_", folder)
        if match:
            trial_numbers.append(int(match.group(1)))
    next_trial = max(trial_numbers) + 1 if trial_numbers else 1
    now_str = datetime.now().strftime("%Y_%m_%d_%H_%M")
    folder_name = f"trial_{next_trial}_date_{now_str}"
    unique_folder = os.path.join(base_dir, folder_name)
    os.makedirs(unique_folder, exist_ok=True)
    return unique_folder

# =============================================================================
# 3. Generate Synthetic Input Signal and Plot It
# =============================================================================
def generate_synthetic_input(output_folder):
    """
    Generates a 1D synthetic signal of 1000 time steps where each block of 100
    time steps alternates between 0.5 and 0.0.
    
    Saves a plot of the input signal inside output_folder.
    
    Returns:
        Tensor of shape (1, 1000, 1)
    """
    num_steps = 1000
    data = np.empty(num_steps, dtype=np.float32)
    for i in range(num_steps):
        data[i] = 1.0 if i == 1 else 0.0
    
    plt.figure(figsize=(8, 4))
    plt.plot(data, marker='o', linestyle='-')
    plt.title('Synthetic Input Signal')
    plt.xlabel('Time Step')
    plt.ylabel('Value')
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(os.path.join(output_folder, 'synthetic_input.png'))
    plt.close()
    
    # Return tensor with batch dimension and feature dimension
    return torch.tensor(data).unsqueeze(0).unsqueeze(-1)

# =============================================================================
# 4. Define the Spiking Reservoir Experiment Model
# =============================================================================
class SpikingReservoirExp(nn.Module):
    def __init__(self, threshold, beta_reservoir, reservoir_size, device='cpu'):
        """
        Args:
            threshold (float): Threshold (membrane potential) for reservoir LIF neurons.
            beta_reservoir (float): β value for reservoir LIF neurons.
            reservoir_size (int): Number of reservoir neurons (default: 10).
            device (str): 'cpu' or 'cuda'
        """
        super(SpikingReservoirExp, self).__init__()
        self.device = device
        self.reservoir_size = reservoir_size
        
        # Input layer: one linear neuron mapping 1 input to 1.
        self.input_fc = nn.Linear(1, 1, bias=False)
        # Input LIF neuron (fixed parameters; not tuned here).
        self.input_lif = snn.Leaky(beta=0.9,
                                   threshold=1.0,
                                   spike_grad=None,
                                   reset_mechanism='zero',
                                   reset_delay=False)
        
        # Reservoir: a linear mapping from 1 to reservoir_size.
        self.reservoir_fc = nn.Linear(1, reservoir_size, bias=False)
        # Reservoir LIF layer with hyperparameters to be tuned.
        self.reservoir_lif = snn.RLeaky(beta=beta_reservoir,
                                        linear_features=reservoir_size,
                                        threshold=threshold,
                                        spike_grad=None,
                                        reset_mechanism='zero',
                                        reset_delay=False,
                                        all_to_all=True)
        
        # Initialize the recurrent weight matrix:
        with torch.no_grad():
            nn.init.uniform_(self.reservoir_lif.recurrent.weight, a=0.1, b=1.0)
            W = self.reservoir_lif.recurrent.weight
            eigenvalues = torch.linalg.eigvals(W)
            current_radius = eigenvalues.abs().max()
            desired_radius = 1.5
            scaling_factor = desired_radius / current_radius
            W.mul_(scaling_factor)
        
        self.to(self.device)
    
    def forward(self, x):
        """
        Simulate the reservoir dynamics one time step at a time.
        
        Args:
            x (tensor): Input tensor of shape (batch_size, time_steps, 1).
            
        Returns:
            avg_firing_rate (float): Average firing rate (over neurons and time).
            spike_record (np.array): Reservoir spike record of shape (time_steps, batch_size, reservoir_size).
            mem_record (np.array): Reservoir membrane potential record (same shape as spike_record).
        """
        batch_size, time_steps, _ = x.shape
        x = x.to(self.device)
        
        # Initialize neuron states.
        input_mem = torch.zeros(batch_size, 1, device=self.device)
        reservoir_mem = torch.zeros(batch_size, self.reservoir_size, device=self.device)
        reservoir_spk = torch.zeros(batch_size, self.reservoir_size, device=self.device)
        
        spike_record = []
        mem_record = []
        
        for t in range(time_steps):
            x_t = x[:, t, :]  # shape: (batch_size, 1)
            input_current = self.input_fc(x_t)
            input_spk, input_mem = self.input_lif(input_current, input_mem)
            reservoir_current = self.reservoir_fc(input_spk)
            reservoir_spk, reservoir_mem = self.reservoir_lif(reservoir_current,
                                                              reservoir_spk,
                                                              reservoir_mem)
            spike_record.append(reservoir_spk.detach().cpu().numpy())
            mem_record.append(reservoir_mem.detach().cpu().numpy())
        
        spike_record = np.array(spike_record)  # shape: (time_steps, batch_size, reservoir_size)
        mem_record = np.array(mem_record)        # same shape
        avg_firing_rate = spike_record.mean()
        
        return avg_firing_rate, spike_record, mem_record
    
    def get_recurrent_weights(self):
        """Return the recurrent weight matrix as a numpy array."""
        return self.reservoir_lif.recurrent.weight.detach().cpu().numpy()

# =============================================================================
# 5. Define the Optuna Objective Function (Using a Grid Search)
# =============================================================================
def objective(trial):
    # Hyperparameters to tune.
    threshold = trial.suggest_float("threshold", 0.5, 5.0)
    beta_reservoir = trial.suggest_float("beta_reservoir", 0.01, 0.99)
    
    device = 'cpu'
    model = SpikingReservoirExp(threshold, beta_reservoir, reservoir_size=10, device=device)
    model.eval()
    
    # Use the unique output folder for saving the synthetic input plot.
    # (The folder path is provided later via trial.user_attrs.)
    output_folder = trial.user_attrs.get("output_folder")
    if output_folder is None:
        output_folder = "."
    x = generate_synthetic_input(output_folder)  # shape: (1, 1000, 1)
    
    avg_firing_rate, spike_record, mem_record = model(x)
    
    # Save extra information in the trial for later plotting.
    trial.set_user_attr("spike_record", spike_record)
    trial.set_user_attr("mem_record", mem_record)
    trial.set_user_attr("weights", model.get_recurrent_weights())
    
    return avg_firing_rate


In [2]:

# =============================================================================
# 6. Main Function: Run Grid Search, Save Plots, Hyperparameters, and Log to TensorBoard
# =============================================================================
if __name__ == '__main__':
    # Base directory for all output results.
    base_output_dir = "results"
    output_folder = create_unique_folder(base_output_dir)
    print(f"Output folder: {output_folder}")
    
    # Initialize TensorBoard writer with the unique output folder.
    writer = SummaryWriter(log_dir=output_folder)
    
    # Define grid search values for hyperparameters.
    threshold_values = np.linspace(0.1, 2.0, 20).tolist()  # n discrete values
    beta_values = np.linspace(0.01, 0.99, 20).tolist()       # m discrete values
    search_space = {
        "threshold": threshold_values,
        "beta_reservoir": beta_values
    }
    
    # Use Optuna's GridSampler.
    sampler = optuna.samplers.GridSampler(search_space)
    study = optuna.create_study(sampler=sampler, direction="maximize")
    
    # Pass the output folder path to each trial via user_attrs.
    def objective_with_folder(trial):
        trial.set_user_attr("output_folder", output_folder)
        return objective(trial)
    
    study.optimize(objective_with_folder, n_trials=len(threshold_values) * len(beta_values))
    
    # -----------------------------------------------------------------------------
    # (A) Extract Results for 3D Surface Plot and Heatmap
    # -----------------------------------------------------------------------------
    trials = study.trials
    # Prepare grids to store hyperparameter values and the corresponding firing rate.
    firing_rate_grid = np.zeros((len(threshold_values), len(beta_values)))
    for trial in trials:
        t_val = trial.params["threshold"]
        b_val = trial.params["beta_reservoir"]
        i = threshold_values.index(t_val)
        j = beta_values.index(b_val)
        firing_rate_grid[i, j] = trial.value
    
    # 1. 3D Surface Plot of the Average Firing Rate.
    fig = plt.figure(figsize=(10, 8))
    ax = fig.add_subplot(111, projection='3d')
    X, Y = np.meshgrid(beta_values, threshold_values)
    surf = ax.plot_surface(X, Y, firing_rate_grid, cmap='viridis', edgecolor='none')
    ax.set_xlabel('Beta Reservoir')
    ax.set_ylabel('Threshold')
    ax.set_zlabel('Avg Firing Rate')
    ax.set_title('3D Surface Plot of Avg Firing Rate')
    fig.colorbar(surf, shrink=0.5, aspect=5)
    plt.tight_layout()
    surface_path = os.path.join(output_folder, '3d_surface_plot.png')
    plt.savefig(surface_path)
    writer.add_figure("3D_Surface_Plot", fig)
    plt.close()
    
    # 2. Heatmap of the Average Firing Rate.
    plt.figure(figsize=(8, 6))
    plt.imshow(firing_rate_grid, origin='lower', aspect='auto',
               extent=[min(beta_values), max(beta_values), min(threshold_values), max(threshold_values)],
               cmap='viridis')
    plt.colorbar(label='Avg Firing Rate')
    plt.xlabel('Beta Reservoir')
    plt.ylabel('Threshold')
    plt.title('Heatmap of Avg Firing Rate')
    plt.tight_layout()
    heatmap_path = os.path.join(output_folder, 'heatmap_firing_rate.png')
    plt.savefig(heatmap_path)
    writer.add_figure("Heatmap_Firing_Rate", plt.gcf())
    plt.close()
    
    # -----------------------------------------------------------------------------
    # (B) Choose a Representative Trial for Detailed Plots
    # -----------------------------------------------------------------------------
    # For example, choose the trial with hyperparameters closest to the grid center.
    rep_threshold = threshold_values[len(threshold_values)//2]
    rep_beta = beta_values[len(beta_values)//2]
    rep_trial = None
    for trial in trials:
        if trial.params["threshold"] == rep_threshold and trial.params["beta_reservoir"] == rep_beta:
            rep_trial = trial
            break
    if rep_trial is None:
        rep_trial = trials[0]
    
    # Save the hyperparameters of the representative trial to a file.
    hyperparams_file = os.path.join(output_folder, "hyperparameters.txt")
    with open(hyperparams_file, "w") as f:
        f.write("Representative Trial Hyperparameters:\n")
        for key, value in rep_trial.params.items():
            f.write(f"{key}: {value}\n")
    
    # Retrieve recorded data from the representative trial.
    spike_record = rep_trial.user_attrs["spike_record"]  # shape: (time_steps, 1, reservoir_size)
    mem_record = rep_trial.user_attrs["mem_record"]        # same shape
    weights = rep_trial.user_attrs["weights"]
    
    # Remove the batch dimension (batch_size=1).
    spike_record = spike_record[:, 0, :]  # shape: (time_steps, reservoir_size)
    mem_record = mem_record[:, 0, :]        # shape: (time_steps, reservoir_size)
    time_steps = spike_record.shape[0]
    
    # -----------------------------------------------------------------------------
    # (C) Detailed Plots for the Representative Trial
    # -----------------------------------------------------------------------------
    # 1. Spike Raster Plot.
    plt.figure(figsize=(10, 6))
    for neuron in range(spike_record.shape[1]):
        spike_times = np.where(spike_record[:, neuron] > 0)[0]
        plt.scatter(spike_times, np.full_like(spike_times, neuron), s=10)
    plt.xlabel('Time Step')
    plt.ylabel('Neuron Index')
    plt.title('Spike Raster Plot')
    plt.yticks(range(spike_record.shape[1]))
    plt.grid(True)
    plt.tight_layout()
    spike_raster_path = os.path.join(output_folder, 'spike_raster_plot.png')
    plt.savefig(spike_raster_path)
    writer.add_figure("Spike_Raster_Plot", plt.gcf())
    plt.close()
    
    # 2. Membrane Potential Traces.
    plt.figure(figsize=(10, 6))
    for neuron in range(mem_record.shape[1]):
        plt.plot(mem_record[:, neuron], label=f'Neuron {neuron}')
    plt.xlabel('Time Step')
    plt.ylabel('Membrane Potential')
    plt.title('Membrane Potential Traces')
    plt.legend(loc='upper right', bbox_to_anchor=(1.15, 1))
    plt.grid(True)
    plt.tight_layout()
    mem_trace_path = os.path.join(output_folder, 'membrane_potential_traces.png')
    plt.savefig(mem_trace_path)
    writer.add_figure("Membrane_Potential_Traces", plt.gcf())
    plt.close()
    
    # 3. Inter-Spike Interval (ISI) Histogram.
    all_intervals = []
    for neuron in range(spike_record.shape[1]):
        spike_times = np.where(spike_record[:, neuron] > 0)[0]
        if len(spike_times) > 1:
            intervals = np.diff(spike_times)
            all_intervals.extend(intervals)
    plt.figure(figsize=(8, 6))
    plt.hist(all_intervals, bins=30, color='skyblue', edgecolor='black')
    plt.xlabel('Inter-Spike Interval (Time Steps)')
    plt.ylabel('Count')
    plt.title('Inter-Spike Interval Histogram')
    plt.grid(True)
    plt.tight_layout()
    isi_path = os.path.join(output_folder, 'isi_histogram.png')
    plt.savefig(isi_path)
    writer.add_figure("ISI_Histogram", plt.gcf())
    plt.close()
    
    # 4. Recurrent Weight Matrix Heatmap.
    plt.figure(figsize=(6, 5))
    plt.imshow(weights, cmap='inferno', aspect='auto')
    plt.colorbar(label='Weight Value')
    plt.xlabel('Post-synaptic Neuron')
    plt.ylabel('Pre-synaptic Neuron')
    plt.title('Recurrent Weight Matrix Heatmap')
    plt.tight_layout()
    weight_heatmap_path = os.path.join(output_folder, 'weight_matrix_heatmap.png')
    plt.savefig(weight_heatmap_path)
    writer.add_figure("Weight_Matrix_Heatmap", plt.gcf())
    plt.close()
    
    # 5. Eigenvalues of the Recurrent Weight Matrix.
    eigenvalues = np.linalg.eigvals(weights)
    plt.figure(figsize=(8, 6))
    plt.scatter(eigenvalues.real, eigenvalues.imag, c='purple', edgecolors='k')
    plt.xlabel('Real Part')
    plt.ylabel('Imaginary Part')
    plt.title('Eigenvalues of the Recurrent Weight Matrix')
    plt.grid(True)
    plt.axhline(0, color='black', linewidth=0.5)
    plt.axvline(0, color='black', linewidth=0.5)
    plt.tight_layout()
    eigen_path = os.path.join(output_folder, 'eigenvalues.png')
    plt.savefig(eigen_path)
    writer.add_figure("Eigenvalues", plt.gcf())
    plt.close()
    
    # Log weight histogram to TensorBoard.
    writer.add_histogram("Recurrent_Weights", weights, global_step=0)
    
    # Log the average firing rate from the representative trial.
    writer.add_scalar("Representative_Avg_Firing_Rate", rep_trial.value, global_step=0)
    
    # Log hyperparameters using TensorBoard's hparams API.
    writer.add_hparams(rep_trial.params, {"hparam/avg_firing_rate": rep_trial.value})
    
    writer.close()
    print("All plots and logs have been saved.")


[I 2025-02-13 13:10:40,053] A new study created in memory with name: no-name-00e4d64c-4adf-41b6-9208-654708328a24


Output folder: results/trial_15_date_2025_02_13_13_10


[I 2025-02-13 13:10:40,409] Trial 0 finished with value: 0.0 and parameters: {'threshold': 1.3, 'beta_reservoir': 0.3194736842105263}. Best is trial 0 with value: 0.0.
[I 2025-02-13 13:10:40,742] Trial 1 finished with value: 0.0 and parameters: {'threshold': 0.9999999999999999, 'beta_reservoir': 0.7836842105263158}. Best is trial 0 with value: 0.0.
[I 2025-02-13 13:10:41,039] Trial 2 finished with value: 0.9993000030517578 and parameters: {'threshold': 0.2, 'beta_reservoir': 0.8868421052631579}. Best is trial 2 with value: 0.9993000030517578.
[I 2025-02-13 13:10:41,383] Trial 3 finished with value: 0.0 and parameters: {'threshold': 1.7, 'beta_reservoir': 0.47421052631578947}. Best is trial 2 with value: 0.9993000030517578.
[I 2025-02-13 13:10:41,676] Trial 4 finished with value: 0.0 and parameters: {'threshold': 0.7, 'beta_reservoir': 0.6289473684210526}. Best is trial 2 with value: 0.9993000030517578.
[I 2025-02-13 13:10:41,975] Trial 5 finished with value: 0.9990000128746033 and para

All plots and logs have been saved.
