Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

The aixs of norm of the gradient #56

Open
CharlesNord opened this issue Aug 29, 2020 · 1 comment
Open

The aixs of norm of the gradient #56

CharlesNord opened this issue Aug 29, 2020 · 1 comment

Comments

@CharlesNord
Copy link

gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * LAMBDA

Hi, I accidently saw your code when I google W-GAN GP. I think there is something wrong with your implementationi here. In W-GAN GP, the norm of the interpolated gradient should be calculated across all axis except the batch axis, since the gradient is wrt each sample. But in your code, you only calculated the norm of the second dimension, which is not reaonable. I think you miss the following reshape step:

gradients.view(gradients.shape[0], -1)

@1292224662
Copy link

I also noticed this mistake, and instead of using torch.view(), I used the following:

 gradient_penalty = ((gradients.norm(2, dim=(1,2,3)) - 1) ** 2).mean() * LAMBDA 

Since the gradients here is in the shape of [batch_size, 3, H, W], I think this can also get the L2 norm of each batches, right?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants