Skip to content

Commit

Permalink
implement download checkpoint (#106)
Browse files Browse the repository at this point in the history
- mark the already-downloaded file, so they will not be download again. 
- dflow support skip_exists, but it requires that all artifacts are not
tared, which may not be favorable.

Co-authored-by: Han Wang <wang_han@iapcm.ac.cn>
Co-authored-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
  • Loading branch information
3 people committed Dec 28, 2022
1 parent 85e6f2e commit 22e9b13
Show file tree
Hide file tree
Showing 6 changed files with 102 additions and 9 deletions.
3 changes: 2 additions & 1 deletion dpgen2/entrypoint/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def download(
wf_config : Optional[Dict] = {},
wf_keys : Optional[List] = None,
prefix : Optional[str] = None,
chk_pnt : bool = False,
):
wf_config = normalize_args(wf_config)

Expand All @@ -36,5 +37,5 @@ def download(

assert wf_keys is not None
for kk in wf_keys:
download_dpgen2_artifacts(wf, kk, prefix=prefix)
download_dpgen2_artifacts(wf, kk, prefix=prefix, chk_pnt=chk_pnt)
logging.info(f'step {kk} downloaded')
10 changes: 10 additions & 0 deletions dpgen2/entrypoint/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,10 @@ def main_parser() -> argparse.ArgumentParser:
parser_download.add_argument(
"-p","--prefix", type=str, help="the prefix of the path storing the download artifacts"
)
parser_download.add_argument(
"-n","--no-check-point", action='store_false',
help="if specified, download regardless whether check points exist."
)

##########################################
# watch
Expand Down Expand Up @@ -173,6 +177,10 @@ def main_parser() -> argparse.ArgumentParser:
parser_watch.add_argument(
"-p","--prefix", type=str, help="the prefix of the path storing the download artifacts"
)
parser_watch.add_argument(
"-n","--no-check-point", action='store_false',
help="if specified, download regardless whether check points exist."
)

# --version
parser.add_argument(
Expand Down Expand Up @@ -246,6 +254,7 @@ def main():
wfid, config,
wf_keys=args.keys,
prefix=args.prefix,
chk_pnt=args.no_check_point,
)
elif args.command == "watch":
with open(args.CONFIG) as fp:
Expand All @@ -257,6 +266,7 @@ def main():
frequency=args.frequency,
download=args.download,
prefix=args.prefix,
chk_pnt=args.no_check_point,
)
elif args.command is None:
pass
Expand Down
8 changes: 6 additions & 2 deletions dpgen2/entrypoint/watch.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def update_finished_steps(
download : Optional[bool] = False,
watching_keys : Optional[List[str]] = None,
prefix : Optional[str] = None,
chk_pnt : bool = False,
):
wf_keys = wf.query_keys_of_steps()
wf_keys = matched_step_key(wf_keys, watching_keys)
Expand All @@ -44,7 +45,7 @@ def update_finished_steps(
for kk in diff_keys:
logging.info(f'steps {kk.ljust(50,"-")} finished')
if download :
download_dpgen2_artifacts(wf, kk, prefix=prefix)
download_dpgen2_artifacts(wf, kk, prefix=prefix, chk_pnt=chk_pnt)
logging.info(f'steps {kk.ljust(50,"-")} downloaded')
finished_keys = wf_keys
return finished_keys
Expand All @@ -55,8 +56,9 @@ def watch(
wf_config : Optional[Dict] = {},
watching_keys : Optional[List] = default_watching_keys,
frequency : float = 600.,
download : Optional[bool] = False,
download : bool = False,
prefix : Optional[str] = None,
chk_pnt : bool = False,
):
wf_config = normalize_args(wf_config)

Expand All @@ -73,6 +75,7 @@ def watch(
download=download,
watching_keys=watching_keys,
prefix=prefix,
chk_pnt=chk_pnt,
)
time.sleep(frequency)

Expand All @@ -84,6 +87,7 @@ def watch(
download=download,
watching_keys=watching_keys,
prefix=prefix,
chk_pnt=chk_pnt,
)
logging.info("well done")
elif status in ["Failed", "Error"]:
Expand Down
48 changes: 43 additions & 5 deletions dpgen2/utils/download_dpgen2_artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,17 @@
import numpy as np
from pathlib import Path

from dflow import (
Workflow,
)
from dpgen2.utils.dflow_query import(
get_iteration,
get_subkey,
)
from dflow import Workflow,download_artifact
from typing import (
Optional
)


class DownloadDefinition():
Expand Down Expand Up @@ -61,9 +67,10 @@ def add_output(


def download_dpgen2_artifacts(
wf,
key,
prefix = None,
wf : Workflow,
key : str,
prefix : Optional[str] = None,
chk_pnt : bool = False,
):
"""
download the artifacts of a step.
Expand Down Expand Up @@ -96,6 +103,31 @@ def download_dpgen2_artifacts(
raise RuntimeError(f'key {key} does not match any step')
step = step[0]

# download inputs
if len(input_def) == 0 or (chk_pnt and (mypath/subkey/'inputs'/'done').is_file()):
pass
else:
_dload_input_lower(step, mypath, key, subkey, input_def)
if chk_pnt:
(mypath/subkey/'inputs'/'done').touch()
# download outputs
if len(output_def) == 0 or (chk_pnt and (mypath/subkey/'outputs'/'done').is_file()):
pass
else:
_dload_output_lower(step, mypath, key, subkey, output_def)
if chk_pnt:
(mypath/subkey/'outputs'/'done').touch()

return


def _dload_input_lower(
step,
mypath,
key,
subkey,
input_def,
):
for kk in input_def.keys():
pref = mypath / subkey / 'inputs'
ksuff = input_def[kk]
Expand All @@ -111,6 +143,14 @@ def download_dpgen2_artifacts(
# NotImplementedError to be compatible with old versions of dflow
logging.warning(f'cannot download input artifact {kk} of {key}, it may be empty')


def _dload_output_lower(
step,
mypath,
key,
subkey,
output_def,
):
for kk in output_def.keys():
pref = mypath / subkey / 'outputs'
ksuff = output_def[kk]
Expand All @@ -125,5 +165,3 @@ def download_dpgen2_artifacts(
except (NotImplementedError, FileNotFoundError):
# NotImplementedError to be compatible with old versions of dflow
logging.warning(f'cannot download input artifact {kk} of {key}, it may be empty')

return
12 changes: 11 additions & 1 deletion tests/test_prep_run_gaussian.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import numpy as np
import unittest
import unittest, os, shutil
from pathlib import Path

from dpgen2.fp.gaussian import (
Expand Down Expand Up @@ -39,6 +39,9 @@ def test_prep_gaussian(self):
inputs=inputs,
)
assert Path(gaussian_input_name).exists()
for ii in ['task.log', 'task.gjf']:
if Path(ii).exists():
os.remove(ii)


class TestRunGaussian(unittest.TestCase):
Expand Down Expand Up @@ -66,3 +69,10 @@ def test_run_gaussian(self):
)
assert out_name == output
assert log_name == gaussian_output_name
for ii in [output]:
if Path(ii).exists():
shutil.rmtree(ii)
for ii in ['task.log', 'task.gjf']:
if Path(ii).exists():
os.remove(ii)

30 changes: 30 additions & 0 deletions tests/utils/test_dl_dpgen2_arti.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,36 @@ def test_fp_download(self, mocked_dl):
self.assertEqual(ii,jj)


@mock.patch('dpgen2.utils.download_dpgen2_artifacts.download_artifact')
def test_fp_download_chkpnt(self, mocked_dl):
if Path('iter-000001').exists():
shutil.rmtree('iter-000001')
Path("iter-000001/prep-run-fp/inputs").mkdir(parents=True, exist_ok=True)
Path("iter-000001/prep-run-fp/outputs").mkdir(parents=True, exist_ok=True)
download_dpgen2_artifacts(Mockedwf(), 'iter-000001--prep-run-fp', None, chk_pnt=True)
expected = [
mock.call("arti-confs", path=Path("iter-000001/prep-run-fp/inputs"), skip_exists=True),
mock.call("arti-logs", path=Path("iter-000001/prep-run-fp/outputs"), skip_exists=True),
mock.call("arti-labeled_data", path=Path("iter-000001/prep-run-fp/outputs"), skip_exists=True),
]
self.assertEqual(len(mocked_dl.call_args_list), len(expected))
for ii,jj in zip(mocked_dl.call_args_list, expected):
self.assertEqual(ii,jj)
self.assertTrue(Path("iter-000001/prep-run-fp/inputs/done").is_file())
self.assertTrue(Path("iter-000001/prep-run-fp/outputs/done").is_file())

download_dpgen2_artifacts(Mockedwf(), 'iter-000001--prep-run-fp', None, chk_pnt=True)
expected = [
mock.call("arti-confs", path=Path("iter-000001/prep-run-fp/inputs"), skip_exists=True),
mock.call("arti-logs", path=Path("iter-000001/prep-run-fp/outputs"), skip_exists=True),
mock.call("arti-labeled_data", path=Path("iter-000001/prep-run-fp/outputs"), skip_exists=True),
]
self.assertEqual(len(mocked_dl.call_args_list), len(expected))
for ii,jj in zip(mocked_dl.call_args_list, expected):
self.assertEqual(ii,jj)
if Path('iter-000001').exists():
shutil.rmtree('iter-000001')

@mock.patch('dpgen2.utils.download_dpgen2_artifacts.download_artifact')
def test_empty_download(self, mocked_dl):
download_dpgen2_artifacts(Mockedwf(), 'iter-000001--foo', None)
Expand Down

0 comments on commit 22e9b13

Please sign in to comment.