In [1]:
import torch
import cv2
import numpy as np
from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
import glob


# Params
IMAGE_PATH = './segmentation/*.png'
RESIZE_TO = (1920, 1080)
SAM_CHECKPOINT = 'sam_vit_b_01ec64.pth'  # Change to your checkpoint
MODEL_TYPE = 'vit_b'  # Change to match your model (vit_b, vit_l, vit_h)


# Load model
device = "cuda" if torch.cuda.is_available() else "cpu"
sam = sam_model_registry[MODEL_TYPE](checkpoint=SAM_CHECKPOINT)
sam.to(device=device)

# Create mask generator
mask_generator = SamAutomaticMaskGenerator(
    sam,
    # points_per_batch=32,  # Default is 64
    # crop_n_layers=0       # Reduces memory usage
)

for image_path in sorted(glob.glob(IMAGE_PATH)):
    try:
        data = np.load(image_path[:image_path.rfind('.')] + '_masks.npz', allow_pickle=True)
        continue
    except:
        pass

    # Load image
    image = cv2.imread(image_path)
    image = cv2.imread(image_path)
    image = cv2.bilateralFilter(image, 9, 75, 75)
    kernel = np.array([[0, -1, 0],
                       [-1, 5,-1],
                       [0, -1, 0]])
    image = cv2.filter2D(image, -1, kernel)
    image = cv2.resize(image, RESIZE_TO)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    
    # Generate all masks
    masks = mask_generator.generate(image)
    
    # Save masks for futher processing
    np.savez(image_path[:image_path.rfind('.')] + '_masks.npz', masks=masks)