Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Strange Laplacian smoothing effect #432

Closed
AndresCasado opened this issue Nov 9, 2020 · 6 comments
Closed

Strange Laplacian smoothing effect #432

AndresCasado opened this issue Nov 9, 2020 · 6 comments
Assignees
Labels
how to How to use PyTorch3D in my project

Comments

@AndresCasado
Copy link

Hi.

Sorry if this is not direclty related to PyTorch3D or if this is caused by a mistake I've made.

I'm having trouble understanding how the laplacian smoothing loss works. Reading the paper linked in the documentation I would expect that the mesh it smooths would keep the shape more or less close to the original. I want to use this regularizer inside a bigger optimization problem, but I want to be sure I'm using it right and knowing what I am doing.

I've tried optimizing some meshes using only the laplacian loss, but even with regular meshes it somehow breaks the shape, usually by creating some spikes on the mesh. See the examples below.

Am I using it wrong? Is this expected?

Thanks in advance

Examples

Initial mesh:
Image

Resulting mesh after 3k iterations:
Image

If I use an icosphere instead of a UV sphere it has a better result, but there are some vertices that seem to collapse:
Initial mesh:
Image

Resulting mesh after 300 iterations:
Image

On more irregular meshes it is clearly visible that there are spikes emerging from the vertices where there is a normal difference between faces:
Initial mesh:
Image

Result after 1k iterations:
Image

Instructions To Reproduce the Issue:

Code
import pytorch3d.io as torch3d_io
import pytorch3d.loss as torch3d_loss
import pytorch3d.structures as torch3d_struct
import torch
import torch.nn as nn

device = torch.device('cuda')

orig_file = 'mesh.obj'
orig_mesh = torch3d_io.load_objs_as_meshes([orig_file], device=device)


class OptimizationModule(nn.Module):
  def __init__(self, starting_mesh: torch3d_struct.Meshes, device):
      super().__init__()

      self.mesh = starting_mesh.clone()
      self.device = device

      self.mesh_offsets = nn.Parameter(
          torch.zeros_like(
              self.mesh.verts_packed(),
              device=device,
              requires_grad=True,
          )
      )

  def forward(self, **kwargs):
      step_mesh = self.mesh.offset_verts(self.mesh_offsets)

      return torch3d_loss.mesh_laplacian_smoothing(step_mesh)


model = OptimizationModule(orig_mesh, device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

torch.cuda.empty_cache()

for step in range(3000):
  loss = model()  # type: torch.Tensor

  optimizer.zero_grad()
  loss.backward()
  optimizer.step()

  # Export the mesh to OBJ each 100 iterations
  if step % 100 == 0:
      print(step, '-->', loss.item())
      with torch.no_grad():
          step_mesh = model.mesh.offset_verts(model.mesh_offsets)
          torch3d_io.save_obj(
              f'mesh_{step}.obj',
              step_mesh.verts_padded().squeeze(),
              step_mesh.faces_padded().squeeze(),
          )
@gkioxari gkioxari self-assigned this Nov 9, 2020
@gkioxari
Copy link
Contributor

gkioxari commented Nov 9, 2020

Laplacian smoothing (with the uniform mode) as you are using it is doing the following

loss = L.mm(verts)

Where L is the Laplacian matrix (or a version of it, there is many versions in the literature) as follows

"""
Computes the laplacian in packed form.
The definition of the laplacian is
L[i, j] = -1 , if i == j
L[i, j] = 1 / deg(i) , if (i, j) is an edge
L[i, j] = 0 , otherwise
where deg(i) is the degree of the i-th vertex in the graph
Returns:
Sparse FloatTensor of shape (V, V) where V = sum(V_n)
"""

In terms of math, the loss tries to bring each vertex close to the weighted average of the neighboring vertices (weighted by the degree of the vertex). Other forms of laplacian take into account the length of the edges (e.g. Taubin smoothing). So what is important to answer is whether what you are doing in terms of math by using our mesh laplacian smoothing loss is what you actually desire to do.

@gkioxari gkioxari added the how to How to use PyTorch3D in my project label Nov 9, 2020
@AndresCasado
Copy link
Author

AndresCasado commented Nov 10, 2020

In terms of math, the loss tries to bring each vertex close to the weighted average of the neighboring vertices (weighted by the degree of the vertex).

I thought this was the case already, that's why I don't understand why it creates those spikes, shouldn't those vertices and their neighbours be smoothed?

What I intend to use the loss for is to avoid the vertices offset overfit a given target by smoothing the resulting mesh. I'm also considering using the mesh edge loss, but the laplacian looked like a good option, specially since it is used in the mesh fitting tutorial.

@gkioxari
Copy link
Contributor

It's all in the math and the spikes are explained by the math. You could try to play with the weight on the loss. Perhaps you need to set it a lower value so that it does't try to do something aggressive. Or you need to play with additional regularizers like edge length and normal consistency. All in PyTorch3D.

@seva100
Copy link

seva100 commented Oct 20, 2022

Can confirm this behavior. Perhaps it's really a consequence of the math but not sure in which way, would appreciate if someone can explain it

@seva100
Copy link

seva100 commented Oct 25, 2022

So far, fixed the spikes by replacing sum to the L2 norm in:

to:

return torch.linalg.norm(loss) / N

As far as I understand, sum encouraged laplacian to be small for most of the vertices but it was OK that it's large for some of them. With L2 norm, laplacian has to be more uniformly small for all vertices (i.e., "outliers" are penalized harder). Perhaps it would be more reasonable to support this as a default behavior...

@rohitrango
Copy link

rohitrango commented May 3, 2023

did you also try using a different optimizer / smaller learning rates?
Also, the code for Laplacian is taking an L1 norm somewhere, which I think is undesirable

https://github.com/facebookresearch/pytorch3d/blob/995b60e3b99faa1ee1bcdbe244426d54d98a7242/pytorch3d/loss/mesh_laplacian_smoothing.py#LL131-LL132

Try converting it into an L2 norm, or a hinge loss.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
how to How to use PyTorch3D in my project
Projects
None yet
Development

No branches or pull requests

4 participants