# Statement of the problem
In this notebook I briefly describe a naive method that attempts to separate merged(overlapping) nuclei masks using OpenCV convexity analysis. The separating method can be applied to a segmentation output after applying skimage `label` or methods of sort to separated connected clusters. The analysis bellow makes the assumption that nuclei are convex-shaped (This is not 100% correct) .

1. Show some examples of overlapping nuclei
2.  Describe the convexity based separation code
3. Apply the correction to one of the benchmar submissions (It actually improved LB 0.274 -> LB 0.300)

## 1. Show some examples
Let us read some train images and their corresponding masks.

In [13]:
import os
import cv2
import numpy as np
from matplotlib import pyplot as plt

train_dirs = os.listdir("../input/data-science-bowl-2018/stage1_train/")
M = {}
I = {}
for file_id in train_dirs[:10]:
    train_filename = "../input/data-science-bowl-2018/stage1_train/"+file_id+"/images/"+file_id+".png"
    I[file_id] = cv2.imread(train_filename) 
    train_mask_dirs = "../input/data-science-bowl-2018/stage1_train/"+file_id+"/masks/"
    train_mask_files = os.listdir(train_mask_dirs)
    M[file_id] = []
    for train_mask in train_mask_files:
        mask = cv2.imread(train_mask_dirs + train_mask, cv2.IMREAD_GRAYSCALE)
        M[file_id].append(mask)

Plot second image (`7f34dfccd1bc2e2466ee3d6f74ff05821a0e5404e9cf2c9568da26b59f7afda5`) and the `np.amax` of all corresponding masks. Ac

In [14]:
id = 1
image = I[train_dirs[id]]
mask = np.amax(M[train_dirs[id]], axis=0)
plt.imshow(image)
plt.show()
plt.imshow(mask)
plt.show()


Use the skimage `label` to separate the masks into connected blobs.

In [15]:
from skimage.morphology import label
labeled_mask = label(mask)
print ('Found ', len(np.unique(labeled_mask)), ' connected masks')
B = []
for i in np.unique(labeled_mask):
    if i == 0: # id = 0 is for background
        continue
    mask_i = (labeled_mask==i).astype(np.uint8)
    B.append(mask_i)
n_blobs = len(B)

In [16]:
f, axarr = plt.subplots(6, 7)
f.set_figwidth(20)
f.set_figheight(20)

for i in range(6):
    for j in range(7):
        k = i*7 + j
        if k < n_blobs:
            axarr[i][j].set_title('mask:' + str(k))
            axarr[i][j].imshow(B[k])     
plt.show()

I guess by now you have noticed cases like mask 1, mask 5, mask 12 and mask 13 that deviate from nice convex shapes. 

## 2. Describe the separation idea
The idea is to operate on a blob-by-blob fashion and detect the points that deviate from the convex hull that is around the blob. More details [here](http://https://docs.opencv.org/3.0-beta/doc/py_tutorials/py_imgproc/py_contours/py_contours_more_functions/py_contours_more_functions.html). The basic idea is described in the steps below:
1. Given a segmentation mask
2. Apply `skimage.morphology.label` or `cv2.connectedComponents` to split it into disjoint masks
3. For each one disjoint mask
    
    If the disjoint mask is not convex enough
    
     a.  Identify convexity defects points
     
      b. Draw a line of background (0) color to connect every two nearest defect points

In [17]:
from IPython.display import Image
Image("../input/separation-case-image/mask_separation.png")


      
### 2.1 Convex enough
We use `skimage.measure.regionprops` method to characterize each individual mask. More specifically we check the ratio of `prop.convex_area/prop.filled_area`. If it is bigger than a predefined threshold then the shape of the mask is non-convex. 

```
from skimage.measure import regionprops
 props = regionprops(mask_i, cache=False)
  prop = props[0]
  if prop.convex_area/prop.filled_area > 1.1:
      print ('no convex')
```
 
 Let us apply the above filter to the 42 masks of the example above:

In [39]:
from skimage.measure import regionprops
f, axarr = plt.subplots(6, 7)
f.set_figwidth(20)
f.set_figheight(20)

for i in range(6):
    for j in range(7):
        k = i*7 + j
        if k < n_blobs:
            props = regionprops(B[k], cache=False)
            prop = props[0]
            if prop.convex_area/prop.filled_area > 1.1:
                axarr[i][j].set_title('non convex mask:' + str(k))
            else:
                axarr[i][j].set_title('convex mask:' + str(k))
                
            axarr[i][j].imshow(B[k])     
plt.show()

I guess by now we have a pretty good indication of convexity. We can move to correction:

### 2.2 Correcting con-conex masks
The idea is based on the example [here](http://https://docs.opencv.org/3.0-beta/doc/py_tutorials/py_imgproc/py_contours/py_contours_more_functions/py_contours_more_functions.html). Our aim is to locate convexity defect points and connect them with a straight line (background color) hence separating them. The defect points are returned as as list of : ` [ start point, end point, farthest point, approximate distance to farthest point ]`. We can use the approximate distance to the farthest point as indication of "badness". Essentially we gather all defect points and filter them using this approximate distance. In the code below we filter the defect points according to:
![](http://)

1.  `if dd[i] > 1.0` :  We need then normalized  distance to be greater than 1
2. `dd[i]/np.max(dd) > 0.2`:  We need the point to be at least 80% far compared to the farthest point.

After gathering all convexity defect points we begin connecting them by simply connecting the two nearest each time. This is naively implemented using two nested loops:
```
 for i in range(len(points)):
    f1 = points[i]
    p1 = tuple(contour[f1][0])
    nearest = None
    min_dist = np.inf
    for j in range(len(points)):
        ...
         dist = (p1[0]-p2[0])*(p1[0]-p2[0]) + (p1[1]-p2[1])*(p1[1]-p2[1]) 
         if dist < min_dist:
             ...
             
    # Connect point p1 to its nearest one.
    cv2.line(thresh,p1, nearest, [0, 0, 0], 2)
```

In [40]:
def split_mask_v1(mask):
    thresh = mask.copy().astype(np.uint8)
    im2, contours, hierarchy = cv2.findContours(thresh, 2, 1)
    i = 0 
    for contour in contours:
        if  cv2.contourArea(contour) > 20:
            hull = cv2.convexHull(contour, returnPoints = False)
            defects = cv2.convexityDefects(contour, hull)
            if defects is None:
                continue
            points = []
            dd = []

            #
            # In this loop we gather all defect points 
            # so that they can be filtered later on.
            for i in range(defects.shape[0]):
                s,e,f,d = defects[i,0]
                start = tuple(contour[s][0])
                end = tuple(contour[e][0])
                far = tuple(contour[f][0])
                d = d / 256
                dd.append(d)

            for i in range(len(dd)):
                s,e,f,d = defects[i,0]
                start = tuple(contour[s][0])
                end = tuple(contour[e][0])
                far = tuple(contour[f][0])
                if dd[i] > 1.0 and dd[i]/np.max(dd) > 0.2:
                    points.append(f)

            i = i + 1
            if len(points) >= 2:
                for i in range(len(points)):
                    f1 = points[i]
                    p1 = tuple(contour[f1][0])
                    nearest = None
                    min_dist = np.inf
                    for j in range(len(points)):
                        if i != j:
                            f2 = points[j]                   
                            p2 = tuple(contour[f2][0])
                            dist = (p1[0]-p2[0])*(p1[0]-p2[0]) + (p1[1]-p2[1])*(p1[1]-p2[1]) 
                            if dist < min_dist:
                                min_dist = dist
                                nearest = p2

                    cv2.line(thresh,p1, nearest, [0, 0, 0], 2)
    return thresh     

Let us see the separation in action.

In [41]:
b_split = split_mask_v1(B[1])
f, axarr = plt.subplots(1, 2)
f.set_figheight(8)
f.set_figwidth(16)

axarr[0].set_title('Overlapping masks')
axarr[0].imshow(B[1], cmap='gray')
axarr[1].set_title('Separated masks')
axarr[1].imshow(b_split, cmap='gray')
plt.show()

## 3. Modify public kernel [Pure image processing LB 0.274](https://www.kaggle.com/ahassaine/pure-image-processing-lb-0-274)
Most of the following code is borrowed from the Ali Hassa飊e' s public kernel. 

In [42]:
test_dirs = os.listdir("../input/data-science-bowl-2018/stage1_test/")
test_filenames=["../input/data-science-bowl-2018/stage1_test/"+file_id+"/images/"+file_id+".png" for file_id in test_dirs]
test_images=[cv2.imread(imagefile) for imagefile in test_filenames]

In [43]:
from skimage.measure import regionprops
def process(img_rgb):
    #green channel happends to produce slightly better results
    #than the grayscale image and other channels
    img_gray=img_rgb[:,:,1]#cv2.cvtColor(img_rgb, cv2.COLOR_BGR2GRAY)
    #morphological opening (size tuned on training data)
    circle7=cv2.getStructuringElement(cv2.MORPH_ELLIPSE,(7,7))
    img_open=cv2.morphologyEx(img_gray, cv2.MORPH_OPEN, circle7)
    #Otsu thresholding
    img_th=cv2.threshold(img_open,0,255,cv2.THRESH_OTSU)[1]
    #Invert the image in case the objects of interest are in the dark side
    if(np.sum(img_th==255)>np.sum(img_th==0)):
        img_th=cv2.bitwise_not(img_th)
    #second morphological opening (on binary image this time)
    bin_open=cv2.morphologyEx(img_th, cv2.MORPH_OPEN, circle7) 
    #connected components
    cc=cv2.connectedComponents(bin_open)[1]
    #cc=segment_on_dt(bin_open,20)
    return cc

def rle_encoding(cc):
    values=list(np.unique(cc))
    values.remove(0)
    RLEs=[]
    for v in values:
        dots = np.where(cc.T.flatten() == v)[0]
        run_lengths = []
        prev = -2
        for b in dots:
            if (b>prev+1):
                run_lengths.extend((b + 1, 0))
            run_lengths[-1] += 1
            prev = b
        RLEs.append(run_lengths)
    return RLEs



Except this function which does apply the separation and filtering:

1. Takes as input as mask created by `cv2.connectedComponents`. That means that background is set to 0 and mask pixels have values from 1, 2, ...
2. Breaks the masks by filtering value: `` `mask_i = (mask==i).astype(np.uint8)`
3. Calculate the `regionprops` properties: ` props = regionprops(mask_i, cache=False)` 
4. If `prop.convex_area/prop.filled_area > 1.1` proceeds on splitting the mask
5. Appends the splitted mask into a list

Finally all masks are grouped together once more into as single one  ` np.amax(masks, axis=0)` and we call `label` to separate them once more.

In [44]:
def split_and_relabel(mask):
    masks = []
    for i in np.unique(mask):
        if i == 0: # id = 0 is for background
            continue
        mask_i = (mask==i).astype(np.uint8)
        props = regionprops(mask_i, cache=False)
        if len(props) > 0:
            prop = props[0]
            if prop.convex_area/prop.filled_area > 1.1:
                mask_i = split_mask_v1(mask_i)
        masks.append(mask_i)
        
    masks = np.array(masks)
    masks_combined = np.amax(masks, axis=0)
    labels = label(masks_combined)
    return labels

Process each test image and create initial masks.

In [45]:
test_connected_components=[process(img)  for img in test_images]

Use the splitting methology to each case:

In [46]:
test_connected_components_split=[split_and_relabel(img)  for img in test_connected_components]

Let us see what happened to three cases:

In [None]:
f, axarr = plt.subplots(3, 2)
f.set_figwidth(16)
f.set_figheight(16)

for i in range(3):
    axarr[i, 0].set_title('Original ')
    axarr[i, 0].imshow(test_connected_components[i])
    axarr[i, 1].set_title('Splitted masks ')
    axarr[i, 1].imshow(test_connected_components_split[i])
plt.show()


By visually inspecting we can see some good cases and some over-splitted cases. There is no free lunch after all. However I am working on some extentions:

* Connecting only nearest convexity defect points is obviously a poor way
* For masks near the margins we may not find symmetric convexity defect point

In [None]:
test_RLEs=[rle_encoding(cc) for cc in test_connected_components_split]

In [49]:
with open("predictions.csv", "w") as myfile:
    myfile.write("ImageId,EncodedPixels\n")
    for i,RLEs in enumerate(test_RLEs):
        for RLE in RLEs:
            myfile.write(test_dirs[i]+","+" ".join([str(i) for i in RLE])+"\n")