From fa3d9dd47e23785e3e4ac330083016830452b4e3 Mon Sep 17 00:00:00 2001 From: Guest Date: Thu, 29 Feb 2024 14:34:06 -0500 Subject: [PATCH] fix symlinkat --- parallel-orch/trace_v2.py | 42 ++++++++++++++++++++++++++++----------- 1 file changed, 30 insertions(+), 12 deletions(-) diff --git a/parallel-orch/trace_v2.py b/parallel-orch/trace_v2.py index d0696944..0f072ea4 100644 --- a/parallel-orch/trace_v2.py +++ b/parallel-orch/trace_v2.py @@ -10,8 +10,9 @@ # not handled: listxattr, llistxattr, getxattr, pivot_root, mount, umount2 # setxattr lsetxattr removexattr lremovexattr, fanotify_mark, renameat2, chroot, quotactl -# handled individually openat, open, chdir, clone, rename -# TODO: link, symlink, renameat, symlinkat +# TODO: link, symlink, renameat + +# handled individually openat, open, chdir, clone, rename, symlinkat r_first_path_set = set(['execve', 'stat', 'lstat', 'access', 'statfs', 'readlink', 'execve', 'getxattr', 'lgetxattr']) w_first_path_set = set(['mkdir', 'rmdir', 'truncate', 'creat', 'chmod', 'chown', @@ -105,6 +106,13 @@ def is_ret_err(ret: str): ret = ret.strip() return ret[0] == '-' +def get_ret_file_path(ret: str): + assert not is_ret_err(ret) + ret = ret.strip() + start = ret.find('<') + 1 + end = ret.rfind('>') + return ret[start:end] + def convert_absolute(cur_dir, path): if is_absolute(path): return path @@ -157,11 +165,12 @@ def handle_open_flag(flags): return 'w' def handle_open_common(total_path, flags, ret): - if handle_open_flag(flags) == 'r': - return RFile(total_path) if is_ret_err(ret): return RFile(total_path) - return WFile(total_path) + elif handle_open_flag(flags) == 'r': + return [RFile(total_path), RFile(get_ret_file_path(ret))] + else: + return [WFile(total_path), WFile(get_ret_file_path(ret))] def parse_openat(args, ret): if args.count(',') <= 2: @@ -183,7 +192,7 @@ def parse_open(pid, args, ret, ctx): return handle_open_common(total_path, flags, ret) def get_path_from_fd_path(args): - a0, a1, _ = args.split(sep=',', maxsplit=2) + a0, a1, *_ = args.split(sep=',', maxsplit=2) a1 = parse_string(a1) if len(a1) and a1[0] == '/': return a1 @@ -219,7 +228,11 @@ def parse_clone(pid, args, ret, ctx): flags = flags[len('flags='):] if has_clone_fs(flags): ctx.do_clone(pid, child) - + +def parse_symlinkat(pid, args, ret): + a0, rest = args.split(sep=',', maxsplit=1) + return parse_w_fd_path(rest, ret) + def parse_syscall(pid, syscall, args, ret, ctx): if syscall in r_first_path_set: return parse_r_first_path(pid, args, ret, ctx) @@ -237,6 +250,8 @@ def parse_syscall(pid, syscall, args, ret, ctx): return parse_w_fd_path(args, ret) elif syscall == 'rename': return parse_rename(pid, args, ret, ctx) + elif syscall == 'symlinkat': + return parse_symlinkat(pid, args, ret) elif syscall == 'clone': return parse_clone(pid, args, ret, ctx) elif syscall in ignore_set: @@ -302,14 +317,17 @@ def parse_and_gather_cmd_rw_sets(trace_object) -> Tuple[set, set]: write_set = set() for l in trace_object: try: - record = parse_line(l, ctx) + records = parse_line(l, ctx) except Exception: logging.debug(l) raise ValueError("error while parsing trace") - if type(record) is RFile and record.fname != '/dev/tty': - read_set.add(record.fname) - elif type(record) is WFile and record.fname != '/dev/tty': - write_set.add(record.fname) + if not isinstance(records, list): + records = [records] + for record in records: + if type(record) is RFile and record.fname != '/dev/tty': + read_set.add(record.fname) + elif type(record) is WFile and record.fname != '/dev/tty': + write_set.add(record.fname) return read_set, write_set def main(fname):