In [1]:
import wandb
import pandas as pd

from notebooks.utils import WANDB_PROJECT

In [2]:
merge_keys = ["batch_size", "lr", "mixing_shift", "layers", "hidden_size", "beta2", "tpu_zone", "tpu_version"]

In [3]:
def format_command(row):
    # print(row)
    "main_ray.py --batch_size 256 --lr 0.3 --max_training_steps 100000 --num_layers 16 --hidden_size 1024 --num_attn_heads 16 --max_seq_len 2048 --cooldown_steps 0.0 --beta2 0.98 --no-break_on_nan --wandb_name gidd-L16-D1024-H16-N2048-bs=256-lr=0.3-noise_mask --hybrid_mixing_shift -1000.0 --wandb_tags crit_bs_l16_d1024,noise_mask"
    if not pd.isna(row["command"]) and not pd.isna(row["wandb_id"]):
        print(row["command"])
        if "--resume_wandb_id" in row["command"]:
            return row["command"]
        else:
            return f"{row['command']} --resume_wandb_id {row['wandb_id']}"
    noise_label = {
        -1000.0: "noise_mask",
        -2.0: "noise_low_uniform",
        0.0: "noise_balanced",
        2.0: "noise_high_uniform",
        1000.0: "noise_uniform",
    }.get(row["mixing_shift"], "noise_unknown")

    batch_size = int(row["batch_size"])
    if batch_size < 512:
        max_training_steps = 50000
    elif batch_size < 1024:
        max_training_steps = 40000
    elif batch_size < 2048:
        max_training_steps = 20000
    else:
        max_training_steps = 10000

    args = [
        ("--batch_size", batch_size),
        ("--lr", float(row["lr"])),
        ("--max_training_steps", max_training_steps),
        ("--num_layers", int(row["layers"])),
        ("--hidden_size", int(row["hidden_size"])),
        ("--num_attn_heads", int(row["hidden_size"]) // 64 if row["hidden_size"] <= 1024 else int(row["hidden_size"]) // 128),
        ("--max_seq_len", 2048),
        ("--cooldown_steps", 0.0),
        ("--beta2", float(row["beta2"])),
        ("--no-break_on_nan",),
        ("--wandb_name", f"gidd-L{int(row['layers'])}-D{int(row['hidden_size'])}-H{int(row['hidden_size'])//64}-N2048-bs={int(row['batch_size'])}-lr={float(row['lr'])}-{noise_label}"),
        ("--hybrid_mixing_shift", float(row["mixing_shift"])),
        ("--wandb_tags", f"crit_bs_l{int(row['layers'])}_d{int(row['hidden_size'])},{noise_label}"),
    ]
    return "main_ray.py " + " ".join(" ".join(map(str, arg)) for arg in args)

In [4]:
from datetime import datetime, timezone
from zoneinfo import ZoneInfo

dt_local = datetime(2025, 9, 13, 21, 0, tzinfo=ZoneInfo("Europe/Zurich"))
dt_utc = dt_local.astimezone(timezone.utc)

runs = wandb.Api().runs(
    path=WANDB_PROJECT,
    filters={
        "$and": [
            {"$or": [
                {"tags": "crit_bs_l8_d512"},
                {"tags": "crit_bs_l10_d640"},
                {"tags": "crit_bs_l12_d768"},
                {"tags": "crit_bs_l16_d1024"},
                {"tags": "crit_bs_l20_d1536"},
                {"tags": "scaling_laws"},
            ]},
            {"created_at": {"$gte": dt_utc.isoformat()}}
        ],
    },
    order="+created_at",
    per_page=1000,
)
len(runs)

718

In [5]:
from concurrent.futures import ThreadPoolExecutor, as_completed
import subprocess
import re


patterns = {
    "batch_size": re.compile(r"--batch_size (\d+)"),
    "lr": re.compile(r"--lr ([\d.]+)"),
    "layers": re.compile(r"--num_layers (\d+)"),
    "hidden_size": re.compile(r"--hidden_size (\d+)"),
    "beta2": re.compile(r"--beta2 ([\d.]+)"),
    "mixing_shift": re.compile(r"--hybrid_mixing_shift ([\-\d.]+)"),
    "tpu_zone": re.compile(r"TPU_ZONE=([a-z0-9\-]+)"),
    "tpu_version": re.compile(r"TPU_VERSION=([a-z0-9\-]+)"),
    "tpu_count": re.compile(r"TPU_POD_COUNT=(\d+)"),
}

def parse_command(command):
    parsed = {}
    for key, pattern in patterns.items():
        match = pattern.search(command)
        if match:
            val = match.group(1)
            try:
                parsed[key] = float(val) if '.' in val else int(val)
            except ValueError:
                parsed[key] = val
    return parsed

addresses = {
    "asia-northeast1-b": "$RAY_ASIA_NORTHEAST1_B",
    "europe-west4-a": "$RAY_EU_WEST4_A",
    "us-central1-a": "$RAY_US_CENTRAL1_A",
    "us-east1-d": "$RAY_US_EAST1_D",
    "us-east5-a": "$RAY_US_EAST5_A",
    "us-east5-b": "$RAY_US_EAST5_B",
}

# for address in addresses:
#     output = subprocess.run(["zsh", "-lc", f"ray job list --address {address}"], capture_output=True, text=True)
#     print(output)
all_rows = []
with ThreadPoolExecutor(max_workers=len(addresses)) as executor:
    futures = {}
    for tpu_zone, address in addresses.items():
        cmd = f"source ~/.zshrc && ray job list --address {address}"
        futures[executor.submit(subprocess.run, cmd, shell=True, capture_output=True, text=True, executable="/bin/zsh")] = tpu_zone
    for future in as_completed(futures):
        tpu_zone = futures[future]
        output = future.result()
        # print(output)
        lines = output.stdout.splitlines()[1:]
        rows = [parse_command(x) for x in lines if "RUNNING" in x]
        for row in rows:
            row["tpu_zone"] = tpu_zone
            row["is_running"] = True
        all_rows.extend(rows)

running_df = pd.DataFrame(all_rows)
running_df

Unnamed: 0,batch_size,lr,layers,hidden_size,beta2,mixing_shift,tpu_version,tpu_count,tpu_zone,is_running
0,512,1.0,12,768,0.98,-1000.0,v6e-64,1,europe-west4-a,True
1,512,2.0,12,768,0.98,2.0,v6e-64,1,europe-west4-a,True
2,512,1.0,12,768,0.98,1000.0,v6e-64,1,europe-west4-a,True
3,512,1.0,12,768,0.98,0.0,v6e-64,1,europe-west4-a,True
4,512,2.0,10,640,0.98,-2.0,v6e-64,1,europe-west4-a,True
5,256,0.3,16,1024,0.98,0.0,v6e-64,1,europe-west4-a,True
6,512,1.0,10,640,0.98,-1000.0,v6e-64,1,europe-west4-a,True
7,512,2.0,12,768,0.98,-2.0,v6e-64,1,europe-west4-a,True
8,512,2.0,10,640,0.98,2.0,v6e-64,1,europe-west4-a,True
9,512,0.5,10,640,0.98,-2.0,v6e-64,1,europe-west4-a,True


In [26]:
hparams_by_bs = {
    8: ([0.2, 0.3], [0.99], 50000),
    16: ([0.2, 0.3, 0.5], [0.99], 50000),
    32: ([0.2, 0.3, 0.5], [0.99], 50000),
    64: ([0.3, 0.5, 1.0], [0.99], 50000),
    128: ([0.5, 1.0], [0.99], 50000),
    256: ([0.5, 1.0], [0.98], 50000),
    512: ([0.5, 1.0, 2.0], [0.98], 40000),
    # 1024: ([1.0, 2.0], [0.98], 20000),
}
target_mixing_shifts = [-1000.0, -2.0, 0.0, 2.0, 1000.0]
target_sizes = [(8, 512), (10, 640), (12, 768), (16, 1024), (20, 1536)]

rows = [
    {
        "batch_size": bs,
        "lr": lr,
        "mixing_shift": ms,
        "layers": l,
        "hidden_size": d,
        "beta2": b2,
        "target_steps": steps,
    }
    for bs, (lrs, b2s, steps) in hparams_by_bs.items() for lr in lrs for b2 in b2s
    for ms in target_mixing_shifts
    for l, d in target_sizes
]
len(rows)

450

In [27]:
import tqdm.auto as tqdm

wandb_rows = []
needed_df = pd.DataFrame(rows)
needed_df["is_needed"] = True
for run in tqdm.tqdm(runs):
    cfg = run.config
    state = run.state
    wandb_rows.append({
        "batch_size": cfg.get("batch_size"),
        "lr": cfg.get("lr"),
        "mixing_shift": cfg.get("hybrid_mixing_shift"),
        "layers": cfg.get("num_layers"),
        "hidden_size": cfg.get("hidden_size"),
        "beta2": cfg.get("beta2"),
        "status": state,
        "wandb_id": run.id,
        "ray_id": cfg.get("ray_job_id"),
        "step": run.summary.get("_step"),
        "is_done": run.summary.get("_step") is not None and run.summary.get("_step") >= 50000,
        "name": run.name,
        "tpu_version": cfg.get("tpu_version"),
        "tpu_zone": cfg.get("tpu_zone"),
        "tpu_count": cfg.get("tpu_pod_count"),
        "command": cfg.get("command"),
    })

wandb_df = pd.DataFrame(wandb_rows).dropna(how="any")

  0%|          | 0/718 [00:00<?, ?it/s]

In [28]:
wandb_df.groupby("status").size()

status
crashed      84
finished    550
running      14
dtype: int64

In [29]:
existing_df = wandb_df.merge(running_df, how="outer", on=merge_keys, suffixes=("", "_running"))
existing_df.loc[existing_df["status"].isna(), "status"] = "pending"
existing_df.loc[existing_df["is_done"].isna(), "is_done"] = False
existing_df.loc[existing_df["tpu_count"].isna(), "tpu_count"] = existing_df.loc[existing_df["tpu_count"].isna(), "tpu_count_running"].astype(int)
existing_df

Unnamed: 0,batch_size,lr,mixing_shift,layers,hidden_size,beta2,status,wandb_id,ray_id,step,is_done,name,tpu_version,tpu_zone,tpu_count,command,tpu_count_running,is_running
0,8.0,0.05,-1000.0,8.0,512.0,0.99,finished,41x399b0,14000000,100000.0,True,gidd-L8-D512-H8-N2048-bs=8-lr=0.05-noise_mask,v6e-8,europe-west4-a,1,main_ray.py --batch_size 8 --lr 0.05 --max_tra...,,
1,8.0,0.05,-1000.0,12.0,768.0,0.99,finished,s0u7gktl,19000000,100000.0,True,gidd-L12-D768-H12-N2048-bs=8-lr=0.05-noise_mask,v6e-8,europe-west4-a,1,main_ray.py --batch_size 8 --lr 0.05 --max_tra...,,
2,8.0,0.05,-1000.0,16.0,1024.0,0.99,finished,aoycq16e,1e000000,100000.0,True,gidd-L16-D1024-H16-N2048-bs=8-lr=0.05-noise_mask,v6e-8,europe-west4-a,1,main_ray.py --batch_size 8 --lr 0.05 --max_tra...,,
3,8.0,0.05,-2.0,8.0,512.0,0.99,finished,j9sznhsd,4a000000,100000.0,True,gidd-L8-D512-H8-N2048-bs=8-lr=0.05-noise_low_u...,v6e-8,us-east5-b,1,main_ray.py --batch_size 8 --lr 0.05 --max_tra...,,
4,8.0,0.05,-2.0,12.0,768.0,0.99,finished,7nd2xgmu,4f000000,100000.0,True,gidd-L12-D768-H12-N2048-bs=8-lr=0.05-noise_low...,v6e-8,us-east5-b,1,main_ray.py --batch_size 8 --lr 0.05 --max_tra...,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
645,512.0,2.00,1000.0,20.0,1536.0,0.98,finished,e20iiyum,06000000,20000.0,False,gidd-L20-D1536-H24-N2048-bs=512-lr=2.0-noise_u...,v5p-64,us-east5-a,1,main_ray.py --batch_size 512 --micro_batch_siz...,,
646,1024.0,1.00,0.0,8.0,512.0,0.97,crashed,mgywim12,29000000,33500.0,False,gidd-L8-D512-H8-N2048-bs=1024-lr=1.0-noise_bal...,v5p-64,us-central1-a,1,main_ray.py --batch_size 1024 --lr 1.0 --max_t...,1.0,True
647,1024.0,2.00,0.0,8.0,512.0,0.97,crashed,7uw3u08p,03000000,60600.0,True,gidd-L8-D512-H8-N2048-bs=1024-lr=2.0-noise_bal...,v6e-256,us-east1-d,1,main_ray.py --batch_size 1024 --lr 2.0 --max_t...,1.0,True
648,2048.0,1.00,1000.0,28.0,3584.0,0.98,pending,,,,False,,v5p-512,us-east5-a,1,,1.0,True


In [30]:
df = existing_df.merge(needed_df, how="outer")
df.loc[:, "is_needed"] = df["is_needed"].fillna(False)
df.loc[df["is_done"].isna(), "is_done"] = False
df.loc[df["step"] >= df["target_steps"], "is_done"] = True
df.loc[df["is_running"].isna(), "is_running"] = False
df = df.sort_values(["is_done", "is_running", "step"], ascending=[False, False, False])
df = df.drop_duplicates(subset=merge_keys, keep="first", inplace=False)
df.loc[(df["status"] == "crashed") & df["is_needed"] & (df["step"].isna() | (df["step"] < 100)), "status"] = "missing"
df.loc[df["is_needed"], "status"] = df.loc[df["is_needed"], "status"].fillna("missing")
# df.loc[~df["wandb_id"].isna(), "status"] = df.loc[~df["wandb_id"].isna(), "status"].fillna("N/A")
df.loc[df["command"].isna(), "command"] = df.loc[df["command"].isna()].apply(format_command, axis=1)
df.loc[df["is_done"], "status"] = "finished"
df

  df.loc[:, "is_needed"] = df["is_needed"].fillna(False)


Unnamed: 0,batch_size,lr,mixing_shift,layers,hidden_size,beta2,status,wandb_id,ray_id,step,is_done,name,tpu_version,tpu_zone,tpu_count,command,tpu_count_running,is_running,target_steps,is_needed
654,1024.0,2.00,0.0,8.0,512.0,0.97,finished,7uw3u08p,03000000,60600.0,True,gidd-L8-D512-H8-N2048-bs=1024-lr=2.0-noise_bal...,v6e-256,us-east1-d,1,main_ray.py --batch_size 1024 --lr 2.0 --max_t...,1.0,True,,False
0,8.0,0.05,-1000.0,8.0,512.0,0.99,finished,41x399b0,14000000,100000.0,True,gidd-L8-D512-H8-N2048-bs=8-lr=0.05-noise_mask,v6e-8,europe-west4-a,1,main_ray.py --batch_size 8 --lr 0.05 --max_tra...,,False,,False
1,8.0,0.05,-1000.0,12.0,768.0,0.99,finished,s0u7gktl,19000000,100000.0,True,gidd-L12-D768-H12-N2048-bs=8-lr=0.05-noise_mask,v6e-8,europe-west4-a,1,main_ray.py --batch_size 8 --lr 0.05 --max_tra...,,False,,False
2,8.0,0.05,-1000.0,16.0,1024.0,0.99,finished,aoycq16e,1e000000,100000.0,True,gidd-L16-D1024-H16-N2048-bs=8-lr=0.05-noise_mask,v6e-8,europe-west4-a,1,main_ray.py --batch_size 8 --lr 0.05 --max_tra...,,False,,False
3,8.0,0.05,-2.0,8.0,512.0,0.99,finished,j9sznhsd,4a000000,100000.0,True,gidd-L8-D512-H8-N2048-bs=8-lr=0.05-noise_low_u...,v6e-8,us-east5-b,1,main_ray.py --batch_size 8 --lr 0.05 --max_tra...,,False,,False
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
342,64.0,0.50,2.0,20.0,1536.0,0.99,missing,,,,False,,,,,main_ray.py --batch_size 64 --lr 0.5 --max_tra...,,False,50000.0,True
512,256.0,0.50,-2.0,10.0,640.0,0.98,missing,,,,False,,,,,main_ray.py --batch_size 256 --lr 0.5 --max_tr...,,False,50000.0,True
616,512.0,1.00,2.0,12.0,768.0,0.98,missing,,,,False,,,,,main_ray.py --batch_size 512 --lr 1.0 --max_tr...,,False,40000.0,True
622,512.0,1.00,1000.0,10.0,640.0,0.98,missing,,,,False,,,,,main_ray.py --batch_size 512 --lr 1.0 --max_tr...,,False,40000.0,True


In [31]:
df["status"].value_counts()

status
finished    562
crashed      38
running      14
missing       7
pending       2
Name: count, dtype: int64

In [32]:
df.loc[df["is_running"]].groupby("tpu_zone").size()

tpu_zone
asia-northeast1-b     5
europe-west4-a       14
us-central1-a         2
us-east1-d            1
us-east5-a            2
us-east5-b            4
dtype: int64

In [33]:
print(df.loc[df["is_running"]].groupby(["tpu_zone", "status"]).size())

tpu_zone           status  
asia-northeast1-b  running     5
europe-west4-a     crashed     9
                   running     5
us-central1-a      crashed     2
us-east1-d         finished    1
us-east5-a         pending     2
us-east5-b         running     4
dtype: int64


In [34]:
print(df.loc[(df["status"] != "crashed") & (df["is_running"] == True)].groupby(["tpu_zone", "status"]).size())

tpu_zone           status  
asia-northeast1-b  running     5
europe-west4-a     running     5
us-east1-d         finished    1
us-east5-a         pending     2
us-east5-b         running     4
dtype: int64


In [35]:
print(df.loc[(df["status"] == "crashed") & (df["is_running"] != True)].groupby(["tpu_zone", "status"]).size())

tpu_zone           status 
asia-northeast1-b  crashed     4
europe-west4-a     crashed    13
us-central1-a      crashed     1
us-east5-a         crashed     1
us-east5-b         crashed     8
dtype: int64


In [36]:
df.loc[df["status"] == "running"]

Unnamed: 0,batch_size,lr,mixing_shift,layers,hidden_size,beta2,status,wandb_id,ray_id,step,is_done,name,tpu_version,tpu_zone,tpu_count,command,tpu_count_running,is_running,target_steps,is_needed
327,64.0,0.5,-2.0,10.0,640.0,0.99,running,rnw6qe8w,4000000.0,39300.0,False,gidd-L10-D640-H10-N2048-bs=64-lr=0.5-noise_low...,v6e-8,asia-northeast1-b,1,main_ray.py --batch_size 64 --lr 0.5 --max_tra...,1.0,True,50000.0,True
354,64.0,1.0,-1000.0,20.0,1536.0,0.99,running,1ud79n66,6000000.0,28200.0,False,gidd-L20-D1536-H24-N2048-bs=64-lr=1.0-noise_mask,v6e-16,asia-northeast1-b,1,main_ray.py --batch_size 64 --lr 1.0 --max_tra...,1.0,True,50000.0,True
376,64.0,1.0,1000.0,10.0,640.0,0.99,running,ofmggyla,8000000.0,26350.0,False,gidd-L10-D640-H10-N2048-bs=64-lr=1.0-noise_uni...,v6e-8,asia-northeast1-b,1,main_ray.py --batch_size 64 --lr 1.0 --max_tra...,1.0,True,50000.0,True
330,64.0,0.5,-2.0,20.0,1536.0,0.99,running,a4eplc7h,7000000.0,19950.0,False,gidd-L20-D1536-H24-N2048-bs=64-lr=0.5-noise_lo...,v6e-16,asia-northeast1-b,1,main_ray.py --batch_size 64 --lr 0.5 --max_tra...,1.0,True,50000.0,True
639,512.0,2.0,0.0,10.0,640.0,0.98,running,x4q5avyu,2000000.0,19700.0,False,gidd-L10-D640-H10-N2048-bs=512-lr=2.0-noise_ba...,v6e-64,us-east5-b,1,main_ray.py --batch_size 512 --micro_batch_siz...,1.0,True,40000.0,True
610,512.0,1.0,0.0,12.0,768.0,0.98,running,zmc3k1ox,11000000.0,17050.0,False,gidd-L12-D768-H12-N2048-bs=512-lr=1.0-noise_ba...,v6e-64,europe-west4-a,1,main_ray.py --batch_size 512 --micro_batch_siz...,1.0,True,40000.0,True
627,512.0,2.0,-1000.0,10.0,640.0,0.98,running,2n0tuldl,3000000.0,17050.0,False,gidd-L10-D640-H10-N2048-bs=512-lr=2.0-noise_mask,v6e-64,us-east5-b,1,main_ray.py --batch_size 512 --micro_batch_siz...,1.0,True,40000.0,True
568,512.0,0.5,-1000.0,10.0,640.0,0.98,running,bxq7e999,4000000.0,13400.0,False,gidd-L10-D640-H10-N2048-bs=512-lr=0.5-noise_mask,v6e-64,us-east5-b,1,main_ray.py --batch_size 512 --micro_batch_siz...,1.0,True,40000.0,True
574,512.0,0.5,-2.0,10.0,640.0,0.98,running,ma6ejr03,13000000.0,11200.0,False,gidd-L10-D640-H10-N2048-bs=512-lr=0.5-noise_lo...,v6e-64,europe-west4-a,1,main_ray.py --batch_size 512 --micro_batch_siz...,1.0,True,40000.0,True
609,512.0,1.0,0.0,10.0,640.0,0.98,running,wahmf7yo,9000000.0,11000.0,False,gidd-L10-D640-H10-N2048-bs=512-lr=1.0-noise_ba...,v6e-64,europe-west4-a,1,main_ray.py --batch_size 512 --micro_batch_siz...,1.0,True,40000.0,True


In [37]:
df.loc[(df["status"] == "crashed") & (df["is_running"] != True) & (df["is_needed"])]

Unnamed: 0,batch_size,lr,mixing_shift,layers,hidden_size,beta2,status,wandb_id,ray_id,step,is_done,name,tpu_version,tpu_zone,tpu_count,command,tpu_count_running,is_running,target_steps,is_needed
546,256.0,1.0,0.0,10.0,640.0,0.98,crashed,rfs43olx,0f000000,49150.0,False,gidd-L10-D640-H10-N2048-bs=256-lr=1.0-noise_ba...,v6e-64,us-east5-b,1,main_ray.py --batch_size 256 --lr 1.0 --max_tr...,,False,50000.0,True
464,128.0,1.0,1000.0,20.0,1536.0,0.99,crashed,mt3r8esv,3f000000,41800.0,False,gidd-L20-D1536-H24-N2048-bs=128-lr=1.0-noise_u...,v5p-32,us-east5-a,1,main_ray.py --batch_size 128 --lr 1.0 --max_tr...,,False,50000.0,True
558,256.0,1.0,1000.0,10.0,640.0,0.98,crashed,go0n2aqz,05000000,39850.0,False,gidd-L10-D640-H10-N2048-bs=256-lr=1.0-noise_un...,v6e-64,us-east5-b,1,main_ray.py --batch_size 256 --lr 1.0 --max_tr...,,False,50000.0,True
299,64.0,0.3,-2.0,10.0,640.0,0.99,crashed,mgu7ufrv,16000000,35000.0,False,gidd-L10-D640-H10-bs=64-lr=0.3-noise_low_uniform,v6e-16,europe-west4-a,1,main_ray.py --batch_size 64 --lr 0.3 --max_tra...,,False,50000.0,True
649,512.0,2.0,1000.0,10.0,640.0,0.98,crashed,coaecup7,0a000000,19450.0,False,gidd-L10-D640-H10-N2048-bs=512-lr=2.0-noise_un...,v6e-64,us-east5-b,1,main_ray.py --batch_size 512 --micro_batch_siz...,,False,40000.0,True
302,64.0,0.3,-2.0,20.0,1536.0,0.99,crashed,rcmycsp8,18000000,7000.0,False,gidd-L20-D1536-H12-bs=64-lr=0.3-noise_low_uniform,v6e-16,europe-west4-a,1,main_ray.py --batch_size 64 --lr 0.3 --max_tra...,,False,50000.0,True
315,64.0,0.3,1000.0,10.0,640.0,0.99,crashed,cjp76b8r,0a000000,2950.0,False,gidd-L10-D640-H10-bs=64-lr=0.3-noise_uniform,v6e-16,asia-northeast1-b,1,main_ray.py --batch_size 64 --lr 0.3 --max_tra...,,False,50000.0,True
313,64.0,0.3,2.0,20.0,1536.0,0.99,crashed,vxicz3g2,19000000,2400.0,False,gidd-L20-D1536-H12-bs=64-lr=0.3-noise_high_uni...,v6e-16,europe-west4-a,1,main_ray.py --batch_size 64 --lr 0.3 --max_tra...,,False,50000.0,True
262,32.0,0.5,0.0,10.0,640.0,0.99,crashed,gesgj4gz,1a000000,2350.0,False,gidd-L10-D640-H10-N2048-bs=32-lr=0.5-noise_bal...,v6e-8,europe-west4-a,1,main_ray.py --batch_size 32 --lr 0.5 --max_tra...,,False,50000.0,True
543,256.0,1.0,-2.0,16.0,1024.0,0.98,crashed,vflzemw2,05000000,1300.0,False,gidd-L16-D1024-H16-N2048-bs=256-lr=1.0-noise_l...,v6e-64,us-east5-b,1,main_ray.py --batch_size 256 --lr 1.0 --max_tr...,,False,50000.0,True


In [38]:
df.loc[df["status"] == "missing"]

Unnamed: 0,batch_size,lr,mixing_shift,layers,hidden_size,beta2,status,wandb_id,ray_id,step,is_done,name,tpu_version,tpu_zone,tpu_count,command,tpu_count_running,is_running,target_steps,is_needed
168,16.0,0.5,2.0,20.0,1536.0,0.99,missing,,,,False,,,,,main_ray.py --batch_size 16 --lr 0.5 --max_tra...,,False,50000.0,True
254,32.0,0.5,-1000.0,20.0,1536.0,0.99,missing,,,,False,,,,,main_ray.py --batch_size 32 --lr 0.5 --max_tra...,,False,50000.0,True
342,64.0,0.5,2.0,20.0,1536.0,0.99,missing,,,,False,,,,,main_ray.py --batch_size 64 --lr 0.5 --max_tra...,,False,50000.0,True
512,256.0,0.5,-2.0,10.0,640.0,0.98,missing,,,,False,,,,,main_ray.py --batch_size 256 --lr 0.5 --max_tr...,,False,50000.0,True
616,512.0,1.0,2.0,12.0,768.0,0.98,missing,,,,False,,,,,main_ray.py --batch_size 512 --lr 1.0 --max_tr...,,False,40000.0,True
622,512.0,1.0,1000.0,10.0,640.0,0.98,missing,,,,False,,,,,main_ray.py --batch_size 512 --lr 1.0 --max_tr...,,False,40000.0,True
628,512.0,2.0,-1000.0,12.0,768.0,0.98,missing,,,,False,,,,,main_ray.py --batch_size 512 --lr 2.0 --max_tr...,,False,40000.0,True


In [39]:
print(df.loc[(df["status"] == "crashed") & (df["is_running"] != True) & (df["is_needed"])].groupby(["batch_size", "layers", "hidden_size"]).size())

batch_size  layers  hidden_size
32.0        10.0    640.0          1
64.0        10.0    640.0          2
            20.0    1536.0         2
128.0       20.0    1536.0         1
256.0       10.0    640.0          2
            16.0    1024.0         1
512.0       10.0    640.0          1
dtype: int64


In [40]:
print(df.loc[(df["status"] == "crashed") & (df["is_running"] == True) & (df["is_needed"])].groupby(["batch_size", "layers", "hidden_size"]).size())

batch_size  layers  hidden_size
128.0       20.0    1536.0         1
256.0       10.0    640.0          1
            16.0    1024.0         1
512.0       10.0    640.0          4
            12.0    768.0          2
dtype: int64


In [41]:
print(df.loc[df["status"] == "missing"].groupby(["batch_size", "layers", "hidden_size"]).size())

batch_size  layers  hidden_size
16.0        20.0    1536.0         1
32.0        20.0    1536.0         1
64.0        20.0    1536.0         1
256.0       10.0    640.0          1
512.0       10.0    640.0          1
            12.0    768.0          2
dtype: int64


In [42]:
df.loc[(df["status"] == "missing")]

Unnamed: 0,batch_size,lr,mixing_shift,layers,hidden_size,beta2,status,wandb_id,ray_id,step,is_done,name,tpu_version,tpu_zone,tpu_count,command,tpu_count_running,is_running,target_steps,is_needed
168,16.0,0.5,2.0,20.0,1536.0,0.99,missing,,,,False,,,,,main_ray.py --batch_size 16 --lr 0.5 --max_tra...,,False,50000.0,True
254,32.0,0.5,-1000.0,20.0,1536.0,0.99,missing,,,,False,,,,,main_ray.py --batch_size 32 --lr 0.5 --max_tra...,,False,50000.0,True
342,64.0,0.5,2.0,20.0,1536.0,0.99,missing,,,,False,,,,,main_ray.py --batch_size 64 --lr 0.5 --max_tra...,,False,50000.0,True
512,256.0,0.5,-2.0,10.0,640.0,0.98,missing,,,,False,,,,,main_ray.py --batch_size 256 --lr 0.5 --max_tr...,,False,50000.0,True
616,512.0,1.0,2.0,12.0,768.0,0.98,missing,,,,False,,,,,main_ray.py --batch_size 512 --lr 1.0 --max_tr...,,False,40000.0,True
622,512.0,1.0,1000.0,10.0,640.0,0.98,missing,,,,False,,,,,main_ray.py --batch_size 512 --lr 1.0 --max_tr...,,False,40000.0,True
628,512.0,2.0,-1000.0,12.0,768.0,0.98,missing,,,,False,,,,,main_ray.py --batch_size 512 --lr 2.0 --max_tr...,,False,40000.0,True


In [43]:
# print("\n".join(df.loc[(df["status"] == "missing") & (df["batch_size"] == 256)]["command"].values))
print("\n".join(df.loc[(df["status"] == "missing") & (df["is_running"] != True) & (df["is_needed"]) & (df["batch_size"] > 1)].sort_values(["batch_size", "layers"])["command"].values))

main_ray.py --batch_size 16 --lr 0.5 --max_training_steps 50000 --num_layers 20 --hidden_size 1536 --num_attn_heads 12 --max_seq_len 2048 --cooldown_steps 0.0 --beta2 0.99 --no-break_on_nan --wandb_name gidd-L20-D1536-H24-N2048-bs=16-lr=0.5-noise_high_uniform --hybrid_mixing_shift 2.0 --wandb_tags crit_bs_l20_d1536,noise_high_uniform
main_ray.py --batch_size 32 --lr 0.5 --max_training_steps 50000 --num_layers 20 --hidden_size 1536 --num_attn_heads 12 --max_seq_len 2048 --cooldown_steps 0.0 --beta2 0.99 --no-break_on_nan --wandb_name gidd-L20-D1536-H24-N2048-bs=32-lr=0.5-noise_mask --hybrid_mixing_shift -1000.0 --wandb_tags crit_bs_l20_d1536,noise_mask
main_ray.py --batch_size 64 --lr 0.5 --max_training_steps 50000 --num_layers 20 --hidden_size 1536 --num_attn_heads 12 --max_seq_len 2048 --cooldown_steps 0.0 --beta2 0.99 --no-break_on_nan --wandb_name gidd-L20-D1536-H24-N2048-bs=64-lr=0.5-noise_high_uniform --hybrid_mixing_shift 2.0 --wandb_tags crit_bs_l20_d1536,noise_high_uniform
main

In [44]:
scaffold = """ray job submit --no-wait --address "{address}" --working-dir . --runtime-env-json='{{"py_modules":["../EasyDeL/easydel","../eformer/eformer","../orbax/checkpoint/orbax"]}}' -- TPU_VERSION={tpu_version} TPU_POD_COUNT={tpu_count} python"""

zone_to_address = {
    "asia-northeast1-b": "$RAY_ASIA_NORTHEAST1_B",
    "europe-west4-a": "$RAY_EU_WEST4_A",
    "us-central1-a": "$RAY_US_CENTRAL1_A",
    "us-east1-d": "$RAY_US_EAST1_D",
    "us-east5-a": "$RAY_US_EAST5_A",
    "us-east5-b": "$RAY_US_EAST5_B",
}

def format_resume_command(row):
    base_cmd = row["command"]
    wandb_id = row["wandb_id"]
    tpu_zone = row["tpu_zone"]
    tpu_version = row["tpu_version"]
    tpu_count = row["tpu_count"]
    if "--resume_wandb_id" not in base_cmd and not pd.isna(wandb_id):
        base_cmd = f"{base_cmd} --resume_wandb_id {wandb_id}"
    address = zone_to_address.get(tpu_zone, "$RAY_US_EAST1_D")
    return scaffold.format(address=address, tpu_version=tpu_version, tpu_count=tpu_count) + " " + base_cmd.strip()

commands = df.loc[(df["status"] == "crashed") & (df["is_running"] != True) & (df["is_needed"])].apply(format_resume_command, axis=1)

with open("resume_commands.txt", "w") as f:
    f.write("\n".join(commands.values))
print("\n".join(commands.values))

ray job submit --no-wait --address "$RAY_US_EAST5_B" --working-dir . --runtime-env-json='{"py_modules":["../EasyDeL/easydel","../eformer/eformer","../orbax/checkpoint/orbax"]}' -- TPU_VERSION=v6e-64 TPU_POD_COUNT=1 python main_ray.py --batch_size 256 --lr 1.0 --max_training_steps 50000 --num_layers 10 --hidden_size 640 --num_attn_heads 10 --max_seq_len 2048 --cooldown_steps 0.0 --beta2 0.98 --no-break_on_nan --wandb_name gidd-L10-D640-H10-N2048-bs=256-lr=1.0-noise_balanced --hybrid_mixing_shift 0.0 --wandb_tags crit_bs_l10_d640,noise_balanced --resume_wandb_id rfs43olx
ray job submit --no-wait --address "$RAY_US_EAST5_A" --working-dir . --runtime-env-json='{"py_modules":["../EasyDeL/easydel","../eformer/eformer","../orbax/checkpoint/orbax"]}' -- TPU_VERSION=v5p-32 TPU_POD_COUNT=1 python main_ray.py --batch_size 128 --lr 1.0 --max_training_steps 50000 --num_layers 20 --hidden_size 1536 --num_attn_heads 12 --max_seq_len 2048 --cooldown_steps 0.0 --beta2 0.99 --no-break_on_nan --wandb_nam