Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions torchx/cli/cmd_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import torchx.specs as specs
from pyre_extensions import none_throws
from torchx.cli.cmd_base import SubCommand
from torchx.runner import Runner, get_runner
from torchx.runner import Runner, config, get_runner
from torchx.schedulers import get_default_scheduler_name, get_scheduler_factories
from torchx.specs.finder import (
ComponentNotFoundException,
Expand Down Expand Up @@ -53,6 +53,7 @@ def _parse_run_config(arg: str, scheduler_opts: specs.runopts) -> specs.RunConfi
option_type = option.opt_type
typed_value = _convert_to_option_type(value, option_type)
conf.set(key, typed_value)

return conf


Expand Down Expand Up @@ -114,7 +115,9 @@ def add_arguments(self, subparser: argparse.ArgumentParser) -> None:
def _run(self, runner: Runner, args: argparse.Namespace) -> Optional[str]:
run_opts = get_runner().run_opts()
scheduler_opts = run_opts[args.scheduler]
scheduler_args = _parse_run_config(args.scheduler_args, scheduler_opts)
cfg = _parse_run_config(args.scheduler_args, scheduler_opts)
config.apply(scheduler=args.scheduler, cfg=cfg)

if len(args.conf_args) < 1:
none_throws(self._subparser).error(
"the following arguments are required: conf_file, conf_args"
Expand All @@ -129,7 +132,7 @@ def _run(self, runner: Runner, args: argparse.Namespace) -> Optional[str]:
conf_file,
conf_args,
args.scheduler,
scheduler_args,
cfg,
dryrun=args.dryrun,
)
except (ComponentValidationException, ComponentNotFoundException) as e:
Expand Down
4 changes: 0 additions & 4 deletions torchx/runner/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, Union

from pyre_extensions import none_throws
from torchx.runner import config
from torchx.runner.events import log_event
from torchx.schedulers import get_schedulers
from torchx.schedulers.api import Scheduler
Expand Down Expand Up @@ -262,9 +261,6 @@ def dryrun(
)

cfg = cfg or RunConfig()
# TODO enable profiles - https://github.com/pytorch/torchx/issues/248
config.apply(scheduler=scheduler, cfg=cfg, profile="default")

sched = self._scheduler(scheduler)
sched._validate(app, scheduler)
dryrun_info = sched.submit_dryrun(app, cfg)
Expand Down
56 changes: 37 additions & 19 deletions torchx/runner/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def dump(

::

[default.kubernetes.cfg]
[kubernetes]
namespace = default
queue = #FIXME (str)Volcano queue to schedule job in

Expand All @@ -89,7 +89,7 @@ def dump(
for sched_name in scheds:
sched = _get_scheduler(sched_name)

section = f"default.{sched_name}.cfg"
section = f"{sched_name}"
config.add_section(section)

for opt_name, opt in sched.run_opts():
Expand All @@ -114,33 +114,51 @@ def dump(
config.write(f, space_around_delimiters=True)


def apply(scheduler: str, cfg: RunConfig, profile: str = "default") -> None:
def apply(scheduler: str, cfg: RunConfig, dirs: Optional[List[str]] = None) -> None:
"""
Loads .torchxconfig files from predefined locations according
to a load hierarchy and applies the loaded configs into the
given ``runcfg``. The load hierarchy is as follows (in order of precedence):
Loads a ``.torchxconfig`` INI file from the specified directories in
preceding order and applies the run configs for the scheduler onto
the given ``cfg``.

#. ``runcfg`` given to this function
#. configs loaded from ``$HOME/.torchxconfig``
#. configs loaded from ``$CWD/.torchxconfig``
If no ``dirs`` is specified, then it looks for ``.torchxconfig`` in the
current working directory. If a specified directory does not have ``.torchxconfig``
then it is ignored.

Note that load hierarchy does NOT overwrite, but rather adds.
That is, the configs already present in ``runcfg`` are not
overridden during the load.
Note that the configs already present in the given ``cfg`` take precedence
over the ones in the config file and only new configs are added. The same holds
true for the configs loaded in list order.

For instance if ``cfg = {"foo": "bar"}`` and the config file is:

::

# dir_1/.torchxconfig
[local_cwd]
foo = baz
hello = world

# dir_2/.torchxconfig
[local_cwd]
hello = bob


Then after the method call, ``cfg = {"foo": "bar", "hello": "world"}``.
"""
lookup_dirs = [Path.home(), Path.cwd()]

for d in lookup_dirs:
configfile = d / ".torchxconfig"
if not dirs:
dirs = [str(Path.cwd())]

for d in dirs:
configfile = Path(d) / ".torchxconfig"
if configfile.exists():
log.info(f"loading configs from {configfile}")
with open(str(configfile), "r") as f:
load(scheduler, f, cfg, profile)
load(scheduler, f, cfg)


def load(scheduler: str, f: TextIO, cfg: RunConfig, profile: str = "default") -> None:
def load(scheduler: str, f: TextIO, cfg: RunConfig) -> None:
"""
loads the section ``[{profile}.scheduler_args.{scheduler}]`` from the given
loads the section ``[{scheduler}]`` from the given
configfile ``f`` (in .INI format) into the provided ``runcfg``, only adding
configs that are NOT currently in the given ``runcfg`` (e.g. does not
override existing values in ``runcfg``). If no section is found, does nothing.
Expand All @@ -151,7 +169,7 @@ def load(scheduler: str, f: TextIO, cfg: RunConfig, profile: str = "default") ->

runopts = _get_scheduler(scheduler).run_opts()

section = f"{profile}.{scheduler}.cfg"
section = f"{scheduler}"
if config.has_section(section):
for name, value in config.items(section):
if name in cfg.cfgs:
Expand Down
76 changes: 35 additions & 41 deletions torchx/runner/test/config_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,35 +102,27 @@ def run_opts(self) -> runopts:
return opts


_CONFIG = """[default.local_cwd.cfg]
_CONFIG = """[local_cwd]
log_dir = /home/bob/logs
prepend_cwd = True

[test.local_cwd.cfg]
log_dir = None
prepend_cwd = False

[alpha.local_cwd.cfg]
log_dir = /tmp/logs
"""

_CONFIG_INVALID = """[default.test.cfg]
_CONFIG_INVALID = """[test]
a_run_opt_that = does_not_exist
s = option_that_exists
"""

_TEAM_CONFIG = """[default.test.cfg]
_TEAM_CONFIG = """[test]
s = team_default
i = 50
f = 1.2
"""

_MY_CONFIG = """[default.test.cfg]
_MY_CONFIG = """[test]
s = my_default
i = 100
"""

PATH_HOME = "torchx.runner.config.Path.home"
PATH_CWD = "torchx.runner.config.Path.cwd"
TORCHX_GET_SCHEDULERS = "torchx.runner.config.get_schedulers"

Expand Down Expand Up @@ -159,45 +151,50 @@ def _write(self, filename: str, content: str) -> Path:

def test_load(self) -> None:
cfg = RunConfig()
load(profile="default", scheduler="local_cwd", f=StringIO(_CONFIG), cfg=cfg)
load(scheduler="local_cwd", f=StringIO(_CONFIG), cfg=cfg)
self.assertEqual("/home/bob/logs", cfg.get("log_dir"))
self.assertEqual(True, cfg.get("prepend_cwd"))

cfg = RunConfig()
load(profile="test", scheduler="local_cwd", f=StringIO(_CONFIG), cfg=cfg)
self.assertEqual(None, cfg.get("log_dir"))
self.assertEqual(False, cfg.get("prepend_cwd"))

cfg = RunConfig()
load(profile="alpha", scheduler="local_cwd", f=StringIO(_CONFIG), cfg=cfg)
self.assertEqual("/tmp/logs", cfg.get("log_dir"))
self.assertEqual(None, cfg.get("prepend_cwd"))

def test_no_override_load(self) -> None:
cfg = RunConfig()
cfg.set("log_dir", "/foo/bar")
cfg.set("debug", 1)

load(profile="test", scheduler="local_cwd", f=StringIO(_CONFIG), cfg=cfg)
load(scheduler="local_cwd", f=StringIO(_CONFIG), cfg=cfg)
self.assertEqual("/foo/bar", cfg.get("log_dir"))
self.assertEqual(1, cfg.get("debug"))
self.assertEqual(False, cfg.get("prepend_cwd"))
self.assertEqual(True, cfg.get("prepend_cwd"))

@patch(
TORCHX_GET_SCHEDULERS,
return_value={"test": TestScheduler()},
)
def test_apply(self, _) -> None:
def test_apply_default(self, _) -> None:
with patch(PATH_CWD, return_value=Path(self.test_dir)):
with patch(PATH_HOME, return_value=Path(self.test_dir) / "home" / "bob"):
cfg = RunConfig()
cfg.set("s", "runtime_value")
cfg = RunConfig()
cfg.set("s", "runtime_value")

apply(scheduler="test", cfg=cfg)

apply(profile="default", scheduler="test", cfg=cfg)
self.assertEqual("runtime_value", cfg.get("s"))
self.assertEqual(50, cfg.get("i"))
self.assertEqual(1.2, cfg.get("f"))

self.assertEqual("runtime_value", cfg.get("s"))
self.assertEqual(100, cfg.get("i"))
self.assertEqual(1.2, cfg.get("f"))
@patch(
TORCHX_GET_SCHEDULERS,
return_value={"test": TestScheduler()},
)
def test_apply_dirs(self, _) -> None:
cfg = RunConfig()
cfg.set("s", "runtime_value")
apply(
scheduler="test",
cfg=cfg,
dirs=[str(Path(self.test_dir) / "home" / "bob"), self.test_dir],
)
self.assertEqual("runtime_value", cfg.get("s"))
self.assertEqual(100, cfg.get("i"))
self.assertEqual(1.2, cfg.get("f"))

def test_dump_invalid_scheduler(self) -> None:
with self.assertRaises(ValueError):
Expand All @@ -215,7 +212,7 @@ def test_dump_only_required(self, _) -> None:

cfg = RunConfig()
sfile.seek(0)
load(profile="default", scheduler="test", f=sfile, cfg=cfg)
load(scheduler="test", f=sfile, cfg=cfg)

self.assertFalse(cfg.cfgs)

Expand All @@ -226,7 +223,6 @@ def test_dump_only_required(self, _) -> None:
def test_load_invalid_runopt(self, _) -> None:
cfg = RunConfig()
load(
profile="default",
scheduler="test",
f=StringIO(_CONFIG_INVALID),
cfg=cfg,
Expand All @@ -241,7 +237,6 @@ def test_load_invalid_runopt(self, _) -> None:
def test_load_no_section(self) -> None:
cfg = RunConfig()
load(
profile="default",
scheduler="local_cwd",
f=StringIO(),
cfg=cfg,
Expand All @@ -250,9 +245,8 @@ def test_load_no_section(self) -> None:
self.assertFalse(cfg.cfgs)

load(
profile="default",
scheduler="local_cwd",
f=StringIO("[default.scheduler_args.local_cwd]\n"),
f=StringIO("[scheduler_args.local_cwd]\n"),
cfg=cfg,
)
# still empty
Expand All @@ -269,7 +263,7 @@ def test_dump_and_load_all_runopt_types(self, _) -> None:
sfile.seek(0)

cfg = RunConfig()
load(profile="default", scheduler="test", f=sfile, cfg=cfg)
load(scheduler="test", f=sfile, cfg=cfg)

# all runopts in the TestScheduler have defaults, just check against those
for opt_name, opt in TestScheduler().run_opts():
Expand All @@ -282,11 +276,11 @@ def test_dump_and_load_all_registered_schedulers(self) -> None:

sfile = StringIO()
dump(sfile)
print(sfile.getvalue())

for sched_name, sched in get_schedulers(session_name="_").items():
sfile.seek(0) # reset the file pos
cfg = RunConfig()
load(profile="default", scheduler=sched_name, f=sfile, cfg=cfg)
load(scheduler=sched_name, f=sfile, cfg=cfg)

for opt_name, _ in sched.run_opts():
self.assertTrue(opt_name in cfg.cfgs)