Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FGSM implementation is incorrect #3

Open
carlini opened this issue Feb 26, 2019 · 9 comments
Open

FGSM implementation is incorrect #3

carlini opened this issue Feb 26, 2019 · 9 comments

Comments

@carlini
Copy link

carlini commented Feb 26, 2019

Despite the simplicity of the Fast Gradient Sign Method, it is surprisingly effective at generating adversarial examples on unsecured models. However, Table XIV reports the misclassification rate of FGSM at eps=0.3 on MNIST as 30.4%, significantly less effective than expected given the results of prior work.

I investigate this further by taking the one-line script and following the README to run the FGSM attack on the baseline MNIST model. Doing this yields a misclassification rate of 38.3%. It is mildly concerning that this number is 25% larger than the value reported in the paper, and I'm unable to account for this statistically significant deviation from what the code returns. However, this error is only of secondary concern: as prior work indicates, the success rate of FGSM should be substantially higher.

So I compare the result of attacking with the CleverHans framework. Because DeepSec is implemented in PyTorch, and CleverHans only supports TensorFlow, I load the DeepSec pre-trained PyTorch model weights into a TensorFlow model and generate adversarial examples on this model with the CleverHans implementation of FGSM. CleverHans obtains a 61% misclassification rate–over double the misclassification rate reported in the DeepSec paper. To confirm the results that I obtain are correct I save these adversarial examples and run the original DeepSec PyTorch model on them, again finding the misclassification rate is 61%. I'm currently not able to explain how DeepSec incorrectly implemented FGSM, however the fact the simplest attack is implemented incorrectly is deeply concerning.

The remainder of the issues I'm filing on DeepSec therefore discusses only the methodology and analysis, and not any specific numbers which may or may not be trustworthy.

@ryderling
Copy link
Owner

Despite the simplicity of the Fast Gradient Sign Method, it is surprisingly effective at generating adversarial examples on unsecured models. However, Table XIV reports the misclassification rate of FGSM at eps=0.3 on MNIST as 30.4%, significantly less effective than expected given the results of prior work.

I investigate this further by taking the one-line script and following the README to run the FGSM attack on the baseline MNIST model. Doing this yields a misclassification rate of 38.3%. It is mildly concerning that this number is 25% larger than the value reported in the paper, and I'm unable to account for this statistically significant deviation from what the code returns. However, this error is only of secondary concern: as prior work indicates, the success rate of FGSM should be substantially higher.

So I compare the result of attacking with the CleverHans framework. Because DeepSec is implemented in PyTorch, and CleverHans only supports TensorFlow, I load the DeepSec pre-trained PyTorch model weights into a TensorFlow model and generate adversarial examples on this model with the CleverHans implementation of FGSM. CleverHans obtains a 61% misclassification rate–over double the misclassification rate reported in the DeepSec paper. To confirm the results that I obtain are correct I save these adversarial examples and run the original DeepSec PyTorch model on them, again finding the misclassification rate is 61%. I'm currently not able to explain how DeepSec incorrectly implemented FGSM, however the fact the simplest attack is implemented incorrectly is deeply concerning.

The remainder of the issues I'm filing on DeepSec therefore discusses only the methodology and analysis, and not any specific numbers which may or may not be trustworthy.

First of all, the pre-trained model in the repo currently is not the exact model when we writing the paper, as all the code has been reconstructed (such as all random seeds are manually set for reproducibility) when we release this repo. For the raw re-trained model of MNIST, the number of validation dataset is slightly different from before (it is 0.1 * 60000 = 6000 instead of fixed 5000 samples), since we uniformly use the ratio instead of the absolute value when sampling the validation dataset from the training dataset for both MNIST and CIFAR10. Therefore, we think it is a reasonable range that the FGSM misclassification rate for the new model is 38%.

On the other hand, we are unable to identify the cause of the discrepancy of misclassification rate between DEEPSEC and CleverHans for FGSM. Is your TensorFlow model architecture exactly the same as the pre-trained model of this repo? Are all parameters in the model exactly matched? If you like, could you share your scripts that transfer model in PyTorch to TensorFlow? Until now, I have no idea about that as we cannot find any bug about the FGSM attack. This deserves more discussion and contribution from the community, and it is the reason that we open-source our platform.

@carlini
Copy link
Author

carlini commented Mar 16, 2019

Alright, here's what I did.

First train the MNIST conv net and run the candidates selection process to get 1000 examples.

python train_mnist.py 
python CandidatesSelection.py --dataset MNIST

Start by making the following patch to get the model weights out of PyTorch and to save the images we're using to attack

diff --git a/RawModels/MNISTConv.py b/RawModels/MNISTConv.py
index eb220ad..c833f4f 100644
--- a/RawModels/MNISTConv.py
+++ b/RawModels/MNISTConv.py
@@ -57,6 +57,9 @@ class MNISTConvNet(BasicModule):
         # softmax ? or not

     def forward(self, x):
+        import numpy as np
+        np.save("/tmp/params.npy", [x.cpu().detach().numpy() for x in list(self.conv32.parameters())+list(self.conv64.parameters())+
+                                    list(self.fc1.parameters())+list(self.fc2.parameters())+list(self.fc3.parameters())])
         out = self.conv32(x)
         out = self.conv64(out)
         out = out.view(-1, 4 * 4 * 64)

diff --git a/Attacks/FGSM_Generation.py b/Attacks/FGSM_Generation.py
index 7443786..8d5d0eb 100644
--- a/Attacks/FGSM_Generation.py
+++ b/Attacks/FGSM_Generation.py
@@ -37,9 +37,11 @@ class FGSMGeneration(Generation):
                                                   device=self.device)
         # prediction for the adversarial examples
         adv_labels = predict(model=self.raw_model, samples=adv_samples, device=self.device)
+        np.save("/tmp/adversarial_predictions.npy", adv_labels.cpu().detach().numpy())
         adv_labels = torch.max(adv_labels, 1)[1]
         adv_labels = adv_labels.cpu().numpy()

+        np.save('{}{}_Original.npy'.format(self.adv_examples_dir, self.attack_name), self.nature_samples)
         np.save('{}{}_AdvExamples.npy'.format(self.adv_examples_dir, self.attack_name), adv_samples)
         np.save('{}{}_AdvLabels.npy'.format(self.adv_examples_dir, self.attack_name), adv_labels)
         np.save('{}{}_TrueLabels.npy'.format(self.adv_examples_dir, self.attack_name), self.labels_samples)

Then attack the baseline model

python FGSM_Generation.py --dataset=MNIST --epsilon=0.3 --attack_batch_size 1000

This time when I run it I get different numbers, and see the result

For **FGSM** on **MNIST**: misclassification ratio is 180/1000=18.0%

So it's mildy concerning that I've seen now 38% and 18% as a result of the FGSM attack on two different models. Averaged over 1000 examples, this result is statistically significant (with p some absurdely low value).

Now let's write some TensorFlow code to load everything now.

import numpy as np
import tensorflow as tf

l = np.load("/tmp/params.npy")
l = [np.array(x,dtype=np.float32) for x in l]

def presoftmax(x):
    out = tf.nn.relu(tf.nn.conv2d(x, l[0].transpose((2,3,1,0)), [1,1,1,1], "VALID") + l[1].reshape((1,1,1,-1)))
    out = tf.nn.relu(tf.nn.conv2d(out, l[2].transpose((2,3,1,0)), [1,1,1,1], "VALID") + l[3].reshape((1,1,1,-1)))
    out = tf.nn.max_pool(out, [1,2,2,1], [1, 2, 2, 1], 'VALID')

    out = tf.nn.relu(tf.nn.conv2d(out, l[4].transpose((2,3,1,0)), [1,1,1,1], "VALID") + l[5].reshape((1,1,1,-1)))
    out = tf.nn.relu(tf.nn.conv2d(out, l[6].transpose((2,3,1,0)), [1,1,1,1], "VALID") + l[7].reshape((1,1,1,-1)))
    out = tf.nn.max_pool(out, [1,2,2,1], [1, 2, 2, 1], 'VALID')

    out = tf.transpose(out, (0, 3, 1, 2))
    out = tf.reshape(out, [-1, 1024])

    out = tf.nn.relu(tf.matmul(out, l[8].transpose())+l[9])
    out = tf.nn.relu(tf.matmul(out, l[10].transpose())+l[11])
    out = tf.matmul(out, l[12].transpose())+l[13]
    return out

sess = tf.Session()

x_test = np.load("AdversarialExampleDatasets/FGSM/MNIST/FGSM_Original.npy")
y_test = np.load("AdversarialExampleDatasets/FGSM/MNIST/FGSM_TrueLabels.npy")
y_test = np.load("AdversarialExampleDatasets/FGSM/MNIST/FGSM_TrueLabels.npy")
x_test = np.transpose(x_test, [0, 2, 3, 1])

xs = tf.placeholder(tf.float32, [None, 28, 28, 1])
ys = tf.placeholder(tf.float32, [None, 10])
logits = presoftmax(xs)

print("Clean error", 1-np.mean(np.argmax(sess.run(logits, {xs: x_test}),axis=1)==np.argmax(y_test,axis=1)))

x_test_deepsec = np.load("AdversarialExampleDatasets/FGSM/MNIST/FGSM_AdvExamples.npy")
x_test_deepsec = np.transpose(x_test_deepsec, [0, 2, 3, 1])
print("DEEPSEC FGSM error", 1.0-np.mean(np.argmax(sess.run(logits, {xs: x_test_deepsec}),axis=1)==np.argmax(y_test,axis=1)))

And we see from this when we run it

Clean error 0.0
DEEPSEC FGSM error 0.18000000000000005

Which matches very nicely so far. But just to make absolutely sure we're doing things right, let's compare against the saved logits.

our_logits = sess.run(logits, {xs: x_test_deepsec})
deepsec_logits = np.load("/tmp/adversarial_predictions.npy")

print("Maximum error", np.max(np.abs(our_logits-deepsec_logits)))

And we see that the answer is basically zero.

Maximum error 3.4332275e-05

So now that we know our implementation is doing the exact same thing as PyTorch, let's write and run a naieve implementation of FGSM:

loss = tf.nn.softmax_cross_entropy_with_logits(logits=logits-tf.reduce_max(logits, axis=-1, keepdims=True),
                                              labels=ys)
grad_step = tf.sign(tf.gradients(loss, [xs]))[0]*0.3
direction = sess.run(grad_step, {xs: x_test,
                                 ys: y_test})
fgsm_adv = np.clip(x_test+direction, 0, 1)
np.save("/tmp/our_fgsm_adv.npy", fgsm_adv)
print("FGSM error", 1-np.mean(np.argmax(sess.run(logits, {xs: fgsm_adv}),axis=1)==np.argmax(y_test,axis=1)))

And then I get the following:

FGSM error 0.661

Now because I'm paranoid, let's make sure we haven't broken anything:

print("Check bounds", np.min(fgsm_adv), np.max(fgsm_adv))
print("Check distortion", np.max(np.abs(fgsm_adv-x_test)))

And I see what's expected:

Check bounds 0.0 1.0
Check distortion 0.30000004

But you know what, let's be really sure that we haven't messed anything up. I saved the adversarial examples, so let's patch the FGSM code once more so that it just loads the ones I generated and returns those directly

diff --git a/Attacks/FGSM_Generation.py b/Attacks/FGSM_Generation.py
index 7443786..54c43a2 100644
--- a/Attacks/FGSM_Generation.py
+++ b/Attacks/FGSM_Generation.py
@@ -33,13 +33,14 @@ class FGSMGeneration(Generation):
         attacker = FGSMAttack(model=self.raw_model, epsilon=self.epsilon)

         # generating
-        adv_samples = attacker.batch_perturbation(xs=self.nature_samples, ys=self.labels_samples, batch_size=self.attack_batch_size,
-                                                  device=self.device)
+        adv_samples = np.load("/tmp/our_fgsm_adv.npy").transpose((0,3,1,2))
+
         # prediction for the adversarial examples
         adv_labels = predict(model=self.raw_model, samples=adv_samples, device=self.device)
         adv_labels = torch.max(adv_labels, 1)[1]
         adv_labels = adv_labels.cpu().numpy()

         np.save('{}{}_AdvExamples.npy'.format(self.adv_examples_dir, self.attack_name), adv_samples)
         np.save('{}{}_AdvLabels.npy'.format(self.adv_examples_dir, self.attack_name), adv_labels)
         np.save('{}{}_TrueLabels.npy'.format(self.adv_examples_dir, self.attack_name), self.labels_samples)           

And now let's run this FGSM replay-attack again to see how it does:

For **FGSM** on **MNIST**: misclassification ratio is 661/1000=66.1%

So, identical accuracy in PyTorch for the adversarial examples generated with TensorFlow. I'm pretty sure that (1) the implementation I have is identical to the PyTorch model, however (2) the naive implementation of FGSM I implemented is OVER THREE TIMES more effective than the code in this repository.

So while I don't know why the implementation in this repository is incorrect, I do know that it is incorrect. If I had to guess, I would say it's likely that there's some numerical instability in some of your code somewhere.

(Now it's also deeply concerning that the variance I've seen on two different runs is between 18% and 38%. I would recommend you think about looking at error bars on your data.)

@ryderling
Copy link
Owner

Thank you very much for your share.
For the FGSM in DEEPSEC, we ran several times, it is always 38.3% with manually set seed = 100.
On the other hand, it is significantly important for the FGSM attack to define the loss function and then generate adversarial examples. For DEEPSEC, the loss = torch.nn.CrossEntropyLoss(),
combines the function of LogSoftmax and NLLLoss, which is suggested by PyTorch officially (https://github.com/pytorch/examples/blob/master/mnist/main.py). However, it seems that you use a different loss function here and then use this loss function to generate the adversarial examples and fed to our model, which is unfair to compare.

@carlini
Copy link
Author

carlini commented Mar 24, 2019

  1. To compute the variance properly, you can't fix both the model and dataset subset and repeat a deterministic computation multiple times. Clearly the variance of a deterministic computation should be 0. The only way to correctly estimate the variance is to train several different models (with different random seeds), and select several different dataset subsets (with different random seeds), then estimate the variance. When I do this eight times, I get the values [25.4, 25.9, 39.8, 40.1, 30.9, 26.0, 30.6, 43.5, 43.5, 26.0, 39.8, 25.9, 40.1, 30.9, 25.4, 30.6]. Assuming a normal distribution (a Kolmogorov-Smirnov test fails to reject at p>.3, so it's not such a bad assumption) this would give a 95% confidence interval of 32 +/- 13. It would be nice to know when reading the paper that the attack success rate might range from 19 to 45.

  2. Regardless of what PyTorch recommends for what loss function to use when training a neural network, you need to make sure that the implementation does what the attack specifies. The definition of FGSM is a fixed mathematical equation. In order to implement this equation in code, you may have to write more than exactly that line of code to ensure that the calculation is correct. I have implemented FGSM the correct way in TensorFlow and got a 66% attack success rate. If you want to call the attack something different---"numerically unstable FGSM" for example---then that would be fine. But don't report it as FGSM and measure it incorrectly.

@ryderling
Copy link
Owner

The definition of FGSM is as follow:
image
Do we violated this definition of FGSM anywhere?
I am quite sure our implementation exactly do what the attack specifies by using the recommended cost function torch.nn.CrossEntropyLoss().

On the another hand, I am very confusing the logits in the tf.nn.softmax_cross_entropy_with_logits
loss = tf.nn.softmax_cross_entropy_with_logits(logits=logits-tf.reduce_max(logits, axis=-1, keepdims=True), labels=ys). Since I am not familiar with TensorFlowis there any other references that uses logits=logits-tf.reduce_max(logits, axis=-1, keepdims=True) ?

@ryderling
Copy link
Owner

From my perspecitive, it is unfair to compare the implementations with different loss functions.

When I investigate more with the loss function and change it from torch.nn.CrossEntropyLoss() to torch.nn.NLLLoss(), the attack success rate changes from 38.2% to 79.5%, which is 20% more than your implementation results 66%. Do you think it is fair for me to make a conclusion that "your FGSM implementation is incorrect"?

@carlini
Copy link
Author

carlini commented Mar 27, 2019

The difference is that my implementation is just the numerically stable way to implement softmax. It's still the same function, just numerically stable. See Section 4.1 of the Deep Learning Book by Goodfellow et al. which recommends the exact implementation I write.

@ryderling
Copy link
Owner

Fixed in d4e1181 in defining the model for both MNIST and CIFAR10, though it is suggested by PyTorch officially (https://github.com/pytorch/examples/blob/master/mnist/main.py).

Nothing needs to be changed in our implementation of FGSM.

After retraining the model for MNIST and attacking, the misclassification rate of FGSM at eps=0.3 on MNIST is 80.8%.

@carlini
Copy link
Author

carlini commented May 21, 2019

Moving the numerical stability fix to the MNIST model does, at a technical level, resolve this specific issue.

However, the stated purpose of DeepSec is to support arbitrary defenses written in the future. It would therefore be preferable that the attack is numerically stable. Otherwise, each new defense will have to ensure it is not unintentionally causing gradient masking, and artificially appear robust against gradient-based attacks.

So while making this particular model numerically stable definitely isn't bad, you might also want to consider also making the attack numerically stable as well so that the evaluation framework can better measure robustness of a general (sight-unseen) deep learning model.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants