diff --git a/torchrec/optim/tests/test_keyed.py b/torchrec/optim/tests/test_keyed.py index 926035dae..605858b38 100644 --- a/torchrec/optim/tests/test_keyed.py +++ b/torchrec/optim/tests/test_keyed.py @@ -2,7 +2,7 @@ import os import unittest -from typing import Dict, Any +from typing import Dict, Any, List import torch import torch.distributed as dist @@ -74,34 +74,36 @@ def test_load_state_dict(self) -> None: ) # Assert state_dict is as expected. - expected_state_dict = { - "state": { - "param_1": { - "one": 1.0, - "tensor": torch.tensor([5.0, 6.0]), - "sharded_tensor": dist._sharded_tensor.full( - # pyre-ignore [28] - dist._sharded_tensor.ChunkShardingSpec( - dim=0, placements=["rank:0/cpu"] - ), - (4,), - fill_value=1.0, + state: Dict[str, Any] = { + "param_1": { + "one": 1.0, + "tensor": torch.tensor([5.0, 6.0]), + "sharded_tensor": dist._sharded_tensor.full( + # pyre-ignore [28] + dist._sharded_tensor.ChunkShardingSpec( + dim=0, placements=["rank:0/cpu"] ), - }, - "param_2": {"two": 2.0}, + (4,), + fill_value=1.0, + ), }, - "param_groups": [ - { - "params": ["param_1"], - "param_group_val_0": 3.0, - "param_group_val_1": 4.0, - }, - { - "params": ["param_2"], - "param_group_val_0": 5.0, - "param_group_val_1": 6.0, - }, - ], + "param_2": {"two": 2.0}, + } + param_groups: List[Dict[str, Any]] = [ + { + "params": ["param_1"], + "param_group_val_0": 3.0, + "param_group_val_1": 4.0, + }, + { + "params": ["param_2"], + "param_group_val_0": 5.0, + "param_group_val_1": 6.0, + }, + ] + expected_state_dict = { + "state": state, + "param_groups": param_groups, } self._assert_state_dict_equals( expected_state_dict, keyed_optimizer.state_dict()