In [None]:
import cv2
import matplotlib.pyplot as plt
import numpy as np
import pillow_avif
import torch
from PIL import Image

from segment_anything import sam_model_registry
from segment_anything import SamAutomaticMaskGenerator, SamPredictor

In [None]:
def show_anns(anns):
    if len(anns) == 0:
        return None
    sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
    img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4))
    # The alpha channel is set to 0, i.e. the image is transparent
    img[:,:,3] = 0
    for ann in sorted_anns:
        m = ann['segmentation']
        # Generate a random colol: np.random.random(3) generates a random float for each color channel
        # Transparency 0.35
        color_mask = np.concatenate([np.random.random(3), [0.35]])
        img[m] = color_mask
    return img

In [None]:
print(sam_model_registry.keys())

In [None]:
# Default is vit_h
sam = sam_model_registry['vit_h'](checkpoint='../models/sam_vit_h_4b8939.pth')
mask_generator = SamAutomaticMaskGenerator(sam)

In [None]:
img = np.array(Image.open('../../../fashion/images/brands/wear-1703467514_1000.avif'))
fig, ax = plt.subplots(figsize=(12, 12))
ax.axis('off')
ax.imshow(img)

In [None]:
img = np.array(img)

In [None]:
%time masks = mask_generator.generate(img)

In [None]:
sam.to(device='cuda')
%time masks = mask_generator.generate(img)

In [None]:
fig, ax = plt.subplots(figsize=(12, 12))
annotated_img = show_anns(masks)
ax.imshow(annotated_img)
ax.axis('off')
plt.show()