In [None]:
"""
PyTorch MNIST the Fast Gradient Sign Method.
"""
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np

from art.attacks.evasion import FastGradientMethod
from art.estimators.classification import PyTorchClassifier
from art.utils import load_mnist

from matplotlib import pyplot as plt


DEVICE=torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv_1 = nn.Conv2d(in_channels=1, out_channels=4, kernel_size=5, stride=1)
        self.conv_2 = nn.Conv2d(in_channels=4, out_channels=10, kernel_size=5, stride=1)
        self.fc_1 = nn.Linear(in_features=4 * 4 * 10, out_features=100)
        self.fc_2 = nn.Linear(in_features=100, out_features=10)

    def forward(self, x):
        x = F.relu(self.conv_1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv_2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 4 * 4 * 10)
        x = F.relu(self.fc_1(x))
        x = self.fc_2(x)
        return x

In [None]:
(x_train, y_train), (x_test, y_test), min_pixel_value, max_pixel_value = load_mnist()

x_train = np.transpose(x_train, (0, 3, 1, 2)).astype(np.float32)
x_test = np.transpose(x_test, (0, 3, 1, 2)).astype(np.float32)

In [None]:
model = Net()

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)

In [None]:
classifier = PyTorchClassifier(
    model=model,
    clip_values=(min_pixel_value, max_pixel_value),
    loss=criterion,
    optimizer=optimizer,
    input_shape=(1, 28, 28),
    nb_classes=10,
)

In [None]:
classifier.fit(x_train, y_train, batch_size=64, nb_epochs=3)

In [None]:
predictions = classifier.predict(x_test)
accuracy = np.sum(np.argmax(predictions, axis=1) == np.argmax(y_test, axis=1)) / len(y_test)
print("Accuracy on test examples: {}%".format(accuracy * 100))

In [None]:
attack = FastGradientMethod(estimator=classifier, eps=0.2)
x_test_adv = attack.generate(x=x_test)

In [None]:
predictions = classifier.predict(x_test_adv)
accuracy = np.sum(np.argmax(predictions, axis=1) == np.argmax(y_test, axis=1)) / len(y_test)
print("Accuracy on adversarial test examples: {}%".format(accuracy * 100))

In [None]:
from transformers import BertModel, AdamW

In [None]:
model_name = 'bert-base-uncased'
themaxlength=512

defense_filter = BertModel.from_pretrained(model_name).to(DEVICE)
linear_in = nn.Linear(16,768).to(DEVICE)
linear_out = nn.Linear(768,16).to(DEVICE)

In [None]:
for name, param in defense_filter.named_parameters():
    if 'LayerNorm' in name:
        param.requires_grad = True
    else:
        param.requires_grad = False

for name, param in classifier.model.named_parameters():
    param.requires_grad = False

for name, param in linear_in.named_parameters():
    param.requires_grad = True
    
for name, param in linear_out.named_parameters():
    param.requires_grad = True

In [None]:
params = list(filter(lambda p: p.requires_grad, linear_in.parameters()))+list(filter(lambda p: p.requires_grad, linear_out.parameters()))+list(filter(lambda p: p.requires_grad, defense_filter.parameters()))
filter_optimizer = AdamW(params)

In [None]:
def get_batches_mnist(x):
    patches=None
    x_torch=torch.tensor(x)
    for i in range(7):
        for j in range(7):
            if(i==0 and j==0):
                patches=x_torch[:,:,i*4:(i+1)*4,j*4:(j+1)*4].reshape((x_torch.shape[0],1,16))
            else:
                patches=torch.cat((patches,x_torch[:,:,i*4:(i+1)*4,j*4:(j+1)*4].reshape((x_torch.shape[0],1,16))),axis=1)
    return patches

In [None]:
def get_image_mnist(x):
    x=x.reshape((x.shape[0],7,7,4,4))
    img_l1=[]
    for i in range(7):
        img_l2=[]
        for j in range(7):
                img_l2.append(x[:,i,j,:,:])
        img_l1.append(torch.cat(img_l2,axis=-1))
    img=torch.cat(img_l1,axis=-2).unsqueeze(1)
    return img

In [None]:
epoch=20
x_in=x_test_adv
y_in=y_test
defense_filter.train()
classifier._model.eval()
l=[]
final_x_pured=[]
for e in range(epoch+1):
    final_x_pured=[]
    pbar=tqdm(range(int(len(x_in)/100)))
    pbar.set_description("Epoch {:0>2d}".format(e+1))
    ll=0
    if(e==epoch):
        defense_filter.eval()
    for i in pbar:#100*100
        filter_optimizer.zero_grad()
        x=x_in[(i*100):((i+1)*100)]
        y=y_in[(i*100):((i+1)*100)]
        patches=get_batches_mnist(x)
        tmp=linear_in(patches.to(DEVICE))
        tmp2=defense_filter(inputs_embeds=tmp)
        tmp3=tmp2['last_hidden_state']
        tmp4=linear_out(tmp3)
        imp=get_image_mnist(tmp4)
        x_pured=torch.tensor(x).to(DEVICE)+imp
        if(len(final_x_pured)==0):
            final_x_pured=x_pured.detach().cpu()
        else:
            final_x_pured=torch.cat((final_x_pured,x_pured.detach().cpu()),axis=0)
        if(e==epoch):
            continue
        out=classifier._model(x_pured)
        loss=classifier._loss(out[-1],torch.tensor(y).to(DEVICE)).sum()
        ll=ll+loss.item()
        pbar.set_postfix({'loss_sum':'{:.4f}'.format(ll)})
        loss.backward()
        filter_optimizer.step()
    if(e!=epoch):
        l.append(ll)

In [None]:
plt.plot(l)
plt.savefig("loss_mnist")

In [None]:
defense_filter.eval()
classifier._model.eval()

predictions = classifier.predict(final_x_pured)
accuracy = np.sum(np.argmax(predictions, axis=1) == np.argmax(y_test, axis=1)) / len(y_test)
print("Accuracy on adversarial test examples after purifying: {}%".format(accuracy * 100))

In [None]:
plt.imshow(x_test[0][0],"gray")
plt.savefig("mnist")

In [None]:
plt.imshow(x_test_adv[0][0],"gray")
plt.savefig("mnist_adv")

In [None]:
plt.imshow(final_x_pured[0][0],"gray")
plt.savefig("mnist_pured")

In [None]:
x_train_adv = attack.generate(x=x_train)

In [None]:
predictions = classifier.predict(x_train_adv)
accuracy = np.sum(np.argmax(predictions, axis=1) == np.argmax(y_train, axis=1)) / len(y_train)
print("Accuracy on adversarial train examples: {}%".format(accuracy * 100))

In [None]:
defense_filter.eval()
classifier._model.eval()

x_in=x_train_adv
y_in=y_train

final_x_pured=[]
pbar=tqdm(range(int(len(x_in)/100)))
with torch.no_grad():
    for i in pbar:#100*100
        x=x_in[(i*100):((i+1)*100)]
        y=y_in[(i*100):((i+1)*100)]
        patches=get_batches_mnist(x)
        tmp=linear_in(patches.to(DEVICE))
        tmp2=defense_filter(inputs_embeds=tmp)
        tmp3=tmp2['last_hidden_state']
        tmp4=linear_out(tmp3)
        imp=get_image_mnist(tmp4)
        x_pured=torch.tensor(x).to(DEVICE)+imp
        if(len(final_x_pured)==0):
            final_x_pured=x_pured.detach().cpu()
        else:
            final_x_pured=torch.cat((final_x_pured,x_pured.detach().cpu()),axis=0)

predictions = classifier.predict(final_x_pured)
accuracy = np.sum(np.argmax(predictions, axis=1) == np.argmax(y_in, axis=1)) / len(y_in)
print("Accuracy on adversarial train examples after purifying: {}%".format(accuracy * 100))   

In [None]:
torch.save(defense_filter.state_dict(), "defense_filter_mnist.pdparams")
torch.save(linear_in.state_dict(), "linear_in_mnist.pdparams")
torch.save(linear_out.state_dict(), "linear_out_mnist.pdparams")

In [None]:
# defense_filter.load_state_dict(torch.load("defense_filter_mnist.pdparams"))
# linear_in.load_state_dict(torch.load("linear_in_mnist.pdparams"))
# linear_out.load_state_dict(torch.load("linear_out_mnist.pdparams"))