# HAM

> Assembling layers and synapses into a single system governed by an energy function

We have now provided the two primary components: `layers` and `synapses`. This module assembles those together into a single network that is governed by an energy function

In [2]:
#| default_exp ham

In [8]:
#| export 
import jax
import jax.numpy as jnp
from typing import *
import treex as tx
from hamux.layers import Layer
from hamux.synapses import Synapse
import jax.tree_util as jtu
from hamux.utils import pytree_save, pytree_load, to_pickleable, align_with_state_dict
import pickle
import functools as ft
from fastcore.utils import *

In [45]:
#| hide
from nbdev.showdoc import *
from fastcore.test import *

## HAM System

We connect `layers` and `synapses` in a [hypergraph](https://en.wikipedia.org/wiki/Hypergraph) to describe the energy function. A hypergraph is a generalization of the familiar graph in that edges (synapses) can connect multiple nodes (layers). This graph -- complete with the operations of the synapses and the activation behavior of the layers -- fully defines the energy function for a given collection of neuron states.

In [6]:
#| export
class HAM(tx.Module):
    """Connecting neuron layers and synapses into a hypergraph"""
    layers: List[Layer]
    synapses: List[Synapse]
    connections: List[Tuple[Tuple, int]]

    def __init__(self, layers, synapses, connections):
        self.layers = layers
        self.synapses = synapses
        self.connections = connections

    @property
    def n_layers(self): return len(self.layers)
    @property
    def n_synapses(self): return len(self.synapses)
    @property
    def n_connections(self): return len(self.connections)
    @property
    def layer_taus(self): return [layer.tau for layer in self.layers]

As is typical for JAX frameworks, the parameters of HAMs need to be initialized. Unlike other machine learning libraries, the *states* of each *layer* -- that is, the dynamical variables of our system -- also need to be initialized. The notation $\mathbf{x}$ indicates the collection of all states from each layer, and $x^\alpha$ indicates that we are referring to the state of layer at index $\alpha$ in our collection. 

We provide this functionality with the following helper functions:

In [13]:
@patch
def init_states(self:HAM, 
                bs=None, # Batch size of the states to initialize, if needed
                rng=None): # RNG seed for random initialization of the states, if non-zero initializations are desired
    """Initialize the states of every layer in the network"""
    if rng is not None:
        keys = jax.random.split(rng, self.n_layers)
        return [layer.init_state(bs, rng=key) for layer, key in zip(self.layers, keys)]
    return [layer.init_state(bs) for layer in self.layers]

@patch
def init_states_and_params(self:HAM, 
                           param_key, # RNG seed for random initialization of the parameters
                           bs=None, # Batch size of the states to initialize, if needed
                           state_key=None): # RNG seed for random initialization of the states, if non-zero initializations are desired
    """Initialize the states and parameters of every layer and synapse in the network"""
    # params don't need a batch size to initialize
    params = self.init(param_key, self.init_states(), call_method="energy")
    states = self.init_states(bs, rng=state_key)
    return states, params

In [22]:
from hamux.layers import *
from hamux.synapses import *
import hamux.lagrangians as lag
import numpy as np

In [35]:
layers = [
    IdentityLayer((2,)),
    ReluLayer((3,))
]

synapses = [
    DenseSynapse()
]

connections = [
    ((0,1), 0)
]
xs, ham = HAM(layers, synapses, connections).init_states_and_params(jax.random.PRNGKey(0), state_key=jax.random.PRNGKey(1))

In [36]:
print([np.array(x) for x in xs]) # The dynamic variables

[array([-0.27703857,  1.351606  ], dtype=float32), array([0.511158  , 2.276133  , 0.23958689], dtype=float32)]


In [37]:
print([s.W for s in ham.synapses]) # The parameters

[DeviceArray([[0.01198053, 0.00434429, 0.01321206],
             [0.00065335, 0.0243299 , 0.02388163]], dtype=float32)]


The energy of our whole system is well defined:

$$E_\text{system}(\mathbf{x}) = E_\text{layers}(\mathbf{x}) + E_\text{synapses}(\mathbf{g}(\mathbf{x}))$$

where $\mathbf{x}$ is a collection of the states of our system, and $\mathbf{g}(\mathbf{x})$ is an identically shaped collection of the corresponding activations of our system. Then, for any instance at time $t$ we can compute the energy as a function of the states.

In [38]:
@patch
def activations(self:HAM, 
                xs:jnp.ndarray): # Collection of states for each layer
    """Turn a collection of states into a collection of activations"""
    gs = [self.layers[i].g(xs[i]) for i in range(len(xs))]
    return gs

@patch
def layer_energy(self:HAM,
                 xs:jnp.ndarray): # Collection of states for each layer
    """The total contribution of the layers' contribution to the energy of the HAM"""
    energies = jnp.stack([self.layers[i].energy(x) for i, x in enumerate(xs)])
    return jnp.sum(energies)

@patch
def synapse_energy(self:HAM,
                   gs:jnp.ndarray): # Collection of activations of each layer
    """The total contribution of the synapses' contribution to the energy of the HAM"""
    def get_energy(lset, k):
        mygs = [gs[i] for i in lset]
        synapse = self.synapses[k]
        return synapse.energy(*mygs)
    energies = jnp.stack([get_energy(lset, k) for lset, k in self.connections])
    return jnp.sum(energies)

@patch
def energy(self:HAM,
           xs:jnp.ndarray): # Collection of states for each layer
    """The complete energy of the HAM"""
    gs = self.activations(xs)
    energy = self.layer_energy(xs) + self.synapse_energy(gs)
    return energy

In [60]:
E_L = ham.layer_energy(xs); E_L

DeviceArray(3.7015278, dtype=float32)

In [43]:
gs = ham.activations(xs)
E_S = ham.synapse_energy(gs); E_S

DeviceArray(-0.07772133, dtype=float32)

In [46]:
test_eq(ham.energy(xs), E_L+E_S)

The update rule for each of the layer states is simply defined as follows:

$$\tau \frac{dx^\alpha}{dt} = -\frac{dE_\text{system}}{dg^\alpha}$$

JAX is wonderful. Autograd does this accurately and efficiently for us.

In [57]:
@patch
def dEdg(self:HAM, 
         xs:jnp.ndarray):
    """Calculate the gradient of system energy wrt. the activations

    Notice that we use an important mathematical property of the Legendre transform to take a mathematical, where dE_layer / dg = x
    """
    gs = self.activations(states)
    return jtu.tree_map(
        lambda x, s: x + s, xs, jax.grad(self.synapse_energy)(gs)
    )

@patch
def updates(self:HAM,
            xs:jnp.ndarray): # Collection of states for each layer
    """The negative of our dEdg, computing the update direction each layer should descend"""
    return jtu.tree_map(lambda dg: -dg, self.dEdg(xs))

Finally, we implement a simple, stochastic step function, though more advanced optimizations from the JAX ecosystem (e.g., [optax](https://github.com/deepmind/optax)) can easily be used.

In [58]:
@patch
def step(self:HAM,
    xs: List[jnp.ndarray], # Collection of current states for each layer
    updates: List[jnp.ndarray], # Collection of update directions for each state
    dt: float = 0.1, # Stepsize to take in direction of updates
    masks: Optional[List[jnp.ndarray]] = None, # Boolean mask, 0 if clamped neuron, and 1 elsewhere. A pytree identical to `xs`. Optional.
):
    """A discrete step down the energy using step size `dt` scaled by the `tau` of each layer"""
    taus = self.layer_taus
    alphas = [dt / tau for tau in taus] # Scaling factor of the update size of each layer
    if masks is not None:
        next_xs = jtu.tree_map(lambda x, u, m, alpha: x + alpha * u * m, xs, updates, masks, alphas)
    else:
        next_xs = jtu.tree_map(lambda x, u, alpha: x + alpha * u, xs, updates, alphas)
    return next_xs

It is particularly useful if all of these functions can be applied to a batched collection of states, something JAX makes particularly easy through its [`jax.vmap`](https://jax.readthedocs.io/en/latest/_autosummary/jax.vmap.html) functionality. We prefix vectorized versions of the above methods with a `v`.

In [None]:
@patch
def _statelist_batch_axes(self:HAM):
    """A helper function to tell vmap to batch along the 0'th dimension of each state in the HAM."""
    return ([0 for _ in range(self.n_layers)],)
    
@patch
def vactivations(self:HAM, 
                 xs: List[jnp.ndarray]): # Collection of states for each layer
    """A vectorized version of `activations`"""
    return jax.vmap(self.activations, in_axes=self._statelist_batch_axes())(xs)

@patch
def venergy(self:HAM, 
            xs: List[jnp.ndarray]): # Collection of states for each layer
    """A vectorized version of `energy`"""
    return jax.vmap(self.energy, in_axes=self._statelist_batch_axes())(xs)

@patch
def vdEdg(self:HAM, 
          xs: List[jnp.ndarray]): # Collection of states for each layer
    """A vectorized version of `dEdg`"""
    return jax.vmap(self.dEdg, in_axes=self._statelist_batch_axes())(xs)

@patch
def vupdates(self:HAM,
             xs: List[jnp.ndarray]): # Collection of states for each layer
    """A vectorized version of `updates`"""
    return jax.vmap(self.updates, in_axes=self._statelist_batch_axes())(xs)

Provide examples here

Finally, some helper functions to save and load this model during training. (Needs a better way to save and load the state dict.)

In [61]:
@patch
def load_state_dict(self:HAM, 
                    state_dict:Any): # The dictionary of all parameters, saved by `save_state_dict`
    if not self.initialized:
        _, self = self.init_states_and_params(jax.random.PRNGKey(0), 1)
    self.connections = state_dict["connections"]
    self.layers = align_with_state_dict(self.layers, state_dict["layers"])
    self.synapses = align_with_state_dict(self.synapses, state_dict["synapses"])
    return self

@patch
def save_state_dict(self:HAM, 
                    fname:Union[str, Path], # Filename of checkpoint to save
                    overwrite:bool=True): # Overwrite an existing file of the same name?
    to_save = jtu.tree_map(to_pickleable, self.to_dict())
    pytree_save(to_save, fname, overwrite=overwrite)

@patch
def load_ckpt(self:HAM, 
              ckpt_f:Union[str, Path]): # Filename of checkpoint to load
    with open(ckpt_f, "rb") as fp:
        state_dict = pickle.load(fp)
    return self.load_state_dict(state_dict)

In [7]:
#| hide
import nbdev; nbdev.nbdev_export()