## Mask images trough thresholding and ORB

In [None]:
import os
import shutil
from pathlib import Path

import cv2
import matplotlib.image as mpimg
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from skimage import img_as_ubyte
from skimage.color import label2rgb, rgb2gray
from skimage.filters import threshold_otsu
from skimage.measure import label, regionprops
from skimage.morphology import (closing, disk, remove_small_holes,
                                remove_small_objects)
from skimage.segmentation import clear_border


In [None]:
def mask_labels(label_filename, dataset_dir):
    path = Path(dataset_dir / label_filename)
    image = img_as_ubyte(rgb2gray(mpimg.imread(path)))
    thresh = threshold_otsu(image)
    bw = closing(image > thresh, disk(5))
    img_size = image.shape[0] * image.shape[1]
    mask = remove_small_objects(bw, int(img_size/50))
    mask = remove_small_holes(mask, int(img_size/10))

    label_image = label(mask)
    regions_before = regionprops(label_image)
    label_image = clear_border(label_image) ## could use a better algorithm to detect ruler and remove it
    regions_after = regionprops(label_image)
    # label_image = (label_image > 0)
    if len(regions_before) > len(regions_after)+1: return mask, False
    else: return label_image, True

def orb_find_label_coords(both_original, labels_original, mask_in):
    #https://docs.opencv.org/master/d1/de0/tutorial_py_feature_homography.html
    labels = cv2.cvtColor(labels_original,  cv2.COLOR_BGR2GRAY)
    both = cv2.cvtColor(both_original, cv2.COLOR_BGR2GRAY)
    # print(both.shape, labels.shape)

    orb = cv2.ORB_create(nfeatures = 5000)
    kp_labels, des_labels = orb.detectAndCompute(labels,None)
    kp_both, des_both = orb.detectAndCompute(both,None)
    bf = cv2.BFMatcher(cv2.NORM_HAMMING, crossCheck=True)
    matches = bf.match(des_labels,des_both)
    matches = sorted(matches, key = lambda x:x.distance)
    good_matches = matches[:10]
    # print(good_matches)

    src_pts = np.float32([ kp_labels[m.queryIdx].pt for m in good_matches ]).reshape(-1,1,2)
    dst_pts = np.float32([ kp_both[m.trainIdx].pt for m in good_matches ]).reshape(-1,1,2)
    M, mask = cv2.findHomography(src_pts, dst_pts, cv2.RANSAC,8.0)
    # print(M)
    mask_in = cv2.warpPerspective(np.float32(mask_in), M, (both.shape[1], both.shape[0])) 

    matchesMask = mask.ravel().tolist()
    h,w = labels.shape[:2]
    pts = np.float32([ [0,0],[0,h-1],[w-1,h-1],[w-1,0] ]).reshape(-1,1,2)
    dst = cv2.perspectiveTransform(pts,M)
    dst += (w, 0)
    draw_params = dict(matchColor = (255,0,0),
                singlePointColor = None,
                matchesMask = matchesMask,
                flags = 2)
    orb_img = cv2.drawMatches(labels, kp_labels, both, kp_both, good_matches, None,**draw_params)
    orb_img = cv2.polylines(orb_img, [np.int32(dst)], True, (0,0,255),3, cv2.LINE_AA)
    return labels, mask_in, orb_img

def match_and_mask_both_label(both_filename, label_filename, mask_in, dataset_dir):
    both_path = Path(dataset_dir / both_filename)
    label_path = Path(dataset_dir / label_filename)
    both_original = mpimg.imread(both_path)
    both = rgb2gray(both_original)
    labels_original = mpimg.imread(label_path)
    labels = rgb2gray(labels_original)

    # labels, mask = match_both_label(both, labels, mask_in)
    labels, mask, orb_img = orb_find_label_coords(both_original, labels_original, mask_in)
    image_label_overlay = label2rgb(mask, image=both_original, bg_label=0, kind='overlay')
    fig, axs = plt.subplots(1, 2, figsize=(16,6))
    fig.suptitle(both_filename + str(labels.shape) + str(labels.shape))
    axs[0].imshow(orb_img)
    axs[1].imshow(image_label_overlay)
    plt.axis('off')
    fig.tight_layout()
    plt.show()
    return mask, image_label_overlay


In [None]:
labeled = pd.read_csv('datasets/label_specimen_mask.csv', index_col=0)
# labeled

In [None]:
base_dir = Path('c:/Users/flori/download/segmentation')
masks_out_path = Path(base_dir / 'masks')
overlayed_out_path = Path(base_dir / 'overlayed')
dataset_dir = Path('c:/Users/flori/download/subset')
done_dir = Path(base_dir / 'done')

input_dataset = os.listdir(dataset_dir)
labeled_checked = labeled[labeled.filename.isin(input_dataset)]

filenames_id = labeled_checked.pivot(index='id', columns='labels', values='filename')
weird_indexes =     [6768373, 5730836, 5730537, 2883099, 2883143, 2883143, 2883181, 2883197, 2883199, 2883241,
                     2883252, 2883297, 2883349, 2883369, 2883370, 2883398, 2883401, 2883405, 2883454, 2883537,
                     2883556, 2883623, 2883731][:10]
# filenames_list = filenames_id[filenames_id.index.isin(weird_indexes)].iterrows()
filenames_list = filenames_id[0:10].iterrows()

for index, row in filenames_list:
    mask, success = mask_labels(row[1], dataset_dir)
    print(index)
    if success:
        print('success')
        mask, image_label_overlay = match_and_mask_both_label(row[0], row[1], mask, dataset_dir)
        plt.imsave(Path(masks_out_path / row[1]), mask)
        plt.imsave(Path(overlayed_out_path / row[1]), image_label_overlay)
    shutil.move(Path(dataset_dir / row[1]),  Path(done_dir / row[1]))
    shutil.move(Path(dataset_dir / row[0]),  Path(done_dir / row[0]))
