Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Running pytorch3d with theseus optimization #1636

Closed
neanea04 opened this issue Sep 14, 2023 · 5 comments
Closed

Running pytorch3d with theseus optimization #1636

neanea04 opened this issue Sep 14, 2023 · 5 comments

Comments

@neanea04
Copy link

neanea04 commented Sep 14, 2023

🐛 Bugs / Unexpected behaviors

I am trying to integrate theseus optimization library and pytorch3d for my application. When I am passing the variables from Theseus to pytorch3d for rendering it throws the following error.

*RuntimeError: vmap: inplace arithmetic(self, extra_args) is not possible because there exists a Tensor other in extra_args that has more elements than self. This happened due to other being vmapped over but self not being vmapped over in a vmap. Please try to use out-of-place operators instead of inplace arithmetic. If said operator is being called inside the PyTorch framework, please file a bug report instead.

I am using a silhouette_renderer to render an obj model.
self.silhouette_renderer(meshes_world=self.mesh, R=R, T=T)
The R and T is coming from Theseus while optimizing.

GradTrackingTensor(lvl=2, value=
    BatchedTensor(lvl=1, bdim=0, value=
        tensor([[[[ 0.9541,  0.2773, -0.1129],
                  [-0.1744,  0.8211,  0.5435],
                  [ 0.2434, -0.4988,  0.8318]]]], device='cuda:0', dtype=torch.float64)
    )
) GradTrackingTensor(lvl=2, value=
    BatchedTensor(lvl=1, bdim=0, value=
        tensor([[[-0.0891,  0.1401,  0.4739]]], device='cuda:0', dtype=torch.float64)
    )
)

These are the tensors passed by theseus while optimizing. When I pass normal tensors ie, while not in optimizing mode the renderer works perfectly.

Traceback (most recent call last):
  File "main.py", line 24, in <module>
    main()
  File "main.py", line 21, in main
    optimizer.optimization()   
  File "/home/nahar/6D-Pose-Estimation/learning-tools/multi-camera-tools/optimizer.py", line 350, in optimization
    print(cost_function.jacobians())
  File "/home/nahar/miniconda3/envs/theseus/lib/python3.8/site-packages/theseus/core/cost_function.py", line 355, in jacobians
    jacobians_full = self._compute_autograd_jacobian_vmap(
  File "/home/nahar/miniconda3/envs/theseus/lib/python3.8/site-packages/theseus/core/cost_function.py", line 341, in _compute_autograd_jacobian_vmap
    return vmap(jacrev(jac_fn, argnums=0))(optim_tensors, aux_tensors)
  File "/home/nahar/miniconda3/envs/theseus/lib/python3.8/site-packages/torch/_functorch/vmap.py", line 434, in wrapped
    return _flat_vmap(
  File "/home/nahar/miniconda3/envs/theseus/lib/python3.8/site-packages/torch/_functorch/vmap.py", line 39, in fn
    return f(*args, **kwargs)
  File "/home/nahar/miniconda3/envs/theseus/lib/python3.8/site-packages/torch/_functorch/vmap.py", line 619, in _flat_vmap
    batched_outputs = func(*batched_inputs, **kwargs)
  File "/home/nahar/miniconda3/envs/theseus/lib/python3.8/site-packages/torch/_functorch/eager_transforms.py", line 489, in wrapper_fn
    vjp_out = _vjp_with_argnums(func, *args, argnums=argnums, has_aux=has_aux)
  File "/home/nahar/miniconda3/envs/theseus/lib/python3.8/site-packages/torch/_functorch/vmap.py", line 39, in fn
    return f(*args, **kwargs)
  File "/home/nahar/miniconda3/envs/theseus/lib/python3.8/site-packages/torch/_functorch/eager_transforms.py", line 291, in _vjp_with_argnums
    primals_out = func(*primals)
  File "/home/nahar/miniconda3/envs/theseus/lib/python3.8/site-packages/theseus/core/cost_function.py", line 314, in jac_fn
    return self._err_fn(optim_vars=tmp_optim_vars, aux_vars=tmp_aux_vars)[0]
  File "/home/nahar/6D-Pose-Estimation/learning-tools/multi-camera-tools/optimizer.py", line 301, in img_sim_error
    img1 = self.dr.render(c2.rotation().tensor, c2.translation().tensor)
  File "/home/nahar/6D-Pose-Estimation/learning-tools/multi-camera-tools/diff_render.py", line 106, in render
    silhouette = self.silhouette_renderer(meshes_world=self.mesh, R=R, T=T)
  File "/home/nahar/miniconda3/envs/theseus/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/nahar/miniconda3/envs/theseus/lib/python3.8/site-packages/pytorch3d/renderer/mesh/renderer.py", line 61, in forward
    fragments = self.rasterizer(meshes_world, **kwargs)
  File "/home/nahar/miniconda3/envs/theseus/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/nahar/miniconda3/envs/theseus/lib/python3.8/site-packages/pytorch3d/renderer/mesh/rasterizer.py", line 224, in forward
    meshes_proj = self.transform(meshes_world, **kwargs)
  File "/home/nahar/miniconda3/envs/theseus/lib/python3.8/site-packages/pytorch3d/renderer/mesh/rasterizer.py", line 198, in transform
    verts_view = cameras.get_world_to_view_transform(**kwargs).transform_points(
  File "/home/nahar/miniconda3/envs/theseus/lib/python3.8/site-packages/pytorch3d/renderer/cameras.py", line 206, in get_world_to_view_transform
    world_to_view_transform = get_world_to_view_transform(R=R, T=T)
  File "/home/nahar/miniconda3/envs/theseus/lib/python3.8/site-packages/pytorch3d/renderer/cameras.py", line 1588, in get_world_to_view_transform
    T_ = Translate(T, device=T.device)
  File "/home/nahar/miniconda3/envs/theseus/lib/python3.8/site-packages/pytorch3d/transforms/transform3d.py", line 553, in __init__
    mat[:, 3, :3] = xyz
RuntimeError: vmap: inplace arithmetic(self, *extra_args) is not possible because there exists a Tensor `other` in extra_args that has more elements than `self`. This happened due to `other` being vmapped over but `self` not being vmapped over in a vmap. Please try to use out-of-place operators instead of inplace arithmetic. If said operator is being called inside the PyTorch framework, please file a bug report instead.
@bottler
Copy link
Contributor

bottler commented Sep 18, 2023

I'm afraid lots of pytorch3d ops are not written to be compatible with vmap. In #1533 you can see an example of how you might wrap a function to combine the vmap dimension with the batch dimension and therefore make it work. Can you do something similar for this case?

@neanea04
Copy link
Author

neanea04 commented Sep 29, 2023

My input R shape is torch.Size([1, 3, 3]). I Think I dont have a batch dimension. I think there are some inplace operations happening here. That maybe the issue.
I can see that in the todo here you are planning to make the function take input of 4*4 matrix. If thats the case maybe it is not an issue.

@bottler
Copy link
Contributor

bottler commented Sep 29, 2023

This might be quite hard to solve. What is the silhouette_renderer function? Perhaps it could be rewritten specially for vmap compatibility.

@neanea04
Copy link
Author

neanea04 commented Oct 2, 2023

My code is just a slightly changed version of this example

cameras = FoVPerspectiveCameras(device=device)


blend_params = BlendParams(sigma=1e-4, gamma=1e-4)
raster_settings = RasterizationSettings(
    image_size=256, 
    blur_radius=np.log(1. / 1e-4 - 1.) * blend_params.sigma, 
    faces_per_pixel=100, 
)

silhouette_renderer = MeshRenderer(
    rasterizer=MeshRasterizer(
        cameras=cameras, 
        raster_settings=raster_settings
    ),
    shader=SoftSilhouetteShader(blend_params=blend_params)
)


raster_settings = RasterizationSettings(
    image_size=256, 
    blur_radius=0.0, 
    faces_per_pixel=1, 
)
lights = PointLights(device=device, location=((2.0, 2.0, -2.0),))
phong_renderer = MeshRenderer(
    rasterizer=MeshRasterizer(
        cameras=cameras, 
        raster_settings=raster_settings
    ),
    shader=HardPhongShader(device=device, cameras=cameras, lights=lights)
)

Here you can see the defenition of functions. Sorry I am new to vmap and couldn't understand what is happening in the error message.

@bottler
Copy link
Contributor

bottler commented Oct 2, 2023

What is the whole function you are passing to theseus?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants