diff --git a/dpdispatcher/ssh_context.py b/dpdispatcher/ssh_context.py index 1c00fdf1..771609cd 100644 --- a/dpdispatcher/ssh_context.py +++ b/dpdispatcher/ssh_context.py @@ -169,10 +169,12 @@ class SSHContext (object): def __init__ (self, local_root, ssh_session, + clean_asynchronously=False, ): assert(type(local_root) == str) self.temp_local_root = os.path.abspath(local_root) self.job_uuid = None + self.clean_asynchronously = clean_asynchronously # self.job_uuid = job_uuid # if job_uuid: # self.job_uuid=job_uuid @@ -203,7 +205,8 @@ def from_jdata(cls, jdata): ssh_session = SSHSession(**input) ssh_context = SSHContext( local_root=local_root, - ssh_session=ssh_session + ssh_session=ssh_session, + clean_asynchronously=jdata.get('clean_asynchronously', False), ) return ssh_context @@ -296,8 +299,11 @@ def download(self, def block_checkcall(self, cmd, + asynchronously=False, retry=0) : self.ssh_session.ensure_alive() + if asynchronously: + cmd = "nohup %s >/dev/null &" % cmd stdin, stdout, stderr = self.ssh_session.exec_command(('cd %s ;' % self.remote_root) + cmd) exit_status = stdout.channel.recv_exit_status() if exit_status != 0: @@ -307,7 +313,7 @@ def block_checkcall(self, (exit_status, cmd, self.job_uuid, stderr.read().decode('utf-8'))) dlog.warning("Sleep 60 s and retry the command...") time.sleep(60) - return self.block_checkcall(cmd, retry=retry+1) + return self.block_checkcall(cmd, asynchronously=asynchronously, retry=retry+1) print('debug:self.remote_root, cmd', self.remote_root, cmd) raise RuntimeError("Get error code %d in calling %s through ssh with job: %s . message: %s" % (exit_status, cmd, self.job_uuid, stderr.read().decode('utf-8'))) @@ -322,7 +328,7 @@ def block_call(self, def clean(self) : self.ssh_session.ensure_alive() - self.sftp._rmtree(sftp, self.remote_root) + self._rmtree(self.remote_root) def write_file(self, fname, write_str): self.ssh_session.ensure_alive() @@ -367,17 +373,18 @@ def kill(self, cmd_pipes) : self.block_checkcall('kill -15 %s' % cmd_pipes['pid']) - def _rmtree(self, sftp, remotepath, level=0, verbose = False): - for f in sftp.listdir_attr(remotepath): - rpath = os.path.join(remotepath, f.filename) - if stat.S_ISDIR(f.st_mode): - self._rmtree(sftp, rpath, level=(level + 1)) - else: - rpath = os.path.join(remotepath, f.filename) - if verbose: dlog.info('removing %s%s' % (' ' * level, rpath)) - sftp.remove(rpath) - if verbose: dlog.info('removing %s%s' % (' ' * level, remotepath)) - sftp.rmdir(remotepath) + def _rmtree(self, remotepath, verbose = False): + """Remove the remote path.""" + # The original implementation method removes files one by one using sftp. + # If the latency of the remote server is high, it is very slow. + # Thus, it's better to use system's `rm` to remove a directory, which may + # save a lot of time. + if verbose: + dlog.info('removing %s' % remotepath) + # In some supercomputers, it's very slow to remove large numbers of files + # (e.g. directory containing trajectory) due to bad I/O performance. + # So an asynchronously option is provided. + self.block_checkcall('rm -rf %s' % remotepath, asynchronously=self.clean_asynchronously) def _put_files(self, files,