# Introduction to `DiscreteHandler`

In [1]:
import os
import sys
import yaml
from typing import Tuple

import numpy as np
import matplotlib.pyplot as plt

sys.path.append(os.path.join(os.path.abspath(""), ".."))

from infovar import DiscreteHandler, StandardGetter

## Context

Imagine you receive a box with two displays showing a numerical value. The box also has three knobs that can be turned to increase or decrease a value.

As it happens, you're not the only one to have received such a box. In fact, 6 of your colleagues have also received a similar box. There's just one detail that sets them apart: on their respective boxes, one or more knobs are hidden. This makes it impossible to read the value of these knobs and turn them. The boxes are as follows, no two identical:
- 3 boxes with one of the three knobs hidden,
- 3 boxes with two of the three knobs hidden,
- your box, with all knobs visible.

When you turn one of the knobs, and put your ear to it, you'll remark that the hidden knobs also turn, a priori randomly. Another important detail is that, even in the case of the box with all the knobs, two similar configurations never give exactly the same value on the screens, even though they are generally quite close.

*TODO: insert images*

What you don't know is that these boxes have been sent to you by an impish statistician. The behavior of these boxes is actually governed by a simple non-deterministic mathematical formula:

$$ \begin{array}{c}s_1\\s_2\end{array} = \begin{array}{ll} (c_1-c_2)^2 + c_3 + \varepsilon_1 & \quad\text{s.t.}\quad\varepsilon_1\sim\mathcal{N}(0, 0.05)\\\mathrm{e}^{c_3} + \varepsilon_2 & \quad\text{s.t.}\quad\varepsilon_2\sim \mathcal{N}(0, 0.1)\end{array} $$

where $x_i$ is knob number $i$ and $y_j$ is display number $j$.

In [2]:
def function(
    x1: np.ndarray, x2: np.ndarray, x3: np.ndarray
) -> Tuple[np.ndarray, np.ndarray]:
    """
    ci = cursor n°i (between -1 and 1)
    """
    assert x1.shape == x2.shape == x3.shape
    assert (
        (np.abs(x1) <= 1).all() and (np.abs(x2) <= 1).all() and (np.abs(x3) <= 1).all()
    )

    y1 = (x1 - x2) ** 2 + x3 + np.random.normal(0, 0.05, x1.shape)
    y2 = x1 + np.random.normal(0, 0.1, x1.shape)
    return y1, y2

Votre objectif, ainsi que chacun de vos collègue, est dans la mesure du possible suivant sa boîte, d'essayer de quantifier l'influence des curseurs sur les valeurs affichées par l'écran. Pour cela, vous allez chacun noter les valeurs affichées par l'écran compte tenu des valeurs connues des curseurs. Les valeurs des curseurs seront échantillonnées uniformément entre -1 et 1.

## Getter

TODO

In [3]:
n_samples = 10_000
x1 = np.random.uniform(-1, 1, n_samples)
x2 = np.random.uniform(-1, 1, n_samples)
x3 = np.random.uniform(-1, 1, n_samples)

y1, y2 = function(x1, x2, x3)

getter = StandardGetter(
    ["x1", "x2", "x3"],
    ["y1", "y2"],
    np.column_stack((x1, x2, x3)),
    np.column_stack((y1, y2)),
)

## Discrete handler

Le `DiscreteHandler` est l'outil qui va vous permettre d'analyser statistiquement l'influence des curseurs sur les valeurs affichées sur l'écran.

In [4]:
handler = DiscreteHandler()

handler.set_path("handlers")
handler.set_getter(getter.get)

In [None]:
with open(os.path.join("handlers", "restrictions.yaml"), "r") as f:
    restrictions = yaml.safe_load(f)

handler.set_restrictions(restrictions)

In [None]:
handler.remove(["y1"])

In [None]:
print("Existing saves:")
for file in handler.get_existing_saves():
    print(file)

In [None]:
inputs_dict = {"restrictions": ["all"], "min_samples": 200, "statistics": ["mi"]}

for s in ["y1", "y2"]:
    for c in ["x1", "x2", "x3"]:
        handler.overwrite(c, s, inputs_dict)

In [None]:
handler.overwrite.__annotations__

In [None]:
print("Existing saves:")
for file in handler.get_existing_saves():
    print(file)

In [None]:
items = handler.read(["x1", "x2", "x3"], "y1", "all", iterable_x=True)
mis = [item["mi"]["value"] for item in items]

plt.figure(figsize=(6.4, 0.5 * 4.8))

plt.bar(np.arange(len(mis)), mis)
plt.xticks(np.arange(len(mis)), ["x1", "x2", "x3"])
plt.title("Mutual information (bits)")

plt.show()

In [None]:
items = handler.read(["x1", "x2", "x3"], "s2", "all", iterable_x=True)
mis = [item["mi"]["value"] for item in items]

plt.figure(figsize=(6.4, 0.5 * 4.8))

plt.bar(np.arange(len(mis)), mis)
plt.xticks(np.arange(len(mis)), ["x1", "x2", "x3"])

plt.show()