In [1]:
import sys
import numpy as np
import cv2
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import matplotlib.gridspec as gridspec
from matplotlib import cm
import scipy
import scipy.signal
from scipy.ndimage.filters import gaussian_filter
from skimage.morphology import opening, binary_erosion, binary_opening, binary_closing

In [2]:
# https://ipython.readthedocs.io/en/stable/config/options/terminal.html
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "last_expr"

In [3]:
def threshold_image(img, show=False):
    _, img_otsu = cv2.threshold(img, 0, 255, cv2.THRESH_OTSU+cv2.THRESH_TOZERO)
    _, img_otsu_binary = cv2.threshold(img_otsu, 0, 255, cv2.THRESH_BINARY_INV)
    
    if show:
        fig, ax = plt.subplots(1,3,figsize=(18,12))
        ax[0].imshow(img, cmap=plt.cm.gray)
        ax[0].set_title("Original")
        ax[1].imshow(img_otsu, cmap=plt.cm.gray)
        ax[1].set_title("Otsu To Zero")
        ax[2].imshow(img_otsu_binary, cmap=plt.cm.gray)
        ax[2].set_title("Otsu Binary")
        
    return img_otsu_binary

In [4]:
def get_erroded_img(img, local=False, global_threshold=70, show=False):
    if local:
        eroded_img = binary_opening(img)
    else:
        global_thresh_img = ~(img > global_threshold)
        eroded_img = binary_opening(global_thresh_img)
    
    if show:
        fig, ax = plt.subplots(1,2,figsize=(18,12))
        ax[0].imshow(img, cmap=plt.cm.gray)
        ax[0].set_title("Original")
        ax[1].imshow(eroded_img, cmap=plt.cm.gray)
        ax[1].set_title("Eroded Image")
        
    return eroded_img

In [5]:
def edge_detection(img):
    # Set img to apply Canny detection
    im = img.astype(np.float)
    edge_img = cv2.Canny((im*255).astype(np.uint8), 10, 100)

    # Gauss Blur (helps with stronger edges)
    edge_img = cv2.GaussianBlur(edge_img, (3,3), 0)

    # Otsu threshold + binary again
    _, edge_img = cv2.threshold(edge_img, 0, 255, cv2.THRESH_BINARY+cv2.THRESH_OTSU)

    # Plot
    fig, ax = plt.subplots(1, 2, figsize=(20,12))
    _ = ax[0].imshow(img, cmap=plt.cm.gray)
    _ = ax[0].set_title('Original')
    _ = ax[1].imshow(edge_img, cmap=plt.cm.gray)
    _ = ax[1].set_title('Canny Edges')

In [520]:
def create_overlay_img(img_original, stats, centroids, label_color=None, bee_type='good', show=True, final=False):
    
    overlay_img = np.copy(img_original).astype(np.uint8)
    
    # Generate overlay image
    if len(img_original.shape) == 2:
        overlay_img = cv2.cvtColor(overlay_img, cv2.COLOR_GRAY2RGB)
    
    if overlay_img.max() == 1:
        overlay_img *= 255
        
    if label_color is None:
        label_color = (0,255,0) if bee_type is 'good' else (255, 0, 0)
    
    # Queen
    cv2.circle(overlay_img, (QUEEN_X, QUEEN_Y), 60, (255, 0, 0), -1)
    
    # Workers
    distances = []
    for stat, centroid in zip(stats, centroids):
        top_left = tuple(stat[:2])
        h,w = stat[2:4]
        bottom_right = (top_left[0]+h, top_left[1]+w)
        centroid_x, centroid_y = centroid.astype(np.int)
        
        cv2.line(overlay_img, (centroid_x, centroid_y), (QUEEN_X, QUEEN_Y), (102, 0, 255), LINE_THICKNESS)
        cv2.rectangle(overlay_img, top_left, bottom_right, label_color, BOX_SIZE)
        cv2.circle(overlay_img, (centroid_x, centroid_y), DOT_SIZE, label_color, -1)
    
        dx = QUEEN_X-centroid_x
        dy = QUEEN_Y-centroid_y
        distance_i = int(np.sqrt((dx)**2 + (dy)**2))
        distances.append(distance_i)
        distance_str = f"{distance_i} px"
        cv2.putText(overlay_img, distance_str, (centroid_x, centroid_y), FONT, FONT_SIZE, (255, 255, 255), 2, cv2.LINE_AA)
    
    if final:
        # Num detections
        num_detections = len(centroids)
        cv2.putText(overlay_img, f"Num Detections: {num_detections}", (150, 150), FONT, 2, (255, 255, 255), 4, cv2.LINE_AA)
        
        ave_distance = int(np.mean(distances))
        ave_distance_str = f"Ave Distance: {ave_distance} px"
        cv2.putText(overlay_img, ave_distance_str, (150, 210), FONT, 2, (255, 255, 255), 4, cv2.LINE_AA)
        
    
    if show:
        fig, ax = plt.subplots(figsize=(12,12))
        ax.imshow(overlay_img)
        
    return overlay_img

In [521]:
def run_connected_components(img, min_factor=0.00001, max_factor=0.3, prev_stats=None):
    img = np.copy(img).astype(np.uint8)
    
    # Run connected components
    num_regions, regions, stats, centroids = cv2.connectedComponentsWithStats(img)
    
    # Filter out by area
    MIN_AREA = np.product(img.shape[:2])*min_factor
    MAX_AREA = np.product(img.shape[:2])*max_factor # Half of image area
    
    condition = (stats[:,-1] < MAX_AREA) & (stats[:,-1] > MIN_AREA)

    new_stats = stats[condition]
    new_centroids = centroids[condition]
    
    if len(new_centroids) == 0:
        return None, None
    
    if prev_stats is None:
        stats = {
            "local"  : new_stats,
            "global" : new_stats
        }
        centroids = new_centroids
    else:
        global_stats = np.copy(new_stats)
        for crop_i in range(len(global_stats)):
            global_stats[crop_i,:2] += prev_stats[:2]
        stats = {
            "local"  : new_stats,
            "global" : global_stats
        }
        centroids = new_centroids
        for centroid_i in range(len(centroids)):
            centroids[centroid_i][0] += prev_stats[0]
            centroids[centroid_i][1] += prev_stats[1]
        
    return stats, centroids

In [522]:
def get_squarish_rows_cols(num):
    factors = [i for i in range(2, num)[::-1] if num%i==0]
    if len(factors) == 1:
        rows, cols = factors[0], factors[0]
    elif len(factors) == 0:
        rows, cols = num, 1
    else:
        mid = len(factors)//2
        rows = factors[mid]
        cols = factors[mid-1]
    return rows, cols

In [523]:
def crop_bees(img, stats):
    src_img = np.copy(img)
    cropped_imgs = []
    for stat in stats:
        top_left_x, top_left_y = stat[:2]
        h,w = stat[2:4]
        bottom_right = (top_left_x+h, top_left_y+w)
        cropped_img = src_img[top_left_y:top_left_y+w, top_left_x:top_left_x+h]
        cropped_imgs.append(cropped_img)
    try:
        cropped_imgs = np.array(cropped_imgs)
    except:
        cropped_imgs = None
        
    return cropped_imgs

In [524]:
def plot_cropped_imgs(cropped_imgs, stats):
    nrows, ncols = get_squarish_rows_cols(len(cropped_imgs))
    
    if nrows == 1:
        if ncols > 10:
            fig, ax = plt.subplots(ncols, figsize=(12,12 + 2*ncols))
        else:
            fig, ax = plt.subplots(1, ncols, figsize=(12,12))
        for img_i in range(ncols):
            ax_i = ax[img_i]
            ax_i.imshow(cropped_imgs[img_i], cmap=plt.cm.gray)
            ax_i.set_title(f"Img {img_i}\nArea: {stats[img_i][-1]}")
    elif ncols == 1:
        if nrows > 10:
            fig, ax = plt.subplots(nrows, figsize=(12,12 + 2*nrows))
        else:
            fig, ax = plt.subplots(1, nrows, figsize=(12,12))
        for img_i in range(nrows):
            ax_i = ax[img_i]
            ax_i.imshow(cropped_imgs[img_i], cmap=plt.cm.gray)
            ax_i.set_title(f"Img {img_i}\nArea: {stats[img_i][-1]}")
    else:
        fig, ax = plt.subplots(nrows, ncols, figsize=(12,12))
        img_i = 0
        for row_i in range(nrows):
            for col_i in range(ncols):
                ax_i = ax[row_i][col_i]
                ax_i.imshow(cropped_imgs[img_i], cmap=plt.cm.gray)
                ax_i.set_title(f"Img {img_i}\nArea: {stats[img_i][-1]}")
                img_i += 1
    plt.tight_layout()

In [525]:
def sort_bees(stats, centroids):
    good_bees = {"stats" : {"global" : [], "local" : []}, "centroids" : []}
    bad_bees = {"stats" : {"global" : [], "local" : []}, "centroids" : []}
    
    for i in range(len(centroids)):
        centroid = centroids[i]
        local_stat = stats['local'][i]
        global_stat = stats['global'][i]
        area = global_stat[-1]
        if area <= SINGLE_BEE_AREA_THRESHOLD:
            container = good_bees
        else:
            container = bad_bees
            
        container['stats']['global'].append(global_stat)
        container['stats']['local'].append(local_stat)
        container['centroids'].append(centroid)
    
    if good_bees['stats']['global'] == []:
        good_bees = None
    
    if bad_bees['stats']['global'] == []:
        bad_bees = None
        
    return good_bees, bad_bees

In [526]:
def update_bees(container, bees):
    local_bee_stats = np.array(bees['stats']['local'])
    global_bee_stats = np.array(bees['stats']['global'])
    bee_centroids = np.array(bees['centroids'])
    
    if container['stats']['local'] is not None: 
        container['stats']['local'] = np.concatenate([container['stats']['local'], local_bee_stats], axis=0)
        container['stats']['global'] = np.concatenate([container['stats']['global'], global_bee_stats], axis=0)
        container['centroids'] = np.concatenate([container['centroids'], bee_centroids], axis=0)
    else:
        container['stats']['local'] = local_bee_stats
        container['stats']['global'] = global_bee_stats
        container['centroids'] = bee_centroids


In [527]:
def update_overlay(overlay_img, color_i, all_good_bees, new_bad_bee_batch, bee_types='good'):
    if bee_types == 'bad':
        bee_type = 'bad'
    overlay_img = create_overlay_img(overlay_img, bee_type=bee_type, show=False, 
                                     label_color=color_i,
                                     stats=all_good_bees['stats']['global'], 
                                     centroids=all_good_bees['centroids'])

    if new_bad_bee_batch['centroids'] is not None:
        overlay_img = create_overlay_img(overlay_img, bee_type='bad', show=False,
                                         stats=new_bad_bee_batch['stats']['global'], 
                                         centroids=new_bad_bee_batch['centroids'])
    return overlay_img

In [528]:
class VideoLoader:
    def __init__(self, vid_path, color=True, img_limit=None, img_skip=1, start_i=0, end_i=None):
        self.vid_path = vid_path
        self.img_limit = img_limit
        self.img_skip = img_skip
        self.start_i = start_i
        self.end_i = end_i
        self.color_xform = cv2.COLOR_BGR2RGB if color else cv2.COLOR_BGR2GRAY
        
        self._open_stream(vid_path)
        self.num_images_loaded = 0
        
    def _open_stream(self, vid_path):
        self.cap = cv2.VideoCapture(vid_path)
        
    def __iter__(self):
        self.frame_i = 0
        self.num_images_loaded = 0
        return self

    def __next__(self):
        # Read frame and increment frame counter
        ret, frame = self.cap.read()
        self.frame_i += 1
        
        # Check for image limi
        if (self.img_limit and self.num_images_loaded >= self.img_limit) or (self.end_i is not None and self.frame_i >= self.end_i):
            raise StopIteration
        # Check image skip
        elif (self.frame_i % self.img_skip != 0) or (self.frame_i < self.start_i):
            frame = self.__next__()
        else:
            if frame is None:
                raise StopIteration
            
            frame = cv2.cvtColor(frame, self.color_xform)
            self.num_images_loaded += 1
                
        return frame

In [529]:
def run(img):
    all_good_bees = {"stats" : {"global" : None, "local" : None}, "centroids" : None}
    remaining_bad_bees = None

    src_img = np.copy(img)
    
#     clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(19,19))
#     cl1 = clahe.apply(img)
#     print(cl)
    
#     src_img = threshold_image(src_img.astype(np.uint8), show=False)
    src_img = get_erroded_img(src_img, local=LOCAL, global_threshold=GLOBAL_TRHESHOLD, show=False)
    
    new_stats, new_centroids = run_connected_components(src_img, min_factor=MIN_FACTOR, max_factor=MAX_FACTOR)
    good_bees, bad_bees = sort_bees(new_stats, new_centroids)
    update_bees(all_good_bees, good_bees)

    overlay_img = np.copy(img)
    
    for i in range(NUM_ITERATIONS):
        color_i = np.array(cmap(i/NUM_ITERATIONS)[:3])*255
        new_stats = bad_bees['stats']
        new_centroids = bad_bees['centroids']

        cropped_imgs = crop_bees(src_img, new_stats['local'])
        if cropped_imgs is None:
            continue
            
        new_bad_bee_batch = {"stats" : {"global" : None, "local" : None}, "centroids" : None} 
        for crop_i, cropped_img in enumerate(cropped_imgs):
            prev_centroids = new_centroids[crop_i]
            prev_stats = {
                "local"  : new_stats['local'][crop_i],
                "global" : new_stats['global'][crop_i],
            }

            src_img = np.copy(cropped_img)
            src_img = get_erroded_img(src_img, local=False, show=False)
#             src_img = threshold_image(src_img.astype(np.uint8), show=False)
            
            new_stats_i, new_centroids_i = run_connected_components(src_img, prev_stats=prev_stats['global'])

            if new_centroids_i is None:
                prev_stats['local'] = np.expand_dims(prev_stats['local'], axis=0)
                prev_stats['global'] = np.expand_dims(prev_stats['global'], axis=0)
                prev_centroids = np.expand_dims(prev_centroids, axis=0)
                bad_bees_end = {
                    "stats"     : prev_stats,
                    "centroids" : prev_centroids
                }
                update_bees(all_good_bees, bad_bees_end)
                overlay_img = update_overlay(overlay_img, color_i, all_good_bees, new_bad_bee_batch, bee_types='bad')
                continue

            new_good_bees, new_bad_bees = sort_bees(new_stats_i, new_centroids_i)

            if new_good_bees is not None:
                update_bees(all_good_bees, new_good_bees)
            if new_bad_bees is not None:
                update_bees(new_bad_bee_batch, new_bad_bees)

            overlay_img = update_overlay(overlay_img, color_i, all_good_bees, new_bad_bee_batch)

        if new_bad_bee_batch['centroids'] is not None:
            bad_bees = new_bad_bee_batch
        else:
            bad_bees = None
            break

    if bad_bees is not None:
        remaining_bad_bees = bad_bees  

    # Plot overlays
    overlay_img = create_overlay_img(overlay_img, bee_type='good', show=False, final=True,
                                     stats=all_good_bees['stats']['global'], 
                                     centroids=all_good_bees['centroids'])

    if remaining_bad_bees is not None:
        overlay_img = create_overlay_img(overlay_img, bee_type='bad', show=False,
                                         stats=remaining_bad_bees['stats']['global'], 
                                         centroids=remaining_bad_bees['centroids'])


    return overlay_img


In [579]:
movie_name = "full_bees_trim_2.mp4" # "full_bees.mp4", "bees.mov"

slices = {
    "bees.mov" : {
        "slice_r" : slice(0,550),
        "slice_c" : slice(100,None),
    },
    "full_bees_trim.mp4" : {
        "slice_r" : slice(100,1800), 
        "slice_c" : slice(850, 2650)
    },
    "full_bees_trim_2.mp4" : {
        "slice_r" : slice(100,1800), 
        "slice_c" : slice(850, 2650)
    }
}

queen_positions = {
    "full_bees_trim.mp4" : {
        "x" : 1410,
        "y" : 360
    },
    "full_bees_trim_2.mp4" : {
        "x" : 1410,
        "y" : 360
    }
}

slice_r = slices[movie_name]['slice_r']
slice_c = slices[movie_name]['slice_c']

START_I = 0
STOP_I = 10

QUEEN_X = queen_positions[movie_name]['x']
QUEEN_Y = queen_positions[movie_name]['y']

LINE_THICKNESS = 1

SINGLE_BEE_AREA_THRESHOLD = 400

LOCAL = False
NUM_ITERATIONS = 4
cmap = plt.cm.cool
MIN_FACTOR = 0.0001
MAX_FACTOR = 0.025
BOX_SIZE = 2
DOT_SIZE = 5
GLOBAL_TRHESHOLD = 75

FONT = cv2.FONT_HERSHEY_SIMPLEX
FONT_SIZE = 1

NUM_ITERS = 10
overlay_imgs = []

vid = VideoLoader(f"../src_data/Orit/{movie_name}") 

try:
    for img_i, img in enumerate(vid):
    
        if img_i < START_I:
            sys.stdout.write('\rWaiting...')
            sys.stdout.flush()
            continue
        elif img_i >= STOP_I:
            break
        img = img[slice_r, slice_c]
        img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
        try:
            overlay_img = run(img) 
        except Exception as e:
            print(f"E: {e}")
            continue
        overlay_imgs.append(overlay_img)
        
        sys.stdout.write(f'\rImage {img_i+1}')
        sys.stdout.flush()
        
except KeyboardInterrupt:
    print("\nKeyboard Interrupt. Finished.")
    pass


Image 10

In [580]:
def imgs2vid(imgs, outpath, fps=12):
    height, width = imgs[0].shape[0:2]
        
    fourcc = cv2.VideoWriter_fourcc("m", "p", "4", "v")
    video = cv2.VideoWriter(outpath, fourcc, fps, (width, height), True)
    
    for img in imgs:
        video.write(img)
    
    cv2.destroyAllWindows()
    video.release()

In [581]:
overlay_imgs = [cv2.cvtColor(img, cv2.COLOR_RGB2BGR) for img in overlay_imgs]

In [582]:
len(overlay_imgs)

10

In [583]:
imgs2vid(overlay_imgs, movie_name)

In [584]:
!open .