# Fetch the Dataset

In [13]:
import wandb
import pandas as pd

def parse_count(s: str) -> int:
    s = s.strip().upper()
    if s.endswith("B"):
        return int(float(s[:-1]) * 1_000_000_000)
    elif s.endswith("M"):
        return int(float(s[:-1]) * 1_000_000)
    elif s.endswith("K"):
        return int(float(s[:-1]) * 1_000)
    return int(float(s))

api, data = wandb.Api(), []
frontier = pd.read_csv("https://docs.google.com/spreadsheets/d/1sIr9HRwYbUXKzlskUTMorMa2A_cAzDwE0eUJnk-W1dQ/export?format=csv&gid=1059339506")

for run in api.runs("haok/flame-moe", {"group": {"$regex": "ablation"}}):
    if run.state != "finished": continue
    flops = run.group.split("-").pop()
    loss = run.summary["lm loss validation"]
    num_layers, hidden_size = run.config["num_layers"], run.config["hidden_size"]
    selected = frontier[(frontier["num_layers"] == num_layers) & (frontier["hidden_size"] == hidden_size)]
    active_params, total_params = selected.iloc[0]["active_params"], selected.iloc[0]["total_params"]
    active_params, total_params = parse_count(active_params), parse_count(total_params)
    data.append((flops, active_params, total_params, loss))

df = pd.DataFrame(data, columns=["flops", "active_params", "total_params", "loss"])
df


Unnamed: 0,flops,active_params,total_params,loss
0,6e+18,82200000,239000000,3.365395
1,6e+18,219000000,995100000,3.332465
2,3e+19,33400000,72600000,3.39731
3,3e+19,82200000,239000000,3.135083
4,3e+19,37500000,100200000,3.307347
5,3e+19,182700000,747100000,3.034445
6,3e+19,354900000,1700000000,3.031785
7,3e+19,98400000,349200000,3.079536
8,3e+19,721200000,3800000000,3.035623
9,3e+19,219000000,995100000,3.007967


# Optimize the Coefficients

In [3]:
from math import exp
import warnings
warnings.filterwarnings("ignore")

from tqdm import tqdm
from itertools import product
from scipy.optimize import minimize
import numpy as np

# Step 1: Define the scaling law model
def scaling_law(params, N, D):
    E, A, alpha, B, beta = params
    return E + A / (N ** alpha) + B / (D ** beta)

# Step 2: Define the loss function (MSE)
def mse_loss(params, N, D, targets):
    preds = scaling_law(params, N, D)
    return np.mean((preds - targets) ** 2)

# Step 3: Prepare your data
N = df["active_params"].values
D = (df["flops"].astype(float) / (6 * df["active_params"])).values
targets = df["loss"].values

# Step 4: Set an initial guess
E_range = [exp(0), exp(2), exp(4), exp(8), exp(16)]
A_range = [exp(0), exp(2), exp(4), exp(8), exp(16)]
alpha_range = [0, 0.25, 0.5, 0.75, 1, 1.25]
B_range = [exp(0), exp(2), exp(4), exp(8), exp(16)]
beta_range = [0, 0.25, 0.5, 0.75, 1, 1.25]
initial_guesses = list(product(E_range, A_range, alpha_range, B_range, beta_range))

# Step 5: Fit using L-BFGS-B
best_result = None
lowest_mse = np.inf

for guess in tqdm(initial_guesses):
    result = minimize(
        mse_loss,
        x0=guess,
        args=(N, D, targets),
        method="L-BFGS-B"
    )
    if result.success:
        current_mse = mse_loss(result.x, N, D, targets)
        if current_mse < lowest_mse:
            lowest_mse = current_mse
            best_result = result

# Step 6: Show the results
if best_result:
    E_opt, A_opt, alpha_opt, B_opt, beta_opt = best_result.x
    print(f"E     = {E_opt:.6f}")
    print(f"A     = {A_opt:.6f}")
    print(f"alpha = {alpha_opt:.6f}")
    print(f"B     = {B_opt:.6f}")
    print(f"beta  = {beta_opt:.6f}")
    print(f"MSE   = {lowest_mse:.6f}")
else:
    print("Optimization failed for all initial guesses.")


100%|██████████| 4500/4500 [00:27<00:00, 162.88it/s]

E     = 1.958528
A     = 123.248650
alpha = 0.257583
B     = 8886110.520508
beta  = 0.750334
MSE   = 0.006187





# Find the Models

In [10]:
import pandas as pd

def parse_count(s: str) -> int:
    s = s.strip().upper()
    if s.endswith("B"):
        return int(float(s[:-1]) * 1_000_000_000)
    elif s.endswith("M"):
        return int(float(s[:-1]) * 1_000_000)
    elif s.endswith("K"):
        return int(float(s[:-1]) * 1_000)
    return int(float(s))

def scaling_law(N, D):
    E, A, alpha, B, beta = 1.958528, 123.248650, 0.257583, 8886110.520508, 0.750334
    return E + A / (N ** alpha) + B / (D ** beta)

# Step 1: Load and transform
df = pd.read_csv("https://docs.google.com/spreadsheets/d/1sIr9HRwYbUXKzlskUTMorMa2A_cAzDwE0eUJnk-W1dQ/export?format=csv&gid=599230821")

# Step 2: Define budgets
budgets = [6e18, 1e19, 3e19, 6e19, 1e20, 3e20, 6e20]

# Step 3: Compute predicted loss and best configs
merged = []
for budget in budgets:
    df["flops"] = budget
    df["tokens"] = budget / (6 * df["active_params"].apply(parse_count))
    df["predicted_loss"] = scaling_law(df["active_params"].apply(parse_count), df["tokens"])
    tops = df.nsmallest(1, "predicted_loss")
    merged.append(tops)

df = pd.concat(merged, ignore_index=True)
df

Unnamed: 0,num_layers,padded_vocab_size,hidden_size,ffn_hidden_size,moe_ffn_hidden_size,num_experts,moe_router_topk,active_params,total_params,flops,tokens,predicted_loss
0,6,50304,768,4104,528,64,8,146.4M,499.2M,6e+18,6830601000.0,3.301284
1,9,50304,768,4104,528,64,8,182.7M,747.1M,1e+19,9122423000.0,3.17497
2,15,50304,1024,5472,704,64,8,419.4M,2.2B,3e+19,11921790000.0,2.943788
3,12,50304,1536,8208,1056,64,8,721.2M,3.8B,6e+19,13865780000.0,2.821048
4,18,50304,1536,8208,1056,64,8,1.0B,5.8B,1e+20,16666670000.0,2.740888
5,27,50304,2048,10944,1408,64,8,2.5B,15.5B,3e+20,20000000000.0,2.592065
6,27,50304,2048,10944,1408,64,8,2.5B,15.5B,6e+20,40000000000.0,2.524841
