forked from weecology/EvergladesTools
/
start_cluster.py
95 lines (77 loc) · 3.02 KB
/
start_cluster.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
"""
Create a cluster of GPU nodes to perform parallel prediction of tiles
"""
import argparse
import sys
import socket
from dask_jobqueue import SLURMCluster
from dask.distributed import Client, wait
import gc
def collect():
gc.collect()
def args():
parser = argparse.ArgumentParser(
description='Simple training script for training a RetinaNet network.')
parser.add_argument('--debug',
help='Run local version without GPU',
action='store_true')
parser.add_argument('--workers', help='Number of dask workers', default="4")
parser.add_argument('--memory_worker', help='GB memory per worker', default="10")
def find_tiles():
"""Read a yaml describing which sites to run"""
pass
def start_tunnel():
"""
Start a juypter session and ssh tunnel to view task progress
"""
host = socket.gethostname()
print("To tunnel into dask dashboard:")
print("For GPU dashboard: ssh -N -L 8787:%s:8787 -l b.weinstein hpg2.rc.ufl.edu" %
(host))
print("For CPU dashboard: ssh -N -L 8781:%s:8781 -l b.weinstein hpg2.rc.ufl.edu" %
(host))
#flush system
sys.stdout.flush()
def start(cpus=0, gpus=0, mem_size="10GB"):
#################
# Setup dask cluster
#################
if cpus > 0:
#job args
extra_args = [
"--error=/orange/idtrees-collab/logs/dask-worker-%j.err", "--account=ewhite",
"--output=/orange/idtrees-collab/logs/dask-worker-%j.out"
]
cluster = SLURMCluster(processes=1,
queue='hpg2-compute',
cores=1,
memory=mem_size,
walltime='24:00:00',
job_extra=extra_args,
extra=['--resources cpu=1'],
scheduler_options={"dashboard_address": ":8781"},
local_directory="/orange/idtrees-collab/tmp/",
death_timeout=300)
print(cluster.job_script())
cluster.scale(cpus)
if gpus:
#job args
extra_args = [
"--error=/orange/idtrees-collab/logs/dask-worker-%j.err", "--account=ewhite",
"--output=/orange/idtrees-collab/logs/dask-worker-%j.out", "--partition=gpu",
"--gpus=1"
]
cluster = SLURMCluster(processes=1,
cores=1,
memory=mem_size,
walltime='24:00:00',
job_extra=extra_args,
extra=['--resources gpu=1'],
scheduler_options={"dashboard_address": ":8787"},
local_directory="/orange/idtrees-collab/tmp/",
death_timeout=300)
cluster.scale(gpus)
dask_client = Client(cluster)
#Start dask
dask_client.run_on_scheduler(start_tunnel)
return dask_client