# Detect rings with MRCNN
##### Braedyn Au 30/8/19
Description

Using the trained model, detect, crop, and pad rings out of the image each to 120x120 sizes. Biggest known issue is that the ring center may not align with the center of the image. Further processing in ImageJ may help but due to rings not being perfectly symmetrical, manual editing will probably be unavoidable.   

In [1]:
from os import listdir
import os
import scipy
import numpy as np
import skimage.color
import skimage.io
import skimage.transform
from skimage.util import pad
from numpy import zeros
from numpy import asarray
from numpy import expand_dims
from matplotlib import pyplot
from matplotlib.patches import Rectangle
from mrcnn.config import Config
from mrcnn.model import MaskRCNN
from mrcnn.model import mold_image
from mrcnn.utils import Dataset
from mrcnn import visualize as vs
#IF WINDOWS OS THEN USE tkinter 
from tkinter import filedialog, messagebox
from tkinter import *

Using TensorFlow backend.


# Configuration
To change parameters for prediction, same as we did while training. Overwrite original configuration class in config.py. DETECTION_MIN_CONFIDENCE is a parameter you may want to play around with.

In [2]:
class PredictionConfig(Config):
    # define the name of the configuration
    NAME = "Septin_cfg"
    # number of classes (background + Septin)
    NUM_CLASSES = 1 + 1
    # simplify GPU config
    GPU_COUNT = 1
    IMAGES_PER_GPU = 1   
    DETECTION_MIN_CONFIDENCE = 0.8

# Prediction
Now we can load images and detect the rings using the model we just trained. In the notebook, we will also plot them nice and pretty. However, to prevent a ton of images popping up when processing a large number of images, this option is defaulted to not occur.

The first cell is a function to plot the images with boxes around the rings, and also to crop those boxes and save them as tif files.

The second cell takes your input for images, model, and output folder, and then runs the function.

In [3]:
# plot a number of photos with ground truth and predictions
def plot_predicted(imageDirectory, outputDirectory, model, cfg, plot=False):
    # load image and mask
    for i in listdir(imageDirectory):
        if i.endswith('.tif'):
            image_number = listdir(imageDirectory).index(i)
            # load the image and mask
            
            imgPath = os.path.join(imageDirectory, i)
            image = skimage.io.imread(imgPath)
            #image = SeptinDataset.load_Septin(imgPath)
            outPath = os.path.join(outputDirectory, i)
            # If grayscale. Convert to RGB for consistency.
            if image.ndim != 3:
                image = skimage.color.gray2rgb(image)
            # If has an alpha channel, remove it for consistency
            if image.shape[-1] == 4:
                image = image[..., :3]
            #mask, _ = SeptinDataset.load_mask(i)
            # convert pixel values (e.g. center)
            scaled_image = mold_image(image, cfg)
            # convert image into one sample
            sample = expand_dims(scaled_image, 0)
            # make prediction
            yhat = model.detect(sample, verbose=0)[0]
            # define subplot
            if plot:
                pyplot.imshow(image)

                mask = yhat['masks']
                for j in range(mask.shape[2]):
                    pyplot.imshow(mask[:, :, j], cmap='gray', alpha=0.3)
                    pyplot.title('Predicted Mask')

                pyplot.title('Predicted')
                ax = pyplot.gca()
            # plot each box
            n = 1
            print("Looking at image", i ,"...")
            for box in yhat['rois']:
                # get coordinates
                y1, x1, y2, x2 = box
                # increase box size
                x1 = x1 - 10
                y1 = y1 - 10
                y2 = y2 + 10
                x2 = x2 + 10
                # calculate width and height of the box
                width, height = x2 - x1, y2 - y1
                # create the shape
                rect = Rectangle((x1, y1), width, height, fill=False, color='red')
                # draw the box
                if plot:
                    ax.add_patch(rect)
                # Cut the rois from each image
                # take output directory as user input in seperate script
                img_crop = image[y1:y2,x1:x2]
                # pad to 100x100
                h, w, d = img_crop.shape
                if h < 120 and w < 120:
                    dh = (120-h)//2
                    dw = (120-w)//2
                    img_crop = pad(img_crop, ((dh,0),(dw,0),(0,0)), 'constant')
                    h, w, d = img_crop.shape
                    img_crop = pad(img_crop,((0,120-h),(0,120-w), (0,0)),'constant')
                else:
                    pass
                if img_crop.size != 0 and np.mean(img_crop) != 0:
                    skimage.io.imsave(outPath+str(n)+'.tif',img_crop, check_contrast = False)
                    
                n = n + 1
            # show the figure
            if plot:
                pyplot.show()
            print('Rings cropped: ' , n-1)
            
        
        else:
            pass

In [4]:
# load the images
messagebox.showinfo("Images", "Load the folder containing images")
imageDirectory = filedialog.askdirectory()
#train_set.load_dataset('Ring', is_train=True)
#train_set.prepare()
#print('Train: %d' % len(train_set.image_ids))
# create config
cfg = PredictionConfig()
# define the model
model = MaskRCNN(mode='inference', model_dir='./', config=cfg)
# load model weights with tkinter
messagebox.showinfo("Model", "Load model found in septin_cfg... folder")
model_path = filedialog.askopenfilename() 
model.load_weights(model_path, by_name=True)
messagebox.showinfo("Output", "Choose output folder")
outputDirectory = filedialog.askdirectory()

# plot predictions for images
plot_predicted(imageDirectory, outputDirectory, model, cfg)

W0904 09:47:47.059910 18548 deprecation_wrapper.py:119] From C:\ProgramData\Anaconda3\lib\site-packages\keras\backend\tensorflow_backend.py:517: The name tf.placeholder is deprecated. Please use tf.compat.v1.placeholder instead.

W0904 09:47:47.085911 18548 deprecation_wrapper.py:119] From C:\ProgramData\Anaconda3\lib\site-packages\keras\backend\tensorflow_backend.py:74: The name tf.get_default_graph is deprecated. Please use tf.compat.v1.get_default_graph instead.

W0904 09:47:47.090912 18548 deprecation_wrapper.py:119] From C:\ProgramData\Anaconda3\lib\site-packages\keras\backend\tensorflow_backend.py:4138: The name tf.random_uniform is deprecated. Please use tf.random.uniform instead.

W0904 09:47:47.124913 18548 deprecation_wrapper.py:119] From C:\ProgramData\Anaconda3\lib\site-packages\keras\backend\tensorflow_backend.py:1919: The name tf.nn.fused_batch_norm is deprecated. Please use tf.compat.v1.nn.fused_batch_norm instead.

W0904 09:47:47.128914 18548 deprecation_wrapper.py:119]

Re-starting from epoch 8
Looking at image 20190822_S2aS5b-g_3-15.tif ...
Rings cropped:  9
Looking at image 20190822_S2aS5b-g_3-16.tif ...
Rings cropped:  2
Looking at image 20190822_S2aS5bS2c-g_1-11.tif ...
Rings cropped:  2
Looking at image 20190822_S2aS5bS2c-g_1-12.tif ...
Rings cropped:  1
Looking at image 20190822_S2aS5bS2c-g_1-5.tif ...
Rings cropped:  0
Looking at image 20190822_S2aS5bS2c-g_1-6.tif ...
Rings cropped:  1
Looking at image 20190822_S2aS5bS2c-g_2-1.tif ...
Rings cropped:  2
Looking at image 20190822_S2aS5bS2c-g_2-13.tif ...
Rings cropped:  0
Looking at image 20190822_S2aS5bS2c-g_2-14.tif ...
Rings cropped:  8
Looking at image 20190822_S2aS5bS2c-g_2-2.tif ...
Rings cropped:  4
Looking at image 20190822_S2aS5bS2c-g_2-7.tif ...
Rings cropped:  7
Looking at image 20190822_S2aS5bS2c-g_2-8.tif ...
Rings cropped:  3
Looking at image 20190822_S2aS5bS2c-g_3-10.tif ...
Rings cropped:  3
Looking at image 20190822_S2aS5bS2c-g_3-15.tif ...
Rings cropped:  0
Looking at image 2019