<a href="https://colab.research.google.com/github/kelanliu1/segment-anything/blob/main/SAM.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth

In [None]:
!pip install torch torchvision
!pip install opencv-python matplotlib
!pip install 'git+https://github.com/facebookresearch/segment-anything.git'

In [None]:
import torch
import torchvision
import cv2
import numpy as np
import matplotlib.pyplot as plt
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor

# Initialize YOLOv5
model = torch.hub.load("ultralytics/yolov5", "yolov5s", pretrained=True).to("cuda")

In [None]:
# Initialize Segment-Anything
sam_checkpoint = "sam_vit_h_4b8939.pth"
model_type = "vit_h"
device = "cuda"
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)
mask_generator_ = SamAutomaticMaskGenerator(
    model=sam,
    points_per_side=32,
    pred_iou_thresh=0.9,
    stability_score_thresh=0.96,
    crop_n_layers=1,
    crop_n_points_downscale_factor=2,
    min_mask_region_area=100,
)

In [None]:
# User input class to blur
class_to_blur = "person"

# Load an image
input_image = "trifecta.jpg"
image = cv2.imread(input_image)
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

# Make a copy of the original image
original_image = image_rgb.copy()

# Perform YOLOv5 object detection
detections = model(image_rgb)
detections_list = detections.xyxy[0].tolist()

# Process detections
for det in detections_list:
    x1, y1, x2, y2, conf, cls = det
    class_name = model.names[int(cls)]

    if class_name == class_to_blur:
        x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
        cropped_frame = image_rgb[y1:y2, x1:x2].copy()
        masks = mask_generator_.generate(cropped_frame)
        for mask in masks:
            segmentation = mask["segmentation"]
            blurred_region = cv2.GaussianBlur(cropped_frame, (99, 99), 30)
            cropped_frame = np.where(segmentation[..., None], blurred_region, cropped_frame)
        image_rgb[y1:y2, x1:x2] = cropped_frame


# Save the blurred image
output_image_path = "output_image.jpg"
output_image = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2BGR)
cv2.imwrite(output_image_path, output_image)


True

Test:

In [None]:
# User input class to blur
class_to_blur = "person"

# Load an image
input_image = "trifecta.jpg"
image = cv2.imread(input_image)
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

# Make a copy of the original image
original_image = image_rgb.copy()

# Perform YOLOv5 object detection
detections = model(image_rgb)
detections_list = detections.xyxy[0].tolist()

# Process detections
for det in detections_list:
    x1, y1, x2, y2, conf, cls = det
    class_name = model.names[int(cls)]

    if class_name == class_to_blur:
        x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
        cropped_frame = image_rgb[y1:y2, x1:x2].copy()
        masks = mask_generator_.generate(cropped_frame)
        
        # Find the largest mask
        largest_mask = None
        largest_mask_area = 0
        for mask in masks:
            segmentation = mask["segmentation"]
            mask_area = np.sum(segmentation)
            if mask_area > largest_mask_area:
                largest_mask_area = mask_area
                largest_mask = mask
        
        if largest_mask is not None:
            # Apply blur only on the largest mask
            segmentation = largest_mask["segmentation"]
            blurred_region = cv2.GaussianBlur(cropped_frame, (99, 99), 30)
            cropped_frame = np.where(segmentation[..., None], blurred_region, cropped_frame)
            image_rgb[y1:y2, x1:x2] = cropped_frame

# Save the blurred image
output_image_path = "output_image.jpg"
output_image = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2BGR)
cv2.imwrite(output_image_path, output_image)