# compile_from_dict.ipynb
***example of compiling a tensorflow-free model from a dictionary saved with tf_to_dict, using either numpy or jax***
___


## imports

In [None]:
import json
import numpy as np
import jax.numpy as jnp

from compile_from_dict import numpy_compile, jax_compile

## load in dictionary made using `tf_to_dict`

In [None]:
model_name = 'pitchfork'
with open(f'models/{model_name}.json', 'r') as fp:
    model_dict = json.load(fp)

## numpy_compile
let's take a look at how we use numpy_compile to interpret the `model_dict` dictionary made using `tf_to_dict`

In [None]:
numpy_model = numpy_compile(model_dict)

done! `numpy_model` is now a an object containing a flow of functions written purely in numpy which represent a forwards pass through the network

the example network, `pitchfork`, takes a set of 5 inputs and predicts 3 outputs on one branch and 38 on the other - let's define a test point of arbitrary values, and check that this is happening:

In [None]:
numpy_single_point = np.array([[0.5,0.5,0.5,0.5,0.5]])

In [None]:
numpy_model.forward_pass(numpy_single_point)

nice! that seemed fast, but let's time it:

In [None]:
%%time
numpy_model.forward_pass(numpy_single_point)

we typically want neural networks to predict on huge batches at once rather than just single points, though.

let's check that this functionality isn't lost when compiling from our dict, and time:

In [None]:
numpy_many_points = np.full((100000,5), 0.5)

In [None]:
%%time
numpy_model.forward_pass(numpy_many_points)

also pretty fast!

however, we can definitely make this faster by using jax (which utilises the GPU where possible), and even faster if we then JIT compile the jax predict function!

## jax_compile
this time, let's try the same model but compiled entirely in jax - same process as before:

In [None]:
jax_model = jax_compile(model_dict)

the `jax_model` object is written entirely in jax.numpy - so we want to be careful that we're only passing in jax objects otherwise we might be losing valuable time!

lets define some test points like before:

In [None]:
jax_single_point = jnp.array([[0.5,0.5,0.5,0.5,0.5]])
jax_many_points = jnp.full((100000,5), 0.5)

and then we can perform a forward pass and time for one point:

In [None]:
%%time
jax_model.forward_pass(jax_single_point)

or for a batch of points:

In [None]:
%%time
jax_model.forward_pass(jax_many_points)

this may be faster than the numpy version or not depending on your machine.

one way that we can certainly speed this up is by jit compiling the flow of functions in jax_model:

## jax_compile.jit_forward_pass
as before, we compile our model from the dictionary:

In [None]:
jax_model = jax_compile(model_dict)

and define some jax friendly test points:

In [None]:
jax_single_point = jnp.array([[0.5,0.5,0.5,0.5,0.5]])
jax_many_points = jnp.full((100000,5), 0.5)

now let's try using the jit compiled version of `forward_pass`, and time:

In [None]:
%%time
jax_model.jit_forward_pass(jax_single_point)

oh dear! this is (probably) slower than your `numpy_model.forward_pass(np_single_point)` or `jax_model.forwad_pass(jax_single_point)` cells!

actually, this is entirely expected because of the way JIT compilation works - the flow of functions is compiled each time for a **specific input shape**, and there is a small overhead associated with JIT compilation.

***this is an important point - JIT compiling might not help us (and in fact slow things down) if our batch sizes change dynamically!***

let's see whether we do better now that we've compiled the single point pass:

In [None]:
%%time
jax_model.jit_forward_pass(jax_single_point)

nice!

what about for the batch of many points? we'll run and time one cell to compile:

In [None]:
%%time
jax_model.jit_forward_pass(jax_many_points)

and then again to time the compiled version:

In [None]:
%%time
jax_model.jit_forward_pass(jax_many_points)

this should be much faster!