Skip to content

Commit

Permalink
Merge pull request #508 from bethgelab/mnist
Browse files Browse the repository at this point in the history
added pretrained mnist zoo example and test model; fixed mnist samples
  • Loading branch information
jonasrauber committed Mar 21, 2020
2 parents 1066e1c + 294356e commit 64dc1ea
Show file tree
Hide file tree
Showing 6 changed files with 53 additions and 2 deletions.
35 changes: 35 additions & 0 deletions examples/zoo/mnist/foolbox_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#!/usr/bin/env python3
import torch
import torch.nn as nn
import os
from foolbox.models import PyTorchModel
from foolbox.utils import accuracy, samples


def create() -> PyTorchModel:
model = nn.Sequential(
nn.Conv2d(1, 32, 3),
nn.ReLU(),
nn.Conv2d(32, 64, 3),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Dropout2d(0.25),
nn.Flatten(), # type: ignore
nn.Linear(9216, 128),
nn.ReLU(),
nn.Dropout2d(0.5),
nn.Linear(128, 10),
)
path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "mnist_cnn.pth")
model.load_state_dict(torch.load(path)) # type: ignore
model.eval()
preprocessing = dict(mean=0.1307, std=0.3081)
fmodel = PyTorchModel(model, bounds=(0, 1), preprocessing=preprocessing)
return fmodel


if __name__ == "__main__":
# test the model
fmodel = create()
images, labels = samples(fmodel, dataset="mnist", batchsize=20)
print(accuracy(fmodel, images, labels))
Binary file added examples/zoo/mnist/mnist_cnn.pth
Binary file not shown.
2 changes: 1 addition & 1 deletion foolbox/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def _samples(

assert image.ndim == 3

if dataset != "mnist" and data_format == "channels_first":
if data_format == "channels_first":
image = np.transpose(image, (2, 0, 1))

images.append(image)
Expand Down
1 change: 1 addition & 0 deletions foolbox/zoo/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .zoo import get_model # noqa: F401
from .weights_fetcher import fetch_weights # noqa: F401
from .git_cloner import GitCloneError # noqa: F401
from .model_loader import ModelLoader # noqa: F401
11 changes: 11 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,17 @@ def pytorch_simple_model_object(request: Any) -> ModelAndData:
return pytorch_simple_model(torch.device("cpu"))


@register("pytorch", real=True)
def pytorch_mnist(request: Any) -> ModelAndData:
fmodel = fbn.zoo.ModelLoader.get().load(
"examples/zoo/mnist/", module_name="foolbox_model"
)
x, y = fbn.samples(fmodel, dataset="mnist", batchsize=16)
x = ep.astensor(x)
y = ep.astensor(y)
return fmodel, x, y


@register("pytorch", real=True)
def pytorch_resnet18(request: Any) -> ModelAndData:
if request.config.option.skipslow:
Expand Down
6 changes: 5 additions & 1 deletion tests/test_brendel_bethge_attack.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Tuple, Union, List
from typing import Tuple, Union, List, Any
import eagerpy as ep

import foolbox as fbn
Expand All @@ -22,11 +22,15 @@ def get_attack_id(x: Tuple[BrendelBethgeAttack, Union[int, float]]) -> str:

@pytest.mark.parametrize("attack_and_p", attacks, ids=get_attack_id)
def test_brendel_bethge_untargeted_attack(
request: Any,
fmodel_and_data_ext_for_attacks: Tuple[
Tuple[fbn.Model, ep.Tensor, ep.Tensor], bool
],
attack_and_p: Tuple[BrendelBethgeAttack, Union[int, float]],
) -> None:
if request.config.option.skipslow:
pytest.skip()

(fmodel, x, y), real = fmodel_and_data_ext_for_attacks

if isinstance(x, ep.NumPyTensor):
Expand Down

0 comments on commit 64dc1ea

Please sign in to comment.