In [1]:
import os
import cv2
import numpy as np
from pathlib import Path
from tqdm import tqdm
from segment_anything import SamPredictor, sam_model_registry

In [2]:
# Load SAM model
sam_checkpoint = "sam_vit_h_4b8939.pth"  # Adjust path to your downloaded checkpoint
model_type = "vit_h"
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
predictor = SamPredictor(sam)

In [3]:
def remove_background(image):
    """
    Removes the background of an image using SAM with a center point prompt.
    """
    # Convert to RGB for SAM
    image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    predictor.set_image(image_rgb)

    # Use center of image as prompt (assumes pig is roughly centered in crop)
    height, width = image.shape[:2]
    center_point = np.array([[width // 2, height // 2]])
    point_label = np.array([1])  # 1 = foreground

    # Predict mask
    masks, scores, _ = predictor.predict(
        point_coords=center_point,
        point_labels=point_label,
        multimask_output=False
    )
    mask_binary = masks[np.argmax(scores)]  # Best mask (boolean)

    # Convert mask to uint8 and scale to 0-255
    mask_binary = (mask_binary * 255).astype(np.uint8)

    # Create RGBA image with transparent background
    result = cv2.cvtColor(image, cv2.COLOR_BGR2BGRA)
    result[:, :, 3] = mask_binary  # Apply transparency

    return result, mask_binary

In [4]:
def apply_mask_to_depth(depth_image, mask):
    """
    Applies the background removal mask to the depth image while preserving its original color.
    """
    # Ensure mask is uint8
    if mask.dtype != np.uint8:
        mask = mask.astype(np.uint8)

    # Check if depth_image is multi-channel (e.g., BGR with colormap)
    if len(depth_image.shape) == 3:  # Color image
        # Ensure mask matches depth_image size
        if mask.shape != depth_image.shape[:2]:
            mask = cv2.resize(mask, (depth_image.shape[1], depth_image.shape[0]), interpolation=cv2.INTER_NEAREST)
        
        # Apply mask to each channel separately to preserve color
        result = depth_image.copy()
        for channel in range(result.shape[2]):  # Apply to R, G, B channels
            result[:, :, channel] = cv2.bitwise_and(result[:, :, channel], result[:, :, channel], mask=mask)
        
        # Convert to BGRA and set alpha channel
        result_bgra = cv2.cvtColor(result, cv2.COLOR_BGR2BGRA)
        result_bgra[:, :, 3] = mask  # Alpha channel for transparency
    else:  # Single-channel depth (unlikely in your case, but included for robustness)
        if mask.shape != depth_image.shape:
            mask = cv2.resize(mask, (depth_image.shape[1], depth_image.shape[0]), interpolation=cv2.INTER_NEAREST)
        result = cv2.bitwise_and(depth_image, depth_image, mask=mask)
        result_bgra = cv2.cvtColor(result, cv2.COLOR_GRAY2BGRA)
        result_bgra[:, :, 3] = mask

    return result_bgra

In [5]:
def process_subfolder(subdir, output_subdir):
    output_subdir.mkdir(parents=True, exist_ok=True)

    image_files = sorted(subdir.glob("frame_*_rgb_crop.png"))
    for rgb_file in tqdm(image_files, desc=f"Processing {subdir.name}"):
        depth_file = rgb_file.with_name(rgb_file.stem.replace("_rgb", "_depth") + rgb_file.suffix)

        if not depth_file.exists():
            print(f"Warning: No corresponding depth image for {rgb_file.name}")
            continue

        # Read images
        rgb_image = cv2.imread(str(rgb_file))
        depth_image = cv2.imread(str(depth_file), cv2.IMREAD_UNCHANGED)

        # Remove background with SAM
        rgb_no_bg, mask_binary = remove_background(rgb_image)
        depth_no_bg = apply_mask_to_depth(depth_image, mask_binary)

        # Save output images
        cv2.imwrite(str(output_subdir / rgb_file.name), rgb_no_bg)
        cv2.imwrite(str(output_subdir / depth_file.name), depth_no_bg)

In [6]:
def process_images(input_folder, output_folder):
    input_path = Path(input_folder)
    output_path = Path(output_folder)
    
    if not input_path.exists():
        print("Input folder does not exist.")
        return
    
    for subdir in tqdm(list(input_path.glob('*')), desc="Processing folders"):
        if subdir.is_dir():
            process_subfolder(subdir, output_path / subdir.relative_to(input_path))

In [7]:
if __name__ == "__main__":
    input_folder = ".\\Jan_14_2025-Batch1\\outputs_test"
    output_folder = ".\\Jan_14_2025-Batch1\\background_remove_output_04032025_2"
    process_images(input_folder, output_folder)

Processing depth_rgb_recording(1):   6%|▌         | 29/492 [29:26<7:50:08, 60.93s/it]
Processing folders:   0%|          | 0/7 [29:26<?, ?it/s]


KeyboardInterrupt: 