Skip to content
This repository has been archived by the owner on Mar 12, 2024. It is now read-only.

'assert (boxes1[:, 2:] >= boxes1[:, :2]).all()' happened when training #101

Open
LovPe opened this issue Jun 24, 2020 · 26 comments
Open

'assert (boxes1[:, 2:] >= boxes1[:, :2]).all()' happened when training #101

LovPe opened this issue Jun 24, 2020 · 26 comments

Comments

@LovPe
Copy link

LovPe commented Jun 24, 2020

Thanks for amazing work
I have questions when training with your code, assert (boxes1[:, 2:] >= boxes1[:, :2]).all() happened in function generalized_box_iou
After reading the code i find that boxes1 is the predictd bbox from a MLP layer, which i think the above assertion may happen during early training time, and then break the training.
I wonder if there are Mechanism that can make sure to avoid this happen

My Environment:

Provide your environment information using the following command:
Collecting environment information...
PyTorch version: 1.4.0
Is debug build: No
CUDA used to build PyTorch: 10.0

OS: Ubuntu 16.04.5 LTS
GCC version: (Ubuntu 4.9.3-13ubuntu2) 4.9.3
CMake version: version 3.16.2

Python version: 3.6
Is CUDA available: Yes
CUDA runtime version: 10.0.130
GPU models and configuration:
GPU 0: TITAN Xp
GPU 1: TITAN Xp
GPU 2: TITAN Xp
GPU 3: TITAN Xp

Nvidia driver version: 410.48
cuDNN version: /usr/lib/x86_64-linux-gnu/libcudnn.so.7.6.5

Versions of relevant libraries:
[pip3] numpy==1.17.1
[pip3] torch==1.4.0
[pip3] torchfile==0.1.0
[pip3] torchvision==0.5.0
[conda] mkl 2019.4 243 https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
[conda] pytorch 1.4.0 py3.6_cuda10.0.130_cudnn7.6.3_0 https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/pytorch
[conda] torch 1.0.0
[conda] torchfile 0.1.0
[conda] torchvision 0.5.0 py36_cu100 https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/pytorch

@alcinos
Copy link
Contributor

alcinos commented Jun 24, 2020

@LovPe Thank you for your interest in DETR.

This assertion is not a fluke, if you're getting it, it means that something is going wrong in your training. Here are some potential things you can look into:

  1. We normally enforce that the boxes are non-degenerate thanks to this sigmoid:
    outputs_coord = self.bbox_embed(hs).sigmoid()
    Did you, by any chance, remove that?
  2. This is possibly a red-herring, and the error is somewhere else. You can try running the code with CUDA_LAUNCH_BLOCKING=1 python main.py to see if anything comes up. See custom training asserts with "degenerate bboxes" over and over - but bboxes look correct, any debugging insight? #28 for more details
  3. I note that your libraries are a bit old, we recommend PyTorch 1.5+ and torchvision 0.6+. If everything else fails, you can try upgrading.

Hope this helps, good luck with the debugging.

@raviv
Copy link

raviv commented Jun 24, 2020

@LovPe I was getting this error when the learning rate was too high.

@kuixu
Copy link

kuixu commented Jun 28, 2020

@LovPe

I was getting this error and found all the boxes are NaN, which is the problem of a full mask generated in the interpolation step (Backbone forward), mainly because of very large zero-padding in loading batch image and mask.

If so, the variables below (outputs_coord, outputs_class and hs) are all NaN.

detr/models/detr.py

Lines 65 to 68 in 10a2c75

hs = self.transformer(self.input_proj(src), mask, self.query_embed.weight, pos[-1])[0]
outputs_class = self.class_embed(hs)
outputs_coord = self.bbox_embed(hs).sigmoid()

@zlyin
Copy link

zlyin commented Jul 11, 2020

I think @raviv is correct. In my dataset, LR = 2e-4 works well but I came across such error when I set LR = 2e-3.

@zlyin
Copy link

zlyin commented Jul 12, 2020

Sorry for my wrong answer. Setting small LR actually only delays the error popping up. I've already set the LR to e-6 level, but still got the error....

@fmassa
Copy link
Contributor

fmassa commented Jul 12, 2020

Hi @zlyin

I believe we went over most of the debugging tips to identify where the root cause of this issue might be in #28.

In particular, I would look to see if there are other error messages that appear in your code before the assert from the beginning, such as

RuntimeError: CUDA error: device-side assert triggered

Which could be a different issue and caused by wrong number of classes.

@zlyin
Copy link

zlyin commented Jul 14, 2020

Hi @fmassa, thanks for your reply. I solved this issue by changing the bbox format into the normalized coco format.

@liminn
Copy link

liminn commented Jul 15, 2020

I am training DETR on COCO panoptic dataset, and I also meet the error assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
I use:
python3 main.py --lr_drop 100 --epochs 150 --coco_path /mnt/hdd1/zlm/data/open_source/coco/ --coco_panoptic_path /mnt/hdd1/zlm/data/open_source/panoptic_annotations_trainval2017/ --dataset_file coco_panoptic --output_dir /mnt/hdd1/zlm/data/train_results/detr_4/
I just use single GPU, and change the lr to lr/8.0, and nothing other changed.

@LovPe
Copy link
Author

LovPe commented Jul 20, 2020

@fmassa
sorry for late replaying. You are right the code can make sure the assert will not happen(sorry for my misunderstanding ).
The core problem is caused somewhere else that the box value is all NaN. From some reply, this may due to incorrect learning rate. so I trained with 4 gpu and I tuned the batch size follow the linear scaling rule, but it still not work.
there are 2 main differences in training settings:
1- 4GPU vs 8GPU (8GPU is not avalible for me. I wonder if you ever tried training successfully with 4 GPU device.)
2- torch1.4-cuda10-torchvision0.5 vs torch1.5-cuda10.1-torchvision0.6 (for some reason, i can not use update to cuda10.1)

@liminn
Copy link

liminn commented Jul 20, 2020

the box value is all NaN

@LovPe
Agree with you, the same phenomenon:the box value is all NaN.
My envs: torch==1.5.0, cuda==10.1, torchvision==0.6.0, single GPU

@fmassa Hi, could you give us some advice? I have experimented with both COCO datasets and custom dataset, and I didn't change any hyperparameters except gpu number. I have this problem all the time.
And my experiments is panoptic segmentation. 'assert (boxes1[:, 2:] >= boxes1[:, :2]).all()' happened on the first stage training: python main.py --coco_path /path/to/coco --coco_panoptic_path /path/to/coco_panoptic --dataset_file coco_panoptic --output_dir /output/path/box_model

@Chris-hughes10
Copy link

I also seem to be running into this error when training on custom data. It seems to happen with certain combinations of batch and image sizes according to no pattern that I can determine. I am feeding the bbox values into the network in xywh (normalised) format, and I don't seem to be running into any other errors.

@d-li14
Copy link

d-li14 commented Aug 9, 2020

I met the same issue when training with the default setting on COCO detection.

python -m torch.distributed.launch --nproc_per_node=8 --use_env main.py --coco_path /path/to/coco 

@sarmientoj24
Copy link

any solutions on this?

@zhengye1995
Copy link

I meet the same issue, any solutions on this?

@1338199
Copy link

1338199 commented Jan 15, 2021

@LovPe

I was getting this error and found all the boxes are NaN, which is the problem of a full mask generated in the interpolation step (Backbone forward), mainly because of very large zero-padding in loading batch image and mask.

If so, the variables below (outputs_coord, outputs_class and hs) are all NaN.

detr/models/detr.py

Lines 65 to 68 in 10a2c75

hs = self.transformer(self.input_proj(src), mask, self.query_embed.weight, pos[-1])[0]
outputs_class = self.class_embed(hs)
outputs_coord = self.bbox_embed(hs).sigmoid()

Hi, recently I met a similar problem, and when I set the key_padding_mask in multiheadattention as none, the bug is gone. So I wonder if it only needs a large zero-padding mask or must have a full zero-padding mask to cause the issue. Thank you.

@aedirn
Copy link

aedirn commented Feb 23, 2021

@1338199
Hi, recently I met a similar problem, and when I set the key_padding_mask in multiheadattention as none, the bug is gone. So I wonder if it only needs a large zero-padding mask or must have a full zero-padding mask to cause the issue. Thank you.

Hi, I found multiple uses of key_padding_mask by a multiheadattention in transformer.py, do you know which line numbers you changed?

Edit for posterity: The lines I tried changing were transformer.py 227 and 251, each from "key_padding_mask=memory_padding_mask" to "key_padding_mask=None". It did not fix the box assertion error that all of us are having. Is this the same change that you made?

@eslambakr
Copy link

I am sharing my experience may be it will be beneficial for you.
I got this error when the num_classes is set in a wrong way.
When I fixed it the error was resolved.

@yonadance
Copy link

I have changed the num_classes and the lr but the error is still on there.

@jvcop
Copy link

jvcop commented May 24, 2022

1. We normally enforce that the boxes are non-degenerate thanks to this sigmoid: https://github.com/facebookresearch/detr/blob/10a2c759454930813aeac7af5e779f835dcb75f5/models/detr.py#L68

I'm wondering whether this is sufficient. For example, if w / 2 > cx, isn't the box degenerate even though each individual value is within the bounds of the image?

@Puranjay-del-Mishra
Copy link

Puranjay-del-Mishra commented May 31, 2022

Hi @fmassa, thanks for your reply. I solved this issue by changing the bbox format into the normalized coco format.

Hello @zlyin ! Even I encountered this error and I was hoping you could share the details of your workaround. Thanks.

@Puranjay-del-Mishra
Copy link

I still get this error. All I did was set the batch size to 1. Did someone get a solution for this?

@rocklee2022
Copy link

set num_class = the nums of classes + 1 in config, then this error not pop up

@bencevans
Copy link

The Bounding Box Loss == nan seems to be when there are no detections in the targets. The cause, at least it seems to resolve it so far here... 🤞... is to update these two lines:

From:

losses['loss_bbox'] = loss_bbox.sum() / num_boxes

To:

losses['loss_bbox'] = loss_bbox.sum() / num_boxes if num_boxes > 0 else loss_bbox.sum()

From:

losses['loss_giou'] = loss_giou.sum() / num_boxes

To:

losses['loss_giou'] = loss_giou.sum() / num_boxes if num_boxes > 0 else loss_giou.sum()

@lilligao
Copy link

hi, i also have the same error. I tried all the suggestions above (change num_classes or only divided by num_boxes when it's >0).
And really strange that if i train on cpu, everything works, this error comes only when i use gpu to train. Any one can help me? Thanks!

@YAOSL98
Copy link

YAOSL98 commented Jan 22, 2024 via email

@lilligao
Copy link

hi, i also have the same error. I tried all the suggestions above (change num_classes or only divided by num_boxes when it's >0). And really strange that if i train on cpu, everything works, this error comes only when i use gpu to train. Any one can help me? Thanks!

i found where the problem was! It was because of the torch.bmm which is used in modeling_detr.py. Since the gpu i am using will have reduced precision https://pytorch.org/docs/stable/notes/numerical_accuracy.html#fp16-on-mi200 if doing calculations with tensor of type float16 or float32. So my solution is to add following lines:
In class DetrAttention:

query_states = torch.as_tensor(query_states, dtype=torch.float64)
key_states = torch.as_tensor(key_states, dtype=torch.float64)

before the line

attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))

add

value_states = torch.as_tensor(value_states, dtype=torch.float64)

before the line

attn_output = torch.bmm(attn_probs, value_states)

and add

attn_output = torch.as_tensor(attn_output, dtype=torch.float32)

before the line

attn_output = self.out_proj(attn_output)

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests