In [1]:
import jax 
from jax.config import config as jax_config
jax_config.update("jax_enable_x64", True)
import jax.numpy as np
from jax import make_jaxpr
import numpy as onp
import objax
from objax.zoo.dnnet import DNNet
from objax.functional import tanh
from objax.functional.loss import mean_squared_error
from timeit import default_timer as timer
import batchjax

# Introduction

In this notebook we showcase an example of `batchjax` and its three usecases/modes which we will refer to as:

    - loop
    - objax
    - batched
   
This examples revolves around wanting to train multiple, independent, neural neworks simultaneously. This is a simple example only to demonstrate `batchjax`.
    

# Data generation functions

In [2]:
def get_data(N, seed):
    """ Generates a noisy sin curve with N observations. """
    onp.random.seed(seed)
    x = onp.linspace(0, 1, N)
    X = x[:, None]

    # Construct output with random input shift and additive Gaussian noise
    y = onp.sin((x+onp.random.randn(1))*10) + 0.01*onp.random.randn(N)
    Y = 0.8*y[:, None]

    return X, X, Y


# Neural Network Objects

In [3]:
class NN(objax.Module):
    """ Simple fully connected Neural Network model wrapper """
    def __init__(self, X, Y, layer_size):
        self.model = DNNet(layer_sizes=layer_size, activation=tanh)
        self.X = objax.StateVar(np.array(X))
        self.Y = objax.StateVar(np.array(Y))

    def objective(self):
        return mean_squared_error(
            self.model(self.X.value),
            self.Y.value,
            keep_axis=None
        )

    def predict(self, XS):
        return self.model(XS)

class NNList(objax.Module):
    """ Wrapper around a NN to add suport for multiple Neural networks. """
    def __init__(self, m_list: list, batch_type):
        self.P = len(m_list)

        if batch_type == batchjax.BatchType.BATCHED:
            self.m_list = batchjax.Batched(m_list)
        else:
            self.m_list = objax.ModuleList(m_list)

        self.batch_type = batch_type

    def objective(self):
        # Use batchjax to batch across each neural network
        obj_arr = batchjax.batch_or_loop(
            lambda x: x.objective(),
            inputs = [self.m_list],
            axes=[0],
            dim = self.P,
            out_dim = 1,
            batch_type = self.batch_type
        )

        return np.sum(obj_arr)

# Demonstration of different modes

In [4]:
def train(m):
    # Train
    start = timer()
    onp.random.seed(0)
    opt = objax.optimizer.Adam(m.vars())
    lr = 1e-3
    epochs = 500
    gv = objax.GradValues(m.objective, m.vars())


    breakpoint()
    @objax.Function.with_vars(m.vars() + gv.vars() + opt.vars())
    def train_op():
        g, v = gv()  # returns gradients, loss
        opt(lr, g)
        return v


    train_op = objax.Jit(train_op)  # Compile train_op to make it run faster.


    loss_arr = []
    for i in range(epochs):
        v = train_op()
        loss_arr.append(v)

    end = timer()
    
    time_taken = end-start
    final_loss = loss_arr[-1]
    
    return time_taken, final_loss

In [5]:
num_models = 50
data = [
    get_data(200, p) for p in range(num_models)
]

## Loop mode

In [6]:
def loop_mode_demonstration(P):
    """ P: the number of indepdent neural networks """
    
    # Construct all independent neural networks
    model_list = [
        NN(data[p][1], data[p][2], [1, 128, 1]) for p in range(P)
    ]   
    
    # List Wrapper
    m = NNList(model_list, batchjax.BatchType.LOOP)
    
    time_taken, final_loss = train(m)
    
    print('Time taken: ', time_taken)
    print('Final loss: ', final_loss)
    
    
loop_mode_demonstration(num_models)



Time taken:  40.210479251
Final loss:  [DeviceArray(14.65966419, dtype=float64)]


## Objax Mode

In [7]:
def objax_mode_demonstration(P):
    """ P: the number of indepdent neural networks """
    
    # Construct all independent neural networks
    model_list = [
        NN(data[p][1], data[p][2], [1, 128, 1]) for p in range(P)
    ]   
    
    # List Wrapper
    m = NNList(model_list, batchjax.BatchType.OBJAX)
    
    time_taken, final_loss = train(m)
    
    print('Time taken: ', time_taken)
    print('Final loss: ', final_loss)
    
    
objax_mode_demonstration(num_models)

Time taken:  19.850679743
Final loss:  [DeviceArray(14.65667522, dtype=float64)]


## Batched mode

In [8]:
def batched_mode_demonstration(P):
    """ P: the number of indepdent neural networks """
    
    # Construct all independent neural networks
    model_list = [
        NN(data[p][1], data[p][2], [1, 128, 1]) for p in range(P)
    ]   
    
    # List Wrapper
    m = NNList(model_list, batchjax.BatchType.BATCHED)
    
    time_taken, final_loss = train(m)
    
    print('Time taken: ', time_taken)
    print('Final loss: ', final_loss)
    
    
batched_mode_demonstration(num_models)

Time taken:  4.074664240000004
Final loss:  [DeviceArray(14.49094453, dtype=float64)]


# Understanding the difference between the modes

To see the different modes result in different run times we can look at the compiled HLO code. 

In loop mode a native python loop is use to iterate over every neural network object, hence the computational graph contains the operations for each independent neural network.

In objax mode each neural network is effectively stacked into single object before batching and then unpacked after.

In batched mode, the objax.ModuleList is replace by a Batched objax, which effectively 'pre-stacks' the objects into a single one. This removes a lot of the broadcasting that is required in objax mode HOWEVER this does change computational graph, as the individual neural networks are replaced by a new object with stacked variables and hence should only be used when fully understood.


In [9]:
model_list = [
    NN(data[p][1], data[p][2], [1, 128, 1]) for p in range(2)
]   

## Looped mode HLO code

Below is the HLO generated code. You can see that the same code is repeated for each neural nework.

In [10]:
m_looped = NNList(model_list, batchjax.BatchType.BATCHED)
make_jaxpr(m_looped.objective)()

{ [34m[22m[1mlambda [39m[22m[22ma[35m:f64[2,128][39m b[35m:f64[2,1,128][39m c[35m:f64[2,1][39m d[35m:f64[2,128,1][39m e[35m:f64[2,200,1][39m f[35m:f64[2,200,1][39m; . [34m[22m[1mlet
    [39m[22m[22mg[35m:f64[2,128][39m = copy a
    h[35m:f64[2,1,128][39m = copy b
    i[35m:f64[2,1][39m = copy c
    j[35m:f64[2,128,1][39m = copy d
    k[35m:f64[2,200,1][39m = copy e
    l[35m:f64[2,200,1][39m = copy f
    m[35m:f64[2,200,128][39m = dot_general[
      dimension_numbers=(((2,), (1,)), ((0,), (0,)))
      precision=None
      preferred_element_type=None
    ] k h
    n[35m:f64[2,1,128][39m = broadcast_in_dim[
      broadcast_dimensions=(0, 2)
      shape=(2, 1, 128)
    ] g
    o[35m:f64[2,200,128][39m = add m n
    p[35m:f64[2,200,128][39m = tanh o
    q[35m:f64[2,200,1][39m = dot_general[
      dimension_numbers=(((2,), (1,)), ((0,), (0,)))
      precision=None
      preferred_element_type=None
    ] p j
    r[35m:f64[2,1,1][39m = broadcast

## Objax mode HLO code

Below is the HLO generated code. After a lot of broadcasting you can see that the objective code is only repeated once.

In [11]:
m_objax = NNList(model_list, batchjax.BatchType.OBJAX)
make_jaxpr(m_objax.objective)()

{ [34m[22m[1mlambda [39m[22m[22ma[35m:f64[128][39m b[35m:f64[128][39m c[35m:f64[1,128][39m d[35m:f64[1,128][39m e[35m:f64[1][39m f[35m:f64[1][39m g[35m:f64[128,1][39m
    h[35m:f64[128,1][39m i[35m:f64[200,1][39m j[35m:f64[200,1][39m k[35m:f64[200,1][39m l[35m:f64[200,1][39m; . [34m[22m[1mlet
    [39m[22m[22mm[35m:f64[1,128][39m = broadcast_in_dim[broadcast_dimensions=(1,) shape=(1, 128)] a
    n[35m:f64[1,128][39m = broadcast_in_dim[broadcast_dimensions=(1,) shape=(1, 128)] b
    o[35m:f64[2,128][39m = concatenate[dimension=0] m n
    p[35m:f64[1,1,128][39m = broadcast_in_dim[
      broadcast_dimensions=(1, 2)
      shape=(1, 1, 128)
    ] c
    q[35m:f64[1,1,128][39m = broadcast_in_dim[
      broadcast_dimensions=(1, 2)
      shape=(1, 1, 128)
    ] d
    r[35m:f64[2,1,128][39m = concatenate[dimension=0] p q
    s[35m:f64[1,1][39m = broadcast_in_dim[broadcast_dimensions=(1,) shape=(1, 1)] e
    t[35m:f64[1,1][39m = broadcast_in_dim[

## Batched mode HLO code

Below is the HLO generated code. Similarily to objax mode the code to compute the objective function is only repeated once however there is now much less broadcasting.

In [12]:
m_batched = NNList(model_list, batchjax.BatchType.BATCHED)
make_jaxpr(m_batched.objective)()

{ [34m[22m[1mlambda [39m[22m[22ma[35m:f64[2,128][39m b[35m:f64[2,1,128][39m c[35m:f64[2,1][39m d[35m:f64[2,128,1][39m e[35m:f64[2,200,1][39m f[35m:f64[2,200,1][39m; . [34m[22m[1mlet
    [39m[22m[22mg[35m:f64[2,128][39m = copy a
    h[35m:f64[2,1,128][39m = copy b
    i[35m:f64[2,1][39m = copy c
    j[35m:f64[2,128,1][39m = copy d
    k[35m:f64[2,200,1][39m = copy e
    l[35m:f64[2,200,1][39m = copy f
    m[35m:f64[2,200,128][39m = dot_general[
      dimension_numbers=(((2,), (1,)), ((0,), (0,)))
      precision=None
      preferred_element_type=None
    ] k h
    n[35m:f64[2,1,128][39m = broadcast_in_dim[
      broadcast_dimensions=(0, 2)
      shape=(2, 1, 128)
    ] g
    o[35m:f64[2,200,128][39m = add m n
    p[35m:f64[2,200,128][39m = tanh o
    q[35m:f64[2,200,1][39m = dot_general[
      dimension_numbers=(((2,), (1,)), ((0,), (0,)))
      precision=None
      preferred_element_type=None
    ] p j
    r[35m:f64[2,1,1][39m = broadcast