Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] ZeroRedundancyOptimizer cannot support param-wise settings with torch.__version__ < 1.12.0 #778

Closed
2 tasks done
nijkah opened this issue Nov 30, 2022 · 3 comments · Fixed by #818
Closed
2 tasks done
Assignees
Labels
bug Something isn't working

Comments

@nijkah
Copy link
Contributor

nijkah commented Nov 30, 2022

Prerequisite

Environment

OrderedDict([('sys.platform', 'linux'), ('Python', '3.7.2 (default, May 11 2021, 10:20:27) [GCC 7.5.0]'), ('CUDA available', True), ('numpy_random_seed', 2147483648), ('GPU 0', 'NVIDIA A100-SXM-80GB MIG 2g.20gb'), ('CUDA_HOME', '/usr/local/cuda'), ('NVCC', 'Cuda compilation tools, release 11.1, V11.1.105'), ('GCC', 'gcc (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0'), ('PyTorch', '1.9.1+cu111'), ('PyTorch compiling details', 'PyTorch built with:\n - GCC 7.3\n - C++ Version: 201402\n - Intel(R) Math Kernel Library Version 2020.0.0 Product Build 20191122 for Intel(R) 64 architecture applications\n - Intel(R) MKL-DNN v2.1.2 (Git Hash 98be7e8afa711dc9b66c8ff3504129cb82013cdb)\n - OpenMP 201511 (a.k.a. OpenMP 4.5)\n - NNPACK is enabled\n - CPU capability usage: AVX2\n - CUDA Runtime 11.1\n - NVCC architecture flags: -gencode;arch=compute_37,code=sm_37;-gencode;arch=compute_50,code=sm_50;-gencode;arch=compute_60,code=sm_60;-gencode;arch=compute_70,code=sm_70;-gencode;arch=compute_75,code=sm_75;-gencode;arch=compute_80,code=sm_80;-gencode;arch=compute_86,code=sm_86\n - CuDNN 8.0.5\n - Magma 2.5.2\n - Build settings: BLAS_INFO=mkl, BUILD_TYPE=Release, CUDA_VERSION=11.1, CUDNN_VERSION=8.0.5, CXX_COMPILER=/opt/rh/devtoolset-7/root/usr/bin/c++, CXX_FLAGS= -Wno-deprecated -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -fopenmp -DNDEBUG -DUSE_KINETO -DUSE_FBGEMM -DUSE_QNNPACK -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -O2 -fPIC -Wno-narrowing -Wall -Wextra -Werror=return-type -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-sign-compare -Wno-unused-parameter -Wno-unused-variable -Wno-unused-function -Wno-unused-result -Wno-unused-local-typedefs -Wno-strict-overflow -Wno-strict-aliasing -Wno-error=deprecated-declarations -Wno-stringop-overflow -Wno-psabi -Wno-error=pedantic -Wno-error=redundant-decls -Wno-error=old-style-cast -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Wno-stringop-overflow, LAPACK_INFO=mkl, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, TORCH_VERSION=1.9.1, USE_CUDA=ON, USE_CUDNN=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=ON, USE_MKLDNN=ON, USE_MPI=OFF, USE_NCCL=ON, USE_NNPACK=ON, USE_OPENMP=ON, \n'), ('TorchVision', '0.10.1+cu111'), ('OpenCV', '4.5.1'), ('MMEngine', '0.3.1')])

Reproduces the problem - code sample

configs/swin/mask-rcnn_swin-t-p4-w7_fpn_1x_coco.py

# configs/swin/mask-rcnn_swin-t-p4-w7_fpn_1x_coco.py
# optimizer config is edited like
optim_wrapper = dict(
    type='OptimWrapper',
    paramwise_cfg=dict(
        custom_keys={
            'absolute_pos_embed': dict(decay_mult=0.),
            'relative_position_bias_table': dict(decay_mult=0.),
            'norm': dict(decay_mult=0.)
        }),
    optimizer=dict(
        _delete_=True,
        type='ZeroRedundancyOptimizer',
        optimizer_type='AdamW',
        lr=0.0001,
        betas=(0.9, 0.999),
        weight_decay=0.05))

Reproduces the problem - command or script

bash tools/dist_train.sh configs/swin/mask-rcnn_swin-t-p4-w7_fpn_1x_coco.py 1 

Reproduces the problem - error message

11/30 10:45:04 - mmengine - INFO - paramwise_options -- roi_head.mask_head.conv_logits.bias:lr=0.0001
11/30 10:45:04 - mmengine - INFO - paramwise_options -- roi_head.mask_head.conv_logits.bias:weight_decay=0.05
Traceback (most recent call last):
  File "/nas/k8s/dev/mlops/hakjinlee/workspace/sandbox/mmengine/mmengine/registry/build_functions.py", line 121, in build_from_cfg
    obj = obj_cls(**args)  # type: ignore
  File "/nas/k8s/dev/mlops/hakjinlee/workspace/sandbox/mmengine/mmengine/optim/optimizer/zero_optimizer.py", line 60, in __init__
    super().__init__(params, optimizer_class, **kwargs)
  File "/usr/local/lib/python3.7/site-packages/torch/distributed/optim/zero_redundancy_optimizer.py", line 166, in __init__
    self._reference_is_trainable_mask = list(map(_is_trainable, self._all_params))
  File "/usr/local/lib/python3.7/site-packages/torch/distributed/optim/zero_redundancy_optimizer.py", line 49, in _is_trainable
    return param.requires_grad
AttributeError: 'dict' object has no attribute 'requires_grad'

Additional information

Related to #716

Some configurations such as mask-rcnn_swin-t-p4-w7_fpn_1x_coco.py requires paramwise_cfg to give different parameters to some modules.

However, torch.distributed.optim.ZeroRedundancyOptimizer only recently started supporting the feature for multiple param groups. (Link) (It supports in torch.__version__ >= 1.12.0).

The key implementation is

  • Save parameters as List[torch.Tensor] in ZeroRedundancyOptimizer._all_params
  • Give original parameters to the inner optimizer(optimizer_class) as Optimizer.__init__(self, params, defaults)

We can copy-and-paste its __init__ logic to support param-wise settings when torch.__version__ < 1.12.0.
However, I think this kind of solution is quite messy and vulnerable to change in torch.distributed.optim.ZeroRedundancyOptimizer.

I kindly ask for any good ideas to handle this!

@nijkah nijkah added the bug Something isn't working label Nov 30, 2022
@C1rN09
Copy link
Collaborator

C1rN09 commented Nov 30, 2022

I'll take a look ✋

@C1rN09
Copy link
Collaborator

C1rN09 commented Dec 2, 2022

I got it. So this feature requires an intrusive change in ZeroRedundancyOptimizer.

We can copy-and-paste its __init__ logic to support param-wise settings when torch.__version__ < 1.12.0. However, I think this kind of solution is quite messy and vulnerable to change in torch.distributed.optim.ZeroRedundancyOptimizer.

I quite agree with your opinion! We cannot do this copy-and-paste forever. As an alternative, what about do some checks (e.g. all_dict) in our __init__ and then give some more instructive information in assertion? We can explain the reason and suggest upgrading to torch>=1.12 😸

@C1rN09
Copy link
Collaborator

C1rN09 commented Dec 12, 2022

Hi, @nijkah ! Can you take a look? #818

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
2 participants