Skip to content

Commit

Permalink
ENH: implement probability for random tx
Browse files Browse the repository at this point in the history
  • Loading branch information
ncullen93 committed May 24, 2024
1 parent 7dcb760 commit 0df3c4f
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 2 deletions.
4 changes: 2 additions & 2 deletions nitrain/loaders/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,12 @@ def __init__(self,
sampler = samplers.BaseSampler(batch_size=images_per_batch)
self.sampler = sampler

def copy(self, dataset=None, keep_transforms=True):
def copy(self, dataset=None, drop_transforms=False):
new_loader = Loader(
dataset = copy(self.dataset) if dataset is None else dataset,
images_per_batch = self.images_per_batch,
channel_axis = self.channel_axis,
transforms = self.transforms if keep_transforms else None,
transforms = self.transforms if not drop_transforms else None,
shuffle = self.shuffle,
sampler = self.sampler
)
Expand Down
3 changes: 3 additions & 0 deletions nitrain/transforms/spatial_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,9 @@ def __call__(self, *images):
min_rotation = self.min_rotation
max_rotation = self.max_rotation

if random.uniform(0, 1) > self.p:
return images if len(images) > 1 else images[0]

if not isinstance(min_rotation, (tuple,list)):
theta = math.pi / 180 * random.uniform(min_rotation, max_rotation)
rotation_matrix = np.array([[np.cos(theta), -np.sin(theta), 0],
Expand Down
17 changes: 17 additions & 0 deletions tests/test_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,23 @@ def test_2d(self):
# test repr
rep = loader.__repr__()

def test_copy(self):
import ants
import nitrain as nt
img = ants.image_read(ants.get_data('r16'))

x = [img for _ in range(100)]
y = list(range(100))

dataset = nt.Dataset(x, y)
# test copy
ds_train, ds_test = dataset.split(0.8)
loader_train = nt.Loader(ds_train, images_per_batch=12)
loader_test = loader_train.copy(ds_test)
self.assertEqual(loader_test.images_per_batch, loader_train.images_per_batch)

loader_test = loader_train.copy(ds_test, drop_transforms=True)

def test_to_keras(self):
loader = nt.Loader(self.dataset_2d, images_per_batch=4)
keras_loader = loader.to_keras()
Expand Down

0 comments on commit 0df3c4f

Please sign in to comment.