# Install Dependencies

In [None]:
#pip install segment-anything
#pip install yolo5

In [None]:
import os
import cv2
import math
import numpy as np
import matplotlib.pyplot as plt
import torch
import matplotlib
matplotlib.use('TkAgg')
import PIL
from PIL import Image

from segment_anything import build_sam, SamPredictor
from yolov5 import YOLOv5 

# Auxillary Function

In [89]:
# Helper functions to show box / mask taken and modified from facebookresearch/segment-anything repository

def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))
    

def show_mask(mask, image, ax, random_color = True, return_sticker = False):
    h, w = mask.shape[-2:]
    annotated_frame_pil = Image.fromarray(image).convert("RGBA")
    
    if return_sticker == False:
        if random_color:
            color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
        else:
            color = np.array([30/255, 144/255, 255/255, 0.6])
        mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
        ax.imshow(mask_image)
        mask_image_pil = Image.fromarray((mask_image * 255).astype(np.uint8)).convert("RGBA")

        return np.asarray(Image.alpha_composite(annotated_frame_pil, mask_image_pil))
    
    else:
        mask_image = mask.reshape(h, w, 1) * annotated_frame_pil
        mask_image_pil = Image.fromarray((mask_image).astype(np.uint8)).convert("RGBA")

        return mask_image_pil

# Load Image

In [121]:
# Load the image
local_image_path = 'assets/bird.jpg'
image = cv2.imread(local_image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

In [122]:
plt.figure(figsize=(10,10))
plt.imshow(image)
plt.axis('off')
plt.show()

# Load YOLOv5 and SA models

In [5]:
# Install the YOLOv5 model

device = torch.device('cpu')

# set model params
model_path = "yolov5/weights/yolov5s.pt" # it automatically downloads yolov5s model to given path

# initialize the model
yolov5 = YOLOv5(model_path, device)

YOLOv5  2023-4-16 Python-3.10.9 torch-2.0.0+cpu CPU

  from .autonotebook import tqdm as notebook_tqdm
Fusing layers... 
YOLOv5s summary: 270 layers, 7235389 parameters, 0 gradients
Adding AutoShape... 


In [6]:
# Download the pretrained weights for the SAM

! wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth

--2023-04-18 00:00:06--  https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
Resolving dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)... failed: No such host is known. .
wget: unable to resolve host address 'dl.fbaipublicfiles.com'


In [7]:
# The VIT-H is the default model

sam_checkpoint = 'sam_vit_h_4b8939.pth'

sam = build_sam(checkpoint=sam_checkpoint)
sam.to(device=device)

sam_predictor = SamPredictor(sam)

# Output from YOLOv5 and SA models

In [123]:
# perform inference
results = yolov5.predict(image)
box = []
for detection in results.pred:
    for i in range(len(detection.detach().cpu().numpy())):
        box.append((detection.detach().cpu().numpy())[i][:4])
        box[i] = np.array(box[i])
box = np.asarray(box)

In [125]:
plt.figure(figsize=(10,10))
plt.imshow(image)
for boxes in box:
    show_box(boxes, plt.gca())
plt.show()

In [126]:
sam_predictor.set_image(image)

In [129]:
if box.shape[0] > 1:
    boxes = torch.Tensor(box, device = sam_predictor.device)
    transformed_boxes = sam_predictor.transform.apply_boxes_torch(boxes, image.shape[:2])
    masks, _, _ = sam_predictor.predict_torch(
        point_coords=None,
        point_labels=None,
        boxes = transformed_boxes,
        multimask_output=False,
    )
    
else:
    masks, _, _ = sam_predictor.predict(
        point_coords=None,
        point_labels=None,
        box = box,
        multimask_output=False,
    )

In [131]:
plt.imshow(image)
for mask in masks:
    segmented_mask = show_mask(mask.cpu().numpy(), image, plt.gca(), random_color = True, return_sticker = False)
for boxes in box:
    show_box(boxes, plt.gca())
plt.axis('off')
plt.show()

In [133]:
sticker = []
for mask in masks:
    sticker.append(show_mask(mask.cpu().numpy(), image, plt.gca(), return_sticker = True))

In [134]:
plt.imshow(sticker[0])
plt.show()

In [135]:
plt.imshow(sticker[1])
plt.show()

In [136]:
# Save output in results/

directory = "results/" 
if not os.path.exists(directory):
    os.makedirs(directory)

file_heading = 'sticker_'
file_sub_heading = os.path.splitext(os.path.basename(local_image_path))[0]
file_num = []
file_ending = '.png'

for i in range(len(sticker)):
    file_num.append('_{}'.format(i + 1))
    file_name = file_heading + file_sub_heading + file_num[i] + file_ending
    if not os.path.exists(os.path.join(directory, file_name)):
        sticker[i].save(os.path.join(directory, file_name))