In [1]:
import os
import cv2
import numpy as np
from pathlib import Path
from tqdm import tqdm
from segment_anything import SamPredictor, sam_model_registry
import torch

In [2]:
# Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU name: {torch.cuda.get_device_name(0)}")

Using device: cuda
GPU name: NVIDIA GeForce RTX 3050 Laptop GPU


In [3]:
# Load SAM model and move to GPU
sam_checkpoint = "sam_vit_h_4b8939.pth"  # Adjust path to your downloaded checkpoint
model_type = "vit_h"
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint).to(device)
predictor = SamPredictor(sam)

In [4]:
def remove_background(image):
    """
    Removes the background of an image using SAM with a center point prompt on GPU.
    """
    # Convert to RGB for SAM
    image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    predictor.set_image(image_rgb)  # Preprocess image (should move to GPU internally)

    # Use center of image as prompt
    height, width = image.shape[:2]
    center_point = np.array([[width // 2, height // 2]], dtype=np.float32)  # Keep as NumPy array
    point_label = np.array([1], dtype=np.int64)  # Keep as NumPy array

    # Ensure prediction runs on GPU and returns tensors
    with torch.no_grad():  # Disable gradient computation for inference
        masks, scores, _ = predictor.predict(
            point_coords=center_point,  # Pass NumPy array
            point_labels=point_label,   # Pass NumPy array
            multimask_output=False,
            return_logits=False
        )

    # Select best mask
    mask_binary = masks[np.argmax(scores)]

    # Move to CPU and convert to NumPy
    if isinstance(mask_binary, torch.Tensor):
        mask_binary = mask_binary.cpu().numpy()
    mask_binary = (mask_binary * 255).astype(np.uint8)

    # Create RGBA image with transparent background
    result = cv2.cvtColor(image, cv2.COLOR_BGR2BGRA)
    result[:, :, 3] = mask_binary

    return result, mask_binary

In [5]:
def apply_mask_to_depth(depth_image, mask):
    if mask.dtype != np.uint8:
        mask = mask.astype(np.uint8)
    if len(depth_image.shape) == 3:
        if mask.shape != depth_image.shape[:2]:
            mask = cv2.resize(mask, (depth_image.shape[1], depth_image.shape[0]), interpolation=cv2.INTER_NEAREST)
        result = depth_image.copy()
        for channel in range(result.shape[2]):
            result[:, :, channel] = cv2.bitwise_and(result[:, :, channel], result[:, :, channel], mask=mask)
        result_bgra = cv2.cvtColor(result, cv2.COLOR_BGR2BGRA)
        result_bgra[:, :, 3] = mask
    else:
        if mask.shape != depth_image.shape:
            mask = cv2.resize(mask, (depth_image.shape[1], depth_image.shape[0]), interpolation=cv2.INTER_NEAREST)
        result = cv2.bitwise_and(depth_image, depth_image, mask=mask)
        result_bgra = cv2.cvtColor(result, cv2.COLOR_GRAY2BGRA)
        result_bgra[:, :, 3] = mask
    return result_bgra

In [6]:
def process_subfolder(subdir, output_subdir):
    output_subdir.mkdir(parents=True, exist_ok=True)

    # Modified pattern to match files with incremental numbers
    image_files = sorted(subdir.glob("frame_*_rgb_crop_*.png"))
    for rgb_file in tqdm(image_files, desc=f"Processing {subdir.name}"):
        # Extract frame number and crop number from filename
        stem = rgb_file.stem  # e.g., "frame_0_rgb_crop_1"
        depth_stem = stem.replace("_rgb_crop_", "_depth_crop_")
        depth_file = rgb_file.with_name(depth_stem + rgb_file.suffix)

        if not depth_file.exists():
            print(f"Warning: No corresponding depth image for {rgb_file.name}")
            continue

        # Read images
        rgb_image = cv2.imread(str(rgb_file))
        depth_image = cv2.imread(str(depth_file), cv2.IMREAD_UNCHANGED)

        # Remove background with SAM
        rgb_no_bg, mask_binary = remove_background(rgb_image)
        depth_no_bg = apply_mask_to_depth(depth_image, mask_binary)

        # Save output images
        cv2.imwrite(str(output_subdir / rgb_file.name), rgb_no_bg)
        cv2.imwrite(str(output_subdir / depth_file.name), depth_no_bg)

In [7]:
def process_images(input_folder, output_folder):
    torch.cuda.empty_cache()
    input_path = Path(input_folder)
    output_path = Path(output_folder)
    
    if not input_path.exists():
        print("Input folder does not exist.")
        return
    
    for subdir in tqdm(list(input_path.glob('*')), desc="Processing folders"):
        if subdir.is_dir():
            for nested_subdir in subdir.glob('*'):
                if nested_subdir.is_dir():
                    process_subfolder(nested_subdir, output_path / nested_subdir.relative_to(input_path))

In [8]:
if __name__ == "__main__":
    input_folder = ".\\April_05_2025-Batch3\\outputs_1-2"
    output_folder = ".\\April_05_2025-Batch3\\background_remove_output_segregated"
    process_images(input_folder, output_folder)

Processing folders:   0%|          | 0/2 [00:00<?, ?it/s]







Processing long: 100%|██████████| 388/388 [21:30<00:00,  3.33s/it]




Processing short: 100%|██████████| 248/248 [13:56<00:00,  3.37s/it]












Processing twirl: 100%|██████████| 364/364 [20:08<00:00,  3.32s/it]
Processing folders:  50%|█████     | 1/2 [55:35<55:35, 3335.20s/it]



















Processing long: 100%|██████████| 564/564 [32:02<00:00,  3.41s/it]
























Processing short: 100%|██████████| 573/573 [32:25<00:00,  3.40s/it]
















Processing twirl: 100%|██████████| 538/538 [29:54<00:00,  3.33s/it]
Processing folders: 100%|██████████| 2/2 [2:29:57<00:00, 4498.70s/it]
