In [469]:
import jax
import jax.numpy as np
import jax.random as jr
from jax import tree_util

import inspect
from functools import wraps, partial

## Auto Batch Scratch Notebook

The `@auto_batch` decorator augments a function with the `_batched` keyword argument. By default, this argument is set to `False` (the function runs as explicitly written); however, by specifying `_batched=True` the function will run in vmapped mode (as specified by the `batched_args` defined in the decorator).

This enables quick and syntactically minimal switching between batched and unbatched function behavior. Can we detect batching automatically though?

In [485]:
def auto_batch(batched_args=("data", "posterior", "covariates", "metadata", "states")):
    
    def auto_batch_decorator(f):
        sig = inspect.signature(f)
        @wraps(f)
        def wrapper(*args, _batched=False, **kwargs):  # augment the function signature with _batched kwarg
            
            # Get the `dataset` argument
            bound_args = sig.bind(*args, **kwargs)
            bound_args.apply_defaults()

            fixed_args = dict([(arg, val) for (arg, val) in bound_args.arguments.items() 
                               if arg not in batched_args])
            batch_args = dict([(arg, val) for (arg, val) in bound_args.arguments.items() 
                               if arg in batched_args])
            
            f_partial = partial(f, **fixed_args)  # TODO: consider jax.tree_util.Partial?
            
            # TODO: can we automatically batch (as the decorator name suggests)?
            # NOTE: can't do jax.lax.cond here because both branches need to run 
            # given the same args (which  won't work for us, since batched v. unbatched)
            if _batched:
                print("vmapping")
                return jax.vmap(f_partial)(**batch_args)
            else:
                print("atomic")
                return f_partial(**batch_args)
            
        # store this info inside the function object
        wrapper.batched_args = batched_args
            
        return wrapper
    return auto_batch_decorator

In [479]:
# let's see it in action

@auto_batch(batched_args=("x", "y"))
def f(x, y, z):
    return np.dot(x, y) + z

In [480]:
# check whether f has auto batching functionality

print(hasattr(f, "batched_args"))
print(f.batched_args)

True
('x', 'y')


In [481]:
# atomic args

x = x_batched[0]
y = y_batched[0]
z = np.ones((10,))

print(f(x, y, z).shape)

atomic
(10,)


In [482]:
#batched args

x_batched = np.ones((32, 10))
y_batched = np.zeros((32, 10))

print(f(x_batched, y_batched, z, _batched=True).shape)

vmapping
(32, 10)


#### what about jit?

In [483]:
jitted_f = jax.jit(partial(f, _batched=False))

jitted_f(x, y, z)
jitted_f(x, y, z)
jitted_f(x, y, z);

atomic


In [484]:
jitted_f_batched = jax.jit(partial(f, _batched=True))

jitted_f_batched(x_batched, y_batched, z)
jitted_f_batched(x_batched, y_batched, z)
jitted_f_batched(x_batched, y_batched, z);

vmapping
