diff --git a/torchrec/distributed/train_pipeline/utils.py b/torchrec/distributed/train_pipeline/utils.py index 0e9631793..b0a51c549 100644 --- a/torchrec/distributed/train_pipeline/utils.py +++ b/torchrec/distributed/train_pipeline/utils.py @@ -678,9 +678,16 @@ def _get_node_args_helper( if child_node.op == "placeholder": if hasattr(child_node, "ph_key"): # pyre-ignore[16] - arg_info.input_attrs.insert(0, child_node.ph_key) - arg_info.is_getitems.insert(0, False) - arg_info.preproc_modules.insert(0, None) + ph_key: str = child_node.ph_key + # example: ph_key = 'event_id_list_features_seqs[marketplace]' + ph_keys = ph_key.split("[") + for key in ph_keys: + if "]" in key: + arg_info.input_attrs.append(key[:-1]) + arg_info.is_getitems.append(True) + else: + arg_info.input_attrs.append(key) + arg_info.is_getitems.append(False) else: # no-op arg_info.input_attrs.insert(0, "") @@ -1038,7 +1045,7 @@ def _rewrite_model( # noqa C901 ) if num_found == total_num_args: - logger.info(f"Module '{node.target}'' will be pipelined") + logger.info(f"Module '{node.target}' will be pipelined") child = sharded_modules[node.target] original_forwards.append(child.forward) child.forward = pipelined_forward(