In [None]:
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np

from art.attacks.evasion import ProjectedGradientDescent
from art.estimators.classification import PyTorchClassifier
from art.utils import load_nursery

from matplotlib import pyplot as plt


DEVICE=torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.linear1=nn.Linear(24,10)
        self.sigmoid1=nn.Sigmoid()
        self.linear2=nn.Linear(10,4)
        
    def forward(self, x):
        x = self.sigmoid1(self.linear1(x))
        x = self.linear2(x)
        return x

In [None]:
(x_train, y_train), (x_test, y_test), min_pixel_value, max_pixel_value = load_nursery()
x_train=x_train.astype("float32")
y_train=y_train.astype("long")
x_test=x_test.astype("float32")
y_test=y_test.astype("long")

In [None]:
model = Net()

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(),lr=0.001)

In [None]:
classifier = PyTorchClassifier(
    model=model,
    clip_values=(min_pixel_value, max_pixel_value),
    loss=criterion,
    optimizer=optimizer,
    input_shape=(24),
    nb_classes=4,
)

In [None]:
classifier.fit(x_train, y_train, batch_size=4, nb_epochs=3)

In [None]:
predictions = classifier.predict(x_test)
accuracy = np.sum(np.argmax(predictions, axis=1) ==y_test) / len(y_test)
print("Accuracy on test examples: {}%".format(accuracy * 100))

In [None]:
attack = ProjectedGradientDescent(estimator=classifier, eps=0.6)
x_test_adv = attack.generate(x=x_test)

In [None]:
predictions = classifier.predict(x_test_adv)
accuracy = np.sum(np.argmax(predictions, axis=1) == y_test) / len(y_test)
print("Accuracy on adversarial test examples: {}%".format(accuracy * 100))

In [None]:
from transformers import BertModel, AdamW

In [None]:
model_name = 'bert-base-uncased'
themaxlength=512

defense_filter = BertModel.from_pretrained(model_name).to(DEVICE)
linear_in = nn.Linear(1,768).to(DEVICE)
linear_out = nn.Linear(768,1).to(DEVICE)

In [None]:
for name, param in defense_filter.named_parameters():
    if 'LayerNorm' in name:
        param.requires_grad = True
    else:
        param.requires_grad = False

for name, param in classifier.model.named_parameters():
    param.requires_grad = False

for name, param in linear_in.named_parameters():
    param.requires_grad = True
    
for name, param in linear_out.named_parameters():
    param.requires_grad = True

In [None]:
params = list(filter(lambda p: p.requires_grad, linear_in.parameters()))+list(filter(lambda p: p.requires_grad, linear_out.parameters()))+list(filter(lambda p: p.requires_grad, defense_filter.parameters()))
filter_optimizer = AdamW(params)

In [None]:
epoch=5
x_in=x_test_adv
y_in=y_test
defense_filter.train()
classifier._model.eval()
l=[]
final_x_pured=[]
for e in range(epoch+1):
    final_x_pured=[]
    pbar=tqdm(range(int(len(x_in)/4)))
    pbar.set_description("Epoch {:0>2d}".format(e+1))
    ll=0
    if(e==epoch):
        defense_filter.eval()
    for i in pbar:#100*100
        filter_optimizer.zero_grad()
        x=x_in[(i*4):((i+1)*4)]
        y=y_in[(i*4):((i+1)*4)]
        tmp=linear_in(torch.tensor(x).to(DEVICE).unsqueeze(-1))
        tmp2=defense_filter(inputs_embeds=tmp)
        tmp3=tmp2['last_hidden_state']
        imp=linear_out(tmp3).squeeze()
        x_pured=torch.tensor(x).to(DEVICE)+imp
        if(len(final_x_pured)==0):
            final_x_pured=x_pured.detach().cpu()
        else:
            final_x_pured=torch.cat((final_x_pured,x_pured.detach().cpu()),axis=0)
        if(e==epoch):
            continue
        out=classifier._model(x_pured)
        loss=classifier._loss(out[-1],torch.tensor(y).to(DEVICE).long()).sum()
        ll=ll+loss.item()
        pbar.set_postfix({'loss_sum':'{:.4f}'.format(ll)})
        loss.backward()
        filter_optimizer.step()
    if(e!=epoch):
        l.append(ll)

In [None]:
plt.plot(l)
plt.savefig("loss_nursery")

In [None]:
defense_filter.eval()
classifier._model.eval()

predictions = classifier.predict(final_x_pured)
accuracy = np.sum(np.argmax(predictions, axis=1) == y_test) / len(y_test)
print("Accuracy on adversarial test examples after purifying: {}%".format(accuracy * 100))

In [None]:
x_train_adv = attack.generate(x=x_train)

In [None]:
predictions = classifier.predict(x_train_adv)
accuracy = np.sum(np.argmax(predictions, axis=1) == y_train) / len(y_train)
print("Accuracy on adversarial train examples: {}%".format(accuracy * 100))

In [None]:
defense_filter.eval()
classifier._model.eval()

x_in=x_train_adv
y_in=y_train

final_x_pured=[]
pbar=tqdm(range(int(len(x_in)/4)))
with torch.no_grad():
    for i in pbar:#100*100
        x=x_in[(i*4):((i+1)*4)]
        y=y_in[(i*4):((i+1)*4)]
        tmp=linear_in(torch.tensor(x).to(DEVICE).unsqueeze(-1))
        tmp2=defense_filter(inputs_embeds=tmp)
        tmp3=tmp2['last_hidden_state']
        imp=linear_out(tmp3).squeeze()
        x_pured=torch.tensor(x).to(DEVICE)+imp
        if(len(final_x_pured)==0):
            final_x_pured=x_pured.detach().cpu()
        else:
            final_x_pured=torch.cat((final_x_pured,x_pured.detach().cpu()),axis=0)

predictions = classifier.predict(final_x_pured)
accuracy = np.sum(np.argmax(predictions, axis=1) == y_in) / len(y_in)
print("Accuracy on adversarial train examples after purifying: {}%".format(accuracy * 100))   

In [None]:
torch.save(defense_filter.state_dict(), "defense_filter_nursery.pdparams")
torch.save(linear_in.state_dict(), "linear_in_nursery.pdparams")
torch.save(linear_out.state_dict(), "linear_out_nursery.pdparams")

In [None]:
# defense_filter.load_state_dict(torch.load("defense_filter_nursery.pdparams"))
# linear_in.load_state_dict(torch.load("linear_in_nursery.pdparams"))
# linear_out.load_state_dict(torch.load("linear_out_nursery.pdparams"))