In [None]:
import itertools as it
import typing

from nbmetalog import nbmetalog as nbm
import numpy as np
from matplotlib import pyplot as plt
import seaborn as sns
from tqdm import tqdm

from pylib import hanoi


In [None]:
nbm.print_metadata()


## Define Set Construction


In [None]:
def filter_retained(reserved_per_hanoi: int, cur_time: int) -> typing.Set[int]:
    n = int(reserved_per_hanoi)  # handle numpy dtypes...
    T = int(cur_time)
    return {
        t
        for t in range(T + 1)
        if hanoi.get_incidence_count_of_hanoi_value_through_index(
            hanoi.get_hanoi_value_at_index(t), T
        )
        - 1
        - hanoi.get_hanoi_value_incidence_at_index(t)
        < n
    }


In [None]:
def construct_retained_abstracted(
    reserved_per_hanoi: int, cur_time: int
) -> typing.Set[int]:
    n = int(reserved_per_hanoi)  # handle numpy dtypes...
    T = int(cur_time)
    generator = filter(
        lambda t: t >= 0,
        (
            hanoi.get_hanoi_value_index_cadence(h)
            * (
                (T - hanoi.get_hanoi_value_index_offset(h))
                // hanoi.get_hanoi_value_index_cadence(h)
                - i
            )
            + hanoi.get_hanoi_value_index_offset(h)
            for h in range(hanoi.get_max_hanoi_value_through_index(T) + 1)
            for i in range(n)
        ),
    )
    return set(generator)


In [None]:
def construct_retained_naive(
    reserved_per_hanoi: int, cur_time: int
) -> typing.Set[int]:
    n = int(reserved_per_hanoi)  # handle numpy dtypes...
    T = int(cur_time)
    generator = filter(
        lambda t: t >= 0,
        (
            2 ** (h + 1) * ((T - 2**h + 1) // 2 ** (h + 1) - i) + 2**h - 1
            for h in range(hanoi.get_max_hanoi_value_through_index(T) + 1)
            for i in range(n)
        ),
    )
    return set(generator)


In [None]:
def construct_retained_distilled(
    reserved_per_hanoi: int, cur_time: int
) -> typing.Set[int]:
    n = int(reserved_per_hanoi)  # handle numpy dtypes...
    T = int(cur_time)
    generator = filter(
        lambda t: t >= 0,
        (
            2**h * ((T + 1) // 2**h - i) - 1
            for h in range(hanoi.get_max_hanoi_value_through_index(T) + 1)
            for i in range(2 * n)
        ),
    )
    return set(generator)


In [None]:
def calc_distilled2naive(h: int, i: int, T: int, n: int) -> int:
    distilled = 2**h * ((T + 1) // 2**h - i) - 1
    if distilled < 0 or distilled > T:
        return -1

    H = hanoi.get_hanoi_value_at_index(distilled)
    assert H >= h

    Z = ((T - 2**H + 1) // 2 ** (h)) / 2 ** (H + 1 - h) - (
        (T - 2**H + 1) // 2 ** (H + 1)
    )
    assert 0 <= Z < 1

    assert (Z - i / 2 ** (H - h + 1)).is_integer()
    assert 0 <= -(Z - i / 2 ** (H - h + 1)) < n

    return (
        2 ** (H + 1)
        * (((T - 2**H + 1) // 2 ** (H + 1)) + Z - i / 2 ** (H - h + 1))
        + 2**H
        - 1
    )


def construct_retained_distilled2naive(
    reserved_per_hanoi: int, cur_time: int
) -> typing.Set[int]:
    n = int(reserved_per_hanoi)  # handle numpy dtypes...
    T = int(cur_time)
    generator = filter(
        lambda t: t >= 0 and t <= T,
        (
            calc_distilled2naive(h, i, T, n)
            for h in range(hanoi.get_max_hanoi_value_through_index(T) + 1)
            for i in range(2 * n)
        ),
    )
    return set(generator)


## Visualize Constructed Sets


In [None]:
def plot_retained(retained: typing.Set[int], T: int) -> plt.Axes:
    ax = sns.rugplot(list(retained), height=1.0)
    ax.figure.set_size_inches(7, 1)
    ax.yaxis.set_visible(False)
    ax.axvline(x=T, color="black", linewidth=2, ls=":")
    ax.set_xlim(-1, T + 1)
    return ax


In [None]:
for n, T in [(1, 87), (5, 128), (3, 17), (4, 70)]:
    retained = construct_retained_distilled(n, T)
    plot_retained(retained, T)
    plt.show()
    print(f"n={n}, T={T}, len(retained)={len(retained)}")
    print(sorted(retained))
    print()


## Test Set Construction Equivalence


In [None]:
for n, T in tqdm(
    it.chain(
        it.product(
            range(100),
            range(100),
        ),
        zip(
            np.random.RandomState(seed=1).randint(500, size=2000),
            np.random.RandomState(seed=1).randint(100, 10001, size=2000),
        ),
    ),
    total=100 * 100 + 2000,
):
    assert (
        filter_retained(n, T)
        == construct_retained_abstracted(n, T)
        == construct_retained_naive(n, T)
        == construct_retained_distilled(n, T)
        == construct_retained_distilled2naive(n, T)
    )
