Skip to content

Commit

Permalink
Merge 749132c into 003f9fa
Browse files Browse the repository at this point in the history
  • Loading branch information
jcohenadad committed Nov 11, 2018
2 parents 003f9fa + 749132c commit c712dc6
Show file tree
Hide file tree
Showing 4 changed files with 212 additions and 330 deletions.
95 changes: 39 additions & 56 deletions AxonDeepSeg/apply_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,15 @@
from .visualization.get_masks import get_masks
import AxonDeepSeg.ads_utils

def apply_convnet(path_acquisitions, acquisitions_resolutions, path_model_folder, config_dict,
acquisitions_names = None,
ckpt_name='model', inference_batch_size=1,
overlap_value=25, resampled_resolutions=[0.1], prediction_proba_activate=False,
gpu_per=1.0, verbosity_level=0):

'''
def apply_convnet(path_acquisitions, acquisitions_resolutions, path_model_folder, config_dict, ckpt_name='model',
inference_batch_size=1, overlap_value=25, resampled_resolutions=[0.1],
prediction_proba_activate=False, gpu_per=1.0, verbosity_level=0):
"""
Preprocesses the images, transform them into patches, applies the network, stitches the predictions and return them.
:param path_acquisitions: List of path to the acquisitions.
:param acquisitions_resolutions: List of the acquisitions resolutions (floats).
:param path_model_folder: Path to the model folder.
:param config_dict: Dictionary containing the model's parameters.
:param acquisitions_names: List of names of the acquisitions.
:param ckpt_name: String, checkpoint to use.
:param inference_batch_size: Int, batch size to use when doing inference.
:param overlap_value: Int, number of pixels to use when overlapping the predictions of the network.
Expand All @@ -31,7 +27,7 @@ def apply_convnet(path_acquisitions, acquisitions_resolutions, path_model_folder
:param gpu_per: Float, percentage of GPU to use if we use it.
:param verbosity_level: Int, how much information to display.
:return: List of segmentations, and list of probability maps if requested.
'''
"""

# We set the logging from python and Tensorflow to a high level, to avoid messages
# in the console when performing segmentation.
Expand All @@ -44,7 +40,7 @@ def apply_convnet(path_acquisitions, acquisitions_resolutions, path_model_folder
patch_size = config_dict["trainingset_patchsize"]
n_classes = config_dict["n_classes"]

########### STEP 1: we load and rescale the acquisitions, and transform them into patches.
# STEP 1: Load and rescale the acquisitions, and transform them into patches.

rs_acquisitions, rs_coeffs, original_acquisitions_shapes = load_acquisitions(
path_acquisitions, acquisitions_resolutions, resampled_resolutions, verbose_mode=verbosity_level)
Expand All @@ -56,27 +52,26 @@ def apply_convnet(path_acquisitions, acquisitions_resolutions, path_model_folder

L_data, L_n_patches, L_positions = prepare_patches(rs_acquisitions, patch_size, overlap_value)


########### STEP 2: Construction of Tensorflow's computing graph and restoration of the session
# STEP 2: Construct Tensorflow's computing graph and restoration of the session

# Construction of the graph
if verbosity_level>=2:
if verbosity_level >= 2:
print("Graph construction ...")
x = tf.placeholder(tf.float32, shape=(None, patch_size, patch_size))
pred = uconv_net(x, config_dict, phase=False, verbose=False) # Inference
saver = tf.train.Saver() # Loading the previous model
saver = tf.train.Saver() # Load previous model

# We limit the amount of GPU we are going to use for inference.
# We limit the amount of GPU for inference
config_gpu = tf.ConfigProto(log_device_placement=False)
config_gpu.gpu_options.per_process_gpu_memory_fraction = gpu_per

# Launch the session. This is the part that takes time, and we are now going to process all images by loading the session just once.
# Launch the session (this part takes time). All images will be processed by loading the session just once.
sess = tf.Session(config=config_gpu)
saver.restore(sess, os.path.join(path_model_folder, ckpt_name + '.ckpt'))

########### STEP 3: Inference
# STEP 3: Inference

if verbosity_level>=2:
if verbosity_level >= 2:
print("Beginning inference ...")

n_patches = len(L_data)
Expand All @@ -88,7 +83,7 @@ def apply_convnet(path_acquisitions, acquisitions_resolutions, path_model_folder
# Inference of complete batches
for i in range(it):

if verbosity_level>=3:
if verbosity_level >= 3:
print(('processing patch %s of %s' % (i+1, it)))

batch_x = np.asarray(L_data[i * inference_batch_size:(i + 1) * inference_batch_size])
Expand Down Expand Up @@ -178,7 +173,7 @@ def axon_segmentation(path_acquisitions_folders, acquisitions_filenames, path_mo
overlap_value=25, resampled_resolutions=0.1, acquired_resolution=None,
prediction_proba_activate=False, write_mode=True, gpu_per=1.0, verbosity_level=0):

'''
"""
Wrapper performing the segmentation of all the requested acquisitions and generates (if requested) the segmentation
images.
:param path_acquisitions_folders: List of folders where the acquisitions to segment are located.
Expand All @@ -189,18 +184,19 @@ def axon_segmentation(path_acquisitions_folders, acquisitions_filenames, path_mo
:param segmentations_filenames: List of the names of the segmentations files, to be used when creating the files.
:param inference_batch_size: Size of the batches fed to the network.
:param overlap_value: Int, number of pixels to use for overlapping the predictions.
:param resampled_resolutions: List of the resolutions (floats) we are going to resample to.
:param resampled_resolutions: List of the resolutions (in µm) to resample to.
:param acquired_resolution: List of the resolutions (in µm) for native images.
:param prediction_proba_activate: Boolean, whether to compute probability maps or not.
:param write_mode: Boolean, whether to create segmentation images or not.
:param gpu_per: Percentage of the GPU to use, if we use it.
:param verbosity_level: Int, level of verbosity. The higher, the more information is displayed.
:return: List of predictions, and optionally of probability maps.
'''
"""

# Processing input so they are lists in every situation
path_acquisitions_folders, acquisitions_filenames, resampled_resolutions, segmentations_filenames = list(map(
ensure_list_type, [path_acquisitions_folders, acquisitions_filenames,
resampled_resolutions, segmentations_filenames]))
path_acquisitions_folders, acquisitions_filenames, resampled_resolutions, segmentations_filenames = \
list(map(ensure_list_type, [path_acquisitions_folders, acquisitions_filenames, resampled_resolutions,
segmentations_filenames]))

if len(segmentations_filenames) != len(path_acquisitions_folders):
segmentations_filenames = ['AxonDeepSeg.png'] * len(path_acquisitions_folders)
Expand All @@ -216,25 +212,17 @@ def axon_segmentation(path_acquisitions_folders, acquisitions_filenames, path_mo

# If we did not receive any resolution we read the pixel size in micrometer from each pixel.
if acquired_resolution == None:

if os.path.exists(os.path.join(path_acquisitions_folders[0], 'pixel_size_in_micrometer.txt')):

resolutions_files = [open(os.path.join(path_acquisition_folder, 'pixel_size_in_micrometer.txt'), 'r')
for path_acquisition_folder in path_acquisitions_folders]

acquisitions_resolutions = [float(file_.read()) for file_ in resolutions_files]


else:

exception_msg = "ERROR: No pixel size is provided, and there is no pixel_size_in_micrometer.txt file in image folder. " \
"Please provide a pixel size (using argument -s), or add a pixel_size_in_micrometer.txt file " \
"containing the pixel size value."
raise Exception(exception_msg)



# If we received a resolution to use we use this one.
# If resolution is specified as input argument, use it
else:
acquisitions_resolutions = [acquired_resolution]*len(path_acquisitions_folders)

Expand All @@ -251,60 +239,55 @@ def axon_segmentation(path_acquisitions_folders, acquisitions_filenames, path_mo
gpu_per=gpu_per, verbosity_level=verbosity_level)
# Predictions are shape of image, value = class of pixel
else:
prediction = apply_convnet(path_acquisitions, acquisitions_resolutions, path_model_folder,
config_dict, ckpt_name=ckpt_name,
inference_batch_size=inference_batch_size, overlap_value=overlap_value,
resampled_resolutions=resampled_resolutions,
prediction_proba_activate=prediction_proba_activate,
gpu_per=gpu_per, verbosity_level=verbosity_level)
prediction = apply_convnet(path_acquisitions, acquisitions_resolutions, path_model_folder, config_dict,
ckpt_name=ckpt_name, inference_batch_size=inference_batch_size,
overlap_value=overlap_value, resampled_resolutions=resampled_resolutions,
prediction_proba_activate=prediction_proba_activate, gpu_per=gpu_per,
verbosity_level=verbosity_level)
# Predictions are shape of image, value = class of pixel

# Final part of the function : generating the image if needed/ returning values
if write_mode:
for i, pred in enumerate(prediction):
# We now transform the prediction to an image
# Transform the prediction to an image
n_classes = config_dict['n_classes']
paint_vals = [int(255 * float(j) / (n_classes - 1)) for j in range(n_classes)]

# Now we create the mask with values in range 0-255
# Create the mask with values in range 0-255
mask = np.zeros_like(pred)
for j in range(n_classes):
mask[pred == j] = paint_vals[j]
# Then we save the image
imsave(os.path.join(path_acquisitions_folders[i], segmentations_filenames[i]), mask, 'png')


axon_prediction, myelin_prediction = get_masks(os.path.join(path_acquisitions_folders[i], segmentations_filenames[i]))


if prediction_proba_activate:
return prediction, prediction_proba
else:
return prediction


# ---------------------------------------------------------------------------------------------------------

def ensure_list_type(elem):
'''
"""
Transforms the argument elem into a list if it's not already its type.
:param elem: Element to transform into a list.
:return: A list containing the element, or the element if it is already a list.
'''
"""
if type(elem) != list:
elem = [elem]
return elem


def load_acquisitions(path_acquisitions, acquisitions_resolutions, resampled_resolutions, verbose_mode=0):
'''
"""
Load and resamples acquisitions located in the indicated folders' paths.
:param path_acquisitions: List of paths to the acquisitions images.
:param acquisitions_resolutions: List of float containing the resolutions the acquisitions were acquired with.
:param resampled_resolutions: List of resolutions (floats) to resample to.
:param verbose_mode: Int, how much information to display.
:return:
'''
"""

path_acquisitions, acquisitions_resolutions, resampled_resolutions = list(map(
ensure_list_type, [path_acquisitions, acquisitions_resolutions, resampled_resolutions]))
Expand Down Expand Up @@ -337,13 +320,13 @@ def load_acquisitions(path_acquisitions, acquisitions_resolutions, resampled_res


def prepare_patches(resampled_acquisitions, patch_size, overlap_value=25):
'''
"""
Transform resampled acquisitions into patches. Each patch is also preprocessed during this step.
:param resampled_acquisitions: List of acquisitions images that have been resampled
:param patch_size: Input size of the network.
:param overlap_value: How much overlap to include when doing the inference.
:return: List of 512x512 patches ready to be fed to the network.
'''
"""


# Handle case when image is too small after resampling to target resolution of the model
Expand Down Expand Up @@ -380,7 +363,7 @@ def process_segmented_patches(predictions_list, L_n_patches, L_positions, L_orig
overlap_value, n_classes,
predictions_proba_list = None, prediction_proba_activate=False, verbose_mode=0):

'''
"""
Gathers the segmented patches into lists corresponding to each acquisition, stitches them and resamples them.
:param predictions_list: List of all segmented patches.
:param L_n_patches: List containing the number of patches related to each acquisition.
Expand All @@ -393,7 +376,7 @@ def process_segmented_patches(predictions_list, L_n_patches, L_positions, L_orig
:param verbose_mode: Int, the level of verbosity.
:return: the reconstructed list of segmentations, as well as the list of probability maps for each acquisition,
if requested.
'''
"""
patch_size = predictions_list[0].shape[0]
L_predictions = []
L_predictions_proba = []
Expand Down Expand Up @@ -454,7 +437,7 @@ def process_segmented_patches(predictions_list, L_n_patches, L_positions, L_orig

def perform_batch_inference(tf_session, tf_prediction_op, tf_input, batch_x, size_batch, input_size, n_classes,
prediction_proba_activate=False):
'''
"""
Performs the segmentation of all the patches in the batch.
:param tf_session: Current Tensorflow session.
:param tf_prediction_op: Tensorflow prediction operator.
Expand All @@ -464,7 +447,7 @@ def perform_batch_inference(tf_session, tf_prediction_op, tf_input, batch_x, siz
:param n_classes: Int, number of classes.
:param prediction_proba_activate: Boolean, whether to compute the probability maps or not.
:return: List of segmentation of the patches, and optionally list of the probabilty maps for each patch.
'''
"""

p = tf_session.run(tf_prediction_op, feed_dict={
tf_input: batch_x}) # here we get the predictions under prob format (float, between 0 and 1, shape = (bs, size_image*size_image, n_classes).
Expand Down
6 changes: 3 additions & 3 deletions notebooks/performance_metrics.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -309,14 +309,14 @@
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.10"
"pygments_lexer": "ipython3",
"version": "3.6.7"
}
},
"nbformat": 4,
Expand Down
Loading

0 comments on commit c712dc6

Please sign in to comment.