In [1]:
import os
import cv2 
import numpy as np
from tqdm.notebook import tqdm
from matplotlib import pyplot as plt
from src.utils.image_stuff import image_stuff

In [2]:
def find_optic_disc_mask(rgb_image, threshold):

    image = rgb_image.copy()
    blur_image = cv2.GaussianBlur(image,(25,25),0)
    red_blur_image = blur_image[:,:,0]

    clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
    clahe_equalized_image = clahe.apply(red_blur_image)
    
    ret, thresh_blur = cv2.threshold(clahe_equalized_image, threshold, 255, cv2.THRESH_BINARY)

    kernel = np.ones((5,5), np.uint8)
    opening = cv2.morphologyEx(thresh_blur, cv2.MORPH_OPEN, kernel, iterations = 5)

    contours, hierarchy = cv2.findContours(opening, cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE)
    shapes = []

    # out_image = image_rgb.copy()
    out_image = np.zeros(rgb_image.shape, dtype=np.uint8)
    for i in range(len(contours)):
        hull = cv2.convexHull(contours[i])
        area = cv2.contourArea(hull)
        x,y,w,h = cv2.boundingRect(hull)
        r = min(w,h)/2
        ideal_area = np.pi*r*r#*(w/2)*(h/2)

        if area/ideal_area >= 0.9:
            shapes.append([hull, area])
    shapes = sorted(shapes,key=lambda l:l[1], reverse=True)
    final_mask = shapes[0]
    return final_mask

In [3]:
def find_ROI(image):
    blur_image = cv2.GaussianBlur(image,(25,25),0)
    red_blur_image = blur_image[:,:,1]

    clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
    clahe_equalized_image = clahe.apply(red_blur_image)
    delta = 0
    for i in range(10):

        ret, thresh_blur = cv2.threshold(clahe_equalized_image, 250-delta, 255, cv2.THRESH_BINARY)

        kernel = np.ones((10,10), np.uint8)
        thresh_blur = cv2.morphologyEx(thresh_blur, cv2.MORPH_DILATE, kernel, iterations = 5)

        contours, hierarchy = cv2.findContours(thresh_blur, cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE)
        if len(contours) == 0:
            delta +=5
        else:
            break

    return contours

In [4]:
# https://www.geeksforgeeks.org/python-opencv-find-center-of-contour/
def extract_center(contour):
    M = cv2.moments(contour)
    cx = int(M['m10']/M['m00'])
    cy = int(M['m01']/M['m00'])
    return cx, cy

In [5]:
def find_optic_disc_from_red_channel(eq_image, roi_x, roi_y, d, eye_area):
    delta = 0
    is_it_inside = True
    final_mask = None
    mask = None
    threshold = 190 #int(cropped_image[:,:,0].mean())
    old_threshold = threshold
    delta = 0
    for i in range(25):
        try:
            mask = find_optic_disc_mask(eq_image, threshold)
        except:
            threshold -= 10
            continue
        cx, cy = extract_center(mask[0])
        if roi_x-d<=cx<=roi_x+d and roi_y-d<=cy<=roi_y+d:
            is_it_inside = True
            old_threshold = threshold
        else:

            is_it_inside = False
            threshold = old_threshold + np.random.randint(-25,25)
            threshold = min(threshold, 254)
            threshold = max(threshold, 75)
            continue
        mask_area =  mask[1]
        _, radius = cv2.minEnclosingCircle(mask[0])
    #     mask_area = 3.14159 * radius**2
        area_ratio = mask_area/eye_area

        if 0.014 - delta <= area_ratio <= 0.025 + delta:
            final_mask = mask
            break
        elif area_ratio > 0.025:
            threshold_delta = 20 + np.random.randint(-15,15)
            threshold += threshold_delta
            threshold = min(threshold, 254)

        elif  area_ratio < 0.014:
            threshold_delta = -20 + np.random.randint(-15,15)
            threshold +=  threshold_delta
            threshold = max(threshold, 75)
        if i> 4 and is_it_inside:
            delta += 0.0005
    return final_mask

In [6]:
def extract_optic_disc(image_path, outout_dir_path):
    
    
    bgr_image = cv2.imread(image_path)
    rgb_image = cv2.cvtColor(bgr_image, cv2.COLOR_BGR2RGB)
    
    eye_area = image_stuff.find_eye_area(rgb_image)
    cropped_image = image_stuff.crop_eye_area(rgb_image)
    eq_image = image_stuff.fix_illumination(cropped_image)

    
    height, width = eq_image.shape[:2]
    d = int(0.075*(height+width))//2
        
    roi_contours = find_ROI(eq_image)

    if len(roi_contours)==1:
        roi_x, roi_y = extract_center(roi_contours[0])
        
        final_mask = find_optic_disc_from_red_channel(eq_image, roi_x, roi_y, d, eye_area) 
        if final_mask:
            
            image_name = os.path.basename(image_path)
            bbox_image = image_stuff.draw_bbox_form_contour(eq_image, final_mask[0])
            cv2.imwrite(os.path.join(outout_dir_path,f"{image_name[:-4]}_bbox.jpg"), cv2.cvtColor(bbox_image, cv2.COLOR_RGB2BGR))
            cv2.imwrite(os.path.join(outout_dir_path,f"{image_name[:-4]}.jpg"), cv2.cvtColor(eq_image, cv2.COLOR_RGB2BGR))

            x,y,w,h = cv2.boundingRect(final_mask[0])
            coordinates = [x, y, x+w, y+h]
            with open(os.path.join(outout_dir_path, f"{image_name[:-4]}.txt"), "w") as f:
                for coordinate in coordinates:
                    f.write(f"{int(coordinate)}\n") 

In [7]:
dataset_dir_path = "/home/hue/Codes/AIROGS/datasets/5_normalized_224x224"
images_path = [os.path.join(dataset_dir_path,f) for f in sorted(os.listdir(f'{dataset_dir_path}'))]
print(f"Total number of images: {len(images_path)}")

Total number of images: 43


In [8]:
outout_dir_path = "/home/hue/Codes/AIROGS/datasets/threshold_output_test"
if not os.path.exists(outout_dir_path):
    os.mkdir(outout_dir_path)

In [9]:
extract_optic_disc(images_path[0], outout_dir_path)