Skip to content

Commit

Permalink
applying the resize on a per-image basis, now. #48
Browse files Browse the repository at this point in the history
  • Loading branch information
sarthakpati committed Apr 27, 2021
1 parent 83a54d8 commit 91283ec
Showing 1 changed file with 25 additions and 24 deletions.
49 changes: 25 additions & 24 deletions GANDLF/data/ImagesFromDataFrame.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,9 @@
from torchio import Image, Subject
import SimpleITK as sitk
# from GANDLF.utils import resize_image
from GANDLF.preprocessing import (NonZeroNormalizeOnMaskedRegion, CropExternalZeroplanes,
resize_image_resolution, threshold_intensities,
from GANDLF.preprocessing import (NonZeroNormalizeOnMaskedRegion, CropExternalZeroplanes, apply_resize, threshold_intensities,
tensor_rotate_180, tensor_rotate_90, clip_intensities,
normalize_imagenet, normalize_standardize,
normalize_div_by_255)
normalize_imagenet, normalize_standardize, normalize_div_by_255)

import copy, sys

Expand Down Expand Up @@ -163,10 +161,14 @@ def ImagesFromDataFrame(dataframe,
for i in range(len(augmentation_patchAxesPoints)):
augmentation_patchAxesPoints[i] = max(round(augmentation_patchAxesPoints[i] / 10), 1) # always at least have 1

# initialize resizeCheck if not present
if not('resizeCheck' in preprocessing):
preprocessing['resizeCheck'] = False

resize_images = False
# if resize has been defined but resample is not (or is none)
if not(preprocessing is None) and ('resize' in preprocessing):
if (preprocessing['resize'] is not None):
if not('resample' in preprocessing):
resize_images = True
else:
print('WARNING: \'resize\' is ignored as \'resample\' is defined under \'data_processing\', this will be skipped', file = sys.stderr)
# iterating through the dataframe
for patient in range(num_row):
# We need this dict for storing the meta data for each subject
Expand All @@ -181,24 +183,17 @@ def ImagesFromDataFrame(dataframe,
else:
img = sitk.ReadImage(str(dataframe[channel][patient]))
array = np.expand_dims(sitk.GetArrayFromImage(img), axis=0)
print("Image shape : ", img.shape, flush=True)
print("Array shape : ", array.shape, flush=True)
# print("Image shape : ", img.shape, flush=True)
# print("Array shape : ", array.shape, flush=True)
subject_dict[str(channel)] = Image(tensor=array, type=torchio.INTENSITY, path=dataframe[channel][patient])

# if resize is requested, the perform per-image resize with appropriate interpolator
if resize_images:
img = subject_dict[str(channel)].as_sitk()
img_resized = apply_resize(img, preprocessing_params=preprocessing)
array = np.expand_dims(sitk.GetArrayFromImage(img_resized), axis=0)
subject_dict[str(channel)] = Image(tensor=array, type=torchio.INTENSITY, path=dataframe[channel][patient])

# if resize has been defined but resample is not (or is none)
if not preprocessing['resizeCheck']:
if not(preprocessing is None) and ('resize' in preprocessing):
if (preprocessing['resize'] is not None):
preprocessing['resizeCheck'] = True
if not('resample' in preprocessing):
preprocessing['resample'] = {}
if not('resolution' in preprocessing['resample']):
preprocessing['resample']['resolution'] = resize_image_resolution(subject_dict[str(channel)].as_sitk(), preprocessing['resize'])
else:
print('WARNING: \'resize\' is ignored as \'resample\' is defined under \'data_processing\', this will be skipped', file = sys.stderr)
else:
preprocessing['resizeCheck'] = True

# # for regression
# if predictionHeaders:
# # get the mask
Expand All @@ -213,6 +208,12 @@ def ImagesFromDataFrame(dataframe,
array = np.expand_dims(sitk.GetArrayFromImage(img), axis=0)
subject_dict['label'] = Image(tensor=array, type=torchio.LABEL, path=dataframe[labelHeader][patient])

# if resize is requested, the perform per-image resize with appropriate interpolator
if resize_images:
img = sitk.ReadImage(str(dataframe[labelHeader][patient]))
img_resized = apply_resize(img, preprocessing_params=preprocessing, interpolator=sitk.sitkNearestNeighbor)
array = np.expand_dims(sitk.GetArrayFromImage(img_resized), axis=0)
subject_dict['label'] = Image(tensor=array, type=torchio.LABEL, path=dataframe[channel][patient])

subject_dict['path_to_metadata'] = str(dataframe[labelHeader][patient])
else:
Expand Down

0 comments on commit 91283ec

Please sign in to comment.