In [11]:
import _pickle as pickle
import math
import matplotlib.pyplot as plt
import matplotlib.ticker as plticker
import numpy as np
import os
import seaborn as sns
import pandas as pd
import timeit


from itertools import product

from utils import set_size, pgf_with_latex

plt.style.use("seaborn-v0_8-colorblind")

# Using the set_size function as defined earlier
doc_width_pt = 452.9679
plt.rcParams.update(pgf_with_latex)

# ERM

In [None]:
n_samples = np.floor(1 * (2 ** np.arange(6, 18)) ** 2)
n_context = 8

p_high = np.arange(1, 10, 2)
p_high = p_high / (p_high + 1)
p_high  = np.concatenate((p_high, np.array([0.99, 0.999])))

print(n_samples)
print(p_high)

In [3]:
n_high = np.floor(n_samples[:, None] * p_high[None, :])
n_low = np.floor(n_samples[:, None] * (1 - p_high[None, :]))

In [4]:
iw_constant = 10
error_iw_high = iw_constant / np.sqrt(n_high)
error_iw_low = iw_constant / np.sqrt(n_low)

In [5]:
p_relevant_high = (n_context - 1) / n_context + p_high / n_context
p_relevant_low = (n_context - 1) / n_context + (1 - p_high) / n_context

n_relevant_high = np.floor(n_high * p_relevant_high)
n_relevant_low = np.floor(n_low * p_relevant_low)

In [None]:
p_relevant_low, p_relevant_high

In [7]:
ic_constant = 0.00001
error_ic_high = (1 - p_relevant_high) + ic_constant / np.sqrt(n_relevant_high)
error_ic_low = (1 - p_relevant_low) + ic_constant / np.clip(np.sqrt(n_relevant_low), a_min=1e-7, a_max=np.inf)

In [None]:
np.concatenate((error_ic_high.T[..., None], error_iw_high.T[..., None]), axis=-1)

In [None]:
np.concatenate((error_ic_low.T[..., None], error_iw_low.T[..., None]), axis=-1)

In [10]:
alpha_high = np.argmin(
    np.concatenate((error_ic_high[..., None], error_iw_high[..., None]), axis=-1),
    axis=-1
).T
alpha_low = np.argmin(
    np.concatenate((error_ic_low[..., None], error_iw_low[..., None]), axis=-1),
    axis=-1
).T

In [None]:
alpha_high

In [None]:
alpha_low

In [None]:
n_high[0]

In [None]:
p_high

In [None]:
n_samples.shape

In [None]:
for p_high_i, curr_p_high in enumerate(p_high):
    plt.plot(
        np.log2(n_samples),
        alpha_high[p_high_i] + p_high_i * 0.01,
        label="{:.3f}".format(curr_p_high),
        marker="x",
        linestyle="--",
        alpha=1.0
    )

plt.title("High freq.")
plt.xlabel("$\\log_2$ num. samples")
plt.ylabel("$\\alpha$")
plt.legend()

In [None]:
for p_low_i, curr_p_high in enumerate(p_high):
    plt.plot(
        np.log2(n_samples),
        alpha_low[p_low_i] + p_low_i * 0.01,
        label="{:.3f}".format(1 - curr_p_high),
        marker="x",
        linestyle="--",
        alpha=1.0
    )

plt.title("Low freq.")
plt.xlabel("$\\log_2$ num. samples")
plt.ylabel("$\\alpha$")
plt.legend()

# ERM 2

IW Predictor:
$$
  R_{{D_x}}(\hat{g}) \leq \min_{y^*} R_{{D_x}}(y^*) + \mathcal{O}\left( \sqrt{ \frac{\log(2\lvert \mathcal{X} \rvert / \delta)}{N_x} } \right)
$$

IC Predictor:
For $L$ contexts and $k$ irrelevant contexts:
$$
  \frac{2k}{k + (L - k)\exp{(4)}} \leq CE(h(\hat{x}), y) \leq \log \frac{L}{L - k}
$$

With $k = 0$,
$$
  CE(h(\hat{x}), y) = \log \frac{1}{\varepsilon}
$$
where $\varepsilon > 0$ is the minimum probability for all classes

In [37]:
L = 8
label_noise = 0.01
label_noise = max(label_noise, 1 - label_noise)

IW

In [None]:
best_error = -label_noise * np.log(label_noise) - (1 - label_noise) * np.log(1 - label_noise)
best_error

In [39]:
n_samples = np.floor(1 * (2 ** np.arange(6, 18)) ** 2)

In [40]:
ic_constant = 10

In [41]:
iw_errors = best_error + ic_constant * np.sqrt(1 / n_samples)

IC

In [42]:
ic_min = 0.01

In [43]:
ic_errors = np.zeros((L + 1, 2))
ic_errors[-1] = np.log(1 / ic_min)

In [44]:
for k in np.arange(0, L):
    ic_errors[k] = [
        2 * k / (k + (L - k) * np.exp(4)),
        np.log(L / (L - k))
    ]

Plot

In [None]:
num_rows = 2
num_cols = L // 2

fig, axes = plt.subplots(
    num_rows,
    num_cols,
    figsize=set_size(doc_width_pt, 0.95, (num_rows, num_cols), use_golden_ratio=True),
    layout="constrained",
)

for k in np.arange(L):
    row_i = k // num_cols
    col_i = k % num_cols
    ax = axes[row_i, col_i]
    midpoint = np.mean(ic_errors[k])
    ax.axhline(ic_errors[k, 0], linestyle="--", color="red", label="IC Lower Bound" if k == 0 else "")
    # ax.axhline(ic_errors[k, 1], linestyle="--", color="red")
    ax.plot(np.log2(n_samples), iw_errors, label="IW Upper Bound" if k == 0 else "")
    ax.set_ylim(-0.01, max(np.max(ic_errors[:L, 0]), np.max(iw_errors)) + 0.1)

    ax.set_title("$k = {}$".format(k), fontsize="8",)
    loc = plticker.MultipleLocator(base=5.0) # this locator puts ticks at regular intervals
    ax.xaxis.set_major_locator(loc)

    if row_i < num_rows - 1:
        ax.set_xticks([])
    if col_i > 0:
        ax.set_yticks([])


fig.supxlabel("$N_x$ (in $\\log_2$)", fontsize="8",)
fig.supylabel("Loss", fontsize="8",)
fig.legend(
    bbox_to_anchor=(0.0, 1.0, 1.0, 0.0),
    loc="lower center",
    ncols=4,
    borderaxespad=0.0,
    frameon=True,
    fontsize="8", 
)
# plt.plot()

plt.savefig("toy_example-errors.pdf", dpi=600, format="pdf", bbox_inches="tight")