diff --git a/examples/retrieval/two_tower_train.py b/examples/retrieval/two_tower_train.py index 1aee84cfe..141c1ca7d 100644 --- a/examples/retrieval/two_tower_train.py +++ b/examples/retrieval/two_tower_train.py @@ -158,7 +158,11 @@ def train( compute_device=device.type, ), ).collective_plan( - module=two_tower_model, sharders=sharders, pg=dist.GroupMember.WORLD + module=two_tower_model, + sharders=sharders, + # pyre-fixme[6]: For 3rd param expected `ProcessGroup` but got + # `Optional[ProcessGroup]`. + pg=dist.GroupMember.WORLD, ) model = DistributedModelParallel( module=two_tower_train_task, diff --git a/torchrec/distributed/tests/test_model_parallel.py b/torchrec/distributed/tests/test_model_parallel.py index 2380c5eb7..4d9462856 100644 --- a/torchrec/distributed/tests/test_model_parallel.py +++ b/torchrec/distributed/tests/test_model_parallel.py @@ -732,6 +732,8 @@ def reset_parameters(self) -> None: ) def test_meta_device_dmp_state_dict(self) -> None: + # pyre-fixme[6]: For 1st param expected `ProcessGroup` but got + # `Optional[ProcessGroup]`. env = ShardingEnv.from_process_group(dist.GroupMember.WORLD) m1 = TestSparseNN(