Skip to content

Commit

Permalink
Replace the scheduler in the old workflow by the new one (#114)
Browse files Browse the repository at this point in the history
When resubmit, update the scheduler of the old workflow. Otherwise the
workflow will exactly follow the old schedule and one has no opportunity
to update the schedule.

Also update the report printing: print the trust levels and if the
iteration is converged.

Co-authored-by: Han Wang <wang_han@iapcm.ac.cn>
  • Loading branch information
wanghan-iapcm and Han Wang committed Jan 14, 2023
1 parent 1261939 commit fe0fa83
Show file tree
Hide file tree
Showing 13 changed files with 720 additions and 56 deletions.
6 changes: 5 additions & 1 deletion dpgen2/entrypoint/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,10 @@ def main_parser() -> argparse.ArgumentParser:
"-l", "--list", action='store_true', help="list the Steps of the existing workflow."
)
parser_resubmit.add_argument(
"--reuse", type=str, nargs='+', default=None, help="specify which Steps to reuse."
"-u", "--reuse", type=str, nargs='+', default=None, help="specify which Steps to reuse."
)
parser_resubmit.add_argument(
"-k", "--keep-schedule", action='store_true', help="if set then keep schedule of the old workflow. otherwise use the schedule defined in the input file"
)
parser_resubmit.add_argument(
"-o", "--old-compatible", action='store_true', help="compatible with old-style input script used in dpgen2 < 0.0.6."
Expand Down Expand Up @@ -241,6 +244,7 @@ def main():
list_steps=args.list,
reuse=args.reuse,
old_style=args.old_compatible,
replace_scheduler=(not args.keep_schedule),
)
elif args.command == "status":
with open(args.CONFIG) as fp:
Expand Down
1 change: 1 addition & 0 deletions dpgen2/entrypoint/status.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
)
from dpgen2.utils.dflow_query import (
get_last_scheduler,
get_all_schedulers,
)
from typing import (
Optional, Dict, Union, List,
Expand Down
90 changes: 87 additions & 3 deletions dpgen2/entrypoint/submit.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import glob, dpdata, os, pickle
import glob, dpdata, os, pickle, logging, copy
from pathlib import Path
from dflow import (
InputParameter,
Expand Down Expand Up @@ -82,6 +82,7 @@
workflow_config_from_dict,
matched_step_key,
bohrium_config_from_dict,
get_subkey,
)
from dpgen2.utils.step_config import normalize as normalize_step_dict
from dpgen2.entrypoint.common import (
Expand Down Expand Up @@ -388,10 +389,76 @@ def workflow_concurrent_learning(
return dpgen_step


def get_scheduler_ids(
reuse_step,
):
scheduler_ids = []
for idx,ii in enumerate(reuse_step):
if get_subkey(ii.key, 1) == "scheduler":
scheduler_ids.append(idx)
scheduler_keys = [reuse_step[ii].key for ii in scheduler_ids]
assert(sorted(scheduler_keys) == scheduler_keys),\
"The scheduler keys are not properly sorted"

if len(scheduler_ids) == 0:
logging.warning("No scheduler found in the workflow, "
"does not do any replacement."
)
return scheduler_ids

def update_reuse_step_scheduler(
reuse_step,
scheduler_new,
):
scheduler_ids = get_scheduler_ids(reuse_step)
if len(scheduler_ids) == 0:
return reuse_step

# do replacement
reuse_step[scheduler_ids[-1]].modify_output_parameter(
"exploration_scheduler", scheduler_new)

return reuse_step

def copy_scheduler_plans(
scheduler_new,
scheduler_old,
):
if len(scheduler_old.stage_schedulers) == 0:
return scheduler_new
if len(scheduler_new.stage_schedulers) < len(scheduler_old.stage_schedulers):
raise RuntimeError(
'The new scheduler has less stages than the old scheduler, '
'scheduler copy is not supported.'
)
# the scheduler_old is planned. minic the init call of the scheduler
if scheduler_old.get_iteration() > -1:
scheduler_new.plan_next_iteration()
for ii in range(len(scheduler_old.stage_schedulers)):
old_stage = scheduler_old.stage_schedulers[ii]
old_reports = old_stage.get_reports()
if old_stage.next_iteration() > 0:
if ii != scheduler_new.get_stage():
raise RuntimeError(
f'The stage {scheduler_new.get_stage()} of the new '
f'scheduler does not match'
f'the stage {ii} of the old scheduler. '
f'scheduler, which should not happen'
)
for report in old_reports:
scheduler_new.plan_next_iteration(report)
if old_stage.complete() and \
(not scheduler_new.stage_schedulers[ii].complete()):
scheduler_new.force_stage_complete()
else:
break
return scheduler_new

def submit_concurrent_learning(
wf_config,
reuse_step = None,
old_style = False,
reuse_step : Optional[List[Step]] = None,
old_style : bool = False,
replace_scheduler: bool = False,
):
# normalize args
wf_config = normalize_args(wf_config)
Expand All @@ -401,6 +468,21 @@ def submit_concurrent_learning(
context = global_config_workflow(wf_config, do_lebesgue=do_lebesgue)

dpgen_step = workflow_concurrent_learning(wf_config, old_style=old_style)

if reuse_step is not None and replace_scheduler:
scheduler_new = copy.deepcopy(dpgen_step.inputs.parameters['exploration_scheduler'].value)
idx_old = get_scheduler_ids(reuse_step)[-1]
scheduler_old = reuse_step[idx_old].inputs.parameters['exploration_scheduler'].value
scheduler_new = copy_scheduler_plans(scheduler_new, scheduler_old)
exploration_report = reuse_step[idx_old].inputs.parameters['exploration_report'].value
# plan next
# hack! trajs is set to None...
conv, lmp_task_grp, selector = scheduler_new.plan_next_iteration(exploration_report, trajs=None)
# update output of the scheduler step
reuse_step[idx_old].modify_output_parameter("converged", conv,)
reuse_step[idx_old].modify_output_parameter("exploration_scheduler", scheduler_new,)
reuse_step[idx_old].modify_output_parameter("lmp_task_grp", lmp_task_grp,)
reuse_step[idx_old].modify_output_parameter("conf_selector", selector,)

wf = Workflow(name="dpgen", context=context)
wf.add(dpgen_step)
Expand Down Expand Up @@ -449,6 +531,7 @@ def resubmit_concurrent_learning(
list_steps = False,
reuse = None,
old_style = False,
replace_scheduler = False,
):
wf_config = normalize_args(wf_config)

Expand All @@ -474,6 +557,7 @@ def resubmit_concurrent_learning(
wf_config,
reuse_step=reuse_step,
old_style=old_style,
replace_scheduler=replace_scheduler,
)

return wf
48 changes: 34 additions & 14 deletions dpgen2/exploration/report/report_trust_levels.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,6 @@
from dflow.python import FatalError

class ExplorationReportTrustLevels(ExplorationReport):
# class attrs
spaces = [8, 8, 8, 10, 10, 10]
fmt_str = ' '.join([f'%{ii}s' for ii in spaces])
fmt_flt = '%.4f'
header_str = '#' + fmt_str % ('stage', 'id_stg.', 'iter.', 'accu.', 'cand.', 'fail.')

def __init__(
self,
trust_level,
Expand All @@ -26,6 +20,21 @@ def __init__(
self.v_level = ( (self.trust_level.level_v_lo is not None) and \
(self.trust_level.level_v_hi is not None) )

print_tuple = ('stage', 'id_stg.', 'iter.',
'accu.', 'cand.', 'fail.',
'lvl_f_lo', 'lvl_f_hi',
)
spaces = [8, 8, 8, 10, 10, 10, 10, 10]
if self.v_level:
print_tuple += ('v_lo', 'v_hi',)
spaces += [10, 10]
print_tuple += ('cvged',)
spaces += [8]
self.fmt_str = ' '.join([f'%{ii}s' for ii in spaces])
self.fmt_flt = '%.4f'
self.header_str = '#' + self.fmt_str % print_tuple


def clear(
self,
):
Expand Down Expand Up @@ -186,7 +195,7 @@ def _get_candidates(

def print_header(self) -> str:
r"""Print the header of report"""
return ExplorationReportTrustLevels.header_str
return self.header_str

def print(
self,
Expand All @@ -195,12 +204,23 @@ def print(
iter_idx : int,
) -> str:
r"""Print the report"""
fmt_str = ExplorationReportTrustLevels.fmt_str
fmt_flt = ExplorationReportTrustLevels.fmt_flt
ret = ' ' + fmt_str % (
str(stage_idx), str(idx_in_stage), str(iter_idx),
fmt_flt%(self.accurate_ratio()),
fmt_flt%(self.candidate_ratio()),
fmt_flt%(self.failed_ratio()),
fmt_str = self.fmt_str
fmt_flt = self.fmt_flt
print_tuple = (
str(stage_idx), str(idx_in_stage), str(iter_idx),
fmt_flt%(self.accurate_ratio()),
fmt_flt%(self.candidate_ratio()),
fmt_flt%(self.failed_ratio()),
fmt_flt%(self.trust_level.level_f_lo),
fmt_flt%(self.trust_level.level_f_hi),
)
if self.v_level:
print_tuple += (
fmt_flt%(self.trust_level.level_v_lo),
fmt_flt%(self.trust_level.level_v_hi),
)
print_tuple += (
str(self.converged()),
)
ret = ' ' + fmt_str % print_tuple
return ret
13 changes: 13 additions & 0 deletions dpgen2/exploration/scheduler/convergence_check_stage_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,18 @@ def __init__(
self.complete_ = False
self.reports = []

def get_reports(self):
return self.reports

def complete(self):
return self.complete_

def force_complete(self):
self.complete_ = True

def next_iteration(self):
return self.nxt_iter

def converged(self):
return self.conv

Expand All @@ -44,6 +53,10 @@ def plan_next_iteration(
report : Optional[ExplorationReport] = None,
trajs : Optional[List[Path]] = None,
) -> Tuple[bool, Optional[ExplorationTaskGroup], Optional[ConfSelector]] :
if self.complete():
raise FatalError(
'Cannot plan because the stage has completed.'
)
if report is None:
stg_complete = False
self.conv = stg_complete
Expand Down
54 changes: 50 additions & 4 deletions dpgen2/exploration/scheduler/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ def __init__(
):
self.stage_schedulers = []
self.cur_stage = 0
self.iteration = -1
self.complete_ = False

def add_stage_scheduler(
Expand Down Expand Up @@ -66,7 +65,15 @@ def get_iteration(self):
Iteration index increase when `self.plan_next_iteration` returns valid `lmp_task_grp` and `conf_selector` for the next iteration.
"""
return self.iteration
tot_iter = -1
for idx,ii in enumerate(self.stage_schedulers):
if ii.complete():
# the last plan is not used because the stage
# is found converged
tot_iter += ii.next_iteration() - 1
else:
tot_iter += ii.next_iteration()
return tot_iter

def complete(self):
"""
Expand All @@ -75,6 +82,20 @@ def complete(self):
"""
return self.complete_

def force_stage_complete(self):
"""
Force complete the current stage
"""
self.stage_schedulers[self.cur_stage].force_complete()
self.cur_stage += 1
if self.cur_stage < len(self.stage_schedulers):
# goes to next stage
self.plan_next_iteration()
else:
# all stages complete
self.complete_ = True

def plan_next_iteration(
self,
report : Optional[ExplorationReport] = None,
Expand Down Expand Up @@ -109,7 +130,7 @@ def plan_next_iteration(
)
except FatalError as e:
raise FatalError(f'stage {self.cur_stage}: ' + str(e))

if stg_complete:
self.cur_stage += 1
if self.cur_stage < len(self.stage_schedulers):
Expand All @@ -120,7 +141,6 @@ def plan_next_iteration(
self.complete_ = True
return True, None, None,
else :
self.iteration += 1
return stg_complete, lmp_task_grp, conf_selector


Expand Down Expand Up @@ -188,6 +208,32 @@ def _print_prev_summary(self, prev_stg_idx):
else:
return None


def print_last_iteration(self, print_header=False):
stages = self.stage_schedulers

stage_idx, idx_in_stage, iter_idx = self.get_stage_of_iterations()

if np.size(iter_idx) == 0:
return "No finished iteration found\n"

iidx = np.size(iter_idx)-1

ret = []
if print_header:
ret.append(
stages[stage_idx[iidx]].reports[idx_in_stage[iidx]].print_header())
ret.append(
stages[stage_idx[iidx]].reports[idx_in_stage[iidx]]\
.print(stage_idx[iidx], idx_in_stage[iidx], iidx)
)

if self.complete():
ret.append(f'# All stages converged')
return '\n'.join(ret + [''])



def print_convergence(self):
ret = []
stages = self.stage_schedulers
Expand Down

0 comments on commit fe0fa83

Please sign in to comment.