Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RuntimeError: Default process group has not been initialized, please make sure to call init_process_group. #3972

Closed
saidineshpola opened this issue Feb 16, 2022 · 2 comments

Comments

@saidineshpola
Copy link

saidineshpola commented Feb 16, 2022

I have tried to train detectron2 using LazyConfig on single GPU but I encountered File "/home/user/.conda/envs/default/lib/python3.9/site-packages/torch/distributed/distributed_c10d.py", line 358, in _get_default_group raise RuntimeError("Default process group has not been initialized") . can anybody help what to do here for using it on single machine with out using any DistributedDataParallel from pytorch?

Instructions To Reproduce the Issue:

  1. Full runnable code or full changes you made:
!python ./tools/lazyconfig_train_net.py --num-gpus 1 --config-file ./configs/new_baselines/mask_rcnn_regnetx_4gf_dds_FPN_200ep_LSJ.py 

<put code or diff here>
  1. What exact command you run:
  2. Full logs or other relevant observations:
[02/16 08:16:26 detectron2]: Rank of current process: 0. World size: 1
[02/16 08:16:27 detectron2]: Environment info:
----------------------  -----------------------------------------------------------------------------------------------
sys.platform            linux
Python                  3.9.7 | packaged by conda-forge | (default, Sep 29 2021, 19:23:11) [GCC 9.4.0]
numpy                   1.22.1
detectron2              0.6 @/home/studio-lab-user/.conda/envs/default/lib/python3.9/site-packages/detectron2
Compiler                GCC 7.3
CUDA compiler           CUDA 11.1
detectron2 arch flags   3.7, 5.0, 5.2, 6.0, 6.1, 7.0, 7.5, 8.0, 8.6
DETECTRON2_ENV_MODULE   <not set>
PyTorch                 1.9.0+cu111 @/home/studio-lab-user/.conda/envs/default/lib/python3.9/site-packages/torch
PyTorch debug build     False
GPU available           Yes
GPU 0                   Tesla T4 (arch=7.5)
Driver version          470.57.02
CUDA_HOME               /usr/local/cuda
Pillow                  9.0.0
torchvision             0.10.0+cu111 @/home/studio-lab-user/.conda/envs/default/lib/python3.9/site-packages/torchvision
torchvision arch flags  3.5, 5.0, 6.0, 7.0, 7.5, 8.0, 8.6
fvcore                  0.1.5.post20211023
iopath                  0.1.9
cv2                     4.5.5
----------------------  -----------------------------------------------------------------------------------------------
PyTorch built with:
  - GCC 7.3
  - C++ Version: 201402
  - Intel(R) Math Kernel Library Version 2020.0.0 Product Build 20191122 for Intel(R) 64 architecture applications
  - Intel(R) MKL-DNN v2.1.2 (Git Hash 98be7e8afa711dc9b66c8ff3504129cb82013cdb)
  - OpenMP 201511 (a.k.a. OpenMP 4.5)
  - NNPACK is enabled
  - CPU capability usage: AVX2
  - CUDA Runtime 11.1
  - NVCC architecture flags: -gencode;arch=compute_37,code=sm_37;-gencode;arch=compute_50,code=sm_50;-gencode;arch=compute_60,code=sm_60;-gencode;arch=compute_70,code=sm_70;-gencode;arch=compute_75,code=sm_75;-gencode;arch=compute_80,code=sm_80;-gencode;arch=compute_86,code=sm_86
  - CuDNN 8.0.5
  - Magma 2.5.2
  - Build settings: BLAS_INFO=mkl, BUILD_TYPE=Release, CUDA_VERSION=11.1, CUDNN_VERSION=8.0.5, CXX_COMPILER=/opt/rh/devtoolset-7/root/usr/bin/c++, CXX_FLAGS= -Wno-deprecated -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -fopenmp -DNDEBUG -DUSE_KINETO -DUSE_FBGEMM -DUSE_QNNPACK -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -O2 -fPIC -Wno-narrowing -Wall -Wextra -Werror=return-type -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-sign-compare -Wno-unused-parameter -Wno-unused-variable -Wno-unused-function -Wno-unused-result -Wno-unused-local-typedefs -Wno-strict-overflow -Wno-strict-aliasing -Wno-error=deprecated-declarations -Wno-stringop-overflow -Wno-psabi -Wno-error=pedantic -Wno-error=redundant-decls -Wno-error=old-style-cast -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Wno-stringop-overflow, LAPACK_INFO=mkl, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, TORCH_VERSION=1.9.0, USE_CUDA=ON, USE_CUDNN=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=ON, USE_MKLDNN=ON, USE_MPI=OFF, USE_NCCL=ON, USE_NNPACK=ON, USE_OPENMP=ON, 

[02/16 08:16:27 detectron2]: Command line arguments: Namespace(config_file='./configs/new_baselines/mask_rcnn_regnetx_4gf_dds_FPN_200ep_LSJ.py', resume=False, eval_only=False, num_gpus=1, num_machines=1, machine_rank=0, dist_url='tcp://127.0.0.1:50152', opts=[])
[02/16 08:16:27 detectron2]: Contents of args.config_file=./configs/new_baselines/mask_rcnn_regnetx_4gf_dds_FPN_200ep_LSJ.py:
from .mask_rcnn_regnetx_4gf_dds_FPN_100ep_LSJ import (
    dataloader,
    lr_multiplier,
    model,
    optimizer,
    train,
)

train.max_iter *= 2  # 100ep -> 200ep

lr_multiplier.scheduler.milestones = [
    milestone * 2 for milestone in lr_multiplier.scheduler.milestones
]
lr_multiplier.scheduler.num_updates = train.max_iter

WARNING [02/16 08:16:28 d2.config.lazy]: The config contains objects that cannot serialize to a valid yaml. ./output/config.yaml is human-readable but cannot be loaded.
WARNING [02/16 08:16:28 d2.config.lazy]: Config is saved using cloudpickle at ./output/config.yaml.pkl.
[02/16 08:16:28 detectron2]: Full config saved to ./output/config.yaml
[02/16 08:16:28 d2.utils.env]: Using a generated random seed 28073747
[02/16 08:16:28 detectron2]: Model:
GeneralizedRCNN(
  (backbone): FPN(
    (fpn_lateral2): Conv2d(
      80, 256, kernel_size=(1, 1), stride=(1, 1), bias=False
      (norm): SyncBatchNorm(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (fpn_output2): Conv2d(
      256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False
      (norm): SyncBatchNorm(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (fpn_lateral3): Conv2d(
      240, 256, kernel_size=(1, 1), stride=(1, 1), bias=False
      (norm): SyncBatchNorm(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (fpn_output3): Conv2d(
      256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False
      (norm): SyncBatchNorm(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (fpn_lateral4): Conv2d(
      560, 256, kernel_size=(1, 1), stride=(1, 1), bias=False
      (norm): SyncBatchNorm(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (fpn_output4): Conv2d(
      256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False
      (norm): SyncBatchNorm(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (fpn_lateral5): Conv2d(
      1360, 256, kernel_size=(1, 1), stride=(1, 1), bias=False
      (norm): SyncBatchNorm(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (fpn_output5): Conv2d(
      256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False
      (norm): SyncBatchNorm(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (top_block): LastLevelMaxPool()
    (bottom_up): RegNet(
      (stem): SimpleStem(
        (conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (bn): SyncBatchNorm(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (af): ReLU(inplace=True)
      )
      (s1): AnyStage(
        (b1): ResBottleneckBlock(
          (proj): Conv2d(32, 80, kernel_size=(1, 1), stride=(2, 2), bias=False)
          (bn): SyncBatchNorm(80, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (f): BottleneckTransform(
            (a): Conv2d(32, 80, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (a_bn): SyncBatchNorm(80, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (a_af): ReLU(inplace=True)
            (b): Conv2d(80, 80, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=2, bias=False)
            (b_bn): SyncBatchNorm(80, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (b_af): ReLU(inplace=True)
            (c): Conv2d(80, 80, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (c_bn): SyncBatchNorm(80, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
          (af): ReLU(inplace=True)
        )
        (b2): ResBottleneckBlock(
          (f): BottleneckTransform(
            (a): Conv2d(80, 80, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (a_bn): SyncBatchNorm(80, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (a_af): ReLU(inplace=True)
            (b): Conv2d(80, 80, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=2, bias=False)
            (b_bn): SyncBatchNorm(80, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (b_af): ReLU(inplace=True)
            (c): Conv2d(80, 80, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (c_bn): SyncBatchNorm(80, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
          (af): ReLU(inplace=True)
        )
      )
      (s2): AnyStage(
        (b1): ResBottleneckBlock(
          (proj): Conv2d(80, 240, kernel_size=(1, 1), stride=(2, 2), bias=False)
          (bn): SyncBatchNorm(240, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (f): BottleneckTransform(
            (a): Conv2d(80, 240, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (a_bn): SyncBatchNorm(240, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (a_af): ReLU(inplace=True)
            (b): Conv2d(240, 240, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=6, bias=False)
            (b_bn): SyncBatchNorm(240, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (b_af): ReLU(inplace=True)
            (c): Conv2d(240, 240, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (c_bn): SyncBatchNorm(240, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
          (af): ReLU(inplace=True)
        )
        (b2): ResBottleneckBlock(
          (f): BottleneckTransform(
            (a): Conv2d(240, 240, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (a_bn): SyncBatchNorm(240, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (a_af): ReLU(inplace=True)
            (b): Conv2d(240, 240, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=6, bias=False)
            (b_bn): SyncBatchNorm(240, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (b_af): ReLU(inplace=True)
            (c): Conv2d(240, 240, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (c_bn): SyncBatchNorm(240, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
          (af): ReLU(inplace=True)
        )
        (b3): ResBottleneckBlock(
          (f): BottleneckTransform(
            (a): Conv2d(240, 240, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (a_bn): SyncBatchNorm(240, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (a_af): ReLU(inplace=True)
            (b): Conv2d(240, 240, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=6, bias=False)
            (b_bn): SyncBatchNorm(240, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (b_af): ReLU(inplace=True)
            (c): Conv2d(240, 240, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (c_bn): SyncBatchNorm(240, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
          (af): ReLU(inplace=True)
        )
        (b4): ResBottleneckBlock(
          (f): BottleneckTransform(
            (a): Conv2d(240, 240, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (a_bn): SyncBatchNorm(240, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (a_af): ReLU(inplace=True)
            (b): Conv2d(240, 240, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=6, bias=False)
            (b_bn): SyncBatchNorm(240, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (b_af): ReLU(inplace=True)
            (c): Conv2d(240, 240, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (c_bn): SyncBatchNorm(240, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
          (af): ReLU(inplace=True)
        )
        (b5): ResBottleneckBlock(
          (f): BottleneckTransform(
            (a): Conv2d(240, 240, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (a_bn): SyncBatchNorm(240, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (a_af): ReLU(inplace=True)
            (b): Conv2d(240, 240, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=6, bias=False)
            (b_bn): SyncBatchNorm(240, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (b_af): ReLU(inplace=True)
            (c): Conv2d(240, 240, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (c_bn): SyncBatchNorm(240, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
          (af): ReLU(inplace=True)
        )
      )
      (s3): AnyStage(
        (b1): ResBottleneckBlock(
          (proj): Conv2d(240, 560, kernel_size=(1, 1), stride=(2, 2), bias=False)
          (bn): SyncBatchNorm(560, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (f): BottleneckTransform(
            (a): Conv2d(240, 560, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (a_bn): SyncBatchNorm(560, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (a_af): ReLU(inplace=True)
            (b): Conv2d(560, 560, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=14, bias=False)
            (b_bn): SyncBatchNorm(560, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (b_af): ReLU(inplace=True)
            (c): Conv2d(560, 560, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (c_bn): SyncBatchNorm(560, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
          (af): ReLU(inplace=True)
        )
        (b2): ResBottleneckBlock(
          (f): BottleneckTransform(
            (a): Conv2d(560, 560, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (a_bn): SyncBatchNorm(560, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (a_af): ReLU(inplace=True)
            (b): Conv2d(560, 560, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=14, bias=False)
            (b_bn): SyncBatchNorm(560, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (b_af): ReLU(inplace=True)
            (c): Conv2d(560, 560, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (c_bn): SyncBatchNorm(560, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
          (af): ReLU(inplace=True)
        )
        (b3): ResBottleneckBlock(
          (f): BottleneckTransform(
            (a): Conv2d(560, 560, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (a_bn): SyncBatchNorm(560, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (a_af): ReLU(inplace=True)
            (b): Conv2d(560, 560, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=14, bias=False)
            (b_bn): SyncBatchNorm(560, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (b_af): ReLU(inplace=True)
            (c): Conv2d(560, 560, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (c_bn): SyncBatchNorm(560, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
          (af): ReLU(inplace=True)
        )
        (b4): ResBottleneckBlock(
          (f): BottleneckTransform(
            (a): Conv2d(560, 560, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (a_bn): SyncBatchNorm(560, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (a_af): ReLU(inplace=True)
            (b): Conv2d(560, 560, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=14, bias=False)
            (b_bn): SyncBatchNorm(560, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (b_af): ReLU(inplace=True)
            (c): Conv2d(560, 560, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (c_bn): SyncBatchNorm(560, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
          (af): ReLU(inplace=True)
        )
        (b5): ResBottleneckBlock(
          (f): BottleneckTransform(
            (a): Conv2d(560, 560, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (a_bn): SyncBatchNorm(560, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (a_af): ReLU(inplace=True)
            (b): Conv2d(560, 560, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=14, bias=False)
            (b_bn): SyncBatchNorm(560, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (b_af): ReLU(inplace=True)
            (c): Conv2d(560, 560, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (c_bn): SyncBatchNorm(560, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
          (af): ReLU(inplace=True)
        )
        (b6): ResBottleneckBlock(
          (f): BottleneckTransform(
            (a): Conv2d(560, 560, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (a_bn): SyncBatchNorm(560, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (a_af): ReLU(inplace=True)
            (b): Conv2d(560, 560, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=14, bias=False)
            (b_bn): SyncBatchNorm(560, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (b_af): ReLU(inplace=True)
            (c): Conv2d(560, 560, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (c_bn): SyncBatchNorm(560, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
          (af): ReLU(inplace=True)
        )
        (b7): ResBottleneckBlock(
          (f): BottleneckTransform(
            (a): Conv2d(560, 560, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (a_bn): SyncBatchNorm(560, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (a_af): ReLU(inplace=True)
            (b): Conv2d(560, 560, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=14, bias=False)
            (b_bn): SyncBatchNorm(560, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (b_af): ReLU(inplace=True)
            (c): Conv2d(560, 560, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (c_bn): SyncBatchNorm(560, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
          (af): ReLU(inplace=True)
        )
        (b8): ResBottleneckBlock(
          (f): BottleneckTransform(
            (a): Conv2d(560, 560, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (a_bn): SyncBatchNorm(560, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (a_af): ReLU(inplace=True)
            (b): Conv2d(560, 560, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=14, bias=False)
            (b_bn): SyncBatchNorm(560, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (b_af): ReLU(inplace=True)
            (c): Conv2d(560, 560, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (c_bn): SyncBatchNorm(560, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
          (af): ReLU(inplace=True)
        )
        (b9): ResBottleneckBlock(
          (f): BottleneckTransform(
            (a): Conv2d(560, 560, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (a_bn): SyncBatchNorm(560, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (a_af): ReLU(inplace=True)
            (b): Conv2d(560, 560, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=14, bias=False)
            (b_bn): SyncBatchNorm(560, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (b_af): ReLU(inplace=True)
            (c): Conv2d(560, 560, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (c_bn): SyncBatchNorm(560, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
          (af): ReLU(inplace=True)
        )
        (b10): ResBottleneckBlock(
          (f): BottleneckTransform(
            (a): Conv2d(560, 560, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (a_bn): SyncBatchNorm(560, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (a_af): ReLU(inplace=True)
            (b): Conv2d(560, 560, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=14, bias=False)
            (b_bn): SyncBatchNorm(560, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (b_af): ReLU(inplace=True)
            (c): Conv2d(560, 560, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (c_bn): SyncBatchNorm(560, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
          (af): ReLU(inplace=True)
        )
        (b11): ResBottleneckBlock(
          (f): BottleneckTransform(
            (a): Conv2d(560, 560, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (a_bn): SyncBatchNorm(560, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (a_af): ReLU(inplace=True)
            (b): Conv2d(560, 560, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=14, bias=False)
            (b_bn): SyncBatchNorm(560, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (b_af): ReLU(inplace=True)
            (c): Conv2d(560, 560, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (c_bn): SyncBatchNorm(560, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
          (af): ReLU(inplace=True)
        )
        (b12): ResBottleneckBlock(
          (f): BottleneckTransform(
            (a): Conv2d(560, 560, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (a_bn): SyncBatchNorm(560, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (a_af): ReLU(inplace=True)
            (b): Conv2d(560, 560, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=14, bias=False)
            (b_bn): SyncBatchNorm(560, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (b_af): ReLU(inplace=True)
            (c): Conv2d(560, 560, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (c_bn): SyncBatchNorm(560, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
          (af): ReLU(inplace=True)
        )
        (b13): ResBottleneckBlock(
          (f): BottleneckTransform(
            (a): Conv2d(560, 560, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (a_bn): SyncBatchNorm(560, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (a_af): ReLU(inplace=True)
            (b): Conv2d(560, 560, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=14, bias=False)
            (b_bn): SyncBatchNorm(560, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (b_af): ReLU(inplace=True)
            (c): Conv2d(560, 560, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (c_bn): SyncBatchNorm(560, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
          (af): ReLU(inplace=True)
        )
        (b14): ResBottleneckBlock(
          (f): BottleneckTransform(
            (a): Conv2d(560, 560, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (a_bn): SyncBatchNorm(560, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (a_af): ReLU(inplace=True)
            (b): Conv2d(560, 560, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=14, bias=False)
            (b_bn): SyncBatchNorm(560, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (b_af): ReLU(inplace=True)
            (c): Conv2d(560, 560, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (c_bn): SyncBatchNorm(560, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
          (af): ReLU(inplace=True)
        )
      )
      (s4): AnyStage(
        (b1): ResBottleneckBlock(
          (proj): Conv2d(560, 1360, kernel_size=(1, 1), stride=(2, 2), bias=False)
          (bn): SyncBatchNorm(1360, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (f): BottleneckTransform(
            (a): Conv2d(560, 1360, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (a_bn): SyncBatchNorm(1360, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (a_af): ReLU(inplace=True)
            (b): Conv2d(1360, 1360, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=34, bias=False)
            (b_bn): SyncBatchNorm(1360, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (b_af): ReLU(inplace=True)
            (c): Conv2d(1360, 1360, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (c_bn): SyncBatchNorm(1360, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
          (af): ReLU(inplace=True)
        )
        (b2): ResBottleneckBlock(
          (f): BottleneckTransform(
            (a): Conv2d(1360, 1360, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (a_bn): SyncBatchNorm(1360, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (a_af): ReLU(inplace=True)
            (b): Conv2d(1360, 1360, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=34, bias=False)
            (b_bn): SyncBatchNorm(1360, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (b_af): ReLU(inplace=True)
            (c): Conv2d(1360, 1360, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (c_bn): SyncBatchNorm(1360, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
          (af): ReLU(inplace=True)
        )
      )
    )
  )
  (proposal_generator): RPN(
    (rpn_head): StandardRPNHead(
      (conv): Sequential(
        (conv0): Conv2d(
          256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
          (activation): ReLU()
        )
        (conv1): Conv2d(
          256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
          (activation): ReLU()
        )
      )
      (objectness_logits): Conv2d(256, 3, kernel_size=(1, 1), stride=(1, 1))
      (anchor_deltas): Conv2d(256, 12, kernel_size=(1, 1), stride=(1, 1))
    )
    (anchor_generator): DefaultAnchorGenerator(
      (cell_anchors): BufferList()
    )
  )
  (roi_heads): StandardROIHeads(
    (box_pooler): ROIPooler(
      (level_poolers): ModuleList(
        (0): ROIAlign(output_size=(7, 7), spatial_scale=0.25, sampling_ratio=0, aligned=True)
        (1): ROIAlign(output_size=(7, 7), spatial_scale=0.125, sampling_ratio=0, aligned=True)
        (2): ROIAlign(output_size=(7, 7), spatial_scale=0.0625, sampling_ratio=0, aligned=True)
        (3): ROIAlign(output_size=(7, 7), spatial_scale=0.03125, sampling_ratio=0, aligned=True)
      )
    )
    (box_head): FastRCNNConvFCHead(
      (conv1): Conv2d(
        256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False
        (norm): NaiveSyncBatchNorm(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (activation): ReLU()
      )
      (conv2): Conv2d(
        256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False
        (norm): NaiveSyncBatchNorm(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (activation): ReLU()
      )
      (conv3): Conv2d(
        256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False
        (norm): NaiveSyncBatchNorm(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (activation): ReLU()
      )
      (conv4): Conv2d(
        256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False
        (norm): NaiveSyncBatchNorm(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (activation): ReLU()
      )
      (flatten): Flatten(start_dim=1, end_dim=-1)
      (fc1): Linear(in_features=12544, out_features=1024, bias=True)
      (fc_relu1): ReLU()
    )
    (box_predictor): FastRCNNOutputLayers(
      (cls_score): Linear(in_features=1024, out_features=499, bias=True)
      (bbox_pred): Linear(in_features=1024, out_features=1992, bias=True)
    )
    (mask_pooler): ROIPooler(
      (level_poolers): ModuleList(
        (0): ROIAlign(output_size=(14, 14), spatial_scale=0.25, sampling_ratio=0, aligned=True)
        (1): ROIAlign(output_size=(14, 14), spatial_scale=0.125, sampling_ratio=0, aligned=True)
        (2): ROIAlign(output_size=(14, 14), spatial_scale=0.0625, sampling_ratio=0, aligned=True)
        (3): ROIAlign(output_size=(14, 14), spatial_scale=0.03125, sampling_ratio=0, aligned=True)
      )
    )
    (mask_head): MaskRCNNConvUpsampleHead(
      (mask_fcn1): Conv2d(
        256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False
        (norm): NaiveSyncBatchNorm(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (activation): ReLU()
      )
      (mask_fcn2): Conv2d(
        256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False
        (norm): NaiveSyncBatchNorm(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (activation): ReLU()
      )
      (mask_fcn3): Conv2d(
        256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False
        (norm): NaiveSyncBatchNorm(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (activation): ReLU()
      )
      (mask_fcn4): Conv2d(
        256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False
        (norm): NaiveSyncBatchNorm(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (activation): ReLU()
      )
      (deconv): ConvTranspose2d(256, 256, kernel_size=(2, 2), stride=(2, 2))
      (deconv_relu): ReLU()
      (predictor): Conv2d(256, 498, kernel_size=(1, 1), stride=(1, 1))
    )
  )
)

[02/16 08:16:39 d2.data.common]: Serialized dataset takes 117.75 MiB
[02/16 08:16:39 fvcore.common.checkpoint]: No checkpoint found. Initializing model from scratch
[02/16 08:16:39 d2.engine.train_loop]: Starting training from iteration 0
/home/studio-lab-user/.conda/envs/default/lib/python3.9/site-packages/fvcore/transforms/transform.py:724: ShapelyDeprecationWarning: Iteration over multi-part geometries is deprecated and will be removed in Shapely 2.0. Use the `geoms` property to access the constituent parts of a multi-part geometry.
  for poly in cropped:
/home/studio-lab-user/.conda/envs/default/lib/python3.9/site-packages/fvcore/transforms/transform.py:724: ShapelyDeprecationWarning: Iteration over multi-part geometries is deprecated and will be removed in Shapely 2.0. Use the `geoms` property to access the constituent parts of a multi-part geometry.
  for poly in cropped:
/home/studio-lab-user/.conda/envs/default/lib/python3.9/site-packages/fvcore/transforms/transform.py:724: ShapelyDeprecationWarning: Iteration over multi-part geometries is deprecated and will be removed in Shapely 2.0. Use the `geoms` property to access the constituent parts of a multi-part geometry.
  for poly in cropped:
/home/studio-lab-user/.conda/envs/default/lib/python3.9/site-packages/torch/_tensor.py:575: UserWarning: floor_divide is deprecated, and will be removed in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values.
To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor'). (Triggered internally at  /pytorch/aten/src/ATen/native/BinaryOps.cpp:467.)
  return torch.floor_divide(self, other)
ERROR [02/16 08:16:42 d2.engine.train_loop]: Exception during training:
Traceback (most recent call last):
  File "/home/studio-lab-user/.conda/envs/default/lib/python3.9/site-packages/detectron2/engine/train_loop.py", line 149, in train
    self.run_step()
  File "/home/studio-lab-user/.conda/envs/default/lib/python3.9/site-packages/detectron2/engine/train_loop.py", line 395, in run_step
    loss_dict = self.model(data)
  File "/home/studio-lab-user/.conda/envs/default/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/studio-lab-user/.conda/envs/default/lib/python3.9/site-packages/detectron2/modeling/meta_arch/rcnn.py", line 154, in forward
    features = self.backbone(images.tensor)
  File "/home/studio-lab-user/.conda/envs/default/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/studio-lab-user/.conda/envs/default/lib/python3.9/site-packages/detectron2/modeling/backbone/fpn.py", line 126, in forward
    bottom_up_features = self.bottom_up(x)
  File "/home/studio-lab-user/.conda/envs/default/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/studio-lab-user/.conda/envs/default/lib/python3.9/site-packages/detectron2/modeling/backbone/regnet.py", line 315, in forward
    x = self.stem(x)
  File "/home/studio-lab-user/.conda/envs/default/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/studio-lab-user/.conda/envs/default/lib/python3.9/site-packages/detectron2/modeling/backbone/regnet.py", line 87, in forward
    x = layer(x)
  File "/home/studio-lab-user/.conda/envs/default/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/studio-lab-user/.conda/envs/default/lib/python3.9/site-packages/torch/nn/modules/batchnorm.py", line 731, in forward
    world_size = torch.distributed.get_world_size(process_group)
  File "/home/studio-lab-user/.conda/envs/default/lib/python3.9/site-packages/torch/distributed/distributed_c10d.py", line 748, in get_world_size
    return _get_group_size(group)
  File "/home/studio-lab-user/.conda/envs/default/lib/python3.9/site-packages/torch/distributed/distributed_c10d.py", line 274, in _get_group_size
    default_pg = _get_default_group()
  File "/home/studio-lab-user/.conda/envs/default/lib/python3.9/site-packages/torch/distributed/distributed_c10d.py", line 358, in _get_default_group
    raise RuntimeError("Default process group has not been initialized, "
RuntimeError: Default process group has not been initialized, please make sure to call init_process_group.
[02/16 08:16:42 d2.engine.hooks]: Total training time: 0:00:02 (0:00:00 on hooks)
[02/16 08:16:42 d2.utils.events]:  iter: 0    lr: N/A  max_mem: 2371M
Traceback (most recent call last):
  File "/home/studio-lab-user/sagemaker-studiolab-notebooks/detectron2/./tools/lazyconfig_train_net.py", line 195, in <module>
    launch(
  File "/home/studio-lab-user/.conda/envs/default/lib/python3.9/site-packages/detectron2/engine/launch.py", line 82, in launch
    main_func(*args)
  File "/home/studio-lab-user/sagemaker-studiolab-notebooks/detectron2/./tools/lazyconfig_train_net.py", line 190, in main
    do_train(args, cfg)
  File "/home/studio-lab-user/sagemaker-studiolab-notebooks/detectron2/./tools/lazyconfig_train_net.py", line 175, in do_train
    trainer.train(start_iter, cfg.train.max_iter)
  File "/home/studio-lab-user/.conda/envs/default/lib/python3.9/site-packages/detectron2/engine/train_loop.py", line 149, in train
    self.run_step()
  File "/home/studio-lab-user/.conda/envs/default/lib/python3.9/site-packages/detectron2/engine/train_loop.py", line 395, in run_step
    loss_dict = self.model(data)
  File "/home/studio-lab-user/.conda/envs/default/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/studio-lab-user/.conda/envs/default/lib/python3.9/site-packages/detectron2/modeling/meta_arch/rcnn.py", line 154, in forward
    features = self.backbone(images.tensor)
  File "/home/studio-lab-user/.conda/envs/default/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/studio-lab-user/.conda/envs/default/lib/python3.9/site-packages/detectron2/modeling/backbone/fpn.py", line 126, in forward
    bottom_up_features = self.bottom_up(x)
  File "/home/studio-lab-user/.conda/envs/default/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/studio-lab-user/.conda/envs/default/lib/python3.9/site-packages/detectron2/modeling/backbone/regnet.py", line 315, in forward
    x = self.stem(x)
  File "/home/studio-lab-user/.conda/envs/default/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/studio-lab-user/.conda/envs/default/lib/python3.9/site-packages/detectron2/modeling/backbone/regnet.py", line 87, in forward
    x = layer(x)
  File "/home/studio-lab-user/.conda/envs/default/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/studio-lab-user/.conda/envs/default/lib/python3.9/site-packages/torch/nn/modules/batchnorm.py", line 731, in forward
    world_size = torch.distributed.get_world_size(process_group)
  File "/home/studio-lab-user/.conda/envs/default/lib/python3.9/site-packages/torch/distributed/distributed_c10d.py", line 748, in get_world_size
    return _get_group_size(group)
  File "/home/studio-lab-user/.conda/envs/default/lib/python3.9/site-packages/torch/distributed/distributed_c10d.py", line 274, in _get_group_size
    default_pg = _get_default_group()
  File "/home/studio-lab-user/.conda/envs/default/lib/python3.9/site-packages/torch/distributed/distributed_c10d.py", line 358, in _get_default_group
    raise RuntimeError("Default process group has not been initialized, "
RuntimeError: Default process group has not been initialized, please make sure to call init_process_group.
@ppwwyyxx
Copy link
Contributor

ppwwyyxx commented Feb 18, 2022

The model you use contains SyncBatchNorm and therefore cannot be used on a single (CPU/GPU) worker. This is expected, therefore closing.

I have asked pytorch to let SyncBatchNorm work on a single GPU at pytorch/pytorch#63662 but no positive feedbacks so far.

UPDATE: pytorch 2.0 will include my fix pytorch/pytorch@56e40fe

@ppwwyyxx
Copy link
Contributor

As a workaround:

  1. print the config
  2. find all keys that has a value of "SyncBN" or similar
  3. Edit the config file or code to set these values to "BN" instead

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants