In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm

from typing import List

In [None]:
import regret_bounds

In [None]:
def plot_miscalibration_adjustment(ms:List[float], max_wr=10):
    ms = np.array(ms)
    rs = 1/ms
    R = np.linspace(0, rs.max()*1.5, 100)
    C = regret_bounds.get_miscalibration_adjustment(ms, R)
    candidates = np.concatenate([[0], rs])
    have_labelled = False # to avoid label proliferation
    for wr in np.arange(0, max_wr, 0.5):
        line_label = '$R \mathrm{WR} + C_R$' if not have_labelled else ''
        plt.plot(R, C + R*wr, linewidth=1, c=cm.Greens_r(wr/max_wr), label=line_label)
        scatter_label = 'Best $R$' if not have_labelled else ''
        candidate_scores = regret_bounds.get_miscalibration_adjustment(ms, candidates) + candidates*wr
        best = candidate_scores.argmin()
        plt.scatter(candidates[best], candidate_scores[best], c='red', s=3, zorder=10, label=scatter_label)
        have_labelled = True
    # TODO colorbar range
    plt.colorbar(cm.ScalarMappable(cmap=cm.Greens_r), label='WR')
    plt.ylim(0)
    plt.xlabel('$R$')
    plt.ylabel('Implied bound')
    plt.title(str(ms))
    plt.legend()

In [None]:
ms = np.array([1, 1, 1])
plot_miscalibration_adjustment(ms)

In [None]:
ms = np.array([1.2, 1])
plot_miscalibration_adjustment(ms)

In [None]:
ms = np.array([2, 1, 1, 1])
plot_miscalibration_adjustment(ms)

In [None]:
ms = np.array([1.5, 1.1, 1, 0.9, 0.8])
plot_miscalibration_adjustment(ms)

In [None]:
ms = np.array([2, 1.4, 1.3, 1.2, 1.1, 1, 0.9, 0.8, 0.7])
plot_miscalibration_adjustment(ms, max_wr=12)