Skip to content

Commit

Permalink
Update local_context.py/ssh_context.py (#370)
Browse files Browse the repository at this point in the history
To support wildcards for backward_files and backward_common_files.
Now only support for local_context and ssh_context.
#371

---------

Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
  • Loading branch information
3 people committed Oct 10, 2023
1 parent f984891 commit 7d89246
Show file tree
Hide file tree
Showing 4 changed files with 140 additions and 31 deletions.
54 changes: 50 additions & 4 deletions dpdispatcher/local_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,9 +152,34 @@ def download(
for ii in submission.belonging_tasks:
local_job = os.path.join(self.local_root, ii.task_work_path)
remote_job = os.path.join(self.remote_root, ii.task_work_path)
flist = ii.backward_files
flist = []
for kk in ii.backward_files:
abs_flist_r = glob(os.path.join(remote_job, kk))
abs_flist_l = glob(os.path.join(local_job, kk))
if not abs_flist_r and not abs_flist_l:
if check_exists:
if mark_failure:
tag_file_path = os.path.join(
self.local_root,
ii.task_work_path,
"tag_failure_download_%s" % kk,
)
with open(tag_file_path, "w") as fp:
pass
else:
pass
else:
raise RuntimeError(
"cannot find download file " + os.path.join(remote_job, kk)
)
rel_flist = [
os.path.relpath(ii, start=remote_job) for ii in abs_flist_r
]
flist.extend(rel_flist)
if back_error:
flist += glob(os.path.join(remote_job, "error*"))
abs_flist = glob(os.path.join(remote_job, "error*"))
rel_flist = [os.path.relpath(ii, start=remote_job) for ii in abs_flist]
flist.extend(rel_flist)
for jj in flist:
rfile = os.path.join(remote_job, jj)
lfile = os.path.join(local_job, jj)
Expand Down Expand Up @@ -198,9 +223,30 @@ def download(
pass
local_job = self.local_root
remote_job = self.remote_root
flist = submission.backward_common_files
flist = []
for kk in submission.backward_common_files:
abs_flist_r = glob(os.path.join(remote_job, kk))
abs_flist_l = glob(os.path.join(local_job, kk))
if not abs_flist_r and not abs_flist_l:
if check_exists:
if mark_failure:
tag_file_path = os.path.join(
self.local_root, "tag_failure_download_%s" % kk
)
with open(tag_file_path, "w") as fp:
pass
else:
pass
else:
raise RuntimeError(
"cannot find download file " + os.path.join(remote_job, kk)
)
rel_flist = [os.path.relpath(ii, start=remote_job) for ii in abs_flist_r]
flist.extend(rel_flist)
if back_error:
flist += glob(os.path.join(remote_job, "error*"))
abs_flist = glob(os.path.join(remote_job, "error*"))
rel_flist = [os.path.relpath(ii, start=remote_job) for ii in abs_flist]
flist.extend(rel_flist)
for jj in flist:
rfile = os.path.join(remote_job, jj)
lfile = os.path.join(local_job, jj)
Expand Down
93 changes: 71 additions & 22 deletions dpdispatcher/ssh_context.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#!/usr/bin/env python

import fnmatch
import os
import pathlib
import shlex
Expand All @@ -10,6 +11,7 @@
import uuid
from functools import lru_cache
from glob import glob
from stat import S_ISDIR, S_ISREG
from typing import List

import paramiko
Expand Down Expand Up @@ -414,7 +416,7 @@ def __init__(
assert os.path.isabs(remote_root), "remote_root must be a abspath"
self.temp_remote_root = remote_root
self.remote_profile = remote_profile
self.remote_root = None
self.remote_root = ""

# self.job_uuid = None
self.clean_asynchronously = clean_asynchronously
Expand Down Expand Up @@ -634,6 +636,18 @@ def upload(
tar_compress=self.remote_profile.get("tar_compress", None),
)

def list_remote_dir(self, sftp, remote_dir, ref_remote_root, result_list):
for entry in sftp.listdir_attr(remote_dir):
remote_name = pathlib.PurePath(
os.path.join(remote_dir, entry.filename)
).as_posix()
st_mode = entry.st_mode
if S_ISDIR(st_mode):
self.list_remote_dir(sftp, remote_name, ref_remote_root, result_list)
elif S_ISREG(st_mode):
rel_remote_name = os.path.relpath(remote_name, start=ref_remote_root)
result_list.append(rel_remote_name)

def download(
self,
submission,
Expand All @@ -646,31 +660,66 @@ def download(
self.ssh_session.ensure_alive()
file_list = []
# for ii in job_dirs :
for task in submission.belonging_tasks:
for jj in task.backward_files:
file_name = pathlib.PurePath(
os.path.join(task.task_work_path, jj)
).as_posix()
for ii in submission.belonging_tasks:
remote_file_list = None
for jj in ii.backward_files:
if "*" in jj or "?" in jj:
if remote_file_list is not None:
abs_file_list = fnmatch.filter(remote_file_list, jj)
else:
remote_file_list = []
remote_job = pathlib.PurePath(
os.path.join(self.remote_root, ii.task_work_path)
).as_posix()
self.list_remote_dir(
self.sftp, remote_job, remote_job, remote_file_list
)

abs_file_list = fnmatch.filter(remote_file_list, jj)
rel_file_list = [
pathlib.PurePath(os.path.join(ii.task_work_path, kk)).as_posix()
for kk in abs_file_list
]

else:
rel_file_list = [
pathlib.PurePath(os.path.join(ii.task_work_path, jj)).as_posix()
]
if check_exists:
if self.check_file_exists(file_name):
file_list.append(file_name)
elif mark_failure:
with open(
os.path.join(
self.local_root,
task.task_work_path,
"tag_failure_download_%s" % jj,
),
"w",
) as fp:
for file_name in rel_file_list:
if self.check_file_exists(file_name):
file_list.append(file_name)
elif mark_failure:
with open(
os.path.join(
self.local_root,
ii.task_work_path,
"tag_failure_download_%s" % jj,
),
"w",
) as fp:
pass
else:
pass
else:
pass
else:
file_list.append(file_name)
file_list.extend(rel_file_list)
if back_error:
errors = glob(os.path.join(task.task_work_path, "error*"))
file_list.extend(errors)
if remote_file_list is not None:
abs_errors = fnmatch.filter(remote_file_list, "error*")
else:
remote_file_list = []
remote_job = pathlib.PurePath(
os.path.join(self.remote_root, ii.task_work_path)
).as_posix()
self.list_remote_dir(
self.sftp, remote_job, remote_job, remote_file_list
)
abs_errors = fnmatch.filter(remote_file_list, "error*")
rel_errors = [
pathlib.PurePath(os.path.join(ii.task_work_path, kk)).as_posix()
for kk in abs_errors
]
file_list.extend(rel_errors)
file_list.extend(submission.backward_common_files)
if len(file_list) > 0:
self._get_files(
Expand Down
16 changes: 13 additions & 3 deletions tests/sample_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def get_sample_task_dict(cls):
return task_dict

@classmethod
def get_sample_task_list(cls):
def get_sample_task_list(cls, backward_wildcard=False):
task1 = Task(
command="lmp -i input.lammps",
task_work_path="bct-1/",
Expand All @@ -109,6 +109,16 @@ def get_sample_task_list(cls):
backward_files=["log.lammps"],
)
task_list = [task1, task2, task3, task4]
if backward_wildcard:
task_wildcard = Task(
command="lmp -i input.lammps",
task_work_path="bct-backward_wildcard/",
forward_files=[],
backward_files=["test*/test*"],
outlog="wildcard.log",
errlog="wildcard.err",
)
task_list.append(task_wildcard)
return task_list

@classmethod
Expand All @@ -127,9 +137,9 @@ def get_sample_empty_submission(cls):
return empty_submission

@classmethod
def get_sample_submission(cls):
def get_sample_submission(cls, backward_wildcard=False):
submission = cls.get_sample_empty_submission()
task_list = cls.get_sample_task_list()
task_list = cls.get_sample_task_list(backward_wildcard=backward_wildcard)
submission.register_task_list(task_list)
submission.generate_jobs()
return submission
Expand Down
8 changes: 6 additions & 2 deletions tests/test_ssh_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def setUpClass(cls):
cls.machine = Machine.load_from_dict(mdata)
except (SSHException, socket.timeout):
raise unittest.SkipTest("SSHException ssh cannot connect")
cls.submission = SampleClass.get_sample_submission()
cls.submission = SampleClass.get_sample_submission(backward_wildcard=True)
cls.submission.bind_machine(cls.machine)
cls.submission_hash = cls.submission.submission_hash
file_list = [
Expand All @@ -50,6 +50,8 @@ def setUpClass(cls):
"bct-3/log.lammps",
"bct-4/log.lammps",
"dir with space/file with space",
"bct-backward_wildcard/test456",
"bct-backward_wildcard/test123/test123",
]
for file in file_list:
cls.machine.context.sftp.mkdir(
Expand Down Expand Up @@ -187,7 +189,7 @@ def setUpClass(cls):
cls.machine = Machine.load_from_dict(mdata)
except (SSHException, socket.timeout):
raise unittest.SkipTest("SSHException ssh cannot connect")
cls.submission = SampleClass.get_sample_submission()
cls.submission = SampleClass.get_sample_submission(backward_wildcard=True)
cls.submission.bind_machine(cls.machine)
cls.submission_hash = cls.submission.submission_hash
file_list = [
Expand All @@ -196,6 +198,8 @@ def setUpClass(cls):
"bct-3/log.lammps",
"bct-4/log.lammps",
"dir with space/file with space",
"bct-backward_wildcard/test456",
"bct-backward_wildcard/test123/test123",
]
for file in file_list:
cls.machine.context.sftp.mkdir(
Expand Down

0 comments on commit 7d89246

Please sign in to comment.