# Feature Toggle in JAX

## Importing dependencies

In [None]:
import jax
from jax import jit
from jax import numpy as jnp
from functools import partial
from chex import dataclass

### Initializing jit compiled functions

In [None]:
def step1(a , b):
    return a + b

def step2(a , b):
    return a * b

def step3(a , b):
    return a - b

## Initialising data

In [None]:
from jax import random

rng1, rng2 = jax.random.split(random.PRNGKey(14), 2)
a = jax.random.normal(rng1, (1000, 1000))
b = jax.random.normal(rng2, (1000, 1000))

## Toggle based on "key" or structure of the features pytree

In [None]:
@dataclass
class Features():
    enable_step1: bool
    enable_step2: bool
    enable_step3: bool

    def __hash__(self):
        return hash((self.enable_step1, self.enable_step2, self.enable_step3))

    def __eq__(self, other):
        assertions = []
        for key in self:
            assertions.append(self[key] == other[key])
        return all(assertions)

In [None]:
import configparser

def load_config(config_file: str) -> Features:
    config = configparser.ConfigParser()
    config.read(config_file)
    
    pt = config["Features"]

    enable_step1 = pt.getboolean("enable_step1", True)
    enable_step2 = pt.getboolean("enable_step2", True)
    enable_step3 = pt.getboolean("enable_step3", True)

    return Features(
        enable_step1 = enable_step1,
        enable_step2 = enable_step2,
        enable_step3 = enable_step3
    )


In [None]:
features = load_config("features.ini")
print(features)

## Pipeline with feature toggle

In [None]:
@partial(jit, static_argnums=2)
def pipeline_with_toggle(a , b, features):    
    
    step1_out = 0
    if features.enable_step1:
        step1_out = step1(a, b)
    
    step2_out = 0
    if features.enable_step2:
        step2_out = step2(a, b)

    step3_out = 0
    if features.enable_step3:
        step3_out = step3(a, b)

    output = step1_out + step2_out + step3_out

    return output

## Statically defined pipeline without feature toggle

In [None]:
@jit
def pipeline(a , b):    
    # step1_out = step1(a, b)
    step2_out = step2(a, b)
    step3_out = step3(a, b)
    output = step2_out + step3_out

    return output

## Comparing output

In [None]:
tout = pipeline_with_toggle(a, b, features)
sout = pipeline(a, b)

jnp.allclose(tout, sout)

## Benchmarks

In [None]:
# %timeit pipeline(a, b)

In [None]:
# %timeit pipeline_with_toggle(a, b, features)

## Generating jaxprs

In [None]:
print(jax.make_jaxpr(pipeline_with_toggle, static_argnums=2)(a, b, features))

In [None]:
print(jax.make_jaxpr(pipeline)(a, b))