## Theorem 2.3: Exponentially-weighted forecaster with time-varying potential

* $\ell: \cdot\times\cdot \to [0, 1]$
* $L_{e,t} = \sum_{\tau=1}^T \ell(f_{e,\tau}, y_\tau)$
* $\eta_t = \sqrt{a (\log E) / t}$

* $w_{e, t} = \exp(-\eta_t L_{e, t-1})$
* $W_t = \sum_{e=1}^E w_{e,t}$

In [16]:
import jax
import jax.numpy as jnp
from plgx import datasets

In [4]:
%load_ext autoreload
%autoreload 2

In [11]:
def loss(yhat, y):
    return (yhat - y) ** 2

In [46]:
key = jax.random.PRNGKey(314)
n_experts = 5
n_timesteps = 20

In [47]:
oracle, experts = datasets.bern_oracle_beta_forecasters(key, n_experts, n_timesteps)

## Building weight terms

In [48]:
a = 8
eta_array = jnp.sqrt(a * jnp.log(n_experts) / jnp.arange(1, n_timesteps + 1))

In [49]:
losses_experts = jax.vmap(loss)(experts, oracle).cumsum(axis=0)
losses_experts_shift = jnp.roll(losses_experts, 1)
losses_experts_shift = losses_experts_shift.at[0].set(jnp.zeros(n_experts))

weights_experts = jnp.exp(- eta_array[:, None] * losses_experts_shift)
W = jnp.sum(weights_experts, axis=1)

## A value for subterms in (C)

In [77]:
eta_init = eta_array[0]
eta_final = jnp.sqrt(a * jnp.log(n_experts) / (n_timesteps + 1))

In [83]:
# wk0,0
v0 = jnp.log(weights_experts[0].max() / W[0]) / eta_init

# wkT,T
vT = jnp.log(weights_experts[-1].max() / W[-1]) / eta_final

In [85]:
vT - v0

Array(0.37117782, dtype=float32)

In [86]:
-v0

Array(0.4485337, dtype=float32)

In [60]:
W[0]

Array(5., dtype=float32)

In [58]:
weights_experts.max(axis=1) / W

Array([0.2       , 0.4469792 , 0.4409532 , 0.73271656, 0.7792208 ,
       0.74584234, 0.42947865, 0.43380028, 0.47606495, 0.69641376,
       0.7185959 , 0.84227335, 0.8290735 , 0.8161625 , 0.81379384,
       0.90331763, 0.89483076, 0.90217525, 0.94261026, 0.94122565],      dtype=float32)

In [59]:
weights_experts.min(axis=1) / W

Array([2.0000000e-01, 3.5347193e-02, 6.9975574e-03, 3.3681134e-03,
       1.2704872e-03, 2.1273864e-03, 1.8919660e-03, 7.6288939e-04,
       3.6385976e-04, 2.4733072e-04, 3.6934300e-04, 5.9761707e-04,
       2.8899414e-04, 3.8005877e-04, 1.9462570e-04, 2.8152476e-04,
       3.5490154e-04, 1.9150849e-04, 2.5072301e-04, 3.0837773e-04],      dtype=float32)