Skip to content

Commit

Permalink
Merge pull request #535 from zimmerrol/wrapper-df
Browse files Browse the repository at this point in the history
Improve behavior of data_format
  • Loading branch information
jonasrauber committed Aug 24, 2020
2 parents 7e68eb6 + 0f42d6b commit fbd350e
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 19 deletions.
2 changes: 1 addition & 1 deletion foolbox/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def _preprocess(self, inputs: ep.TensorType) -> ep.TensorType:

@property
def data_format(self) -> Any:
return getattr(self._model, "data_format", None)
return self._model.data_format # type: ignore


ModelType = TypeVar("ModelType", bound="ModelWithPreprocessing")
Expand Down
18 changes: 0 additions & 18 deletions tests/test_attacks_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,21 +45,3 @@ class Model:
model.data_format = "invalid" # type: ignore
with pytest.raises(ValueError):
assert fbn.attacks.base.get_channel_axis(model, 3) # type: ignore


def test_transform_bounds_wrapper_data_format() -> None:
class Model(fbn.models.Model):
data_format = "channels_first"

@property
def bounds(self) -> fbn.types.Bounds:
return fbn.types.Bounds(0, 1)

def __call__(self, inputs: fbn.models.base.T) -> fbn.models.base.T:
return inputs

model = Model()
wrapped_model = fbn.models.TransformBoundsWrapper(model, (0, 1))
assert fbn.attacks.base.get_channel_axis(
model, 3
) == fbn.attacks.base.get_channel_axis(wrapped_model, 3)
35 changes: 35 additions & 0 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,3 +234,38 @@ def test_preprocessing(fmodel_and_data: ModelAndData) -> None:
fmodel = fbn.models.base.ModelWithPreprocessing(
fmodel._model, fmodel.bounds, fmodel.dummy, preprocessing
)


def test_transform_bounds_wrapper_data_format() -> None:
class Model(fbn.models.Model):
data_format = "channels_first"

@property
def bounds(self) -> fbn.types.Bounds:
return fbn.types.Bounds(0, 1)

def __call__(self, inputs: fbn.models.base.T) -> fbn.models.base.T:
return inputs

model = Model()
wrapped_model = fbn.models.TransformBoundsWrapper(model, (0, 1))
assert fbn.attacks.base.get_channel_axis(
model, 3
) == fbn.attacks.base.get_channel_axis(wrapped_model, 3)
assert hasattr(wrapped_model, "data_format")
assert not hasattr(wrapped_model, "not_data_format")


def test_transform_bounds_wrapper_missing_data_format() -> None:
class Model(fbn.models.Model):
@property
def bounds(self) -> fbn.types.Bounds:
return fbn.types.Bounds(0, 1)

def __call__(self, inputs: fbn.models.base.T) -> fbn.models.base.T:
return inputs

model = Model()
wrapped_model = fbn.models.TransformBoundsWrapper(model, (0, 1))
assert not hasattr(wrapped_model, "data_format")
assert not hasattr(wrapped_model, "not_data_format")

0 comments on commit fbd350e

Please sign in to comment.