Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Jagged edges + "glitchy" overlap when rendering multiple meshes in a single scene? #657

Closed
collinskatie opened this issue Apr 27, 2021 · 12 comments
Assignees
Labels
how to How to use PyTorch3D in my project question Further information is requested Stale

Comments

@collinskatie
Copy link

❓ Questions on how to use PyTorch3D

Hi, thank you for making such an amazing tool!

I'm currently trying to render a scene of block towers. I am following the method proposed in Issue #15 to concatenate several meshes (in this case, cubes) into a single mesh to render. However, sometimes the edges of these cubes look very jagged, particularly in low-resolution settings. And when adding blur/blending, there are still challenges with overlap.

I'm curious if anyone on the PyTorch3D dev team (or any other user!) has recommendations on how to better render multiple meshes to a single image, where each mesh may have a different color? Does the renderer have issues resolving face overlap, or cases where one block sits directly ontop of another (and how to potentially resolve this on-the-fly?)

Thank you!

Below are some examples of cube "glitches" (the top is shown at 32x32 res and the bottom at 256x256). I've also included a cube mesh in isolation, showing the "jagged" edges.

Screen Shot 2021-04-27 at 12 23 58 AM

Screen Shot 2021-04-27 at 12 24 31 AM

Screen Shot 2021-04-21 at 2 11 43 PM

Here is a sampling of the code used to combine the meshes, in case anything is particulary amiss there.

   # Modified from PyTorch3D tutorial
    # https://github.com/facebookresearch/pytorch3d/blob/master/docs/tutorials/render_textured_meshes.ipynb
    device = sizes.device
    R, T = look_at_view_transform(1.0, 90, 180,
                                  up=((0.0, -1.0, 0.0),),
                                  at=((0.0, 1, -0.2),))  # view top to see stacking
    cameras = FoVPerspectiveCameras(device=device, R=R, T=T,
                                    fov=45.0)
    # Settings for rasterizer (optional blur)
    # https://github.com/facebookresearch/pytorch3d/blob/1c45ec9770ee3010477272e4cd5387f9ccb8cb51/pytorch3d/renderer/mesh/shader.py
    blend_params = BlendParams(sigma=1e-3, gamma=1e-4, background_color=(0.0, 0.0, 0.0))#BlendParams(sigma=1e-4, gamma=1e-4, background_color=(0.0, 0.0, 0.0))
    raster_settings = RasterizationSettings(
        image_size=im_size,  # crisper objects + texture w/ higher resolution
        blur_radius=np.log(1. / 1e-4 - 1.) * blend_params.sigma,
        faces_per_pixel=1,  # increase at cost of GPU memory,
        bin_size=0
    )
    lights = PointLights(device=device, location=[[0.0, 3.0, 0.0]])  # top light
    # Compose renderer and shader
    renderer = MeshRenderer(
        rasterizer=MeshRasterizer(
            cameras=cameras,
            raster_settings=raster_settings
        ),
        shader=SoftPhongShader(
            device=device,
            cameras=cameras,
            lights=lights,
            blend_params=blend_params
        )
    )
    # create one mesh per elmt in batch
    meshes = []
    for batch_idx, n_cubes in enumerate(num_cubes):
        # Combine obj meshes into single mesh from rendering
        # https://github.com/facebookresearch/pytorch3d/issues/15
        vertices = []
        faces = []
        textures = []
        vert_offset = 0 # offset by vertices from prior meshes
        for i, (position, size,color) in enumerate(zip(positions[batch_idx, :n_cubes, :], sizes[batch_idx, :n_cubes],
                                                       colors[batch_idx, :n_cubes, :])):
            cube_vertices, cube_faces = get_cube_mesh(position, size)
            # For now, apply same color to each mesh vertex (v \in V)
            texture = torch.ones_like(cube_vertices) * color# [V, 3]
            # Offset faces (account for diff indexing, b/c treating as one mesh)
            cube_faces = cube_faces + vert_offset
            vert_offset = cube_vertices.shape[0]
            vertices.append(cube_vertices)
            faces.append(cube_faces)
            textures.append(texture)
        # Concatenate data into single mesh
        vertices = torch.cat(vertices)
        faces = torch.cat(faces)
        textures = torch.cat(textures)[None]  # (1, num_verts, 3)
        textures = TexturesVertex(verts_features=textures)
        # each elmt of verts array is diff mesh in batch
        mesh = Meshes(verts=[vertices], faces=[faces], textures=textures)
        meshes.append(mesh)
    batched_mesh = join_meshes_as_batch(meshes)
    # Render image
    img = renderer(batched_mesh)   # (B, H, W, 4)
    # Remove alpha channel and return (B, im_size, im_size, 3)
    img = img[:, ..., :3]#.detach().squeeze().cpu().numpy()

Thanks for any help/advice!

@bottler bottler added the question Further information is requested label Apr 27, 2021
@nikhilaravi
Copy link
Contributor

@collinskatie You are currently only rendering with faces_per_pixel=1. Are you trying to generate an image to compute a loss with in order to get gradients? The jagged edges in the low resolution image seem normal.

Can you provide the cube mesh you are using so we can reproduce the behavior at the high resolution setting? If you can provide the full code snippet to load and render one mesh and the stacked meshes, we can help you more easily.

Regarding the overlap, both meshes are now part of the same mesh, so all the faces are handled equally by the rasterizer - as you have set faces_per_pixel=1 the rasterizer should only return the closest face in the z direction from the given camera.

Also in the rasterization settings make sure you set bin_size=None, as with bin_size=0 you're using the naive rasterizer which is significantly slower than the coarse-to-fine rasterizer.

@nikhilaravi nikhilaravi self-assigned this Apr 28, 2021
@nikhilaravi nikhilaravi added the how to How to use PyTorch3D in my project label Apr 28, 2021
@collinskatie
Copy link
Author

Hi @nikhilaravi - thank you for your reply! The cube mesh function is included below.

I realize I actually think I had some issues with the coordinate system! And regarding faces_per_pixel, I'm a bit confused on the exact impact of that parameter. Do we always want faces_per_pixel > 1 if we have multiple stacked meshes? Is there a particular setting of that parameter (say, faces_per_pixel = 5) that you've found to be typically good?

And I'll update the bin_size - thanks!!

def get_cube_mesh(position, size):
    """Computes a cube mesh
    Adapted from https://github.com/mikedh/trimesh/blob/master/trimesh/creation.py#L566

    Args
        position [3]
        size []

    Returns
        vertices [num_vertices, 3]
        faces [num_faces, 3]
    """
    # Extract
    device = position.device

    # vertices of the cube
    centered_vertices = (
        torch.tensor(
            [0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 0, 1, 1, 1, 0, 1, 1, 1],
            dtype=torch.float,
            device=device,
        ).view(-1, 3)
        - 0.5
    )
    translation = position.clone()
    translation[-1] += size / 2
    vertices = centered_vertices * size + translation[None]

    # hardcoded face indices
    faces = torch.tensor(
        [
            1,
            3,
            0,
            4,
            1,
            0,
            0,
            3,
            2,
            2,
            4,
            0,
            1,
            7,
            3,
            5,
            1,
            4,
            5,
            7,
            1,
            3,
            7,
            2,
            6,
            4,
            2,
            2,
            7,
            6,
            6,
            5,
            4,
            7,
            5,
            6,
        ],
        dtype=torch.int32,
        device=device,
    ).view(-1, 3)

    return vertices, faces

@collinskatie
Copy link
Author

Hi @nikhilaravi , even when adjusting the faces_per_pixel and bin_size, we are having issues where a cube either floats or overlaps with another cube mesh - even though the locations we specify would prohibit overlap and place cubes directly ontop of each other. This seems to particularly be an issue with small cube meshes. Just curious if you have any sense of why we may be having issues with small cube meshes? Does this seem like a face resolution issue? We are using the same method detailed above, where we concatenate individual cube meshes, and adjust the vertices, to render several meshes to a single image. And because we are encountering both floating and overlapping cubes, we can't uniformly adjust the y-coordinate of blocks in our towers to hack-ily amelioriate these issues. Does PyTorch3D have any way to automatically check if meshes will overlap?

Below are some examples of floating + overlapping cubes:

image (1)
image (2)
image (3)

Thank you for any help as to why this may be happening!

@nikhilaravi
Copy link
Contributor

@collinskatie apologies for the delay! Are you still having problems with this? I have time to look into it now!

@collinskatie
Copy link
Author

collinskatie commented Jun 19, 2021

Hi @nikhilaravi! Thank you for your response (PyTorch3D is awesome!)

We have mostly fixed the issue above. You were right that increasing the faces_per_pixel definitely helps, particularly using around 10. And I actually found a bug in our mesh function -- the following code works better, in case anyone encounters a similar issue:

centered_vertices = (
        torch.tensor(
            [0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 0, 1, 1, 1, 0, 1, 1, 1],
            dtype=torch.float,
            device=device,
        ).view(-1, 3)
    )
translation = position.clone()
translation[-1] += size / 2
vertices = centered_vertices * size + translation[None] - 0.5

However, I did have one other question related to overlap --- if two cube meshes are sitting directly ontop of each other, they end up sharing vertices, which seems to lead to rendered overlap. Do you know if there is a good way to render/combine meshes that sit directly ontop of each other/have contact without overlap issues? I have been able to add a hack by slightly moving one of the blocks up (by something like epsilon = 0.005), but was wondering if there was a better method or already a way to render meshes that have contact along some vertices? Most of the time the rendering works, but it does seem that sometimes surfaces are resolved oddly, leading to overlap? This is a very minor issue, and the hack works well! Just wanted to check though.

Screen Shot 2021-06-18 at 4 25 42 PM

(as an example, the vertices of these blocks are all lying directly ontop of each other, meaning some blocks share vertices -- which I'm guessing needs to be avoided by a user when working in PyTorch3D?)

Thanks!

@nikhilaravi
Copy link
Contributor

@collinskatie glad that increasing faces per pixel helped to partially resolve this.

Regarding the overlap issue, are you referring to this effect where the green cube is on top of the purple but you're saying the cubes are actually on top of each other?

Screen Shot 2021-06-22 at 12 32 04 PM

some blocks share vertices

To clarify, the the vertices for each mesh in the Meshes datastructure are different but they are in the same xyz position?
The rendering pipeline only looks at if the face overlaps a pixel and it is then retained or discarded depending on the z distance and the number of faces per pixel being stored.

@collinskatie
Copy link
Author

Hi @nikhilaravi , yes that's the issue. The y position is the same for some of the vertices -- though, as seen, the two blocks have different textures. I partially resolved this by adding a tiny amount to the y of the upper block (offsetting the y by 0.05), which helped!

Regarding your point about the rendering pipeline -- does that mean it's most important to ensure that the z coordinates are the same? I'm a little confused about the "retained or discarded depending on the z distance" aspect?

Thank you again for your help!

@github-actions
Copy link

This issue is stale because it has been open 30 days with no activity. Remove stale label or comment or this will be closed in 5 days.

@github-actions github-actions bot added the Stale label Jul 29, 2021
@github-actions
Copy link

github-actions bot commented Aug 4, 2021

This issue was closed because it has been stalled for 5 days with no activity.

@github-actions github-actions bot closed this as completed Aug 4, 2021
@culurciello
Copy link

Hi I was using your great examples @collinskatie and wanted to thank you for this. I have an issues that for some reason I cannot see more than 2 cubes... do you see anything wrong in code below?

# Combine obj meshes into single mesh from rendering
# https://github.com/facebookresearch/pytorch3d/issues/15
meshes = []
vertices = []
faces = []
textures = []
vert_offset = 0 # offset by vertices from prior meshes

position = torch.FloatTensor([[0,0,0], [0,0,1], [0,0,-1],
                              [1,0,0], [0,1,1], [1,1,-1]]).to(device)
size = torch.FloatTensor([[0.1,0.2,0.2], [0.1,0.2,0.3],[0.3,0.3,0.4],
                          [0.3,0.5,0.2], [0.2,0.2,0.4],[0.2,0.4,0.2]]).to(device)

color = torch.FloatTensor([[1,0,0], [0,1,0], [0,0,1],
                           [1,0,0], [0,1,0], [0,0,1]]).to(device)

num_cubes = position.shape[0]
print('Number of cubes:', num_cubes)

for n_cubes in range(num_cubes):
    # print(position[n_cubes], size[n_cubes])
    cube_vertices, cube_faces = get_cube_mesh(position[n_cubes], size[n_cubes], device)
    # For now, apply same color to each mesh vertex (v \in V)
    texture = torch.ones_like(cube_vertices) * color[n_cubes]# [V, 3]
    # Offset faces (account for diff indexing, b/c treating as one mesh)
    cube_faces = cube_faces + vert_offset
    vert_offset = cube_vertices.shape[0]
    vertices.append(cube_vertices)
    faces.append(cube_faces)
    textures.append(texture)

# Concatenate data into single mesh
vertices = torch.cat(vertices)
faces = torch.cat(faces)
textures = torch.cat(textures)[None]  # (1, num_verts, 3)
textures = TexturesVertex(verts_features=textures)
# each elmt of verts array is diff mesh in batch
mesh = Meshes(verts=[vertices], faces=[faces], textures=textures)
meshes.append(mesh)
# batched_mesh = join_meshes_as_batch(meshes)
mesh = join_meshes_as_scene(meshes)

# Render image and save:
# img = renderer(batched_mesh)   # (B, H, W, 4)
img = renderer(mesh) 
# Remove alpha channel, make tensor and then PIL image:
img = img[:, ..., :3].detach().squeeze().cpu()
img = img.permute(2,0,1)
im = transforms.ToPILImage()(img).convert("RGB")
im.save('test.png')

@culurciello
Copy link

well never mind... error was:

    vert_offset = cube_vertices.shape[0]*n_cubes

@collinskatie
Copy link
Author

Glad you figured things out!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
how to How to use PyTorch3D in my project question Further information is requested Stale
Projects
None yet
Development

No branches or pull requests

4 participants