From 85b4ab9725a0d320bb5063418241346c4f388364 Mon Sep 17 00:00:00 2001 From: Paul Zhang Date: Fri, 14 Jun 2024 12:44:41 -0700 Subject: [PATCH] OSS fix train_pipelines utils test (#2114) Summary: Pull Request resolved: https://github.com/pytorch/torchrec/pull/2114 Differential Revision: D58594592 --- torchrec/distributed/tests/test_pt2.py | 36 ++++++++++++++++--- ...utils.py => test_train_pipelines_utils.py} | 2 +- 2 files changed, 33 insertions(+), 5 deletions(-) rename torchrec/distributed/train_pipeline/tests/{test_utils.py => test_train_pipelines_utils.py} (96%) diff --git a/torchrec/distributed/tests/test_pt2.py b/torchrec/distributed/tests/test_pt2.py index 3ea551853..90c8eade7 100644 --- a/torchrec/distributed/tests/test_pt2.py +++ b/torchrec/distributed/tests/test_pt2.py @@ -179,10 +179,37 @@ def get(self) -> int: def set(self, val): self.counter_ = val + @torch._library.register_fake_class("fbgemm::TensorQueue") + class FakeTensorQueue: + def __init__(self, queue, init_tensor): + self.queue = queue + self.init_tensor = init_tensor + + @classmethod + def __obj_unflatten__(cls, flattened_ctx): + return cls(**dict(flattened_ctx)) + + def push(self, x): + self.queue.append(x) + + def pop(self): + if len(self.queue) == 0: + return self.init_tensor + return self.queue.pop(0) + + def top(self): + if len(self.queue) == 0: + return self.init_tensor + return self.queue[0] + + def size(self): + return len(self.queue) + def tearDown(self): torch._library.fake_class_registry.deregister_fake_class( "fbgemm::AtomicCounter" ) + torch._library.fake_class_registry.deregister_fake_class("fbgemm::TensorQueue") super().tearDown() def _test_kjt_input_module( @@ -517,7 +544,7 @@ def test_sharded_quant_ebc_non_strict_export(self) -> None: {}, strict=False, pre_dispatch=True, - ).run_decompositions() + ) ep.module()(kjt.values(), kjt.lengths()) @@ -556,7 +583,7 @@ def test_sharded_quant_fpebc_non_strict_export(self) -> None: {}, strict=False, pre_dispatch=True, - ).run_decompositions() + ) ep.module()(kjt.values(), kjt.lengths()) # PT2 IR autofunctionalizes mutation funcs (bounds_check_indices) @@ -564,8 +591,9 @@ def test_sharded_quant_fpebc_non_strict_export(self) -> None: for n in ep.graph_module.graph.nodes: self.assertFalse("auto_functionalized" in str(n.name)) - # TODO: Fix Unflatten - # torch.export.unflatten(ep) + torch.export.unflatten(ep) + + ep(kjt.values(), kjt.lengths()) def test_maybe_compute_kjt_to_jt_dict(self) -> None: kjt: KeyedJaggedTensor = make_kjt([2, 3, 4, 5, 6], [1, 2, 1, 1]) diff --git a/torchrec/distributed/train_pipeline/tests/test_utils.py b/torchrec/distributed/train_pipeline/tests/test_train_pipelines_utils.py similarity index 96% rename from torchrec/distributed/train_pipeline/tests/test_utils.py rename to torchrec/distributed/train_pipeline/tests/test_train_pipelines_utils.py index 26fc6178d..854423385 100644 --- a/torchrec/distributed/train_pipeline/tests/test_utils.py +++ b/torchrec/distributed/train_pipeline/tests/test_train_pipelines_utils.py @@ -15,7 +15,7 @@ from torchrec.sparse.jagged_tensor import KeyedJaggedTensor -class TestUtils(unittest.TestCase): +class TestTrainPipelineUtils(unittest.TestCase): def test_get_node_args_helper_call_module_kjt(self) -> None: graph = torch.fx.Graph() kjt_args = []