Skip to content

Commit

Permalink
Update Rasterizer and add end2end fisheye integration test
Browse files Browse the repository at this point in the history
Summary:
1) Update rasterizer/point rasterizer to accommodate fisheyecamera. Specifically, transform_points is in placement of explicit transform compositions.

2) In rasterizer unittests, update corresponding tests for rasterizer and point_rasterizer. Address comments to test fisheye against perspective camera when distortions are turned off.

3) Address comments to add end2end test for fisheyecameras. In test_render_meshes, fisheyecameras are added to camera enuerations whenever possible.

4) Test renderings with fisheyecameras of different params on cow mesh.

5) Use compositions for linear cameras whenever possible.

Reviewed By: kjchalup

Differential Revision: D38932736

fbshipit-source-id: 5b7074fc001f2390f4cf43c7267a8b37fd987547
  • Loading branch information
davidsonic authored and facebook-github-bot committed Aug 31, 2022
1 parent b0515e1 commit d19e624
Show file tree
Hide file tree
Showing 63 changed files with 566 additions and 76 deletions.
24 changes: 22 additions & 2 deletions pytorch3d/renderer/cameras.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import math
import warnings
from typing import List, Optional, Sequence, Tuple, Union
from typing import Callable, List, Optional, Sequence, Tuple, Union

import numpy as np
import torch
Expand Down Expand Up @@ -91,7 +91,7 @@ class CamerasBase(TensorProperties):
# When joining objects into a batch, they will have to agree.
_SHARED_FIELDS: Tuple[str, ...] = ()

def get_projection_transform(self):
def get_projection_transform(self, **kwargs):
"""
Calculate the projective transformation matrix.
Expand Down Expand Up @@ -1841,3 +1841,23 @@ def get_screen_to_ndc_transform(
image_size=image_size,
).inverse()
return transform


def try_get_projection_transform(cameras, kwargs) -> Optional[Callable]:
"""
Try block to get projection transform.
Args:
cameras instance, can be linear cameras or nonliear cameras
Returns:
If the camera implemented projection_transform, return the
projection transform; Otherwise, return None
"""

transform = None
try:
transform = cameras.get_projection_transform(**kwargs)
except NotImplementedError:
pass
return transform
18 changes: 13 additions & 5 deletions pytorch3d/renderer/mesh/rasterizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import torch
import torch.nn as nn
from pytorch3d.renderer.cameras import try_get_projection_transform

from .rasterize_meshes import rasterize_meshes

Expand Down Expand Up @@ -197,12 +198,19 @@ def transform(self, meshes_world, **kwargs) -> torch.Tensor:
verts_view = cameras.get_world_to_view_transform(**kwargs).transform_points(
verts_world, eps=eps
)
# view to NDC transform
# Call transform_points instead of explicitly composing transforms to handle
# the case, where camera class does not have a projection matrix form.
verts_proj = cameras.transform_points(verts_world, eps=eps)
to_ndc_transform = cameras.get_ndc_camera_transform(**kwargs)
projection_transform = cameras.get_projection_transform(**kwargs).compose(
to_ndc_transform
)
verts_ndc = projection_transform.transform_points(verts_view, eps=eps)
projection_transform = try_get_projection_transform(cameras, kwargs)
if projection_transform is not None:
projection_transform = projection_transform.compose(to_ndc_transform)
verts_ndc = projection_transform.transform_points(verts_view, eps=eps)
else:
# Call transform_points instead of explicitly composing transforms to handle
# the case, where camera class does not have a projection matrix form.
verts_proj = cameras.transform_points(verts_world, eps=eps)
verts_ndc = to_ndc_transform.transform_points(verts_proj, eps=eps)

verts_ndc[..., 2] = verts_view[..., 2]
meshes_ndc = meshes_world.update_padded(new_verts_padded=verts_ndc)
Expand Down
15 changes: 10 additions & 5 deletions pytorch3d/renderer/points/rasterizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import torch
import torch.nn as nn
from pytorch3d.renderer.cameras import try_get_projection_transform
from pytorch3d.structures import Pointclouds

from .rasterize_points import rasterize_points
Expand Down Expand Up @@ -103,12 +104,16 @@ def transform(self, point_clouds, **kwargs) -> Pointclouds:
pts_view = cameras.get_world_to_view_transform(**kwargs).transform_points(
pts_world, eps=eps
)
# view to NDC transform
to_ndc_transform = cameras.get_ndc_camera_transform(**kwargs)
projection_transform = cameras.get_projection_transform(**kwargs).compose(
to_ndc_transform
)
pts_ndc = projection_transform.transform_points(pts_view, eps=eps)
projection_transform = try_get_projection_transform(cameras, kwargs)
if projection_transform is not None:
projection_transform = projection_transform.compose(to_ndc_transform)
pts_ndc = projection_transform.transform_points(pts_view, eps=eps)
else:
# Call transform_points instead of explicitly composing transforms to handle
# the case, where camera class does not have a projection matrix form.
pts_proj = cameras.transform_points(pts_world, eps=eps)
pts_ndc = to_ndc_transform.transform_points(pts_proj, eps=eps)

pts_ndc[..., 2] = pts_view[..., 2]
point_clouds = point_clouds.update_padded(pts_ndc)
Expand Down
Binary file added tests/data/test_FishEyeCameras_silhouette.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/data/test_render_fisheye_sphere_points.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
159 changes: 159 additions & 0 deletions tests/test_rasterizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
PointsRasterizer,
RasterizationSettings,
)
from pytorch3d.renderer.fisheyecameras import FishEyeCameras
from pytorch3d.renderer.opengl.rasterizer_opengl import (
_check_cameras,
_check_raster_settings,
Expand Down Expand Up @@ -51,6 +52,9 @@ class TestMeshRasterizer(unittest.TestCase):
def test_simple_sphere(self):
self._simple_sphere(MeshRasterizer)

def test_simple_sphere_fisheye(self):
self._simple_sphere_fisheye_against_perspective(MeshRasterizer)

def test_simple_sphere_opengl(self):
self._simple_sphere(MeshRasterizerOpenGL)

Expand Down Expand Up @@ -155,6 +159,91 @@ def _simple_sphere(self, rasterizer_type):

self.assertTrue(torch.allclose(image, image_ref))

def _simple_sphere_fisheye_against_perspective(self, rasterizer_type):
device = torch.device("cuda:0")

# Init mesh
sphere_mesh = ico_sphere(5, device)

# Init rasterizer settings
R, T = look_at_view_transform(2.7, 0, 0)

# Init Fisheye camera params
focal = torch.tensor([[1.7321]], dtype=torch.float32)
principal_point = torch.tensor([[0.0101, -0.0101]])
perspective_cameras = PerspectiveCameras(
R=R,
T=T,
focal_length=focal,
principal_point=principal_point,
device="cuda:0",
)
fisheye_cameras = FishEyeCameras(
device=device,
R=R,
T=T,
focal_length=focal,
principal_point=principal_point,
world_coordinates=True,
use_radial=False,
use_tangential=False,
use_thin_prism=False,
)
raster_settings = RasterizationSettings(
image_size=512, blur_radius=0.0, faces_per_pixel=1, bin_size=0
)

# Init rasterizer
perspective_rasterizer = rasterizer_type(
cameras=perspective_cameras, raster_settings=raster_settings
)
fisheye_rasterizer = rasterizer_type(
cameras=fisheye_cameras, raster_settings=raster_settings
)

####################################################################################
# Test rasterizing a single mesh comparing fisheye camera against perspective camera
####################################################################################

perspective_fragments = perspective_rasterizer(sphere_mesh)
perspective_image = perspective_fragments.pix_to_face[0, ..., 0].squeeze().cpu()
# Convert pix_to_face to a binary mask
perspective_image[perspective_image >= 0] = 1.0
perspective_image[perspective_image < 0] = 0.0

if DEBUG:
Image.fromarray((perspective_image.numpy() * 255).astype(np.uint8)).save(
DATA_DIR
/ f"DEBUG_test_perspective_rasterized_sphere_{rasterizer_type.__name__}.png"
)

fisheye_fragments = fisheye_rasterizer(sphere_mesh)
fisheye_image = fisheye_fragments.pix_to_face[0, ..., 0].squeeze().cpu()
# Convert pix_to_face to a binary mask
fisheye_image[fisheye_image >= 0] = 1.0
fisheye_image[fisheye_image < 0] = 0.0

if DEBUG:
Image.fromarray((fisheye_image.numpy() * 255).astype(np.uint8)).save(
DATA_DIR
/ f"DEBUG_test_fisheye_rasterized_sphere_{rasterizer_type.__name__}.png"
)

self.assertTrue(torch.allclose(fisheye_image, perspective_image))

##################################
# 2. Test with a batch of meshes
##################################

batch_size = 10
sphere_meshes = sphere_mesh.extend(batch_size)
fragments = fisheye_rasterizer(sphere_meshes)
for i in range(batch_size):
image = fragments.pix_to_face[i, ..., 0].squeeze().cpu()
image[image >= 0] = 1.0
image[image < 0] = 0.0
self.assertTrue(torch.allclose(image, perspective_image))

def test_simple_to(self):
# Check that to() works without a cameras object.
device = torch.device("cuda:0")
Expand Down Expand Up @@ -412,6 +501,76 @@ def test_simple_sphere(self):
image[image < 0] = 0.0
self.assertTrue(torch.allclose(image, image_ref[..., 0]))

def test_simple_sphere_fisheye_against_perspective(self):
device = torch.device("cuda:0")

# Rescale image_ref to the 0 - 1 range and convert to a binary mask.
sphere_mesh = ico_sphere(1, device)
verts_padded = sphere_mesh.verts_padded()
verts_padded[..., 1] += 0.2
verts_padded[..., 0] += 0.2
pointclouds = Pointclouds(points=verts_padded)
R, T = look_at_view_transform(2.7, 0.0, 0.0)
perspective_cameras = PerspectiveCameras(
R=R,
T=T,
device=device,
)
fisheye_cameras = FishEyeCameras(
device=device,
R=R,
T=T,
world_coordinates=True,
use_radial=False,
use_tangential=False,
use_thin_prism=False,
)
raster_settings = PointsRasterizationSettings(
image_size=256, radius=5e-2, points_per_pixel=1
)

#################################
# 1. Test init without cameras.
##################################

# Initialize without passing in the cameras
rasterizer = PointsRasterizer()

# Check that omitting the cameras in both initialization
# and the forward pass throws an error:
with self.assertRaisesRegex(ValueError, "Cameras must be specified"):
rasterizer(pointclouds)

########################################################################################
# 2. Test rasterizing a single pointcloud with fisheye camera agasint perspective camera
########################################################################################

perspective_fragments = rasterizer(
pointclouds, cameras=perspective_cameras, raster_settings=raster_settings
)
fisheye_fragments = rasterizer(
pointclouds, cameras=fisheye_cameras, raster_settings=raster_settings
)

# Convert idx to a binary mask
perspective_image = perspective_fragments.idx[0, ..., 0].squeeze().cpu()
perspective_image[perspective_image >= 0] = 1.0
perspective_image[perspective_image < 0] = 0.0

fisheye_image = fisheye_fragments.idx[0, ..., 0].squeeze().cpu()
fisheye_image[fisheye_image >= 0] = 1.0
fisheye_image[fisheye_image < 0] = 0.0

if DEBUG:
Image.fromarray((perspective_image.numpy() * 255).astype(np.uint8)).save(
DATA_DIR / "DEBUG_test_rasterized_perspective_sphere_points.png"
)
Image.fromarray((fisheye_image.numpy() * 255).astype(np.uint8)).save(
DATA_DIR / "DEBUG_test_rasterized_fisheye_sphere_points.png"
)

self.assertTrue(torch.allclose(fisheye_image, perspective_image))

def test_simple_to(self):
# Check that to() works without a cameras object.
device = torch.device("cuda:0")
Expand Down
Loading

0 comments on commit d19e624

Please sign in to comment.