-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Strange Laplacian smoothing effect #432
Comments
Laplacian smoothing (with the
Where L is the Laplacian matrix (or a version of it, there is many versions in the literature) as follows pytorch3d/pytorch3d/structures/meshes.py Lines 1060 to 1071 in 83fef0a
In terms of math, the loss tries to bring each vertex close to the weighted average of the neighboring vertices (weighted by the degree of the vertex). Other forms of laplacian take into account the length of the edges (e.g. Taubin smoothing). So what is important to answer is whether what you are doing in terms of math by using our mesh laplacian smoothing loss is what you actually desire to do. |
I thought this was the case already, that's why I don't understand why it creates those spikes, shouldn't those vertices and their neighbours be smoothed? What I intend to use the loss for is to avoid the vertices offset overfit a given target by smoothing the resulting mesh. I'm also considering using the mesh edge loss, but the laplacian looked like a good option, specially since it is used in the mesh fitting tutorial. |
It's all in the math and the spikes are explained by the math. You could try to play with the weight on the loss. Perhaps you need to set it a lower value so that it does't try to do something aggressive. Or you need to play with additional regularizers like edge length and normal consistency. All in PyTorch3D. |
Can confirm this behavior. Perhaps it's really a consequence of the math but not sure in which way, would appreciate if someone can explain it |
So far, fixed the spikes by replacing sum to the L2 norm in:
to: return torch.linalg.norm(loss) / N As far as I understand, |
did you also try using a different optimizer / smaller learning rates? Try converting it into an L2 norm, or a hinge loss. |
Hi.
Sorry if this is not direclty related to PyTorch3D or if this is caused by a mistake I've made.
I'm having trouble understanding how the laplacian smoothing loss works. Reading the paper linked in the documentation I would expect that the mesh it smooths would keep the shape more or less close to the original. I want to use this regularizer inside a bigger optimization problem, but I want to be sure I'm using it right and knowing what I am doing.
I've tried optimizing some meshes using only the laplacian loss, but even with regular meshes it somehow breaks the shape, usually by creating some spikes on the mesh. See the examples below.
Am I using it wrong? Is this expected?
Thanks in advance
Examples
Initial mesh:
![Image](https://camo.githubusercontent.com/e3ca0c51b71edb18edde5ff860c4b92e73c4788a82c9417c5d2f3321184495a1/68747470733a2f2f692e696d6775722e636f6d2f42706a383658532e706e67)
Resulting mesh after 3k iterations:
![Image](https://camo.githubusercontent.com/cc9404afd9209a3108f86fc232a46b4a0fe73a30945181f06776700085a4ac14/68747470733a2f2f692e696d6775722e636f6d2f334e39724233412e706e67)
If I use an icosphere instead of a UV sphere it has a better result, but there are some vertices that seem to collapse:
![Image](https://camo.githubusercontent.com/b4f2907143d20e90b84e92705f4310bba3156a0e9a0606ad0833fdf907a935c7/68747470733a2f2f692e696d6775722e636f6d2f337456737956542e706e67)
Initial mesh:
Resulting mesh after 300 iterations:
![Image](https://camo.githubusercontent.com/f4c0051954276e6e8fdc16ef4654645285f1c5ae154cccbb6432893930921e8d/68747470733a2f2f692e696d6775722e636f6d2f386f36683168452e706e67)
On more irregular meshes it is clearly visible that there are spikes emerging from the vertices where there is a normal difference between faces:
![Image](https://camo.githubusercontent.com/2d73ae94b8decd0384cb4b910950dd0ec1a16be629f7581a313f88043036c011/68747470733a2f2f692e696d6775722e636f6d2f724a77736c4c4e2e706e67)
Initial mesh:
Result after 1k iterations:
![Image](https://camo.githubusercontent.com/997ed7c1c8ffca8e92e464d0055acfd247d0888edead7b5a1a0ce4b24892861c/68747470733a2f2f692e696d6775722e636f6d2f723872707a556a2e706e67)
Instructions To Reproduce the Issue:
Code
The text was updated successfully, but these errors were encountered: