# BrainState Utility Toolkit

The `brainstate.util` package bundles helpers for collections, structured
PyTrees, pretty-printing, and runtime hygiene. This notebook walks through
the most frequently used APIs with runnable examples.

Sections:

1. Scheduling and naming helpers
2. Memory housekeeping
3. Managing collections with `DictManager`
4. Configuration access via `DotDict`
5. Dictionary utilities (`merge`, `flatten`, `unflatten`)
6. Structured PyTrees with `util.struct`
7. Filtering nested objects
8. Pretty PyTree containers

In [None]:
from typing import Any

import jax
import jax.numpy as jnp

from brainstate import util
from brainstate.util import (
    DictManager,
    DotDict,
    clear_buffer_memory,
    flatten_dict,
    merge_dicts,
    split_total,
    unflatten_dict,
)

from brainstate.util import struct, filter as util_filter

## 1. Scheduling and naming helpers

`split_total` calculates a portion of work either from a fractional quota or
an absolute count. `get_unique_name` keeps thread-local counters so repeated
calls stay unique without manual bookkeeping.

In [32]:
epochs = split_total(total=120, fraction=0.25)
override = split_total(total=120, fraction=30)
print('fractional schedule:', epochs)
print('absolute schedule:', override)

names = [util.get_unique_name('layer') for _ in range(3)]
scoped = [util.get_unique_name('block', prefix='encoder_') for _ in range(2)]
print('names:', names)
print('scoped names:', scoped)

fractional schedule: 30
absolute schedule: 30
names: ['layer39', 'layer40', 'layer41']
scoped names: ['encoder_block0', 'encoder_block1']


## 2. Memory housekeeping

`clear_buffer_memory` makes it easy to release cached device buffers and
compilation artifacts between experiments. Passing `array=False` keeps this
example side-effect free while illustrating the API.

In [33]:
clear_buffer_memory(array=False)
print('Cleared JAX compilation caches and triggered GC.')

Cleared JAX compilation caches and triggered GC.


## 3. Managing collections with `DictManager`

`DictManager` extends the standard mapping interface with filters, splits,
combination operators, and JAX PyTree support.

In [34]:
modules = DictManager({
    'encoder': {'params': 32},
    'decoder': {'params': 45},
    'dropout': 0.1,
})
print('original:', modules)

# Filter only submodules (dict instances)
submods = modules.subset(dict)
print('subset:', submods)

# Split by type: dict entries vs everything else
dicts, remainder = modules.split(dict)
print('split dicts:', dicts)
print('split remainder:', remainder)

# Map over values to extract parameter counts
param_counts = submods.map_values(lambda layer: layer['params'])
param_counts

original: DictManager({'encoder': {'params': 32}, 'decoder': {'params': 45}, 'dropout': 0.1})
subset: DictManager({'encoder': {'params': 32}, 'decoder': {'params': 45}})
split dicts: DictManager({'encoder': {'params': 32}, 'decoder': {'params': 45}})
split remainder: DictManager({'dropout': 0.1})


DictManager({'encoder': 32, 'decoder': 45})

## 4. Configuration access via `DotDict`

`DotDict` lets you treat nested dictionaries like lightweight objects while
preserving conversion back to standard dicts when needed.

In [35]:
config = DotDict({
    'model': {
        'layers': 4,
        'hidden': 256,
    },
    'training': {
        'lr': 3e-4,
        'scheduler': {'warmup_steps': 500},
    },
})

print('hidden units:', config.model.hidden)
config.training.dropout = 0.2
print('with dropout:', config.training.dropout)

round_trip = config.to_dict()
round_trip

hidden units: 256
with dropout: 0.2


{'model': {'layers': 4, 'hidden': 256},
 'training': {'lr': 0.0003,
  'scheduler': {'warmup_steps': 500},
  'dropout': 0.2}}

## 5. Dictionary utilities

`merge_dicts` performs optional recursive merges. `flatten_dict` and
`unflatten_dict` convert between nested and dotted-key representations—useful
for logging or CLI overrides.

In [36]:
base = {'optimizer': {'lr': 1e-3, 'beta1': 0.9}}
override = {'optimizer': {'lr': 5e-4}, 'seed': 1234}
merged = merge_dicts(base, override)
print('merged:', merged)

flat = flatten_dict(merged)
print('flattened:', flat)
unflatten_dict(flat)

merged: {'optimizer': {'lr': 0.0005, 'beta1': 0.9}, 'seed': 1234}
flattened: {'optimizer.lr': 0.0005, 'optimizer.beta1': 0.9, 'seed': 1234}


{'optimizer': {'lr': 0.0005, 'beta1': 0.9}, 'seed': 1234}

## 6. Structured PyTrees with `util.struct`

The `struct` submodule mirrors Flax-friendly data structures. The
`dataclass` decorator registers classes as PyTrees, while `FrozenDict`
provides immutable mappings compatible with JAX transformations.

In [37]:
@struct.dataclass
class LayerConfig:
    weight: jax.Array
    bias: jax.Array
    name: str = struct.field(pytree_node=False, default='layer')

cfg = LayerConfig(weight=jnp.ones((2, 2)), bias=jnp.zeros(2))
print(cfg)

cfg2 = cfg.replace(weight=jnp.full((2, 2), 3.0))
print('updated weight:', cfg2.weight)

flat_leaves, _ = jax.tree_util.tree_flatten(cfg)
print('pytree leaves:', [leaf.shape for leaf in flat_leaves])

frozen = struct.freeze({'encoder': jnp.arange(3)})
print('frozen dict:', frozen)
print('unfrozen:', struct.unfreeze(frozen))

LayerConfig(weight=Array([[1., 1.],
       [1., 1.]], dtype=float32), bias=Array([0., 0.], dtype=float32), name='layer')
updated weight: [[3. 3.]
 [3. 3.]]
pytree leaves: [(2, 2), (2,)]
frozen dict: FrozenDict({
  'encoder': Array([0, 1, 2], dtype=int32)
})
unfrozen: {'encoder': Array([0, 1, 2], dtype=int32)}


## 7. Filtering nested objects

`brainstate.util.filter` turns declarative filters into callables. Combine tag,
type, and path checks when traversing parameter trees.

In [38]:
class Module:
    def __init__(self, tag: str | None, kind: str):
        self.tag = tag
        self.kind = kind
        self.params = jnp.arange(2)

model_tree = {
    'encoder': Module(tag='trainable', kind='linear'),
    'decoder': Module(tag='frozen', kind='linear'),
    'head': Module(tag='trainable', kind='mlp'),
}

tag_filter = util_filter.to_predicate('trainable')
type_filter = util_filter.OfType(Module)
combined = util_filter.All(type_filter, util_filter.WithTag('trainable'))

def collect(tree: dict[str, Any], predicate) -> dict[str, Any]:
    out = {}
    for key, value in tree.items():
        if predicate((key,), value):
            out[key] = value
    return out

trainable_modules = collect(model_tree, tag_filter)
both = collect(model_tree, lambda path, val: combined(path, val))
print('trainable keys:', tuple(trainable_modules.keys()))
print('trainable Modules:', tuple(both.keys()))

trainable keys: ('encoder', 'head')
trainable Modules: ('encoder', 'head')


## 8. Pretty PyTree containers

`NestedDict`, `FlattedDict`, and `PrettyList` bring readable reprs plus PyTree
semantics. Use them to explore checkpoints or log structured configs.

In [39]:
from brainstate.util import NestedDict, flat_mapping, nest_mapping, PrettyList

state = NestedDict({
    'encoder': {'weight': jnp.ones((2, 2)), 'bias': jnp.zeros(2)},
    'decoder': {'weight': jnp.eye(2)},
})
print(state)

flat_state = flat_mapping(state)
print('flat keys:', list(flat_state.keys()))

round_trip = nest_mapping(flat_state)
print('round-trip equal:', round_trip == state)

history = PrettyList([{'loss': 0.8}, {'loss': 0.42}])
print(history)

{
  'encoder': {
    'weight': Array([[1., 1.],
           [1., 1.]], dtype=float32),
    'bias': Array([0., 0.], dtype=float32)
  },
  'decoder': {
    'weight': Array([[1., 0.],
           [0., 1.]], dtype=float32)
  }
}
flat keys: [('encoder', 'weight'), ('encoder', 'bias'), ('decoder', 'weight')]
round-trip equal: True
[
  {
    'loss': 0.8
  },
  {
    'loss': 0.42
  }
]


## Summary

- Use scheduling helpers (`split_total`, `get_unique_name`) to coordinate
  experiments.
- Reach for `DictManager` and `DotDict` to manage nested collections.
- Convert between nested and flat configs with `merge_dicts`, `flatten_dict`,
  and `unflatten_dict`.
- Wrap structured data using `util.struct` and leverage filter/pretty utilities
  when exploring PyTrees.