# Feature Toggle in JAX

## Importing dependencies

In [None]:
import jax
from jax import jit
from jax import numpy as jnp
from functools import partial
from chex import dataclass

### Initializing jit compiled functions

In [None]:
@jit
def add(a , b):
    return a + b

@jit
def mul(a , b):
    return jnp.matmul(a, b)

@jit
def sub(a , b):
    return a - b

## Initialising data

In [None]:
from jax import random

rng1, rng2 = jax.random.split(random.PRNGKey(14), 2)
a = jax.random.normal(rng1, (1000, 1000))
b = jax.random.normal(rng2, (1000, 1000))

## Toggle based on "key" or structure of the features pytree

In [None]:
@dataclass
class Features():
    enable_add: bool
    enable_mul: bool
    enable_sub: bool

    def __hash__(self):
        # TODO: can not hash jax arrays, so only hashing float variables
        return hash((self.enable_add, self.enable_mul, self.enable_sub))

    def __eq__(self, other):
        assertions = []
        for key in self:
            assertions.append(self[key] == other[key])
        return all(assertions)

In [None]:
import configparser

def load_config(config_file: str) -> Features:
    config = configparser.ConfigParser()
    config.read(config_file)
    
    pt = config["Features"]

    enable_add = pt.getboolean("enable_add", True)
    enable_sub = pt.getboolean("enable_sub", True)
    enable_mul = pt.getboolean("enable_mul", True)

    return Features(
        enable_add = enable_add,
        enable_sub = enable_sub,
        enable_mul = enable_mul
    )


In [None]:
features = load_config("features.ini")

## Pipeline with feature toggle

In [None]:
@partial(jit, static_argnums=2)
def simulate_with_toggle(a , b, features):    
    
    add_out = 0
    if features.enable_add:
        add_out = add(a, b)
    
    mul_out = 0
    if features.enable_mul:
        mul_out = mul(a, b)

    sub_out = 0
    if features.enable_sub:
        sub_out = sub(a, b)

    output = add_out + mul_out + sub_out

    return output

## Statically defined pipeline without feature toggle

In [None]:
@jit
def simulate(a , b):    
    # add_out = add(a, b)
    mul_out = mul(a, b)
    sub_out = sub(a, b)
    output = mul_out + sub_out

    return output

## Comparing output

In [None]:
tout = simulate_with_toggle(a, b, features)
sout = simulate(a, b)

jnp.allclose(tout, sout)

## Benchmarks

In [None]:
# %timeit simulate(a, b)

In [None]:
# %timeit simulate_with_toggle(a, b, features)

## Generating jaxprs

In [None]:
print(jax.make_jaxpr(simulate_with_toggle, static_argnums=2)(a, b, features))

In [None]:
print(jax.make_jaxpr(simulate)(a, b))