Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CUDA Multinomial Bug in Mesh Sampling #82

Closed
weichengkuo opened this issue Feb 23, 2020 · 17 comments
Closed

CUDA Multinomial Bug in Mesh Sampling #82

weichengkuo opened this issue Feb 23, 2020 · 17 comments
Assignees
Labels
bug Something isn't working question Further information is requested

Comments

@weichengkuo
Copy link

weichengkuo commented Feb 23, 2020

I'm using Mesh R-CNN evaluation pipeline that depends on Pytorch3D. Everything works fine for the most part and I was able to reproduce the numbers in the paper. However, every once in a while I ran into the following error for meshes that look normal to me. I tried printing the mesh tensor and they seem reasonable. I noticed that in line 62 of pytorch3d/ops/sample_points_from_meshes.py there's a TODO to fix a multinomial bug. I wonder if that's exactly what I'm running into here.

Is there anything I can do to avoid this?

/opt/conda/conda-bld/pytorch_1579027003190/work/aten/src/ATen/native/cuda/MultinomialKernel.cu:87: int at::native::<unnamed>::binarySearchForMultinomial(scalar_t *, scalar_t *, int, scalar_t) [with scalar_t = float]: block: [21,0,0], thread: [0,1,0] Assertioncumdist[size - 1] > static_cast<scalar_t>(0)failed. /opt/conda/conda-bld/pytorch_1579027003190/work/aten/src/ATen/native/cuda/MultinomialKernel.cu:87: int at::native::<unnamed>::binarySearchForMultinomial(scalar_t *, scalar_t *, int, scalar_t) [with scalar_t = float]: block: [21,0,0], thread: [0,2,0] Assertioncumdist[size - 1] > static_cast<scalar_t>(0)failed. /opt/conda/conda-bld/pytorch_1579027003190/work/aten/src/ATen/native/cuda/MultinomialKernel.cu:87: int at::native::<unnamed>::binarySearchForMultinomial(scalar_t *, scalar_t *, int, scalar_t) [with scalar_t = float]: block: [21,0,0], thread: [0,0,0] Assertioncumdist[size - 1] > static_cast<scalar_t>(0)failed. /opt/conda/conda-bld/pytorch_1579027003190/work/aten/src/ATen/native/cuda/MultinomialKernel.cu:87: int at::native::<unnamed>::binarySearchForMultinomial(scalar_t *, scalar_t *, int, scalar_t) [with scalar_t = float]: block: [21,0,0], thread: [0,3,0] Assertioncumdist[size - 1] > static_cast<scalar_t>(0)failed. Traceback (most recent call last): File "tools/train_net.py", line 292, in <module> args=(args,), File "/usr/local/google/home/weicheng/Documents/research/detectron2_repo/detectron2/engine/launch.py", line 52, in launch main_func(*args) File "tools/train_net.py", line 274, in main args.num_eval_images) File "tools/train_net.py", line 213, in eval_no_predict results_i = evaluation_on_dataset(model_preds, evaluator) File "/usr/local/google/home/weicheng/Documents/research/detectron2_repo/detectron2/evaluation/evaluator.py", line 185, in evaluation_on_dataset results = evaluator.evaluate() File "/google/src/cloud/weicheng/movemesh/google3/third_party/py/cadretinanet/meshrcnn/meshrcnn/evaluation/pix3d_evaluation.py", line 170, in evaluate self._eval_predictions() File "/google/src/cloud/weicheng/movemesh/google3/third_party/py/cadretinanet/meshrcnn/meshrcnn/evaluation/pix3d_evaluation.py", line 192, in _eval_predictions vis_output_name=self._vis_output_name, File "/google/src/cloud/weicheng/movemesh/google3/third_party/py/cadretinanet/meshrcnn/meshrcnn/evaluation/pix3d_evaluation.py", line 443, in evaluate_for_pix3d shape_metrics = compare_meshes(meshes, gt_mesh, reduce=False) File "/usr/local/google/home/weicheng/Documents/research/miniconda3/envs/pytorch3d2/lib/python3.6/site-packages/torch/autograd/grad_mode.py", line 49, in decorate_no_grad return func(*args, **kwargs) File "/google/src/cloud/weicheng/movemesh/google3/third_party/py/cadretinanet/meshrcnn/meshrcnn/utils/metrics.py", line 68, in compare_meshes pred_points, pred_normals = _sample_meshes(pred_meshes, num_samples_pred) File "/google/src/cloud/weicheng/movemesh/google3/third_party/py/cadretinanet/meshrcnn/meshrcnn/utils/metrics.py", line 153, in _sample_meshes verts, normals = sample_points_from_meshes(meshes, num_samples, return_normals=True) File "/usr/local/google/home/weicheng/Documents/research/miniconda3/envs/pytorch3d2/lib/python3.6/site-packages/pytorch3d/ops/sample_points_from_meshes.py", line 66, in sample_points_from_meshes print(meshes.valid) File "/usr/local/google/home/weicheng/Documents/research/miniconda3/envs/pytorch3d2/lib/python3.6/site-packages/torch/tensor.py", line 159, in __repr__ return torch._tensor_str._str(self) File "/usr/local/google/home/weicheng/Documents/research/miniconda3/envs/pytorch3d2/lib/python3.6/site-packages/torch/_tensor_str.py", line 311, in _str tensor_str = _tensor_str(self, indent) File "/usr/local/google/home/weicheng/Documents/research/miniconda3/envs/pytorch3d2/lib/python3.6/site-packages/torch/_tensor_str.py", line 209, in _tensor_str formatter = _Formatter(get_summarized_data(self) if summarize else self) File "/usr/local/google/home/weicheng/Documents/research/miniconda3/envs/pytorch3d2/lib/python3.6/site-packages/torch/_tensor_str.py", line 83, in __init__ value_str = '{}'.format(value) File "/usr/local/google/home/weicheng/Documents/research/miniconda3/envs/pytorch3d2/lib/python3.6/site-packages/torch/tensor.py", line 409, in __format__ return self.item().__format__(format_spec) RuntimeError: CUDA error: device-side assert triggered

@gkioxari
Copy link
Contributor

gkioxari commented Feb 23, 2020

Hey @weichengkuo! The multinomial bug has been fixed (the source was an old pytorch bug which is fixed now and we also have a test for it in the tests to verify). The TODO should be removed. I think the source of your error is different. Actually, your issue sees related to this issue #63. To help us reproduce this, could you provide the output of some print statements, just like discussed in the other issue? This likely happens because of some invalid indices in the indexing of the meshes but we need to find where exactly :)

@weichengkuo
Copy link
Author

Hey @gkioxari ! Thanks for your quick response, clarification about the test, and pointer to the relevant issue #63 ! It's great to know that the bug has been taken care of. Following the instructions in #63, I added the print statements in meshrcnn/utils/metric.py before the sample_points_from_meshes function call. Attached is the change I made to the code. Here's the link to the full log: https://cl1p.net/debug_cuda
Code change

Let me know if there's anything else I can print to shed more light on this. Thanks again!!

@weichengkuo
Copy link
Author

weichengkuo commented Feb 24, 2020

Oops the link doesn't seem to work. Here's the full log:

Debug_log.txt

@nikhilaravi
Copy link
Contributor

nikhilaravi commented Feb 24, 2020

@weichengkuo are you building pytorch3d from local clone? If so can you try to add print statements inside the sample_points_from_meshes function before the line sample_face_idxs += mesh_to_face[meshes.valid].view(num_valid_meshes, 1)?

Alternatively we can try to repro the issue - can you save meshes.verts_padded() and meshes.faces_padded() to a file before the call to sample_points_from_meshes and share them here? (you can save these values at each iteration but we only need the values from the iteration when the error occurs) We can then try to load and run the inputs that are triggering the error.

@nikhilaravi nikhilaravi self-assigned this Feb 24, 2020
@nikhilaravi nikhilaravi added bug Something isn't working question Further information is requested labels Feb 24, 2020
@weichengkuo
Copy link
Author

weichengkuo commented Feb 24, 2020

Hi @nikhilaravi, I'm using the pre-built pytorch3d through conda install. I tried saving the vertices and faces using verts_padded and it works for every iteration except for the one that crashes. The error occurs in the self.isempty() function.

I tried printing the vertices or faces of the meshes before calling isempty but they both crashes with similar error as well. Any thoughts what might have caused this?

Here's the stack trace:
/opt/conda/conda-bld/pytorch_1579027003190/work/aten/src/ATen/native/cuda/MultinomialKernel.cu:87: int at::native::<unnamed>::binarySearchForMultinomial(scalar_t *, scalar_t *, int, scalar_t) [with scalar_t = float]: block: [21,0,0], thread: [0,2,0] Assertion cumdist[size - 1] > static_cast<scalar_t>(0)failed. /opt/conda/conda-bld/pytorch_1579027003190/work/aten/src/ATen/native/cuda/MultinomialKernel.cu:87: int at::native::<unnamed>::binarySearchForMultinomial(scalar_t *, scalar_t *, int, scalar_t) [with scalar_t = float]: block: [21,0,0], thread: [0,1,0] Assertioncumdist[size - 1] > static_cast<scalar_t>(0)failed. /opt/conda/conda-bld/pytorch_1579027003190/work/aten/src/ATen/native/cuda/MultinomialKernel.cu:87: int at::native::<unnamed>::binarySearchForMultinomial(scalar_t *, scalar_t *, int, scalar_t) [with scalar_t = float]: block: [21,0,0], thread: [0,3,0] Assertioncumdist[size - 1] > static_cast<scalar_t>(0)failed. /opt/conda/conda-bld/pytorch_1579027003190/work/aten/src/ATen/native/cuda/MultinomialKernel.cu:87: int at::native::<unnamed>::binarySearchForMultinomial(scalar_t *, scalar_t *, int, scalar_t) [with scalar_t = float]: block: [21,0,0], thread: [0,0,0] Assertioncumdist[size - 1] > static_cast<scalar_t>(0) failed. Traceback (most recent call last): File "tools/train_net.py", line 292, in <module> args=(args,), File "/usr/local/google/home/weicheng/Documents/research/detectron2_repo/detectron2/engine/launch.py", line 52, in launch main_func(*args) File "tools/train_net.py", line 274, in main args.num_eval_images) File "tools/train_net.py", line 213, in eval_no_predict results_i = evaluation_on_dataset(model_preds, evaluator) File "/usr/local/google/home/weicheng/Documents/research/detectron2_repo/detectron2/evaluation/evaluator.py", line 185, in evaluation_on_dataset results = evaluator.evaluate() File "/google/src/cloud/weicheng/movemesh/google3/third_party/py/cadretinanet/meshrcnn/meshrcnn/evaluation/pix3d_evaluation.py", line 170, in evaluate self._eval_predictions() File "/google/src/cloud/weicheng/movemesh/google3/third_party/py/cadretinanet/meshrcnn/meshrcnn/evaluation/pix3d_evaluation.py", line 192, in _eval_predictions vis_output_name=self._vis_output_name, File "/google/src/cloud/weicheng/movemesh/google3/third_party/py/cadretinanet/meshrcnn/meshrcnn/evaluation/pix3d_evaluation.py", line 443, in evaluate_for_pix3d shape_metrics = compare_meshes(meshes, gt_mesh, reduce=False) File "/usr/local/google/home/weicheng/Documents/research/miniconda3/envs/pytorch3d2/lib/python3.6/site-packages/torch/autograd/grad_mode.py", line 49, in decorate_no_grad return func(*args, **kwargs) File "/google/src/cloud/weicheng/movemesh/google3/third_party/py/cadretinanet/meshrcnn/meshrcnn/utils/metrics.py", line 68, in compare_meshes pred_points, pred_normals = _sample_meshes(pred_meshes, num_samples_pred) File "/google/src/cloud/weicheng/movemesh/google3/third_party/py/cadretinanet/meshrcnn/meshrcnn/utils/metrics.py", line 153, in _sample_meshes verts, normals = sample_points_from_meshes(meshes, num_samples, return_normals=True) File "/usr/local/google/home/weicheng/Documents/research/miniconda3/envs/pytorch3d2/lib/python3.6/site-packages/pytorch3d/ops/sample_points_from_meshes.py", line 67, in sample_points_from_meshes mverts = meshes.verts_padded() File "/usr/local/google/home/weicheng/Documents/research/miniconda3/envs/pytorch3d2/lib/python3.6/site-packages/pytorch3d/structures/meshes.py", line 553, in verts_padded self._compute_padded() File "/usr/local/google/home/weicheng/Documents/research/miniconda3/envs/pytorch3d2/lib/python3.6/site-packages/pytorch3d/structures/meshes.py", line 866, in _compute_padded if self.isempty(): RuntimeError: CUDA error: device-side assert triggered

@nikhilaravi
Copy link
Contributor

nikhilaravi commented Feb 24, 2020

@weichengkuo ok thanks for the update. Instead of meshes.verts_padded can you try to save out meshes._verts_list and meshes._faces_list? These are the preceding lines before the call to self.isempty() so you should be able to get these values i.e.

verts_list = self._verts_list
faces_list = self._faces_list

@weichengkuo
Copy link
Author

weichengkuo commented Feb 24, 2020

@nikhilaravi, I tried doing that but it turns out I could do the assignment but couldn't print or access the values of the meshes structure members. Here's what I did to show this in the pdb debugger.

(Pdb) meshes.valid
*** RuntimeError: CUDA error: device-side assert triggered
(Pdb) meshes._verts_list
*** RuntimeError: cuda runtime error (59) : device-side assert triggered at /opt/conda/conda-bld/pytorch_1579027003190/work/aten/src/THC/THCCachingHostAllocator.cpp:278
(Pdb) verts = meshes.verts_packed()
(Pdb) verts
*** RuntimeError: cuda runtime error (59) : device-side assert triggered at /opt/conda/conda-bld/pytorch_1579027003190/work/aten/src/THC/THCGeneral.cpp:313
(Pdb) mesh_to_face = meshes.mesh_to_faces_packed_first_idx()
(Pdb) mesh_to_face
*** RuntimeError: CUDA error: device-side assert triggered

@weichengkuo
Copy link
Author

I managed to find a place where the code didn't crash while saving it. It's inside the compare_meshes function. The pred_meshes are attached here:
pred_meshes.zip

The snippet for saving them is here:

mverts = pred_meshes.verts_padded() mfaces = pred_meshes.faces_padded() torch.save(mverts, '/tmp/pred_verts.pt') torch.save(mfaces, '/tmp/pred_faces.pt')

The error log due to the sampling is here:
File "/google/src/cloud/weicheng/movemesh/google3/third_party/py/cadretinanet/meshrcnn/meshrcnn/utils/metrics.py", line 73, in compare_meshes
pred_points, pred_normals = _sample_meshes(pred_meshes, num_samples_pred)
File "/google/src/cloud/weicheng/movemesh/google3/third_party/py/cadretinanet/meshrcnn/meshrcnn/utils/metrics.py", line 165, in _sample_meshes
verts, normals = sample_points_from_meshes(meshes, num_samples, return_normals=True)
File "/usr/local/google/home/weicheng/Documents/research/miniconda3/envs/pytorch3d2/lib/python3.6/site-packages/pytorch3d/ops/sample_points_from_meshes.py", line 67, in sample_points_from_meshes
sample_face_idxs += mesh_to_face[meshes.valid].view(num_valid_meshes, 1)
RuntimeError: copy_if failed to synchronize: device-side assert triggered

@weichengkuo
Copy link
Author

weichengkuo commented Feb 25, 2020

Hey @nikhilaravi and @gkioxari , I prepared a minimal script to reproduce the bug that loads the zip file as below:

`
import torch
from pytorch3d.structures import Meshes
from pytorch3d.ops import sample_points_from_meshes

verts = torch.load('pred_verts.pt')
faces = torch.load('pred_faces.pt')
mesh = Meshes(verts=verts, faces=faces)
verts, normals = sample_points_from_meshes(mesh, 10000, return_normals=True)
`

@weichengkuo
Copy link
Author

Hey @nikhilaravi and @gkioxari , I prepared a minimal script to reproduce the bug that loads the zip file as below:

`import torch
from pytorch3d.structures import Meshes
from pytorch3d.ops import sample_points_from_meshes

verts = torch.load('pred_verts.pt')
faces = torch.load('pred_faces.pt')
mesh = Meshes(verts=verts, faces=faces)
verts, normals = sample_points_from_meshes(mesh, 10000, return_normals=True)
`

@nikhilaravi
Copy link
Contributor

@weichengkuo great, thank you! We can debug this and get back to you. Feel free to keep this issue open though until we resolve the cause of the error.

@weichengkuo
Copy link
Author

Thanks so much!! We have an ECCV 2020 deadline coming up in a week so we'd really appreciate your help!

@nikhilaravi
Copy link
Contributor

nikhilaravi commented Feb 25, 2020

@weichengkuo the issue is not with sample_points_from_meshes but with the batch of meshes. There are several vertices which are nans e.g. see this output:

(Pdb) mesh.verts_packed()[torch.isnan(mesh.verts_packed())]
tensor([nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
        nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
        nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
        nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
        nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
       device='cuda:0')
(Pdb) torch.isnan(mesh.verts_packed()).sum()
tensor(112, device='cuda:0')

We need to determine what upstream of the call to sample_points_from_meshes is causing the vertices to be set to nan.

Can you provide some more details about your pipeline?

@weichengkuo
Copy link
Author

Thanks for the great find @nikhilaravi !! I'll look more carefully at my pipeline again to identify the cause. Will update you as soon as possible.

@gkioxari
Copy link
Contributor

gkioxari commented Feb 25, 2020

@weichengkuo Are you getting these nans after retraining a meshrcnn model with the meshrcnn codebase? I have never ran into this error so I wonder how the nans ended up there. If you are diverging from the training recipe or if there something going on in the optimization, then the predictions can be nans and thus you can run into this.

@weichengkuo
Copy link
Author

weichengkuo commented Feb 25, 2020

Hi @gkioxari and @nikhilaravi, I've identified the cause now and it's the prediction of my model that went wrong in one particular mesh. I don't think there's anything wrong with mesh rcnn optimization here. I should have checked for nan values when initializing the mesh structure :)

But huge thanks for your prompt assistance! Very, very much appreciated!

@ChenFengYe
Copy link

ChenFengYe commented Jun 12, 2020

@gkioxari Hi,I am running this tutorial deform_source_mesh_to_target_mesh. Find this

RuntimeError: CUDA error: invalid device function

around pytorch3d/ops/mesh_face_areas_normals.py Line 41. I just load this provided obj file - dolphin.obj. I only run this script from load to visualization, while error occurs in points = sample_points_from_meshes(mesh, 5000).

Another related thing is that the load part shows a quite strange warning

/home/pytorch3d/pytorch3d/io/obj_io.py:70: UserWarning: Faces have invalid indices
warnings.warn("Faces have invalid indices")

A bit strange... I do check the trg_mesh after finishing loading, but there are not nan variances or negative or bigger than max_face_id. Any advice?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working question Further information is requested
Projects
None yet
Development

No branches or pull requests

4 participants