Skip to content

Commit

Permalink
Pass data_format through TransformBoundsWrapper (#529)
Browse files Browse the repository at this point in the history
* Make data format accessible through wrapper

* Fix type

* Increase test coverage

* Use correct types
  • Loading branch information
zimmerrol committed Apr 7, 2020
1 parent 51461b1 commit bbc9389
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 0 deletions.
4 changes: 4 additions & 0 deletions foolbox/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ def _preprocess(self, inputs: ep.TensorType) -> ep.TensorType:
min_, max_ = self._model.bounds
return x * (max_ - min_) + min_

@property
def data_format(self) -> Any:
return getattr(self._model, "data_format", None)


ModelType = TypeVar("ModelType", bound="ModelWithPreprocessing")

Expand Down
18 changes: 18 additions & 0 deletions tests/test_attacks_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,21 @@ class Model:
model.data_format = "invalid" # type: ignore
with pytest.raises(ValueError):
assert fbn.attacks.base.get_channel_axis(model, 3) # type: ignore


def test_transform_bounds_wrapper_data_format() -> None:
class Model(fbn.models.Model):
data_format = "channels_first"

@property
def bounds(self) -> fbn.types.Bounds:
return fbn.types.Bounds(0, 1)

def __call__(self, inputs: fbn.models.base.T) -> fbn.models.base.T:
return inputs

model = Model()
wrapped_model = fbn.models.TransformBoundsWrapper(model, (0, 1))
assert fbn.attacks.base.get_channel_axis(
model, 3
) == fbn.attacks.base.get_channel_axis(wrapped_model, 3)

0 comments on commit bbc9389

Please sign in to comment.