# Base Class

In this section, we are going to talk about:

- ``Base`` class for BrainPy ecosystem, 
- ``Collector`` to facilitate variable collection and manipulation.

In [1]:
import brainpy as bp
import brainpy.math as bm

bp.math.set_platform('cpu')

## Base

The foundation of BrainPy is [brainpy.Base](../apis/auto/generated/brainpy.base.Base.rst). A Base instance is an object which has variables and methods. All methods in the Base object can be [JIT compiled](./compilation.ipynb) or [automatic differentiated](./differentiation.ipynb). Or we can say, any **class objects** want to be JIT compiled or automatically differentiated must inherent from ``brainpy.Base``. 

A Base object can have many variables, children Base objects, integrators, and methods. For example, let's implement a [FitzHugh-Nagumo neuron model](https://brainmodels.readthedocs.io/en/brainpy-2.x/apis/generated/brainmodels.neurons.FHN.html). 

In [2]:
class FHN(bp.Base):
  def __init__(self, num, a=0.7, b=0.8, tau=12.5, Vth=1.9, name=None):
    super(FHN, self).__init__(name=name)

    # parameters
    self.num = num
    self.a = a
    self.b = b
    self.tau = tau
    self.Vth = Vth

    # variables
    self.V = bm.Variable(bm.zeros(num))
    self.w = bm.Variable(bm.zeros(num))
    self.spike = bm.Variable(bm.zeros(num, dtype=bool))

    # integral
    self.integral = bp.odeint(method='rk4', f=self.derivative)

  def derivative(self, V, w, t, Iext):
    dw = (V + self.a - self.b * w) / self.tau
    dV = V - V * V * V / 3 - w + Iext
    return dV, dw

  def update(self, _t, _dt, x):
    V, w = self.integral(self.V, self.w, _t, x)
    self.spike[:] = bm.logical_and(V > self.Vth, self.V <= self.Vth)
    self.w[:] = w
    self.V[:] = V

Note this model has three variables: ``self.V``, ``self.w``, and ``self.spike``. It also has an integrator ``self.integral``. 

### Naming system

Every Base object has a unique name. You can specify a unique name when you instantiate a Base class. A used name will cause an error. 

In [3]:
FHN(10, name='X').name

'X'

In [4]:
FHN(10, name='Y').name

'Y'

In [5]:
try:
    FHN(10, name='Y').name
except Exception as e:
    print(type(e).__name__, ':', e)

UniqueNameError : In BrainPy, each object should have a unique name. However, we detect that <__main__.FHN object at 0x00000224FA317BB0> has a used name "Y".


When you instance a Base class without "name" specification, BrainPy will assign a name for this object automatically. The rule for generating object name is ``class_name +  number_of_instances``. For example, ``FHN0``, ``FHN1``, etc.

In [6]:
FHN(10).name

'FHN0'

In [7]:
FHN(10).name

'FHN1'

Therefore in BrainPy, you can access any object by its unique name, no matter how insignificant this object is.

### Collection functions

Three important collection functions are implemented for each Base object. Specifically, they are:

- ``nodes()``: to collect all instances of Base objects, including children nodes in a node.
- ``ints()``: to collect all integrators defined in the Base node and in its children nodes. 
- ``vars()``: to collect all variables defined in the Base node and in its children nodes. 

All integrators can be collected through one method ``Base.ints()``. The result container is a [Collector](../apis/auto/generated/brainpy.base.Collector.rst). 

In [8]:
fhn = FHN(10)

In [9]:
ints = fhn.ints()

ints

{'FHN2.integral': <brainpy.integrators.ode.explicit_rk.RK4 at 0x224fa309370>}

In [10]:
type(ints)

brainpy.base.collector.Collector

Similarly, all variables in a Base object can be collected through ``Base.vars()``. The returned container is a [TensorCollector](../apis/auto/generated/brainpy.base.TensorCollector.rst) (a subclass of ``Collector``). 

In [11]:
vars = fhn.vars()

vars

{'FHN2.V': Variable(DeviceArray([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)),
 'FHN2.spike': Variable(DeviceArray([False, False, False, False, False, False, False, False,
                       False, False], dtype=bool)),
 'FHN2.w': Variable(DeviceArray([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32))}

In [12]:
type(vars)

brainpy.base.collector.TensorCollector

All nodes in the model can also be collected through one method ``Base.nodes()``. The result container is an instance of [Collector](../apis/auto/generated/brainpy.base.Collector.rst). 

In [13]:
nodes = fhn.nodes()

nodes  # note: integrator is also a node

{'RK44': <brainpy.integrators.ode.explicit_rk.RK4 at 0x224fa309370>,
 'FHN2': <__main__.FHN at 0x224fa317a60>}

In [14]:
type(nodes)

brainpy.base.collector.Collector

Now, let's make a more complicated model by using the previously defined model ``FHN``. 

In [15]:
class FeedForwardCircuit(bp.Base):
    def __init__(self, num1, num2, w=0.1, a=0.7, b=0.8, tau=12.5, Vth=1.9, name=None):
        super(FeedForwardCircuit, self).__init__(name=name)
        
        self.pre = FHN(num1, a=a, b=b, tau=tau, Vth=Vth)
        self.post = FHN(num2, a=a, b=b, tau=tau, Vth=Vth)
        
        conn = bm.ones((num1, num2), dtype=bool)
        self.conn = bm.fill_diagonal(conn, False) * w

    def update(self, _t, _dt, x):
        self.pre.update(_t, _dt, x)
        x2 = self.pre.spike @ self.conn
        self.post.update(_t, _dt, x2)

This model ``FeedForwardCircuit`` defines two layers. Each layer is modeled as a FitzHugh-Nagumo model (``FHN``). The first layer is densely connected to the second layer. The input to the second layer is the first layer's spike times a connection strength ``w``. 

In [16]:
net = FeedForwardCircuit(8, 5)

We can retrieve all integrators in the network with ``.ints()`` :

In [17]:
net.ints()

{'FHN3.integral': <brainpy.integrators.ode.explicit_rk.RK4 at 0x224fa358790>,
 'FHN4.integral': <brainpy.integrators.ode.explicit_rk.RK4 at 0x224fb878d90>}

Or, retrieve all variables by ``.vars()``:

In [18]:
net.vars()

{'FHN3.V': Variable(DeviceArray([0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)),
 'FHN3.spike': Variable(DeviceArray([False, False, False, False, False, False, False, False], dtype=bool)),
 'FHN3.w': Variable(DeviceArray([0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)),
 'FHN4.V': Variable(DeviceArray([0., 0., 0., 0., 0.], dtype=float32)),
 'FHN4.spike': Variable(DeviceArray([False, False, False, False, False], dtype=bool)),
 'FHN4.w': Variable(DeviceArray([0., 0., 0., 0., 0.], dtype=float32))}

Or, retrieve all nodes (instances of Base class) with ``.nodes()``:

In [19]:
net.nodes()

{'FHN3': <__main__.FHN at 0x224fb8780a0>,
 'FHN4': <__main__.FHN at 0x224fa3173a0>,
 'RK45': <brainpy.integrators.ode.explicit_rk.RK4 at 0x224fa358790>,
 'RK46': <brainpy.integrators.ode.explicit_rk.RK4 at 0x224fb878d90>,
 'FeedForwardCircuit0': <__main__.FeedForwardCircuit at 0x224fb878070>}

If we only care about a subtype of class, we can retrieve them through:

In [20]:
net.nodes().subset(bp.ode.ODEIntegrator)

{'RK45': <brainpy.integrators.ode.explicit_rk.RK4 at 0x224fa358790>,
 'RK46': <brainpy.integrators.ode.explicit_rk.RK4 at 0x224fb878d90>}

#### Absolute path

It's worthy to note that there are two types of ways to access variables, integrators, and nodes. They are "absolute" path and "relative" path. The default way is the absolute path. 

"Absolute" path means that all keys in the resulting Collector (``Base.nodes()``) has the format of ``key = node_name [+ field_name]``. 

**.nodes() example 1**: In the above ``fhn`` instance, there are two nodes: "fnh" and its integrator "fhn.integral".

In [21]:
fhn.integral.name, fhn.name

('RK44', 'FHN2')

Calling ``.nodes()`` returns names and models. 

In [22]:
fhn.nodes().keys()

dict_keys(['RK44', 'FHN2'])

**.nodes() example 2**: In the above ``net`` instance, there are five nodes:

In [23]:
net.pre.name, net.post.name, net.pre.integral.name, net.post.integral.name, net.name

('FHN3', 'FHN4', 'RK45', 'RK46', 'FeedForwardCircuit0')

Calling ``.nodes()`` also returns the names and instances of all models. 

In [24]:
net.nodes().keys()

dict_keys(['FHN3', 'FHN4', 'RK45', 'RK46', 'FeedForwardCircuit0'])

**.vars() example 1**: In the above ``fhn`` instance, there are three variables: "V", "w" and "input". Calling ``.vars()`` returns a dict of `<node_name + var_name, var_value>`. 

In [25]:
fhn.vars()

{'FHN2.V': Variable(DeviceArray([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)),
 'FHN2.spike': Variable(DeviceArray([False, False, False, False, False, False, False, False,
                       False, False], dtype=bool)),
 'FHN2.w': Variable(DeviceArray([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32))}

**.vars() example 2**: This also applies in the ``net`` instance:

In [26]:
net.vars()

{'FHN3.V': Variable(DeviceArray([0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)),
 'FHN3.spike': Variable(DeviceArray([False, False, False, False, False, False, False, False], dtype=bool)),
 'FHN3.w': Variable(DeviceArray([0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)),
 'FHN4.V': Variable(DeviceArray([0., 0., 0., 0., 0.], dtype=float32)),
 'FHN4.spike': Variable(DeviceArray([False, False, False, False, False], dtype=bool)),
 'FHN4.w': Variable(DeviceArray([0., 0., 0., 0., 0.], dtype=float32))}

#### Relative path

Variables, integrators, and nodes can also be accessed by relative path. For example, the ``pre`` instance in the ``net`` can be accessed by

In [27]:
net.pre

<__main__.FHN at 0x224fb8780a0>

Relative path preserves the dependence relationship. For example, all nodes retrieved from the perspective of ``net`` are:

In [28]:
net.nodes(method='relative')

{'': <__main__.FeedForwardCircuit at 0x224fb878070>,
 'pre': <__main__.FHN at 0x224fb8780a0>,
 'post': <__main__.FHN at 0x224fa3173a0>,
 'pre.integral': <brainpy.integrators.ode.explicit_rk.RK4 at 0x224fa358790>,
 'post.integral': <brainpy.integrators.ode.explicit_rk.RK4 at 0x224fb878d90>}

However, nodes retrieved from the start point of ``net.pre`` will be:

In [29]:
net.pre.nodes('relative')

{'': <__main__.FHN at 0x224fb8780a0>,
 'integral': <brainpy.integrators.ode.explicit_rk.RK4 at 0x224fa358790>}

Variables can also relatively inferred from the model. For example, all variables one can relatively accessed from ``net`` are:

In [30]:
net.vars('relative')

{'pre.V': Variable(DeviceArray([0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)),
 'pre.spike': Variable(DeviceArray([False, False, False, False, False, False, False, False], dtype=bool)),
 'pre.w': Variable(DeviceArray([0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)),
 'post.V': Variable(DeviceArray([0., 0., 0., 0., 0.], dtype=float32)),
 'post.spike': Variable(DeviceArray([False, False, False, False, False], dtype=bool)),
 'post.w': Variable(DeviceArray([0., 0., 0., 0., 0.], dtype=float32))}

While, variables relatively accessed from the view of ``net.post`` are:

In [31]:
net.post.vars('relative')

{'V': Variable(DeviceArray([0., 0., 0., 0., 0.], dtype=float32)),
 'spike': Variable(DeviceArray([False, False, False, False, False], dtype=bool)),
 'w': Variable(DeviceArray([0., 0., 0., 0., 0.], dtype=float32))}

####  Elements in containers

One drawback of collection functions is that they don't look for elements in *list*, *dict* or any container structure. 

In [32]:
class ATest(bp.Base):
    def __init__(self):
        super(ATest, self).__init__()
        
        self.all_vars = [bm.Variable(bm.zeros(5)), bm.Variable(bm.ones(6)),]
        self.sub_nodes = {'a': FHN(10), 'b': FHN(5)}

In [33]:
t1 = ATest()

The above class define a list of variables, and a dict of children nodes. However, they can not be retrieved from the collection functions ``vars()`` and ``nodes()``. 

In [34]:
t1.vars()

{}

In [35]:
t1.nodes()

{'ATest0': <__main__.ATest at 0x224fa309430>}

To solve this problem, in BrianPy, we provide ``implicit_vars`` and ``implicit_nodes`` (an instance of "dict") to hold variables and nodes in container structures. Any variable registered in ``implicit_vars``, or any integrator or node registered in ``implicit_nodes`` can be retrieved by collection functions. Let's make a try.

In [36]:
class AnotherTest(bp.Base):
    def __init__(self):
        super(AnotherTest, self).__init__()
        
        self.all_vars = [bm.Variable(bm.zeros(5)), bm.Variable(bm.ones(6)),]
        self.sub_nodes = {'a': FHN(10, name='T1'), 'b': FHN(5, name='T2')}
        
        self.register_implicit_vars({f'v{i}': v for i, v in enumerate(self.all_vars)}  # must be a dict
                                    )
        self.register_implicit_nodes({k: v for k, v in self.sub_nodes.items()}  # must be a dict
                                     )

In [37]:
t2 = AnotherTest()

In [38]:
# This model has two "FHN" instances, each "FHN" instance has one integrator. 
# Therefore, there are five Base objects. 

t2.nodes()

{'T1': <__main__.FHN at 0x224fb8a51c0>,
 'T2': <__main__.FHN at 0x224fb8a5c70>,
 'RK49': <brainpy.integrators.ode.explicit_rk.RK4 at 0x224fb9c7190>,
 'RK410': <brainpy.integrators.ode.explicit_rk.RK4 at 0x224fb9c7400>,
 'AnotherTest0': <__main__.AnotherTest at 0x224fb8a5250>}

In [39]:
# This model has five Base objects (seen above), 
# each FHN node has three variables, 
# moreover, this model has two implicit variables.

t2.vars()

{'T1.V': Variable(DeviceArray([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)),
 'T1.spike': Variable(DeviceArray([False, False, False, False, False, False, False, False,
                       False, False], dtype=bool)),
 'T1.w': Variable(DeviceArray([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)),
 'T2.V': Variable(DeviceArray([0., 0., 0., 0., 0.], dtype=float32)),
 'T2.spike': Variable(DeviceArray([False, False, False, False, False], dtype=bool)),
 'T2.w': Variable(DeviceArray([0., 0., 0., 0., 0.], dtype=float32)),
 'AnotherTest0.v0': Variable(DeviceArray([0., 0., 0., 0., 0.], dtype=float32)),
 'AnotherTest0.v1': Variable(DeviceArray([1., 1., 1., 1., 1., 1.], dtype=float32))}

### Saving and loading

Because ``Base.vars()`` returns a Python dictionary object [Collector](#Collector), they can be easily saved, updated, altered, and restored, adding a great deal of modularity to BrainPy models. Therefore, each Base object has standard exporting and loading methods (more details please see [Saving and Loading](../tutorial_simulation/save_and_load.ipynb)). Specifically, they are implemented by ``Base.save_states()`` and ``Base.load_states()``. 

#### Save

```python
Base.save_states(PATH, [vars])
```

Model exporting in BrainPy supports various Python standard file formats, including 

- HDF5: ``.h5``, ``.hdf5``
- ``.npz`` (NumPy file format)
- ``.pkl`` (Python's `pickle` utility)
- ``.mat`` (Matlab file format)

In [40]:
net.save_states('./data/net.h5')

In [41]:
net.save_states('./data/net.pkl')

In [42]:
# Unknown file format will cause error

try:
    net.save_states('./data/net.xxx')
except Exception as e:
    print(type(e).__name__, ":", e)

BrainPyError : Unknown file format: ./data/net.xxx. We only supports ['.h5', '.hdf5', '.npz', '.pkl', '.mat']


#### Load

```python

Base.load_states(PATH)
```

In [43]:
net.load_states('./data/net.h5')

In [44]:
net.load_states('./data/net.pkl')

## Collector

Collection functions returns an ``brainpy.Collector``. This class is a dictionary that maps names to elements. It has some useful methods. 

### ``subset()``

``Collector.subset(cls)`` returns a part of elements whose type is the given ``cls``. For example, ``Base.nodes()`` returns all instances of Base class. If you are only interested in one type, like ``ODEIntegrator``, you can use:

In [45]:
net.nodes().subset(bp.ode.ODEIntegrator)

{'RK45': <brainpy.integrators.ode.explicit_rk.RK4 at 0x224fa358790>,
 'RK46': <brainpy.integrators.ode.explicit_rk.RK4 at 0x224fb878d90>}

Actually, ``Collector.subset(cls)`` travels all the elements in this collection, and find the element whose type matches to the given ``cls``. 

### ``unique()``

It's a common in machine learning that weights are shared with several objects, or the same weight can be accessed by various dependence relationships. Collection functions of Base usually return a collection in which the same value have multiple keys. The duplicate elements will not be automatically excluded. However, it is important not to apply operations twice or more to the same elements (e.g., apply gradients and update weights). 

Therefore, Collector provides method ``Collector.unique()`` to handle this automatically. ``Collector.unique()`` returns a copy of collection in which all elements are unique. 

In [46]:
class ModelA(bp.Base):
    def __init__(self):
        super(ModelA, self).__init__()
        self.a = bm.Variable(bm.zeros(5))

        
class SharedA(bp.Base):
    def __init__(self, source):
        super(SharedA, self).__init__()
        self.source = source
        self.a = source.a  # shared variable
        
        
class Group(bp.Base):
    def __init__(self):
        super(Group, self).__init__()
        self.A = ModelA()
        self.A_shared = SharedA(self.A)

g = Group()

In [47]:
g.vars('relative')  # save Variable can be accessed by three paths

{'A.a': Variable(DeviceArray([0., 0., 0., 0., 0.], dtype=float32)),
 'A_shared.a': Variable(DeviceArray([0., 0., 0., 0., 0.], dtype=float32)),
 'A_shared.source.a': Variable(DeviceArray([0., 0., 0., 0., 0.], dtype=float32))}

In [48]:
g.vars('relative').unique()  # only return a unique path

{'A.a': Variable(DeviceArray([0., 0., 0., 0., 0.], dtype=float32))}

In [49]:
g.nodes('relative')  # "ModelA" is accessed twice

{'': <__main__.Group at 0x224fb9e8130>,
 'A': <__main__.ModelA at 0x224fb9e8040>,
 'A_shared': <__main__.SharedA at 0x224fb9e8280>,
 'A_shared.source': <__main__.ModelA at 0x224fb9e8040>}

In [50]:
g.nodes('relative').unique()

{'': <__main__.Group at 0x224fb9e8130>,
 'A': <__main__.ModelA at 0x224fb9e8040>,
 'A_shared': <__main__.SharedA at 0x224fb9e8280>}

### ``update()``

Collector is a dict. But, it has means to catch potential conflicts during assignment. The bracket assignment of a Collector (``[key]``) and ``Collector.update()`` will check whether the same key maps to a different value. If yes, an error will raise. 

In [51]:
tc = bp.Collector({'a': bm.zeros(10)})

tc

{'a': JaxArray(DeviceArray([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32))}

In [52]:
try:
    tc['a'] = bm.zeros(1)  # same key "a", different tensor
except Exception as e:
    print(type(e).__name__, ":", e)

ValueError : Name "a" conflicts: same name for [0.] and [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.].


In [53]:
try:
    tc.update({'a': bm.ones(1)})  # same key "a", different tensor
except Exception as e:
    print(type(e).__name__, ":", e)

ValueError : Name "a" conflicts: same name for [1.] and [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.].


### ``replace()``

If you try to replace the old key with the new value, you should use ``Collector.replace(old_key, new_value)`` function. 

In [54]:
tc

{'a': JaxArray(DeviceArray([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32))}

In [55]:
tc.replace('a', bm.ones(3))

tc

{'a': JaxArray(DeviceArray([1., 1., 1.], dtype=float32))}

### ``__add()__``

Two Collectors can be merged. 

In [56]:
a = bp.Collector({'a': bm.zeros(10)})
b = bp.Collector({'b': bm.ones(10)})

a + b

{'a': JaxArray(DeviceArray([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)),
 'b': JaxArray(DeviceArray([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], dtype=float32))}

## TensorCollector

``TensorCollector`` is subclass of ``Collector``, but it is specifically to collect tensors. 