In [1]:
import utils
# from utils import DeltaEnsemble
from mnist_models import TwoLayerNN, MLP, CNN, MLPBN, ConvNet, LeNet, LeNet5
import torch.nn.functional as F

import torch

m = CNN()
batchSize = 1000
loss_fn = lambda x, y: F.nll_loss(torch.log(x + 1e-9), y)
learningRate = 0.001

optimizer = torch.optim.Adam(m.parameters(), lr = learningRate)

num_epochs = 5

utils.train_model(m, loss_fn, batchSize, utils.trainset, utils.valset, optimizer, num_epochs)


Validation-epoch 0. Avg-Loss: 0.2636, Accuracy: 0.9501
Validation-epoch 1. Avg-Loss: 0.0969, Accuracy: 0.9796
Validation-epoch 2. Avg-Loss: 0.0653, Accuracy: 0.9841
Validation-epoch 3. Avg-Loss: 0.0501, Accuracy: 0.9862
Validation-epoch 4. Avg-Loss: 0.0473, Accuracy: 0.9863


In [2]:
utils.attack_model(m, loss_fn, 1000, utils.valset, 400);


Before. Avg-Loss: 0.0473, Accuracy: 0.9863
After. Avg-Loss: 1.0141, Accuracy: 0.6653


In [3]:
class DeltaEnsemble(torch.nn.Module):
    def __init__(self, m, eps = 0.1, n_neighb = 0):
        super(DeltaEnsemble, self).__init__()
        self.m = m
        self.eps = eps
        self.n_neighb = n_neighb

    def _get_neighb(self, x, n_neighb):
        all_inputs = [x]
        for k in range(n_neighb):
            grad = torch.sigmoid(torch.rand_like(x).uniform_(-200, 200))
            ub = torch.clamp(x + self.eps, min = 0, max = 1)
            lb = torch.clamp(x - self.eps, min = 0, max = 1)
            delta = ub - lb
            x2 = delta * grad + lb
            all_inputs.append(x2)
        return torch.stack(all_inputs)

    def _cam(self, x):
        x_ = x.clone().detach()
        x_.requires_grad = True
        z_ = self.m._feature(x_)
        z = z_.detach()
        z = z +  torch.randn_like(z).normal_(std = z.std() / 10)
        loss_z = F.mse_loss(z_, z)
        loss_z.backward()
        return x_.grad 

    def _get_neighb_with_grad(self, x, n_neighb):
        cam_abs = self._cam(x).abs()
        cam_mask = cam_abs > np.percentile(cam_abs.cpu(), 75)

        x = x.unsqueeze(0)
        x_ = x.repeat(n_neighb, 1, 1, 1, 1)
        x_ = x + torch.randn_like(x_).sign() * self.eps * cam_mask
        x_ = torch.clamp(x_, min = 0, max = 1)
        x_ = torch.cat((x, x_), dim = 0)
        return x_

    def _get_neighb_uniform(self, x, n_neighb):
        x = x.unsqueeze(0)
        x_ = x.repeat(n_neighb, 1, 1, 1, 1)
        ub = torch.clamp(x + self.eps, min = 0, max = 1)
        lb = torch.clamp(x - self.eps, min = 0, max = 1)
        x_ = (ub - lb) * torch.rand_like(x_) + lb
        x_ = torch.cat((x, x_), dim = 0)
        return x_

    def _predict_neighb(self, x, n_neighb):
        all_inputs = self._get_neighb(x, n_neighb)
        if (n_neighb + 1) * len(x) <= 10000:
            outputs = self.m(all_inputs.view((n_neighb + 1) * len(x), *x.shape[1:])).view((n_neighb + 1), len(x), -1)
        else:
            outputs = torch.stack([self.m(i) for i in all_inputs])
        return outputs

    def forward(self, x, n_neighb = -1):
        if n_neighb == -1:
            n_neighb = self.n_neighb

        if n_neighb == 0:
            return self.m(x)
        else:            
            outputs = self._predict_neighb(x, n_neighb)
            return sum(outputs) / n_neighb

In [None]:
for i in [1, 2, 5, 10]:
    print(i)
    m_ = DeltaEnsemble(m, n_neighb = i, eps = 0.1)
    m_.eval()
    utils.attack_model(m_, loss_fn, int(40000 // max(1, i)), utils.valset, 400)
#     break

1
Before. Avg-Loss: -0.6448, Accuracy: 0.9860
After. Avg-Loss: -0.1205, Accuracy: 0.7875
2
Before. Avg-Loss: -0.3568, Accuracy: 0.9861
After. Avg-Loss: 0.0901, Accuracy: 0.8284
5
Before. Avg-Loss: -0.1410, Accuracy: 0.9865
After. Avg-Loss: 0.2036, Accuracy: 0.8560
10


In [None]:
m_ = DeltaEnsemble(m, n_neighb = 10)

import matplotlib.pyplot as plt
import torchattacks
atk = torchattacks.PGD(m_, eps=0.1, alpha=1/255, steps=400, random_start=False)


In [None]:
from tqdm.notebook import tqdm
wrong = []

for k, (x, label) in enumerate(tqdm(utils.valset)):
    m_.eval()
    x = x.unsqueeze(0).cuda()    
    adv_images = atk(x, torch.tensor(label).unsqueeze(0).cuda())    
    if m_(x).argmax().item() != m_(adv_images).argmax().item():
        wrong.append(k)
        break

In [None]:
# k = 35
x, label = utils.valset[k]
x = x.unsqueeze(0).cuda()
adv_images = atk(x, torch.tensor(label).unsqueeze(0).cuda())


In [None]:
x_ = m_._get_neighb_with_grad(x, 1000)
pred_ = m(x_.squeeze(1))
plt.plot(pred_.cpu().detach().T)
plt.show()

In [None]:
adv_images_ = m_._get_neighb(adv_images, 1000)
pred_adv_images_ = m(adv_images_.squeeze(1))
plt.plot(pred_adv_images_.cpu().detach().T)
plt.show()

from collections import Counter

print(Counter(pred_.cpu().detach().argmax(1).tolist()))
print(Counter(pred_adv_images_.cpu().detach().argmax(1).tolist()))


In [None]:
import random

fig, axs = plt.subplots(4, 2)
axs[0, 0].plot(m(x.repeat(2, 1,1,1)).tolist()[0])
axs[0, 1].plot(m(adv_images.repeat(2, 1,1,1)).tolist()[0])
axs[1, 0].matshow(x.squeeze().cpu())
axs[1, 1].matshow(adv_images.squeeze().cpu())
axs[2, 0].matshow(x_[random.randint(0, len(adv_images_))].squeeze().cpu())
axs[2, 1].matshow(adv_images_[random.randint(0, len(adv_images_))].squeeze().cpu())
axs[3, 0].plot(m_(x.repeat(2, 1,1,1)).tolist()[0])
axs[3, 1].plot(m_(adv_images.repeat(2, 1,1,1)).tolist()[0])
