In [1]:
import itertools
import shutil
import time
from dataclasses import asdict, dataclass, field
from datetime import datetime
from pathlib import Path
from typing import Any, Callable, Iterable, Literal, Self

import jax
import jax.numpy as jnp
import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import yaml
from jax import random
from mpl_toolkits.axes_grid1.inset_locator import inset_axes

from heidelberg_v01 import (
    load_data,
    load_datasets,
    plot_error,
    plot_spikes,
    plot_traces,
    run,
    run_example,
)
from spikegd.theta import ThetaNeuron


In [2]:
datasets = load_datasets("data", verbose=True)

Loading data from h5 file
Loading audio filenames
Finished loading SHD
Loading data from h5 file
Loading audio filenames
Finished loading SHD


In [3]:
def run_theta(config: dict, data_loaders=None):
    """
    Wrapper to train a network of Theta neurons with the given configuration.

    See docstring of `run` and article for more information.
    """
    if data_loaders is None:
        data_loaders = load_data(datasets, config)

    tau, I0, eps = config["tau"], config["I0"], config["eps"]
    neuron = ThetaNeuron(tau, I0, eps)
    metrics, perf_metrics = run(neuron, data_loaders, config, progress_bar="script")
    return metrics, perf_metrics

def run_theta_ensemble(config: dict, samples: int = 1, data_loaders=None) -> dict:
    seed = 0
    key = random.PRNGKey(seed)
    seeds = random.randint(key, (samples,), 0, jnp.uint32(2**32 - 1), dtype=jnp.uint32)
    metrics_list = []

    # load data once if not provided
    if data_loaders is None:
        data_loaders = load_data(datasets, config)

    for seed in seeds:
        config_theta = {**config, "seed": seed}
        metrics, perf_metrics = run_theta(config_theta, data_loaders)
        metrics_list.append(metrics | perf_metrics)
    metrics = jax.tree.map(lambda *args: jnp.stack(args), *metrics_list)
    
    return metrics

In [4]:
def summarize_ensemble_metrics(ensemble_metrics: dict, epoch: int = -1) -> dict:
    metrics = {}

    for key, value in ensemble_metrics.items():
        if key.startswith("perf."):
            # Not epoch-specific
            metrics[f"{key}_mean"] = float(jnp.mean(value))
            metrics[f"{key}_std"] = float(jnp.std(value))
        elif key not in ["p_init", "p_end"]:
            # Epoch-specific
            metrics[f"{key}_mean"] = float(jnp.mean(value[:, epoch]))
            metrics[f"{key}_std"] = float(jnp.std(value[:, epoch]))

    return metrics


In [5]:
config_grid = {
    "seed": 0,
    # Neuron
    "tau": 6 / np.pi,
    "I0": 5 / 4,
    "eps": 1e-6,
    # Network
    # "Nin": 7000, # must be N * Nt, where N is the number of neurons in the SHD dataset (700)
    "Nin_virtual": (12, 16, 20),  # #Virtual input neurons = N_bin - 1
    "Nhidden": (40, 60, 80, 100),
    "Nlayer": (2, 3),  # Number of layers
    "Nout": 20,
    "w_scale": 0.5,  # Scaling factor of initial weights
    # Trial
    "T": 2.0,
    "K": (50, 100, 150, 200),  # Maximal number of simulated ordinary spikes
    "dt": 0.001,  # Step size used to compute state traces
    # Training
    "gamma": 1e-2,
    "Nbatch": 1000,
    "lr": 4e-3,
    "tau_lr": 1e2,
    "beta1": 0.9,
    "beta2": 0.999,
    "p_flip": (0.0, 0.02, 0.04),
    "Nepochs": 10,
    "Ntrain": None,  # Number of training samples
    # SHD Quantization
    "Nt": (8, 12, 16),
    "Nin_data": 700,
}


def configs_from_grid(config_grid: dict):
    config_grid = {
        k: v if isinstance(v, tuple) else (v,) for k, v in config_grid.items()
    }

    configs = [
        dict(zip(config_grid.keys(), values))
        for values in itertools.product(*config_grid.values())
    ]

    for config in configs:
        config["Nin"] = config["Nin_data"] * config["Nt"]

    return configs


def format_timestamp(t: float):
    return datetime.fromtimestamp(t).strftime("%Y-%m-%d_%H-%M-%S_%f")


def parse_timestamp(t: str):
    return datetime.strptime(t, "%Y-%m-%d_%H-%M-%S_%f").timestamp()


def get_grid_data_filename(version: int):
    return f"grid_data_{version:02d}.yaml"


@dataclass
class GridData:
    version: int
    created_at: str
    const: dict[str, Any] | None = None
    trials: dict[int, "GridTrial"] = field(default_factory=dict)

    def to_dict(self):
        dict_ = asdict(self)
        dict_["trials"] = dict(trial.to_item() for trial in self.trials.values())
        return dict_

    @classmethod
    def from_dict(cls, dict_: dict):
        dict_["trials"] = {
            k: GridTrial.from_item((k, v)) for k, v in dict_["trials"].items()
        }
        return cls(**dict_)

    def save(self):
        filename = get_grid_data_filename(self.version)
        with open(filename, "w") as f:
            yaml.dump(self.to_dict(), f, sort_keys=False)

    @classmethod
    def load_or_create(cls, version: int) -> Self:
        path = Path(get_grid_data_filename(version))

        if not path.exists():
            data = cls(version=version, created_at=format_timestamp(time.time()))
            data.save()
            return data
        
        # File exists -> create backup and load data
        shutil.copy(path, path.with_stem(f"{path.stem}_{time.strftime("%Y%m%d-%H%M%S")}"))

        with path.open("r") as f:
            dict_ = yaml.safe_load(f)

        data = cls.from_dict(dict_)

        assert data.version == version, f"Version mismatch: {data.version} != {version}"

        return data

    def create_trial_config(self, config: dict[str, Any]) -> dict[str, Any]:
        if self.const is None:
            self.const = config.copy()
            return {}

        trial_config = {}

        for key, const_value in list(self.const.items()):
            new_value = config.get(key)

            if const_value != new_value:
                # Previously constant key is different now -> remove from const
                self.const.pop(key, None)

                # Add key with constant value to all existing trial configs
                for trial in self.trials.values():
                    assert (
                        key not in trial.config
                    ), f"Key {key} already in trial config of trial {trial.index}."
                    trial.config[key] = const_value

        for key, new_value in config.items():
            const_value = self.const.get(key)

            if new_value != const_value:
                # Add key with new value to trial config
                trial_config[key] = new_value

        return trial_config

    def get_next_trial_index(self):
        return max((trial.index for trial in self.trials.values()), default=0) + 1

    def start_trial(self, config: dict[str, Any], n_samples: int) -> "GridTrial":
        trial_config = self.create_trial_config(config)

        trial = GridTrial(
            index=self.get_next_trial_index(),
            config=trial_config,
            n_samples=n_samples,
            started_at=format_timestamp(time.time()),
        )
        self.trials[trial.index] = trial
        return trial

    def find_trial_by_config(self, config: dict[str, Any]):
        for trial in self.trials.values():
            for k, v in config.items():
                trial_value = trial.config.get(k)
                if trial_value is None and self.const is not None:
                    trial_value = self.const.get(k)

                if v != trial_value:
                    break
            else:
                return trial


@dataclass
class GridTrial:
    index: int
    config: dict[str, Any]
    n_samples: int
    started_at: str
    finished_at: str | None = None
    duration: float | None = None
    metrics: dict[str, Any] = field(default_factory=dict)
    error: str | None = None

    def __post_init__(self):
        if self.index > 999:
            raise ValueError("Trial indices above 999 are not supported.")

    def to_item(self):
        return self.index, asdict(self)

    @classmethod
    def from_item(cls, item: tuple[int, dict[str, Any]]):
        index, dict_ = item
        assert isinstance(index, int), f"Invalid trial index: {index}"
        assert isinstance(dict_, dict), f"Invalid trial data: {dict_}"

        trial = cls(**dict_)
        assert trial.index == index, f"Trial index mismatch: {trial.index} != {index}"
        return trial

    def finish(self, metrics: dict[str, Any], error: str | None = None):
        self.metrics = metrics
        self.error = error
        end_time = time.time()
        self.finished_at = format_timestamp(end_time)
        self.duration = end_time - parse_timestamp(self.started_at)


def print_dict(d: dict, value_format=""):
    for k, v in d.items():
        print(f"{k:<25} {v:{value_format}}")


def filter_dict[K, V](
    d: dict[K, V], predicate: Callable[[K, V], bool] | Iterable[K]
) -> dict[K, V]:
    if isinstance(predicate, Iterable):
        keys = predicate
        predicate = lambda k, v: k in keys

    return {k: v for k, v in d.items() if predicate(k, v)}


def run_theta_grid(config_grid: dict, version: int, n_samples_per_trial=1):
    configs = configs_from_grid(config_grid)
    const_keys = [k for k, v in config_grid.items() if not isinstance(v, tuple)]
    varying_keys = [k for k, v in config_grid.items() if isinstance(v, tuple)]


    data = GridData.load_or_create(version)    
    print_dict(
        {
            "variables": ", ".join(varying_keys),
            "samples per trial": n_samples_per_trial,
            "configs": len(configs),
        }
    )
    print("========== CONSTANTS ==========")
    print_dict(filter_dict(config_grid, const_keys))
    print()

    print()

    for config in configs:
        trial_index = data.get_next_trial_index()
        print(f"========== TRIAL {trial_index:03d} ==========")
        print_dict(filter_dict(config, varying_keys))

        # Check if this config has already been run
        trial = data.find_trial_by_config(config)

        if trial is not None:
            print(f"This config has already been used in trial {trial.index}.")
            print()
            continue

        trial = data.start_trial(config, n_samples_per_trial)

        try:
            ensemble_metrics = run_theta_ensemble(config, samples=n_samples_per_trial)
        except Exception as e:
            trial.finish({}, repr(e))
            print_dict({"error": repr(e)})
        else:
            metrics = summarize_ensemble_metrics(ensemble_metrics)
            trial.finish(metrics)
            print_dict(filter_dict(metrics, ("acc_mean", "acc_std")), ".3f")

        print()

        data.save()

In [6]:
run_theta_grid(config_grid, version=1, n_samples_per_trial=3)

variables                 Nin_virtual, Nhidden, Nlayer, K, Nbatch, lr, p_flip, Nt
samples per trial         3
configs                   13824
seed                      0
tau                       1.909859317102744
I0                        1.25
eps                       1e-06
Nout                      20
w_scale                   0.5
T                         2.0
dt                        0.001
gamma                     0.01
tau_lr                    100.0
beta1                     0.9
beta2                     0.999
Nepochs                   10
Ntrain                    None
Nin_data                  700


Nin_virtual               12
Nhidden                   40
Nlayer                    2
K                         50
Nbatch                    100
lr                        0.001
p_flip                    0.0
Nt                        8
