# Fetch the Dataset

In [None]:
import wandb
import pandas as pd

def parse_count(s: str) -> int:
    s = s.strip().upper()
    if s.endswith("B"):
        return int(float(s[:-1]) * 1_000_000_000)
    elif s.endswith("M"):
        return int(float(s[:-1]) * 1_000_000)
    elif s.endswith("K"):
        return int(float(s[:-1]) * 1_000)
    return int(float(s))

api, data = wandb.Api(), []
frontier = pd.read_csv("https://docs.google.com/spreadsheets/d/1sIr9HRwYbUXKzlskUTMorMa2A_cAzDwE0eUJnk-W1dQ/export?format=csv&gid=1059339506")

for run in api.runs("haok/flame-moe", {"group": {"$regex": "ablation"}}):
    if run.state != "finished": continue
    flops = run.group.split("-").pop()
    loss = run.summary["lm loss validation"]
    num_layers, hidden_size = run.config["num_layers"], run.config["hidden_size"]
    selected = frontier[(frontier["num_layers"] == num_layers) & (frontier["hidden_size"] == hidden_size)]
    active_params, total_params = selected.iloc[0]["active_params"], selected.iloc[0]["total_params"]
    active_params, total_params = parse_count(active_params), parse_count(total_params)
    data.append((flops, active_params, total_params, loss))

df = pd.DataFrame(data, columns=["flops", "active_params", "total_params", "loss"])
print(df.head(3))


# Optimize the Coefficients

In [None]:
from math import exp
import warnings
warnings.filterwarnings("ignore")

from tqdm import tqdm
from itertools import product
from scipy.optimize import minimize
import numpy as np

# Step 1: Define the scaling law model
def scaling_law(params, N, D):
    E, A, alpha, B, beta = params
    return E + A / (N ** alpha) + B / (D ** beta)

# Step 2: Define the loss function (MSE)
def mse_loss(params, N, D, targets):
    preds = scaling_law(params, N, D)
    return np.mean((preds - targets) ** 2)

# Step 3: Prepare your data
N = df["active_params"].values
D = (df["flops"].astype(float) / (6 * df["active_params"])).values
targets = df["loss"].values

# Step 4: Set an initial guess
E_range = [exp(0), exp(2), exp(4), exp(8), exp(16)]
A_range = [exp(0), exp(2), exp(4), exp(8), exp(16)]
alpha_range = [0, 0.25, 0.5, 0.75, 1, 1.25]
B_range = [exp(0), exp(2), exp(4), exp(8), exp(16)]
beta_range = [0, 0.25, 0.5, 0.75, 1, 1.25]
initial_guesses = list(product(E_range, A_range, alpha_range, B_range, beta_range))

# Step 5: Fit using L-BFGS-B
best_result = None
lowest_mse = np.inf

for guess in tqdm(initial_guesses):
    result = minimize(
        mse_loss,
        x0=guess,
        args=(N, D, targets),
        method="L-BFGS-B"
    )
    if result.success:
        current_mse = mse_loss(result.x, N, D, targets)
        if current_mse < lowest_mse:
            lowest_mse = current_mse
            best_result = result

# Step 6: Show the results
if best_result:
    E_opt, A_opt, alpha_opt, B_opt, beta_opt = best_result.x
    print(f"E     = {E_opt:.6f}")
    print(f"A     = {A_opt:.6f}")
    print(f"alpha = {alpha_opt:.6f}")
    print(f"B     = {B_opt:.6f}")
    print(f"beta  = {beta_opt:.6f}")
    print(f"MSE   = {lowest_mse:.6f}")
else:
    print("Optimization failed for all initial guesses.")


# Find the Models

In [None]:
import pandas as pd

def scaling_law(N, D):
    E, A, alpha, B, beta = -12.498690, 0.000002, -0.653283, 54.598254, 0.020899
    return E + A / (N ** alpha) + B / (D ** beta)

flops = 1e19
