In [1]:
import numpy as np
import torch
import torch.functional as F
import torch.nn as nn

In [None]:
def curvature_regularization_loss(model, x, h, attack_norm_p='inf'):
    """
    Computes curvature regularization term.
    
    The formula is L(x) = \| \nabla l(x + h z) - \nabla l(x) \|^2,
    where z depends on attack norm. If attack is in \ell_inf, then
    z = sign \nabla l(x) / \| sign \nabla l(x) \|. Another good
    choice is z = \nabla l(x) / \| \nabla l(x) \|.
    
    Args:
        x (Tensor): data point
        h (float): interpolation parameter
        attack_norm_p (str): if 'inf', \ell_inf z is used
    """
    original = torch.zeros_like(x).copy_(x)
    prob_original = model(original)
    gradients_original = torch.autograd.grad(outputs=prob_original,
                                             inputs=original,
                                             grad_outputs=torch.ones(prob_original.size()),
                                             create_graph=True,
                                             retain_graph=True)[0]
    
    if attack_norm_p == 'inf':
        z = (gradients_original >= 0) * 1
    
    interpolated = x + h * z
    prob_interpolated = model(interpolated)

    gradients_interpolated = torch.autograd.grad(outputs=prob_interpolated,
                                                 inputs=interpolated,
                                                 grad_outputs=torch.ones(prob_interpolated.size()),
                                                 create_graph=True,
                                                 retain_graph=True)[0]

    # Gradients have shape (batch_size, num_channels, img_width, img_height),
    # so flatten to easily take norm per example in batch
    gradients = gradients.view(batch_size, -1)
    self.losses['gradient_norm'].append(gradients.norm(2, dim=1).mean().data[0])

    # Derivatives of the gradient close to 0 can cause problems because of
    # the square root, so manually calculate norm and add epsilon
    gradients_norm = torch.sqrt(torch.sum(gradients ** 2, dim=1) + 1e-12)

    # Return gradient penalty
    return self.gp_weight * ((gradients_norm - 1) ** 2).mean()