Skip to content

Commit

Permalink
support pickling ants images
Browse files Browse the repository at this point in the history
  • Loading branch information
ncullen93 committed May 24, 2024
1 parent b4ba534 commit d2e7595
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 44 deletions.
54 changes: 27 additions & 27 deletions nitrain/datasets/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import warnings
import numpy as np
import math
import random
from copy import deepcopy, copy

from ..readers.utils import infer_reader
Expand All @@ -26,8 +27,6 @@ def __init__(self, inputs, outputs, transforms=None, base_dir=None, base_file=No
inputs = infer_reader(inputs)
outputs = infer_reader(outputs)

self._inputs = inputs
self._outputs = outputs
self._base_dir = base_dir
self._base_file = base_file

Expand All @@ -45,7 +44,23 @@ def __init__(self, inputs, outputs, transforms=None, base_dir=None, base_file=No
self.outputs = outputs
self.transforms = transforms

def split(self, p, shuffle=False):
def select(self, n, random=False):
"""
Select a number of records from the dataset.
"""
all_indices = np.arange(len(self))
if random:
selected_indices = np.random.choice(all_indices, size=n, replace=False)
else:
selected_indices = np.arange(n)

ds = deepcopy(self)
ds.inputs = ds.inputs.select(selected_indices)
ds.outputs = ds.outputs.select(selected_indices)

return ds

def split(self, p, random=False):
"""
Split dataset into training, testing, and optionally validation.
Expand All @@ -66,16 +81,16 @@ def split(self, p, shuffle=False):
n_vals = len(self)
indices = np.arange(n_vals)

if shuffle:
if random:
if p[2] > 0:
sampled_indices = np.random.choice([0,1,2], size=n_vals, p=p)
train_indices = indices[np.where(sampled_indices==0)[0]]
test_indices = indices[np.where(sampled_indices==1)[0]]
val_indices = indices[np.where(sampled_indices==2)[0]]
train_indices = np.where(sampled_indices==0)[0]
test_indices = np.where(sampled_indices==1)[0]
val_indices = np.where(sampled_indices==2)[0]
else:
sampled_indices = np.random.choice([0,1], size=n_vals, p=p[:-1])
train_indices = indices[np.where(sampled_indices==0)[0]]
test_indices = indices[np.where(sampled_indices==1)[0]]
train_indices = np.where(sampled_indices==0)[0]
test_indices = np.where(sampled_indices==1)[0]
else:
if p[2] > 0:
train_indices = indices[:math.ceil(n_vals*p[0])]
Expand All @@ -85,33 +100,18 @@ def split(self, p, shuffle=False):
train_indices = indices[:math.ceil(n_vals*p[0])]
test_indices = indices[math.ceil(n_vals*p[0]):]


ds_train = Dataset(self._inputs,
self._outputs,
self.transforms,
self._base_dir,
self._base_file)

ds_test = Dataset(self._inputs,
self._outputs,
self.transforms,
self._base_dir,
self._base_file)
ds_train = deepcopy(self)
ds_test = deepcopy(self)

ds_train.inputs = ds_train.inputs.select(train_indices)
ds_train.outputs = ds_train.outputs.select(train_indices)
ds_test.inputs = ds_test.inputs.select(test_indices)
ds_test.outputs = ds_test.outputs.select(test_indices)

if p[2] > 0:
ds_val = Dataset(self._inputs,
self._outputs,
self.transforms,
self._base_dir,
self._base_file)
ds_val = deepcopy(self)
ds_val.inputs = ds_val.inputs.select(val_indices)
ds_val.outputs = ds_val.outputs.select(val_indices)

return ds_train, ds_test, ds_val
else:
return ds_train, ds_test
Expand Down
9 changes: 6 additions & 3 deletions nitrain/readers/folder_name.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def map_gcs_values(self, bucket, credentials=None, base_dir=None, base_file=None
self.values = [np.where(unique_values==v)[0][0] for v in values]
elif self.format == 'onehot':
self.values = [list(np.eye(len(unique_values),
dtype='uint32')[np.where(unique_values==v)[0][0]]) for v in values]
dtype='uint8')[np.where(unique_values==v)[0][0]]) for v in values]
elif self.format == 'string':
self.values = values
else:
Expand Down Expand Up @@ -152,7 +152,7 @@ def map_values(self, base_dir=None, base_label=None, **kwargs):
self.values = [np.where(unique_values==v)[0][0] for v in values]
elif self.format == 'onehot':
self.values = [list(np.eye(len(unique_values),
dtype='uint32')[np.where(unique_values==v)[0][0]]) for v in values]
dtype='uint8')[np.where(unique_values==v)[0][0]]) for v in values]
elif self.format == 'string':
self.values = values
else:
Expand All @@ -168,7 +168,10 @@ def map_values(self, base_dir=None, base_label=None, **kwargs):
self.label = 'folder_name'

def __getitem__(self, idx):
return {self.label: np.array(self.values[idx])}
if self.format == 'onehot':
return {self.label: np.array(self.values[idx])}
else:
return {self.label: self.values[idx]}

def __len__(self):
return len(self.values)
13 changes: 8 additions & 5 deletions nitrain/transforms/intensity.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def __init__(self, min=0, max=1):
import ants
from nitrain import transforms as tx
img = ants.image_read(ants.get_data('r16'))
mytx = tx.RangeNormalize(0, 2)
mytx = tx.RangeNormalize(0, 1)
img2 = mytx(img)
"""
self.min = min
Expand All @@ -79,10 +79,13 @@ def __call__(self, *images):
image = image.clone('float')
minimum = image.min()
maximum = image.max()
m = (self.max - self.min) / (maximum - minimum)
b = self.min - m * minimum
image = m * image + b
new_images.append(image)
if maximum - minimum == 0:
new_images.append(image)
else:
m = (self.max - self.min) / (maximum - minimum)
b = self.min - m * minimum
image = m * image + b
new_images.append(image)
return new_images if len(new_images) > 1 else new_images[0]

class Clip(BaseTransform):
Expand Down
14 changes: 7 additions & 7 deletions tests/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def test_memory_double_inputs(self):
self.assertEqual(y, 4)

# test split
ds_train, ds_test = dataset.split(0.8, shuffle=False)
ds_train, ds_test = dataset.split(0.8, random=False)
self.assertEqual(len(ds_train), 8)
self.assertEqual(len(ds_test), 2)

Expand Down Expand Up @@ -213,7 +213,7 @@ def test_split(self):
)

ds0,ds1,ds2 = ds.split((0.6,0.2,0.2))
ds0,ds1,ds2 = ds.split((0.6,0.2,0.2), shuffle=False)
ds0,ds1,ds2 = ds.split((0.6,0.2,0.2), random=False)

with self.assertRaises(Exception):
ds0,ds1,ds2 = ds.split((0.6,0.2,0.5))
Expand Down Expand Up @@ -260,7 +260,7 @@ def test_2d(self):
self.assertEqual(y, [50, 51])

# test split
ds_train, ds_test = dataset.split(0.8, shuffle=False)
ds_train, ds_test = dataset.split(0.8, random=False)
self.assertTrue(len(ds_train) > len(ds_test))

# test repr
Expand All @@ -275,7 +275,7 @@ def test_2d_split(self):
base_file=os.path.join(tmp_dir, 'participants.csv')
)

ds_train, ds_test = dataset.split(0.8, shuffle=False)
ds_train, ds_test = dataset.split(0.8, random=False)
self.assertTrue(len(ds_train) > len(ds_test))

# test repr
Expand Down Expand Up @@ -338,7 +338,7 @@ def test_pattern_compose(self):
outputs=readers.ImageReader('*/img3d_100.nii.gz'),
base_dir=base_dir)

data_train, data_test = dataset.split(0.8, shuffle=False)
data_train, data_test = dataset.split(0.8, random=False)

self.assertEqual(len(data_train), 8)
self.assertEqual(len(data_test), 2)
Expand Down Expand Up @@ -367,7 +367,7 @@ def test_folder_name(self):
outputs=readers.FolderNameReader('*/img3d_100.nii.gz'),
base_dir=base_dir)

data_train, data_test = dataset.split(0.8, shuffle=False)
data_train, data_test = dataset.split(0.8, random=False)

self.assertEqual(len(data_train), 8)
self.assertEqual(len(data_test), 2)
Expand Down Expand Up @@ -411,7 +411,7 @@ def test_folder_name_compose(self):
readers.FolderNameReader('*/img3d.nii.gz')],
base_dir=base_dir)

data_train, data_test = dataset.split(0.8, shuffle=False)
data_train, data_test = dataset.split(0.8, random=False)

x,y=dataset[3]
self.assertEqual(x.mean(), 4)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def test_multiple_image_slice_after_split(self):
outputs=readers.ImageReader('*/img3d_seg.nii.gz'),
base_dir=base_dir)

ds_train, ds_test = dataset.split(0.8, shuffle=False)
ds_train, ds_test = dataset.split(0.8, random=False)

loader = nt.Loader(ds_train,
images_per_batch=1,
Expand Down Expand Up @@ -302,7 +302,7 @@ def test_multiclass_segmentation_no_expand_dims(self):

x,y = dataset[0]

data_train, data_test = dataset.split(0.8, shuffle=False)
data_train, data_test = dataset.split(0.8, random=False)

loader = nt.Loader(data_train,
images_per_batch=4,
Expand Down

0 comments on commit d2e7595

Please sign in to comment.