Skip to content

Commit

Permalink
[Fix] Transformation Matrix Mis-calculation for autoaugmentations (ko…
Browse files Browse the repository at this point in the history
…rnia#2852)

* fixes kornia#2843

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* fixes kornia#2843

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* compress warnings

* update

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
2 people authored and christie committed May 18, 2024
1 parent 40bf4dc commit 718074c
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 13 deletions.
19 changes: 18 additions & 1 deletion kornia/augmentation/auto/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
SUBPOLICY_CONFIG = List[OP_CONFIG]


class PolicyAugmentBase(TransformMatrixMinIn, ImageSequentialBase):
class PolicyAugmentBase(ImageSequentialBase, TransformMatrixMinIn):
"""Policy-based image augmentation."""

def __init__(self, policy: List[SUBPOLICY_CONFIG], transformation_matrix_mode: str = "silence") -> None:
Expand Down Expand Up @@ -92,8 +92,25 @@ def forward_parameters(self, batch_shape: torch.Size) -> List[ParamItem]:
return params

def transform_inputs(self, input: Tensor, params: List[ParamItem], extra_args: Dict[str, Any] = {}) -> Tensor:
for param in params:
module = self.get_submodule(param.name)
input = InputSequentialOps.transform(input, module=module, param=param, extra_args=extra_args)
return input

def forward(
self, input: Tensor, params: Optional[List[ParamItem]] = None, extra_args: Dict[str, Any] = {}
) -> Tensor:
self.clear_state()

if params is None:
inp = input
_, out_shape = self.autofill_dim(inp, dim_range=(2, 4))
params = self.forward_parameters(out_shape)

for param in params:
module = self.get_submodule(param.name)
input = InputSequentialOps.transform(input, module=module, param=param, extra_args=extra_args)
self._update_transform_matrix_by_module(module)

self._params = params
return input
4 changes: 2 additions & 2 deletions kornia/augmentation/auto/operations/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,9 @@ def get_transformation_matrix(
flags = override_parameters(module.op.flags, extra_args, in_place=False)
mat = module.op.generate_transformation_matrix(input, param.data, flags)
elif module.op._transform_matrix is not None:
mat = as_tensor(module.op._transform_matrix, device=input.device, dtype=input.dtype)
mat = as_tensor(module.transform_matrix, device=input.device, dtype=input.dtype)
else:
raise RuntimeError(f"{module}.op._transform_matrix is None while `recompute=False`.")
raise RuntimeError(f"{module}.transform_matrix is None while `recompute=False`.")
res_mat = mat @ res_mat
input = module.op.transform_output_tensor(input, ori_shape)
if module.op.keepdim and ori_shape != input.shape:
Expand Down
13 changes: 8 additions & 5 deletions kornia/augmentation/auto/rand_augment/rand_augment.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import Dict, Iterator, List, Optional, Tuple, Union, cast

import torch
from torch.distributions import Categorical

from kornia.augmentation.auto.base import SUBPOLICY_CONFIG, PolicyAugmentBase
from kornia.augmentation.auto.operations import OperationBase
Expand Down Expand Up @@ -70,11 +69,14 @@ def __init__(
_policy = policy

super().__init__(_policy, transformation_matrix_mode=transformation_matrix_mode)
selection_weights = torch.tensor([1.0 / len(self)] * len(self))
self.rand_selector = Categorical(selection_weights)
self.n = n
self.m = m

def rand_selector(self, n: int) -> Tensor:
perm = torch.randperm(len(self._modules))
idx = perm[:n]
return idx

def compose_subpolicy_sequential(self, subpolicy: SUBPOLICY_CONFIG) -> PolicySequential:
if len(subpolicy) != 1:
raise RuntimeError(f"Each policy must have only one operation for RandAugment. Got {len(subpolicy)}.")
Expand All @@ -83,14 +85,15 @@ def compose_subpolicy_sequential(self, subpolicy: SUBPOLICY_CONFIG) -> PolicySeq

def get_forward_sequence(self, params: Optional[List[ParamItem]] = None) -> Iterator[Tuple[str, Module]]:
if params is None:
idx = self.rand_selector.sample((self.n,))
idx = self.rand_selector(
self.n,
)
return self.get_children_by_indices(idx)

return self.get_children_by_params(params)

def forward_parameters(self, batch_shape: torch.Size) -> List[ParamItem]:
named_modules: Iterator[Tuple[str, Module]] = self.get_forward_sequence()

params: List[ParamItem] = []
mod_param: Union[Dict[str, Tensor], List[ParamItem]]
m = torch.tensor([self.m / 30] * batch_shape[0])
Expand Down
5 changes: 0 additions & 5 deletions tests/augmentation/test_auto_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from kornia.augmentation.auto.trivial_augment import TrivialAugment
from kornia.augmentation.container import AugmentationSequential
from kornia.geometry.bbox import bbox_to_mask
from kornia.utils._compat import torch_version

from testing.base import BaseTester

Expand Down Expand Up @@ -112,10 +111,6 @@ def test_smoke(self, policy):
in_tensor = torch.rand(10, 3, 50, 50, requires_grad=True)
aug(in_tensor)

@pytest.mark.xfail(
torch_version() in {"1.9.1", "1.10.2", "1.11.0", "1.12.1", "1.13.1"},
reason="randomness failing into some torch versions",
)
def test_transform_mat(self, device, dtype):
aug = RandAugment(n=3, m=15)
in_tensor = torch.rand(10, 3, 50, 50, device=device, dtype=dtype, requires_grad=True)
Expand Down

0 comments on commit 718074c

Please sign in to comment.