# Lagrangians

> The physics of associative memories is captured in the **Lagrangian** operation of neurons

We begin with the lagrangian, the fundamental building block of any neuron layer.

In [1]:
#| default_exp lagrangians

In [2]:
#| export 
import jax.numpy as jnp
import jax
import numpy as np
from fastcore.test import *
import functools as ft

In [3]:
#| hide
from nbdev.showdoc import *

## Functional interface

Here we define our lagrangian functions, which can be thought of as the integrand of common activation functions in Deep Learning literature. All lagrangians are (potentially parameterized) functions of the form:

$$\mathcal{L}(x;\ldots) \mapsto \mathbb{R}$$

where $x$ can be a tensor of arbitrary shape. It is important that our Lagrangians be convex and differentiable.

We want to rely on JAX's autograd to automatically differentiate our lagrangians into activation functions. In certain cases (e.g., `lagr_sigmoid` and `lagr_tanh`), autodiff will create a numerically unstable activation function. We follow JAX's [documentation guidelines](https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html) to define `custom_jvp`s to fix this behavior.

In [4]:
#| export
def lagr_identity(x): return 1 / 2 * jnp.power(x, 2).sum()

def lagr_repu(x, 
              n): # Degree of the polynomial in the power unit
    """Rectified Power Unit of degree `n`"""
    return 1 / n * jnp.power(jnp.maximum(x, 0), n).sum()

def lagr_relu(x):
    """Rectified Linear Unit. Same as repu of degree 2"""
    return lagr_repu(x, 2)

def lagr_softmax(x,
                 beta=1.0, # Inverse temperature
                 axis=-1): # Dimension over which to apply logsumexp
    """The lagrangian of the softmax -- the logsumexp"""
    return (1/beta * jax.nn.logsumexp(beta * x, axis=axis, keepdims=True)).sum()

def lagr_exp(x, 
             beta=1.0): # Inverse temperature
    """Exponential activation function, as in [Demicirgil et al.](https://arxiv.org/abs/1702.01929)"""
    return 1 / beta * jnp.exp(beta * x).sum()

The lagrangian of the `sigmoid` and the `tanh` are a bit more numerically unstable. We will need to define custom gradients for them.

In [5]:
#| export
@jax.custom_jvp
def _lagr_sigmoid(x, 
                  beta=1.0, # Inverse temperature
                  scale=1.0): # Amount to stretch the range of the sigmoid's lagrangian
    return scale / beta * jnp.log(jnp.exp(beta * x) + 1)

def tempered_sigmoid(x, 
                     beta=1.0, # Inverse temperature
                     scale=1.0): # Amount to stretch the range of the sigmoid
    """The basic sigmoid, but with a scaling factor"""
    return scale / (1 + jnp.exp(-beta * x))

@_lagr_sigmoid.defjvp
def _lagr_sigmoid_jvp(primals, tangents):
    x, beta, scale = primals
    x_dot, beta_dot, scale_dot = tangents
    primal_out = _lagr_sigmoid(x, beta, scale)
    tangent_out = tempered_sigmoid(x, beta=beta, scale=scale) * x_dot # Manually defined sigmoid
    return primal_out, tangent_out

def lagr_sigmoid(x, 
                 beta=1.0, # Inverse temperature
                 scale=1.0): # Amount to stretch the range of the sigmoid's lagrangian
    """The lagrangian of the sigmoid activation function"""
    return _lagr_sigmoid(x, beta=beta, scale=scale).sum()

In [6]:
#|hide
x = np.random.randn(4,20);beta=0.2; scale=1.3
test_eq(tempered_sigmoid(x, beta=beta, scale=scale), jax.grad(ft.partial(lagr_sigmoid, beta=beta, scale=scale))(x))



In [7]:
#| export
@jax.custom_jvp
def _lagr_tanh(x, beta=1.0):
    return 1 / beta * jnp.log(jnp.cosh(beta * x))

@_lagr_tanh.defjvp
def _lagr_tanh_defjvp(primals, tangents):
    x, beta = primals
    x_dot, beta_dot = tangents
    primal_out = _lagr_tanh(x, beta)
    tangent_out = jnp.tanh(beta * x) * x_dot
    return primal_out, tangent_out

def lagr_tanh(x, beta=1.0):
    return _lagr_tanh(x, beta).sum()

In [8]:
#| hide
import nbdev; nbdev.nbdev_export()

## Parameterized Lagrangians

It is beneficial to consider lagrangians as modules with their own learnable parameters.

In [9]:
#| export
import treex as tx
from dataclasses import dataclass
from typing import *

In [10]:
#| export
class LIdentity(tx.Module):
    """Lagrangian whose activation function is the identity function"""
    def __init__(self): pass
    def __call__(self, x):
        return lagr_identity(x)

class LRepu(tx.Module):
    """Lagrangian whose activation function is the rectified polynomial unit of specified degree `n`"""
    n: float = 2
    
    # Need a default for `n` to work with layer creation
    def __init__(self,
                 n=2.): # The degree of the RePU. By default, set to the ReLU configuration
        self.n = n

    def __call__(self, x):
        return lagr_repu(x, self.n)
    
class LRelu(tx.Module):
    """Lagrangian whose activation function is the rectified linear unit"""
    def __init__(self): pass
    def __call__(self, x):
        return lagr_relu(x)
    
class LSigmoid(tx.Module):
    """Lagrangian whose activation function is the sigmoid
    
    Parameters:
    
    - beta
    """
    beta: Union[float, jnp.ndarray] = tx.Parameter.node(default=1.0)
    scale: float = tx.Parameter.node(default=1.0)
    min_beta: float = 1e-6
    
    def __init__(self, 
                 beta=1., # Inverse temperature
                 scale=1., # Amount to stretch the sigmoid.
                 min_beta=1e-6): # Minimal accepted value of beta. For energy dynamics, it is important that beta be positive.
        self.beta = beta
        self.scale = scale
        self.min_beta = min_beta

    def __call__(self, x):
        return lagr_simoid(x, beta=jnp.clip(self.beta, self.min_beta), scale=self.scale)
    
class LSoftmax(tx.Module):
    """Lagrangian whose activation function is the softmax
    
    Parameters:
    
    - beta
    """
    beta: Union[float, jnp.ndarray] = tx.Parameter.node(default=1.0)
    axis: int = -1
    min_beta: float = 1e-6

    def __init__(self, 
         beta=1., # Inverse temperature
         axis=-1, # Axis over which to apply the softmax
         min_beta=1e-6): # Minimal accepted value of beta. For energy dynamics, it is important that beta be positive.
        self.beta = beta
        self.axis = axis
        self.min_beta = min_beta

    def __call__(self, x):
        return lagr_softmax(x, beta=jnp.clip(self.beta, self.min_beta), axis=self.axis)
    

class LExp(tx.Module):
    """Lagrangian whose activation function is the exponential function
    
    Parameters:

    - beta
    """
    beta: Union[float, jnp.ndarray] = tx.Parameter.node(default=1.0)
    min_beta: float = 1e-6

    def __init__(self, 
                 beta=1., # Inverse temperature, for the sharpness of the exponent
                 min_beta=1e-6): # Minimal accepted value of beta. For energy dynamics, it is important that beta be positive.
        self.beta = beta
        self.min_beta = min_beta
        
    def __call__(self, x):
        return lagr_exp(x, beta=jnp.clip(self.beta, self.min_beta))
    
class LTanh(tx.Module):
    """Lagrangian whose activation function is the tanh
    
    Parameters:
    
    - beta
    """
    beta: Union[float, jnp.ndarray] = tx.Parameter.node(default=1.0)
    min_beta: float = 1e-6

    def __init__(self, 
                 beta=1., # Inverse temperature, for the sharpness of the exponent
                 min_beta=1e-6): # Minimal accepted value of beta. For energy dynamics, it is important that beta be positive.
        self.beta = beta
        self.min_beta = min_beta

    def __call__(self, x):
        return lagr_tanh(x, beta=jnp.clip(self.beta,self.min_beta))

In [11]:
#| hide
import nbdev; nbdev.nbdev_export()