# Fetch the Dataset

In [None]:
import wandb
import pandas as pd

def parse_count(s: str) -> int:
    s = s.strip().upper()
    if s.endswith("B"):
        return int(float(s[:-1]) * 1_000_000_000)
    elif s.endswith("M"):
        return int(float(s[:-1]) * 1_000_000)
    elif s.endswith("K"):
        return int(float(s[:-1]) * 1_000)
    return int(float(s))

api, data = wandb.Api(), []
frontier = pd.read_csv("https://docs.google.com/spreadsheets/d/1sIr9HRwYbUXKzlskUTMorMa2A_cAzDwE0eUJnk-W1dQ/export?format=csv&gid=1059339506")

for run in api.runs("haok/flame-moe", {"group": {"$regex": "ablation"}}):
    if run.state != "finished": continue
    flops = run.group.split("-").pop()
    loss = run.summary["lm loss validation"]
    num_layers, hidden_size = run.config["num_layers"], run.config["hidden_size"]
    selected = frontier[(frontier["num_layers"] == num_layers) & (frontier["hidden_size"] == hidden_size)]
    active_params, total_params = selected.iloc[0]["active_params"], selected.iloc[0]["total_params"]
    active_params, total_params = parse_count(active_params), parse_count(total_params)
    data.append((flops, active_params, total_params, loss))

df = pd.DataFrame(data, columns=["flops", "active_params", "total_params", "loss"])
df


# Optimize the Coefficients

In [None]:
from math import exp
import warnings

warnings.filterwarnings("ignore")

from tqdm import tqdm
from itertools import product
from scipy.optimize import minimize
import numpy as np

# Step 1: Define the scaling law model
def scaling_law(params, N, D):
    E, A, alpha, B, beta = params
    return E + A / (N**alpha) + B / (D**beta)

# Step 2: Define the loss function
def mse_loss(params, N, D, targets):
    preds = scaling_law(params, N, D)
    return np.mean((preds - targets) ** 2)

def huber_loss(params, N, D, targets, delta=1e-3):
    preds = scaling_law(params, N, D)
    error = np.log(targets) - np.log(preds)
    is_small = np.abs(error) <= delta
    squared = 0.5 * error**2
    linear = delta * (np.abs(error) - 0.5 * delta)
    return np.mean(np.where(is_small, squared, linear))

# Step 3: Prepare your data
N = df["active_params"].values
D = (df["flops"].astype(float) / (6 * df["active_params"])).values
targets = df["loss"].values

# Step 4: Set an initial guess
E_range = [exp(-1), exp(-0.5), exp(0), exp(0.5), exp(1.0)]
A_range = [exp(0), exp(5), exp(10), exp(15), exp(20), exp(25)]
alpha_range = [0, 0.5, 1, 1.5, 2]
B_range = [exp(0), exp(5), exp(10), exp(15), exp(20), exp(25)]
beta_range = [0, 0.5, 1, 1.5, 2]
initial_guesses = list(product(E_range, A_range, alpha_range, B_range, beta_range))

# Step 5: Fit using L-BFGS-B
best_result = None
lowest_mse = np.inf

for guess in tqdm(initial_guesses):
    result = minimize(huber_loss, x0=guess, args=(N, D, targets), method="L-BFGS-B")
    if result.success:
        current_mse = huber_loss(result.x, N, D, targets)
        if current_mse < lowest_mse:
            lowest_mse = current_mse
            best_result = result

# Step 6: Show the results
if best_result:
    E_opt, A_opt, alpha_opt, B_opt, beta_opt = best_result.x
    print(f"E     = {E_opt:.6f}")
    print(f"A     = {A_opt:.6f}")
    print(f"alpha = {alpha_opt:.6f}")
    print(f"B     = {B_opt:.6f}")
    print(f"beta  = {beta_opt:.6f}")
    print(f"Loss  = {lowest_mse:.6f}")
else:
    print("Optimization failed for all initial guesses.")


# Find the Models

In [7]:
import pandas as pd

def parse_count(s: str) -> int:
    s = s.strip().upper()
    if s.endswith("B"):
        return int(float(s[:-1]) * 1_000_000_000)
    elif s.endswith("M"):
        return int(float(s[:-1]) * 1_000_000)
    elif s.endswith("K"):
        return int(float(s[:-1]) * 1_000)
    return int(float(s))

def scaling_law(N, D):
    # Obtained from MSE 
    # E, A, alpha, B, beta = 2.133594, 65.254366, 0.226854, 485165195.409790, 0.949495
    # Obtained from Huber
    E, A, alpha, B, beta = 2.242056, 148.412995, 0.279724, 3269017.372472, 0.715505
    return E + A / (N**alpha) + B / (D**beta)

# Step 1: Load and transform
df = pd.read_csv("https://docs.google.com/spreadsheets/d/1sIr9HRwYbUXKzlskUTMorMa2A_cAzDwE0eUJnk-W1dQ/export?format=csv&gid=599230821")

# Step 2: Define budgets
budgets = [6e18, 1e19, 3e19, 6e19, 1e20, 3e20, 6e20, 1e21, 3e21, 6e21]
budgets = [2e19, 8e19, 2.4e20, 7.2e20, 9.4e20, 1.2e21, 5.7e21, 1.1e22]

# Step 3: Compute predicted loss and best configs
merged = []
for budget in budgets:
    df["FLOPs"] = budget
    df["tokens"] = budget / (6 * df["active_params"].apply(parse_count))
    df["predicted_loss"] = scaling_law(df["active_params"].apply(parse_count), df["tokens"])
    tops = df.nsmallest(1, "predicted_loss")
    merged.append(tops)

df = pd.concat(merged, ignore_index=True)
df["DCLM scale"] = ["400M-1x", "400M-4x", "1B-1x", "1B-3x", "3B-1x", "1B-5x", "7B-1x", "7B-2x"]
df = df[["DCLM scale", "FLOPs"] + [col for col in df.columns if col != "DCLM scale" and col != "FLOPs"]]
df


Unnamed: 0,DCLM scale,FLOPs,num_layers,padded_vocab_size,hidden_size,ffn_hidden_size,moe_ffn_hidden_size,num_experts,moe_router_topk,active_params,total_params,tokens,predicted_loss
0,400M-1x,2e+19,12,50304,1024,5472,704,64,8,354.9M,1.7B,9392317000.0,3.083589
1,400M-4x,8e+19,18,50304,1536,8208,1056,64,8,1.0B,5.8B,13333330000.0,2.879024
2,1B-1x,2.4e+20,21,50304,2048,10944,1408,64,8,2.0B,12.0B,20000000000.0,2.752687
3,1B-3x,7.2e+20,27,50304,2048,10944,1408,64,8,2.5B,15.5B,48000000000.0,2.665371
4,3B-1x,9.4e+20,27,50304,2048,10944,1408,64,8,2.5B,15.5B,62666670000.0,2.652439
5,1B-5x,1.2e+21,27,50304,2048,10944,1408,64,8,2.5B,15.5B,80000000000.0,2.642575
6,7B-1x,5.7e+21,27,50304,2048,10944,1408,64,8,2.5B,15.5B,380000000000.0,2.607853
7,7B-2x,1.1e+22,27,50304,2048,10944,1408,64,8,2.5B,15.5B,733333300000.0,2.601495
