In [1]:
import itertools
import shutil
import time
from dataclasses import asdict, dataclass, field
from datetime import datetime
from pathlib import Path
from typing import Any, Callable, Iterable, Literal, Self

import jax
import jax.numpy as jnp
import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import yaml
from jax import random
from mpl_toolkits.axes_grid1.inset_locator import inset_axes

from heidelberg_v01 import (
    load_data,
    load_datasets,
    plot_error,
    plot_spikes,
    plot_traces,
    run,
    run_example,
)
from hyperparam_scan_util import computed, scan_grid, vary
from spikegd.theta import ThetaNeuron


In [2]:
datasets = load_datasets("data", verbose=True)

Loading data from h5 file
Loading audio filenames
Finished loading SHD
Loading data from h5 file
Loading audio filenames
Finished loading SHD


In [3]:
def run_theta(config: dict, data_loaders=None):
    """
    Wrapper to train a network of Theta neurons with the given configuration.

    See docstring of `run` and article for more information.
    """
    if data_loaders is None:
        data_loaders = load_data(datasets, config)

    tau, I0, eps = config["tau"], config["I0"], config["eps"]
    neuron = ThetaNeuron(tau, I0, eps)
    metrics, perf_metrics = run(neuron, data_loaders, config, progress_bar="script")
    return metrics, perf_metrics

In [4]:
def summarize_ensemble_metrics(ensemble_metrics: dict, Nepochs: int) -> dict:
    metrics: dict = {}
    # epoch 0 is the initial state, other epochs are counted from 1
    epoch_metrics = [{} for _ in range(Nepochs + 1)]

    for key, value in ensemble_metrics.items():
        if key in ["p_init", "p_end"]:
            continue

        is_global = key.startswith("perf.")

        if is_global:
            metrics[f"{key}_mean"] = float(jnp.mean(value))
            metrics[f"{key}_std"] = float(jnp.std(value))
        else:
            min_mean = None
            min_mean_epoch = None
            max_mean = None
            max_mean_epoch = None

            if value.shape[1] != Nepochs + 1:
                raise ValueError(f"Expected {Nepochs + 1} (Nepochs + 1) values, got {value.shape[1]} in {key}")
            
            for epoch in range(Nepochs + 1):
                mean = float(jnp.mean(value[:, epoch]))
                std = float(jnp.std(value[:, epoch]))

                epoch_dict = epoch_metrics[epoch]
                epoch_dict[f"{key}_mean"] = mean
                epoch_dict[f"{key}_std"] = std

                if min_mean is None or mean < min_mean:
                    min_mean = mean
                    min_mean_epoch = epoch
                if max_mean is None or mean > max_mean:
                    max_mean = mean
                    max_mean_epoch = epoch

            # Also store init, final, min and max values for convenience
            metrics[f"{key}_init_mean"] = epoch_metrics[0][f"{key}_mean"]
            metrics[f"{key}_init_std"] = epoch_metrics[0][f"{key}_std"]

            metrics[f"{key}_final_mean"] = epoch_metrics[Nepochs][f"{key}_mean"]
            metrics[f"{key}_final_std"] = epoch_metrics[Nepochs][f"{key}_std"]

            if min_mean_epoch is not None:
                metrics[f"{key}_min_epoch"] = min_mean_epoch
                metrics[f"{key}_min_mean"] = min_mean
                metrics[f"{key}_min_std"] = epoch_metrics[min_mean_epoch][f"{key}_std"]

            if max_mean_epoch is not None:
                metrics[f"{key}_max_epoch"] = max_mean_epoch                    
                metrics[f"{key}_max_mean"] = max_mean
                metrics[f"{key}_max_std"] = epoch_metrics[max_mean_epoch][f"{key}_std"]

    metrics["epochs"] = epoch_metrics

    return metrics

def run_theta_ensemble(config: dict, data_loaders=None) -> dict:
    seed = config.get("seed", 0)
    Nsamples = config.get("Nsamples", 1)
    Nepochs = config["Nepochs"]

    key = random.PRNGKey(seed)
    seeds = random.randint(key, (Nsamples,), 0, jnp.uint32(2**32 - 1), dtype=jnp.uint32)
    metrics_list = []

    # load data once if not provided
    if data_loaders is None:
        data_loaders = load_data(datasets, config)

    for seed in seeds:
        config_theta = {**config, "seed": seed}
        metrics, perf_metrics = run_theta(config_theta, data_loaders)
        metrics_list.append(metrics | perf_metrics)
    metrics = jax.tree.map(lambda *args: jnp.stack(args), *metrics_list)
    
    return summarize_ensemble_metrics(metrics, Nepochs)

In [5]:
config_grid = {
    "seed": 0,
    # Neuron
    "tau": 6 / np.pi,
    "I0": 5 / 4,
    "eps": 1e-6,
    # Network
    # "Nin": 7000, # must be N * Nt, where N is the number of neurons in the SHD dataset (700)
    "Nin_virtual": vary(12, 16, 20),  # #Virtual input neurons = N_bin - 1
    "Nhidden": vary(40, 60, 80, 100),
    "Nlayer": vary(2, 3),  # Number of layers
    "Nout": 20,
    "w_scale": 0.5,  # Scaling factor of initial weights
    # Trial
    "T": 2.0,
    "K": vary(50, 100, 150, 200),  # Maximal number of simulated ordinary spikes
    "dt": 0.001,  # Step size used to compute state traces
    # Training
    "gamma": 1e-2,
    "Nbatch": 1000,
    "lr": 4e-3,
    "tau_lr": 1e2,
    "beta1": 0.9,
    "beta2": 0.999,
    "p_flip": vary(0.0, 0.02, 0.04),
    "Nepochs": 10,
    "Ntrain": None,  # Number of training samples
    # SHD Quantization
    "Nt": vary(8, 12, 16),
    "Nin_data": 700,
    "Nin": computed(lambda Nin_data, Nt: Nin_data * Nt),
    # Ensemble
    "Nsamples": 3,
}

In [6]:
scan_grid(run_theta_ensemble, config_grid, version=1,
          show_metrics=("acc_max_epoch", "acc_max_mean", "acc_max_std"),
          if_trial_exists="recompute_if_error")

varying keys              Nin_virtual, Nhidden, Nlayer, K, p_flip, Nt, Nin
configs                   864
seed                      0
tau                       1/6 pi^-1
I0                        1.25
eps                       1e-06
Nout                      20
w_scale                   0.5
T                         2.0
dt                        0.001
gamma                     0.01
Nbatch                    1000
lr                        0.004
tau_lr                    100.0
beta1                     0.9
beta2                     0.999
Nepochs                   10
Ntrain                    None
Nin_data                  700
Nsamples                  3


Nin_virtual               12
Nhidden                   40
Nlayer                    2
K                         50
p_flip                    0.0
Nt                        8
Nin                       5600
This config has already been used in trial 1 and had no error. Skipping.
Nin_virtual               16
Nhidden                   40
Nlay

100%|██████████| 10/10 [00:12<00:00,  1.30s/it]




100%|██████████| 10/10 [00:10<00:00,  1.02s/it]
100%|██████████| 10/10 [00:10<00:00,  1.04s/it]


acc_max_epoch             10
acc_max_mean              0.17445455491542816
acc_max_std               0.11381981521844864

Remaining: 852 configs (34304.5s, ETA: 03:13:22)

Nin_virtual               12
Nhidden                   40
Nlayer                    3
K                         50
p_flip                    0.0
Nt                        8
Nin                       5600
Starting trial 13.


100%|██████████| 10/10 [00:10<00:00,  1.02s/it]




100%|██████████| 10/10 [00:08<00:00,  1.20it/s]
  0%|          | 0/10 [00:00<?, ?it/s]