# A Gentle start: example

In [1]:
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt

In [2]:
%config InlineBackend.figure_format = "retina"

In [3]:
key = jax.random.PRNGKey(314)
key_oracle, key_experts = jax.random.split(key)
n_experts = 2 ** 5
n_timesteps = 100
true_expert_where = 10

oracle = jax.random.bernoulli(key_oracle, p=0.6, shape=(n_timesteps,))
experts = jax.random.bernoulli(key_experts, p=0.6, shape=(n_timesteps, n_experts))
# Insert the one true expert to the list of experts
experts = jnp.insert(experts, true_expert_where, oracle, axis=1)

In [4]:
def update_weights(y, weights, experts, n_errors):
    """
    We filter where the choice of an expert aligns
    with the true value only if the expert has not already
    a mistake in the past
    """
    filter_experts = experts == y
    weights = filter_experts * weights
    return weights, n_errors + 1
    

def step(state, xs):
    """
    At every timestep we compute the compute the
    predicted expert choice based on the still-credible
    experts. We then update the weights only if the aggregate
    prediction is the true prediction
    """
    weights, n_errors = state
    experts, y = xs
    
    expert_choice = (weights * experts).sum() / weights.sum()
    expert_choice = round(expert_choice) == y

    weights, n_errors = jax.lax.cond(
        expert_choice,
        lambda *_: (weights, n_errors),
        update_weights,
        y, weights, experts, n_errors
    )

    out = {
        "weights": weights,
    }

    new_state = (weights, n_errors) 
    return new_state, out

In [5]:
w_init = jnp.ones(n_experts + 1)
state_init = (w_init, 0)
xs = (experts, oracle)

(w_last, n_errors), hist = jax.lax.scan(step, state_init, xs)
weights_hist = hist["weights"]

In [6]:
n_errors

Array(4, dtype=int32, weak_type=True)

In [7]:
jnp.floor(jnp.log2(n_experts))

Array(5., dtype=float32)

In [8]:
weights_hist.sum(axis=1)

Array([33., 33., 16., 16.,  6.,  6.,  6.,  6.,  6.,  6.,  6.,  6.,  2.,
        2.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,
        1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,
        1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,
        1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,
        1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,
        1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,
        1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.], dtype=float32)

## Imperfect expert

In [12]:
key = jax.random.PRNGKey(314)

key_oracle, key_experts, key_noise = jax.random.split(key, 3)

n_experts = 2 ** 5
n_timesteps = 100
true_expert_where = 10

beta = 0.7
oracle = jax.random.bernoulli(key_oracle, p=0.5, shape=(n_timesteps,))
experts = jax.random.bernoulli(key_experts, p=0.5, shape=(n_timesteps, n_experts))

# Insert the one true expert to the list of experts
experts = jnp.insert(experts, true_expert_where, oracle, axis=1)

In [43]:
p_mistake = 0.1
mistakes = jax.random.bernoulli(key_noise, p=p_mistake, shape=(n_timesteps,))
true_expert  = oracle ^ mistakes # Flip bit if mistake