Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wrong rendering result of point cloud #373

Closed
BostonLobster opened this issue Sep 22, 2020 · 5 comments
Closed

Wrong rendering result of point cloud #373

BostonLobster opened this issue Sep 22, 2020 · 5 comments
Assignees
Labels
how to How to use PyTorch3D in my project

Comments

@BostonLobster
Copy link

If you do not know the root cause of the problem / bug, and wish someone to help you, please
post according to this template:

🐛 Bugs / Unexpected behaviors

I'm trying to use the code from render_colored_point.ipynb to render a point cloud from Scan Net.
My procedure is as follows:

  1. Read the ply point cloud file and convert to PointCloud instance.
  2. Read intrinsic and extrinsic.
  3. Create a PerspectiveCameras instance with intrinsics and extrinsics.
  4. Render the point cloud.

But the rendering result is very different from the ground truth view image, as shown below:

Rendered GT
render_perspective_ori 0000

The rendered image is downloaded from jupyter, which is resized, the original size is (1296, 1296). The GT is of (968, 1296).

So you can see the camera is at wrong place! The correct camera pose is shooting to the two screens, but the rendered result is shooting above the room.

I know that the extrinsics of Scan Net is in OpenCV coordinate system, so I check information from https://github.com/facebookresearch/pytorch3d/blob/master/docs/notes/cameras.md and https://github.com/vvvv/VL.OpenCV/wiki/Coordinate-system-conversions-between-OpenCV,-DirectX-and-vvvv.
I found that in OpenCV, the x-axis points right, y-axis points down, so I guess I just need to rotate them around z-axis by pi, so that it is aligned with pytorch3d? I tried to add minus to the first two column of R, but got a emtpy rendering result.

Instructions To Reproduce the Issue:

Please include the following (depending on what the issue is):

  1. following is my code
import os
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
from skimage.io import imread

# Util function for loading point clouds
import numpy as np

# Data structures and functions for rendering
from pytorch3d.structures import Pointclouds
from pytorch3d.renderer import (
    look_at_view_transform,
    OpenGLOrthographicCameras, 
    PointsRasterizationSettings,
    PointsRenderer,
    PointsRasterizer,
    AlphaCompositor,
    NormWeightedCompositor,
    SfMPerspectiveCameras,
    SfMOrthographicCameras,
    OpenGLPerspectiveCameras
)
import trimesh

# Setup
if torch.cuda.is_available():
    device = torch.device("cuda:0")
    torch.cuda.set_device(device)
else:
    device = torch.device("cpu")

pc = trimesh.load('./scene0010_00/scene0010_00_vh_clean.ply')
# Load point cloud
verts = torch.Tensor(pc.vertices).to(device)  
rgb = torch.Tensor(pc.visual.vertex_colors[:, :3] / 255.).to(device)
point_cloud = Pointclouds(points=[verts], features=[rgb])

to_tensor = lambda x: [torch.Tensor(i).unsqueeze(0) for i in x]

# Load camera parameters
pose_path = './scene0010_00/pose/0000.txt'
intrinsic_path = './scene0010_00/intrinsic/intrinsic_color.txt'
extrinsic = np.loadtxt(pose_path)

R, T = to_tensor([extrinsic[:3, :3], extrinsic[:3, -1]])
K = torch.Tensor(np.loadtxt(intrinsic_path)).unsqueeze(0)

image_size = 1296

f_screen = torch.stack([K[:, 0, 0], K[:, 1, 1]], dim=1)
p_screen = torch.stack([K[:, 0, 2], K[:, 1, 2]], dim=1)

f_ndc = f_screen * 2.0 / image_size
p_ndc = - (p_screen - image_size / 2.0) * 2.0 / image_size

cameras = SfMPerspectiveCameras(focal_length=f_ndc, principal_point=p_ndc, R=R, T=T, device=device)

# Define the settings for rasterization and shading. Here we set the output image to be of size
# 512x512. As we are rendering images for visualization purposes only we will set faces_per_pixel=1
# and blur_radius=0.0. Refer to raster_points.py for explanations of these parameters. 
raster_settings = PointsRasterizationSettings(
    image_size=image_size, 
    radius = 0.003,
    points_per_pixel = 10,
    bin_size=100
)

# Create a points renderer by compositing points using an alpha compositor (nearer points
# are weighted more heavily). See [1] for an explanation.
renderer = PointsRenderer(
    rasterizer=PointsRasterizer(cameras=cameras, raster_settings=raster_settings),
    compositor=AlphaCompositor()
)

images = renderer(point_cloud)
plt.figure(figsize=(10, 10))
plt.imshow(images[0, ..., :3].cpu().numpy())
plt.grid("off")
plt.axis("off");
@gkioxari gkioxari added the how to How to use PyTorch3D in my project label Sep 22, 2020
@gkioxari
Copy link
Contributor

Hi @BostonLobster
I think this is an issue of coordinate systems. You need to figure out the world coordinate system of the world system and also the R, T given by ScanNet. I haven't worked with ScanNet a lot but I think that their R, T follows a different convention.

@gkioxari gkioxari self-assigned this Sep 22, 2020
@BostonLobster
Copy link
Author

@gkioxari Thanks for your reply.
However, as far as I know, the R, T of ScanNet is in OpenCV coordinate system, which is following figure
image

and pytorch3d is using the coordinate system below
image

So, by rotating the X-Y plane around z-axis by pi in OpenCV coordinate system, we get pytorch3d coordinate. Anything wrong?

@BostonLobster
Copy link
Author

My above understanding is correct, by rotating X-Y plane around z-axis we can get the right coordinate. The wrong rendering result comes from other mistake made elsewhere.

@ZX-Yin
Copy link

ZX-Yin commented Sep 19, 2021

My above understanding is correct, by rotating X-Y plane around z-axis we can get the right coordinate. The wrong rendering result comes from other mistake made elsewhere.

Hi,
have you solved the problem? I've been stuck in this problem for a long time.

@Minisal
Copy link

Minisal commented Sep 13, 2022

@BostonLobster @JasonYinn

Hello,
Did any of you solve this problem? I try to rotating X-Y plane around z-axis, but still don't get the correct rendering result.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
how to How to use PyTorch3D in my project
Projects
None yet
Development

No branches or pull requests

4 participants