In [None]:
import numpy as np
import torch
import torch.functional as F
import torch.nn as nn
from torchvision.models import resnet18
from torchvision.datasets import CIFAR10, SVHN

In [None]:
def curvature_regularization_loss(model, lossf, x, h, y=None, attack_norm_p='inf'):
    """
    Computes curvature regularization term.
    
    The formula is L(x) = \| \nabla l(x + h z) - \nabla l(x) \|^2,
    where z depends on attack norm. If attack is in \ell_inf, then
    z = sign \nabla l(x) / \| sign \nabla l(x) \|. Another good
    choice is z = \nabla l(x) / \| \nabla l(x) \|.
    
    Args:
        model, lossf (Module): model and corresponding loss function
        x, y (Tensor): data and optional label
        h (float): interpolation parameter
        attack_norm_p (str): if 'inf', \ell_inf z is used, otherwise
            simply normalized gradient.
    """
    original = x.clone().detach()
    prob_original = lossf(model(original), y) if y is not None else lossf(model(original))
    gradients_original = torch.autograd.grad(outputs=prob_original,
                                             inputs=original,
                                             grad_outputs=torch.ones(prob_original.size()),
                                             create_graph=True,
                                             retain_graph=True)[0]
    
    # do not back-propagate through z
    if attack_norm_p == 'inf':
        z = gradients_original.clone().detach().sign()
    else:
        z = gradients_original.clone().detach()
    
    interpolated = x + h * z
    prob_interpolated = lossf(model(interpolated), y) if y is not None else lossf(model(interpolated))
    gradients_interpolated = torch.autograd.grad(outputs=prob_interpolated,
                                                 inputs=interpolated,
                                                 grad_outputs=torch.ones(prob_interpolated.size()),
                                                 create_graph=True,
                                                 retain_graph=True)[0]

    return torch.sum((gradients_interpolated - gradients_original) ** 2)

In [None]:
dataset = CIFAR10('../data/cifar10', download=True)

In [None]:
def update_rate(optimizer, lr):
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

In [None]:
#model = resnet18(pretrained=True)
#opt = torch.optim.Adam()
loss = nn.CrossEntropyLoss()

In [None]:
epochs = 20
epochs_burnin = 5
lr_schedule = np.linspace(1e-6, 1e-4, epochs).flip()
h_schedule = np.linspace(0, 1.5, epochs_burnin).tolist() + np.repeat([1.5], epochs-epochs_burnin)
cr_weight = 4

In [None]:
def train_model(model, optimizer, lossf, dataloader, lr_schedule, h_schedule, cr_weight, epochs=epochs):
    for epoch, lr, h in zip(range(epochs), lr_schedule, h_schedule):
        update_rate(optimizer, lr)
        losses = []
        for idx, (batch, labels) in enumerate(dataloader):
            model.train()
            optimizer.zero_grad()
            loss = lossf(model(batch.cuda()), labels.cuda())
            full_loss = loss + cr_weight * curvature_regularization_loss(model, lossf, batch.cuda(), h, labels.cuda())
            losses.append(full_loss.detach().cpu().numpy())
            full_loss.backward()
            optimizer.step()
        if epoch % 10 == 0:
            print('[%2d]\tloss\t%.7f' % (epoch+1, np.mean(losses)))