Skip to content
This repository has been archived by the owner on Feb 22, 2020. It is now read-only.

Commit

Permalink
Fixed coordinates scaling for classification prediction (#272)
Browse files Browse the repository at this point in the history
* updated gitignore to ignore jpyter notebook and temporary files

* fixed coordinates scaling for classification

* fixed mashgrid generation for cropped patch

now machgrid is generated relativaly to symmetrical padding

* added tests for real modules classification

* fixed preprocessing tests
  • Loading branch information
Serhiy-Shekhovtsov authored and reubano committed Jan 9, 2018
1 parent a2b55a8 commit 413bc74
Show file tree
Hide file tree
Showing 10 changed files with 108 additions and 96 deletions.
6 changes: 6 additions & 0 deletions .gitignore
Expand Up @@ -56,6 +56,9 @@ package-lock.json
prediction/data/*
!prediction/data/.gitkeep

# temporary data
prediction/extracted/*

# ignore redis database dumps
*.rdb

Expand All @@ -67,3 +70,6 @@ prediction/data/*
data/
docs/apidoc_*/**/*.rst
!docs/apidoc_*/*_doc.rst

# jupyter notebook files
**/.ipynb_checkpoints/*
4 changes: 3 additions & 1 deletion prediction/src/algorithms/classify/src/gtr123_model.py
Expand Up @@ -258,10 +258,12 @@ def predict(ct_path, nodule_list, model_path=None):
# else:
# casenet = torch.nn.parallel.DistributedDataParallel(casenet)

preprocess = PreprocessCT(clip_lower=-1200., clip_upper=600., spacing=1., order=1,
preprocess = PreprocessCT(clip_lower=-1200., clip_upper=600., spacing=True, order=1,
min_max_normalize=True, scale=255, dtype='uint8')

# convert the image to voxels(apply the real spacing between pixels)
ct_array, meta = preprocess(*load_ct(ct_path))

patches = patches_from_ct(ct_array, meta, config['crop_size'], nodule_list,
stride=config['stride'], pad_value=config['filling_value'])

Expand Down
2 changes: 1 addition & 1 deletion prediction/src/algorithms/identify/src/gtr123_model.py
Expand Up @@ -501,7 +501,7 @@ def predict(ct_path, model_path=None):
# We have to use small batches until the next release of PyTorch, as bigger ones will segfault for CPU
# split_comber = SplitComb(side_len=int(32), margin=16, max_stride=16, stride=4, pad_value=170)
# Transform image to the 0-255 range and resample to 1x1x1mm
preprocess = preprocess_ct.PreprocessCT(clip_lower=-1200., clip_upper=600., spacing=1., order=1,
preprocess = preprocess_ct.PreprocessCT(clip_lower=-1200., clip_upper=600., spacing=True, order=1,
min_max_normalize=True, scale=255, dtype='uint8')

ct_array, meta = preprocess(ct_array, meta)
Expand Down
58 changes: 20 additions & 38 deletions prediction/src/preprocess/crop_patches.py
Expand Up @@ -2,31 +2,10 @@

import numpy as np
import scipy.ndimage
from src.preprocess import load_ct
from src.preprocess.preprocess_ct import mm_coordinates_to_voxel


def mm2voxel(coord, origin=0., spacing=1.):
""" Transfer coordinates in mm into voxel's location
Args:
coord (scalar | list[scalar]): coordinates in mm.
origin (scalar | list[scalar]): an origin of the CT scan in mm.
spacing (scalar | list[scalar]): an CT scan's spacing, i.e. the size of one voxel in mm.
Returns:
list[int]: the voxel location related to the coord.
"""
if np.isscalar(coord):
coord = [coord]

coord = np.array(coord)
origin = scipy.ndimage._ni_support._normalize_sequence(origin, len(coord))
spacing = scipy.ndimage._ni_support._normalize_sequence(spacing, len(coord))
coord = np.ceil((coord - np.array(origin)) / np.array(spacing))
return coord.astype(np.int)


def crop_patch(ct_array, meta, patch_shape=None, centroids=None, stride=None, pad_value=0):
def crop_patch(ct_array, patch_shape=None, centroids=None, stride=None, pad_value=0):
""" Generator yield a patch of a desired shape for each centroid
from a given a CT scan.
Expand All @@ -38,7 +17,6 @@ def crop_patch(ct_array, meta, patch_shape=None, centroids=None, stride=None, pa
{'x': int,
'y': int,
'z': int}
meta (src.preprocess.load_ct.MetaData): meta information of the CT scan.
stride (int): stride for patch coordinates meshgrid.
If None is set (default), then no meshgrid will be returned.
pad_value (int): value with which an array padding will be performed.
Expand All @@ -52,29 +30,28 @@ def crop_patch(ct_array, meta, patch_shape=None, centroids=None, stride=None, pa
if patch_shape is None:
patch_shape = []

if not isinstance(meta, load_ct.MetaData):
meta = load_ct.MetaData(meta)

patch_shape = scipy.ndimage._ni_support._normalize_sequence(patch_shape, len(ct_array.shape))
patch_shape = np.array(patch_shape)
init_shape = np.array(ct_array.shape)
padding = np.ceil(patch_shape / 2.).astype(np.int)
padding = np.stack([padding, padding], axis=1)
ct_array = np.pad(ct_array, padding, mode='constant', constant_values=pad_value)

for centroid in centroids:
centroid = mm2voxel([centroid[axis] for axis in 'zyx'], meta.origin, meta.spacing)
# array with padding size for each dimension
padding_size = np.ceil(patch_shape / 2.).astype(np.int)

# array with left and right padding for each dimension
padding_array = np.stack([padding_size, padding_size], axis=1)

# adding paddings at both ends of all dimensions
ct_array = np.pad(ct_array, padding_array, mode='constant', constant_values=pad_value)

for centroid in centroids:
# cropping a patch with selected centroid in the center of it
patch = ct_array[centroid[0]: centroid[0] + patch_shape[0],
centroid[1]: centroid[1] + patch_shape[1],
centroid[2]: centroid[2] + patch_shape[2]]

if stride:
init_shape += np.clip(patch_shape // 2 - centroid, 0, np.inf).astype(np.int64)
init_shape += np.clip(centroid + patch_shape // 2 - init_shape, 0, np.inf).astype(np.int64)
normstart = np.array(centroid) / np.array(ct_array.shape) - 0.5
normsize = np.array(patch_shape) / np.array(ct_array.shape)

normstart = (np.array(centroid) - patch_shape / 2) / init_shape - 0.5
normsize = patch_shape / init_shape
xx, yy, zz = np.meshgrid(np.linspace(normstart[0], normstart[0] + normsize[0], patch_shape[0] // stride),
np.linspace(normstart[1], normstart[1] + normsize[1], patch_shape[1] // stride),
np.linspace(normstart[2], normstart[2] + normsize[2], patch_shape[2] // stride),
Expand Down Expand Up @@ -113,6 +90,11 @@ def patches_from_ct(ct_array, meta, patch_shape=None, centroids=None, stride=Non
if centroids is None:
centroids = []

patch_generator = crop_patch(ct_array, meta, patch_shape, centroids, stride, pad_value)
centroids = [[centroid[axis] for axis in 'zyx'] for centroid in centroids]

# scale the coordinates according to spacing
centroids = [mm_coordinates_to_voxel(centroid, meta) for centroid in centroids]

patch_generator = crop_patch(ct_array, patch_shape, centroids, stride, pad_value)
patches = itertools.islice(patch_generator, len(centroids))
return list(patches)
42 changes: 29 additions & 13 deletions prediction/src/preprocess/preprocess_ct.py
Expand Up @@ -16,10 +16,7 @@ class Params:
If None is set (default), then no lower bound will applied.
clip_upper (int | float): clip the voxels' value to be less or equal to clip_upper.
If None is set (default), then no upper bound will applied.
spacing (float | sequence[float]): resample CT array to satisfy the desired spacing (voxel size along the axes).
If a float, `voxel_shape` is the same for each axis.
If a sequence, `voxel_shape` should contain one value for each axis.
If None is set (default), then no re-sampling will applied.
spacing (boolean): If True, resample CT array according to the meta.spacing.
order ({0, 1, 2, 3, 4}): the order of the spline interpolation used by re-sampling.
The default value is 0.
ndim (int): the dimension of CT array, should be greater than 1. The default value is 3.
Expand All @@ -34,7 +31,7 @@ class Params:
preprocess.preprocess_dicom.Params
"""

def __init__(self, clip_lower=None, clip_upper=None, spacing=None, order=0, # noqa: C901
def __init__(self, clip_lower=None, clip_upper=None, spacing=False, order=0, # noqa: C901
ndim=3, min_max_normalize=False, scale=None, dtype=None, to_hu=False):
if not isinstance(clip_lower, (int, float)) and (clip_lower is not None):
raise TypeError('The clip_lower should be int or float')
Expand All @@ -52,9 +49,7 @@ def __init__(self, clip_lower=None, clip_upper=None, spacing=None, order=0, # n
raise ValueError('The ndim should be greater than 0')
self.ndim = ndim

self.spacing = None
if spacing is not None:
self.spacing = scipy.ndimage._ni_support._normalize_sequence(spacing, self.ndim)
self.spacing = spacing

if not isinstance(min_max_normalize, (bool, int)) and (min_max_normalize is not None):
raise TypeError('The min_max_normalize should be bool or int')
Expand Down Expand Up @@ -137,17 +132,38 @@ def __call__(self, voxel_data, meta): # noqa: C901
data_min = voxel_data.min()
voxel_data = (voxel_data - data_min) / float(data_max - data_min)

if self.spacing is not None:
if self.scale is not None:
voxel_data *= self.scale

if self.spacing:
zoom_fctr = meta.spacing / np.asarray(self.spacing)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
voxel_data = scipy.ndimage.interpolation.zoom(voxel_data, zoom_fctr, order=self.order)
meta.spacing = [axis for axis in self.spacing]

if self.scale is not None:
voxel_data *= self.scale

if self.dtype:
voxel_data = voxel_data.astype(dtype=self.dtype, copy=False)

return voxel_data, meta


def mm_coordinates_to_voxel(coord, meta):
""" Transfer coordinates in mm into voxel's location
Args:
coord (scalar | list[scalar]): coordinates in mm.
meta (src.preprocess.load_ct.MetaData): meta information of the CT scan.
Returns:
list[int]: the voxel location related to the coord.
"""

if np.isscalar(coord):
coord = [coord]

coord = np.array(coord)
origin = scipy.ndimage._ni_support._normalize_sequence(meta.origin, len(coord))
spacing = scipy.ndimage._ni_support._normalize_sequence(meta.spacing, len(coord))
coord = np.rint((coord - np.array(origin)) * np.array(spacing))

return coord.astype(np.int)
2 changes: 1 addition & 1 deletion prediction/src/tests/conftest.py
Expand Up @@ -40,7 +40,7 @@ def full_mhd_path(scope='session'):

@pytest.fixture
def dicom_paths(scope='session'):
yield glob(path.join(Config.FULL_DICOM_PATHS_WILDCARD))
yield sorted(glob(path.join(Config.FULL_DICOM_PATHS_WILDCARD)))


@pytest.fixture
Expand Down
12 changes: 12 additions & 0 deletions prediction/src/tests/test_classification.py
Expand Up @@ -11,6 +11,18 @@ def test_classify_dicom(dicom_paths, nodule_locations, model_path):
assert 0 <= predicted[0]['p_concerning'] <= 1


def test_classify_real_nodule_small_dicom(dicom_path_003, model_path):
predicted = trained_model.predict(dicom_path_003, [{'x': 369, 'y': 350, 'z': 5}], model_path)
assert predicted
assert 0.3 <= predicted[0]['p_concerning'] <= 1


def test_classify_real_nodule_full_dicom(dicom_paths, model_path):
predicted = trained_model.predict(dicom_paths[2], [{'x': 367, 'y': 349, 'z': 75}], model_path)
assert predicted
assert 0.3 <= predicted[0]['p_concerning'] <= 1


def test_classify_luna(metaimage_path, luna_nodule, model_path):
predicted = trained_model.predict(metaimage_path, [luna_nodule], model_path)
assert predicted
Expand Down
9 changes: 8 additions & 1 deletion prediction/src/tests/test_cropping.py
Expand Up @@ -6,6 +6,7 @@
from ..preprocess.crop_dicom import crop_dicom
from ..preprocess.load_ct import load_ct
from ..preprocess.crop_patches import patches_from_ct
from ..preprocess.preprocess_ct import PreprocessCT


def test_crop_dicom(dicom_path):
Expand All @@ -32,7 +33,13 @@ def test_crop_dicom(dicom_path):


def test_patches_from_ct(ct_path, luna_nodules):
patches = patches_from_ct(*load_ct(ct_path), patch_shape=12, centroids=luna_nodules)
preprocess = PreprocessCT(spacing=True)

# convert the image to voxels(apply the real spacing between pixels)
# convert the meta to load_ct.MetaData
ct_array, meta = preprocess(*load_ct(ct_path))

patches = patches_from_ct(ct_array, meta, patch_shape=12, centroids=luna_nodules)
assert isinstance(patches, list)
assert len(patches) == 3
assert all(patch.shape == (12, 12, 12) for patch in patches)
61 changes: 28 additions & 33 deletions prediction/src/tests/test_grt123_preprocess.py
Expand Up @@ -23,28 +23,21 @@ def __init__(self):

def __call__(self, imgs, target):
crop_size = np.array(self.crop_size).astype('int')
padding_size = np.ceil(np.array(self.crop_size) / 2.).astype(np.int)

start = (target[:3] - crop_size / 2).astype('int')
pad = [[0, 0]]
# array with left and right padding for each dimension
pad = np.stack([padding_size, padding_size], axis=1)

for i in range(3):
if start[i] < 0:
leftpad = -start[i]
start[i] = 0
else:
leftpad = 0
if start[i] + crop_size[i] > imgs.shape[i + 1]:
rightpad = start[i] + crop_size[i] - imgs.shape[i + 1]
else:
rightpad = 0

pad.append([leftpad, rightpad])
start = np.rint(target).astype(np.int)

imgs = np.pad(imgs, pad, 'constant', constant_values=self.filling_value)
crop = imgs[:, start[0]:start[0] + crop_size[0], start[1]:start[1] + crop_size[1],
crop = imgs[start[0]:start[0] + crop_size[0],
start[1]:start[1] + crop_size[1],
start[2]:start[2] + crop_size[2]]
normstart = np.array(start).astype('float32') / np.array(imgs.shape[1:]) - 0.5
normsize = np.array(crop_size).astype('float32') / np.array(imgs.shape[1:])

normstart = np.array(start).astype('float32') / np.array(imgs.shape) - 0.5
normsize = np.array(crop_size).astype('float32') / np.array(imgs.shape)

xx, yy, zz = np.meshgrid(np.linspace(normstart[0],
normstart[0] + normsize[0],
self.crop_size[0] // self.stride),
Expand All @@ -69,11 +62,16 @@ def lum_trans(img):
Returns: Image windowed to [-1200; 600] and scaled to 0-255
"""
lungwin = np.array([-1200., 600.])
newimg = (img - lungwin[0]) / (lungwin[1] - lungwin[0])
newimg[newimg < 0] = 0
newimg[newimg > 1] = 1
return (newimg * 255).astype('uint8')
clip_lower = -1200.
clip_upper = 600.

newimg = np.copy(img)
newimg[newimg < clip_lower] = clip_lower
newimg[newimg > clip_upper] = clip_upper

newimg = (newimg - clip_lower) / float(clip_upper - clip_lower)

return newimg * 255


def resample(imgs, spacing, new_spacing, order=2):
Expand Down Expand Up @@ -115,7 +113,7 @@ def resample(imgs, spacing, new_spacing, order=2):

def test_lum_trans(metaimage_path):
ct_array, meta = load_ct.load_ct(metaimage_path)
lumed = lum_trans(ct_array)
lumed = lum_trans(ct_array).astype('uint8')
functional = preprocess_ct.PreprocessCT(clip_lower=-1200., clip_upper=600.,
min_max_normalize=True, scale=255, dtype='uint8')

Expand All @@ -126,7 +124,7 @@ def test_lum_trans(metaimage_path):
def test_resample(metaimage_path):
ct_array, meta = load_ct.load_ct(metaimage_path)
resampled, _ = resample(ct_array, np.array(load_ct.MetaData(meta).spacing), np.array([1, 1, 1]), order=1)
preprocess = preprocess_ct.PreprocessCT(spacing=1., order=1)
preprocess = preprocess_ct.PreprocessCT(spacing=True, order=1)
processed, _ = preprocess(ct_array, meta)
assert np.abs(resampled - processed).sum() == 0

Expand All @@ -140,22 +138,19 @@ def test_preprocess(metaimage_path):
origin = np.array(image_itk.GetOrigin())[::-1]
image = lum_trans(image)
image = resample(image, spacing, np.array([1, 1, 1]), order=1)[0]
image = image.astype('uint8')

crop = SimpleCrop()

for nodule in nodule_list:
nod_location = np.array([np.float32(nodule[s]) for s in ["z", "y", "x"]])
nod_location = np.ceil((nod_location - origin) / 1.)
cropped_image, coords = crop(image[np.newaxis], nod_location)
nod_location = (nod_location - origin) * spacing
cropped_image, coords = crop(image, nod_location)

# New style
ct_array, meta = load_ct.load_ct(metaimage_path)

preprocess = preprocess_ct.PreprocessCT(clip_lower=-1200., clip_upper=600.,
min_max_normalize=True, scale=255, dtype='uint8')
preprocess = preprocess_ct.PreprocessCT(clip_lower=-1200., clip_upper=600., min_max_normalize=True, scale=255,
spacing=True, order=1, dtype='uint8')

ct_array, meta = preprocess(ct_array, meta)
preprocess = preprocess_ct.PreprocessCT(spacing=1., order=1)
ct_array, meta = load_ct.load_ct(metaimage_path)
ct_array, meta = preprocess(ct_array, meta)

cropped_image_new, coords_new = crop_patches.patches_from_ct(ct_array, meta, 96, nodule_list,
Expand Down
8 changes: 0 additions & 8 deletions prediction/src/tests/test_preprocess_dicom.py
Expand Up @@ -6,21 +6,13 @@

def test_create_params():
preprocess_ct.Params()
params = preprocess_ct.Params(spacing=1., ndim=3)
assert len(params.spacing) == 3

spacing = [shape == 1. for shape in params.spacing]
assert all(spacing)

with pytest.raises(TypeError):
preprocess_ct.Params(clip_lower='one', clip_upper=0)
preprocess_ct.Params(clip_lower=1, clip_upper=0)
preprocess_ct.Params(ndim=0)
preprocess_ct.Params(min_max_normalize=[False])

with pytest.raises(RuntimeError):
preprocess_ct.Params(spacing=(1, 1, 1, 1), ndim=3)


def test_preprocess_dicom_pure(dicom_path):
preprocess = preprocess_ct.PreprocessCT()
Expand Down

0 comments on commit 413bc74

Please sign in to comment.