In [2]:
from cs336_scaling.common import (
    get_chinchilla_power_law_n_for_c,
    get_chinchilla_n_for_c,
    pick_candidates_around_n,
    get_shape_given_n,
    print_predicted_shapes,
    get_chinchilla_lr_for_n,
    get_shape_for_n_custom,
)

from cs336_scaling.constants import BATCH_SIZE, FLOPS_BUDGET
from cs336_scaling.training_api import get_loss, sync_api_state, get_total_flops_used
from cs336_scaling.analyze import best_run_where, fit_quadratic, find_optimal_params, plot_runs, get_all_runs

import numpy as np

In [3]:
def print_flops_stats(total_flops_used: float | None = None):
    if total_flops_used is None:
        total_flops_used = get_total_flops_used()

    print(f"Total FLOPs Used: {total_flops_used:.2e}")
    print(f"Total FLOPs Used (%): {total_flops_used * 100 / FLOPS_BUDGET:.3f}")
    print("-" * 100)


def print_stats_at_c(
    c: float,
    all_runs: list[dict] | None = None,
    best_n: int | None = None,
    sync_api: bool = False,
    best_lr_per_n: bool = True,
):
    if all_runs is None and sync_api is False:
        raise ValueError("Must provide either all_runs or sync_api")

    if sync_api:
        all_runs, total_flops_used = sync_api_state()

    runs_at_c = [run for run in all_runs if run["train_flops"] == c]

    if len(runs_at_c) == 0:
        print(f"No runs found for C = {c:.0e}")
        return

    runs_to_fit = runs_at_c

    if best_lr_per_n:
        group_by_n = group_by(runs_at_c, "est_n_non_embedding")
        for n in group_by_n:
            group_by_n[n] = sorted(group_by_n[n], key=lambda x: x["loss"])[0]

        runs_to_fit = list(group_by_n.values())

    if best_n is not None:
        runs_to_fit = sorted(runs_at_c, key=lambda x: x["loss"])[:best_n]

    fit_fn = fit_quadratic(runs_to_fit)
    coeffs = fit_fn.coeffs
    best_run = min(runs_at_c, key=lambda x: x["loss"])
    best_run_n = best_run["est_n_non_embedding"]
    best_run_n_total = best_run["est_n_total"]
    chinchilla_lr = get_chinchilla_lr_for_n(best_run_n_total)
    pred_opt_n = find_optimal_params(fit_fn)
    pred_opt_loss = fit_fn(np.log(pred_opt_n))

    print(f"Best for C = {c:.0e}:")

    print(f"- Non-embedding params: {best_run_n:.2e} ({best_run_n} params)")
    print(f"- LR: {best_run['learning_rate']:.3e} (chinchilla: {chinchilla_lr:.3e})")
    print(f"- Loss: {best_run['loss']:.5f}")
    print(f"- Total params: {best_run_n_total:.2e} ({best_run_n_total} params)")
    print(f"- Embedding ratio: {best_run['est_embed_ratio']:.2f}")
    print(
        f"- Tokens: {best_run['est_tokens']:.2e} ({int(best_run['est_tokens'])} tokens)"
    )
    print(f"- Tokens per param (D/N): {best_run['est_tokens_per_param']:.3f}")
    print(f"- Aspect ratio: {best_run['d_model'] / best_run['num_layers']:.2f}")
    print(best_run)
    print("-" * 100)

    print(f"Quadratic fit for all runs at {c:.0e} FLOPs:")
    print(f"{coeffs[0]:.4f}x^2 + {coeffs[1]:.4f}x + {coeffs[2]:.4f}")

    print(f"Pred. optimal params: {pred_opt_n:.2e}")
    print(f"Pred. loss at optimal params: {pred_opt_loss:.5f}")

    print("-" * 100)

In [3]:
def sweep_n_at_c(c: float, n_guess: int, factor: float = 5, n_candidates: int = 5, dry_run: bool = False):
    candidates = pick_candidates_around_n(
        n_guess, factor=factor, n_candidates=n_candidates, round_to_int=True
    )
    print_predicted_shapes(ns=candidates, c=c)

    for n in candidates:
        d, L, h, n_star = get_shape_given_n(n)
        lr = get_chinchilla_lr_for_n(n_star)
        bs = BATCH_SIZE

        # res = {"loss": 0.0, "total_flops_used": 0}

        if dry_run:
            print(f"Would run with N={n:.2e}, lr={lr:.3e} at C={c:.0e}")
        else:
            get_loss(
                d_model=d,
                num_layers=L,
                num_heads=h,
                batch_size=bs,
                learning_rate=lr,
            train_flops=c,
        )

    print_stats_at_c(c, sync_api=True)
    print_flops_stats()
    plot_runs()


def sweep_lr_at_cn(
    c: int, n: int, lr_guess: float, factor: float = 3, n_candidates: int = 5, dry_run: bool = False
):
    lr_candidates = pick_candidates_around_n(
        lr_guess, factor=factor, n_candidates=n_candidates, round_to_int=False
    )
    lr_candidates = [c for c in lr_candidates if c >= 1e-4 and c <= 1e-3]

    for lr in lr_candidates:
        d, L, h, n_star = get_shape_given_n(n)
        bs = BATCH_SIZE

        if dry_run:
            print(f"Would run with lr={lr:.3e}, N={n:.2e} at C={c:.0e}")
        else:
            get_loss(
                d_model=d,
                num_layers=L,
                num_heads=h,
                batch_size=bs,
                learning_rate=lr,
            train_flops=c,
        )

    print_stats_at_c(c, sync_api=True)
    print_flops_stats()
    plot_runs()

In [5]:
# INITIAL RUNS CENTERED AROUND HOFFMAN/2 GUESS (running one at a time to save FLOPs)
C_1 = 1e16
N_1_guess = int(
    get_chinchilla_power_law_n_for_c(C_1) * 0.05
)  # scaled down chinchilla prediction
# sweep_n_at_c(C_1, N_1_guess, factor=5, n_candidates=5, dry_run=True)

print(f"N_1_guess: {N_1_guess:.2e} ({N_1_guess})")

candidates = pick_candidates_around_n(
    N_1_guess, factor=9, n_candidates=9, round_to_int=True
)

# for n in candidates:

ratio = candidates[-1] / candidates[-2]

candidate = N_1_guess  # index 2
# candidate = candidates[1] # better
# candidate = candidates[0] # better

print(f"CANDIDATE: {candidate:.2e} ({candidate})")

print_predicted_shapes(ns=candidates, get_shape_fn=get_shape_given_n, c=C_1)

# ================================

print("-" * 100)

N_1_total_guess = get_chinchilla_n_for_c(C_1, tokens_per_param=20)
d, L, h, n_star_total, *_ = get_shape_for_n_custom(N_1_total_guess, n_is_total=True)

N_1_guess = 12 * L * d**2  # Realised non-embedding parameter count
print(f"N_1_guess: {N_1_guess:.2e} ({N_1_guess})")

candidates = pick_candidates_around_n(
    N_1_guess, factor=9, n_candidates=9, round_to_int=True
)

ratio = candidates[-1] / candidates[-2]
candidate = N_1_guess  # index 2
# candidate = candidates[1] # better
# candidate = candidates[0] # better

print(f"CANDIDATE: {candidate:.2e} ({candidate})")

print_predicted_shapes(ns=candidates, get_shape_fn=get_shape_given_n, c=C_1)

# ================================


# d, L, h, n_star = get_shape_given_n(candidate)
# lr = get_chinchilla_lr_for_n(n_star)
# bs = BATCH_SIZE

res = {"loss": 0.0, "total_flops_used": 0}

# res = get_loss(
#     d_model=d,
#     num_layers=L,
#     num_heads=h,
#     batch_size=bs,
#     learning_rate=lr,
#     train_flops=C_1,
# )

if not res.get("loss"):
    print(res)

all_runs = get_all_runs(sync_api=True)
runs_at_c = [run for run in all_runs if run["train_flops"] == C_1]

print_stats_at_c(c=C_1, all_runs=runs_at_c)

print_flops_stats()
if len(runs_at_c) > 0:
    plot_runs(runs_at_c)

N_1_guess: 5.09e+05 (509000)
CANDIDATE: 5.09e+05 (509000)
Computed shapes:
  idx	     c	       n	  n_star	     err	err_pct	   tok/n	   d	  L	  d/L	  h	  pred_lr	embed_ratio	   tokens	tok/n_star
    0	 1e+16	5.66e+04	5.53e+04	1.26e+03	  2.23%	  170.42	  48	  2	24.00	  3	 4.09e-04	     0.98	 5.33e+08	 9637.98
    1	 1e+16	9.80e+04	9.83e+04	3.47e+02	  0.35%	   94.74	  64	  2	32.00	  4	 3.86e-04	     0.98	 3.97e+08	 4042.20
    2	 1e+16	1.70e+05	1.54e+05	1.61e+04	  9.47%	   59.93	  80	  2	40.00	  5	 3.69e-04	     0.97	 3.16e+08	 2057.55
    3	 1e+16	2.94e+05	3.01e+05	7.18e+03	  2.44%	   29.88	 112	  2	56.00	  7	 3.45e-04	     0.96	 2.23e+08	  741.20
    4	 1e+16	5.09e+05	4.98e+05	1.13e+04	  2.23%	   17.66	 144	  2	72.00	  9	 3.28e-04	     0.95	 1.72e+08	  344.77
    5	 1e+16	8.82e+05	8.85e+05	3.12e+03	  0.35%	    9.60	 192	  2	96.00	 12	 3.10e-04	     0.93	 1.27e+08	  143.01
    6	 1e+16	1.53e+06	1.57e+06	4.59e+04	  3.00%	    5.17	 256	  2	128.00	  8	 2.92e-04	     0.91	 9.28e+07	   59.01


In [None]:
# REFINEMENT RUNS CENTERED AROUND N_1_OPT_SO_FAR (factor=2)
# From best run at C_1 (which was better than quadratic pred. optimal)
N_1_opt_so_far = best_run_where("train_flops", C_1)["est_n_non_embedding"]
sweep_n_at_c(C_1, N_1_opt_so_far, factor=2, n_candidates=5, dry_run=True)

In [None]:
# LEARNING RATE EXPLORATION USING N_1_OPT_SO_FAR
# From new best run at C_1 (better than quadratic pred. optimal)
N_1_opt_so_far = 884736
lr_guess = get_chinchilla_lr_for_n(N_1_opt_so_far)
sweep_lr_at_cn(C_1, N_1_opt_so_far, lr_guess, factor=4, n_candidates=11, dry_run=True)

In [None]:
# LEARNING RATE REFINEMENT EXPLORATION USING LR_OPT_SO_FAR (for N_1_OPT_SO_FAR)
# From new best run at C_1 (better than quadratic pred. optimal)
N_1_opt_so_far = 884736
lr_opt_so_far = 5.391e-04
sweep_lr_at_cn(C_1, N_1_opt_so_far, lr_opt_so_far, factor=1.5, n_candidates=5, dry_run=True)