Skip to content
This repository has been archived by the owner on Jun 26, 2021. It is now read-only.

Commit

Permalink
Merge pull request #146 from justusschock/numba
Browse files Browse the repository at this point in the history
Numba Transforms
  • Loading branch information
justusschock committed Jun 17, 2019
2 parents 8da5c32 + d6ed758 commit 299b7ac
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 1 deletion.
8 changes: 7 additions & 1 deletion delira/data_loading/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,13 @@
RandomSampler, \
StoppingPrevalenceSequentialSampler, \
SequentialSampler
from .sampler import __all__ as __all_sampling

if "TORCH" in get_backends():
from .dataset import TorchvisionClassificationDataset


try:
from delira.data_loading.numba_transform import NumbaTransform, \
NumbaTransformWrapper, NumbaCompose
except ImportError:
pass
42 changes: 42 additions & 0 deletions delira/data_loading/numba_transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from batchgenerators.transforms import AbstractTransform, Compose

import logging
from delira import get_current_debug_mode
import numba

logger = logging.getLogger(__name__)


class NumbaTransformWrapper(AbstractTransform):
def __init__(self, transform: AbstractTransform, nopython=True,
target="cpu", parallel=False, **options):

if get_current_debug_mode():
# set options for debug mode
logging.debug("Debug mode detected. Overwriting numba options "
"nopython to False and target to cpu")
nopython = False
target = "cpu"

transform.__call__ = numba.jit(transform.__call__, nopython=nopython,
target=target,
parallel=parallel, **options)
self._transform = transform

def __call__(self, **kwargs):
return self._transform(**kwargs)


class NumbaTransform(NumbaTransformWrapper):
def __init__(self, transform_cls, nopython=True, target="cpu",
parallel=False, **kwargs):
trafo = transform_cls(**kwargs)

super().__init__(trafo, nopython=nopython, target=target,
parallel=parallel)


class NumbaCompose(Compose):
def __init__(self, transforms):
super().__init__(transforms=[NumbaTransformWrapper(trafo)
for trafo in transforms])
54 changes: 54 additions & 0 deletions tests/data_loading/test_numba_transforms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import unittest

from batchgenerators.transforms import ZoomTransform, PadTransform, Compose
import numpy as np

try:
import numba
except ImportError:
numba = None


class NumbaTest(unittest.TestCase):
def setUp(self) -> None:
from delira.data_loading.numba_transform import NumbaTransform, \
NumbaCompose
self._basic_zoom_trafo = ZoomTransform(3)
self._numba_zoom_trafo = NumbaTransform(ZoomTransform, zoom_factors=3)
self._basic_pad_trafo = PadTransform(new_size=(30, 30))
self._numba_pad_trafo = NumbaTransform(PadTransform,
new_size=(30, 30))

self._basic_compose_trafo = Compose([self._basic_pad_trafo,
self._basic_zoom_trafo])
self._numba_compose_trafo = NumbaCompose([self._basic_pad_trafo,
self._basic_zoom_trafo])

self._input = {"data": np.random.rand(10, 1, 24, 24)}

def compare_transform_outputs(self, transform, numba_transform):
output_normal = transform(**self._input)["data"]
output_numba = numba_transform(**self._input)["data"]

# only check for same shapes, since numba might apply slightly
# different interpolations
self.assertTupleEqual(output_normal.shape, output_numba.shape)

@unittest.skipIf(numba is None, "Numba must be imported successfully")
def test_zoom(self):
self.compare_transform_outputs(self._basic_zoom_trafo,
self._numba_zoom_trafo)

@unittest.skipIf(numba is None, "Numba must be imported successfully")
def test_pad(self):
self.compare_transform_outputs(self._basic_pad_trafo,
self._numba_pad_trafo)

@unittest.skipIf(numba is None, "Numba must be imported successfully")
def test_compose(self):
self.compare_transform_outputs(self._basic_compose_trafo,
self._numba_compose_trafo)


if __name__ == '__main__':
unittest.main()

0 comments on commit 299b7ac

Please sign in to comment.