# Registry

> Easily create preconfigured models and prediction functions on a HAM

We create very simple helper functions to instantiate HAMs with particular architectural choices. Inspired by [`timm`](https://github.com/rwightman/pytorch-image-models).

A HAM is a fundamentally general purpose architecture. It is a general-purpose Associative Memory -- it is up to the user to extract the desired information from the system. Hence, every registered model must return the `ham` architecture and a `fwd` function that accomplishes a task from that architecture

In [None]:
#| default_exp registry

In [None]:
#| export 
import hamux as hmx
from typing import *
import functools as ft
from fastcore.utils import *
import jax
import jax.numpy as jnp
import jax.tree_util as jtu
import treex as tx
from einops import rearrange

In [None]:
#| hide
from nbdev.showdoc import *
from fastcore.test import *
import warnings
import os

In [None]:
#| hide
warnings.simplefilter('ignore')
os.environ["CUDA_VISIBLE_DEVICES"] = ""

## The Registry

In [None]:
#| export
__MODELS = {}

def register_model(fgen:Callable): # Function that returns a HAM with desired config
    """Register a function that returns a model configuration factory function.
    The name of the function acts as the retrieval key and must be unique across models"""
    __MODELS[fgen.__name__] = fgen
    return fgen

def create_model(mname:str, # Retrieve this stored model name
                 *args, # Passed to retrieved factory function
                 **kwargs): # Passed to retrieved factory function
    """Retrieve the model name from all registered models, passing `args` and `kwargs` to the factory function"""
    assert mname in __MODELS, f"Model '{mname}' has not been registered"
    return __MODELS[mname](*args, **kwargs)

def named_partial(f, *args, new_name=None, order=None, **kwargs):
    """Like `functools.partial` but also copies over function name and docstring. 
    
    If new_name is not None, use that as the name
    """
    fnew = ft.partial(f,*args,**kwargs)
    fnew.__doc__ = f.__doc__
    name = new_name if new_name is not None else f.__name__
    fnew.__name__ = name
    if order is not None: fnew.order=order
    elif hasattr(f,'order'): fnew.order=f.order
    return fnew

We can now register a model as follows:

In [None]:
@register_model
def example_classical_hn(img_shape:Tuple, # Vector input size
            label_shape:Tuple[int], # Number of labels
            nhid:int=1000, # Number of hidden units in the single hidden layer
            depth:int=4, # Default number of iterations to run the Hopfield Network prediction function
            dt:float=0.4, # Default step size of the system
           ): 
    """Create a 2-layer classical Hopfield Network applied on vectorized inputs and a function showing how to use it"""
    layers = [
        hmx.TanhLayer(img_shape),
        hmx.SoftmaxLayer(label_shape),
    ]

    synapses = [
        hmx.DenseMatrixSynapseWithHiddenLayer(nhid, hidden_lagrangian=hmx.lagrangians.LRelu()),
    ]

    connections = [
        ((0, 1), 0),
    ]

    ham = hmx.HAM(layers, synapses, connections)
    
    def fwd(model, x, depth=depth, dt=dt, rng=None):
        """A pure function to extract desired information from the configured HAM, applied on batched inputs"""
        # Initialize hidden states to our image
        xs = model.init_states(x.shape[0], rng=rng)
        xs[0] = jnp.array(x)

        # Masks allow us to clamp our visible data over time
        masks = jtu.tree_map(lambda x: jnp.ones_like(x, dtype=jnp.int8), xs)
        masks[0] = jnp.zeros_like(masks[0], dtype=jnp.int8)  # Don't evolve images

        for i in range(depth):
            updates = model.vupdates(xs)  # Calculate the updates
            xs = model.step(
                xs, updates, dt=dt, masks=masks
            )  # Add them to our current states

        # All labels have a softmax activation function as the last layer, spitting out probabilities
        return model.layers[-1].g(xs[-1])

    return ham, fwd

The model that we just created comes with a default function that predicts label probabilities after 4 steps (though feel free to write any function to extract a layer state/activation at any point in time).  

In [None]:
img_shape = (32,32); bs = 12
model, fwd = create_model("example_classical_hn", img_shape=img_shape, label_shape=(10,))

_, model = model.init_states_and_params(jax.random.PRNGKey(0))
x = jnp.ones((bs, *img_shape))
probs = fwd(model, x); probs.shape

2022-12-13 16:24:58.481170: E external/org_tensorflow/tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.cc:267] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected


(12, 10)

For the simple pipeline of classification, our `fwd` pipelines are quite similar. We therefore create some helper functions to use throughout the rest of our model configuration.

In [None]:
#| export
def simple_fwd(model:hmx.HAM, # HAM where layer[0] is the image input and layer[-1] are the labels
               x: jnp.ndarray, # Starting point for clamped layer[0]
               depth: int, # Number of iterations for which to run the model
               dt: float, # Step size through time
               rng: Optional[jnp.ndarray]=None): # If provided, initialize states to random instead of 0
    """A simple version of the forward function for showing in the paper.

    All time constants `tau` are set to be 1 in our architecture, but this is variable
    """
    # Initialize hidden states to our image
    xs = model.init_states(x.shape[0], rng=rng)
    xs[0] = jnp.array(x)

    # Masks allow us to clamp our visible data over time
    masks = jtu.tree_map(lambda x: jnp.ones_like(x, dtype=jnp.int8), xs)
    masks[0] = jnp.zeros_like(masks[0], dtype=jnp.int8)  # Don't evolve images

    for i in range(depth):
        updates = model.vupdates(xs)  # Calculate the updates
        xs = model.step(
            xs, updates, dt=dt, masks=masks
        )  # Add them to our current states

    # All labels have a softmax activation function as the last layer, spitting out probabilities
    return model.layers[-1].g(xs[-1])

def fwd_vec(model:hmx.HAM, # HAM where layer[0] is the image input and layer[-1] are the labels
               x: jnp.ndarray, # Starting point for clamped layer[0]
               depth: int, # Number of iterations for which to run the model
               dt: float, # Step size through time
               rng: Optional[jnp.ndarray]=None): # If provided, initialize states to random instead of 0
    """Where the image input is vectorized"""
    x = rearrange(x, "... c h w -> ... (c h w)")
    return simple_fwd(model, x, depth, dt, rng)

def fwd_conv(model:hmx.HAM, # HAM where layer[0] is the image input and layer[-1] are the labels
               x: jnp.ndarray, # Starting point for clamped layer[0]
               depth: int, # Number of iterations for which to run the model
               dt: float, # Step size through time
               rng: Optional[jnp.ndarray]=None): # If provided, initialize states to random instead of 0
    """Where the image input is kept as a 3 channel image"""
    x = rearrange(x, "... c h w -> ... h w c")
    return simple_fwd(model,x, depth,dt, rng)


## Model Registry

### 2 Layer HN

In [None]:
#| export
@register_model
def hn(hidden_lagrangian:tx.Module,
       img_shape: Tuple, # Shape of image input to model
       label_shape: Tuple, # Shape of label probabilities,typically (NLABELS,)
       nhid:int=1000, # Number of units in hidden layer
       do_norm:bool=False): # If provided, enforce that all weights are standardized
    """Create a Classical Hopfield Network that is intended to be applied on vectorized inputs"""
    layers = [
        hmx.TanhLayer(img_shape),
        hmx.SoftmaxLayer(label_shape),
    ]

    synapses = [
        hmx.DenseMatrixSynapseWithHiddenLayer(nhid, hidden_lagrangian=hidden_lagrangian, do_norm=do_norm),
    ]

    connections = [
        ((0, 1), 0),
    ]

    ham = hmx.HAM(layers, synapses, connections)

    forward = ft.partial(fwd_vec, depth=4, dt=0.4)

    return ham, forward

hn_relu = named_partial(hn, hmx.lagrangians.LRelu(), new_name="hn_relu")
register_model(hn_relu)

hn_repu5 = named_partial(hn, hmx.lagrangians.LRepu(n=5), new_name="hn_repu5")
register_model(hn_repu5)

hn_softmax = named_partial(hn, hmx.lagrangians.LSoftmax(), new_name="hn_softmax")
register_model(hn_softmax)

@register_model
def hn_relu_mnist(nhid:int=1000): # Number of units in the single hidden layer
    """Vectorized HN on flattened MNIST"""
    return hn_relu(img_shape=(784,), label_shape=(10,), nhid=nhid)

@register_model
def hn_relu_cifar(nhid:int=6000): # Number of units in the single hidden layer
    """Vectorized HN on flattened CIFAR10"""
    return hn_relu(img_shape=(3072,), label_shape=(10,), nhid=nhid)

@register_model
def hn_repu5_mnist(nhid=1000):
    """Vectorized DAM on flattened MNIST"""
    return hn_repu5(img_shape=(784,), label_shape=(10,), nhid=nhid)

@register_model
def hn_repu5_cifar(nhid=6000):
    """Vectorized DAM on flattened CIFAR"""
    return hn_repu5(img_shape=(3072,), label_shape=(10,), nhid=nhid)

@register_model
def hn_softmax_mnist(nhid=1000):
    return hn_softmax(img_shape=(784,), label_shape=(10,), nhid=nhid, do_norm=True)

@register_model
def hn_softmax_cifar(nhid=6000):
    return hn_softmax(img_shape=(3072,), label_shape=(10,), nhid=nhid, do_norm=True)


These models can now be instantiated by their strings:

In [None]:
xcifar = jnp.ones((1,3, 32,32)) # Per pytorch convention, CHW
xmnist = jnp.ones((1,1,28,28)) # Per pytorch convention, CHW

exhn, exhn_fwd = create_model("hn", hmx.lagrangians.LExp(), (32,32,3), (10,))
_, exhn = exhn.init_states_and_params(jax.random.PRNGKey(22))
exhn_fwd(exhn, xcifar)

Array([[0.31536135, 0.08483113, 0.0897951 , 0.02981309, 0.03241062,
        0.03071734, 0.03666373, 0.00281298, 0.03433144, 0.3432633 ]],      dtype=float32)

In [None]:
#| hide

# Additional tests for the registry

# Relu model tests
exhn_relu, exhn_relu_fwd = create_model("hn_relu", (32,32,3), (10,))
_, exhn_relu = exhn_relu.init_states_and_params(jax.random.PRNGKey(22))
exhn_relu_fwd(exhn_relu, xcifar)

exhn_relu_mnist, exhn_relu_mnist_fwd = create_model("hn_relu_mnist")
_, exhn_relu_mnist = exhn_relu_mnist.init_states_and_params(jax.random.PRNGKey(22))
exhn_relu_mnist_fwd(exhn_relu_mnist, xmnist)

exhn_relu_cifar, exhn_relu_cifar_fwd = create_model("hn_relu_cifar")
_, exhn_relu_cifar = exhn_relu_cifar.init_states_and_params(jax.random.PRNGKey(22))
exhn_relu_cifar_fwd(exhn_relu_cifar, xcifar)
    
# Repu5 model tests
exhn_repu5, exhn_repu5_fwd = create_model("hn_repu5", (32,32,3), (10,))
_, exhn_repu5 = exhn_repu5.init_states_and_params(jax.random.PRNGKey(22))
exhn_repu5_fwd(exhn_repu5, xcifar)

exhn_repu5_mnist, exhn_repu5_mnist_fwd = create_model("hn_repu5_mnist")
_, exhn_repu5_mnist = exhn_repu5_mnist.init_states_and_params(jax.random.PRNGKey(22))
exhn_repu5_mnist_fwd(exhn_repu5_mnist, xmnist)

exhn_repu5_cifar, exhn_repu5_cifar_fwd = create_model("hn_repu5_cifar")
_, exhn_repu5_cifar = exhn_repu5_cifar.init_states_and_params(jax.random.PRNGKey(22))
exhn_repu5_cifar_fwd(exhn_repu5_cifar, xcifar)

# Softmax model tests
exhn_repu5, exhn_repu5_fwd = create_model("hn_repu5", (32,32,3), (10,))
_, exhn_repu5 = exhn_repu5.init_states_and_params(jax.random.PRNGKey(22))
exhn_repu5_fwd(exhn_repu5, xcifar)

exhn_repu5_mnist, exhn_repu5_mnist_fwd = create_model("hn_repu5_mnist")
_, exhn_repu5_mnist = exhn_repu5_mnist.init_states_and_params(jax.random.PRNGKey(22))
exhn_repu5_mnist_fwd(exhn_repu5_mnist, xmnist)

exhn_repu5_cifar, exhn_repu5_cifar_fwd = create_model("hn_repu5_cifar")
_, exhn_repu5_cifar = exhn_repu5_cifar.init_states_and_params(jax.random.PRNGKey(22))
exhn_repu5_cifar_fwd(exhn_repu5_cifar, xcifar)

## Simple Convolution

In [None]:
#| export
@register_model
def conv_ham(s1, s2, s3, pool_type, nhid=1000):
    layers = [
        hmx.TanhLayer(s1, tau=1.0),
        hmx.TanhLayer(s2, tau=1.0),
        hmx.TanhLayer(s3, tau=1.0),
        hmx.SoftmaxLayer((10,), tau=1.0),
    ]
    synapses = [
        hmx.ConvSynapseWithPool(
            (4, 4),
            strides=(2, 2),
            padding=(2, 2),
            pool_window=(2, 2),
            pool_stride=(2, 2),
            pool_type=pool_type,
        ),
        hmx.ConvSynapseWithPool(
            (3, 3),
            strides=(1, 1),
            padding=(0, 0),
            pool_window=(2, 2),
            pool_stride=(2, 2),
            pool_type=pool_type,
        ),
        hmx.DenseMatrixSynapseWithHiddenLayer(nhid),
    ]
    connections = [
        ((0, 1), 0), 
        ((1, 2), 1), 
        ((2, 3), 2)
    ]

    ham = hmx.HAM(layers, synapses, connections)

    forward = ft.partial(fwd_conv, depth=7, dt=0.3)
    return ham, forward


@register_model
def conv_ham_avgpool_mnist(nhid=1000):
    return conv_ham((28, 28, 1), (7, 7, 64), (2, 2, 128), pool_type="avg", nhid=nhid)


@register_model
def conv_ham_maxpool_mnist(nhid=1000):
    return conv_ham((28, 28, 1), (7, 7, 64), (2, 2, 128), pool_type="max", nhid=nhid)


@register_model
def conv_ham_avgpool_cifar(nhid=1000):
    return conv_ham((32, 32, 3), (8, 8, 90), (3, 3, 180), pool_type="avg", nhid=nhid)


@register_model
def conv_ham_maxpool_cifar(nhid=1000):
    return conv_ham((32, 32, 3), (8, 8, 90), (3, 3, 180), pool_type="max", nhid=nhid)

In [None]:
model, fwd = create_model("conv_ham_avgpool_cifar")
_, model = model.init_states_and_params(jax.random.PRNGKey(0))
fwd(model, xcifar)

### Energy Version of Attention

We now introduce a simple model for energy-based attention

In [None]:
#| export
@register_model
def energy_attn(s1, s2, nheads_self, nheads_cross):
    layers = [
        hmx.TanhLayer(s1, tau=1.0),
        hmx.TanhLayer(s2, tau=1.0, use_bias=True),
        hmx.SoftmaxLayer((10,), tau=1.0),
    ]

    synapses = [
        hmx.ConvSynapse((4, 4), strides=(4, 4), padding=(0, 0)),
        hmx.AttentionSynapse(num_heads=nheads_cross, zspace_dim=64, stdinit=0.002),
        hmx.AttentionSynapse(num_heads=nheads_self, zspace_dim=64, stdinit=0.002),
    ]

    connections = [(
        (0, 1), 0), 
        ((2, 1), 1), 
        ((1, 1), 2)
    ]
    ham = hmx.HAM(layers, synapses, connections)
    forward = ft.partial(fwd_conv, depth=5, dt=0.4)

    return ham, forward


@register_model
def energy_attn_mnist():
    return energy_attn(
        (28, 28, 1),
        (7, 7, 128),
        nheads_self=4,
        nheads_cross=2,
    )


@register_model
def energy_attn_cifar():
    return energy_attn(
        (32, 32, 3),
        (8, 8, 224),
        nheads_self=4,
        nheads_cross=2,
    )

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()