In [1]:
import wandb
import pandas as pd

from notebooks.utils import WANDB_PROJECT

In [2]:
merge_keys = ["batch_size", "lr", "mixing_shift", "layers", "hidden_size", "beta2", "tpu_zone", "tpu_version"]

In [3]:
def format_command(row):
    # print(row)
    "main_ray.py --batch_size 256 --lr 0.3 --max_training_steps 100000 --num_layers 16 --hidden_size 1024 --num_attn_heads 16 --max_seq_len 2048 --cooldown_steps 0.0 --beta2 0.98 --no-break_on_nan --wandb_name gidd-L16-D1024-H16-N2048-bs=256-lr=0.3-noise_mask --hybrid_mixing_shift -1000.0 --wandb_tags crit_bs_l16_d1024,noise_mask"
    if not pd.isna(row["command"]) and not pd.isna(row["wandb_id"]):
        print(row["command"])
        if "--resume_wandb_id" in row["command"]:
            return row["command"]
        else:
            return f"{row['command']} --resume_wandb_id {row['wandb_id']}"
    noise_label = {
        -1000.0: "noise_mask",
        -2.0: "noise_low_uniform",
        0.0: "noise_balanced",
        2.0: "noise_high_uniform",
        1000.0: "noise_uniform",
    }.get(row["mixing_shift"], "noise_unknown")
    args = [
        ("--batch_size", int(row["batch_size"])),
        ("--lr", float(row["lr"])),
        ("--max_training_steps", 100000),
        ("--num_layers", int(row["layers"])),
        ("--hidden_size", int(row["hidden_size"])),
        ("--num_attn_heads", int(row["hidden_size"]) // 64 if row["hidden_size"] <= 1024 else row["hidden_size"] // 128),
        ("--max_seq_len", 2048),
        ("--cooldown_steps", 0.0),
        ("--beta2", float(row["beta2"])),
        ("--no-break_on_nan",),
        ("--wandb_name", f"gidd-L{int(row['layers'])}-D{int(row['hidden_size'])}-H{int(row['hidden_size'])//64}-N2048-bs={int(row['batch_size'])}-lr={float(row['lr'])}-{noise_label}"),
        ("--hybrid_mixing_shift", float(row["mixing_shift"])),
        ("--wandb_tags", f"crit_bs_l{int(row['layers'])}_d{int(row['hidden_size'])},{noise_label}"),
    ]
    return "main_ray.py " + " ".join(" ".join(map(str, arg)) for arg in args)

In [4]:
from datetime import datetime, timezone
from zoneinfo import ZoneInfo

dt_local = datetime(2025, 9, 13, 21, 0, tzinfo=ZoneInfo("Europe/Zurich"))
dt_utc = dt_local.astimezone(timezone.utc)

runs = wandb.Api().runs(
    path=WANDB_PROJECT,
    filters={
        "$and": [
            {"$or": [
                {"tags": "crit_bs_l8_d512"},
                {"tags": "crit_bs_l12_d768"},
                {"tags": "crit_bs_l16_d1024"},
                {"tags": "scaling_laws"},
            ]},
            {"created_at": {"$gte": dt_utc.isoformat()}}
        ],
    },
    order="+created_at",
    per_page=1000,
)
len(runs)

510

In [5]:
from concurrent.futures import ThreadPoolExecutor, as_completed
import subprocess
import re


patterns = {
    "batch_size": re.compile(r"--batch_size (\d+)"),
    "lr": re.compile(r"--lr ([\d.]+)"),
    "layers": re.compile(r"--num_layers (\d+)"),
    "hidden_size": re.compile(r"--hidden_size (\d+)"),
    "beta2": re.compile(r"--beta2 ([\d.]+)"),
    "mixing_shift": re.compile(r"--hybrid_mixing_shift ([\-\d.]+)"),
    "tpu_zone": re.compile(r"TPU_ZONE=([a-z0-9\-]+)"),
    "tpu_version": re.compile(r"TPU_VERSION=([a-z0-9\-]+)"),
    "tpu_count": re.compile(r"TPU_POD_COUNT=(\d+)"),
}

def parse_command(command):
    parsed = {}
    for key, pattern in patterns.items():
        match = pattern.search(command)
        if match:
            val = match.group(1)
            try:
                parsed[key] = float(val) if '.' in val else int(val)
            except ValueError:
                parsed[key] = val
    return parsed

addresses = {
    "asia-northeast1-b": "$RAY_ASIA_NORTHEAST1_B",
    "europe-west4-a": "$RAY_EU_WEST4_A",
    "us-central1-a": "$RAY_US_CENTRAL1_A",
    "us-east1-d": "$RAY_US_EAST1_D",
    "us-east5-a": "$RAY_US_EAST5_A",
    "us-east5-b": "$RAY_US_EAST5_B",
}

# for address in addresses:
#     output = subprocess.run(["zsh", "-lc", f"ray job list --address {address}"], capture_output=True, text=True)
#     print(output)
all_rows = []
with ThreadPoolExecutor(max_workers=len(addresses)) as executor:
    futures = {}
    for tpu_zone, address in addresses.items():
        cmd = f"source ~/.zshrc && ray job list --address {address}"
        futures[executor.submit(subprocess.run, cmd, shell=True, capture_output=True, text=True, executable="/bin/zsh")] = tpu_zone
    for future in as_completed(futures):
        tpu_zone = futures[future]
        output = future.result()
        # print(output)
        lines = output.stdout.splitlines()[1:]
        rows = [parse_command(x) for x in lines if "RUNNING" in x]
        for row in rows:
            row["tpu_zone"] = tpu_zone
            row["is_running"] = True
        all_rows.extend(rows)

running_df = pd.DataFrame(all_rows)
running_df

Unnamed: 0,batch_size,lr,layers,hidden_size,beta2,mixing_shift,tpu_version,tpu_count,tpu_zone,is_running
0,1024,2.0,8,512,0.97,0.0,v6e-256,1,us-east1-d,True
1,256,0.5,12,768,0.98,-1000.0,v6e-32,1,us-east1-d,True
2,256,1.0,16,1024,0.98,0.0,v6e-64,1,us-east1-d,True
3,256,0.3,12,768,0.98,2.0,v6e-32,1,us-east1-d,True
4,256,0.3,16,1024,0.98,0.0,v6e-64,1,europe-west4-a,True
5,128,0.4,10,640,0.99,1000.0,v6e-8,4,europe-west4-a,True
6,128,0.4,20,1536,0.99,-2.0,v6e-8,4,europe-west4-a,True
7,32,0.25,10,640,0.99,0.0,v6e-8,2,europe-west4-a,True
8,512,0.5,8,512,0.95,0.0,v5p-32,1,us-central1-a,True
9,1024,1.0,8,512,0.97,0.0,v5p-64,1,us-central1-a,True


In [6]:
hparams_by_bs = {
    8: ([0.05, 0.1, 0.2, 0.3], [0.99]),
    16: ([0.1, 0.2, 0.3, 0.5], [0.99]),
    32: ([0.1, 0.2, 0.3, 0.5], [0.99]),
    64: ([0.2, 0.3, 0.5, 1.0], [0.99]),
    128: ([0.3, 0.5, 1.0], [0.99]),
    256: ([0.3, 0.5, 1.0], [0.98]),
    512: ([0.3, 0.5, 1.0], [0.98]),
}
target_mixing_shifts = [-1000.0, -2.0, 0.0, 2.0, 1000.0]
target_sizes = [(8, 512), (12, 768), (16, 1024)]

rows = [
    {
        "batch_size": bs,
        "lr": lr,
        "mixing_shift": ms,
        "layers": l,
        "hidden_size": d,
        "beta2": b2,
    }
    for bs, (lrs, b2s) in hparams_by_bs.items() for lr in lrs for b2 in b2s
    for ms in target_mixing_shifts
    for l, d in target_sizes
]

# ---------------------------

hparams_by_bs = {
    8: ([0.2], [0.99]),
    16: ([0.2], [0.99]),
    32: ([0.25], [0.99]),
    64: ([0.35], [0.99]),
    128: ([0.4], [0.99]),
    256: ([0.5], [0.98]),
    512: ([0.6], [0.98]),
}
target_mixing_shifts = [-1000.0, -2.0, 0.0, 2.0, 1000.0]
target_sizes = [(10, 640), (20, 1536)]

rows += [
    {
        "batch_size": bs,
        "lr": lr,
        "mixing_shift": ms,
        "layers": l,
        "hidden_size": d,
        "beta2": b2,
    }
    for bs, (lrs, b2s) in hparams_by_bs.items() for lr in lrs for b2 in b2s
    for ms in target_mixing_shifts
    for l, d in target_sizes
]
len(rows)

445

In [7]:
import tqdm.auto as tqdm

wandb_rows = []
needed_df = pd.DataFrame(rows)
needed_df["is_needed"] = True
for run in tqdm.tqdm(runs):
    cfg = run.config
    state = run.state
    wandb_rows.append({
        "batch_size": cfg.get("batch_size"),
        "lr": cfg.get("lr"),
        "mixing_shift": cfg.get("hybrid_mixing_shift"),
        "layers": cfg.get("num_layers"),
        "hidden_size": cfg.get("hidden_size"),
        "beta2": cfg.get("beta2"),
        "status": state,
        "wandb_id": run.id,
        "ray_id": cfg.get("ray_job_id"),
        "step": run.summary.get("_step"),
        "name": run.name,
        "tpu_version": cfg.get("tpu_version"),
        "tpu_zone": cfg.get("tpu_zone"),
        "tpu_count": cfg.get("tpu_pod_count"),
        "command": cfg.get("command"),
    })

wandb_df = pd.DataFrame(wandb_rows).dropna(how="any")

  0%|          | 0/510 [00:00<?, ?it/s]

In [8]:
existing_df = wandb_df.merge(running_df, how="outer", on=merge_keys, suffixes=("", "_running"))
existing_df.loc[existing_df["status"].isna(), "status"] = "pending"
existing_df.loc[existing_df["tpu_count"].isna(), "tpu_count"] = existing_df.loc[existing_df["tpu_count"].isna(), "tpu_count_running"].astype(int)
existing_df

Unnamed: 0,batch_size,lr,mixing_shift,layers,hidden_size,beta2,status,wandb_id,ray_id,step,name,tpu_version,tpu_zone,tpu_count,command,tpu_count_running,is_running
0,8.0,0.05,-1000.0,8.0,512.0,0.99,finished,41x399b0,14000000,100000.0,gidd-L8-D512-H8-N2048-bs=8-lr=0.05-noise_mask,v6e-8,europe-west4-a,1,main_ray.py --batch_size 8 --lr 0.05 --max_tra...,,
1,8.0,0.05,-1000.0,12.0,768.0,0.99,finished,s0u7gktl,19000000,100000.0,gidd-L12-D768-H12-N2048-bs=8-lr=0.05-noise_mask,v6e-8,europe-west4-a,1,main_ray.py --batch_size 8 --lr 0.05 --max_tra...,,
2,8.0,0.05,-1000.0,16.0,1024.0,0.99,finished,aoycq16e,1e000000,100000.0,gidd-L16-D1024-H16-N2048-bs=8-lr=0.05-noise_mask,v6e-8,europe-west4-a,1,main_ray.py --batch_size 8 --lr 0.05 --max_tra...,,
3,8.0,0.05,-2.0,8.0,512.0,0.99,finished,j9sznhsd,4a000000,100000.0,gidd-L8-D512-H8-N2048-bs=8-lr=0.05-noise_low_u...,v6e-8,us-east5-b,1,main_ray.py --batch_size 8 --lr 0.05 --max_tra...,,
4,8.0,0.05,-2.0,12.0,768.0,0.99,finished,7nd2xgmu,4f000000,100000.0,gidd-L12-D768-H12-N2048-bs=8-lr=0.05-noise_low...,v6e-8,us-east5-b,1,main_ray.py --batch_size 8 --lr 0.05 --max_tra...,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
457,512.0,2.00,0.0,8.0,512.0,0.98,finished,ngqox211,14000000,100000.0,gidd-L8-D512-H8-N2048-bs=512-lr=2.0-noise_bala...,v5p-32,us-east5-a,1,main_ray.py --batch_size 512 --lr 2.0 --max_tr...,,
458,512.0,2.00,2.0,8.0,512.0,0.98,finished,clu5ynkz,1f000000,100000.0,gidd-L8-D512-H8-N2048-bs=512-lr=2.0-noise_high...,v5p-32,us-east5-a,1,main_ray.py --batch_size 512 --lr 2.0 --max_tr...,,
459,1024.0,0.80,,40.0,5120.0,0.95,pending,,,,,v5p-1024,us-east5-a,1,,1.0,True
460,1024.0,1.00,0.0,8.0,512.0,0.97,crashed,mgywim12,29000000,21850.0,gidd-L8-D512-H8-N2048-bs=1024-lr=1.0-noise_bal...,v5p-64,us-central1-a,1,main_ray.py --batch_size 1024 --lr 1.0 --max_t...,1.0,True


In [9]:
df = existing_df.merge(needed_df, how="outer")
df.loc[:, "is_needed"] = df["is_needed"].fillna(False)
df["is_done"] = df["status"] == "finished"
df.loc[df["is_running"].isna(), "is_running"] = False
df = df.sort_values(["is_done", "is_running", "step"], ascending=[False, False, False])
df = df.drop_duplicates(subset=merge_keys, keep="first", inplace=False)
df.loc[(df["status"] == "crashed") & df["is_needed"] & (df["step"].isna() | (df["step"] < 100)), "status"] = "missing"
df.loc[df["is_needed"], "status"] = df.loc[df["is_needed"], "status"].fillna("missing")
# df.loc[~df["wandb_id"].isna(), "status"] = df.loc[~df["wandb_id"].isna(), "status"].fillna("N/A")
df.loc[df["command"].isna(), "command"] = df.loc[df["command"].isna()].apply(format_command, axis=1)
df

  df.loc[:, "is_needed"] = df["is_needed"].fillna(False)


Unnamed: 0,batch_size,lr,mixing_shift,layers,hidden_size,beta2,status,wandb_id,ray_id,step,name,tpu_version,tpu_zone,tpu_count,command,tpu_count_running,is_running,is_needed,is_done
0,8.0,0.05,-1000.0,8.0,512.0,0.99,finished,41x399b0,14000000,100000.0,gidd-L8-D512-H8-N2048-bs=8-lr=0.05-noise_mask,v6e-8,europe-west4-a,1,main_ray.py --batch_size 8 --lr 0.05 --max_tra...,,False,True,True
1,8.0,0.05,-1000.0,12.0,768.0,0.99,finished,s0u7gktl,19000000,100000.0,gidd-L12-D768-H12-N2048-bs=8-lr=0.05-noise_mask,v6e-8,europe-west4-a,1,main_ray.py --batch_size 8 --lr 0.05 --max_tra...,,False,True,True
2,8.0,0.05,-1000.0,16.0,1024.0,0.99,finished,aoycq16e,1e000000,100000.0,gidd-L16-D1024-H16-N2048-bs=8-lr=0.05-noise_mask,v6e-8,europe-west4-a,1,main_ray.py --batch_size 8 --lr 0.05 --max_tra...,,False,True,True
3,8.0,0.05,-2.0,8.0,512.0,0.99,finished,j9sznhsd,4a000000,100000.0,gidd-L8-D512-H8-N2048-bs=8-lr=0.05-noise_low_u...,v6e-8,us-east5-b,1,main_ray.py --batch_size 8 --lr 0.05 --max_tra...,,False,True,True
4,8.0,0.05,-2.0,12.0,768.0,0.99,finished,7nd2xgmu,4f000000,100000.0,gidd-L12-D768-H12-N2048-bs=8-lr=0.05-noise_low...,v6e-8,us-east5-b,1,main_ray.py --batch_size 8 --lr 0.05 --max_tra...,,False,True,True
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
497,512.0,1.00,0.0,16.0,1024.0,0.98,missing,,,,,,,,main_ray.py --batch_size 512 --lr 1.0 --max_tr...,,False,True,False
499,512.0,1.00,2.0,12.0,768.0,0.98,missing,,,,,,,,main_ray.py --batch_size 512 --lr 1.0 --max_tr...,,False,True,False
500,512.0,1.00,2.0,16.0,1024.0,0.98,missing,,,,,,,,main_ray.py --batch_size 512 --lr 1.0 --max_tr...,,False,True,False
504,512.0,1.00,1000.0,12.0,768.0,0.98,missing,,,,,,,,main_ray.py --batch_size 512 --lr 1.0 --max_tr...,,False,True,False


In [10]:
df["status"].value_counts()

status
finished    361
missing      53
crashed      42
running      26
pending       3
Name: count, dtype: int64

In [11]:
df.loc[df["is_running"]].groupby("tpu_zone").size()

tpu_zone
asia-northeast1-b     3
europe-west4-a        4
us-central1-a         4
us-east1-d            4
us-east5-a           14
us-east5-b            5
dtype: int64

In [12]:
print(df.loc[df["is_running"]].groupby(["tpu_zone", "status"]).size())

tpu_zone           status 
asia-northeast1-b  running     3
europe-west4-a     crashed     1
                   running     3
us-central1-a      crashed     2
                   running     2
us-east1-d         running     4
us-east5-a         pending     2
                   running    12
us-east5-b         crashed     2
                   pending     1
                   running     2
dtype: int64


In [13]:
print(df.loc[(df["status"] != "crashed") & (df["is_running"] == True)].groupby(["tpu_zone", "status"]).size())

tpu_zone           status 
asia-northeast1-b  running     3
europe-west4-a     running     3
us-central1-a      running     2
us-east1-d         running     4
us-east5-a         pending     2
                   running    12
us-east5-b         pending     1
                   running     2
dtype: int64


In [14]:
print(df.loc[(df["status"] == "crashed") & (df["is_running"] != True)].groupby(["tpu_zone", "status"]).size())

tpu_zone           status 
asia-northeast1-b  crashed     7
europe-west4-a     crashed    14
us-central1-a      crashed     1
us-east5-b         crashed    15
dtype: int64


In [15]:
df.loc[df["status"] == "running"]

Unnamed: 0,batch_size,lr,mixing_shift,layers,hidden_size,beta2,status,wandb_id,ray_id,step,name,tpu_version,tpu_zone,tpu_count,command,tpu_count_running,is_running,is_needed,is_done
463,512.0,0.5,-2.0,12.0,768.0,0.98,running,pzk26tpw,25000000,98250.0,gidd-L12-D768-H12-N2048-bs=512-lr=0.5-noise_lo...,v5p-64,us-east5-a,1,main_ray.py --batch_size 512 --lr 0.5 --max_tr...,1.0,True,True,False
467,512.0,0.5,0.0,12.0,768.0,0.98,running,6yccys5s,26000000,98250.0,gidd-L12-D768-H12-N2048-bs=512-lr=0.5-noise_ba...,v5p-64,us-east5-a,1,main_ray.py --batch_size 512 --lr 0.5 --max_tr...,1.0,True,True,False
391,256.0,0.3,2.0,12.0,768.0,0.98,running,iv2ypo5q,09000000,93850.0,gidd-L12-D768-H12-N2048-bs=256-lr=0.3-noise_hi...,v6e-32,us-east1-d,1,main_ray.py --batch_size 256 --lr 0.3 --max_tr...,1.0,True,True,False
399,256.0,0.5,-1000.0,12.0,768.0,0.98,running,tmvuv129,07000000,89150.0,gidd-L12-D768-H12-N2048-bs=256-lr=0.5-noise_mask,v6e-32,us-east1-d,1,main_ray.py --batch_size 256 --lr 0.5 --max_tr...,1.0,True,True,False
465,512.0,0.5,0.0,8.0,512.0,0.95,running,bg12ue86,27000000,86550.0,gidd-L8-D512-H8-N2048-bs=512-lr=0.5-noise_bala...,v5p-32,us-central1-a,1,main_ray.py --batch_size 512 --lr 0.5 --max_tr...,1.0,True,False,False
460,512.0,0.5,-1000.0,16.0,1024.0,0.98,running,tvmgz7z9,22000000,83650.0,gidd-L16-D1024-H16-N2048-bs=512-lr=0.5-noise_mask,v5p-64,us-east5-a,1,main_ray.py --batch_size 512 --lr 0.5 --max_tr...,1.0,True,True,False
471,512.0,0.5,2.0,16.0,1024.0,0.98,running,7mpx66js,29000000,71650.0,gidd-L16-D1024-H16-N2048-bs=512-lr=0.5-noise_h...,v5p-64,us-east5-a,1,main_ray.py --batch_size 512 --lr 0.5 --max_tr...,1.0,True,True,False
459,512.0,0.5,-1000.0,12.0,768.0,0.98,running,13x51cgd,23000000,68950.0,gidd-L12-D768-H12-N2048-bs=512-lr=0.5-noise_mask,v5p-64,us-east5-a,1,main_ray.py --batch_size 512 --lr 0.5 --max_tr...,1.0,True,True,False
474,512.0,0.5,1000.0,12.0,768.0,0.98,running,g5cb0w76,2a000000,68850.0,gidd-L12-D768-H12-N2048-bs=512-lr=0.5-noise_un...,v5p-64,us-east5-a,1,main_ray.py --batch_size 512 --lr 0.5 --max_tr...,1.0,True,True,False
470,512.0,0.5,2.0,12.0,768.0,0.98,running,n5j1z3dt,28000000,68700.0,gidd-L12-D768-H12-N2048-bs=512-lr=0.5-noise_hi...,v5p-64,us-east5-a,1,main_ray.py --batch_size 512 --lr 0.5 --max_tr...,1.0,True,True,False


In [16]:
df.loc[(df["status"] == "crashed") & (df["is_running"] != True)]

Unnamed: 0,batch_size,lr,mixing_shift,layers,hidden_size,beta2,status,wandb_id,ray_id,step,name,tpu_version,tpu_zone,tpu_count,command,tpu_count_running,is_running,is_needed,is_done
170,32.0,0.25,-1000.0,10.0,640.0,0.99,crashed,g9kdbc98,35000000,94500.0,gidd-L10-D640-H10-bs=32-lr=0.25-noise_mask,v6e-8,europe-west4-a,2,main_ray.py --batch_size 32 --lr 0.25 --max_tr...,,False,True,False
254,64.0,0.3,1000.0,20.0,1536.0,0.99,crashed,cl3rs8d7,1b000000,74450.0,gidd-L20-D1536-H12-bs=64-lr=0.3-noise_uniform,v6e-8,asia-northeast1-b,2,main_ray.py --batch_size 64 --lr 0.3 --max_tra...,,False,False,False
243,64.0,0.3,0.0,20.0,1536.0,0.99,crashed,o9g43r2z,1a000000,66900.0,gidd-L20-D1536-H12-bs=64-lr=0.3-noise_balanced,v6e-8,asia-northeast1-b,2,main_ray.py --batch_size 64 --lr 0.3 --max_tra...,,False,False,False
240,64.0,0.3,0.0,10.0,640.0,0.99,crashed,yeoig5q7,15000000,54600.0,gidd-L10-D640-H10-bs=64-lr=0.3-noise_balanced,v6e-8,asia-northeast1-b,2,main_ray.py --batch_size 64 --lr 0.3 --max_tra...,,False,False,False
242,64.0,0.3,0.0,16.0,1024.0,0.99,crashed,0ur979k3,21000000,52350.0,gidd-L16-D1024-H16-N2048-bs=64-lr=0.3-noise_ba...,v6e-16,us-east5-b,1,main_ray.py --batch_size 64 --lr 0.3 --max_tra...,,False,True,False
303,128.0,0.3,-2.0,16.0,1024.0,0.99,crashed,judac7vi,43000000,51100.0,gidd-L16-D1024-H16-N2048-bs=128-lr=0.3-noise_l...,v6e-16,us-east5-b,1,main_ray.py --batch_size 128 --lr 0.3 --max_tr...,,False,True,False
331,128.0,0.5,0.0,8.0,512.0,0.99,crashed,l0ifyy46,11000000,47750.0,gidd-L8-D512-H8-N2048-bs=128-lr=0.5-noise_bala...,v6e-16,us-east5-b,1,main_ray.py --batch_size 128 --lr 0.5 --max_tr...,,False,True,False
379,256.0,0.3,-1000.0,12.0,768.0,0.98,crashed,ws40hhw0,14000000,47600.0,gidd-L12-D768-H12-N2048-bs=256-lr=0.3-noise_mask,v6e-32,us-east5-b,1,main_ray.py --batch_size 256 --lr 0.3 --max_tr...,,False,True,False
356,128.0,1.0,2.0,16.0,1024.0,0.99,crashed,v6fs2deb,3d000000,39350.0,gidd-L16-D1024-H16-N2048-bs=128-lr=1.0-noise_h...,v6e-16,us-east5-b,1,main_ray.py --batch_size 128 --lr 1.0 --max_tr...,,False,True,False
178,32.0,0.25,1000.0,10.0,640.0,0.99,crashed,qh7r3wcr,37000000,37550.0,gidd-L10-D640-H10-bs=32-lr=0.25-noise_uniform,v6e-8,europe-west4-a,2,main_ray.py --batch_size 32 --lr 0.25 --max_tr...,,False,True,False


In [17]:
df.loc[df["status"] == "missing"]

Unnamed: 0,batch_size,lr,mixing_shift,layers,hidden_size,beta2,status,wandb_id,ray_id,step,name,tpu_version,tpu_zone,tpu_count,command,tpu_count_running,is_running,is_needed,is_done
173,32.0,0.25,-2.0,20.0,1536.0,0.99,missing,,,,,,,,main_ray.py --batch_size 32 --lr 0.25 --max_tr...,,False,True,False
255,64.0,0.35,-1000.0,10.0,640.0,0.99,missing,,,,,,,,main_ray.py --batch_size 64 --lr 0.35 --max_tr...,,False,True,False
256,64.0,0.35,-1000.0,20.0,1536.0,0.99,missing,,,,,,,,main_ray.py --batch_size 64 --lr 0.35 --max_tr...,,False,True,False
257,64.0,0.35,-2.0,10.0,640.0,0.99,missing,,,,,,,,main_ray.py --batch_size 64 --lr 0.35 --max_tr...,,False,True,False
258,64.0,0.35,-2.0,20.0,1536.0,0.99,missing,,,,,,,,main_ray.py --batch_size 64 --lr 0.35 --max_tr...,,False,True,False
259,64.0,0.35,0.0,10.0,640.0,0.99,missing,,,,,,,,main_ray.py --batch_size 64 --lr 0.35 --max_tr...,,False,True,False
260,64.0,0.35,0.0,20.0,1536.0,0.99,missing,,,,,,,,main_ray.py --batch_size 64 --lr 0.35 --max_tr...,,False,True,False
261,64.0,0.35,2.0,10.0,640.0,0.99,missing,,,,,,,,main_ray.py --batch_size 64 --lr 0.35 --max_tr...,,False,True,False
262,64.0,0.35,2.0,20.0,1536.0,0.99,missing,,,,,,,,main_ray.py --batch_size 64 --lr 0.35 --max_tr...,,False,True,False
263,64.0,0.35,1000.0,10.0,640.0,0.99,missing,,,,,,,,main_ray.py --batch_size 64 --lr 0.35 --max_tr...,,False,True,False


In [18]:
print(df.loc[(df["status"] == "crashed") & (df["is_running"] != True)].groupby(["batch_size", "layers", "hidden_size"]).size())

batch_size  layers  hidden_size
32.0        10.0    640.0          3
            12.0    768.0          1
            20.0    1536.0         3
64.0        10.0    640.0          3
            16.0    1024.0         2
            20.0    1536.0         4
128.0       8.0     512.0          1
            10.0    640.0          3
            12.0    768.0          1
            16.0    1024.0         2
            20.0    1536.0         2
256.0       12.0    768.0          6
            16.0    1024.0         5
512.0       8.0     512.0          1
dtype: int64


In [24]:
print(df.loc[df["status"] == "missing"].groupby(["batch_size", "layers", "hidden_size"]).size())

batch_size  layers  hidden_size
32.0        20.0    1536.0          1
64.0        10.0    640.0           5
            20.0    1536.0          5
256.0       10.0    640.0           5
            16.0    1024.0          2
            20.0    1536.0          5
512.0       10.0    640.0           5
            12.0    768.0          10
            16.0    1024.0         10
            20.0    1536.0          5
dtype: int64


In [20]:
df.loc[(df["status"] == "missing")]

Unnamed: 0,batch_size,lr,mixing_shift,layers,hidden_size,beta2,status,wandb_id,ray_id,step,name,tpu_version,tpu_zone,tpu_count,command,tpu_count_running,is_running,is_needed,is_done
173,32.0,0.25,-2.0,20.0,1536.0,0.99,missing,,,,,,,,main_ray.py --batch_size 32 --lr 0.25 --max_tr...,,False,True,False
255,64.0,0.35,-1000.0,10.0,640.0,0.99,missing,,,,,,,,main_ray.py --batch_size 64 --lr 0.35 --max_tr...,,False,True,False
256,64.0,0.35,-1000.0,20.0,1536.0,0.99,missing,,,,,,,,main_ray.py --batch_size 64 --lr 0.35 --max_tr...,,False,True,False
257,64.0,0.35,-2.0,10.0,640.0,0.99,missing,,,,,,,,main_ray.py --batch_size 64 --lr 0.35 --max_tr...,,False,True,False
258,64.0,0.35,-2.0,20.0,1536.0,0.99,missing,,,,,,,,main_ray.py --batch_size 64 --lr 0.35 --max_tr...,,False,True,False
259,64.0,0.35,0.0,10.0,640.0,0.99,missing,,,,,,,,main_ray.py --batch_size 64 --lr 0.35 --max_tr...,,False,True,False
260,64.0,0.35,0.0,20.0,1536.0,0.99,missing,,,,,,,,main_ray.py --batch_size 64 --lr 0.35 --max_tr...,,False,True,False
261,64.0,0.35,2.0,10.0,640.0,0.99,missing,,,,,,,,main_ray.py --batch_size 64 --lr 0.35 --max_tr...,,False,True,False
262,64.0,0.35,2.0,20.0,1536.0,0.99,missing,,,,,,,,main_ray.py --batch_size 64 --lr 0.35 --max_tr...,,False,True,False
263,64.0,0.35,1000.0,10.0,640.0,0.99,missing,,,,,,,,main_ray.py --batch_size 64 --lr 0.35 --max_tr...,,False,True,False


In [21]:
# print("\n".join(df.loc[(df["status"] == "missing") & (df["batch_size"] == 256)]["command"].values))
print("\n".join(df.loc[(df["status"] == "missing") & (df["batch_size"] == 512) & (df["layers"] > 8)].sort_values("lr")["command"].values))

main_ray.py --batch_size 512 --lr 0.3 --max_training_steps 100000 --num_layers 12 --hidden_size 768 --num_attn_heads 12 --max_seq_len 2048 --cooldown_steps 0.0 --beta2 0.98 --no-break_on_nan --wandb_name gidd-L12-D768-H12-N2048-bs=512-lr=0.3-noise_mask --hybrid_mixing_shift -1000.0 --wandb_tags crit_bs_l12_d768,noise_mask
main_ray.py --batch_size 512 --lr 0.3 --max_training_steps 100000 --num_layers 16 --hidden_size 1024 --num_attn_heads 16 --max_seq_len 2048 --cooldown_steps 0.0 --beta2 0.98 --no-break_on_nan --wandb_name gidd-L16-D1024-H16-N2048-bs=512-lr=0.3-noise_mask --hybrid_mixing_shift -1000.0 --wandb_tags crit_bs_l16_d1024,noise_mask
main_ray.py --batch_size 512 --lr 0.3 --max_training_steps 100000 --num_layers 12 --hidden_size 768 --num_attn_heads 12 --max_seq_len 2048 --cooldown_steps 0.0 --beta2 0.98 --no-break_on_nan --wandb_name gidd-L12-D768-H12-N2048-bs=512-lr=0.3-noise_low_uniform --hybrid_mixing_shift -2.0 --wandb_tags crit_bs_l12_d768,noise_low_uniform
main_ray.py --

In [22]:
scaffold = """ray job submit --no-wait --address "{address}" --working-dir . --runtime-env-json='{{"py_modules":["../EasyDeL/easydel","../eformer/eformer","../orbax/checkpoint/orbax"]}}' -- TPU_VERSION={tpu_version} TPU_POD_COUNT={tpu_count} python"""

zone_to_address = {
    "asia-northeast1-b": "$RAY_ASIA_NORTHEAST1_B",
    "europe-west4-a": "$RAY_EU_WEST4_A",
    "us-central1-a": "$RAY_US_CENTRAL1_A",
    "us-east1-d": "$RAY_US_EAST1_D",
    "us-east5-a": "$RAY_US_EAST5_A",
    "us-east5-b": "$RAY_US_EAST5_B",
}

def format_resume_command(row):
    base_cmd = row["command"]
    wandb_id = row["wandb_id"]
    tpu_zone = row["tpu_zone"]
    tpu_version = row["tpu_version"]
    tpu_count = row["tpu_count"]
    if "--resume_wandb_id" not in base_cmd and not pd.isna(wandb_id):
        base_cmd = f"{base_cmd} --resume_wandb_id {wandb_id}"
    address = zone_to_address.get(tpu_zone, "$RAY_US_EAST1_D")
    return scaffold.format(address=address, tpu_version=tpu_version, tpu_count=tpu_count) + " " + base_cmd.strip()

commands = df.loc[(df["status"] == "crashed") & (df["is_running"] != True) & (df["tpu_zone"] != "us-east5-b")].apply(format_resume_command, axis=1)

with open("resume_commands.txt", "w") as f:
    f.write("\n".join(commands.values))
print("\n".join(commands.values))

ray job submit --no-wait --address "$RAY_EU_WEST4_A" --working-dir . --runtime-env-json='{"py_modules":["../EasyDeL/easydel","../eformer/eformer","../orbax/checkpoint/orbax"]}' -- TPU_VERSION=v6e-8 TPU_POD_COUNT=2 python main_ray.py --batch_size 32 --lr 0.25 --max_training_steps 100000 --num_layers 10 --hidden_size 640 --num_attn_heads 10 --max_seq_len 2048 --cooldown_steps 0.0 --beta2 0.99 --no-break_on_nan --wandb_name gidd-L10-D640-H10-bs=32-lr=0.25-noise_mask --hybrid_mixing_shift -1000.0 --wandb_tags scaling_laws,l10_d640,noise_mask --resume_wandb_id g9kdbc98
ray job submit --no-wait --address "$RAY_ASIA_NORTHEAST1_B" --working-dir . --runtime-env-json='{"py_modules":["../EasyDeL/easydel","../eformer/eformer","../orbax/checkpoint/orbax"]}' -- TPU_VERSION=v6e-8 TPU_POD_COUNT=2 python main_ray.py --batch_size 64 --lr 0.3 --max_training_steps 100000 --num_layers 20 --hidden_size 1536 --num_attn_heads 12 --max_seq_len 2048 --cooldown_steps 0.0 --beta2 0.99 --no-break_on_nan --wandb_na