From 1b61bbad327d2bf32502b3b9a770b57714cc43dc Mon Sep 17 00:00:00 2001 From: Vimal Manohar Date: Thu, 27 Jan 2022 13:01:49 -0800 Subject: [PATCH] Fix broken EMA in fairseq Summary: EMA broken since D33649708 (https://github.com/pytorch/fairseq/commit/995c204337d16a6146a433cee360e5a5bfbc9a6f) due to indentation error. Reviewed By: cruvadom Differential Revision: D33809223 fbshipit-source-id: c6c4d0d327443bfea787817040e1832eef0f50e4 --- fairseq/models/ema/ema.py | 10 +++---- tests/test_ema.py | 61 ++++++++++++++++++++++++++++++++++++++- 2 files changed, 65 insertions(+), 6 deletions(-) diff --git a/fairseq/models/ema/ema.py b/fairseq/models/ema/ema.py index bc966a9aed..dcc19d85cf 100644 --- a/fairseq/models/ema/ema.py +++ b/fairseq/models/ema/ema.py @@ -185,11 +185,11 @@ def step(self, new_model, updates=None): self._set_decay( 0 if updates < self.config.ema_start_update else self.config.ema_decay ) - if updates is not None and self.config.ema_update_freq > 1: - self.update_freq_counter += 1 - if self.update_freq_counter >= self.config.ema_update_freq: - self._step_internal(new_model, updates) - self.update_freq_counter = 0 + if self.config.ema_update_freq > 1: + self.update_freq_counter += 1 + if self.update_freq_counter >= self.config.ema_update_freq: + self._step_internal(new_model, updates) + self.update_freq_counter = 0 else: self._step_internal(new_model, updates) diff --git a/tests/test_ema.py b/tests/test_ema.py index e6f10ce9c2..1f8f71b9b6 100644 --- a/tests/test_ema.py +++ b/tests/test_ema.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. import unittest +from unittest.mock import patch from copy import deepcopy from dataclasses import dataclass from typing import Optional @@ -36,9 +37,10 @@ class EMAConfig(object): ema_start_update: int = 0 ema_fp32: bool = False ema_seed_model: Optional[str] = None + ema_update_freq: int = 1 -class TestEMAGPU(unittest.TestCase): +class TestEMA(unittest.TestCase): def assertTorchAllClose(self, x, y, atol=1e-8, rtol=1e-5, msg=None): diff = x.float() - y.float() diff_norm = torch.norm(diff) @@ -104,6 +106,63 @@ def test_ema(self): ema_param = ema_state_dict[key] self.assertTrue(torch.allclose(ema_param, param)) + # Check that step_internal is called once + with patch.object( + ema, "_step_internal", return_value=None + ) as mock_method: + ema.step(model) + mock_method.assert_called_once_with(model, None) + + def _test_ema_start_update(self, updates): + model = DummyModule() + optimizer = torch.optim.SGD(model.parameters(), lr=0.01) + state = deepcopy(model.state_dict()) + config = EMAConfig(ema_start_update=1) + ema = EMA(model, config) + + # EMA step + x = torch.randn(32) + y = model(x) + loss = y.sum() + loss.backward() + optimizer.step() + + ema.step(model, updates=updates) + ema_state_dict = ema.get_model().state_dict() + + self.assertEqual(ema.get_decay(), 0 if updates == 0 else config.ema_decay) + + for key, param in model.state_dict().items(): + ema_param = ema_state_dict[key] + prev_param = state[key] + + if "version" in key: + # Do not decay a model.version pytorch param + continue + if updates == 0: + self.assertTorchAllClose( + ema_param, + param, + ) + else: + self.assertTorchAllClose( + ema_param, + config.ema_decay * prev_param + (1 - config.ema_decay) * param, + ) + + # Check that step_internal is called once + with patch.object( + ema, "_step_internal", return_value=None + ) as mock_method: + ema.step(model, updates=updates) + mock_method.assert_called_once_with(model, updates) + + def test_ema_before_start_update(self): + self._test_ema_start_update(updates=0) + + def test_ema_after_start_update(self): + self._test_ema_start_update(updates=1) + def test_ema_fp32(self): model = DummyModule().half() optimizer = torch.optim.SGD(model.parameters(), lr=0.01)