Skip to content

Commit

Permalink
add predictor
Browse files Browse the repository at this point in the history
  • Loading branch information
ncullen93 committed May 3, 2024
1 parent 2d0d2d1 commit 0071cc2
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 7 deletions.
48 changes: 46 additions & 2 deletions nitrain/predictors/predictor.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
import numpy as np
import ntimage as nti

from ..loaders import Loader

class Predictor:

def __init__(self, model, task, sampler=None):
def __init__(self, model, task, sampler=None, expand_dims=-1):
self.model = model
self.task = task
self.sampler = sampler
self.expand_dims = expand_dims

def predict(self, dataset):
"""
Expand All @@ -22,4 +27,43 @@ def predict(self, dataset):
The result of the prediction will be one (or a sequence of) of the
following depending on the model: ntimage, np.ndarray, scalar.
"""
pass

y_pred_list = []
for x, y in dataset:
sampled_batch = self.sampler([x], [y])

y_pred = []
for x_batch, y_batch in sampled_batch:
if isinstance(x_batch[0], list):
x_batch_return = []
for i in range(len(x_batch[0])):
tmp_x_batch = np.array([np.expand_dims(xx[i].numpy(), self.expand_dims) if self.expand_dims else xx.numpy() for xx in x_batch])
x_batch_return.append(tmp_x_batch)
x_batch = x_batch_return
else:
x_batch = np.array([np.expand_dims(xx.numpy(), self.expand_dims) if self.expand_dims else xx.numpy() for xx in x_batch])

# TODO: write general function for model prediction
tmp_y_pred = self.model.predict(x_batch)
y_pred.append(tmp_y_pred)

# TODO: handle multiple inputs
y_pred = np.concatenate(y_pred)
y_pred = np.squeeze(y_pred)

# put sampled axis in correct place
if 'SliceSampler' in str(type(self.sampler)):
y_pred = np.rollaxis(y_pred, 0, self.sampler.axis)

# process prediction according to task
if self.task == 'segmentation':
y_pred = np.round(y_pred).astype('uint8')
y_pred = nti.from_numpy(y_pred)

y_pred_list.append(y_pred)

return y_pred_list




2 changes: 1 addition & 1 deletion nitrain/samplers/slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ class SliceSampler:
"""
Sampler that returns batches of 2D slices from 3D images.
"""
def __init__(self, batch_size, axis=0, shuffle=False):
def __init__(self, batch_size=24, axis=-1, shuffle=False):
self.batch_size = batch_size
self.axis = axis
self.shuffle = shuffle
Expand Down
7 changes: 3 additions & 4 deletions tests/test_workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def setUp(self):
def tearDown(self):
pass

def test_classification(self):
def test_segmentation(self):
base_dir = nt.fetch_data('example-01')

dataset = nt.Dataset(inputs=PatternReader('*/img3d.nii.gz'),
Expand Down Expand Up @@ -57,9 +57,8 @@ def test_classification(self):
# inference on test data
predictor = nt.Predictor(model,
task='segmentation',
sampler=SliceSampler(batch_size=20, axis=-1))
#y_pred = predictor.predict(data_test)

sampler=SliceSampler(axis=-1))
y_pred = predictor.predict(data_test)


if __name__ == '__main__':
Expand Down

0 comments on commit 0071cc2

Please sign in to comment.