diff --git a/torchft/optim.py b/torchft/optim.py index ce24823e..0583d94b 100644 --- a/torchft/optim.py +++ b/torchft/optim.py @@ -12,8 +12,9 @@ """ -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional +import torch from torch.optim import Optimizer if TYPE_CHECKING: @@ -52,3 +53,11 @@ def step(self, closure: Optional[object] = None) -> None: assert closure is None, "optimizers that use closures are not supported" if self.manager.should_commit(): self.optim.step() + + @property + def param_groups(self) -> List[Dict[str, Any]]: + return self.optim.param_groups + + @property + def state(self) -> Mapping[torch.Tensor, Any]: # pyre-fixme[3] + return self.optim.state diff --git a/torchft/optim_test.py b/torchft/optim_test.py index 50412d85..5dd69640 100644 --- a/torchft/optim_test.py +++ b/torchft/optim_test.py @@ -7,6 +7,7 @@ from unittest import TestCase from unittest.mock import MagicMock, create_autospec +import torch from torch.nn import Linear from torch.optim import AdamW @@ -34,9 +35,16 @@ def test_optimizer_wrapper(self) -> None: optim.zero_grad() self.assertEqual(manager.start_quorum.call_count, 1) + b = torch.rand(3) + m(b).sum().backward() + manager.should_commit.return_value = True optim.step() manager.should_commit.return_value = False optim.step() + self.assertEqual(len(optim.param_groups), 2) + self.assertEqual(optim.param_groups[1]["lr"], 1e-4) + self.assertEqual(optim.param_groups[1]["params"], []) + self.assertEqual(len(optim.state), len(list(m.parameters()))) self.assertEqual(manager.should_commit.call_count, 2)