-
Notifications
You must be signed in to change notification settings - Fork 1
/
train_all.py
43 lines (37 loc) · 1.42 KB
/
train_all.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
# This script runs train.py for all yaml files in the CaptainCook4D folder.
import os
import threading
import click
@click.command()
@click.option("--more_seeds", is_flag=True, help="Use multiple seeds for error bars.")
def main(more_seeds):
print("more_seeds:", more_seeds)
print("Running train_all_with_gt.py")
def run_training(yaml_file):
print(f"Running training for {yaml_file}")
if more_seeds:
# Use three different seeds for error bars
for seed in [42, 1337, 2024, 2025, 2026]:
command = f"python train.py --config ./configs/CaptainCook4D/{yaml_file} --seed {seed} --log"
os.system(command)
else:
command = f"python train.py --config ./configs/CaptainCook4D/{yaml_file} --log"
os.system(command)
print(f"Finished training for {yaml_file}")
threads = []
for yaml_file in os.listdir("./configs/CaptainCook4D"):
# Start a new thread for each yaml file
thread = threading.Thread(target=run_training, args=(yaml_file,))
threads.append(thread)
thread.start()
# Max 8 threads at a time
if len(threads) >= 8:
# Wait for the threads to complete
for thread in threads:
thread.join()
threads = []
# Wait for all threads to complete
for thread in threads:
thread.join()
if __name__ == "__main__":
main()