Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RuntimeError: Expected to have finished reduction in the prior iteration before starting a new one. #2117

Closed
tankche1 opened this issue Feb 19, 2020 · 2 comments
Assignees

Comments

@tankche1
Copy link

tankche1 commented Feb 19, 2020

Checklist

  1. I have searched related issues but cannot get the expected help.
  2. The bug has not been fixed in the latest version.

Describe the bug
A clear and concise description of what the bug is.

Reproduction

  1. What command or script did you run?
./tools/dist_train.sh configs/pascal_voc/fast_rcnn_r50_fpn_1x_voc0712.py 4
  1. Did you make any modifications on the code or config? Did you understand what you have modified?

Yes. I load selective search bounding boxes and change the num_classes from 81 to 21.(from coco to pascal voc)
Here is the config:

# model settings
model = dict(
    type='FastRCNN',
    pretrained='torchvision://resnet50',
    backbone=dict(
        type='ResNet',
        depth=50,
        num_stages=4,
        out_indices=(0, 1, 2, 3),
        frozen_stages=1,
        style='pytorch'),
    neck=dict(
        type='FPN',
        in_channels=[256, 512, 1024, 2048],
        out_channels=256,
        num_outs=5),
    bbox_roi_extractor=dict(
        type='SingleRoIExtractor',
        roi_layer=dict(type='RoIAlign', out_size=7, sample_num=2),
        out_channels=256,
        featmap_strides=[4, 8, 16, 32]),
    bbox_head=dict(
        type='SharedFCBBoxHead',
        num_fcs=2,
        in_channels=256,
        fc_out_channels=1024,
        roi_feat_size=7,
        num_classes=21,#####
        target_means=[0., 0., 0., 0.],
        target_stds=[0.1, 0.1, 0.2, 0.2],
        reg_class_agnostic=False,
        loss_cls=dict(
            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
        loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)))
# model training and testing settings
train_cfg = dict(
    rcnn=dict(
        assigner=dict(
            type='MaxIoUAssigner',
            pos_iou_thr=0.5,
            neg_iou_thr=0.5,
            min_pos_iou=0.5,
            ignore_iof_thr=-1),
        sampler=dict(
            type='RandomSampler',
            num=512,
            pos_fraction=0.25,
            neg_pos_ub=-1,
            add_gt_as_proposals=True),
        pos_weight=-1,
        debug=False))
test_cfg = dict(
    rcnn=dict(
        score_thr=0.05, nms=dict(type='nms', iou_thr=0.5), max_per_img=100))
# dataset settings
dataset_type = 'VOCDataset'
data_root = 'data/VOCdevkit/'
img_norm_cfg = dict(
    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)

train_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='LoadProposals', num_max_proposals=2000),
    dict(type='LoadAnnotations', with_bbox=True),
    dict(type='Resize', img_scale=(1000, 600), keep_ratio=True),
    dict(type='RandomFlip', flip_ratio=0.5),
    dict(type='Normalize', **img_norm_cfg),
    dict(type='Pad', size_divisor=32),
    dict(type='DefaultFormatBundle'),
    dict(type='Collect', keys=['img', 'proposals', 'gt_bboxes', 'gt_labels']),
]
test_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='LoadProposals', num_max_proposals=2000), ###
    dict(
        type='MultiScaleFlipAug',
        img_scale=(1000, 600),
        flip=False,
        transforms=[
            dict(type='Resize', keep_ratio=True),
            dict(type='RandomFlip'),
            dict(type='Normalize', **img_norm_cfg),
            dict(type='Pad', size_divisor=32),
            dict(type='ImageToTensor', keys=['img']),
            dict(type='Collect', keys=['img', 'proposals']),
        ])
]
data = dict(
    imgs_per_gpu=2,
    workers_per_gpu=2,
    train=dict(
        type='RepeatDataset',
        times=4,##
        dataset=dict(
            type=dataset_type,
            ann_file=[
                data_root + 'VOC2007/ImageSets/Main/trainval.txt',
                data_root + 'VOC2012/ImageSets/Main/trainval.txt'
            ],
            proposal_file=[
                data_root + 'voc_2007_trainval.pkl',
                data_root + 'voc_2012_trainval.pkl'
            ],
            img_prefix=[data_root + 'VOC2007/', data_root + 'VOC2012/'],
            pipeline=train_pipeline)),
    val=dict(
        type=dataset_type,
        ann_file=data_root + 'VOC2007/ImageSets/Main/test.txt',
        img_prefix=data_root + 'VOC2007/',
        proposal_file=data_root + 'voc_2007_test.pkl',
        pipeline=test_pipeline),
    test=dict(
        type=dataset_type,
        ann_file=data_root + 'VOC2007/ImageSets/Main/test.txt',
        img_prefix=data_root + 'VOC2007/',
        proposal_file=data_root + 'voc_2007_test.pkl',
        pipeline=test_pipeline))
evaluation = dict(interval=1, metric='mAP')
# optimizer
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)
optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
# learning policy
lr_config = dict(policy='step', step=[3])
checkpoint_config = dict(interval=1)
# yapf:disable
log_config = dict(
    interval=50,
    hooks=[
        dict(type='TextLoggerHook'),
        # dict(type='TensorboardLoggerHook')
    ])
# yapf:enable
# runtime settings
total_epochs = 6##
dist_params = dict(backend='nccl')
log_level = 'INFO'
work_dir = './work_dirs/fast_rcnn_r50_fpn_1x_voc0712'
load_from = None
resume_from = None
workflow = [('train', 1)]

  1. What dataset did you use?
    pascal voc 07+12

Environment

  1. Please run python mmdet/utils/collect_env.py to collect necessary environment infomation and paste it here.

sys.platform: linux
Python: 3.7.6 (default, Jan 8 2020, 19:59:22) [GCC 7.3.0]
CUDA available: True
CUDA_HOME: /cm/shared/apps/cuda100/10.0.130
NVCC: Cuda compilation tools, release 10.0, V10.0.130
GPU 0,1,2,3: GeForce GTX TITAN X
GCC: gcc (GCC) 4.8.5 20150623 (Red Hat 4.8.5-36)
PyTorch: 1.1.0
PyTorch compiling details: PyTorch built with:

  • GCC 4.9
  • Intel(R) Math Kernel Library Version 2020.0.0 Product Build 20191122 for Intel(R) 64 architecture applications
  • Intel(R) MKL-DNN v0.18.1 (Git Hash 7de7e5d02bf687f971e7668963649728356e0c20)
  • OpenMP 201307 (a.k.a. OpenMP 4.0)
  • NNPACK is enabled
  • CUDA Runtime 10.0
  • NVCC architecture flags: -gencode;arch=compute_35,code=sm_35;-gencode;arch=compute_50,code=sm_50;-gencode;arch=compute_60,code=sm_60;-gencode;arch=compute_61,code=sm_61;-gencode;arch=compute_70,code=sm_70;-gencode;arch=compute_75,code=sm_75;-gencode;arch=compute_50,code=compute_50
  • CuDNN 7.5.1
  • Magma 2.5.0
  • Build settings: BLAS=MKL, BUILD_TYPE=Release, CXX_FLAGS= -Wno-deprecated -fvisibility-inlines-hidden -fopenmp -O2 -fPIC -Wno-narrowing -Wall -Wextra -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-sign-compare -Wno-unused-parameter -Wno-unused-variable -Wno-unused-function -Wno-unused-result -Wno-strict-overflow -W
    no-strict-aliasing -Wno-error=deprecated-declarations -Wno-error=pedantic -Wno-error=redundant-decls -Wno-error=old-style-cast -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math, DISABLE_NUMA=1, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, USE_CUDA=True, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=ON, USE_MKLDNN=ON, USE_M
    PI=OFF, USE_NCCL=True, USE_NNPACK=True, USE_OPENMP=ON,

TorchVision: 0.3.0
OpenCV: 4.2.0
MMCV: 0.3.1
MMDetection: 1.0.0+2afa063
MMDetection Compiler: GCC 6.1
MMDetection CUDA Compiler: 10.0

  1. You may add addition that may be helpful for locating the problem, such as
    • How you installed PyTorch [e.g., pip, conda, source]
    • Other environment variables that may be related (such as $PATH, $LD_LIBRARY_PATH, $PYTHONPATH, etc.)

Error traceback
If applicable, paste the error trackback here.
RuntimeError: Expected to have finished reduction in the prior iteration before starting a new one. This error indicates that your module has parameters that were not used in producing its output (the return value of forward). You can enable unused parameter detection by passing the keyword argument find_unused_parameters=True to torch.nn.parallel.Distributed DataParallel. If you already have this argument set, then the distributed data parallel module wasn't able to locate the output tensors in the return value of your module's forward function. Please include the structure of the return value of forward of your module when reporting this issue (e.g. list, dict, iterable). (prepare_for_backward at /opt/conda/cond
a-bld/pytorch_1556653114079/work/torch/csrc/distributed/c10d/reducer.cpp:408)
frame #0: c10::Error::Error(c10::SourceLocation, std::string const&) + 0x45 (0x2aaaf74b1dc5 in /home/zitianchen/anaconda2/envs/open-mmlab/lib/python3.7/site-packages/torch/lib/libc10.so)
frame #1: c10d::Reducer::prepare_for_backward(std::vector<torch::autograd::Variable, std::allocatortorch::autograd::Variable > const&) + 0x5ff (0x2aaac82fabbf in /home/zitianchen/anaconda2/envs/open-mmlab/lib/python3.7/site-packages/torch/lib/libtorch_python.so)
frame #2: + 0x6cb6c8 (0x2aaac82f06c8 in /home/zitianchen/anaconda2/envs/open-mmlab/lib/python3.7/site-packages/torch/lib/libtorch_python.so)
frame #3: + 0x12d07a (0x2aaac7d5207a in /home/zitianchen/anaconda2/envs/open-mmlab/lib/python3.7/site-packages/torch/lib/libtorch_python.so)
frame #4: _PyMethodDef_RawFastCallKeywords + 0x264 (0x5555556b5114 in /home/zitianchen/anaconda2/envs/open-mmlab/bin/python)
frame #5: _PyCFunction_FastCallKeywords + 0x21 (0x5555556b5231 in /home/zitianchen/anaconda2/envs/open-mmlab/bin/python)
frame #6: _PyEval_EvalFrameDefault + 0x52cf (0x555555719e8f in /home/zitianchen/anaconda2/envs/open-mmlab/bin/python)
frame #7: _PyEval_EvalCodeWithName + 0x2f9 (0x55555566e6f9 in /home/zitianchen/anaconda2/envs/open-mmlab/bin/python)
frame #8: _PyFunction_FastCallDict + 0x400 (0x55555566fa30 in /home/zitianchen/anaconda2/envs/open-mmlab/bin/python)
frame #9: _PyObject_Call_Prepend + 0x63 (0x55555568a943 in /home/zitianchen/anaconda2/envs/open-mmlab/bin/python)
frame #10: PyObject_Call + 0x6e (0x55555567db9e in /home/zitianchen/anaconda2/envs/open-mmlab/bin/python)
frame #11: _PyEval_EvalFrameDefault + 0x1e35 (0x5555557169f5 in /home/zitianchen/anaconda2/envs/open-mmlab/bin/python)
frame #12: _PyEval_EvalCodeWithName + 0x2f9 (0x55555566e6f9 in /home/zitianchen/anaconda2/envs/open-mmlab/bin/python)
frame #13: _PyFunction_FastCallDict + 0x400 (0x55555566fa30 in /home/zitianchen/anaconda2/envs/open-mmlab/bin/python)
frame #14: _PyObject_Call_Prepend + 0x63 (0x55555568a943 in /home/zitianchen/anaconda2/envs/open-mmlab/bin/python)
frame #15: + 0x17512a (0x5555556c912a in /home/zitianchen/anaconda2/envs/open-mmlab/bin/python)
frame #16: PyObject_Call + 0x6e (0x55555567db9e in /home/zitianchen/anaconda2/envs/open-mmlab/bin/python)
frame #17: _PyEval_EvalFrameDefault + 0x1e35 (0x5555557169f5 in /home/zitianchen/anaconda2/envs/open-mmlab/bin/python)
frame #18: _PyEval_EvalCodeWithName + 0x2f9 (0x55555566e6f9 in /home/zitianchen/anaconda2/envs/open-mmlab/bin/python)
frame #19: _PyFunction_FastCallDict + 0x400 (0x55555566fa30 in /home/zitianchen/anaconda2/envs/open-mmlab/bin/python)
frame #20: _PyEval_EvalFrameDefault + 0x1e35 (0x5555557169f5 in /home/zitianchen/anaconda2/envs/open-mmlab/bin/python)
frame #21: _PyEval_EvalCodeWithName + 0x2f9 (0x55555566e6f9 in /home/zitianchen/anaconda2/envs/open-mmlab/bin/python)
frame #22: _PyFunction_FastCallDict + 0x1d5 (0x55555566f805 in /home/zitianchen/anaconda2/envs/open-mmlab/bin/python)
frame #23: _PyObject_Call_Prepend + 0x63 (0x55555568a943 in /home/zitianchen/anaconda2/envs/open-mmlab/bin/python)
frame #24: PyObject_Call + 0x6e (0x55555567db9e in /home/zitianchen/anaconda2/envs/open-mmlab/bin/python)
frame #25: _PyEval_EvalFrameDefault + 0x1e35 (0x5555557169f5 in /home/zitianchen/anaconda2/envs/open-mmlab/bin/python)
frame #26: _PyEval_EvalCodeWithName + 0x2f9 (0x55555566e6f9 in /home/zitianchen/anaconda2/envs/open-mmlab/bin/python)
frame #27: _PyFunction_FastCallKeywords + 0x387 (0x5555556b4917 in /home/zitianchen/anaconda2/envs/open-mmlab/bin/python)
frame #28: _PyEval_EvalFrameDefault + 0x6a0 (0x555555715260 in /home/zitianchen/anaconda2/envs/open-mmlab/bin/python)
frame #29: _PyEval_EvalCodeWithName + 0xc30 (0x55555566f030 in /home/zitianchen/anaconda2/envs/open-mmlab/bin/python)
frame #30: _PyFunction_FastCallKeywords + 0x387 (0x5555556b4917 in /home/zitianchen/anaconda2/envs/open-mmlab/bin/python)
frame #31: _PyEval_EvalFrameDefault + 0x14e6 (0x5555557160a6 in /home/zitianchen/anaconda2/envs/open-mmlab/bin/python)
frame #32: _PyEval_EvalCodeWithName + 0x2f9 (0x55555566e6f9 in /home/zitianchen/anaconda2/envs/open-mmlab/bin/python)
frame #33: _PyFunction_FastCallKeywords + 0x387 (0x5555556b4917 in /home/zitianchen/anaconda2/envs/open-mmlab/bin/python)
frame #34: _PyEval_EvalFrameDefault + 0x14e6 (0x5555557160a6 in /home/zitianchen/anaconda2/envs/open-mmlab/bin/python)
frame #35: _PyFunction_FastCallKeywords + 0xfb (0x5555556b468b in /home/zitianchen/anaconda2/envs/open-mmlab/bin/python)
frame #36: _PyEval_EvalFrameDefault + 0x416 (0x555555714fd6 in /home/zitianchen/anaconda2/envs/open-mmlab/bin/python)
frame #37: _PyEval_EvalCodeWithName + 0x2f9 (0x55555566e6f9 in /home/zitianchen/anaconda2/envs/open-mmlab/bin/python)
frame #38: PyEval_EvalCodeEx + 0x44 (0x55555566f5f4 in /home/zitianchen/anaconda2/envs/open-mmlab/bin/python)
frame #39: PyEval_EvalCode + 0x1c (0x55555566f61c in /home/zitianchen/anaconda2/envs/open-mmlab/bin/python)
frame #40: + 0x21c974 (0x555555770974 in /home/zitianchen/anaconda2/envs/open-mmlab/bin/python)
frame #41: PyRun_FileExFlags + 0xa1 (0x55555577acf1 in /home/zitianchen/anaconda2/envs/open-mmlab/bin/python)
frame #42: PyRun_SimpleFileExFlags + 0x1c3 (0x55555577aee3 in /home/zitianchen/anaconda2/envs/open-mmlab/bin/python)
frame #43: + 0x227f95 (0x55555577bf95 in /home/zitianchen/anaconda2/envs/open-mmlab/bin/python)
frame #44: _Py_UnixMain + 0x3c (0x55555577c0bc in /home/zitianchen/anaconda2/envs/open-mmlab/bin/python)
frame #45: __libc_start_main + 0xf5 (0x2aaaaaf0d3d5 in /lib64/libc.so.6)
frame #46: + 0x1d0990 (0x555555724990 in /home/zitianchen/anaconda2/envs/open-mmlab/bin/python)

Traceback (most recent call last):
  File "./tools/train.py", line 141, in <module>
    main()
  File "./tools/train.py", line 137, in main
    meta=meta)
  File "/home/zitianchen/code/mmdetection/mmdet/apis/train.py", line 102, in train_detector
    meta=meta)
  File "/home/zitianchen/code/mmdetection/mmdet/apis/train.py", line 251, in _dist_train
    runner.run(data_loaders, cfg.workflow, cfg.total_epochs)
  File "/home/zitianchen/anaconda2/envs/open-mmlab/lib/python3.7/site-packages/mmcv/runner/runner.py", line 371, in run
    epoch_runner(data_loaders[i], **kwargs)
  File "/home/zitianchen/anaconda2/envs/open-mmlab/lib/python3.7/site-packages/mmcv/runner/runner.py", line 275, in train
    self.model, data_batch, train_mode=True, **kwargs)
  File "/home/zitianchen/code/mmdetection/mmdet/apis/train.py", line 75, in batch_processor
    losses = model(**data)
  File "/home/zitianchen/anaconda2/envs/open-mmlab/lib/python3.7/site-packages/torch/nn/modules/module.py", line 493, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/zitianchen/anaconda2/envs/open-mmlab/lib/python3.7/site-packages/torch/nn/parallel/distributed.py", line 392, in forward
    self.reducer.prepare_for_backward([])

Bug fix
This error happened after 4 epochs and 1500 iterations. I think it is because there are unused parameters in loss calculation. Any solution to figure it out? init_dist() does not have keyword argument find_unused_parameters=True. My code can correctly run using only one GPU(./tools/train.py). I try to print out some variables, it seems that there is nothing strange with the input.

I wonder if there is a way to print out unused_parameters.

BTW, I have written another detection model that facing this issue in the first 100 iterations.

Some related reference:
ultralytics/yolov3#331

pytorch/pytorch#22049

@tankche1
Copy link
Author

By adding
mmdet/apis/train.py line215: find_unused_parameters=True
It seems to work fine right now.
Not sure why this happened and nobody mentions this before.
I will reopen if I meet this error again.

@YAOYI626
Copy link

YAOYI626 commented Aug 20, 2021

By adding
mmdet/apis/train.py line215: find_unused_parameters=True
It seems to work fine right now.
Not sure why this happened and nobody mentions this before.
I will reopen if I meet this error again.

I occur this issue too but my issue is related to including empty images without annotated bounding boxes. When I set filter_empty_gt=True, the issue is solved. If I have to include empty images, I have to also set find_unused_parameters=True to keep my run going normally.

Does anyone have any idea about why the find_unused_parameters can help solve this issue?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants