In [None]:
from ultralytics import FastSAM
import cv2
from shapely.geometry import Polygon
import os 

In [None]:
# Create a FastSAM model
model = FastSAM('FastSAM-x.pt')

In [None]:
# Defining the path to image 
image_path = "input/DJI_0952-2023-11-30-11-10-42.jpg"

# Opening the image
image = cv2.imread(image_path)

# Saving the original w and h 
original_w = image.shape[1]
original_h = image.shape[0]

In [None]:
# Run inference on an image
everything_results = model(image, device='cpu', retina_masks=True, imgsz=1024, conf=0.4, iou=0.9)

In [None]:
boxes = everything_results[0].boxes.xyxy.detach().numpy()

# Converting each pixel to int 
for bbox in boxes:
    bbox[0] = int(bbox[0])
    bbox[1] = int(bbox[1])
    bbox[2] = int(bbox[2])
    bbox[3] = int(bbox[3])

In [None]:
def get_nearest_polygon(center_coords: list, boxes: list) -> list:
    """
    Returns the nearest polygon to the center of the image
    """
    # Defining the center of the image
    center_x = center_coords[0]
    center_y = center_coords[1]

    # (top left x, top left y, bottom right x, bottom right y)
    center_patch = [center_x - 25, center_y - 25, center_x + 25, center_y + 25]

    # Converting to shapely polygon
    center_patch_poly = Polygon([
        (center_patch[0], center_patch[1]), 
        (center_patch[2], center_patch[1]), 
        (center_patch[2], center_patch[3]), 
        (center_patch[0], center_patch[3])
        ])

    # Saving the bboxes
    bboxes_poly = [Polygon([(x[0], x[1]), (x[2], x[1]), (x[2], x[3]), (x[0], x[3])]) for x in boxes]

    # Saving the overlaps
    overlaps = [x.intersection(center_patch_poly).area for x in bboxes_poly]

    # Getting the index of the bbox with the highest overlap
    index = overlaps.index(max(overlaps))

    # Saving the bbox with the highest overlap
    bbox_highest = boxes[index]

    # Converting to int
    bbox_highest = [int(x) for x in bbox_highest]

    # Returning the bbox with the highest overlap
    return bbox_highest

In [None]:
# Creating the center coord list 
center_coords = [original_w // 2, original_h // 2]

# Getting the nearest polygon
bbox_highest = get_nearest_polygon(center_coords, boxes)

In [None]:
def adjust_center_coords(center_patch: list, bbox_highest: list) -> list: 
    # Calculating the distances to each of the edges;
    # The first element is the distance to the left edge
    # The second element is the distance to the right edge
    # The third element is the distance to the top edge
    # The fourth element is the distance to the bottom edge
    distance_left = abs(center_patch[0] - bbox_highest[0])
    distance_right = abs(center_patch[2] - bbox_highest[2])
    distance_top = abs(center_patch[1] - bbox_highest[1])
    distance_bottom = abs(center_patch[3] - bbox_highest[3])

    # Infering which edge is the closest
    distances = [distance_left, distance_right, distance_top, distance_bottom]

    # Getting the index of the closest edge
    index = distances.index(min(distances))

    # Moving the center patch to the closest edge
    if index == 0:
        # Moving to the left edge
        y1 = center_patch[1]
        y2 = center_patch[3]

        x1 = center_patch[0] - distance_left - abs(center_patch[0] - center_patch[2])
        x2 = center_patch[2] - distance_left - abs(center_patch[0] - center_patch[2])

        # Updating the center patch
        center_patch = [x1, y1, x2, y2]
    elif index == 1:
        # Moving to the right edge
        y1 = center_patch[1]
        y2 = center_patch[3]

        x1 = center_patch[0] + distance_right + abs(center_patch[0] - center_patch[2])
        x2 = center_patch[2] + distance_right + abs(center_patch[0] - center_patch[2])

        # Updating the center patch
        center_patch = [x1, y1, x2, y2]
    elif index == 2:
        # Moving to the top edge
        x1 = center_patch[0]
        x2 = center_patch[2]

        y1 = center_patch[1] - distance_top - abs(center_patch[1] - center_patch[3])
        y2 = center_patch[3] - distance_top - abs(center_patch[1] - center_patch[3]) 

        # Updating the center patch
        center_patch = [x1, y1, x2, y2]
    elif index == 3:
        # Moving to the bottom edge
        x1 = center_patch[0]
        x2 = center_patch[2]

        y1 = center_patch[1] + distance_bottom + abs(center_patch[1] - center_patch[3])
        y2 = center_patch[3] + distance_bottom + abs(center_patch[1] - center_patch[3])

        # Updating the center patch
        center_patch = [x1, y1, x2, y2]

    return center_patch


In [None]:
def get_new_coords(center_coords: list, highest_overlap_coords: list, image_area: float, overlap_treshold: float = 0.1) -> list:
    """
    This function takes in the center coords and the highest overlap coords and returns the new coords
    """
    # Creating the polygons
    center_poly = Polygon([
        (center_coords[0], center_coords[1]), 
        (center_coords[2], center_coords[1]), 
        (center_coords[2], center_coords[3]), 
        (center_coords[0], center_coords[3])
        ])
    highest_overlap_poly = Polygon([
        (highest_overlap_coords[0], highest_overlap_coords[1]), 
        (highest_overlap_coords[2], highest_overlap_coords[1]), 
        (highest_overlap_coords[2], highest_overlap_coords[3]), 
        (highest_overlap_coords[0], highest_overlap_coords[3])
        ])
    
    # Getting the overlaps
    overlap = center_poly.intersection(highest_overlap_poly).area / center_poly.area
    highest_poly_overlap = highest_overlap_poly.area / image_area

    # If the overlap is smaller than the trehsold AND the center polygon is all in the shape, we adjust the center coords 
    if overlap == 1.0 and highest_poly_overlap < overlap_treshold:
        center_coords = adjust_center_coords(center_coords, highest_overlap_coords)
        return center_coords
    else: 
        return center_coords

In [None]:
# Defining the original center patch 
center_patch = [original_w // 2 - 25, original_h // 2 - 25, original_w // 2 + 25, original_h // 2 + 25]

# Adjusting the coords 
center_patch = get_new_coords(center_patch, bbox_highest, image_area=original_w * original_h)

# Drawing the new center patch 
image = cv2.rectangle(image, (center_patch[0], center_patch[1]), (center_patch[2], center_patch[3]), (255, 0, 0), 2)

In [None]:
# Saving the image 
cv2.imwrite("output.jpg", image)

# Iterating over all the images in the input dir 


In [None]:
def pipeline(
        image_path: str, 
        center_path_offset: 25, 
        model: FastSAM, 
        output_dir: str
        ) -> None: 
    """
    This function takes in an image path and returns the new center patch
    """
    # Opening the image
    image = cv2.imread(image_path)

    # Saving the original w and h 
    original_w = image.shape[1]
    original_h = image.shape[0]

    # Run inference on an image
    everything_results = model(image, device='cpu', retina_masks=True, imgsz=1024, conf=0.4, iou=0.9)

    boxes = everything_results[0].boxes.xyxy.detach().numpy()

    # Converting each pixel to int 
    for bbox in boxes:
        bbox[0] = int(bbox[0])
        bbox[1] = int(bbox[1])
        bbox[2] = int(bbox[2])
        bbox[3] = int(bbox[3])

    # Creating the center coord list 
    center_coords = [original_w // 2, original_h // 2]

    # Getting the nearest polygon
    bbox_highest = get_nearest_polygon(center_coords, boxes)

    # Drawing the rectangle 
    image = cv2.rectangle(image, (bbox_highest[0], bbox_highest[1]), (bbox_highest[2], bbox_highest[3]), (0, 0, 255), 4)

    # Defining the original center patch 
    center_patch = [original_w // 2 - center_path_offset, original_h // 2 - center_path_offset, original_w // 2 + center_path_offset, original_h // 2 + center_path_offset]

    # Drawing the original rectangle 
    image = cv2.rectangle(image, (center_patch[0], center_patch[1]), (center_patch[2], center_patch[3]), (0, 0, 255), 2)

    # Adjusting the coords 
    center_patch = get_new_coords(center_patch, bbox_highest, image_area=original_w * original_h)

    # Drawing the new center patch 
    image = cv2.rectangle(image, (center_patch[0], center_patch[1]), (center_patch[2], center_patch[3]), (255, 0, 0), 2)

    # Saving the image
    image_basename = os.path.basename(image_path)
    cv2.imwrite(os.path.join(output_dir, image_basename), image)

In [None]:
# Listing the images in input dir 
input_dir = "input/"
images = os.listdir(input_dir)

# Creating the output dir
output_dir = "output/"
if not os.path.exists(output_dir):
    os.makedirs(output_dir)

# Iterating over the images
for image in images:
    # Creating the image path
    image_path = os.path.join(input_dir, image)

    # Running the pipeline
    pipeline(image_path, 25, model, output_dir)