Skip to content

Commit

Permalink
ENH: add random transforms
Browse files Browse the repository at this point in the history
  • Loading branch information
ncullen93 committed May 24, 2024
1 parent fc42ed2 commit 7dcb760
Show file tree
Hide file tree
Showing 15 changed files with 453 additions and 50 deletions.
56 changes: 51 additions & 5 deletions nitrain/datasets/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,46 @@ def __init__(self, inputs, outputs, transforms=None, base_dir=None, base_file=No
self.outputs = outputs
self.transforms = transforms

def split(self, p, shuffle=False):
def split(self, p, random=True):
"""
Split dataset into training, testing, and optionally validation.
dataset.split(0.8)
dataset.split((0.8,0.2))
dataset.split((0.8,0.1,0.1))
"""
if isinstance(p, float):
p = (p, 1-p, 0)

if isinstance(p, (list, tuple)):
if len(p) == 2:
p = p + (0,)

if sum(p) != 1:
raise Exception('The probabilities must sum to 1.')

n_vals = len(self)
indices = np.arange(n_vals)
train_indices = indices[:math.ceil(n_vals*p)]
test_indices = indices[math.ceil(n_vals*p):]

if random:
if p[2] > 0:
sampled_indices = np.random.choice([0,1,2], size=n_vals, p=p)
train_indices = indices[np.where(sampled_indices==0)[0]]
test_indices = indices[np.where(sampled_indices==1)[0]]
val_indices = indices[np.where(sampled_indices==2)[0]]
else:
sampled_indices = np.random.choice([0,1], size=n_vals, p=p[:-1])
train_indices = indices[np.where(sampled_indices==0)[0]]
test_indices = indices[np.where(sampled_indices==1)[0]]
else:
if p[2] > 0:
train_indices = indices[:math.ceil(n_vals*p[0])]
test_indices = indices[math.ceil(n_vals*p[0]):math.ceil(n_vals*(p[0]+p[1]))]
val_indices = indices[math.ceil(n_vals*(p[0]+p[1])):]
else:
train_indices = indices[:math.ceil(n_vals*p[0])]
test_indices = indices[math.ceil(n_vals*p[0]):]


ds_train = Dataset(self._inputs,
self._outputs,
Expand All @@ -62,13 +97,24 @@ def split(self, p, shuffle=False):
self.transforms,
self._base_dir,
self._base_file)

ds_train.inputs = ds_train.inputs.select(train_indices)
ds_train.outputs = ds_train.outputs.select(train_indices)
ds_test.inputs = ds_test.inputs.select(test_indices)
ds_test.outputs = ds_test.outputs.select(test_indices)

if p[2] > 0:
ds_val = Dataset(self._inputs,
self._outputs,
self.transforms,
self._base_dir,
self._base_file)
ds_val.inputs = ds_val.inputs.select(val_indices)
ds_val.outputs = ds_val.outputs.select(val_indices)

return ds_train, ds_test
return ds_train, ds_test, ds_val
else:
return ds_train, ds_test

def filter(self, expr):
raise NotImplementedError('Not implemented')
Expand Down
27 changes: 13 additions & 14 deletions nitrain/loaders/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import numpy as np
import warnings
import ants
from copy import deepcopy
from copy import deepcopy, copy

from .. import samplers, transforms as tx
from ..datasets.utils import reduce_to_list, apply_transforms
Expand All @@ -13,7 +13,7 @@ def __init__(self,
dataset,
images_per_batch,
transforms=None,
expand_dims=-1,
channel_axis=-1,
shuffle=False,
sampler=None):
"""
Expand All @@ -33,21 +33,20 @@ def __init__(self,

self.dataset = dataset
self.images_per_batch = images_per_batch
self.expand_dims = expand_dims
self.channel_axis = channel_axis
self.transforms = transforms
self.shuffle = shuffle

if sampler is None:
sampler = samplers.BaseSampler(batch_size=images_per_batch)
self.sampler = sampler

def copy(self, dataset=None):

def copy(self, dataset=None, keep_transforms=True):
new_loader = Loader(
dataset = deepcopy(self.dataset),
dataset = copy(self.dataset) if dataset is None else dataset,
images_per_batch = self.images_per_batch,
expand_dims = self.expand_dims,
transforms = self.transforms,
channel_axis = self.channel_axis,
transforms = self.transforms if keep_transforms else None,
shuffle = self.shuffle,
sampler = self.sampler
)
Expand Down Expand Up @@ -104,9 +103,9 @@ def __iter__(self):

for x_batch, y_batch in sampled_batch:

if self.expand_dims:
x_batch = expand_image_dims(x_batch)
y_batch = expand_image_dims(y_batch)
if self.channel_axis:
x_batch = expand_image_dims(x_batch, self.channel_axis)
y_batch = expand_image_dims(y_batch, self.channel_axis)

x_batch = convert_to_numpy(x_batch)
y_batch = convert_to_numpy(y_batch)
Expand Down Expand Up @@ -156,10 +155,10 @@ def convert_to_numpy(x):
else:
return np.array(x)

def expand_image_dims(x):
mytx = tx.AddChannel()
def expand_image_dims(x, axis):
mytx = tx.AddChannel(axis)
if isinstance(x, list):
return [expand_image_dims(xx) for xx in x]
return [expand_image_dims(xx, axis) for xx in x]
else:
if ants.is_image(x):
return mytx(x) if not x.has_components else x
Expand Down
7 changes: 4 additions & 3 deletions nitrain/samplers/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,10 @@ def create_patches(images, values, block_size, stride):
z_indices = grid[2].flatten()

for a, b, c in zip(x_indices, y_indices, z_indices):
cropped_image = image.crop([(a, a+block_size[0]),
(b, b+block_size[1]),
(c, c+block_size[2])])
cropped_image = image.crop_indices((a,b,c),
(a+block_size[0],
b+block_size[1],
c+block_size[2]))
cropped_images.append(cropped_image)
new_values.append(value)

Expand Down
4 changes: 2 additions & 2 deletions nitrain/samplers/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ def create_patches(images, values, patch_size, stride):
y_indices = grid[1].flatten()

for a, b in zip(x_indices, y_indices):
cropped_image = image.crop([(a, a+patch_size[0]),
(b, b+patch_size[1])])
cropped_image = image.crop_indices((a,b),
(a+patch_size[0],b+patch_size[1]))
cropped_images.append(cropped_image)
new_values.append(value)

Expand Down
1 change: 1 addition & 0 deletions nitrain/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,5 @@
from .math import *
from .shape import *
from .spatial import *
from .spatial_random import *
from .utility import *
4 changes: 2 additions & 2 deletions nitrain/transforms/shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,15 @@
]

class AddChannel(BaseTransform):
def __init__(self):
def __init__(self, axis=-1):
"""
import ants
from nitrain import transforms as tx
img = ants.image_read(ants.get_data('r16'))
mytx = tx.AddChannel()
img2 = mytx(img)
"""
pass
self.axis = axis

def __call__(self, *images):
images = [ants.merge_channels([image]) for image in images]
Expand Down
6 changes: 3 additions & 3 deletions nitrain/transforms/spatial.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
'Shear',
'Rotate',
'Zoom',
'Reflect',
'Flip',
'Translate'
]

Expand Down Expand Up @@ -277,14 +277,14 @@ def __call__(self, *images):
new_images.append(new_image)
return new_images if len(new_images) > 1 else new_images[0]

class Reflect(BaseTransform):
class Flip(BaseTransform):

def __init__(self, axis=0):
"""
import ants
from nitrain import transforms as tx
img = ants.image_read(ants.get_data('r16'))
mytx = tx.Reflect()
mytx = tx.Flip()
img2 = mytx(img)
"""
self.axis = axis
Expand Down
Loading

0 comments on commit 7dcb760

Please sign in to comment.