In [None]:
import itertools as it
import typing

from nbmetalog import nbmetalog as nbm
from matplotlib import pyplot as plt
import seaborn as sns
from tqdm import tqdm

from pylib import hanoi


In [None]:
nbm.print_metadata()


## Define Set Construction


In [None]:
def filter_retained(reserved_per_hanoi: int, cur_time: int) -> typing.Set[int]:
    n = reserved_per_hanoi
    T = cur_time
    return {
        t
        for t in range(T + 1)
        if hanoi.get_hanoi_value_incidence_at_index(t) < n
    }


In [None]:
def construct_retained_abstracted(
    reserved_per_hanoi: int, cur_time: int
) -> typing.Set[int]:
    n = reserved_per_hanoi
    T = cur_time
    generator = filter(
        lambda t: t <= T,
        (
            hanoi.get_hanoi_value_index_cadence(h) * i
            + hanoi.get_hanoi_value_index_offset(h)
            for h in range(hanoi.get_max_hanoi_value_through_index(T) + 1)
            for i in range(n)
        ),
    )
    return set(generator)


In [None]:
def construct_retained_naive(
    reserved_per_hanoi: int, cur_time: int
) -> typing.Set[int]:
    n = reserved_per_hanoi
    T = cur_time
    generator = filter(
        lambda t: t <= T,
        (
            2 ** (h + 1) * i + 2**h - 1
            for h in range(hanoi.get_max_hanoi_value_through_index(T) + 1)
            for i in range(n)
        ),
    )
    return set(generator)


In [None]:
def construct_retained_distilled(
    reserved_per_hanoi: int, cur_time: int
) -> typing.Set[int]:
    n = reserved_per_hanoi
    T = cur_time
    generator = filter(
        lambda t: t <= T,
        (
            i * 2**h - 1
            for h in range(hanoi.get_max_hanoi_value_through_index(T) + 1)
            for i in range(1, 2 * n + 1)
        ),
    )
    return set(generator)


## Visualize Constructed Sets


In [None]:
def plot_retained(retained: typing.Set[int], T: int) -> plt.Axes:
    ax = sns.rugplot(list(retained), height=1.0)
    ax.figure.set_size_inches(7, 1)
    ax.yaxis.set_visible(False)
    ax.axvline(x=0, color="black", linewidth=2, ls=":")
    ax.set_xlim(-1, T + 1)
    return ax


In [None]:
for n, T in [(1, 87), (5, 128), (3, 17), (4, 70)]:
    retained = filter_retained(n, T)
    plot_retained(retained, T)
    plt.show()
    print(f"n={n}, T={T}, len(retained)={len(retained)}")
    print(sorted(retained))
    print()


## Test Set Construction Equivalence


In [None]:
for n, T in tqdm(
    it.product(
        range(100),
        range(100),
    ),
    total=100 * 100,
):
    assert (
        filter_retained(n, T)
        == construct_retained_abstracted(n, T)
        == construct_retained_naive(n, T)
        == construct_retained_distilled(n, T)
    )
