In [1]:
import itertools
import time
import numpy as np

VOCAB_SIZE = 32000


def non_embed_params(d_model: int, num_layers: int) -> int:
    return 12 * num_layers * d_model**2


def embed_params(d_model: int) -> int:
    return 2 * VOCAB_SIZE * d_model


def total_params(d_model: int, num_layers: int) -> int:
    return non_embed_params(d_model, num_layers) + embed_params(d_model)


def chinchilla_n_for_c(c: int, tokens_per_param: float = 20) -> int:
    return int(np.sqrt(c / (6 * tokens_per_param)))


def get_shape_for_n(
    n_target: int, n_is_total: bool = False
) -> tuple[int, int, int, int]:
    """
    Find and return a transformer shape (d, L, h) that gets as close as possible to n_target.

    If `n_is_total` is True, `n_target` should represent the *total* number of parameters.
    Otherwise, `n_target` should represent the number of *non-embedding* parameters.

    If `n_is_total` is True:
        Solve 12Ld^2 + 2Vd = n, keeping L in [2, 24], d in [64, 1024], h in [2, 16]
    Otherwise:
        Solve 12Ld^2 = n, keeping L in [2, 24], d in [64, 1024], h in [2, 16]
    """
    closest = None, None, None, None  # d, l, h, n_star

    for d, L, h in itertools.product(range(64, 1025, 32), range(2, 25), range(2, 17)):
        aspect_ratio = d / L
        min_aspect_ratio = 48 if n_target < 1e8 else 64

        if aspect_ratio < min_aspect_ratio or aspect_ratio > 256:
            continue

        head_dim = d / h

        head_dim_ratio = 16 if n_target < 1e8 else 64

        if head_dim < 16 or head_dim > 128 or head_dim % head_dim_ratio != 0:
            continue

        if n_is_total:
            n_star = 12 * L * d**2 + 2 * VOCAB_SIZE * d
        else:
            n_star = 12 * L * d**2

        closest_n_star = closest[3]
        if not closest_n_star or abs(n_star - n_target) < abs(
            closest_n_star - n_target
        ):
            closest = d, L, h, n_star

    d, L, h, n_star = closest
    err = abs(n_star - n_target)
    err_pct = err * 100 / n_target

    return d, L, h, n_star, err, err_pct


c = 1e15

n_guess = chinchilla_n_for_c(
    c, tokens_per_param=20
)  # total params to hit 20 tokens per param

t0 = time.perf_counter()

d, L, h, n_realized, err, err_pct = get_shape_for_n(n_guess, n_is_total=True)
n_realized_non_embed = 12 * L * d**2
print(
    f"d: {d}, L: {L}, h: {h}, n_realized: {n_realized:,}, n_realized_non_embed: {n_realized_non_embed:,}, err: {err:,} err_pct: {err_pct:.2f}%"
)

print(f"Time taken: {time.perf_counter() - t0:.4f} seconds")

d: 96, L: 2, h: 2, n_realized: 6,365,184, n_realized_non_embed: 221,184, err: 3,478,433 err_pct: 120.50%
Time taken: 0.0014 seconds
