Skip to content
This repository has been archived by the owner on Oct 31, 2023. It is now read-only.

Distributed validation #275

Open
miguelvr opened this issue Dec 14, 2018 · 14 comments
Open

Distributed validation #275

miguelvr opened this issue Dec 14, 2018 · 14 comments

Comments

@miguelvr
Copy link
Contributor

❓ Questions and Help

I'm working on a branch where I implemented validation inference at every checkpoint.

Everything was working fine until the new changes from torch.deprecated.distributed to the new torch.distributed

Now either the Dataloader breaks on one of the processes or, if I run the inference in the main process, it hangs there forever.

sample code:

        if iteration % checkpoint_period == 0 or iteration == max_iter:
            checkpointer.save("model_{:07d}".format(iteration), **arguments)
            if is_main_process():
                if val_data_loader is not None:
                    logger.info('Evaluating on validation data set')
                    iou_types = ("bbox",)
                    if cfg.MODEL.MASK_ON:
                        iou_types = iou_types + ("segm",)

                    inference(
                        model,
                        val_data_loader,
                        iou_types=iou_types,
                        box_only=cfg.MODEL.RPN_ONLY,
                        device=cfg.MODEL.DEVICE,
                        expected_results=cfg.TEST.EXPECTED_RESULTS,
                        expected_results_sigma_tol=cfg.TEST.EXPECTED_RESULTS_SIGMA_TOL,
                        verbose=False
                    )

                model.train()  # reset training flag
            synchronize()

Any work around this?

@fmassa
Copy link
Contributor

fmassa commented Dec 14, 2018

Don't you need to unwrap the model from DistributedDataParallel before feeding it to inference?
By that I mean passing model.module instead of model?

@miguelvr
Copy link
Contributor Author

That doesn't make a difference. Same behaviour

@fmassa
Copy link
Contributor

fmassa commented Dec 14, 2018

I don't actually know what else could be the problem, I'd need to check it out to identify where it hangs.

Can you maybe attach a gdb and print the stack trace where it hangs?

python tools/train_net.py --config-file /path/to/config

and once it hangs, attach the gdb

gdb attach <pid>

and then run

thread apply all bt

and paste the results?

@BobZhangHT
Copy link

@fmassa Hi! I also meet the same problem. I follow your instruction to attach gdb and run the corresponding codes. Here is my results:

Thread 7 (Thread 0x7f6033568700 (LWP 3350)):
#0  0x00007f6087f6803f in accept4 () from /lib64/libc.so.6
#1  0x00007f6073b644a6 in ?? () from /usr/lib64/nvidia/libcuda.so.1
#2  0x00007f6073b58a3d in ?? () from /usr/lib64/nvidia/libcuda.so.1
#3  0x00007f6073b65110 in ?? () from /usr/lib64/nvidia/libcuda.so.1
#4  0x00007f608823ce25 in start_thread () from /lib64/libpthread.so.0
#5  0x00007f6087f66bad in clone () from /lib64/libc.so.6

Thread 6 (Thread 0x7f6032d67700 (LWP 3351)):
#0  0x00007f6087f5bf0d in poll () from /lib64/libc.so.6
#1  0x00007f60712478e9 in c10d::TCPStoreDaemon::run() ()
   from /home/r7user3/anaconda2/envs/maskrcnn/lib/python3.7/site-packages/torch/lib/libtorch_python.so
#2  0x00007f6071e61dc0 in ?? ()
   from /home/r7user3/anaconda2/envs/maskrcnn/lib/python3.7/site-packages/torch/../../../libstdc++.so.6
#3  0x00007f608823ce25 in start_thread () from /lib64/libpthread.so.0
#4  0x00007f6087f66bad in clone () from /lib64/libc.so.6

Thread 5 (Thread 0x7f602a6fa700 (LWP 3512)):
#0  0x00007f6087f5bf0d in poll () from /lib64/libc.so.6
#1  0x00007f6073b6369b in ?? () from /usr/lib64/nvidia/libcuda.so.1

Could you please provide any suggestions?

@fmassa
Copy link
Contributor

fmassa commented Jan 11, 2019

@BobZhangHT we might need the full log for that.
Does it still hang on single-machine training?

@BobZhangHT
Copy link

BobZhangHT commented Jan 11, 2019

@fmassa

Thanks for your reply. I modified my codes and it did not hang on since then. Actually I successfully run several rounds after that and suddenly got stuck in a new problem about distributed training. Here is the log, and I update my pytorch to the latest version but it does not work.

File "tools/train_net.py", line 230, in <module>
    main()
  File "tools/train_net.py", line 223, in main
    model = train(cfg, args.local_rank, args.distributed)
  File "tools/train_net.py", line 121, in train
    masksgt,)# add 2019/01/11 
  File "/home/r7user3/ZhangHT/github/maskrcnn-benchmark/maskrcnn_benchmark/engine/trainer.py", line 106, in do_train
    loss_dict = model(images, targets)
  File "/home/r7user3/anaconda2/envs/maskrcnn/lib/python3.7/site-packages/torch/nn/modules/module.py", line 489, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/r7user3/anaconda2/envs/maskrcnn/lib/python3.7/site-packages/torch/nn/parallel/distributed.py", line 364, in forward
    return self.module(*inputs[0], **kwargs[0])
  File "/home/r7user3/anaconda2/envs/maskrcnn/lib/python3.7/site-packages/torch/nn/modules/module.py", line 489, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/r7user3/ZhangHT/github/maskrcnn-benchmark/maskrcnn_benchmark/modeling/detector/generalized_rcnn.py", line 50, in forward
    proposals, proposal_losses = self.rpn(images, features, targets)
  File "/home/r7user3/anaconda2/envs/maskrcnn/lib/python3.7/site-packages/torch/nn/modules/module.py", line 489, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/r7user3/ZhangHT/github/maskrcnn-benchmark/maskrcnn_benchmark/modeling/rpn/rpn.py", line 100, in forward
    return self._forward_train(anchors, objectness, rpn_box_regression, targets)
  File "/home/r7user3/ZhangHT/github/maskrcnn-benchmark/maskrcnn_benchmark/modeling/rpn/rpn.py", line 116, in _forward_train
    anchors, objectness, rpn_box_regression, targets
  File "/home/r7user3/anaconda2/envs/maskrcnn/lib/python3.7/site-packages/torch/nn/modules/module.py", line 489, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/r7user3/ZhangHT/github/maskrcnn-benchmark/maskrcnn_benchmark/modeling/rpn/inference.py", line 138, in forward
    sampled_boxes.append(self.forward_for_single_feature_map(a, o, b))
  File "/home/r7user3/ZhangHT/github/maskrcnn-benchmark/maskrcnn_benchmark/modeling/rpn/inference.py", line 113, in forward_for_single_feature_map
    boxlist = remove_small_boxes(boxlist, self.min_size)
  File "/home/r7user3/ZhangHT/github/maskrcnn-benchmark/maskrcnn_benchmark/structures/boxlist_ops.py", line 46, in remove_small_boxes
    (ws >= min_size) & (hs >= min_size)
RuntimeError: copy_if failed to synchronize: device-side assert triggered
terminate called without an active exception

@BobZhangHT
Copy link

BobZhangHT commented Jan 13, 2019

@fmassa Hi. I additionally tracked the training every iteration and found that it can actually train for the first 7 or 9 iterations, then break with the following error. It seems that something's index is out of bound in cuda but I can not figure out what it is. Here is the part of the error information, since some of them are in a repeated pattern so I just provide a small fraction.

/opt/conda/conda-bld/pytorch-nightly_1547199076991/work/aten/src/ATen/native/cuda/IndexKernel.cu:53: lambda [](int)->auto::operator()(int)->auto: block: [29,0,0], thread: [76,0,0] Assertion `index >= -sizes[i] && index < sizes[i] && "index out of bounds"` failed.
/opt/conda/conda-bld/pytorch-nightly_1547199076991/work/aten/src/ATen/native/cuda/IndexKernel.cu:53: lambda [](int)->auto::operator()(int)->auto: block: [29,0,0], thread: [77,0,0] Assertion `index >= -sizes[i] && index < sizes[i] && "index out of bounds"` failed.
/opt/conda/conda-bld/pytorch-nightly_1547199076991/work/aten/src/ATen/native/cuda/IndexKernel.cu:53: lambda [](int)->auto::operator()(int)->auto: block: [29,0,0], thread: [78,0,0] Assertion `index >= -sizes[i] && index < sizes[i] && "index out of bounds"` failed.
/opt/conda/conda-bld/pytorch-nightly_1547199076991/work/aten/src/ATen/native/cuda/IndexKernel.cu:53: lambda [](int)->auto::operator()(int)->auto: block: [29,0,0], thread: [79,0,0] Assertion `index >= -sizes[i] && index < sizes[i] && "index out of bounds"` failed.
/opt/conda/conda-bld/pytorch-nightly_1547199076991/work/aten/src/ATen/native/cuda/IndexKernel.cu:53: lambda [](int)->auto::operator()(int)->auto: block: [29,0,0], thread: [52,0,0] Assertion `index >= -sizes[i] && index < sizes[i] && "index out of bounds"` failed.
/opt/conda/conda-bld/pytorch-nightly_1547199076991/work/aten/src/ATen/native/cuda/IndexKernel.cu:53: lambda [](int)->auto::operator()(int)->auto: block: [29,0,0], thread: [53,0,0] Assertion `index >= -sizes[i] && index < sizes[i] && "index out of bounds"` failed.
/opt/conda/conda-bld/pytorch-nightly_1547199076991/work/aten/src/ATen/native/cuda/IndexKernel.cu:53: lambda [](int)->auto::operator()(int)->auto: block: [29,0,0], thread: [54,0,0] Assertion `index >= -sizes[i] && index < sizes[i] && "index out of bounds"` failed.
/opt/conda/conda-bld/pytorch-nightly_1547199076991/work/aten/src/ATen/native/cuda/IndexKernel.cu:53: lambda [](int)->auto::operator()(int)->auto: block: [29,0,0], thread: [55,0,0] Assertion `index >= -sizes[i] && index < sizes[i] && "index out of bounds"` failed.
/opt/conda/conda-bld/pytorch-nightly_1547199076991/work/aten/src/ATen/native/cuda/IndexKernel.cu:53: lambda [](int)->auto::operator()(int)->auto: block: [29,0,0], thread: [56,0,0] Assertion `index >= -sizes[i] && index < sizes[i] && "index out of bounds"` failed.
/opt/conda/conda-bld/pytorch-nightly_1547199076991/work/aten/src/ATen/native/cuda/IndexKernel.cu:53: lambda [](int)->auto::operator()(int)->auto: block: [29,0,0], thread: [57,0,0] Assertion `index >= -sizes[i] && index < sizes[i] && "index out of bounds"` failed.
/opt/conda/conda-bld/pytorch-nightly_1547199076991/work/aten/src/ATen/native/cuda/IndexKernel.cu:53: lambda [](int)->auto::operator()(int)->auto: block: [29,0,0], thread: [58,0,0] Assertion `index >= -sizes[i] && index < sizes[i] && "index out of bounds"` failed.
/opt/conda/conda-bld/pytorch-nightly_1547199076991/work/aten/src/ATen/native/cuda/IndexKernel.cu:53: lambda [](int)->auto::operator()(int)->auto: block: [29,0,0], thread: [59,0,0] Assertion `index >= -sizes[i] && index < sizes[i] && "index out of bounds"` failed.
/opt/conda/conda-bld/pytorch-nightly_1547199076991/work/aten/src/ATen/native/cuda/IndexKernel.cu:53: lambda [](int)->auto::operator()(int)->auto: block: [29,0,0], thread: [60,0,0] Assertion `index >= -sizes[i] && index < sizes[i] && "index out of bounds"` failed.
/opt/conda/conda-bld/pytorch-nightly_1547199076991/work/aten/src/ATen/native/cuda/IndexKernel.cu:53: lambda [](int)->auto::operator()(int)->auto: block: [29,0,0], thread: [61,0,0] Assertion `index >= -sizes[i] && index < sizes[i] && "index out of bounds"` failed.
/opt/conda/conda-bld/pytorch-nightly_1547199076991/work/aten/src/ATen/native/cuda/IndexKernel.cu:53: lambda [](int)->auto::operator()(int)->auto: block: [29,0,0], thread: [62,0,0] Assertion `index >= -sizes[i] && index < sizes[i] && "index out of bounds"` failed.
/opt/conda/conda-bld/pytorch-nightly_1547199076991/work/aten/src/ATen/native/cuda/IndexKernel.cu:53: lambda [](int)->auto::operator()(int)->auto: block: [29,0,0], thread: [63,0,0] Assertion `index >= -sizes[i] && index < sizes[i] && "index out of bounds"` failed.
/opt/conda/conda-bld/pytorch-nightly_1547199076991/work/aten/src/ATen/native/cuda/IndexKernel.cu:53: lambda [](int)->auto::operator()(int)->auto: block: [8,0,0], thread: [0,0,0] Assertion `index >= -sizes[i] && index < sizes[i] && "index out of bounds"` failed.
....
Traceback (most recent call last):
  File "tools/train_net.py", line 237, in <module>
    main()
  File "tools/train_net.py", line 230, in main
    model = train(cfg, args.local_rank, args.distributed)
  File "tools/train_net.py", line 138, in train
    arguments,)
  File "/home/r7user3/ZhangHT/github/maskrcnn-benchmark/maskrcnn_benchmark/engine/trainer.py", line 109, in do_train
    loss_dict = model(images, targets)
  File "/home/r7user3/anaconda2/envs/maskrcnn/lib/python3.7/site-packages/torch/nn/modules/module.py", line 489, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/r7user3/anaconda2/envs/maskrcnn/lib/python3.7/site-packages/torch/nn/parallel/distributed.py", line 364, in forward
    return self.module(*inputs[0], **kwargs[0])
  File "/home/r7user3/anaconda2/envs/maskrcnn/lib/python3.7/site-packages/torch/nn/modules/module.py", line 489, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/r7user3/ZhangHT/github/maskrcnn-benchmark/maskrcnn_benchmark/modeling/detector/generalized_rcnn.py", line 50, in forward
    proposals, proposal_losses = self.rpn(images, features, targets)
  File "/home/r7user3/anaconda2/envs/maskrcnn/lib/python3.7/site-packages/torch/nn/modules/module.py", line 489, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/r7user3/ZhangHT/github/maskrcnn-benchmark/maskrcnn_benchmark/modeling/rpn/rpn.py", line 100, in forward
    return self._forward_train(anchors, objectness, rpn_box_regression, targets)
  File "/home/r7user3/ZhangHT/github/maskrcnn-benchmark/maskrcnn_benchmark/modeling/rpn/rpn.py", line 116, in _forward_train
    anchors, objectness, rpn_box_regression, targets
  File "/home/r7user3/anaconda2/envs/maskrcnn/lib/python3.7/site-packages/torch/nn/modules/module.py", line 489, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/r7user3/ZhangHT/github/maskrcnn-benchmark/maskrcnn_benchmark/modeling/rpn/inference.py", line 138, in forward
    sampled_boxes.append(self.forward_for_single_feature_map(a, o, b))
  File "/home/r7user3/ZhangHT/github/maskrcnn-benchmark/maskrcnn_benchmark/modeling/rpn/inference.py", line 113, in forward_for_single_feature_map
    boxlist = remove_small_boxes(boxlist, self.min_size)
  File "/home/r7user3/ZhangHT/github/maskrcnn-benchmark/maskrcnn_benchmark/structures/boxlist_ops.py", line 46, in remove_small_boxes
    (ws >= min_size) & (hs >= min_size)
RuntimeError: copy_if failed to synchronize: device-side assert triggered
terminate called without an active exception
terminate called without an active exception

Update: I also found that it may take place occasionally even for single GPU training.

@fmassa
Copy link
Contributor

fmassa commented Jan 14, 2019

I would say that you have in one of your training data an out-of-bound indexing or something like that.
Maybe one of your training examples has an GT index which is larger than NUM_CLASSES - 1?

@BobZhangHT
Copy link

@fmassa Sincerely thanks for your reply. After I double check my codes, I still cannot figure out where my training data's indices are out-of-bound. Here is my code to generate COCO format data. All ids start from 1 and there are 2 classes in total (including the background).

def coco_dict_gen(files,imagesFile,masksFile):
    img_lst=[]
    anno_lst=[]

    for i,file in enumerate(tqdm(files)):
        
        # image dict
        image_dict={}
        # image file directory
        image_dir=imagesFile+file
        # load the image
        image=plt.imread(image_dir,'RGB')[:,:,:3]
        # specify the image dict
        image_dict['file_name']=image_dir
        image_dict['height']=int(image.shape[0])
        image_dict['width']=int(image.shape[1])
        image_dict['id']=i+1
        # append
        img_lst.append(image_dict)

        # annotation dict
        anno_dict={}
        # load mask
        mask_dir=masksFile+file
        mask=plt.imread(mask_dir,'RGB')[:,:,:3]
        # convert RGB mask into gray mask
        bimask=cv2.cvtColor(mask,cv2.COLOR_RGB2GRAY).astype(np.uint8)
        poly=bimask_to_polygon(bimask)
        bimask=np.asfortranarray(bimask,dtype=np.uint8) # convert to fortran format
        rle_cprs=mask_utils.encode(bimask)
        # save dict
        anno_dict['image_id']=i+1
        anno_dict['id']=i+1
        anno_dict['category_id']=1
        anno_dict['iscrowd']=0
        anno_dict['segmentation']=poly#rle_uncprs
        anno_dict['area']=float(mask_utils.area(rle_cprs))
        anno_dict['bbox']=list(mask_utils.toBbox(rle_cprs))
        # append
        anno_lst.append(anno_dict)
        
    data_dict={}
    data_dict['info']={}
    data_dict['licenses']=[]
    data_dict['images']=img_lst
    data_dict['annotations']=anno_lst
    data_dict['categories']=[{'supercategory': 'Slice', 'id': 1, 'name': 'Lesion Slice'}]
    
    return data_dict

@fmassa
Copy link
Contributor

fmassa commented Jan 14, 2019

So, you have two classes, including background, or excluding background?

I'd check that if your indices go up to 10, then NUM_CLASSES should be 11.

@BobZhangHT
Copy link

BobZhangHT commented Jan 14, 2019

Including background actually. So I set the class as 2 (1 for object and 1 for background). Besides, I found a very confusing fact. Under a single GPU, If I use the default setting of e2e_mask_rcnn_R_50_FPN_1x.yaml, that is, if I use

python tools/train_net.py --config-file "configs/e2e_mask_rcnn_R_50_FPN_1x.yaml" SOLVER.IMS_PER_BATCH 2 TEST.IMS_PER_BATCH 1

to train my model, I could not avoid this error.
But if I turn to this

python tools/train_net.py --config-file "configs/e2e_mask_rcnn_R_50_FPN_1x.yaml" SOLVER.IMS_PER_BATCH 2 SOLVER.BASE_LR 0.0025 SOLVER.MAX_ITER 60000 SOLVER.STEPS "(30000, 40000)" TEST.IMS_PER_BATCH 1

or

python tools/train_net.py --config-file "configs/e2e_mask_rcnn_R_50_FPN_1x.yaml" SOLVER.IMS_PER_BATCH 2 SOLVER.BASE_LR 0.0025 SOLVER.MAX_ITER 720000 SOLVER.STEPS "(480000, 640000)" TEST.IMS_PER_BATCH 1

I can train. I don't know why, it's so weird.

@fmassa
Copy link
Contributor

fmassa commented Jan 14, 2019

Oh, your model probably diverged during training, because you were using a too large learning rate for batch size 2.

Given that the problem doesn't seem to be related to distributed validation, if the error persists would you mind opening a new issue?

Thanks

@BobZhangHT
Copy link

@fmassa Thanks you sooooo much! After tuning the learning rate, it works! As for the distributed case, I used to followe the lr_schedule in Detectron where the lr x2 for 2 GPUs (lr=0.005), it seems that the lr is still too large so I turn it back to 0.0025.

@1119066022
Copy link

1119066022 commented Dec 27, 2019

@BobZhangHT @fmassa hello
i use the following config for Single-gpu cityscapes instance seg training
SOLVER.IMS_PER_BATCH 1
SOLVER.BASE_LR 0.00125
Steps=32000

so should i change the config as follows when i use two gpus?
SOLVER.IMS_PER_BATCH 2
SOLVER.BASE_LR 0.00125 #or 0.0025?
Steps=16000 #or 8000?

thanks!

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
Development

No branches or pull requests

4 participants