Skip to content

Commit

Permalink
Add a job_extra flag to the SGECluster (#256)
Browse files Browse the repository at this point in the history
  • Loading branch information
Lnaden authored and lesteve committed Apr 5, 2019
1 parent caeab64 commit 9781a3d
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 3 deletions.
1 change: 1 addition & 0 deletions dask_jobqueue/jobqueue.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ jobqueue:
walltime: '00:30:00'
extra: []
env-extra: []
job-extra: []
log-directory: null

resource-spec: null
Expand Down
7 changes: 7 additions & 0 deletions dask_jobqueue/sge.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ class SGECluster(JobQueueCluster):
Request resources and specify job placement. Passed to `#$ -l` option.
walltime : str
Walltime for each worker job.
job_extra : list
List of other SGE options, for example -w e. Each option will be
prepended with the #$ prefix.
%(JobQueueCluster.parameters)s
Examples
Expand Down Expand Up @@ -58,6 +61,7 @@ def __init__(
project=None,
resource_spec=None,
walltime=None,
job_extra=None,
config_name="sge",
**kwargs
):
Expand All @@ -69,6 +73,8 @@ def __init__(
resource_spec = dask.config.get("jobqueue.%s.resource-spec" % config_name)
if walltime is None:
walltime = dask.config.get("jobqueue.%s.walltime" % config_name)
if job_extra is None:
job_extra = dask.config.get("jobqueue.%s.job-extra" % config_name)

super(SGECluster, self).__init__(config_name=config_name, **kwargs)

Expand All @@ -87,6 +93,7 @@ def __init__(
header_lines.append("#$ -e %(log_directory)s/")
header_lines.append("#$ -o %(log_directory)s/")
header_lines.extend(["#$ -cwd", "#$ -j y"])
header_lines.extend(["#$ %s" % arg for arg in job_extra])
header_template = "\n".join(header_lines)

config = {
Expand Down
39 changes: 36 additions & 3 deletions dask_jobqueue/tests/test_sge.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,15 @@

import pytest
from distributed import Client
from distributed.utils_test import loop # noqa: F401

from dask_jobqueue import SGECluster
import dask

from . import QUEUE_WAIT


@pytest.mark.env("sge") # noqa: F811
def test_basic(loop): # noqa: F811
@pytest.mark.env("sge")
def test_basic(loop):
with SGECluster(
walltime="00:02:00", cores=8, processes=4, memory="2GB", loop=loop
) as cluster:
Expand Down Expand Up @@ -70,3 +69,37 @@ def test_config_name_sge_takes_custom_config():
with dask.config.set({"jobqueue.sge-config-name": conf}):
with SGECluster(config_name="sge-config-name") as cluster:
assert cluster.name == "myname"


def test_job_script(tmpdir):
log_directory = tmpdir.strpath
with SGECluster(
cores=6,
processes=2,
memory="12GB",
queue="my-queue",
project="my-project",
walltime="02:00:00",
env_extra=["export MY_VAR=my_var"],
job_extra=["-w e", "-m e"],
log_directory=log_directory,
resource_spec="h_vmem=12G,mem_req=12G",
) as cluster:
job_script = cluster.job_script()
for each in [
"--nprocs 2",
"--nthreads 3",
"--memory-limit 6.00GB",
"-q my-queue",
"-P my-project",
"-l h_rt=02:00:00",
"export MY_VAR=my_var",
"#$ -w e",
"#$ -m e",
"#$ -e {}".format(log_directory),
"#$ -o {}".format(log_directory),
"-l h_vmem=12G,mem_req=12G",
"#$ -cwd",
"#$ -j y",
]:
assert each in job_script

0 comments on commit 9781a3d

Please sign in to comment.