# Neuron Layers

> Turning Lagrangians into building blocks

Fundamentally, a neuron layer is nothing more than a lagrangian function on top of data. This means that, in addition to a lagrangian, a neuron layer has:

- A `shape`
- A time constant `tau`
- A `bias` (optional) that we can view as the activation threshold of a neuron layer

In [17]:
#| default_exp layers

In [18]:
#| hide
from nbdev.showdoc import *
from fastcore.test import *

In [19]:
#| export
import jax
import jax.numpy as jnp
from typing import *
import treex as tx
from abc import ABC, abstractmethod
from flax import linen as nn
from hamux.lagrangians import *
import functools as ft
from fastcore.meta import delegates
from fastcore.utils import *
from fastcore.basics import *

In [20]:
#| export
class Layer(tx.Module):
    """The energy building block of any activation in our network that we want to hold state over time"""
    lagrangian: tx.Module 
    shape: Tuple
    tau: float
    use_bias: bool
    bias: jnp.ndarray = tx.Parameter.node(default=None)

    def __init__(self, 
                 lagrangian:tx.Module, # Describes the non-linearity
                 shape:Tuple[int], # Number and shape of neuron assembly
                 tau:float=1.0, # Time constant
                 use_bias:bool=False, # Add bias?
                 **kwargs): # Arguments passed to initialize the lagrangian
        self.lagrangian = lagrangian(**kwargs)
        self.shape = shape
        assert tau > 0.0, "Tau must be positive and non-zero"
        self.tau = tau
        self.use_bias = use_bias
        
    def activation(self, x):
        """Alias for `self.g`"""
        return self.g(x)

    def energy(self, x):
        """The predefined energy of a layer, defined for any lagrangian"""
        if self.initializing():
            if self.use_bias:
                self.bias = nn.initializers.normal(0.02)(tx.next_key(), self.shape)
        x2 = x - self.bias if self.use_bias else x # Is this an issue?

        # When jitted, this is no slower than the optimized `@` vector multiplication
        return jnp.multiply(self.g(x), x2).sum() - self.lagrangian(x2)

    def g(self, x):
        """The derivative of the lagrangian is our activation or Gain function `g`. 
        
        Defined to operate over input states `x` of shape `self.shape`
        """
        if self.initializing():
            if self.use_bias:
                self.bias = nn.initializers.normal(0.02)(tx.next_key(), self.shape)
        x2 = x - self.bias if self.use_bias else x
        return jax.grad(self.lagrangian)(x2)

    def init_state(self, 
                   bs: int = None, # Batch size
                   rng=None): # If given, initialize states from a normal distribution with this key
        """Initialize the states of this layer, with correct shape.
        
        If `bs` is provided, return tensor of shape (bs, *self.shape), otherwise return self.shape
        By default, initialize layer state to all 0.
        """
        layer_shape = self.shape if bs is None else (bs, *self.shape)
        if rng is not None:
            return jax.random.normal(rng, layer_shape)
        return jnp.zeros(layer_shape)

It is nice to package commonly used lagrangians as their own kind of layer, as follows.

In [33]:
#| export
#| hide
def MakeLayer(lagrangian_factory):
    """Hack to make it easy to create new layers from `Layer` utility class.
    
    `delegates` modifies the signature for all Layers. We want a different signature for each type of layer.

    So we redefine a local version of layer and delegate that for type inference.
    """
    global Layer

    @delegates(lagrangian_factory, keep=True)
    class Layer(Layer):
        __doc__ = Layer.__doc__
        
    out = partialler(Layer, lagrangian_factory)
    out.__doc__ = Layer.__doc__

    return out

In [34]:
#| export

# Some reason, docstrings are not showing the new kwargs, and the docs for these are broken. 
IdentityLayer = MakeLayer(LIdentity)
RepuLayer = MakeLayer(LRepu)
ReluLayer = MakeLayer(LRelu)
SoftmaxLayer = MakeLayer(LSoftmax)
SigmoidLayer = MakeLayer(LSigmoid)
TanhLayer = MakeLayer(LTanh)
ExpLayer = MakeLayer(LExp)

In [35]:
show_doc(SigmoidLayer)

---

### LSigmoid'>)

>      LSigmoid'>) (shape:Tuple[int], tau:float=1.0, use_bias:bool=False,
>                   beta=1.0, scale=1.0, min_beta=1e-06, **kwargs)

The energy building block of any activation in our network that we want to hold state over time

Our utility that we use to create these "convenience layers" is a bit hacky, but it works by injecting the lagrangian and the expected arguments for the lagrangian into our `Layer` utility:

In [37]:
show_doc(MakeLayer)

---

[source](https://github.com/bhoov/hamux/blob/main/hamux/layers.py#L80){target="_blank" style="float:right; font-size:smaller"}

### MakeLayer

>      MakeLayer (lagrangian_factory)

Hack to make it easy to create new layers from `Layer` utility class.

`delegates` modifies the signature for all Layers. We want a different signature for each type of layer.

So we redefine a local version of layer and delegate that for type inference.

By doing this hack, we lose the ability to inspect docstrings.

In [38]:
#| hide
import nbdev; nbdev.nbdev_export()