diff --git a/torchrec/modules/tests/test_feature_processor_.py b/torchrec/modules/tests/test_feature_processor_.py index 150b35bed..036d28822 100644 --- a/torchrec/modules/tests/test_feature_processor_.py +++ b/torchrec/modules/tests/test_feature_processor_.py @@ -61,23 +61,6 @@ def test_populate_weights(self) -> None: weighted_features.lengths(), weighted_features_gm_script.lengths() ) - # TODO: this test is not being run - # pyre-ignore - @unittest.skipIf( - torch.cuda.device_count() <= 0, - "Not enough GPUs, this test requires at least one GPU", - ) - def test_rematerialize_from_meta(self) -> None: - pw = PositionWeightedModule(max_feature_length=10, device=torch.device("meta")) - self.assertTrue(pw.position_weight.is_meta) - - # Re-materialize on cuda - init_parameters(pw, torch.device("cuda")) - self.assertTrue(not pw.position_weight.is_meta) - torch.testing.assert_close( - pw.position_weight, torch.ones_like(pw.position_weight) - ) - class PositionWeightedCollectionModuleTest(unittest.TestCase): def test_populate_weights(self) -> None: @@ -133,26 +116,6 @@ def test_populate_weights(self) -> None: empty_fp_kjt.length_per_key(), empty_fp_kjt_gm_script.length_per_key() ) - # TODO: this test is not being run - # pyre-ignore - @unittest.skipIf( - torch.cuda.device_count() <= 0, - "Not enough GPUs, this test requires at least one GPU", - ) - def test_rematerialize_from_meta(self) -> None: - pwmc = PositionWeightedModuleCollection( - max_feature_lengths={"f1": 10, "f2": 10}, - device=torch.device("meta"), - ) - self.assertTrue(all(param.is_meta for param in pwmc.position_weights.values())) - - # Re-materialize on cuda - init_parameters(pwmc, torch.device("cuda")) - for key, param in pwmc.position_weights.items(): - self.assertTrue(not param.is_meta) - self.assertTrue(pwmc.position_weights_dict[key] is param) - torch.testing.assert_close(param, torch.ones_like(param)) - # pyre-ignore @unittest.skipIf( torch.cuda.device_count() <= 0,