# Compilation Graph Visualization

`brainscale` uses intermediate representation (IR) analysis to extract the dependencies between neuron states, synaptic connections, and model parameters. By calling the `.show_graph()` method, users can visualize the compiled computation graph, providing deeper insights into the computational structure and interdependencies within the neural model.

In [1]:
import brainstate
import brainunit as u
import brainscale
import jax


## Single-Layer Network

This example defines a simple single-layer LIF (Leaky Integrate-and-Fire) network with recurrent connections and an output layer. The network consists of LIF neurons, a Delta projection layer, and a LeakyRateReadout component.


In [2]:
class LIF_Delta_Net(brainstate.nn.Module):
    def __init__(
        self,
        n_in, n_rec, n_out,
        tau_mem=5. * u.ms,
        tau_o=5. * u.ms,
        V_th=1. * u.mV,
        spk_fun=brainstate.surrogate.ReluGrad(),
        spk_reset: str = 'soft',
        rec_scale: float = 1.,
        ff_scale: float = 1.,
    ):
        super().__init__()
        self.neu = brainscale.nn.LIF(n_rec, tau=tau_mem, spk_fun=spk_fun, spk_reset=spk_reset, V_th=V_th)
        rec_init = brainstate.init.KaimingNormal(rec_scale, unit=u.mV)
        ff_init = brainstate.init.KaimingNormal(ff_scale, unit=u.mV)
        w_init = u.math.concatenate([ff_init([n_in, n_rec]), rec_init([n_rec, n_rec])], axis=0)
        self.syn = brainstate.nn.DeltaProj(
            comm=brainscale.nn.Linear(
                n_in + n_rec, n_rec,
                w_init=w_init,
                b_init=brainstate.init.ZeroInit(unit=u.mV)
            ),
            post=self.neu
        )
        self.out = brainscale.nn.LeakyRateReadout(
            in_size=n_rec,
            out_size=n_out,
            tau=tau_o,
            w_init=brainstate.init.KaimingNormal()
        )

    def update(self, spk):
        self.syn(u.math.concatenate([spk, self.neu.get_spike()], axis=-1))
        return self.out(self.neu())

In [3]:
with brainstate.environ.context(dt=0.1 * u.ms):
    net = LIF_Delta_Net(n_in=10, n_rec=20, n_out=5)
    brainstate.nn.init_all_states(net)
    model = brainscale.D_RTRL(net)
    model.compile_graph(brainstate.random.rand(10))
    model.show_graph()

The hidden groups are:

   Group 0: [('out', 'r')]
   Group 1: [('neu', 'V')]


The weight parameters which are associated with the hidden states are:

   Weight 0: ('syn', 'comm', 'weight_op')  is associated with hidden group 1
   Weight 1: ('out', 'weight_op')  is associated with hidden group 0





## Multi-Layer Network

This example demonstrates a basic multi-layer GIF (Generalized Integrate-and-Fire) network with stacked projections and synaptic connections.


In [4]:
class GIF(brainstate.nn.Neuron):
    def __init__(
        self, size,
        V_rest=0. * u.mV,
        V_th_inf=1. * u.mV,
        R=1. * u.ohm,
        tau=20. * u.ms,
        tau_I2=50. * u.ms,
        A2=0. * u.mA,
        V_initializer = brainstate.init.ZeroInit(unit=u.mV),
        I2_initializer = brainstate.init.ZeroInit(unit=u.mA),
        spike_fun = brainstate.surrogate.ReluGrad(),
        spk_reset: str = 'soft',
        name: str = None,
    ):
        super().__init__(size, name=name, spk_fun=spike_fun, spk_reset=spk_reset)

        # parameters
        self.V_rest = brainstate.init.param(V_rest, self.varshape, allow_none=False)
        self.V_th_inf = brainstate.init.param(V_th_inf, self.varshape, allow_none=False)
        self.R = brainstate.init.param(R, self.varshape, allow_none=False)
        self.tau = brainstate.init.param(tau, self.varshape, allow_none=False)
        self.tau_I2 = brainstate.init.param(tau_I2, self.varshape, allow_none=False)
        self.A2 = brainstate.init.param(A2, self.varshape, allow_none=False)

        # initializers
        self._V_initializer = V_initializer
        self._I2_initializer = I2_initializer

    def init_state(self):
        # 将模型用于在线学习，需要初始化状态变量
        self.V = brainscale.ETraceState(brainstate.init.param(self._V_initializer, self.varshape))
        self.I2 = brainscale.ETraceState(brainstate.init.param(self._I2_initializer, self.varshape))

    def update(self, x=0.):
        # 如果前一时刻发放了脉冲，则将膜电位和适应性电流进行重置
        last_spk = self.get_spike()
        last_spk = jax.lax.stop_gradient(last_spk)
        last_V = self.V.value - self.V_th_inf * last_spk
        last_I2 = self.I2.value - self.A2 * last_spk
        # 更新状态
        I2 = brainstate.nn.exp_euler_step(lambda i2: - i2 / self.tau_I2, last_I2)
        V = brainstate.nn.exp_euler_step(lambda v, Iext: (- v + self.V_rest + self.R * Iext) / self.tau,
                                         last_V, x + I2)
        self.I2.value = I2
        self.V.value = V
        # 输出
        inp = self.V.value - self.V_th_inf
        inp = jax.nn.standardize(u.get_magnitude(inp))
        return inp

    def get_spike(self, V=None):
        V = self.V.value if V is None else V
        spk = self.spk_fun((V - self.V_th_inf) / self.V_th_inf)
        return spk


class GifLayer(brainstate.nn.Module):
    def __init__(
        self,
        n_in: int,
        n_rec: int,
        ff_scale: float = 1.,
        rec_scale: float = 1.,
        tau_neu: float = 5. * u.ms,
        tau_syn: float = 5. * u.ms,
        tau_I2: float = 5. * u.ms,
        A2=1. * u.mA,
    ):
        super().__init__()

        # 初始化权重
        ff_init = brainstate.init.KaimingNormal(ff_scale, unit=u.mA)
        rec_init = brainstate.init.KaimingNormal(rec_scale, unit=u.mA)
        w = u.math.concatenate([ff_init((n_in, n_rec)), rec_init((n_rec, n_rec))], axis=0)

        # 参数
        self.n_in = n_in
        self.n_rec = n_rec

        # 模型层
        self.ir2r = brainscale.nn.Linear(n_in + n_rec, n_rec, w_init=w, b_init=brainstate.init.ZeroInit(unit=u.mA))
        self.exp = brainscale.nn.Expon(n_rec, tau=tau_syn, g_initializer=brainstate.init.ZeroInit(unit=u.mA))
        self.r = GIF(
            n_rec,
            V_rest=0. * u.mV,
            V_th_inf=1. * u.mV,
            A2=A2,
            tau=tau_neu,
            tau_I2=brainstate.random.uniform(100. * u.ms, tau_I2 * 1.5, n_rec),
        )

    def update(self, spikes):
        cond = self.ir2r(u.math.concatenate([spikes, self.r.get_spike()], axis=-1))
        return self.r(self.exp(cond))


class GifNet(brainstate.nn.Module):
    def __init__(
        self,
        n_in: int,
        n_rec: list,
        n_out: int,
        tau_o: float = 5. * u.ms,
    ):
        super().__init__()

        self.layers = []
        for n in n_rec:
            assert n > 0, "n_rec should be a list of positive integers."
            self.layers.append(GifLayer(n_in, n))
            n_in = n
        self.out = brainscale.nn.LeakyRateReadout(n_in, n_out, tau=tau_o, w_init=brainstate.init.KaimingNormal())

    def update(self, x):
        for layer in self.layers:
            x = layer(x)
        return self.out(x)

In [5]:
with brainstate.environ.context(dt=0.1 * u.ms):
    net2 = GifNet(n_in=10, n_rec=[20, 20, 20], n_out=5)
    brainstate.nn.init_all_states(net2)
    model = brainscale.D_RTRL(net2)
    model.compile_graph(brainstate.random.rand(10))
    model.show_graph()

The hidden groups are:

   Group 0: [('layers', 2, 'r', 'I2'), ('layers', 2, 'r', 'V'), ('layers', 2, 'exp', 'g')]
   Group 1: [('layers', 0, 'r', 'I2'), ('layers', 0, 'r', 'V'), ('layers', 0, 'exp', 'g')]
   Group 2: [('layers', 1, 'r', 'I2'), ('layers', 1, 'r', 'V'), ('layers', 1, 'exp', 'g')]
   Group 3: [('out', 'r')]


The weight parameters which are associated with the hidden states are:

   Weight 0: ('layers', 0, 'ir2r', 'weight_op')  is associated with hidden group 1
   Weight 1: ('layers', 1, 'ir2r', 'weight_op')  is associated with hidden group 2
   Weight 2: ('layers', 2, 'ir2r', 'weight_op')  is associated with hidden group 0
   Weight 3: ('out', 'weight_op')  is associated with hidden group 3






## Multi-Layer Convolutional Neural Network

A demonstration of a multi-layer convolutional architecture built using `brainscale` components. This example showcases how convolutional operations can be integrated with spiking neuron models in a hierarchical structure.

In [6]:
class ConvSNN(brainstate.nn.Module):
    """
    Convolutional SNN example.

    The model architecture is:

    1. Conv2d -> LayerNorm -> IF -> MaxPool2d
    2. Conv2d -> LayerNorm -> IF
    3. MaxPool2d -> Flatten
    4. Linear -> IF
    5. LeakyRateReadout
    """

    def __init__(
        self,
        in_size: brainstate.typing.Size,
        out_sze: brainstate.typing.Size,
        tau_v: float = 2.0,
        tau_o: float = 10.,
        v_th: float = 1.0,
        n_channel: int = 32,
        ff_wscale: float = 40.0,
    ):
        super().__init__()

        conv_inits = dict(w_init=brainstate.init.XavierNormal(scale=ff_wscale), b_init=None)
        linear_inits = dict(w_init=brainstate.init.KaimingNormal(scale=ff_wscale), b_init=None)
        if_param = dict(
            V_th=v_th,
            tau=tau_v,
            spk_fun=brainstate.surrogate.Arctan(),
            V_initializer=brainstate.init.ZeroInit(),
            R=1.
        )

        self.layer1 = brainstate.nn.Sequential(
            brainscale.nn.Conv2d(in_size, n_channel, kernel_size=3, padding=1, **conv_inits),
            brainscale.nn.LayerNorm.desc(),
            brainscale.nn.IF.desc(**if_param),
            brainstate.nn.MaxPool2d.desc(kernel_size=2, stride=2)  # 14 * 14
        )

        self.layer2 = brainstate.nn.Sequential(
            brainscale.nn.Conv2d(self.layer1.out_size, n_channel, kernel_size=3, padding=1, **conv_inits),
            brainscale.nn.LayerNorm.desc(),
            brainscale.nn.IF.desc(**if_param),
        )
        self.layer3 = brainstate.nn.Sequential(
            brainstate.nn.MaxPool2d(kernel_size=2, stride=2, in_size=self.layer2.out_size),  # 7 * 7
            brainstate.nn.Flatten.desc()
        )
        self.layer4 = brainstate.nn.Sequential(
            brainscale.nn.Linear(self.layer3.out_size, n_channel * 4 * 4, **linear_inits),
            brainscale.nn.IF.desc(**if_param),
        )
        self.layer5 = brainscale.nn.LeakyRateReadout(self.layer4.out_size, out_sze, tau=tau_o)

    def update(self, x):
        # x.shape = [B, H, W, C]
        return x >> self.layer1 >> self.layer2 >> self.layer3 >> self.layer4 >> self.layer5

In [7]:
with brainstate.environ.context(dt=0.1):
    net2 = ConvSNN((34, 34, 2), 10)
    brainstate.nn.init_all_states(net2)
    model = brainscale.D_RTRL(net2)
    model.compile_graph(brainstate.random.random((34, 34, 2)))
    model.show_graph()

The hidden groups are:

   Group 0: [('layer2', 'layers', 2, 'V')]
   Group 1: [('layer5', 'r')]
   Group 2: [('layer1', 'layers', 2, 'V')]
   Group 3: [('layer4', 'layers', 1, 'V')]


The weight parameters which are associated with the hidden states are:

   Weight 0: ('layer1', 'layers', 0, 'weight_op')  is associated with hidden group 2
   Weight 1: ('layer1', 'layers', 1, 'weight')  is associated with hidden group 2
   Weight 2: ('layer2', 'layers', 0, 'weight_op')  is associated with hidden group 0
   Weight 3: ('layer2', 'layers', 1, 'weight')  is associated with hidden group 0
   Weight 4: ('layer4', 'layers', 0, 'weight_op')  is associated with hidden group 3
   Weight 5: ('layer5', 'weight_op')  is associated with hidden group 1



