# BrainState Mixin System


This tutorial explains the mixin utilities that ship with `brainstate`. After working through the examples you will:

- Understand what a mixin is and when to use one.
- Reuse behaviors by inheriting from `brainstate.mixin.Mixin`.
- Capture reusable constructor presets with `ParamDesc` and `ParamDescriber`.
- Express rich type expectations with `JointTypes` and `OneOfTypes`.
- Control runtime behaviour with the built-in mode mixins such as `Training`, `Batching`, and `JointMode`.


In [1]:
import datetime
from dataclasses import dataclass

import jax.numpy as jnp

import brainstate
from brainstate import mixin


## What is a mixin?

A *mixin* is a lightweight class that contributes behaviour (methods or attributes) without forcing a rigid inheritance hierarchy.
In BrainState every mixin inherits from `brainstate.mixin.Mixin`, signalling that the class
provides optional behaviour and should not define its own `__init__`.
Mixins are usually paired with core components such as `brainstate.nn.Module` to keep reusable code close to the consumer.


In [2]:
class LoggingMixin(mixin.Mixin):
    """Attach timestamped logging to any class without touching its constructor."""

    def log(self, message: str) -> None:
        stamp = datetime.datetime.now().strftime('%H:%M:%S')
        print(f'[LOG {stamp}] {self.__class__.__name__}: {message}')


class Accumulator(brainstate.nn.Module, LoggingMixin):
    """Simple module that reuses the logging helper."""

    def __init__(self):
        super().__init__()
        self.total = 0.0

    def add(self, value):
        self.total += float(value)
        self.log(f'updated running total to {self.total:.2f}')
        return self.total


acc = Accumulator()
_ = acc.add(1.25)
_ = acc.add(2.75)


[LOG 23:58:20] Accumulator: updated running total to 1.25
[LOG 23:58:20] Accumulator: updated running total to 4.00


### Design tips

- A mixin should only provide behaviour; avoid introducing new required constructor arguments.
- Keep mixins focused. Several small mixins compose better than a single, opinionated base class.
- Document expectations about host classes (e.g. attributes a mixin reads or writes).


## Parameter descriptors with `ParamDesc`

`ParamDesc` helps you capture reusable constructor presets.
The `desc()` class method stores the provided arguments inside a `ParamDescriber`, which you can later call
to instantiate new objects while still overriding any argument on demand.


In [3]:
class DenseBlock(mixin.ParamDesc):
    """Toy layer that records its configuration for inspection."""

    def __init__(self, in_features: int, out_features: int, *, activation: str = 'relu'):
        self.in_features = in_features
        self.out_features = out_features
        self.activation = activation

    def summary(self) -> str:
        return f'{self.activation} dense {self.in_features} → {self.out_features}'


encoder_block = DenseBlock.desc(256, 128, activation='gelu')
decoder_block = DenseBlock.desc(128, 64, activation='relu')

print(encoder_block().summary())
print(encoder_block(activation='relu').summary())  # override at call time
print(decoder_block().summary())


gelu dense 256 → 128
relu dense 256 → 128
relu dense 128 → 64


`ParamDesc` stores descriptors in a hashable structure. This plays nicely with caching systems because
`descriptor.identifier` is safe to use as a dictionary key.


In [4]:
print(encoder_block.identifier)


(<class '__main__.DenseBlock'>, (256, 128), {'activation': 'gelu'})


### Using `ParamDescriber` directly

If you want to describe classes that do not inherit from `ParamDesc`, you can work with
`ParamDescriber` manually.


In [5]:
@dataclass
class OptimConfig:
    lr: float
    beta1: float = 0.9
    beta2: float = 0.999


adam_template = mixin.ParamDescriber(OptimConfig, lr=1e-3, beta1=0.95)
opt_a = adam_template()
opt_b = adam_template(lr=5e-4)  # override a keyword

print(opt_a)
print(opt_b)


OptimConfig(lr=0.001, beta1=0.95, beta2=0.999)
OptimConfig(lr=0.0005, beta1=0.95, beta2=0.999)


## Type combinators: `JointTypes` and `OneOfTypes`

BrainState ships two helpers that make intent explicit when a value must satisfy multiple interfaces
or just one of several options:

- `JointTypes[A, B, ...]` behaves like an intersection — an instance must satisfy *all* listed types.
- `OneOfTypes[A, B, ...]` behaves like a union — an instance may satisfy *any* listed type.


In [6]:
class Persistable:
    def save(self):
        raise NotImplementedError


class Visualisable:
    def plot(self):
        raise NotImplementedError


class Report(Persistable, Visualisable):
    def save(self):
        return 'saved to disk'

    def plot(self):
        return 'rendering preview'


FullFeatureType = mixin.JointTypes[Persistable, Visualisable]
OptionalNumber = mixin.OneOfTypes[int, float, type(None)]

report = Report()
print(isinstance(report, FullFeatureType))
print(isinstance(3.14, OptionalNumber), isinstance(None, OptionalNumber))


True
True True


## Mode mixins for runtime behaviour

Mode objects capture the *context* in which computation happens.
The base `Mode` class is lightweight, and the built-ins `Training`, `Batching`, and `JointMode` cover
common runtime switches.


In [7]:
class ToyPipeline:
    """A tiny module that responds to different mode configurations."""

    def __init__(self):
        self.mode: mixin.Mode = mixin.Mode()

    def set_mode(self, *modes: mixin.Mode):
        if not modes:
            self.mode = mixin.Mode()
        elif len(modes) == 1:
            self.mode = modes[0]
        else:
            self.mode = mixin.JointMode(*modes)

    def forward(self, values):
        x = jnp.asarray(values, dtype=jnp.float32)
        if self.mode.has(mixin.Training):
            x = x + 0.1  # emulate noise or dropout
        if self.mode.has(mixin.Batching):
            batch = self.mode.batch_size
            x = x.reshape((batch, -1)).mean(axis=1)
        return x


pipeline = ToyPipeline()
print('default', pipeline.forward(jnp.arange(4.0)))

pipeline.set_mode(mixin.Training())
print('training', pipeline.forward(jnp.arange(4.0)))

pipeline.set_mode(mixin.Training(), mixin.Batching(batch_size=2))
print('joint', pipeline.forward(jnp.arange(4.0)))
print('joint exposes batch size:', pipeline.mode.batch_size)


default [0. 1. 2. 3.]
training [0.1 1.1 2.1 3.1]
joint [0.6 2.6]
joint exposes batch size: 2


The joint mode exposes the attributes of its members, so accessing `pipeline.mode.batch_size` works even
though the current mode is a `JointMode` instance.


## Putting it together

When you combine these mixin tools you can:

1. Add reusable behaviour (logging, validation, metrics) without disturbing core module hierarchies.
2. Parameterise component templates and reuse them safely through descriptors.
3. Encode clear expectations about inputs or collaborators via `JointTypes`/`OneOfTypes`.
4. Toggle runtime semantics with mode objects instead of ad-hoc boolean flags.


### Next steps

- Audit your own modules for behaviours that could live in a mixin.
- Wrap frequently reused constructor arguments with `ParamDesc`.
- Adopt mode objects in your training scripts to centralise feature flags (e.g. evaluation vs training).
- Explore `brainstate.mixin.not_implemented` to clearly mark unsupported operations.
