Skip to content

v0.1.10

Latest

Choose a tag to compare

@chaoming0625 chaoming0625 released this 10 Jun 02:14
cdd8f94

This release adds two forward-mode second-order optimizers — SOFO and SOFOScan — to braintools.optim, and hardens the braintools.cogtask task engine so that conditional combinators, categorical labels, and metadata batching behave correctly under brainstate.transform.jit and brainstate.transform.vmap2. The package now ships inline type information (PEP 561), test coverage is raised to ~92% (and several latent trainer/visualize bugs uncovered along the way are fixed), and documentation links and assets are migrated to the new brainx.chaobrain.com host.

Highlights

  • New optimizers SOFO and SOFOScan: Second-Order Forward-mode Optimization for feedforward and recurrent models. Both build a Generalised Gauss-Newton matrix in a random tangent subspace from forward-mode JVPs and apply the resulting direction through the standard optax update path, so learning-rate schedules, momentum, weight decay, and gradient clipping continue to work unchanged.
  • Hardened cogtask dispatch: Switch now works in both eager and traced execution, While fails loudly on unsupported traced conditions, and a new num_classes parameter decouples categorical-head sizing from num_outputs.
  • PEP 561 typing: braintools ships a py.typed marker and inline annotations on its public API, so downstream static type checkers consume its types directly.

Added

braintools.optim — forward-mode second-order optimizers

  • SOFO: Second-Order Forward-mode Optimization for a model model(inputs) -> predictions paired with loss_fn(predictions, targets). It samples random tangent vectors, takes forward-mode JVPs through the model and loss, builds a damped Generalised Gauss-Newton system in the random subspace, solves it, and projects the solution back to parameter space. Supports 'mse' and 'ce' loss forms, a configurable tangent_size and damping, momentum / nesterov, decoupled weight_decay, and norm/value gradient clipping.
  • SOFOScan: a recurrent variant for a stateful one-step cell rnn_cell(latent, inputs) -> (new_latent, output). The cell is scanned over the input sequence with brainstate.transform.scan, and forward-mode JVPs propagate the tangents through lax.scan, accumulating the Gauss-Newton matrix over every (timestep, batch) sample before a single solve. Both optimizers are exported from braintools.optim and documented in the API reference.

braintools.cogtask

  • Task categorical sizing: a new num_classes argument, decoupled from num_outputs, sizes categorical output heads independently of the raw output dimension.
  • Task feature ergonomics: Task now accepts a lone Feature in place of a FeatureSet, and requires features to be supplied whenever phases are given.
  • Task time step: Task and make_task accept an optional dt argument. When set, it is pinned around trial generation via brainstate.environ.context, so phase durations and buffer sizes are computed against that dt and the reported dt stays consistent regardless of the ambient environment. When omitted, the ambient brainstate.environ.get_dt() is used (unchanged behaviour).

Typing

  • PEP 561 support: a braintools/py.typed marker is shipped via package data, and the top-level public API — spike bitwise ops, spike encoders (with implicit-Optional defaults fixed), tree utilities, and _misc helpers — now carries resolvable inline annotations.

Changed

  • cogtask conditional dispatch:
    • Switch uses dual-mode packed dispatch — a concrete key in cases lookup in eager mode and a lax.switch over ordered branches under jit / vmap — and coerces 0-d concrete array keys (e.g. ctx.rng.choice(...) selectors) so eager sample_trial no longer raises unhashable type: 'ArrayImpl'.
    • While raises a clear NotImplementedError for data-dependent (traced) conditions under jit / vmap instead of surfacing a cryptic TracerBoolConversionError.
  • Documentation links: chaobrain-ecosystem documentation URLs (brainstate, brainunit, braincell, brainmass, brainevent, braintrace, braintools, and related packages) were rewritten from *.readthedocs.io to the new brainx.chaobrain.com host, stripping /latest, /en/latest, and /en/stable path prefixes and ?badge=latest query strings. Third-party ReadTheDocs links are left intact.
  • README logo: the project logo is now served from brainx.chaobrain.com as WebP instead of a raw GitHub asset.

Fixed

  • braintools.cogtask:
    • Categorical labels that are statically out of range for the declared num_classes are now validated and rejected up front.
    • Packed-mode phases expose phase_start / phase_end before on_enter runs, matching the contract already provided in fixed-length mode.
    • String leaves are dropped from batched metadata so return_meta works correctly under brainstate.transform.vmap2.
    • Minor fixes to the input encoder and the working-memory task library.
    • Added regression tests covering all of the above.
  • braintools.trainer:
    • LightningModule.device no longer raises TypeError: 'set' object is not subscriptable on array-backed parameters; Array.devices() returns a set, which is now handled correctly (#92).
    • ModelCheckpoint saves through braintools.file.msgpack_save instead of the msgpack_from_state_dict restore helper, so checkpoints are actually written (#95).
  • braintools.visualize:
    • animate_2D reshapes the (num_step, num_neuron) values to the (height, width) grid before drawing the first frame, fixing a pcolor crash on the initial step (#93).
    • correlation_matrix(method='kendall') builds the correlation matrix pairwise over feature columns instead of passing a 2-D array to kendalltau (#94).
    • remove_axis uses ax.spines instead of the non-existent ax.spine, which previously raised AttributeError (#96).
    • create_neural_colormap / brain_colormaps register with force=True, making them idempotent rather than raising on re-use (#97).
    • roc_curve / precision_recall_curve resolve np.trapezoid when available (falling back to np.trapz), fixing an AttributeError on NumPy >= 2.4 where np.trapz was removed (#99).

Infrastructure

  • Publish workflow: reads the package version directly from braintools/_version.py (the single source of truth) and verifies that the release tag matches before publishing.
  • Docs deployment: the push: main trigger was removed; documentation is now deployed only on a GitHub release (released) or via a manual workflow_dispatch.
  • Type-check workflow: a new Type Check workflow runs mypy over the annotated public surface, backed by a [tool.mypy] configuration and a type-check optional-dependency group.
  • Test coverage: new test suites cover the previously-untested trainer, visualize, file, and surrogate modules, raising overall coverage to ~92%. CI runs pytest with --cov and uploads results to Codecov, and the README carries a coverage badge. tqdm and rich were added to the testing extra so the progress-bar callback tests run in CI.

Full changelog: v0.1.9...v0.1.10