## Segment Anything Image Preprocessing

In [1]:
NEW_ENV = False

if NEW_ENV:
    !pip install numpy pandas matplotlib
    !pip install opencv-python torch
    !pip install segment-anything


In [2]:
import cv2
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
from PIL import Image
import torch

from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor

# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')


In [3]:
# --- Load SAM-2 model and predictor ---
sam2_checkpoint = "../checkpoints/sam2.1_hiera_large.pt"
model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"
sam2_model = build_sam2(config_file=model_cfg, ckpt_path=sam2_checkpoint, device="mps")
sam2_model.eval().to(device)

# --- Create the predictor ---
predictor = SAM2ImagePredictor(sam2_model, device=device)


In [4]:
def show_points(coords, labels, ax, marker_size=375):
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)

def show_box(box, ax):
    print(box)
    print(box[0])
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0, 0, 0, 0), lw=2))

def show_mask(mask, ax, random_color=False, borders = True):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30/255, 144/255, 255/255, 0.6])
    h, w = mask.shape[-2:]
    mask = mask.astype(np.uint8)
    mask_image =  mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    if borders:
        import cv2
        contours, _ = cv2.findContours(mask,cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
        # Try to smooth contours
        contours = [cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours]
        mask_image = cv2.drawContours(mask_image, contours, -1, (1, 1, 1, 0.5), thickness=2)
    ax.imshow(mask_image)

def show_masks(image, masks, scores, point_coords=None, box_coords=None, input_labels=None, borders=True):
    for i, (mask, score) in enumerate(zip(masks, scores)):
        plt.figure(figsize=(10, 10))
        plt.imshow(image)
        show_mask(mask, plt.gca(), borders=borders)
        if point_coords is not None:
            assert input_labels is not None
            show_points(point_coords, input_labels, plt.gca())
        if box_coords is not None:
            # boxes
            show_box(box_coords, plt.gca())
        if len(scores) > 1:
            plt.title(f"Mask {i+1}, Score: {score:.3f}", fontsize=18)
        plt.axis('off')
        plt.show()


In [24]:
image_dir = '/Users/m0h00x7/Projects/ncstate/CSC542/ECE542_Project/data/images'
segmented_image_dir = '/Users/m0h00x7/Projects/ncstate/CSC542/ECE542_Project/data/segm_images'
masks_dir = '/Users/m0h00x7/Projects/ncstate/CSC542/ECE542_Project/data/masks'

# input_box = np.array([256, 512, 768, 1536])
# input_label = np.array([1])

# load mask from file
mask_from_file = np.load(f'{masks_dir}/mask01.npy')


img_files = list(Path(image_dir).glob('*.jpg'))
for index, img_file in enumerate(img_files):

    image = Image.open(img_file)
    image_rgb = np.array(image.convert("RGB"))

    # plt.figure(figsize=(10, 10))
    # plt.imshow(image_rgb)
    # show_box(input_box, plt.gca())
    # plt.axis('on')
    # plt.show()

    # set image in mask predictor
    predictor.set_image(image_rgb)
    # masks, scores, _ = predictor.predict(box=input_box[None, :], multimask_output=True)
    masks, scores, _ = predictor.predict(mask_input=mask_from_file, multimask_output=False)
    # masks, scores, _ = predictor.predict(multimask_output=False)
    show_masks(image, masks, scores)

    pass

    # Create a mask where foreground is white (255) and background is black (0)
    mask_uint8 = (masks[0] * 255).astype(np.uint8)

    # # Invert the mask to get the background mask
    # background_mask = 255 - mask_uint8

    # Create the output image with the background removed
    output_image = cv2.bitwise_and(image_rgb, image_rgb, mask=mask_uint8)
    plt.figure(figsize=(5, 5))
    plt.imshow(output_image)
    plt.axis('on')
    plt.show()

    # save segmented image
    output_image = Image.fromarray(output_image)
    output_image.save(f'{segmented_image_dir}/{img_file.name}')

    if index % 100 == 99:
        print(f'{index+1}')
    elif index % 10 == 9:
        print(f'{index+1}', end='')
    else:
        print('.', end='')


RuntimeError: The size of tensor a (64) must match the size of tensor b (256) at non-singleton dimension 3

In [21]:
# masks_dir = '/Users/m0h00x7/Projects/ncstate/CSC542/ECE542_Project/data/masks'
# .save(f'{masks_dir}/mask01')

np.save(f'{masks_dir}/mask01.npy', masks)


In [None]:
best_masks.append(masks[0])


In [None]:
show_masks(image, best_masks, scores)
