diff --git a/opacus/tests/validators/gru_test.py b/opacus/tests/validators/gru_test.py new file mode 100644 index 00000000..00f0d259 --- /dev/null +++ b/opacus/tests/validators/gru_test.py @@ -0,0 +1,41 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch.nn as nn +from opacus.layers import DPGRU +from opacus.utils.module_utils import are_state_dict_equal +from opacus.validators.errors import ShouldReplaceModuleError +from opacus.validators.module_validator import ModuleValidator + + +class GRUValidator_test(unittest.TestCase): + def setUp(self) -> None: + self.gru = nn.GRU(8, 4) + self.mv = ModuleValidator.VALIDATORS + self.mf = ModuleValidator.FIXERS + + def test_validate(self) -> None: + val_gru = self.mv[type(self.gru)](self.gru) + self.assertEqual(len(val_gru), 1) + self.assertTrue(isinstance(val_gru[0], ShouldReplaceModuleError)) + + def test_fix(self) -> None: + fix_gru = self.mf[type(self.gru)](self.gru) + self.assertTrue(isinstance(fix_gru, DPGRU)) + self.assertTrue( + are_state_dict_equal(self.gru.state_dict(), fix_gru.state_dict()) + ) diff --git a/opacus/validators/gru.py b/opacus/validators/gru.py new file mode 100644 index 00000000..53005a0f --- /dev/null +++ b/opacus/validators/gru.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List + +import torch.nn as nn +from opacus.layers import DPGRU + +from .errors import ShouldReplaceModuleError, UnsupportedModuleError +from .utils import register_module_fixer, register_module_validator + + +@register_module_validator(nn.GRU) +def validate(module: nn.GRU) -> List[UnsupportedModuleError]: + return [ + ShouldReplaceModuleError( + "We do not support nn.GRU because its implementation uses special " + "modules. We have written a GRU class that is a drop-in replacement " + "which is compatible with our Grad Sample hooks. Please run the recommended " + "replacement!" + ) + ] + + +@register_module_fixer(nn.GRU) +def fix(module: nn.GRU) -> DPGRU: + dpgru = DPGRU( + input_size=module.input_size, + hidden_size=module.hidden_size, + num_layers=module.num_layers, + bias=module.bias, + batch_first=module.batch_first, + dropout=module.dropout, + bidirectional=module.bidirectional, + proj_size=module.proj_size, + ) + dpgru.load_state_dict(module.state_dict()) + return dpgru