Skip to content

Commit

Permalink
add unittest
Browse files Browse the repository at this point in the history
  • Loading branch information
C1rN09 committed Dec 13, 2022
1 parent 10e7cd6 commit 1228931
Showing 1 changed file with 34 additions and 20 deletions.
54 changes: 34 additions & 20 deletions tests/test_optim/test_optimizer/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -735,29 +735,18 @@ def _check_default_optimizer(self, optimizer, model):
self.assertEqual(optimizer.defaults['lr'], self.base_lr)
self.assertEqual(optimizer.defaults['momentum'], self.momentum)
self.assertEqual(optimizer.defaults['weight_decay'], self.base_wd)
param_groups = optimizer.param_groups[0]
if MMCV_FULL_AVAILABLE:
param_names = [
'param1', 'conv1.weight', 'conv2.weight', 'conv2.bias',
'bn.weight', 'bn.bias', 'sub.param1', 'sub.conv1.weight',
'sub.conv1.bias', 'sub.gn.weight', 'sub.gn.bias', 'dcn.weight',
'dcn.conv_offset.weight', 'dcn.conv_offset.bias'
]
else:
param_names = [
'param1', 'conv1.weight', 'conv2.weight', 'conv2.bias',
'bn.weight', 'bn.bias', 'sub.param1', 'sub.conv1.weight',
'sub.conv1.bias', 'sub.gn.weight', 'sub.gn.bias'
]
param_dict = dict(model.named_parameters())
self.assertEqual(len(param_groups['params']), len(param_names))
for i in range(len(param_groups['params'])):
assert torch.equal(param_groups['params'][i],
param_dict[param_names[i]])
param_groups = optimizer.param_groups
params_set = set(model.parameters())
self.assertEqual(
sum(len(param_group['params']) for param_group in param_groups),
len(params_set))
self.assertTrue(
all(param in params_set for param_group in param_groups
for param in param_group['params']))

def test_build_zero_redundancy_optimizer(self):
from torch.distributed.optim import ZeroRedundancyOptimizer
self._init_dist_env(self.rank, self.world_size)
from mmengine.optim.optimizer import ZeroRedundancyOptimizer
model = ExampleModel()
self.base_lr = 0.01
self.momentum = 0.0001
Expand Down Expand Up @@ -785,6 +774,31 @@ def test_build_zero_redundancy_optimizer(self):
momentum=self.momentum))
optim_wrapper = build_optim_wrapper(model, optim_wrapper_cfg)

@unittest.skipIf(
digit_version(TORCH_VERSION) < digit_version('1.12.0'),
reason='ZeRO started to support param groups since pytorch 1.12.0')
def test_build_zero_redundancy_optimizer_with_paramwise_cfg(self):
self._init_dist_env(self.rank, self.world_size)
from mmengine.optim.optimizer import ZeroRedundancyOptimizer
model = ExampleModel()
self.base_lr = 0.01
self.momentum = 0.0001
self.base_wd = 0.9

# test build function
paramwise_cfg = dict(conv1_lr_mult=1)
optim_wrapper_cfg = dict(
optimizer=dict(
type='ZeroRedundancyOptimizer',
optimizer_type='SGD',
lr=self.base_lr,
weight_decay=self.base_wd,
momentum=self.momentum),
paramwise_cfg=paramwise_cfg)
optim_wrapper = build_optim_wrapper(model, optim_wrapper_cfg)
self.assertIsInstance(optim_wrapper.optimizer, ZeroRedundancyOptimizer)
self._check_default_optimizer(optim_wrapper.optimizer, model)

def _init_dist_env(self, rank, world_size):
"""Initialize the distributed environment."""
os.environ['MASTER_ADDR'] = '127.0.0.1'
Expand Down

0 comments on commit 1228931

Please sign in to comment.