From 6fba1d0225f830da09e048171b85c53309facde6 Mon Sep 17 00:00:00 2001 From: Huanyu He Date: Wed, 26 Jun 2024 13:39:45 -0700 Subject: [PATCH] modify the placeholder attr parser to handle dict/list types (#2181) Summary: Pull Request resolved: https://github.com/pytorch/torchrec/pull/2181 # context * to handle complex ph_key for the placeholder like the following: ``` (Pdb) arg.op 'placeholder' (Pdb) arg.ph_key 'event_id_list_features_seqs[marketplace]' ``` * original workaround is to modify the `arg_info` in the `_start_data_dist` ``` (Pdb) forward.args [ArgInfo(input_attrs=['event_id_list_features_seqs[user_conv_ads_event]'], is_getitems=[False], name=None)] (Pdb) attr 'event_id_list_features_seqs[user_conv_ads_event]' ``` * according to the ph_key generation, it could be something like `A[key][idx]`. Differential Revision: D59074268 --- torchrec/distributed/train_pipeline/utils.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/torchrec/distributed/train_pipeline/utils.py b/torchrec/distributed/train_pipeline/utils.py index 0e9631793..b0a51c549 100644 --- a/torchrec/distributed/train_pipeline/utils.py +++ b/torchrec/distributed/train_pipeline/utils.py @@ -678,9 +678,16 @@ def _get_node_args_helper( if child_node.op == "placeholder": if hasattr(child_node, "ph_key"): # pyre-ignore[16] - arg_info.input_attrs.insert(0, child_node.ph_key) - arg_info.is_getitems.insert(0, False) - arg_info.preproc_modules.insert(0, None) + ph_key: str = child_node.ph_key + # example: ph_key = 'event_id_list_features_seqs[marketplace]' + ph_keys = ph_key.split("[") + for key in ph_keys: + if "]" in key: + arg_info.input_attrs.append(key[:-1]) + arg_info.is_getitems.append(True) + else: + arg_info.input_attrs.append(key) + arg_info.is_getitems.append(False) else: # no-op arg_info.input_attrs.insert(0, "") @@ -1038,7 +1045,7 @@ def _rewrite_model( # noqa C901 ) if num_found == total_num_args: - logger.info(f"Module '{node.target}'' will be pipelined") + logger.info(f"Module '{node.target}' will be pipelined") child = sharded_modules[node.target] original_forwards.append(child.forward) child.forward = pipelined_forward(