# jaxify

> The core functionality of `javiche`. Defines the `@jaxit()` decorator

In [None]:
#| default_exp jaxify

In [None]:
#| exporti
from ceviche import jacobian
import numpy as np
import jax
import jax.numpy as jnp
from typing import List
from functools import lru_cache, wraps

In [None]:
#| export
# from https://gist.github.com/Susensio/61f4fee01150caaac1e10fc5f005eb75
def np_cache(*args, **kwargs): 
    """LRU cache implementation for functions whose FIRST parameter is a numpy array
    >>> array = np.array([[1, 2, 3], [4, 5, 6]])
    >>> @np_cache(maxsize=256)
    ... def multiply(array, factor):
    ...     print("Calculating...")
    ...     return factor*array
    >>> multiply(array, 2)
    Calculating...
    array([[ 2,  4,  6],
           [ 8, 10, 12]])
    >>> multiply(array, 2)
    array([[ 2,  4,  6],
           [ 8, 10, 12]])
    >>> multiply.cache_info()
    CacheInfo(hits=1, misses=1, maxsize=256, currsize=1)
    
    """
    def decorator(function):
        @wraps(function)
        def wrapper(*args, **kwargs):
            mod_args = []
            for i,arg in enumerate(args): # modified to allow arbitrary amounts of arrays
              if isinstance(arg, np.ndarray) or isinstance(arg, jax.Array):
                mod_args.append(array_to_tuple(arg))
              else:
                mod_args.append(arg)
            return cached_wrapper(*mod_args, **kwargs)

        @lru_cache(*args, **kwargs)
        def cached_wrapper(hashable_array, *args, **kwargs):
            array = np.array(hashable_array)
            return function(array, *args, **kwargs)

        def array_to_tuple(np_array):
            """Iterates recursivelly."""
            #print(type(np_array))
            if isinstance(np_array, jax.Array):
              np_array = np.asarray(np_array)
            try:
              return tuple(array_to_tuple(_) for _ in np_array)
            except TypeError:
              return np_array

        # copy lru_cache attributes over too
        wrapper.cache_info = cached_wrapper.cache_info
        wrapper.cache_clear = cached_wrapper.cache_clear

        return wrapper

    return decorator

In [None]:
#| export
def jaxit(
  mode: str='reverse', #the mode used to calculate the jacobian using `ceviche`
  argnums: List[int]=[0], #the argument indices this function should be differentiable against
  cache: bool = False
  ):
  """
    make a function that internally uses autograd compatible to jax gradient calculations

    Attention: only a single output variable is supported
  """
  def identity_decorator(fn):
    return fn
  
  caching_decorator = identity_decorator
  if cache:
    caching_decorator = np_cache()

  def inner(function):
    grad_fns = [jacobian(function, mode=mode, argnum=i) for i in argnums]

    @jax.custom_jvp
    @caching_decorator
    def jaxed(*args):
      return function(*args)
    
    @caching_decorator
    @jaxed.defjvp
    def jaxed_jvp(primals, tangents):
      #print(type(primals), type(tangents))
      primals_out = jaxed(*primals)
      as_np = [np.asarray(prim) for prim in primals]
      grads = [jnp.array(grad_fns[i](*as_np)) for i in argnums]

      # if len(tangents) > len(grads):
      #   raise RuntimeError("passed `num_args` is lower than the actual number of arguments")
      
      contributions = jnp.array([jnp.dot(grads[i],tangents[i].flatten()) for i in argnums])
  
      return primals_out, jnp.sum(contributions)
    return jaxed
  return inner


In [None]:
#| hide
import nbdev; nbdev.nbdev_export()