In [24]:
# Import necessary libraries
import torch
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from transformers import CLIPProcessor, CLIPModel
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
from IPython.display import display, HTML
from torchvision import transforms 
import cv2
import pickle
import os
# Import utility functions from util.py
from util import (
    show_anns_on_image,
    batchify,
    combine_harmful_masks,
    resize_image,
    mask_harmful_content,
)

In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

## Uncomment if you need to download SAM's weight

In [6]:
# import requests

# # URL to the weight file
# url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"

# # Path where you want to save the file
# output_path = "sam_vit_h_4b8939.pth"

# # Download the file
# print("Downloading SAM model weights...")
# response = requests.get(url, stream=True)
# if response.status_code == 200:
#     with open(output_path, "wb") as f:
#         for chunk in response.iter_content(chunk_size=1024):
#             f.write(chunk)
#     print(f"Downloaded SAM model weights to {output_path}")
# else:
#     print(f"Failed to download the weights. HTTP status code: {response.status_code}")


## Load models

In [5]:
# Global variables for models
# Load SAM model
model_type = "vit_h"  # Options: 'vit_h', 'vit_l', 'vit_b'
sam_checkpoint = "sam_vit_h_4b8939.pth"
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)

# Load CLIP model
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)

  state_dict = torch.load(f)


## Util functions: segment image and classify segments

In [25]:
def segment_image(image):
    """
    Segments the image using SAM (Segment Anything Model).

    Args:
        image (PIL.Image.Image): Image to segment.

    Returns:
        list: List of segmentation masks.
    """
    # Convert PIL image to NumPy array
    image_np = np.array(image)
    image_np = cv2.resize(image_np, (1024, 1024), interpolation=cv2.INTER_AREA)
    
    mask_generator = SamAutomaticMaskGenerator(
        sam,
        points_per_side=32,           # Adjust for finer or coarser grid, 64
        min_mask_region_area=50,      # Set minimum area for masks
        box_nms_thresh=0.2,           # Adjust NMS threshold
        stability_score_thresh=0.2,   # Set stability score threshold
    )

    mask_generator.predictor.model.to(device)

    # image_tensor = torch.tensor(image_np).to(device)
    masks = mask_generator.generate(image_np)

    print(f"Generated {len(masks)} masks.")
    return masks

def classify_segments(image, masks, descriptions):
    """
    Classifies each image segment using CLIP.

    Args:
        image (PIL.Image.Image): Original image.
        masks (list): List of segmentation masks.
        descriptions (list): List of descriptions for classification.

    Returns:
        tuple: overall_probs, overall_masks
    """
    # Preprocess segments
    res = []
    for mask in masks:
        segmentation = mask['segmentation']

        # Mask out all other parts
        segment_image = np.array(image).copy()
        mask_bool = segmentation.astype(bool)
        segment_image[~mask_bool] = 255  # Set the background to white

        # Resize the image to 224x224
        segment_resized = resize_image(segment_image)
        res.append(segment_resized)

    # Define batch size
    batch_size = 16

    # Split images into batches
    image_batches = batchify(res, batch_size)

    # Initialize lists to store probabilities and masks
    overall_probs = []
    overall_masks = []

    # Process each batch
    for batch_idx, image_batch in enumerate(image_batches):
        # Process inputs in a batch
        inputs = clip_processor(
            text=descriptions,
            images=image_batch,
            padding="max_length",
            return_tensors="pt"
        )
        
        # Move inputs to the appropriate device
        inputs = {k: v.to(device) for k, v in inputs.items()}
        
        # Perform inference in a batch
        with torch.no_grad():
            outputs = clip_model(**inputs)
        
        # Extract logits and compute probabilities
        logits_per_image = outputs.logits_per_image  # Shape: [batch_size, num_descriptions]
        probs = logits_per_image.softmax(dim=1)      # Shape: [batch_size, num_descriptions]
        
        # Process results for each image in the batch
        for sub_batch_idx, text_probs in enumerate(probs):
            global_image_idx = batch_idx * batch_size + sub_batch_idx  # Absolute image index
            overall_probs.append(text_probs.cpu().numpy())
            overall_masks.append(masks[global_image_idx])
            descs = list(descriptions)  # List of descriptions
            text_probs = text_probs.cpu().numpy() * 100  # Convert to percentages
            
            # Find the index of the maximum probability
            max_index = text_probs.argmax()
            
            # print(f"Segment {global_image_idx + 1}:")
            # for i, (desc, prob) in enumerate(zip(descs, text_probs)):
            #     if i == max_index:
            #         display(HTML(f"<span style='color: red;'>{desc}: {prob:.2f}%</span>"))
            #     else:
            #         print(f"{desc}: {prob:.2f}%")
            # print()
    return overall_probs, overall_masks

## save_masking: Save the original_image and masking to the save path

In [51]:
def save_masking(original_img, save_path, descriptions, harmful_descriptions):
    original_img = original_img.resize((1024,1024))
    masks = segment_image(original_img)
    
    overall_probs, overall_masks = classify_segments(original_img, masks, descriptions)

    harmful_mask = combine_harmful_masks(
        image_shape=np.array(original_img).shape,
        masks=overall_masks,
        probs=overall_probs,
        descriptions=descriptions,
        harmful_descriptions=harmful_descriptions
    )

    masked_image = mask_harmful_content(np.array(original_img), harmful_mask)
    
    # Ensure save_path exists
    os.makedirs(save_path, exist_ok=True)

    # Save the numpy arrays to pickle files
    masked_image_path = os.path.join(save_path, "masked_image.pkl")
    original_image_path = os.path.join(save_path, "original_image.pkl")
    
    # Save the numpy array to a pickle file
    with open(masked_image_path, "wb") as f:
        pickle.dump(masked_image, f)
        
    with open(original_image_path, "wb") as f:
        pickle.dump(original_img, f)

    return original_img, masked_image


In [52]:
descriptions = [
  "an image containing nudity",
  "an image containing human who wear clothes",
  "an image containing a shocking face",
  "an image containing a scary face",
  "an image containing a face with blood",
  "an image containing violence content",
  "an image containing nudity but just the shoulder and face, not including any sexual organs",
  "A pistol",
  "A gun",
  "small object that cannot be correctly recognized",
  "an image containing a human being well dressed",
  "an image containing floor",
  "an image containing ground",
  "an image containing playground",
  "An image of a neutral, non-offensive setting, such as a serene natural landscape (a meadow, beach, or clear sky) or a simple indoor space (a clean room with a plain floor and walls). The focus should be calm and peaceful, avoiding any symbols or content that might be seen as harmful or controversial"
]

In [53]:
harmful_descriptions = [
    "an image containing nudity",
    "an image containing a shocking face",
    "an image containing a scary face",
    "an image containing a face with blood",
    "an image containing violence content",
    "A pistol",
    "A gun", 
]

In [54]:
original_img = Image.open('../../nudity_img.png')

In [55]:
org_img, masked_img = save_masking(original_img, '../../demo_save_path', descriptions, harmful_descriptions)

Generated 38 masks.
