In [2]:
import numpy as np
from itertools import product
from datetime import datetime
import os

is_eu4w = False
eu4w_base_args = {
    "data.train_urls": "\"['gs://euw4-stack-v2/20250122-flattened-shuffled/shard-{00..58}.jsonl.gz']\"",
    "data.validation_urls": "\"['gs://euw4-stack-v2/20250122-flattened-shuffled/shard-{59..60}.jsonl.gz']\"",
    "data.cache_dir": "gs://euw4-datacache/20250122-flattened-shuffled",
    "trainer.checkpointer.base_path": "gs://euw4-ckpt",
    "trainer.per_device_parallelism": 2,
    "trainer.per_device_eval_parallelism": 2,
    "model.upcast_attn": "true",
}
all_args = eu4w_base_args if is_eu4w else {}
tpu_type = "v3-8" if is_eu4w else "v4-8"
script = "src/levanter/main/routed_lm.py"

def datetime_str():
    return datetime.now().strftime("%y%m%d%H%M")

def make_runs(out_file, sweep_name, sweep_params, extra_attrs, group, i=0, sort_key=None, commented=False):
    callables = {k: v for k, v in sweep_params.items() if callable(v)}
    sweep_params = {k: v for k, v in sweep_params.items() if not callable(v)}
    ks, vs = zip(*sweep_params.items())
    ks, vs = list(ks), list(vs)
    out_file.write(f"##### {group} #####\n")
    sort_idx = 0 if sort_key is None else ks.index(sort_key)
    all_v = sorted(list(product(*vs)), key=lambda x: x[sort_idx])
    for v in all_v:
        a = extra_attrs.copy()
        for k, vv in zip(ks, v):
            a[k] = vv
        a["trainer.wandb.tags"] = f"'[{sweep_name}, {group}]'"
        a["trainer.wandb.name"] = f"{sweep_name}_{group}_{i}"
        cmd = f"python {script} --config_path config/rlora.yaml"
        a.update({k: v(a) for k, v in callables.items()})
        for k, v in a.items():
            cmd += f" --{k} {v}"
        if commented:
            cmd = "# " + cmd
        out_file.write(cmd + "\n")
        i += 1
    return i

In [5]:
sweep_name = "pikachu"

sweep_params = {
    "optimizer.learning_rate": np.logspace(-2, -5, 6).round(6),
    "trainer.num_train_steps": np.logspace(np.log10(50), np.log10(500), 4).round(-1).astype(int)
}

attrs = all_args | {
    "model.num_experts": 32,
    "model.expert_rank": 128,
    "model.top_k": 4,
    "model.expert_type": "mlp_glu",
    "model.expert_init": "mlp_zero_down",
    "model.expert_init_scale": 0.01,
}

baseline_attrs = attrs.copy()
baseline_attrs["model.num_experts"] = 1
baseline_attrs["model.expert_rank"] = attrs["model.expert_rank"] * attrs["model.top_k"]
baseline_attrs["model.ident_expert_mask"] = "true"
baseline_attrs["model.top_k"] = 1

max_k_attrs = attrs.copy()
max_k_attrs["model.top_k"] = max_k_attrs["model.num_experts"]

max_k_prefill_exp_attrs = max_k_attrs | {
    'model.prefill_expert': 'true'
}

prefill_exp_attrs = attrs | {
    'model.prefill_expert': 'true'
}

sigmoid_prefill_exp_attrs = prefill_exp_attrs | {
    'model.router_activation': 'sigmoid'
}

tight_lr_params = { 
    "optimizer.learning_rate": np.logspace(np.log10(1.58e-4), np.log10(1.5e-3), 10).round(6),
    "trainer.num_train_steps": [500]
}

zloss_params = {
    "optimizer.learning_rate": np.logspace(np.log10(5e-3), np.log10(6e-5), 6).round(6),
    "trainer.num_train_steps": [110],
    "router_z_loss_weight": np.logspace(-1, -5, 6).round(6),
}
zloss_attrs = prefill_exp_attrs | {
    "trainer.abort_if_loss_above": 8.0,
}


os.makedirs("sweeps", exist_ok=True)
fname = os.path.abspath(f"sweeps/{sweep_name}_{datetime_str()}.cmd")
with open(fname, "w") as f:
    i = 0
    i = make_runs(f, sweep_name, sweep_params, baseline_attrs, "baseline", i, sort_key="trainer.num_train_steps", commented=True)
    i = make_runs(f, sweep_name, sweep_params, attrs, "seqmoe", i, sort_key="trainer.num_train_steps", commented=True)
    i = make_runs(f, sweep_name, sweep_params, max_k_attrs, "maxk", i, sort_key="trainer.num_train_steps", commented=True)
    i = make_runs(f, sweep_name, sweep_params, max_k_prefill_exp_attrs, "maxk_prefill_expert_v2", i, sort_key="trainer.num_train_steps", commented=True)
    i = make_runs(f, sweep_name, sweep_params, prefill_exp_attrs, "prefill_expert_v2", i, sort_key="trainer.num_train_steps", commented=True)
    i = make_runs(f, sweep_name, sweep_params, sigmoid_prefill_exp_attrs, "sigmoid_prefill_expert_v2", i, sort_key="trainer.num_train_steps", commented=True)
    i = make_runs(f, sweep_name, tight_lr_params, prefill_exp_attrs, "prefill_expert_v2_tightlr", i, sort_key="optimizer.learning_rate", commented=True)
    i = make_runs(f, sweep_name, tight_lr_params, baseline_attrs, "baseline_tightlr", i, sort_key="optimizer.learning_rate", commented=True)
    i = make_runs(f, sweep_name, tight_lr_params, baseline_attrs, "baseline_tightlr", i, sort_key="optimizer.learning_rate", commented=True)
    i = make_runs(f, sweep_name, zloss_params, zloss_attrs, "zloss", i, sort_key="optimizer.learning_rate", commented=False)
    i
print("Num pikachu runs", i)
print(f"python ./infra/launch_on_ray.py --tpu_type {tpu_type} {fname}")

Num pikachu runs 210
python ./infra/launch_on_ray.py --tpu_type v4-8 /Users/will/proj/levanter/infra/sweeps/pikachu_2502071559.cmd


In [7]:

sweep_name = "bulba"

sweep_params = {
    "optimizer.learning_rate": np.logspace(np.log10(1.58e-4), np.log10(1.5e-3), 10).round(6),
}

attrs = all_args | {
    "model.num_experts": 32,
    "model.expert_rank": 128,
    "model.top_k": 4,
    "model.expert_type": "mlp_glu",
    "model.expert_init": "mlp_zero_down",
    "model.expert_init_scale": 0.01,
    "model.prefill_expert": 'true',
    "trainer.num_train_steps": 110,
}

baseline_attrs = attrs.copy()
baseline_attrs["model.num_experts"] = 1
baseline_attrs["model.expert_rank"] = attrs["model.expert_rank"] * attrs["model.top_k"]
baseline_attrs["model.ident_expert_mask"] = "true"
baseline_attrs["model.top_k"] = 1
baseline_attrs["model.prefill_expert"] = "false"

sigmoid_attrs = attrs | {
    'model.router_activation': 'sigmoid'
}

per_layer_attrs = attrs | {
    "model.route_each_layer": "true"
}

more_experts_params = {
    "model.num_experts": [8, 16, 64, 128, 256],
    "optimizer.learning_rate": sweep_params["optimizer.learning_rate"][4:-2],
    "trainer.per_device_parallelism": lambda a: 4 if a["model.num_experts"] < 128 else 1,
    "trainer.per_device_eval_parallelism": lambda a: 4 if a["model.num_experts"] < 128 else 1,
}
print(more_experts_params)

total_memory = attrs["model.num_experts"] * attrs["model.expert_rank"]
active_memory = attrs["model.top_k"] * attrs["model.expert_rank"]
# E * r = M
# k * r = A
# Fix M, A, vary E, solve for r, k
# r = M / E
# k = A / r = A * E / M
granularity_params = more_experts_params | {
    "model.num_experts": [8, 16, 64, 128, 256],
    "model.expert_rank": lambda a:  total_memory // a["model.num_experts"],
    "model.top_k": lambda a:  active_memory * a["model.num_experts"] // total_memory
}
del granularity_params["trainer.per_device_eval_parallelism"]
del granularity_params["trainer.per_device_parallelism"]


os.makedirs("sweeps", exist_ok=True)
fname = os.path.abspath(f"sweeps/{sweep_name}_{datetime_str()}.cmd")
with open(fname, "w") as f:
    i = 0
    i = make_runs(f, sweep_name, sweep_params, baseline_attrs, "baseline", i, commented=True)
    i = make_runs(f, sweep_name, sweep_params, attrs, "seqmoe", i, commented=True)
    i = make_runs(f, sweep_name, sweep_params, sigmoid_attrs, "seqmoe_sigmoid", i, commented=True)
    i = make_runs(f, sweep_name, sweep_params, per_layer_attrs, "seqmoe_perlayer", i, commented=True)
    i = make_runs(f, sweep_name, more_experts_params, attrs, "seqmoe_expert_sweep", i, commented=False)
    i = make_runs(f, sweep_name, granularity_params, attrs, "seqmoe_granularity", i, commented=False)
    i
print(f"Num {sweep_name} runs", i)
print(f"python ./infra/launch_on_ray.py --tpu_type {tpu_type} {fname}")

{'model.num_experts': [8, 16, 64, 128, 256], 'optimizer.learning_rate': array([0.00043 , 0.000552, 0.000708, 0.00091 ]), 'trainer.per_device_parallelism': <function <lambda> at 0x1052ffba0>, 'trainer.per_device_eval_parallelism': <function <lambda> at 0x1052ff9c0>}
Num bulba runs 80
python ./infra/launch_on_ray.py --tpu_type v4-8 /Users/will/proj/levanter/infra/sweeps/bulba_2502071501.cmd


In [8]:
sweep_name = "charmander"
lrs = np.logspace(np.log10(1.5e-3), -6, 6).round(6)
sweep_params = {
    "optimizer.learning_rate": lrs,
    "full_ft_base_weights_optimizer.learning_rate": lrs,
}

attrs = all_args | {
    "model.num_experts": 32,
    "model.expert_rank": 128,
    "model.top_k": 4,
    "model.expert_type": "mlp_glu",
    "model.expert_init": "mlp_zero_down",
    "model.expert_init_scale": 0.01,
    "optimizer.warmup": 0.08,
    "full_ft_base_weights_optimizer": "\"{'type': 'adam'}\"",
    "full_ft_base_weights_optimizer.weight_decay": 0.0,
    "full_ft_base_weights_optimizer.warmup": 0.08,
    "trainer.num_train_steps": 50,
    "full_ft": "true",
}

baseline_attrs = attrs.copy()
baseline_attrs["model.num_experts"] = 1
baseline_attrs["model.expert_rank"] = attrs["model.expert_rank"] * attrs["model.top_k"]
baseline_attrs["model.ident_expert_mask"] = "true"
baseline_attrs["model.top_k"] = 1
baseline_attrs["model.prefill_expert"] = "false"

os.makedirs("sweeps", exist_ok=True)
fname = os.path.abspath(f"sweeps/{sweep_name}_{datetime_str()}.cmd")
with open(fname, "w") as f:
    i = 0
    i = make_runs(f, sweep_name, sweep_params, baseline_attrs, "baseline", i)
    i = make_runs(f, sweep_name, sweep_params, attrs, "seqmoe", i)
print(f"Num {sweep_name} runs", i)
print(f"python ./infra/launch_on_ray.py --tpu_type {tpu_type} {fname}")

Num charmander runs 72
python ./infra/launch_on_ray.py --tpu_type v4-8 /Users/will/proj/levanter/infra/sweeps/charmander_2502071501.cmd
