Skip to content

Commit 961e06c

Browse files
committed
add script to launch jobs on toolkit
1 parent 56ec144 commit 961e06c

File tree

1 file changed

+104
-0
lines changed

1 file changed

+104
-0
lines changed

scripts/launch_eval.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
import os
2+
from datetime import datetime
3+
from pathlib import Path
4+
from subprocess import run
5+
import sys
6+
7+
JOB_COUNT = 0
8+
MODEL_BATCH_SIZE = {
9+
"bigcode/large-model": 20,
10+
"huggyllama/llama-7b": 20,
11+
"huggyllama/llama-13b": 10,
12+
"huggyllama/llama-30b": 2,
13+
"Salesforce/codegen-2B-mono": 10,
14+
"Salesforce/codegen-16B-multi": 8,
15+
"Salesforce/codegen-16B-mono": 8
16+
}
17+
18+
def get_gen_args(task, model):
19+
if task in ["humaneval", "humaneval-unstripped"]:
20+
batch_size = MODEL_BATCH_SIZE.get(model, 50)
21+
gen_args = f"--max_length_generation 1024 --n_samples 100 --temperature 0.2 --top_p 0.95 --batch_size {batch_size}"
22+
# batch_size = 1
23+
# gen_args = f"--max_length_generation 1024 --n_samples 1 --do_sample False --batch_size {batch_size}"
24+
elif "perturbed-humaneval" in task:
25+
batch_size = 1
26+
gen_args = f"--max_length_generation 1024 --n_samples 1 --do_sample False --batch_size {batch_size}"
27+
else:
28+
raise ValueError(f"{task} and {model}")
29+
return gen_args
30+
31+
32+
def main(model_name, model_revision, task):
33+
global JOB_COUNT
34+
now = datetime.now()
35+
dt_string = now.strftime("%Y_%m_%d_%H_%M_%S")
36+
num_gpu = 4
37+
38+
model_id = model_name.split("/")[-1].lower() # for job-name
39+
model_revision_arg = f"--revision {model_revision}" if model_revision else ""
40+
41+
gen_args = get_gen_args(task, model_name)
42+
43+
# launch_command = "python main.py"
44+
multi_gpu = f"--multi_gpu --num_processes={num_gpu}" if num_gpu > 1 else ""
45+
launch_command = f"accelerate env && accelerate launch {multi_gpu} main.py"
46+
output_path = Path(f"/home/toolkit_tmp/evaluation/bigcode/{model_name}/{model_revision}/{task}-100_samples/evaluation_results.json") # ADJUST
47+
48+
if "greedy" in str(output_path) or "--do_sample False" in gen_args:
49+
assert "greedy" in str(output_path) and "--do_sample False" in gen_args
50+
51+
generations_path = output_path.with_name("generations.json")
52+
os.makedirs(os.path.dirname(output_path), exist_ok=True)
53+
54+
# TF-flags for DS-1000
55+
job_command = f"""cd /app/bigcode-evaluation-harness && pwd && \
56+
TF_FORCE_GPU_ALLOW_GROWTH=true \
57+
TF_CPP_MIN_LOG_LEVEL=3 \
58+
{launch_command} \
59+
--precision bf16 \
60+
--model {model_name} {model_revision_arg} \
61+
--trust_remote_code \
62+
--use_auth_token \
63+
--tasks {task} \
64+
{gen_args} \
65+
--seed 0 \
66+
--allow_code_execution \
67+
--metric_output_path {output_path} \
68+
--save_generations \
69+
--save_generations_path {generations_path} \
70+
"""
71+
toolkit_command = [
72+
"eai", "job", "submit",
73+
# "--image", "volatile-registry.console.elementai.com/snow.raymond/bigcode-evaluation-harness:latest-3months",
74+
# "--image", "volatile-registry.console.elementai.com/snow.raymond/bigcode-evaluation-harness:custom_transformers",
75+
"--image", "volatile-registry.console.elementai.com/snow.raymond/bigcode-evaluation-harness:raymond_patch-3months",
76+
"--restartable",
77+
"--name", f"{task.replace('-', '_')}__{model_id.replace('-', '_').replace('.', '_')}_{JOB_COUNT}__{dt_string}",
78+
"--data", "snow.raymond.home_tmp:/home/toolkit_tmp", # ADJUST
79+
"--data", "snow.code_llm.transformers_cache:/transformers_cache",
80+
"--env", "HOME=/home/toolkit_tmp",
81+
"--env", "HF_HOME=/transformers_cache",
82+
"--cpu", "16",
83+
"--mem", str(150),
84+
"--gpu", str(num_gpu),
85+
"--gpu-mem", "32",
86+
"--", "bash", "-c",
87+
job_command
88+
]
89+
JOB_COUNT += 1
90+
91+
run(toolkit_command)
92+
93+
94+
if __name__ == "__main__":
95+
model_name = "bigcode/sc2-1b-ablations"
96+
97+
# Branch-name or commit-id
98+
# model_revision = ""
99+
model_revision = "repo_context_Random_8k_8k_vocab-114688_freq_1e6"
100+
101+
task = "humaneval"
102+
103+
# for model_revision in []:
104+
main(model_name, model_revision, task)

0 commit comments

Comments
 (0)