In [33]:
import wandb
import pandas as pd

from notebooks.utils import WANDB_PROJECT

In [34]:
merge_keys = ["batch_size", "lr", "mixing_shift", "layers", "hidden_size", "beta2"]

In [35]:
scaffold = """ray job submit --no-wait --address "{address}" --working-dir . --runtime-env-json='{{"py_modules":["../EasyDeL/easydel","../eformer/eformer","../orbax/checkpoint/orbax"]}}' -- TPU_VERSION={tpu_version} TPU_POD_COUNT={tpu_count} python"""

zone_to_address = {
    "asia-northeast1-b": "$RAY_ASIA_NORTHEAST1_B",
    "europe-west4-a": "$RAY_EU_WEST4_A",
    "us-central1-a": "$RAY_US_CENTRAL1_A",
    "us-east1-d": "$RAY_US_EAST1_D",
    "us-east5-a": "$RAY_US_EAST5_A",
    "us-east5-b": "$RAY_US_EAST5_B",
}


def format_resume_command(row):
    base_cmd = row["command"]
    if pd.isna(base_cmd):
        return None
    wandb_id = row["wandb_id"]
    tpu_zone = row["tpu_zone"]
    tpu_version = row["tpu_version"]
    tpu_count = int(row["tpu_count"])
    max_training_steps = int(row["target_steps"])
    if "--max_training_steps" in base_cmd:
        parts = base_cmd.split("--max_training_steps", 1)
        parts = [parts[0].strip(), f"--max_training_steps {max_training_steps}", parts[1].strip().split(" ", 1)[1]]
        base_cmd = " ".join(parts).strip()
    if "--resume_wandb_id" not in base_cmd and not pd.isna(wandb_id):
        base_cmd = f"{base_cmd} --resume_wandb_id {wandb_id}"
    if "--resume_wandb_id" not in base_cmd:
        return None
    address = zone_to_address.get(tpu_zone, "<unknown>")
    return scaffold.format(address=address, tpu_version=tpu_version, tpu_count=tpu_count) + " " + base_cmd.strip()

def format_command(row):
    noise_label = {
        -1000.0: "noise_mask",
        -2.0: "noise_low_uniform",
        0.0: "noise_balanced",
        2.0: "noise_high_uniform",
        1000.0: "noise_uniform",
    }.get(row["mixing_shift"], "noise_unknown")

    batch_size = int(row["batch_size"])

    args = [
        ("--batch_size", batch_size),
        ("--lr", float(row["lr"])),
        ("--max_training_steps", row["target_steps"]),
        ("--num_layers", int(row["layers"])),
        ("--hidden_size", int(row["hidden_size"])),
        ("--num_attn_heads", int(row["hidden_size"]) // 64 if row["hidden_size"] <= 1024 else int(row["hidden_size"]) // 128),
        ("--max_seq_len", 2048),
        ("--cooldown_steps", 0.0),
        ("--beta2", float(row["beta2"])),
        ("--no-break_on_nan",),
        ("--wandb_name", f"gidd-L{int(row['layers'])}-D{int(row['hidden_size'])}-H{int(row['hidden_size'])//64}-N2048-bs={int(row['batch_size'])}-lr={float(row['lr'])}-{noise_label}"),
        ("--hybrid_mixing_shift", float(row["mixing_shift"])),
        ("--wandb_tags", f"crit_bs_l{int(row['layers'])}_d{int(row['hidden_size'])},{noise_label}"),
    ]
    # if not pd.isna(row["wandb_id"]):
    #     args.append(("--resume_wandb_id", row["wandb_id"]))
    return "python main_ray.py " + " ".join(" ".join(map(str, arg)) for arg in args)

In [36]:
from datetime import datetime, timezone
from zoneinfo import ZoneInfo

dt_local = datetime(2025, 9, 13, 21, 0, tzinfo=ZoneInfo("Europe/Zurich"))
dt_utc = dt_local.astimezone(timezone.utc)

runs = wandb.Api().runs(
    path=WANDB_PROJECT,
    filters={
        "$and": [
            {"$or": [
                {"tags": "crit_bs_l8_d512"},
                {"tags": "crit_bs_l10_d640"},
                {"tags": "crit_bs_l12_d768"},
                {"tags": "crit_bs_l16_d1024"},
                {"tags": "crit_bs_l20_d1536"},
                {"tags": "scaling_laws"},
            ]},
            {"created_at": {"$gte": dt_utc.isoformat()}}
        ],
    },
    order="+created_at",
    per_page=1000,
)
len(runs)

743

In [37]:
from concurrent.futures import ThreadPoolExecutor, as_completed
import subprocess
import re


patterns = {
    "batch_size": re.compile(r"--batch_size (\d+)"),
    "lr": re.compile(r"--lr ([\d.]+)"),
    "layers": re.compile(r"--num_layers (\d+)"),
    "hidden_size": re.compile(r"--hidden_size (\d+)"),
    "beta2": re.compile(r"--beta2 ([\d.]+)"),
    "mixing_shift": re.compile(r"--hybrid_mixing_shift ([\-\d.]+)"),
    "tpu_zone": re.compile(r"TPU_ZONE=([a-z0-9\-]+)"),
    "tpu_version": re.compile(r"TPU_VERSION=([a-z0-9\-]+)"),
    "tpu_count": re.compile(r"TPU_POD_COUNT=(\d+)"),
}

def parse_command(command):
    parsed = {}
    for key, pattern in patterns.items():
        match = pattern.search(command)
        if match:
            val = match.group(1)
            try:
                parsed[key] = float(val) if '.' in val else int(val)
            except ValueError:
                parsed[key] = val
    return parsed

addresses = {
    "asia-northeast1-b": "$RAY_ASIA_NORTHEAST1_B",
    "europe-west4-a": "$RAY_EU_WEST4_A",
    "us-central1-a": "$RAY_US_CENTRAL1_A",
    "us-east1-d": "$RAY_US_EAST1_D",
    "us-east5-a": "$RAY_US_EAST5_A",
    "us-east5-b": "$RAY_US_EAST5_B",
}

# for address in addresses:
#     output = subprocess.run(["zsh", "-lc", f"ray job list --address {address}"], capture_output=True, text=True)
#     print(output)
all_rows = []
with ThreadPoolExecutor(max_workers=len(addresses)) as executor:
    futures = {}
    for tpu_zone, address in addresses.items():
        cmd = f"source ~/.zshrc && ray job list --address {address}"
        futures[executor.submit(subprocess.run, cmd, shell=True, capture_output=True, text=True, executable="/bin/zsh")] = tpu_zone
    for future in as_completed(futures):
        tpu_zone = futures[future]
        output = future.result()
        # print(output)
        lines = output.stdout.splitlines()[1:]
        rows = [parse_command(x) for x in lines if "RUNNING" in x]
        for row in rows:
            row["tpu_zone"] = tpu_zone
            row["is_running"] = True
        all_rows.extend(rows)

running_df = pd.DataFrame(all_rows)
running_df

In [38]:
def hparams_by_bs(layers):
    hparams = {
        8: ([0.1, 0.2, 0.3], [0.99], 80000),
        16: ([0.2, 0.3, 0.5], [0.99], 80000),
        32: ([0.2, 0.3, 0.5], [0.99], 80000),
        64: ([0.3, 0.5, 1.0], [0.99], 80000),
        128: ([0.3, 0.5, 1.0], [0.99], 80000),
        256: ([0.5, 1.0], [0.98], 80000),
        512: ([0.5, 1.0, 2.0], [0.98], 40000 if layers < 20 else 80000),
        # 1024: ([1.0, 2.0], [0.98], 20000),
    }
    if layers >= 20:
        hparams[1024] = ([1.0, 2.0], [0.98], 40000)
    return hparams
target_mixing_shifts = [-1000.0, -2.0, 0.0, 2.0, 1000.0]
target_sizes = [(8, 512), (10, 640), (12, 768), (16, 1024), (20, 1536)]

rows = [
    {
        "batch_size": bs,
        "lr": lr,
        "mixing_shift": ms,
        "layers": l,
        "hidden_size": d,
        "beta2": b2,
        "target_steps": steps,
    }
    for l, d in target_sizes
    for bs, (lrs, b2s, steps) in hparams_by_bs(l).items() for lr in lrs for b2 in b2s
    for ms in target_mixing_shifts
]
needed_df = pd.DataFrame(rows)
needed_df["is_needed"] = True
len(needed_df)

510

In [39]:
import tqdm.auto as tqdm

wandb_rows = []
for run in tqdm.tqdm(runs):
    cfg = run.config
    state = run.state
    wandb_rows.append({
        "batch_size": cfg.get("batch_size"),
        "lr": cfg.get("lr"),
        "mixing_shift": cfg.get("hybrid_mixing_shift"),
        "layers": cfg.get("num_layers"),
        "hidden_size": cfg.get("hidden_size"),
        "beta2": cfg.get("beta2"),
        "status": state,
        "wandb_id": run.id,
        "ray_id": cfg.get("ray_job_id"),
        "step": run.summary.get("_step"),
        # "is_done": run.summary.get("_step") is not None and run.summary.get("_step") >= step_targets.get(cfg.get("batch_size"), 1e6),
        "name": run.name,
        "tpu_version": cfg.get("tpu_version"),
        "tpu_zone": cfg.get("tpu_zone"),
        "tpu_count": cfg.get("tpu_pod_count"),
        "command": cfg.get("command"),
    })

wandb_df = pd.DataFrame(wandb_rows).dropna(how="any")

  0%|          | 0/743 [00:00<?, ?it/s]

In [None]:
existing_df = wandb_df
if len(running_df) > 0:
    existing_df = existing_df.merge(running_df, how="outer", on=merge_keys + ["tpu_zone", "tpu_version"], suffixes=("", "_running"))
else:
    existing_df["tpu_count_running"] = None
    existing_df["is_running"] = False
existing_df.loc[existing_df["status"].isna(), "status"] = "pending"
# existing_df.loc[existing_df["is_done"].isna(), "is_done"] = False
existing_df.loc[existing_df["tpu_count"].isna(), "tpu_count"] = existing_df.loc[existing_df["tpu_count"].isna(), "tpu_count_running"].astype(int)
existing_df

Unnamed: 0,batch_size,lr,mixing_shift,layers,hidden_size,beta2,status,wandb_id,ray_id,step,name,tpu_version,tpu_zone,tpu_count,command,tpu_count_running,is_running
0,16.0,0.1,0,8.0,512.0,0.99,finished,yjkfkoa1,02000000,100000.0,gidd-L8-D512-H8-N2048-bs=16-lr=0.1-noise_balanced,v6e-8,europe-west4-a,1,main_ray.py --batch_size 16 --lr 0.1 --max_tra...,,False
1,16.0,0.5,0,8.0,512.0,0.99,finished,yvu8cfvq,05000000,100000.0,gidd-L8-D512-H8-N2048-bs=16-lr=0.5-noise_balanced,v6e-8,europe-west4-a,1,main_ray.py --batch_size 16 --lr 0.5 --max_tra...,,False
2,16.0,0.3,0,8.0,512.0,0.99,finished,vr4hlb9a,03000000,100000.0,gidd-L8-D512-H8-N2048-bs=16-lr=0.3-noise_balanced,v6e-8,europe-west4-a,1,main_ray.py --batch_size 16 --lr 0.3 --max_tra...,,False
3,16.0,0.2,0,8.0,512.0,0.99,finished,fcn0t3to,04000000,100000.0,gidd-L8-D512-H8-N2048-bs=16-lr=0.2-noise_balanced,v6e-8,europe-west4-a,1,main_ray.py --batch_size 16 --lr 0.2 --max_tra...,,False
4,16.0,0.1,-1000,8.0,512.0,0.99,finished,q2pv9rty,72000000,100000.0,gidd-L8-D512-H8-N2048-bs=16-lr=0.1-noise_mask,v6e-8,asia-northeast1-b,1,main_ray.py --batch_size 16 --lr 0.1 --max_tra...,,False
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
738,512.0,1.0,2,16.0,1024.0,0.98,finished,qe4h8zgv,3e000000,40000.0,gidd-L16-D1024-H16-N2048-bs=512-lr=1.0-noise_h...,v5p-32,us-east5-a,1,main_ray.py --batch_size 512 --micro_batch_siz...,,False
739,512.0,1.0,1000,16.0,1024.0,0.98,finished,uonm8g0t,3f000000,40000.0,gidd-L16-D1024-H16-N2048-bs=512-lr=1.0-noise_u...,v5p-32,us-east5-a,1,main_ray.py --batch_size 512 --micro_batch_siz...,,False
740,512.0,2.0,-1000,16.0,1024.0,0.98,finished,22vahdn7,40000000,40000.0,gidd-L16-D1024-H16-N2048-bs=512-lr=2.0-noise_mask,v5p-32,us-east5-a,1,main_ray.py --batch_size 512 --micro_batch_siz...,,False
741,512.0,2.0,-2,16.0,1024.0,0.98,finished,gasebctv,41000000,40000.0,gidd-L16-D1024-H16-N2048-bs=512-lr=2.0-noise_l...,v5p-32,us-east5-a,1,main_ray.py --batch_size 512 --micro_batch_siz...,,False


In [45]:
df.loc[df["target_steps"].isna()]

Unnamed: 0,batch_size,lr,mixing_shift,layers,hidden_size,beta2,status,wandb_id,ray_id,step,name,tpu_version,tpu_zone,tpu_count,command,tpu_count_running,is_running,target_steps,is_needed,is_done
0,8.0,0.05,-1000.0,8.0,512.0,0.99,finished,41x399b0,14000000,100000.0,gidd-L8-D512-H8-N2048-bs=8-lr=0.05-noise_mask,v6e-8,europe-west4-a,1,main_ray.py --batch_size 8 --lr 0.05 --max_tra...,,False,,False,False
1,8.0,0.05,-1000.0,12.0,768.0,0.99,finished,s0u7gktl,19000000,100000.0,gidd-L12-D768-H12-N2048-bs=8-lr=0.05-noise_mask,v6e-8,europe-west4-a,1,main_ray.py --batch_size 8 --lr 0.05 --max_tra...,,False,,False,False
2,8.0,0.05,-1000.0,16.0,1024.0,0.99,finished,aoycq16e,1e000000,100000.0,gidd-L16-D1024-H16-N2048-bs=8-lr=0.05-noise_mask,v6e-8,europe-west4-a,1,main_ray.py --batch_size 8 --lr 0.05 --max_tra...,,False,,False,False
3,8.0,0.05,-2.0,8.0,512.0,0.99,finished,j9sznhsd,4a000000,100000.0,gidd-L8-D512-H8-N2048-bs=8-lr=0.05-noise_low_u...,v6e-8,us-east5-b,1,main_ray.py --batch_size 8 --lr 0.05 --max_tra...,,False,,False,False
4,8.0,0.05,-2.0,12.0,768.0,0.99,finished,7nd2xgmu,4f000000,100000.0,gidd-L12-D768-H12-N2048-bs=8-lr=0.05-noise_low...,v6e-8,us-east5-b,1,main_ray.py --batch_size 8 --lr 0.05 --max_tra...,,False,,False,False
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
529,256.0,0.30,1000.0,16.0,1024.0,0.98,crashed,ojuz7tsk,27000000,7500.0,gidd-L16-D1024-H16-N2048-bs=256-lr=0.3-noise_u...,v6e-16,europe-west4-a,4,main_ray.py --batch_size 256 --lr 0.3 --max_tr...,,False,,False,False
510,256.0,0.30,-1000.0,16.0,1024.0,0.98,crashed,bnzcdk4l,26000000,7150.0,gidd-L16-D1024-H16-N2048-bs=256-lr=0.3-noise_mask,v6e-16,europe-west4-a,4,main_ray.py --batch_size 256 --lr 0.3 --max_tr...,,False,,False,False
525,256.0,0.30,2.0,16.0,1024.0,0.98,crashed,e3n1n2d2,1a000000,6350.0,gidd-L16-D1024-H16-N2048-bs=256-lr=0.3-noise_h...,v6e-16,us-east5-b,4,main_ray.py --batch_size 256 --lr 0.3 --max_tr...,,False,,False,False
228,32.0,0.25,0.0,20.0,1536.0,0.99,crashed,yzf73zxt,3b000000,3950.0,gidd-L20-D1536-H12-bs=32-lr=0.25-noise_balanced,v6e-8,europe-west4-a,2,main_ray.py --batch_size 32 --lr 0.25 --max_tr...,,False,,False,False


In [50]:
df = existing_df.merge(needed_df, how="outer")
df.loc[:, "is_needed"] = df["is_needed"].fillna(False)

df.loc[df["target_steps"].isna(), "target_steps"] = 80000
df["is_done"] = False
df.loc[df["step"] >= df["target_steps"], "is_done"] = True

df.loc[df["is_running"].isna(), "is_running"] = False
df = df.sort_values(["is_done", "is_running", "step"], ascending=[False, False, False])
df = df.drop_duplicates(subset=merge_keys, keep="first", inplace=False)
df.loc[(df["status"] == "crashed") & df["is_needed"] & (df["step"].isna() | (df["step"] < 100)), "status"] = "missing"
df.loc[df["is_needed"] & df["status"].isna(), "status"] = "missing"
# df.loc[~df["wandb_id"].isna(), "status"] = df.loc[~df["wandb_id"].isna(), "status"].fillna("N/A")
df["resume_command"] = df.apply(format_resume_command, axis=1)
df["command"] = df.apply(format_command, axis=1)
df.loc[df["is_done"], "status"] = "finished"
df

  df.loc[:, "is_needed"] = df["is_needed"].fillna(False)


Unnamed: 0,batch_size,lr,mixing_shift,layers,hidden_size,beta2,status,wandb_id,ray_id,step,...,tpu_version,tpu_zone,tpu_count,command,tpu_count_running,is_running,target_steps,is_needed,is_done,resume_command
0,8.0,0.05,-1000.0,8.0,512.0,0.99,finished,41x399b0,14000000,100000.0,...,v6e-8,europe-west4-a,1,python main_ray.py --batch_size 8 --lr 0.05 --...,,False,80000.0,False,True,"ray job submit --no-wait --address ""$RAY_EU_WE..."
1,8.0,0.05,-1000.0,12.0,768.0,0.99,finished,s0u7gktl,19000000,100000.0,...,v6e-8,europe-west4-a,1,python main_ray.py --batch_size 8 --lr 0.05 --...,,False,80000.0,False,True,"ray job submit --no-wait --address ""$RAY_EU_WE..."
2,8.0,0.05,-1000.0,16.0,1024.0,0.99,finished,aoycq16e,1e000000,100000.0,...,v6e-8,europe-west4-a,1,python main_ray.py --batch_size 8 --lr 0.05 --...,,False,80000.0,False,True,"ray job submit --no-wait --address ""$RAY_EU_WE..."
3,8.0,0.05,-2.0,8.0,512.0,0.99,finished,j9sznhsd,4a000000,100000.0,...,v6e-8,us-east5-b,1,python main_ray.py --batch_size 8 --lr 0.05 --...,,False,80000.0,False,True,"ray job submit --no-wait --address ""$RAY_US_EA..."
4,8.0,0.05,-2.0,12.0,768.0,0.99,finished,7nd2xgmu,4f000000,100000.0,...,v6e-8,us-east5-b,1,python main_ray.py --batch_size 8 --lr 0.05 --...,,False,80000.0,False,True,"ray job submit --no-wait --address ""$RAY_US_EA..."
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
695,1024.0,2.00,-1000.0,20.0,1536.0,0.98,missing,,,,...,,,,python main_ray.py --batch_size 1024 --lr 2.0 ...,,False,40000.0,True,False,
696,1024.0,2.00,-2.0,20.0,1536.0,0.98,missing,,,,...,,,,python main_ray.py --batch_size 1024 --lr 2.0 ...,,False,40000.0,True,False,
698,1024.0,2.00,0.0,20.0,1536.0,0.98,missing,,,,...,,,,python main_ray.py --batch_size 1024 --lr 2.0 ...,,False,40000.0,True,False,
699,1024.0,2.00,2.0,20.0,1536.0,0.98,missing,,,,...,,,,python main_ray.py --batch_size 1024 --lr 2.0 ...,,False,40000.0,True,False,


In [51]:
df.loc[df["is_needed"] == True, "is_done"].value_counts()

is_done
True     336
False    174
Name: count, dtype: int64

In [52]:
notdone_df = df[(df["is_needed"]) & (~df["is_done"])]
notdone_df

Unnamed: 0,batch_size,lr,mixing_shift,layers,hidden_size,beta2,status,wandb_id,ray_id,step,...,tpu_version,tpu_zone,tpu_count,command,tpu_count_running,is_running,target_steps,is_needed,is_done,resume_command
333,64.0,0.3,1000.0,20.0,1536.0,0.99,crashed,cl3rs8d7,1b000000,74450.0,...,v6e-8,asia-northeast1-b,2,python main_ray.py --batch_size 64 --lr 0.3 --...,,False,80000.0,True,False,"ray job submit --no-wait --address ""$RAY_ASIA_..."
320,64.0,0.3,0.0,20.0,1536.0,0.99,crashed,o9g43r2z,1a000000,66900.0,...,v6e-8,asia-northeast1-b,2,python main_ray.py --batch_size 64 --lr 0.3 --...,,False,80000.0,True,False,"ray job submit --no-wait --address ""$RAY_ASIA_..."
534,256.0,0.5,-1000.0,16.0,1024.0,0.98,crashed,l57aw9ji,19000000,57450.0,...,v6e-16,us-east5-b,4,python main_ray.py --batch_size 256 --lr 0.5 -...,,False,80000.0,True,False,"ray job submit --no-wait --address ""$RAY_US_EA..."
317,64.0,0.3,0.0,10.0,640.0,0.99,crashed,yeoig5q7,15000000,54600.0,...,v6e-8,asia-northeast1-b,2,python main_ray.py --batch_size 64 --lr 0.3 --...,,False,80000.0,True,False,"ray job submit --no-wait --address ""$RAY_ASIA_..."
582,256.0,1.0,2.0,16.0,1024.0,0.98,finished,kl046drk,2a000000,52400.0,...,v6e-16,europe-west4-a,4,python main_ray.py --batch_size 256 --lr 1.0 -...,,False,80000.0,True,False,"ray job submit --no-wait --address ""$RAY_EU_WE..."
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
695,1024.0,2.0,-1000.0,20.0,1536.0,0.98,missing,,,,...,,,,python main_ray.py --batch_size 1024 --lr 2.0 ...,,False,40000.0,True,False,
696,1024.0,2.0,-2.0,20.0,1536.0,0.98,missing,,,,...,,,,python main_ray.py --batch_size 1024 --lr 2.0 ...,,False,40000.0,True,False,
698,1024.0,2.0,0.0,20.0,1536.0,0.98,missing,,,,...,,,,python main_ray.py --batch_size 1024 --lr 2.0 ...,,False,40000.0,True,False,
699,1024.0,2.0,2.0,20.0,1536.0,0.98,missing,,,,...,,,,python main_ray.py --batch_size 1024 --lr 2.0 ...,,False,40000.0,True,False,


In [53]:
interrupted_df = df.loc[(df["is_done"] == False) & (df["is_running"] != True) & (df["is_needed"]) & (df["status"] != "missing")]
interrupted_df

Unnamed: 0,batch_size,lr,mixing_shift,layers,hidden_size,beta2,status,wandb_id,ray_id,step,...,tpu_version,tpu_zone,tpu_count,command,tpu_count_running,is_running,target_steps,is_needed,is_done,resume_command
333,64.0,0.3,1000.0,20.0,1536.0,0.99,crashed,cl3rs8d7,1b000000,74450.0,...,v6e-8,asia-northeast1-b,2,python main_ray.py --batch_size 64 --lr 0.3 --...,,False,80000.0,True,False,"ray job submit --no-wait --address ""$RAY_ASIA_..."
320,64.0,0.3,0.0,20.0,1536.0,0.99,crashed,o9g43r2z,1a000000,66900.0,...,v6e-8,asia-northeast1-b,2,python main_ray.py --batch_size 64 --lr 0.3 --...,,False,80000.0,True,False,"ray job submit --no-wait --address ""$RAY_ASIA_..."
534,256.0,0.5,-1000.0,16.0,1024.0,0.98,crashed,l57aw9ji,19000000,57450.0,...,v6e-16,us-east5-b,4,python main_ray.py --batch_size 256 --lr 0.5 -...,,False,80000.0,True,False,"ray job submit --no-wait --address ""$RAY_US_EA..."
317,64.0,0.3,0.0,10.0,640.0,0.99,crashed,yeoig5q7,15000000,54600.0,...,v6e-8,asia-northeast1-b,2,python main_ray.py --batch_size 64 --lr 0.3 --...,,False,80000.0,True,False,"ray job submit --no-wait --address ""$RAY_ASIA_..."
582,256.0,1.0,2.0,16.0,1024.0,0.98,finished,kl046drk,2a000000,52400.0,...,v6e-16,europe-west4-a,4,python main_ray.py --batch_size 256 --lr 1.0 -...,,False,80000.0,True,False,"ray job submit --no-wait --address ""$RAY_EU_WE..."
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
653,512.0,1.0,1000.0,12.0,768.0,0.98,crashed,yytqoild,0e000000,10800.0,...,v6e-64,europe-west4-a,1,python main_ray.py --batch_size 512 --lr 1.0 -...,,False,40000.0,True,False,"ray job submit --no-wait --address ""$RAY_EU_WE..."
546,256.0,0.5,0.0,16.0,1024.0,0.98,crashed,yx9ds0os,17000000,10400.0,...,v6e-64,europe-west4-a,1,python main_ray.py --batch_size 256 --lr 0.5 -...,,False,80000.0,True,False,"ray job submit --no-wait --address ""$RAY_EU_WE..."
625,512.0,1.0,-1000.0,12.0,768.0,0.98,crashed,ws67nyvz,12000000,9550.0,...,v6e-64,europe-west4-a,1,python main_ray.py --batch_size 512 --lr 1.0 -...,,False,40000.0,True,False,"ray job submit --no-wait --address ""$RAY_EU_WE..."
569,256.0,1.0,-2.0,16.0,1024.0,0.98,crashed,76xixgci,0a000000,5300.0,...,v6e-64,us-east5-b,1,python main_ray.py --batch_size 256 --lr 1.0 -...,,False,80000.0,True,False,"ray job submit --no-wait --address ""$RAY_US_EA..."


In [54]:
inprogress_df = df.loc[(df["is_done"] == False) & (df["is_running"] == True) & (df["is_needed"])]
inprogress_df

Unnamed: 0,batch_size,lr,mixing_shift,layers,hidden_size,beta2,status,wandb_id,ray_id,step,...,tpu_version,tpu_zone,tpu_count,command,tpu_count_running,is_running,target_steps,is_needed,is_done,resume_command


In [55]:
missing_df = df.loc[(df["status"] == "missing") & (df["is_needed"])]
missing_df

Unnamed: 0,batch_size,lr,mixing_shift,layers,hidden_size,beta2,status,wandb_id,ray_id,step,...,tpu_version,tpu_zone,tpu_count,command,tpu_count_running,is_running,target_steps,is_needed,is_done,resume_command
16,8.0,0.1,-1000.0,10.0,640.0,0.99,missing,,,,...,,,,python main_ray.py --batch_size 8 --lr 0.1 --m...,,False,80000.0,True,False,
19,8.0,0.1,-1000.0,20.0,1536.0,0.99,missing,,,,...,,,,python main_ray.py --batch_size 8 --lr 0.1 --m...,,False,80000.0,True,False,
21,8.0,0.1,-2.0,10.0,640.0,0.99,missing,,,,...,,,,python main_ray.py --batch_size 8 --lr 0.1 --m...,,False,80000.0,True,False,
24,8.0,0.1,-2.0,20.0,1536.0,0.99,missing,,,,...,,,,python main_ray.py --batch_size 8 --lr 0.1 --m...,,False,80000.0,True,False,
26,8.0,0.1,0.0,10.0,640.0,0.99,missing,,,,...,,,,python main_ray.py --batch_size 8 --lr 0.1 --m...,,False,80000.0,True,False,
29,8.0,0.1,0.0,20.0,1536.0,0.99,missing,,,,...,,,,python main_ray.py --batch_size 8 --lr 0.1 --m...,,False,80000.0,True,False,
31,8.0,0.1,2.0,10.0,640.0,0.99,missing,,,,...,,,,python main_ray.py --batch_size 8 --lr 0.1 --m...,,False,80000.0,True,False,
34,8.0,0.1,2.0,20.0,1536.0,0.99,missing,,,,...,,,,python main_ray.py --batch_size 8 --lr 0.1 --m...,,False,80000.0,True,False,
36,8.0,0.1,1000.0,10.0,640.0,0.99,missing,,,,...,,,,python main_ray.py --batch_size 8 --lr 0.1 --m...,,False,80000.0,True,False,
39,8.0,0.1,1000.0,20.0,1536.0,0.99,missing,,,,...,,,,python main_ray.py --batch_size 8 --lr 0.1 --m...,,False,80000.0,True,False,


In [56]:
print(interrupted_df.groupby(["batch_size", "layers", "hidden_size"]).size())

batch_size  layers  hidden_size
8.0         10.0    640.0           5
            20.0    1536.0          5
16.0        10.0    640.0          10
            20.0    1536.0         10
32.0        10.0    640.0          15
            20.0    1536.0         15
64.0        10.0    640.0          12
            16.0    1024.0          1
            20.0    1536.0         14
128.0       10.0    640.0          10
            16.0    1024.0          1
            20.0    1536.0          1
256.0       10.0    640.0           3
            16.0    1024.0          6
512.0       10.0    640.0          12
            12.0    768.0           8
            20.0    1536.0         15
dtype: int64


In [57]:
print(inprogress_df.groupby(["batch_size", "layers", "hidden_size"]).size())

Series([], dtype: int64)


In [58]:
print(missing_df.groupby(["batch_size", "layers", "hidden_size"]).size())

batch_size  layers  hidden_size
8.0         10.0    640.0           5
            20.0    1536.0          5
128.0       10.0    640.0           5
            20.0    1536.0          5
512.0       10.0    640.0           1
1024.0      20.0    1536.0         10
dtype: int64


In [59]:
print(interrupted_df.loc[(interrupted_df["tpu_zone"] != "us-east5-a")].groupby(["tpu_zone", "batch_size"]).size())

tpu_zone           batch_size
asia-northeast1-b  8.0            4
                   16.0           7
                   32.0          11
                   64.0          12
                   128.0          2
europe-west4-a     8.0            6
                   16.0          13
                   32.0           8
                   64.0           4
                   128.0          3
                   256.0          4
                   512.0         10
us-central1-a      128.0          1
                   256.0          1
us-east1-d         32.0           4
                   64.0           3
                   128.0          2
us-east5-b         32.0           7
                   64.0           8
                   128.0          4
                   256.0          4
                   512.0         10
dtype: int64


In [60]:
print("\n".join(interrupted_df.loc[interrupted_df["tpu_zone"] == "us-east5-a"].sort_values(["batch_size", "layers"])["resume_command"].values))

ray job submit --no-wait --address "$RAY_US_EAST5_A" --working-dir . --runtime-env-json='{"py_modules":["../EasyDeL/easydel","../eformer/eformer","../orbax/checkpoint/orbax"]}' -- TPU_VERSION=v5p-64 TPU_POD_COUNT=1 python main_ray.py --batch_size 512 --micro_batch_size 256 --lr 1.0 --max_training_steps 80000 --num_layers 20 --hidden_size 1536 --num_attn_heads 12 --max_seq_len 2048 --cooldown_steps 0.0 --beta2 0.98 --no-break_on_nan --wandb_name gidd-L20-D1536-H24-N2048-bs=512-lr=1.0-noise_low_uniform --hybrid_mixing_shift -2.0 --wandb_tags crit_bs_l20_d1536,noise_low_uniform --resume_wandb_id 13o58qup
ray job submit --no-wait --address "$RAY_US_EAST5_A" --working-dir . --runtime-env-json='{"py_modules":["../EasyDeL/easydel","../eformer/eformer","../orbax/checkpoint/orbax"]}' -- TPU_VERSION=v5p-64 TPU_POD_COUNT=1 python main_ray.py --batch_size 512 --micro_batch_size 256 --lr 0.5 --max_training_steps 80000 --num_layers 20 --hidden_size 1536 --num_attn_heads 12 --max_seq_len 2048 --coold

In [None]:
# print("\n".join(df.loc[(df["status"] == "missing") & (df["batch_size"] == 256)]["command"].values))
print("\n".join(interrupted_df.loc[interrupted_df["batch_size"] > 1].sort_values(["batch_size", "layers"])["command"].values))

python main_ray.py --batch_size 8 --lr 0.3 --max_training_steps 80000.0 --num_layers 10 --hidden_size 640 --num_attn_heads 10 --max_seq_len 2048 --cooldown_steps 0.0 --beta2 0.99 --no-break_on_nan --wandb_name gidd-L10-D640-H10-N2048-bs=8-lr=0.3-noise_mask --hybrid_mixing_shift -1000.0 --wandb_tags crit_bs_l10_d640,noise_mask
python main_ray.py --batch_size 8 --lr 0.3 --max_training_steps 80000.0 --num_layers 10 --hidden_size 640 --num_attn_heads 10 --max_seq_len 2048 --cooldown_steps 0.0 --beta2 0.99 --no-break_on_nan --wandb_name gidd-L10-D640-H10-N2048-bs=8-lr=0.3-noise_low_uniform --hybrid_mixing_shift -2.0 --wandb_tags crit_bs_l10_d640,noise_low_uniform
python main_ray.py --batch_size 8 --lr 0.3 --max_training_steps 80000.0 --num_layers 10 --hidden_size 640 --num_attn_heads 10 --max_seq_len 2048 --cooldown_steps 0.0 --beta2 0.99 --no-break_on_nan --wandb_name gidd-L10-D640-H10-N2048-bs=8-lr=0.3-noise_balanced --hybrid_mixing_shift 0.0 --wandb_tags crit_bs_l10_d640,noise_balanced
p

In [None]:
# scaffold = """ray job submit --no-wait --address "{address}" --working-dir . --runtime-env-json='{{"py_modules":["../EasyDeL/easydel","../eformer/eformer","../orbax/checkpoint/orbax"]}}' -- TPU_VERSION={tpu_version} TPU_POD_COUNT={tpu_count} python"""

# zone_to_address = {
#     "asia-northeast1-b": "$RAY_ASIA_NORTHEAST1_B",
#     "europe-west4-a": "$RAY_EU_WEST4_A",
#     "us-central1-a": "$RAY_US_CENTRAL1_A",
#     "us-east1-d": "$RAY_US_EAST1_D",
#     "us-east5-a": "$RAY_US_EAST5_A",
#     "us-east5-b": "$RAY_US_EAST5_B",
# }

# def format_resume_command(row):
#     base_cmd = row["command"]
#     wandb_id = row["wandb_id"]
#     tpu_zone = row["tpu_zone"]
#     tpu_version = row["tpu_version"]
#     tpu_count = row["tpu_count"]
#     if "--resume_wandb_id" not in base_cmd and not pd.isna(wandb_id):
#         base_cmd = f"{base_cmd} --resume_wandb_id {wandb_id}"
#     address = zone_to_address.get(tpu_zone, "$RAY_US_EAST1_D")
#     return scaffold.format(address=address, tpu_version=tpu_version, tpu_count=tpu_count) + " " + base_cmd.strip()

# commands = df.loc[(df["status"] == "crashed") & (df["is_running"] != True) & (df["is_needed"])].apply(format_resume_command, axis=1)

# with open("resume_commands.txt", "w") as f:
#     f.write("\n".join(commands.values))
# print("\n".join(commands.values))