In [None]:
budget = 2e18

# High-level idea:
# - Assume batch_size of 128 (be more rigorous about this)
# - Find an optimal configuration (N_opt, LR_opt) at a small scale (a bunch of runs at 1e15)
# - Use that to predict the optimal configuration at the next scale
# - Refine the scaling law
# - Repeat until we hit our final scale (3e17)
# - Fit a scaling law between 3e16 and 3e17
# - Extrapolate to 1e19

# Want to find a and b s.t. n_opt(C) = a * C^b, and predict n_opt for 1e19

# Idea:
# 1. Find n_opt at smallest scale, starting from Hoffman estimate
#     a. Initial guess: N_opt(C_1) = 1.018e-01 * C_1^0.5 (so a = 1.018e-01, b = 0.5)
#     b. Train for 6 log-spaced N_1j points around this (it's cheap, and this is our first anchor)
#          i. If it doesn't look like the minimum is in this range, expand the range and try again
#     c. Choose a' = N_opt(C_1) / (C_1^b); keep b fixed until we can fit a new scaling law
# 2. Find N_opt at next scale
#     a. Use N_opt(C_2) = a' * C_2^b to get initial guess
#     b. Train for 6 log-spaced N_2j points around this
# 3. Then, iteratively:
#     a. Predict n_opt for next slice
#     b. Explore around that point (log-spaced N_ij values?)
#     c. Fit scaling law going back 1 OOM

slices = [10 * 1e15, 5 * 3e15, 5 * 6e15, 5 * 1e16, 5 * 3e16, 5 * 6e16, 4 * 1e17, 4 * 3e17]
n_per_slice = 1
total_compute = sum(n_per_slice * slices)

print(total_compute)
