diff --git a/dpdispatcher/ssh_context.py b/dpdispatcher/ssh_context.py index 796bac71..f655d31f 100644 --- a/dpdispatcher/ssh_context.py +++ b/dpdispatcher/ssh_context.py @@ -502,6 +502,14 @@ def bind_submission(self, submission): self.block_checkcall( f"mv {shlex.quote(old_remote_root)} {shlex.quote(self.remote_root)}" ) + elif ( + old_remote_root is not None + and old_remote_root != self.remote_root + and self.check_file_exists(old_remote_root) + and not len(self.ssh_session.sftp.listdir(old_remote_root)) + ): + # if the new directory exists and the old directory does not contain files, then move the old directory + self._rmtree(old_remote_root) sftp = self.ssh_session.ssh.open_sftp() try: diff --git a/tests/test_ssh_context.py b/tests/test_ssh_context.py index 6434f8d9..83ca23f2 100644 --- a/tests/test_ssh_context.py +++ b/tests/test_ssh_context.py @@ -36,6 +36,7 @@ def setUpClass(cls): "key_filename": "/root/.ssh/id_rsa", }, } + cls.mdata = mdata try: cls.machine = Machine.load_from_dict(mdata) except (SSHException, socket.timeout): @@ -113,6 +114,53 @@ def test_empty_transfer(self): ) submission.run_submission() + def test_recover(self): + """Test recover from a previous submission.""" + machine = Machine.load_from_dict(self.machine.serialize()) + resources = Resources.load_from_dict( + { + "number_node": 1, + "cpu_per_node": 1, + "gpu_per_node": 0, + "queue_name": "?", + "group_size": 1, + } + ) + task = Task( + command="touch times && echo 1 >> times && test $(wc -l < times) -gt 3 && echo done", + task_work_path="./", + forward_files=[], + backward_files=[], + outlog="out.txt", + ) + + submission = Submission( + work_base="./", + machine=machine, + resources=resources, + forward_common_files=[], + backward_common_files=[], + task_list=[task], + ) + try: + submission.run_submission() + except RuntimeError: + # expected to fail, try again + # reinit machine to test machine recover + machine = Machine.load_from_dict(self.mdata) + resources = Resources.load_from_dict(resources.serialize()) + task = Task.deserialize(task.serialize()) + + submission = Submission( + work_base="./", + machine=machine, + resources=resources, + forward_common_files=[], + backward_common_files=[], + task_list=[task], + ) + submission.run_submission() + def test_download(self): self.context.download(self.__class__.submission)