Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

EMA & IterBasedRunner #2195

Open
Divadi opened this issue Aug 13, 2022 · 2 comments
Open

EMA & IterBasedRunner #2195

Divadi opened this issue Aug 13, 2022 · 2 comments
Assignees

Comments

@Divadi
Copy link

Divadi commented Aug 13, 2022

Describe the Issue
EMA, when used in conjunction with IterBasedRunner, I believe does not properly swap ema & normal parameters before evaluation & checkpointing.
EMA weight swapping happens with after_train_epoch, which is only called once at the conclusion of training for IterBasedRunner. However, evaluation & checkpointing happens based on after_train_iter for IterBasedRunner, causing evaluation (and default model weights saving) to be done on the regular model weights. This is different from the behavior of EpochBasedRunner, which evaluates & saves the EMA weights.

While I'm asking this, I had a quick question: What is the purpose of the resume_from argument in ema.py? It seems to call runner.resume, but don't pipelines like mmdetection already call runner.resume regardless?
https://github.com/open-mmlab/mmdetection/blob/3b72b12fe9b14de906d1363982b9fba05e7d47c1/mmdet/apis/train.py#L240-L244

Reproduction
I just added custom_hooks = [dict(type='ExpMomentumEMAHook', momentum=0.001, priority=49)] to an IterBasedRunner pipeline

Environment

  1. Please run python -c "from mmcv.utils import collect_env; print(collect_env())" to collect necessary environment information and paste it here.
    This is on mmcv 1.3.16, but I think(?) this issue persists into the latest version. I apologize if I missed a fix.
{'sys.platform': 'linux', 'Python': '3.7.13 (default, Mar 29 2022, 02:18:16) [GCC 7.5.0]', 'CUDA available': True, 'GPU 0,1,2,3,4,5,6,7': 'NVIDIA RTX A6000', 'CUDA_HOME': '/usr/local/cuda', 'NVCC': 'Build cuda_11.3.r11.3/compiler.29745058_0', 'GCC': 'gcc (Ubuntu 9.4.0-1ubuntu1~20.04.1) 9.4.0', 'PyTorch': '1.10.1', 'PyTorch compiling details': 'PyTorch built with:\n  - GCC 7.3\n  - C++ Version: 201402\n  - Intel(R) oneAPI Math Kernel Library Version 2021.4-Product Build 20210904 for Intel(R) 64 architecture applications\n  - Intel(R) MKL-DNN v2.2.3 (Git Hash 7336ca9f055cf1bfa13efb658fe15dc9b41f0740)\n  - OpenMP 201511 (a.k.a. OpenMP 4.5)\n  - LAPACK is enabled (usually provided by MKL)\n  - NNPACK is enabled\n  - CPU capability usage: AVX2\n  - CUDA Runtime 11.3\n  - NVCC architecture flags: -gencode;arch=compute_37,code=sm_37;-gencode;arch=compute_50,code=sm_50;-gencode;arch=compute_60,code=sm_60;-gencode;arch=compute_61,code=sm_61;-gencode;arch=compute_70,code=sm_70;-gencode;arch=compute_75,code=sm_75;-gencode;arch=compute_80,code=sm_80;-gencode;arch=compute_86,code=sm_86;-gencode;arch=compute_37,code=compute_37\n  - CuDNN 8.2\n  - Magma 2.5.2\n  - Build settings: BLAS_INFO=mkl, BUILD_TYPE=Release, CUDA_VERSION=11.3, CUDNN_VERSION=8.2.0, CXX_COMPILER=/opt/rh/devtoolset-7/root/usr/bin/c++, CXX_FLAGS= -Wno-deprecated -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -fopenmp -DNDEBUG -DUSE_KINETO -DUSE_FBGEMM -DUSE_QNNPACK -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -DEDGE_PROFILER_USE_KINETO -O2 -fPIC -Wno-narrowing -Wall -Wextra -Werror=return-type -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-sign-compare -Wno-unused-parameter -Wno-unused-variable -Wno-unused-function -Wno-unused-result -Wno-unused-local-typedefs -Wno-strict-overflow -Wno-strict-aliasing -Wno-error=deprecated-declarations -Wno-stringop-overflow -Wno-psabi -Wno-error=pedantic -Wno-error=redundant-decls -Wno-error=old-style-cast -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Wno-stringop-overflow, LAPACK_INFO=mkl, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, TORCH_VERSION=1.10.1, USE_CUDA=ON, USE_CUDNN=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=ON, USE_MKLDNN=ON, USE_MPI=OFF, USE_NCCL=ON, USE_NNPACK=ON, USE_OPENMP=ON, \n', 'TorchVision': '0.11.2', 'OpenCV': '4.6.0', 'MMCV': '1.3.16', 'MMCV Compiler': 'GCC 7.3', 'MMCV CUDA Compiler': '11.3'}
@HAOCHENYE
Copy link
Collaborator

  1. EMAHook is not compatible with IterBasedRunner for the following reasons:
  • Evaluation is controlled by EvalHook, and it will not trigger other hooks.
  • IterBasedHook will not trigger after_train_epoch like EpochBaseHook, therefore EMAHook does not know when to swap
    parameters.
  1. We save ema weight in register buffer of the model, and the buffer is created in EMAHook.before_run, therefore we have to resume ema weight again in EMAHook.before after creating the register buffer.

@HAOCHENYE
Copy link
Collaborator

@makecent Hi~, Thanks for your suggestions. Considering that there may be some hooks after EvalHook that require the model to remain in eval state, therefore we do not call model.train in EvalHook.

I strongly agree with you that some exception warnings should be thrown, the current behavior is not friendly enough.

BTW, we have recently upgraded our training architecture and released a new repo, MMEngine. You can solve this problem by using EMAHook in MMEngine(MMEngine has more rich hook entry, this issue can be solved easily). Currently, MMEngine is in public beta, and some of the OpenMMLab repos(including mmdet, still in progress) have been refactored based on MMEngine. Welcome to experience MMEngine, your valuable comments can help us to improve MMEngine.

LBNL, please feel free to use MMCV, and we will still maintain the master branch of it. All beneficial PRs will be merged into both MMCV and MMEngine.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants