# Background Removal using Mask R-CNN in PyTorch - Part 1: Core Implementation

This notebook demonstrates how to use Mask R-CNN to replace image backgrounds with black. The implementation is based on PyTorch and the Torchvision library.

## Dataset Overview
- Total number of images: 4,431
- Number of categories: 413
- Date range: 2025-02-17
- Image formats: JPEG, MPO
- Average dimensions: 1197x1308 pixels
- Average file size: 265.03 KB

## Data Quality Issues
- Missing values: 5,416 total across all columns
- Duplicate files: 0
- Date inconsistencies: 0
- Size outliers: 45
- Dimension outliers: 16

## 1. Setup and Import Libraries

In [None]:
# Install required packages if not already installed
!pip install torch torchvision matplotlib opencv-python tqdm

In [None]:
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import torch
import torchvision
from torchvision.models.detection import maskrcnn_resnet50_fpn, MaskRCNN_ResNet50_FPN_Weights
from PIL import Image
from torchvision import transforms

# Check if CUDA is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## 2. Load Mask R-CNN Model

In [None]:
# Load pre-trained Mask R-CNN model
weights = MaskRCNN_ResNet50_FPN_Weights.DEFAULT
model = maskrcnn_resnet50_fpn(weights=weights)
model.to(device)
model.eval()

# Get COCO class names
COCO_INSTANCE_CATEGORY_NAMES = weights.meta["categories"]
print(f"Model loaded with {len(COCO_INSTANCE_CATEGORY_NAMES)} classes")

## 3. Image Preprocessing and Prediction Functions

In [None]:
def load_image(img_path):
    """Load an image and convert to RGB"""
    try:
        img = Image.open(img_path).convert("RGB")
        return img
    except Exception as e:
        print(f"Error loading image {img_path}: {e}")
        return None

def preprocess_image(img):
    """Preprocess image for the model"""
    # Define transforms
    transform = transforms.Compose([
        transforms.ToTensor()
    ])
    
    # Apply transforms
    img_tensor = transform(img)
    return img_tensor

def get_prediction(img_tensor, threshold=0.5):
    """Get model prediction"""
    with torch.no_grad():
        prediction = model([img_tensor.to(device)])
        
    return prediction[0]

## 4. Background Removal Function

In [None]:
def remove_background(img_path, confidence_threshold=0.7, save_path=None):
    """Remove background from image using Mask R-CNN"""
    # Load and preprocess image
    img = load_image(img_path)
    if img is None:
        return None, None
        
    img_tensor = preprocess_image(img)
    
    # Get predictions
    prediction = get_prediction(img_tensor)
    
    # Convert PIL image to numpy array for processing
    img_np = np.array(img)
    
    # Create a blank mask for the combined objects
    height, width = img_np.shape[:2]
    combined_mask = np.zeros((height, width), dtype=np.uint8)
    
    # Extract masks based on confidence score
    masks = prediction['masks']
    scores = prediction['scores']
    
    # Use a threshold to filter out low-confidence predictions
    high_confidence_masks = masks[scores > confidence_threshold]
    
    if len(high_confidence_masks) == 0:
        print(f"No objects detected with confidence threshold {confidence_threshold}")
        return img_np, combined_mask
    
    # Combine all high-confidence masks
    for mask_tensor in high_confidence_masks:
        mask = mask_tensor[0].cpu().numpy() > 0.5  # Convert to binary mask
        mask = mask.astype(np.uint8) * 255
        combined_mask = cv2.bitwise_or(combined_mask, mask)
    
    # Create a 3-channel mask for bitwise operations
    mask_3ch = cv2.merge([combined_mask, combined_mask, combined_mask])
    
    # Apply mask to keep foreground and make background black
    result = cv2.bitwise_and(img_np, mask_3ch)
    
    # Save result if path is provided
    if save_path:
        cv2.imwrite(save_path, cv2.cvtColor(result, cv2.COLOR_RGB2BGR))
    
    return result, combined_mask

## 5. Visualize Results

In [None]:
def visualize_results(original_img, processed_img, mask):
    """Visualize original image, mask, and processed image"""
    plt.figure(figsize=(15, 5))
    
    plt.subplot(1, 3, 1)
    plt.title('Original Image')
    plt.imshow(original_img)
    plt.axis('off')
    
    plt.subplot(1, 3, 2)
    plt.title('Segmentation Mask')
    plt.imshow(mask, cmap='gray')
    plt.axis('off')
    
    plt.subplot(1, 3, 3)
    plt.title('Black Background Result')
    plt.imshow(processed_img)
    plt.axis('off')
    
    plt.tight_layout()
    plt.show()

## 6. Process a Single Image

Let's test our implementation on a single image

In [None]:
# Replace with your image path
sample_img_path = "path/to/sample_image.jpg"  # Change this to your image path

# Check if the file exists
if os.path.exists(sample_img_path):
    # Load original image
    original_img = np.array(load_image(sample_img_path))
    
    # Process image
    processed_img, mask = remove_background(sample_img_path, confidence_threshold=0.7)
    
    # Visualize results
    visualize_results(original_img, processed_img, mask)
else:
    print(f"Image not found at {sample_img_path}. Please provide a valid image path.")

## 7. Advanced Post-Processing (Optional)

The basic Mask R-CNN segmentation might have imperfections at object boundaries. We can use GrabCut to refine the masks for better results.

In [None]:
def refine_mask_grabcut(img, mask, iterations=5):
    """Refine mask using GrabCut algorithm"""
    # Create a mask for GrabCut
    # 0: background, 1: foreground, 2: probable background, 3: probable foreground
    grabcut_mask = np.zeros(mask.shape, dtype=np.uint8)
    
    # Set mask values
    grabcut_mask[mask == 0] = 0  # Set definite background
    grabcut_mask[mask > 0] = 1    # Set definite foreground
    
    # Create background and foreground models
    bgdModel = np.zeros((1, 65), np.float64)
    fgdModel = np.zeros((1, 65), np.float64)
    
    # Apply GrabCut
    rect = (0, 0, img.shape[1], img.shape[0])  # Full image rectangle
    cv2.grabCut(img, grabcut_mask, rect, bgdModel, fgdModel, iterations, cv2.GC_INIT_WITH_MASK)
    
    # Create mask where foreground and probable foreground pixels are 1
    refined_mask = np.where((grabcut_mask == 1) | (grabcut_mask == 3), 255, 0).astype('uint8')
    
    return refined_mask

In [None]:
def apply_refined_mask(img_path, confidence_threshold=0.7, use_grabcut=True, save_path=None):
    """Process image with optional refinement"""
    # Load and preprocess image
    img = load_image(img_path)
    if img is None:
        return None, None
        
    img_np = np.array(img)
    
    # Get basic mask
    processed_img, mask = remove_background(img_path, confidence_threshold)
    
    if use_grabcut and np.any(mask > 0):  # Only apply GrabCut if there's a mask
        # Refine mask with GrabCut
        refined_mask = refine_mask_grabcut(img_np, mask)
        
        # Create 3-channel mask
        mask_3ch = cv2.merge([refined_mask, refined_mask, refined_mask])
        
        # Apply refined mask
        refined_result = cv2.bitwise_and(img_np, mask_3ch)
        
        # Save result if path is provided
        if save_path:
            cv2.imwrite(save_path, cv2.cvtColor(refined_result, cv2.COLOR_RGB2BGR))
        
        return refined_result, refined_mask
    else:
        return processed_img, mask

In [None]:
# Test advanced processing on the same image if it exists
if os.path.exists(sample_img_path):
    # Load original image
    original_img = np.array(load_image(sample_img_path))
    
    # Process image with refinement
    refined_img, refined_mask = apply_refined_mask(sample_img_path, confidence_threshold=0.7, use_grabcut=True)
    
    # Visualize results
    visualize_results(original_img, refined_img, refined_mask)
else:
    print(f"Image not found at {sample_img_path}. Please provide a valid image path.")

## 8. Comparing Results: Basic vs. Refined

Let's compare the results of basic Mask R-CNN segmentation with the refined GrabCut version

In [None]:
def compare_results(img_path, confidence_threshold=0.7):
    """Compare basic and refined background removal"""
    if not os.path.exists(img_path):
        print(f"Image not found at {img_path}")
        return
    
    # Load original image
    original_img = np.array(load_image(img_path))
    
    # Basic processing
    basic_img, basic_mask = remove_background(img_path, confidence_threshold)
    
    # Refined processing
    refined_img, refined_mask = apply_refined_mask(img_path, confidence_threshold, use_grabcut=True)
    
    # Visualize comparison
    plt.figure(figsize=(15, 10))
    
    # Original
    plt.subplot(2, 2, 1)
    plt.title('Original Image')
    plt.imshow(original_img)
    plt.axis('off')
    
    # Basic mask
    plt.subplot(2, 2, 2)
    plt.title('Basic Mask R-CNN Mask')
    plt.imshow(basic_mask, cmap='gray')
    plt.axis('off')
    
    # Basic result
    plt.subplot(2, 2, 3)
    plt.title('Basic Background Removal')
    plt.imshow(basic_img)
    plt.axis('off')
    
    # Refined result
    plt.subplot(2, 2, 4)
    plt.title('Refined Background Removal (GrabCut)')
    plt.imshow(refined_img)
    plt.axis('off')
    
    plt.tight_layout()
    plt.show()

In [None]:
# Compare results if the sample image exists
if os.path.exists(sample_img_path):
    compare_results(sample_img_path, confidence_threshold=0.7)
else:
    print(f"Image not found at {sample_img_path}. Please provide a valid image path.")

## 9. Conclusion and Next Steps

In this notebook, we've implemented basic and refined background removal using Mask R-CNN and GrabCut algorithms. The key components include:

1. Loading a pre-trained Mask R-CNN model
2. Processing images to detect objects and generate segmentation masks
3. Converting masks to binary format for background removal
4. Applying GrabCut refinement for improved edge quality
5. Visualizing and comparing the results

For batch processing and handling dataset quality issues, please refer to Part 2 of this notebook series.