Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RuntimeError: CUDA error: an illegal memory access was encountered when training cfa #614

Closed
3 tasks done
sltlls opened this issue Nov 13, 2022 · 17 comments
Closed
3 tasks done
Labels
bug Something isn't working dev-1.x

Comments

@sltlls
Copy link
Contributor

sltlls commented Nov 13, 2022

Prerequisite

Task

I'm using the official example scripts/configs for the officially supported tasks/models/datasets.

Branch

1.x branch https://github.com/open-mmlab/mmrotate/tree/1.x

Environment

sys.platform: linux
Python: 3.8.13 (default, Mar 28 2022, 11:38:47) [GCC 7.5.0]
CUDA available: True
numpy_random_seed: 2147483648
GPU 0,1,2,3: TITAN X (Pascal)
CUDA_HOME: /usr/local/cuda-10.1
NVCC: Cuda compilation tools, release 10.1, V10.1.10
GCC: gcc (Ubuntu 5.4.0-6ubuntu1~16.04.12) 5.4.0 20160609
PyTorch: 1.7.1+cu101
PyTorch compiling details: PyTorch built with:

  • GCC 7.3
  • C++ Version: 201402
  • Intel(R) Math Kernel Library Version 2020.0.0 Product Build 20191122 for Intel(R) 64 architecture applications
  • Intel(R) MKL-DNN v1.6.0 (Git Hash 5ef631a030a6f73131c77892041042805a06064f)
  • OpenMP 201511 (a.k.a. OpenMP 4.5)
  • NNPACK is enabled
  • CPU capability usage: AVX2
  • CUDA Runtime 10.1
  • NVCC architecture flags: -gencode;arch=compute_37,code=sm_37;-gencode;arch=compute_50,code=sm_50;-gencode;arch=compute_60,code=sm_60;-gencode;arch=compute_70,code=sm_70;-gencode;arch=compute_75,code=sm_75
  • CuDNN 7.6.3
  • Magma 2.5.2
  • Build settings: BLAS=MKL, BUILD_TYPE=Release, CXX_FLAGS= -Wno-deprecated -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -fopenmp -DNDEBUG -DUSE_FBGEMM -DUSE_QNNPACK -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -DUSE_VULKAN_WRAPPER -O2 -fPIC -Wno-narrowing -Wall -Wextra -Werror=return-type -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-sign-compare -Wno-unused-parameter -Wno-unused-variable -Wno-unused-function -Wno-unused-result -Wno-unused-local-typedefs -Wno-strict-overflow -Wno-strict-aliasing -Wno-error=deprecated-declarations -Wno-stringop-overflow -Wno-psabi -Wno-error=pedantic -Wno-error=redundant-decls -Wno-error=old-style-cast -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Wno-stringop-overflow, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, USE_CUDA=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=ON, USE_MKLDNN=ON, USE_MPI=OFF, USE_NCCL=ON, USE_NNPACK=ON, USE_OPENMP=ON,

TorchVision: 0.8.2+cu101
OpenCV: 4.6.0
MMEngine: 0.3.0
MMRotate: 1.0.0rc0+

Reproduces the problem - code sample

python tools/tain.py configs/cfa/cfa-qbox_r50_fpn_1x_dota.py

Reproduces the problem - command or script

CUDA_LAUNCH_BLOCKING=1 python tools/tain.py configs/cfa/cfa-qbox_r50_fpn_1x_dota.py

Reproduces the problem - error message

Result has been saved to /media/amax/partion2/lsy/work_dir/mmlabV2/mmrotate/cfa/cfa_qbox_r50_fpn_1x_dota1/modules_statistic_results.json
11/13 14:51:42 - mmengine - INFO - Distributed training is not used, all SyncBatchNorm (SyncBN) layers in the model will be automatically reverted to BatchNormXd layers if they are used.
11/13 14:51:46 - mmengine - WARNING - Failed to search registry with scope "mmrotate" in the "optim_wrapper" registry tree. As a workaround, the current "optim_wrapper" registry in "mmengine" is used to build instance. This may cause unexpected failure when running the built modules. Please check whether "mmrotate" is a correct scope, or whether the registry is initialized.
fatal: Not a git repository (or any parent up to mount point /media/amax/partion2)
Stopping at filesystem boundary (GIT_DISCOVERY_ACROSS_FILESYSTEM not set).
11/13 14:51:47 - mmengine - INFO - load model from: torchvision://resnet50
11/13 14:51:47 - mmengine - INFO - torchvision loads checkpoint from path: torchvision://resnet50
11/13 14:51:48 - mmengine - WARNING - The model and loaded state dict do not match exactly

unexpected key in source state_dict: fc.weight, fc.bias

11/13 14:51:48 - mmengine - INFO - Checkpoints will be saved to /media/amax/partion2/lsy/work_dir/mmlabV2/mmrotate/cfa/cfa_qbox_r50_fpn_1x_dota1.
/media/amax/partion2/lsy/workspace/mmlab_V2/mmrotate/mmrotate/structures/bbox/quadri_boxes.py:146: UserWarning: The clip function does nothing in QuadriBoxes.
warnings.warn('The clip function does nothing in QuadriBoxes.')
/media/amax/partion2/lsy/workspace/mmlab_V2/mmrotate/mmrotate/structures/bbox/quadri_boxes.py:146: UserWarning: The clip function does nothing in QuadriBoxes.
warnings.warn('The clip function does nothing in QuadriBoxes.')
Traceback (most recent call last):
File "tools/train.py", line 122, in
main()
File "tools/train.py", line 118, in main
runner.train()
File "/home/amax/anaconda3/envs/openmmlab2/lib/python3.8/site-packages/mmengine/runner/runner.py", line 1661, in train
model = self.train_loop.run() # type: ignore
File "/home/amax/anaconda3/envs/openmmlab2/lib/python3.8/site-packages/mmengine/runner/loops.py", line 90, in run
self.run_epoch()
File "/home/amax/anaconda3/envs/openmmlab2/lib/python3.8/site-packages/mmengine/runner/loops.py", line 106, in run_epoch
self.run_iter(idx, data_batch)
File "/home/amax/anaconda3/envs/openmmlab2/lib/python3.8/site-packages/mmengine/runner/loops.py", line 122, in run_iter
outputs = self.runner.model.train_step(
File "/home/amax/anaconda3/envs/openmmlab2/lib/python3.8/site-packages/mmengine/model/base_model/base_model.py", line 114, in train_step
losses = self._run_forward(data, mode='loss') # type: ignore
File "/home/amax/anaconda3/envs/openmmlab2/lib/python3.8/site-packages/mmengine/model/base_model/base_model.py", line 320, in _run_forward
results = self(**data, mode=mode)
File "/home/amax/anaconda3/envs/openmmlab2/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, kwargs)
File "/media/amax/partion2/lsy/workspace/mmlab_V2/mmdetection3.x/mmdet/models/detectors/base.py", line 92, in forward
return self.loss(inputs, data_samples)
File "/media/amax/partion2/lsy/workspace/mmlab_V2/mmdetection3.x/mmdet/models/detectors/single_stage.py", line 78, in loss
losses = self.bbox_head.loss(x, batch_data_samples)
File "/media/amax/partion2/lsy/workspace/mmlab_V2/mmdetection3.x/mmdet/models/dense_heads/base_dense_head.py", line 123, in loss
losses = self.loss_by_feat(loss_inputs)
File "/media/amax/partion2/lsy/workspace/mmlab_V2/mmrotate/mmrotate/models/dense_heads/cfa_head.py", line 186, in loss_by_feat
pos_losses_list, = multi_apply(self.get_pos_loss, cls_scores,
File "/media/amax/partion2/lsy/workspace/mmlab_V2/mmdetection3.x/mmdet/models/utils/misc.py", line 218, in multi_apply
return tuple(map(list, zip(map_results)))
File "/media/amax/partion2/lsy/workspace/mmlab_V2/mmrotate/mmrotate/models/dense_heads/cfa_head.py", line 379, in get_pos_loss
loss_bbox = self.loss_bbox_refine(
File "/home/amax/anaconda3/envs/openmmlab2/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(input, **kwargs)
File "/media/amax/partion2/lsy/workspace/mmlab_V2/mmrotate/mmrotate/models/losses/convex_giou_loss.py", line 111, in forward
loss = self.loss_weight * convex_giou_loss(
File "/media/amax/partion2/lsy/workspace/mmlab_V2/mmrotate/mmrotate/models/losses/convex_giou_loss.py", line 37, in forward
convex_gious, grad = convex_giou(pred, target)
File "/home/amax/anaconda3/envs/openmmlab2/lib/python3.8/site-packages/mmcv/ops/convex_iou.py", line 28, in convex_giou
ext_module.convex_giou(pointsets, polygons, output)
RuntimeError: CUDA error: an illegal memory access was encountered
Exception raised from ConvexGIoUCUDAKernelLauncher at /tmp/mmcv/mmcv/ops/csrc/pytorch/cuda/convex_iou.cu:40 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x42 (0x7f6faf2928b2 in /home/amax/anaconda3/envs/openmmlab2/lib/python3.8/site-packages/torch/lib/libc10.so)
frame #1: ConvexGIoUCUDAKernelLauncher(at::Tensor, at::Tensor, at::Tensor) + 0x1d6 (0x7f6f8ccb85c3 in /home/amax/anaconda3/envs/openmmlab2/lib/python3.8/site-packages/mmcv/_ext.cpython-38-x86_64-linux-gnu.so)
frame #2: convex_giou_cuda(at::Tensor, at::Tensor, at::Tensor) + 0x67 (0x7f6f8ccc8417 in /home/amax/anaconda3/envs/openmmlab2/lib/python3.8/site-packages/mmcv/_ext.cpython-38-x86_64-linux-gnu.so)
frame #3: auto Dispatch<DeviceRegistry<void (
)(at::Tensor, at::Tensor, at::Tensor), &(convex_giou_impl(at::Tensor, at::Tensor, at::Tensor))>, at::Tensor const&, at::Tensor const&, at::Tensor&>(DeviceRegistry<void (
)(at::Tensor, at::Tensor, at::Tensor), &(convex_giou_impl(at::Tensor, at::Tensor, at::Tensor))> const&, char const
, at::Tensor const&, at::Tensor const&, at::Tensor&) + 0x802 (0x7f6f8cbfe182 in /home/amax/anaconda3/envs/openmmlab2/lib/python3.8/site-packages/mmcv/_ext.cpython-38-x86_64-linux-gnu.so)
frame #4: convex_giou(at::Tensor, at::Tensor, at::Tensor) + 0x67 (0x7f6f8cbfbdd7 in /home/amax/anaconda3/envs/openmmlab2/lib/python3.8/site-packages/mmcv/_ext.cpython-38-x86_64-linux-gnu.so)
frame #5: + 0x330513 (0x7f6f8ce8a513 in /home/amax/anaconda3/envs/openmmlab2/lib/python3.8/site-packages/mmcv/_ext.cpython-38-x86_64-linux-gnu.so)
frame #6: + 0x350190 (0x7f6f8ceaa190 in /home/amax/anaconda3/envs/openmmlab2/lib/python3.8/site-packages/mmcv/_ext.cpython-38-x86_64-linux-gnu.so)
frame #7: + 0x34de4e (0x7f6f8cea7e4e in /home/amax/anaconda3/envs/openmmlab2/lib/python3.8/site-packages/mmcv/_ext.cpython-38-x86_64-linux-gnu.so)

frame #16: THPFunction_apply(_object, _object) + 0x93d (0x7f6ff9a732dd in /home/amax/anaconda3/envs/openmmlab2/lib/python3.8/site-packages/torch/lib/libtorch_python.so)

Additional information

I'm using DOTA V1.0 dataset, and this error occurs when training

@zytx121
Copy link
Collaborator

zytx121 commented Nov 13, 2022

Please try to install mmcv by

pip install mmcv==2.0.0rc2 -f https://download.openmmlab.com/mmcv/dist/cu101/torch1.7/index.html

Do other models also raise errors? Or just this one?

@sltlls
Copy link
Contributor Author

sltlls commented Nov 14, 2022

I've tried reinstall mmcv using command above, but this erroe still exist.
And it seems that all models using convex_giou_loss will occur this error(I've tested training cfa, rotated_reppoints and oriented_reppoints, they all have the same error). And when I train roi_transformer, this error doesn't exist.

@Qian-CV
Copy link

Qian-CV commented Nov 14, 2022

I also encountered this problem, It can be trained at the beginning, but this problem occurs during the training process or validation randomly.

@zytx121 zytx121 added bug Something isn't working dev-1.x labels Nov 15, 2022
@zytx121
Copy link
Collaborator

zytx121 commented Nov 15, 2022

I also encountered this problem, It can be trained at the beginning, but this problem occurs during the training process or validation randomly.

Hi @Lebron0126
Please provide us more information.

@zytx121
Copy link
Collaborator

zytx121 commented Nov 15, 2022

I've tried reinstall mmcv using command above, but this erroe still exist. And it seems that all models using convex_giou_loss will occur this error(I've tested training cfa, rotated_reppoints and oriented_reppoints, they all have the same error). And when I train roi_transformer, this error doesn't exist.

It's weird. Cfa, rotated_reppoints and oriented_reppoints can success train or test on our device.

@SlytherinGe
Copy link

SlytherinGe commented Nov 19, 2022

I've tried reinstall mmcv using command above, but this erroe still exist. And it seems that all models using convex_giou_loss will occur this error(I've tested training cfa, rotated_reppoints and oriented_reppoints, they all have the same error). And when I train roi_transformer, this error doesn't exist.

It's weird. Cfa, rotated_reppoints and oriented_reppoints can success train or test on our device.

I encountered the same problem using oriented_reppoints.
I did extra experiment, it shows that when using mmcv.ops.min_area_polygons, the same CUDA error arises. I can't tell what the problem is, as I carefully inspect the input data, I can't see anything abnormal and this happens randomly. I guess there are some problems in the source code of the function min_area_polygons.

@yangxue0827
Copy link
Collaborator

Several feasible solutions: #405

@LLsmile
Copy link

LLsmile commented Dec 1, 2022

Same error. But it occurs in epoch_8. It seems this error may occur in any time of the training procedure.

@SlytherinGe
Copy link

I've tried reinstall mmcv using command above, but this erroe still exist. And it seems that all models using convex_giou_loss will occur this error(I've tested training cfa, rotated_reppoints and oriented_reppoints, they all have the same error). And when I train roi_transformer, this error doesn't exist.

It's weird. Cfa, rotated_reppoints and oriented_reppoints can success train or test on our device.

I encountered the same problem using oriented_reppoints. I did extra experiment, it shows that when using mmcv.ops.min_area_polygons, the same CUDA error arises. I can't tell what the problem is, as I carefully inspect the input data, I can't see anything abnormal and this happens randomly. I guess there are some problems in the source code of the function min_area_polygons.

Update for this bug:
I found that when train with Adam, the feature map result from the forward function would produce extreme large value (1e32 or -1e32) unexpectedly in an iteration. Before this iteration, everything goes normally. It is the extreme big value cracked the function min_area_polygons when training the reppoints. But when I switch the optimizer from Adam to SGD, the model can be trained without any problem.

@crisz94
Copy link

crisz94 commented Feb 15, 2023

I've tried reinstall mmcv using command above, but this erroe still exist. And it seems that all models using convex_giou_loss will occur this error(I've tested training cfa, rotated_reppoints and oriented_reppoints, they all have the same error). And when I train roi_transformer, this error doesn't exist.

It's weird. Cfa, rotated_reppoints and oriented_reppoints can success train or test on our device.

I encountered the same problem using oriented_reppoints. I did extra experiment, it shows that when using mmcv.ops.min_area_polygons, the same CUDA error arises. I can't tell what the problem is, as I carefully inspect the input data, I can't see anything abnormal and this happens randomly. I guess there are some problems in the source code of the function min_area_polygons.

According to open-mmlab/mmcv#2407, adding small random noise to the input of mmcv.ops.min_area_polygons seems to be a feasible solution. As stated in the above link, this bug may be caused by numerical instability of min_area_polygons cuda op.

Meanwhile, I think cv2.minAreaRect is another way to fix this bug as cv2.minAreaRect can return the true min area polygons without adding any noises to the input. But the speed of training might be slowed down as cv2.minAreaRect is running on cpu.

    def min_area_polygons(self, point_sets):
        polygons = []
        for point_set in point_sets:
            polygons.append(cv2.boxPoints(
                cv2.minAreaRect(point_set.cpu().numpy().reshape(-1, 2))).reshape(-1))
        polygons = torch.tensor(
            np.vstack(polygons), dtype=point_sets.dtype, device=point_sets.device)
        return polygons

FYI, @yangxue0827 Unfortunately, all the solutions you provided is proven to be unuseful in my case

@yellowjs0304
Copy link

yellowjs0304 commented Mar 29, 2023

@yangxue0827 @zytx121 What can i do for this issue. I got same error.

+) It maybe related with this, and There's same line in mmrotate_handler.py

@sltlls
Copy link
Contributor Author

sltlls commented Apr 29, 2023

Update for this bug:
After several testing, it seems that mmcv ops: min_area_polygon and convex_giou all have numerical instability problems for specific predictions. Here are the details.

For min_area_polygon function, I tested the case where the input convex is extremely small, for example, all points in convex are 0 or random numbers generated with mean of 0 and variance of 1e-30 are used as points of convex. However, it works normally in these cases, so I further tested the function and find a convex data that can reproduce this problem:

pred_pt = [1.0350e-02, 1.0234e+03, 1.0242e-02, 1.0259e-02, 1.0291e-02, 1.0244e-02, 1.0240e+03, 1.0249e-02, 1.0240e+03, 7.1550e-01, 1.0240e+03, 1.0240e+03, 1.0240e+03, 1.0240e+03, 7.6025e-02, 1.0240e+03, 1.0211e+03, 6.6669e+02

You can find that the point of this convex are distributed in the corner as well as the boundary of the image(all images in my dataset are 1024*1024), I don't know the exact reason of this problem, but it seems that convex points distributed in the corner or border of the image may lead to the collapse of this function.

For convex_giou function, I also test extremely small convex input(the convex area is extremely small). In such situation, convex_iou works normally but convex_giou returns negative iou with 0 grad. However, it doesn't occur errors in this situation and can still running. After further testing, I found that specific predicted convex with specific target quadrangle may lead to the collapse of convex_giou. Here is an example I found can reproduce the error of convex_giou:

pred_pt = torch.tensor([1024 for _ in range(18)],
                       dtype=torch.float32,
                       device=device).unsqueeze(0)
target = torch.tensor([2.4000e+01, 5.6400e+02, 1.0000e-05, 5.6700e+02, 1.0000e-05, 5.1700e+02, 1.5000e+01, 5.1400e+02],
                       dtype=torch.float32,
                       device=device).unsqueeze(0)
convex_gious, grad = convex_giou(pred_pt, target)

In this example, the target quadrangle intersects the boundary of the image, and this error only occurs when convex points are all right bottom corner of the image, even convex with all 0 points will not lead to this error. I don't know why such thing happens and it's so wired. For this case, simply add some small random noise with std=1e-3 to every convex point can avoid the error.

@DapengFeng
Copy link

@sltlls In my case, it works for the min_area_polygons test.

>>> import torch
>>> from mmcv.ops import min_area_polygons
>>> pred_pt = torch.tensor([[1.0350e-02, 1.0234e+03, 1.0242e-02, 1.0259e-02, 1.0291e-02, 1.0244e-02, 1.0240e+03, 1.0249e-02, 1.0240e+03, 7.1550e-01, 1.0240e+03, 1.0240e+03, 1.0240e+03, 1.0240e+03, 7.6025e-02, 1.0240e+03, 1.0211e+03, 6.6669e+02]]).cuda()
>>> min_area_polygons(pred_pt)
tensor([[1.0240e+03, 1.0249e-02, 1.0242e-02, 1.0244e-02, 1.0237e-02, 1.0240e+03,
         1.0240e+03, 1.0240e+03]], device='cuda:0')

And for the convex_giou test, I found the same error.

>>> import torch
>>> from mmcv.ops import convex_giou
>>> device='cuda'
>>> pred_pt = torch.tensor([1024 for _ in range(18)],
...                        dtype=torch.float32,
...                        device=device).unsqueeze(0)
>>> target = torch.tensor([2.4000e+01, 5.6400e+02, 1.0000e-05, 5.6700e+02, 1.0000e-05, 5.1700e+02, 1.5000e+01, 5.1400e+02],
...                        dtype=torch.float32,
...                        device=device).unsqueeze(0)
>>> convex_gious, grad = convex_giou(pred_pt, target)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/dapeng/openmmlab/mmcv/mmcv/ops/convex_iou.py", line 28, in convex_giou
    ext_module.convex_giou(pointsets, polygons, output)
RuntimeError: CUDA error: an illegal memory access was encountered
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

@DapengFeng
Copy link

The error occurred in judging the collinearity of the points. For example, as shown in the figure, the value of sign = cross(point 6, point 7, point 1) and sign = cross(point 7, point 6, point 1) is small but less than zero. This will cause an infinite loop in the following code snippets.

  while (k_index != max_index) {
    p_k = p_max;
    k_index = max_index;
    for (int i = 1; i < n_poly; i++) {
      sign = cross(in_poly[Stack[top2]], in_poly[i], p_k);
      if ((sign < 0) || (sign == 0) && (dis(in_poly[Stack[top2]], in_poly[i]) >
                                        dis(in_poly[Stack[top2]], p_k))) {
        p_k = in_poly[i];
        k_index = i;
      }
    }
    top2++;
    Stack[top2] = k_index;
  }

open-mmlab/mmcv#2786

[[67.189827, 445.142517,
80.672424, 726.558838,
154.913605, 711.699219,
94.742928, 574.717041,
116.207588, 755.282104,
99.939827, 489.668854,
14.599480, 462.490112,
35.282326, 545.151672,
87.332382, 536.160767]]

@DapengFeng
Copy link

>>> import torch
>>> from mmcv.ops import convex_giou
>>> device='cuda'
>>> pred_pt = torch.tensor([1024 for _ in range(18)],
...                        dtype=torch.float32,
...                        device=device).unsqueeze(0)
>>> target = torch.tensor([2.4000e+01, 5.6400e+02, 1.0000e-05, 5.6700e+02, 1.0000e-05, 5.1700e+02, 1.5000e+01, 5.1400e+02],
...                        dtype=torch.float32,
...                        device=device).unsqueeze(0)
>>> convex_gious, grad = convex_giou(pred_pt, target)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/dapeng/openmmlab/mmcv/mmcv/ops/convex_iou.py", line 28, in convex_giou
    ext_module.convex_giou(pointsets, polygons, output)
RuntimeError: CUDA error: an illegal memory access was encountered
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

In this case, the convex hull of pred_pt are two same points. But when calculating the area of the convex hull of the union of pred_pt and target, it will fall into an infinite loop in the same code snippets.

  while (k_index != max_index) {
    p_k = p_max;
    k_index = max_index;
    for (int i = 1; i < n_poly; i++) {
      sign = cross(in_poly[Stack[top2]], in_poly[i], p_k);
      if ((sign < 0) || (sign == 0) && (dis(in_poly[Stack[top2]], in_poly[i]) >
                                        dis(in_poly[Stack[top2]], p_k))) {
        p_k = in_poly[i];
        k_index = i;
      }
    }
    top2++;
    Stack[top2] = k_index;
  }

@crisz94
Copy link

crisz94 commented May 5, 2023

Fun fact: when dtype of pred_pt and target are setted to torch.float64, it works for the convex_giou test. Maybe this is another feasible solution? @DapengFeng

@DapengFeng
Copy link

DapengFeng commented May 5, 2023

Changing data type will relieve the number instability but not solve the problem. Pre-sorting the points in order guarantees that the newly added point is always outside the convex hull formed upto then. An interesting website about convex hull.
Figure_1

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working dev-1.x
Projects
None yet
Development

No branches or pull requests

9 participants