Skip to content

Commit

Permalink
save exploration_scheduler in global outputs and get exploration_sche…
Browse files Browse the repository at this point in the history
…duler from global outputs in the command 'dpgen2 status' (#129)

Signed-off-by: zjgemi <liuxin_zijian@163.com>

---------

Signed-off-by: zjgemi <liuxin_zijian@163.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Han Wang <wang_han@iapcm.ac.cn>
  • Loading branch information
3 people committed Feb 6, 2023
1 parent 89f7c95 commit d7aa3aa
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 9 deletions.
6 changes: 6 additions & 0 deletions dpgen2/flow/dpgen_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,9 @@ def _loop(
executor=step_executor,
**step_config,
)
scheduler_step.template.outputs.parameters[
"exploration_scheduler"
].global_name = "exploration_scheduler"
steps.add(scheduler_step)

id_step = Step(
Expand Down Expand Up @@ -467,6 +470,9 @@ def _dpgen(
executor=step_executor,
**step_config,
)
scheduler_step.template.outputs.parameters[
"exploration_scheduler"
].global_name = "exploration_scheduler"
steps.add(scheduler_step)

id_step = Step(
Expand Down
20 changes: 15 additions & 5 deletions dpgen2/utils/dflow_query.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
import re
from typing import (
Any,
Expand Down Expand Up @@ -50,20 +51,29 @@ def get_last_scheduler(
"""
get the output Scheduler of the last successful iteration
"""
outputs = wf.query_global_outputs()
if (
outputs is not None
and hasattr(outputs, "parameters")
and "exploration_scheduler" in outputs.parameters
):
return outputs.parameters["exploration_scheduler"].value

logging.warn("Exploration scheduler not found in the global outputs")
scheduler_keys_ = []
for ii in keys:
if get_subkey(ii) == "scheduler":
scheduler_keys_.append(ii)
wf_info = wf.query()
scheduler_steps = wf.query_step_by_key(scheduler_keys_)
scheduler_keys = []
for ii in scheduler_keys_:
if wf_info.get_step(key=ii)[0]["phase"] == "Succeeded":
scheduler_keys.append(ii)
for step in scheduler_steps:
if step["phase"] == "Succeeded":
scheduler_keys.append(step.key)
if len(scheduler_keys) == 0:
return None
else:
skey = sorted(scheduler_keys)[-1]
step = wf_info.get_step(key=skey)[0]
step = [step for step in scheduler_steps if step.key == skey][0]
return step.outputs.parameters["exploration_scheduler"].value


Expand Down
33 changes: 29 additions & 4 deletions tests/utils/test_dflow_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,8 @@ def __init__(self):


class MockedBar:
def __init__(self, xx):
def __init__(self, xx, kk):
self.key = kk
self.outputs = MockedFoo()
self.outputs.parameters["exploration_scheduler"].value = xx * 10

Expand All @@ -133,11 +134,11 @@ def __getitem__(self, key):

def _get_step(key=None):
if key == "init--scheduler":
return [MockedBar(2)]
return [MockedBar(2, key)]
elif key == "iter-0--scheduler":
return [MockedBar(0)]
return [MockedBar(0, key)]
elif key == "iter-1--scheduler":
return [MockedBar(1)]
return [MockedBar(1, key)]
else:
raise RuntimeError("unexpected key")

Expand All @@ -148,12 +149,29 @@ def get_step(self, key=None):


class MockedWF:
def __init__(
self,
none_global=True,
):
self.none_global = none_global

def query_step(self, key=None):
return _get_step(key)

def query(self):
return MockedWFInfo()

def query_global_outputs(self):
# mocked return None: non-global scheduler output
if self.none_global:
return None
else:
return MockedFoo()

def query_step_by_key(self, keys):
ret = [_get_step(kk)[0] for kk in keys]
return ret


class TestDflowQuery(unittest.TestCase):
def test_get_subkey(self):
Expand All @@ -180,6 +198,13 @@ def test_get_last_scheduler(self):
)
self.assertEqual(value, 10)

def test_get_last_scheduler(self):
value = get_last_scheduler(
MockedWF(none_global=False),
["iter-1--scheduler", "foo", "bar", "iter-0--scheduler", "init--scheduler"],
)
self.assertEqual(value, 10)

def test_get_all_schedulers(self):
value = get_all_schedulers(
MockedWF(),
Expand Down

0 comments on commit d7aa3aa

Please sign in to comment.