# Hyperbatch Performance Analysis

In [None]:
import numpy as np
import matplotlib.pyplot as plt

# this sets up the Matplotlib interactive windows:
%matplotlib widget


`hyperbatch_efficinecy` sums the batch size calculations from the modified samplers
and returns `total_samples_for_regular_batch / total_samples_for_hyperbatch`

In [None]:
# assert that the schedule:
# - is an Iterator[tuple[current_size, appended_size | None, int]]
# - list(map(lambda x: x[-1], schedule)) == list(range(num_steps - 1))
# - map(lambda x: x[1], output) == range(num_steps - 1)
def batch_doubling_schedule_validate_and_final_size(schedule):
    current_index = 0
    current_size = 1
    for expected_current_size, appended_size, i in schedule:
        assert current_index == i, f"current_index has unexpected value: {current_index} != {i}"
        current_index += 1
        if appended_size is None:
            current_size *= 2
        else:
            current_size += appended_size
        assert expected_current_size == current_size, f"expected_current_size != current_size: {expected_current_size} != {current_size}"

    return current_size

# The doubling pattern has shape:
# d steps, 2x, d steps, 2x, .., 2x, d steps, leftover/2x, d steps
# K 2x’s/leftovers
# K*d + d = (K+1)*d = steps
# steps/(K + 1) = d
# zs = [0] * d
# If leftovers
# (zs + [None]) * (K-1) + zs + leftovers + zs
# Else
# (zs + [None]) * K + zs
# Note: this function asserts that
# - batch_doubling_schedule_validate_and_final_size's assertions holds
# - batch_doubling_schedule_validate_and_final_size(output) == batch_size
def batch_doubling_schedule(batch_size, num_steps):
    batch_size_log2 = batch_size.bit_length()
    batch_size_leftover = None
    if 2 ** batch_size_log2 != batch_size:
        batch_size_log2 -= 1
        batch_size_leftover = batch_size - 2 ** batch_size_log2

    if num_steps <= batch_size_log2:
        print(f"The number of steps must be greater than log2(batch_size) to use a Hyperbatch scheduler: disabling Hyperbatch functionality.")
        schedule = list(map(lambda i: (batch_size, 0, i), range(num_steps)))
        assert len(schedule) == num_steps, f"len(schedule) != num_steps: {len(schedule)} != {num_steps}: {batch_size_log2} {schedule}"
        return schedule

    substep_length = num_steps // (batch_size_log2 + 1)
    substeps = [0] * (substep_length - 1)
    schedule = (substeps + [None]) * batch_size_log2 + substeps + [batch_size_leftover]
    schedule += [0] * (num_steps - len(schedule))
    current_batch_size = 1
    def add_current_batch_size(i_appended_size):
        nonlocal current_batch_size
        i, appended_size = i_appended_size
        if appended_size is None:
            current_batch_size *= 2
        else:
            current_batch_size += appended_size
        return (current_batch_size, appended_size, i)

    schedule = list(map(add_current_batch_size, enumerate(schedule)))
    final_size = batch_doubling_schedule_validate_and_final_size(schedule)
    assert batch_size == final_size, f"batch_size not equal to current_size: {batch_size} != {final_size} \n {schedule}"
    assert len(schedule) == num_steps, f"len(schedule) != num_steps: {len(schedule)} != {num_steps}: {schedule}"
    return schedule

print(batch_doubling_schedule(8, 7))

for i in range(1, 10):
    for j in range(1, 10):
        batch_doubling_schedule(i, j)

def hyperbatch_efficiency(batch_size, num_steps):
    total_image_steps = 0
    for current_batch_size, appended_size, i in batch_doubling_schedule(batch_size, num_steps):
        total_image_steps += current_batch_size
    return float(batch_size * num_steps) / total_image_steps

plt.close('all')
max_batch_size = 128
batch_size_range = range(1, max_batch_size)
fig, axs = plt.subplots(4, 1, sharex=True, constrained_layout=True)
fig.set_size_inches(9.25, 5.25)
fig.suptitle('Hyperbatch efficiency at different step counts')
for i, ax in enumerate(axs):
    num_steps = 10 * (i + 2)
    ax.set_title(f"{num_steps} steps", loc='left')
    ax.plot(batch_size_range, list(map(lambda x: hyperbatch_efficiency(x, num_steps), batch_size_range)))
    ax.set_xlabel('Batch size')
    # xticks = range(1, max_batch_size, 10)
    # ax.set_xticks(xticks)
    ax.set_ylabel('Speedup')
