Skip to content

Commit

Permalink
Add torch.compile to RandomGaussianIllumination (#2868)
Browse files Browse the repository at this point in the history
* Add torch.compile to RandomGaussianIllumination

* update structure of class

* use self _fn variable

* update dynamo test
  • Loading branch information
vgilabert94 committed Apr 8, 2024
1 parent 85820c2 commit d1a1cc0
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 1 deletion.
36 changes: 35 additions & 1 deletion kornia/augmentation/_2d/intensity/gaussian_illumination.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import Any, Dict, Optional, Tuple, Union

import torch

from kornia.augmentation._2d.intensity.base import IntensityAugmentationBase2D
from kornia.augmentation.random_generator._2d import GaussianIlluminationGenerator
from kornia.core import Tensor
Expand Down Expand Up @@ -141,6 +143,16 @@ def __init__(
# Generator of random parameters and masks.
self._param_generator = GaussianIlluminationGenerator(gain, center, sigma, sign)

def _apply_transform(
input: Tensor,
params: Dict[str, Tensor],
flags: Dict[str, Any],
transform: Optional[Tensor] = None,
) -> Tensor:
return input.add_(params["gradient"]).clamp_(0, 1)

self._fn = _apply_transform

def apply_transform(
self,
input: Tensor,
Expand All @@ -149,4 +161,26 @@ def apply_transform(
transform: Optional[Tensor] = None,
) -> Tensor:
r"""Apply random gaussian gradient illumination to the input image."""
return input.add_(params["gradient"]).clamp_(0, 1)
return self._fn(input=input, params=params, flags=flags, transform=transform)

def compile(
self,
*,
fullgraph: bool = False,
dynamic: bool = False,
backend: str = "inductor",
mode: Optional[str] = None,
options: Optional[Dict[Any, Any]] = None,
disable: bool = False,
) -> "RandomGaussianIllumination":
self._fn = torch.compile(
self._fn,
fullgraph=fullgraph,
dynamic=dynamic,
backend=backend,
mode=mode,
options=options,
disable=disable,
)

return self
8 changes: 8 additions & 0 deletions tests/augmentation/test_augmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4060,6 +4060,14 @@ def test_same_on_batch(self, device, dtype):
output_tensor = transform(input_tensor)
self.assert_close(output_tensor[0], output_tensor[1])

@pytest.mark.slow
def test_dynamo(self, device, dtype, torch_optimizer):
input_tensor = torch.ones(1, 3, 3, 3, device=device, dtype=dtype) * 0.5
aug = RandomGaussianIllumination(gain=0.5, p=1.0)
aug = aug.compile(fullgraph=True)
actual = aug(input_tensor)
assert actual.shape == input_tensor.shape


class TestRandomLinearIllumination(BaseTester):
def _get_expected(self, device, dtype):
Expand Down

0 comments on commit d1a1cc0

Please sign in to comment.