In [1]:
import cv2
from cv2 import VideoCapture
import numpy as np
from tqdm import tqdm
from typing import Callable, Tuple

In [2]:
def is_bbox(array: np.ndarray) -> bool:
    if array is None:
        return False
    return array.ndim == 2 and array.shape[1] == 4 and array.dtype == int

In [3]:
def to_grayscale(image):
    return cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)

In [12]:
# TODO: IMPROVE
def find_boxes(image: np.ndarray) -> np.ndarray:
    _, mask = cv2.threshold(image, 160, 255, cv2.THRESH_BINARY_INV)
    kernel = np.ones((5, 5), np.uint8)
    opened_mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
    contours, _ = cv2.findContours(opened_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

    # Populate bounding boxes
    bbox_list = []
    for c in contours:
        area = cv2.contourArea(c)

        if area < 100 or area > 2000:
            continue

        box = cv2.boundingRect(c)
        bbox_list.append(box)

    # Turn our bboxes into 2d ndarray
    bboxes = np.asanyarray(bbox_list)
    return bboxes

In [5]:
def resize_boxes(bboxes: np.ndarray, new_width: int, new_height: int) -> np.ndarray:
    # Validate bboxes arg
    assert is_bbox(bboxes)

    # Create width and heigh arrays
    num_boxes = bboxes.shape[0]
    new_width = np.full(shape=[num_boxes], fill_value=new_width)
    new_height = np.full(shape=[num_boxes], fill_value=new_height)

    # Unpack columns
    x, y, w, h = bboxes.T

    # Calc center of bbox
    x_mid = x + w // 2
    y_mid = y + h // 2

    # Calc upper left corner
    new_x = x_mid - new_width // 2
    new_y = y_mid - new_height // 2

    # Join the columns back
    new_bboxes = np.column_stack((new_x, new_y, new_width, new_height))
    return new_bboxes

In [6]:
def sanitize_boxes(bboxes: np.ndarray, image_shape: tuple[int]) -> np.ndarray:
    # Validate bboxes arg
    assert is_bbox(bboxes)

    # Get max possible dimensions
    max_height = image_shape[0] - 1
    max_width = image_shape[1] - 1

    # Unpack columns
    x, y, w, h = bboxes.T

    # Make sure bboxes have valid dimensions and within image shape
    good_mask = (x >= 0) & (y >= 0) & (w > 0) & (h > 0)
    good_mask = good_mask & (x + w <= max_width) & (y + h <= max_height)

    # Select only good bboxes
    bboxes = bboxes[good_mask]
    return bboxes

In [7]:
def remove_overlapping_boxes(bboxes: np.ndarray, bboxes_other: np.ndarray = None) -> np.ndarray:
    """
    If bboxes_other is not None then overlap is checked against these bboxes
    """
    # Validate args
    assert is_bbox(bboxes)
    assert bboxes_other is None or (is_bbox(bboxes_other) and bboxes.shape == bboxes_other.shape)

    # Get num elements
    num_boxes = bboxes.shape[0]

    # Calculate left, right, top, bottom limits
    left = np.expand_dims(bboxes[:, 0], axis=1)
    right = np.expand_dims(bboxes[:, 0] + bboxes[:, 2], axis=1)
    top = np.expand_dims(bboxes[:, 1], axis=1)
    bottom = np.expand_dims(bboxes[:, 1] + bboxes[:, 3], axis=1)

    # Calculate left, right, top, bottom limits of other
    if bboxes_other is None:
        left_other = left
        right_other = right
        top_other = top
        bottom_other = bottom
    else:
        left_other = np.expand_dims(bboxes_other[:, 0], axis=1)
        right_other = np.expand_dims(bboxes_other[:, 0] + bboxes_other[:, 2], axis=1)
        top_other = np.expand_dims(bboxes_other[:, 1], axis=1)
        bottom_other = np.expand_dims(bboxes_other[:, 1] + bboxes_other[:, 3], axis=1)

    # Check for left limit intrusions, right limit intrusions, ...
    check_l = (left <= left_other.T) & (left_other.T <= right)
    check_r = (left <= right_other.T) & (right_other.T <= right)
    check_t = (top <= top_other.T) & (top_other.T <= bottom)
    check_b = (top <= bottom_other.T) & (bottom_other.T <= bottom)

    # Check for combinations of left-top intrusions, left-bottom intrusions, ...
    check_lt = check_l & check_t
    check_lb = check_l & check_b
    check_rt = check_r & check_t
    check_rb = check_r & check_b

    # Get all combinations; get rid of self identical matches
    check = check_lt | check_lb | check_rt | check_rb
    check = np.bitwise_xor(check, np.eye(num_boxes, dtype=bool))
    check = np.argwhere(check)

    # Get unique indices of bad bboxes
    bad_indices = np.unique(check)

    # Get indices of good bboxes
    good_indices = np.arange(num_boxes)
    good_indices = good_indices[np.in1d(good_indices, bad_indices, invert=True)]

    # Take only the good bboxes
    good_bboxes = np.take(bboxes, good_indices, axis=0)
    return good_bboxes

In [8]:
def extract_regions(image: np.ndarray, bboxes: np.ndarray) -> list[np.ndarray]:
    assert is_bbox(bboxes)

    regions = []
    for bbox in bboxes:
        x, y, w, h = bbox
        region = image[y : y + h, x : x + w]
        regions.append(region)

    return regions

In [9]:
def extract_worms(
    image: np.ndarray,
    image_transform: Callable[[np.ndarray], np.ndarray] = None,
) -> Tuple[np.ndarray, list[np.ndarray]]:
    # Apply transform if needed
    if image_transform is not None:
        img_trans = image_transform(image)

    # Find bboxes according to given params
    bboxes = find_boxes(img_trans)

    # Calc camera-sized bboxes
    camera_bboxes = resize_boxes(bboxes, new_width=350, new_height=350)

    # Calc worm-sized bboxes
    worm_bboxes = resize_boxes(camera_bboxes, 150, 150)

    # Remove overlapping bboxes between camera-bboxes and worm-bboxes
    camera_bboxes = remove_overlapping_boxes(camera_bboxes, worm_bboxes)

    # Remove bboxes which are out of bounds
    camera_bboxes = sanitize_boxes(camera_bboxes, image.shape)

    # Get corresponding image regions to the camera-bboxes
    regions = extract_regions(image, camera_bboxes)

    return camera_bboxes, regions

In [10]:
cap = VideoCapture("worms.avi")

while True:
    ret, image = cap.read()
    if ret == False:
        break

    if cv2.waitKey(1) & 0xFF == ord("q"):
        break

    coords, rois = extract_worms(image, image_transform=to_grayscale)

    # Draw bboxes
    for box in coords:
        x, y, w, h = box
        cv2.rectangle(image, (x, y), (x + w, y + h), (0, 0, 255), 3)

    cv2.imshow("BBOXES", image)

cap.release()
cv2.destroyAllWindows()

In [11]:
cv2.destroyAllWindows()