Skip to content

Commit

Permalink
fix UT
Browse files Browse the repository at this point in the history
  • Loading branch information
nijkah committed Sep 27, 2022
1 parent a408f69 commit dd64538
Showing 1 changed file with 4 additions and 11 deletions.
15 changes: 4 additions & 11 deletions tests/test_optim/test_optimizer/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import torch
import torch.nn as nn

from mmengine.model import MMDistributedDataParallel
from mmengine.optim import (OPTIM_WRAPPER_CONSTRUCTORS, OPTIMIZERS,
DefaultOptimWrapperConstructor, OptimWrapper,
build_optim_wrapper)
Expand Down Expand Up @@ -759,6 +760,7 @@ def _check_default_optimizer(self, optimizer, model, prefix=''):
def test_build_zero_redundancy_optimizer(self):
self._init_dist_env(self.rank, self.world_size)
model = ExampleModel()
model = MMDistributedDataParallel(module=model)
self.base_lr = 0.01
self.momentum = 0.0001
self.base_wd = 0.9
Expand All @@ -773,17 +775,8 @@ def test_build_zero_redundancy_optimizer(self):
momentum=self.momentum))
optim_wrapper = build_optim_wrapper(model, optim_wrapper_cfg)
self.assertIsInstance(optim_wrapper.optimizer, ZeroRedundancyOptimizer)
self._check_default_optimizer(optim_wrapper.optimizer, model)

# test build optimizer without ``optimizer_type``
with self.assertRaises(TypeError):
optim_wrapper_cfg = dict(
optimizer=dict(
type='ZeroRedundancyOptimizer',
lr=self.base_lr,
weight_decay=self.base_wd,
momentum=self.momentum))
optim_wrapper = build_optim_wrapper(model, optim_wrapper_cfg)
self._check_default_optimizer(
optim_wrapper.optimizer, model, prefix='module.')

def _init_dist_env(self, rank, world_size):
"""Initialize the distributed environment."""
Expand Down

0 comments on commit dd64538

Please sign in to comment.