# Lecture 2: JIT compilation and Physics Analysis

<br>

We saw how JAX has a numpy-like user API for performing array-based computing.

It has the capability to perform Composable Function Transformations on our Python/NumPy code and use it with the XLA compiler to perform super-fast computations on CPU, GPU or TPU

<br>

$$\text{Python} \rightarrow \text{Intermediate Representation} \rightarrow \text{Transformations}$$

<br>

<br>

In [3]:
# Let's try a simple example - we already saw this in the last lecture

import jax
import jax.numpy as jnp
from jax import jit
from jax import make_jaxpr

def fn(tuple_arr):
    
    sum_val = jnp.sum(tuple_arr**2 - tuple_arr**3 - tuple_arr)
    
    return sum_val


print(make_jaxpr(fn)(jnp.ones(100)))


{ lambda ; a:f32[100]. let
    b:f32[100] = integer_pow[y=2] a
    c:f32[100] = integer_pow[y=3] a
    d:f32[100] = sub b c
    e:f32[100] = sub d a
    f:f32[] = reduce_sum[axes=(0,)] e
  in (f,) }


<br>

<br>

<br>
<br>
<div>
<center>
<img src="figures/Py_IR_Trans.png" width="1000"/>
</center>
</div>
<br>
<br>

<br>

<br>

Examples of composable transformations include automatic differentiation, JIT compilation, parallelization on multi-core hardware (pmap), automatic vectorization (vmap) and automatic differentiation.

We explored automatic differentiation in the last lecture.

In this lecture, we will explore JIT compilation as well as find out a bit about how JAX works under the hood with these intermediate representations and transformations.

<br>

<br>

### JIT Compilation

<br>

JIT stands for Just-In-Time compilation, as opposed to AOT (Ahead-Of-Time) compilation.

As the name suggests, compilation of the code happens $\textit{just in time}$ for computation.


<br>

<br>

During JIT compilation, JAX applies a series of optimizations on primitive `lax` operations to generate efficient XLA executable code on CPU, GPU or TPU.


<br>
<br>
<div>
<center>
<img src="figures/JAX_workflow.png" width="1000"/>
    </center>
</div>
<br>
<br>

<br>

<br>

Once a function has been JIT-compiled, JAX caches the resulting XLA code so that it can be re-used in subsequent calls. 

<br>
<br>

$$\textbf{This is where the power of JIT compilation comes in} \\ \textbf{ - after an initial compilation phase,} \\ \textbf{the subsequent calls to the JIT-compiled function are super fast!}$$ 

<br>

In [76]:
import jax
from jax import jit

def fn(tuple_arr):
    
    return jnp.sum(tuple_arr ** 2 - tuple_arr ** 3 - tuple_arr)

%timeit -r1 -n1 fn(jnp.ones(100)).block_until_ready()

fn_compiled = jit(fn)
%timeit -r1 -n1 fn_compiled(jnp.ones(100)).block_until_ready()

1.47 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
56.5 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)


<br>

As mentioned before, the power of JIT compilation only comes in when we have intend to repeat function calls many times - for example finding the minimum of a loss function by repeatedly calling the loss function and it's gradient.

During the first computation, JAX performs the optimizations which takes some time. - hence the slowed down computation time.

<br>

In [79]:
# Let's try again

%timeit -r1 -n1 fn(jnp.ones(100)).block_until_ready()

%timeit -r1 -n1 fn_compiled(jnp.ones(100)).block_until_ready()

1.64 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
490 µs ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)


<br>

In [80]:
# Let's try again with a different sized array

%timeit -r1 -n1 fn(jnp.ones(1000)).block_until_ready()

%timeit -r1 -n1 fn_compiled(jnp.ones(1000)).block_until_ready()

1.3 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
42.9 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)


<br>

In [83]:
# Let's re-do the same computation with the same array as before

%timeit -r1 -n1 fn(jnp.ones(1000)).block_until_ready()

%timeit -r1 -n1 fn_compiled(jnp.ones(1000)).block_until_ready()

2.45 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
560 µs ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)


<br>

To better understand what's going on, let's find out how JAX performs under the hood.

<br>

In [85]:
import jax
from jax import jit

# Equivalent to doing jit(fn), we can optionally put a decorator that does the same thing

@jit
def fn(tuple_arr):
    
    print(tuple_arr)   # Let's print out the input we passed to the function
    
    return jnp.sum(tuple_arr ** 2 - tuple_arr ** 3 - tuple_arr)


print(fn(jnp.ones(100)))


Traced<ShapedArray(float32[100])>with<DynamicJaxprTrace(level=0/1)>
-100.0


<br>

JAX first takes in the shape of the input, and uses it to create an abstract object called `ShapedArray` that has a specific data type and size.

But it has no value!

<br>

In [86]:
print(jax.make_jaxpr(fn)(jnp.ones(100)))

Traced<ShapedArray(float32[100])>with<DynamicJaxprTrace(level=1/1)>
{ lambda ; a:f32[100]. let
    b:f32[] = xla_call[
      call_jaxpr={ lambda ; c:f32[100]. let
          d:f32[100] = integer_pow[y=2] c
          e:f32[100] = integer_pow[y=3] c
          f:f32[100] = sub d e
          g:f32[100] = sub f c
          h:f32[] = reduce_sum[axes=(0,)] g
        in (h,) }
      name=fn
    ] a
  in (b,) }


<br>

Then, it passes this `ShapedArray` object through the primitive `lax` operations to create a computation chart for function execution. This part of the process is known as $\textit{tracing}$.

But note that the `ShapedArray` object has a specific size and type now `a:f32[100]` - but that is all it has. This object doesn't store any values.

<br>

Once JAX finishes creating the computation chart with `lax` operations, it then optimizes the code for the XLA compiler to run on any machine of choice.

This optimization step is then cached.

Upon repeated executions with inputs of the same type and size, but any different value, the cached code is run on the machine with super-fast execution times.

<br>

<br>

#### Key Summary of JIT compilation

<br>

By default JAX executes operations one at a time, in sequence.

We saw an example of this in the last lecture with our purpose-ly bad written function and how JAX converted it.

<br>

In [1]:
# Let's see how JAX simplifies our expression with it's computation graph

def fnc_jax(x1, x2):
    
    return (jnp.divide(x1,x2) - jnp.exp(x2))*(jnp.sin(jnp.divide(x1,x2)) + jnp.divide(x1,x2) - jnp.exp(x2))

jax.make_jaxpr(fnc_jax)(1.0,1.0)

{ lambda ; a:f32[] b:f32[]. let
    c:f32[] = div a b
    d:f32[] = exp b
    e:f32[] = sub c d
    f:f32[] = div a b
    g:f32[] = sin f
    h:f32[] = div a b
    i:f32[] = add g h
    j:f32[] = exp b
    k:f32[] = sub i j
    l:f32[] = mul e k
  in (l,) }

<br>

Using a just-in-time (JIT) compilation decorator, sequences of operations can be optimized together and run at once.

Let's see this in action!

<br>

In [2]:
# Here we perform an AOT compilation to show what goes on after compilation - yes JAX also supports AOT when necessary

from jax import jit

print(jax.__version__)

print(jit(fnc_jax).lower(1., 1.).compile().as_text())

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


0.4.6
HloModule jit_fnc_jax, entry_computation_layout={(f32[],f32[])->f32[]}, allow_spmd_sharding_propagation_to_output={true}

%fused_computation (param_0.2: f32[], param_1.4: f32[]) -> f32[] {
  %param_0.2 = f32[] parameter(0)
  %param_1.4 = f32[] parameter(1)
  %exponential.0 = f32[] exponential(f32[] %param_1.4), metadata={op_name="jit(fnc_jax)/jit(main)/exp" source_file="/tmp/ipykernel_42962/377062428.py" source_line=8}
  %subtract.1 = f32[] subtract(f32[] %param_0.2, f32[] %exponential.0), metadata={op_name="jit(fnc_jax)/jit(main)/sub" source_file="/tmp/ipykernel_42962/377062428.py" source_line=8}
  %sine.0 = f32[] sine(f32[] %param_0.2), metadata={op_name="jit(fnc_jax)/jit(main)/sin" source_file="/tmp/ipykernel_42962/377062428.py" source_line=8}
  %add.0 = f32[] add(f32[] %sine.0, f32[] %param_0.2), metadata={op_name="jit(fnc_jax)/jit(main)/add" source_file="/tmp/ipykernel_42962/377062428.py" source_line=8}
  %subtract.0 = f32[] subtract(f32[] %add.0, f32[] %exponential.0), meta

<br>

Not all JAX code can be JIT compiled, as it requires array shapes to be static & known at compile time.

As we saw before, JIT compilation traces a `ShapedArray` object with fixed shape and type - you cannot change it's shape after compilation.

<br>

In [4]:
@jit
def fn_bad(tuple_arr):
    
    sum_arr = tuple_arr ** 2 - tuple_arr ** 3 - tuple_arr
    
    sum_arr = sum_arr[sum_arr>=0.0]          # Remove negative values from array
    
    return jnp.sum(sum_arr)       


# We cannot have output with shapes not known explicitly at compile time

print(fn_bad(jnp.ones(100)))

NonConcreteBooleanIndexError: Array boolean indices must be concrete; got ShapedArray(bool[100])

See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.NonConcreteBooleanIndexError

<br>

<br>

### Things to Keep in Mind when using JAX and JIT

<br>

JAX provides the `jax.numpy` wrapper to mimic the more familiar interface for users.

Under the hood, howoever, JAX performs it's computations using the more powerful, but stricter, `jax.lax` API. 

We just explored this with some examples.

<br>

<br>

But `jax.numpy` cannot always be directly used as a replacement to `numpy`.

For example, unlike NumPy arrays, JAX arrays are always immutable.

<br>

In [98]:
import jax.numpy as jnp
import numpy as np

arr_jnp = jnp.array([1,2,3,4])

print(type(arr_jnp))

arr_np = np.array([1,2,3,4])

print(type(arr_np))

arr_np[1] = 0
print(arr_np)

arr_jnp[1] = 0



<class 'jaxlib.xla_extension.DeviceArray'>
<class 'numpy.ndarray'>
[1 0 3 4]


TypeError: '<class 'jaxlib.xla_extension.DeviceArray'>' object does not support item assignment. JAX arrays are immutable. Instead of ``x[idx] = y``, use ``x = x.at[idx].set(y)`` or another .at[] method: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html

<br>

<br>

In [99]:
# For updating individual elements, JAX provides an indexed update syntax that returns an updated copy:

arr_jnp = arr_jnp.at[1].set(0)  # Re-assign the mutated array
print(arr_jnp)


[1 0 3 4]


<br>

<br>

### Pure Functions

The proper way to use JAX and JIT compilation is to use it only on functionally pure Python functions.

What does this mean? A pure function will always return the same result if invoked with the same inputs.

<br>

In [2]:
from jax import jit
import jax.numpy as jnp

def fn(tuple_arr):
    
    print(tuple_arr)   # This is a side-effect
    
    return jnp.sum(tuple_arr ** 2 - tuple_arr ** 3 - tuple_arr)


# The side-effects appear during the first run
print ("First call: ", jit(fn)(jnp.ones(100)))

# Subsequent runs with parameters of same type and shape may not show the side-effect
# This is because JAX now invokes a cached compilation of the function
print ("Second call: ", jit(fn)(jnp.ones(100)))




Traced<ShapedArray(float32[100])>with<DynamicJaxprTrace(level=0/2)>
First call:  -100.0
Second call:  -100.0


<br>

In [12]:
# Global variables defined outside the function will be stored as static when JIT caches the computation

bias = 5.0

def fn(tuple_arr):
    
    return jnp.sum(tuple_arr ** 2 - tuple_arr ** 3 - tuple_arr)+bias

# JAX captures the value of the global during the first run
print ("First call: ", jit(fn)(jnp.ones(100)))

bias = bias + 10.0  # Update the global

# Subsequent runs may silently use the cached value of the globals
print ("Second call: ", jit(fn)(jnp.ones(100)))

# JAX re-runs the Python function when the type or shape of the argument changes
# This will end up reading the latest value of the global
print ("Third call, different type: ", jit(fn)(jnp.ones(100, dtype='int32')))

First call:  -95.0
Second call:  -95.0
Third call, different type:  -85.0


<br>

$\textbf{General rule of thumb:}$ 

$$\textbf{All the input data should be passed through the function parameters,} \\ 
\textbf{all the results should be output through the function results.}$$

<br>

This will also help with the constant-folding issue for large constants passed to the function.

For more $\textit{sharp edges}$ to be careful of, read https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html

<br>

<br>

## Physics Analysis

<br>

As an example, let us continue with the $H \rightarrow WW \rightarrow 2\ell2\nu$ example we explored in the ML introduction lectures.

<br>
<br>
<div>
<center>
<img src="figures/WW_Feynman.png" width="750"/>
    </center>
</div>
<br>
<br>

<br>

<br>

<img align="right" src="figures/Discriminant_Prelim.png" width="500"/>

<br>

We trained an ML discriminant between the Signal process $H \rightarrow WW \rightarrow 2\ell2\nu$ and the various backgrounds.

<br>

We will use this binned histogram to do a statistical inference on the parameter $\mu$, where $\mu$ is characterized as the signal-strength: $N_\text{tot} = \mu S + B$.

<br>

Using the discriminant, we can calculate the full likelihood for the value of POI $\mu$.

<br>

We approximate the likelihood of observing $n_i$ events in bin $i$ of the discriminant using a Poisson likelihood centered around the MC expected value in that bin. 

<br>

This gives us an expression for the full likelihood, inclusive of all bins:

<br>

$$\mathcal{L} (\mu | \mathcal{D}_{obs}) = \prod_{i}^\text{bins} \frac{e^{-\nu (\mu)} \nu^{N_{obs}} (\mu)}{N_{obs} !}$$



<br>

<br>

Systematic uncertainties arise in all physics analyses where there exist uncertainties in the predictions of any of the simulation steps. 

Uncertainties that arise in the detector simulation are normally categorised as experimental uncertainties. Uncertainty in any of the other processing steps is usually categorised as a theoretical uncertainty.

<br>
<br>
<br>
<div>
<center>
<img src="figures/Unc_chart.png" width="750"/>
    </center>
</div>
<br>
<br>
<br>

Theoretical and experimental uncertainties are grouped together as $\textit{systematic uncertainties}$.

<br>

<br>

The effect of systematic uncertainties is described in the probability model using so-called constrained NPs, denoted here by $\theta$:

$$\mathcal{L}_\text{phys} (\mu, \theta | \mathcal{D}_{obs}) = \prod_{i}^\text{bins} \frac{e^{-\nu (\mu, \theta)} \nu (\mu, \theta)^{N_{obs}} }{N_{obs} !}$$

<br>


The detailed probability model for the auxiliary measurement is approximated as a Gaussian, with central values and errors estimated using either experimental data or theoretical assumptions. 

The constraint term arising from the subsidiary measurement in the likelihood is written simply as (after some convenient normalizations):

$$\mathcal{L}_\text{constraint} (\theta) = \prod_{i}^{N_\text{syst}} \frac{1}{2\pi} \exp \left(-\frac{\theta^2}{2}\right)$$

The full $\theta$-parametrized likelihood is then:


$$\mathcal{L} (\mu, \theta | \mathcal{D}_{obs}) = \mathcal{L}_\text{phys} (\mu, \theta | \mathcal{D}_{obs}) \cdot \mathcal{L}_\text{constraint} (\theta)$$

<br>

<br>

In order to do a hypothesis test using the likelihoods, we must first define a test statistic $t_\mu$. 

In a typical physics analysis, a profiled log of the likelihood ratio between the hypothesis of interest and a nominal hypothesis is chosen as the test statistic. 

<br>

$$t_\mu = -2 \log \frac{\mathcal{L} (\mu, \hat{\hat{\theta}} | \mathcal{D}_{obs})}{\mathcal{L} (\hat{\mu}, \hat{\theta} | \mathcal{D}_{obs})}$$

<br>

where $\hat{\mu}$, $\hat{\theta}$ are the global best-fit values of the parameters $\mu$ and $\theta$ that maxmizes the likelihood:

$$\hat{\mu}, \hat{\theta} = \underset{\mu, \theta}{\operatorname{argmax}}\mathcal{L} (\mu, \theta | \mathcal{D}_{obs})$$

<br>

and $\hat{\hat{\theta}}$ is the local best-fit value of $\theta$ for a fixed value of parameter $\mu$ being scanned:

$$\hat{\hat{\theta}} = \underset{\theta}{\operatorname{argmax}}\mathcal{L} (\mu_\text{fix}, \theta | \mathcal{D}_{obs})$$

<br>

The best-fit value of the parameter of interest (POI) $\mu$ given an observed dataset can be calculated by minimizing the log likelihood test statistic:

$$\mu_\text{MLE} = \underset{\mu}{\operatorname{argmin}} t_\mu $$

<br>

This quantity $\mu_\text{MLE}$ is called Maximum Likelihood Estimator (MLE) of the POI $\mu$.

And of course, this is yet another problem that can be solved using gradient-based optimizations!

<br>

<br>

There is a very popular tool used in HEP traditionally, that has the capability to perform this optimization and find the best-fit values using Numerical Differentiation techniques!

This tool is called Minuit (https://en.wikipedia.org/wiki/MINUIT)

For our purpose, we will be working with it's Python-friendly implementation iMinuit (https://iminuit.readthedocs.io/en/stable/index.html).

<br>

<br>

While the tool uses numerical differentiation by default, we can take advantage of the accelerated array-based computing capabilities of JAX.

In order to do the fitting optimization, iMinuit makes several (often thousands in a full analysis) of repeated calls of the test statistic function until it finds the best fit value.

This is where we can take advantage of JIT compilation for an accelerated computation.

<br>

We shall explore this and more in the today's exercise.