# 02 - Scaling law predictions

Use Chinchilla-style scaling laws (Hoffmann et al., 2022) to estimate cross-entropy loss, perplexity, and approximate token accuracy for our ~40M parameter model.

**Formula**

We assume the loss follows `L(N, D) = L_∞ + (N / N₀)^-α + (D / D₀)^-β`.
- `N` = model parameters
- `D` = tokens seen during training
- `L_∞ = 1.69`, `α = 0.34`, `β = 0.28`, `N₀ = 6.7e3`, `D₀ = 1.8e5` (values fit by Hoffmann et al.)
- Perplexity = `exp(L)` and an approximate token-level accuracy = `exp(-L)`

> Accuracy here is just the expected probability assigned to the correct token (the exponential of the negative loss). It is *not* a sequence accuracy metric but is handy for intuition.

In [4]:
from dataclasses import dataclass
from math import exp


def format_number(value: float) -> str:
    """Pretty-print helper for large numbers."""
    if value >= 1e9:
        return f"{value/1e9:.2f}B"
    if value >= 1e6:
        return f"{value/1e6:.2f}M"
    if value >= 1e3:
        return f"{value/1e3:.2f}K"
    return f"{value:.0f}"


@dataclass
class ChinchillaScalingLaw:
    alpha: float = 0.34
    beta: float = 0.28
    loss_floor: float = 1.69
    n_ref: float = 6.7e3
    d_ref: float = 1.8e5

    def predict(self, params: float, tokens: float) -> dict:
        if params <= 0 or tokens <= 0:
            raise ValueError("params and tokens must be > 0")
        n_term = (params / self.n_ref) ** (-self.alpha)
        d_term = (tokens / self.d_ref) ** (-self.beta)
        loss = self.loss_floor + n_term + d_term
        return {
            "loss": loss,
            "perplexity": exp(loss),
            "token_accuracy": exp(-loss),
        }


def summarize_predictions(params: float, token_scenarios: dict[str, float]) -> list[dict]:
    law = ChinchillaScalingLaw()
    rows = []
    for label, tokens in token_scenarios.items():
        pred = law.predict(params, tokens)
        rows.append({
            "scenario": label,
            "tokens": tokens,
            "loss": pred["loss"],
            "perplexity": pred["perplexity"],
            "token_accuracy": pred["token_accuracy"],
        })
    return rows


In [5]:
PARAM_COUNT = 35_763_840  # ParrotLLM config
SECONDS_PER_RUN = 23 * 3600  # <=24h compute budget, we plan for 23h
TOKEN_SCENARIOS = {
    "Chinchilla optimal (20 tokens/param)": PARAM_COUNT * 20,
    "Conservative 23h throughput": 50_000 * SECONDS_PER_RUN,
    "Moderate 23h throughput": 100_000 * SECONDS_PER_RUN,
    "Optimistic 23h throughput": 150_000 * SECONDS_PER_RUN,
}

rows = summarize_predictions(PARAM_COUNT, TOKEN_SCENARIOS)

header = f"{'Scenario':40s} | {'Tokens':>12s} | {'Loss':>6s} | {'PPL':>8s} | {'Token Acc.':>11s}"
print(header)
print("-" * len(header))
for row in rows:
    print(
        f"{row['scenario']:40s} | "
        f"{format_number(row['tokens']):>12s} | "
        f"{row['loss']:.2f} | "
        f"{row['perplexity']:.2f} | "
        f"{row['token_accuracy']:.3f}"
    )


Scenario                                 |       Tokens |   Loss |      PPL |  Token Acc.
-----------------------------------------------------------------------------------------
Chinchilla optimal (20 tokens/param)     |      715.28M | 1.84 | 6.31 | 0.158
Conservative 23h throughput              |        4.14B | 1.80 | 6.07 | 0.165
Moderate 23h throughput                  |        8.28B | 1.79 | 6.01 | 0.166
Optimistic 23h throughput                |       12.42B | 1.79 | 5.98 | 0.167


In [6]:
def predict_for_custom_inputs(params: float, tokens: float) -> None:
    law = ChinchillaScalingLaw()
    pred = law.predict(params, tokens)
    print(f"Params: {params:,.0f} | Tokens: {tokens:,.0f}")
    print(f"Loss: {pred['loss']:.3f}")
    print(f"Perplexity: {pred['perplexity']:.2f}")
    print(f"Token accuracy: {pred['token_accuracy']:.3f}")


# Example: tweak params/tokens here before running the cell
predict_for_custom_inputs(params=40_000_000, tokens=12_000_000_000)


Params: 40,000,000 | Tokens: 12,000,000,000
Loss: 1.787
Perplexity: 5.97
Token accuracy: 0.168
