# Room Object Removal Pipeline
Remove furniture and objects from room photos to create empty room images using YOLO + SAM + LaMa

In [None]:
# Cell 1: Setup - Install packages and download models
!pip install ultralytics>=8.0.0 opencv-python>=4.8.0 segment-anything lama-cleaner
!pip install torch torchvision numpy pillow matplotlib
!wget -q https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth

In [None]:
# Cell 2: Import libraries and configure parameters
import cv2
import numpy as np
import matplotlib.pyplot as plt
import torch
from google.colab import files

from ultralytics import YOLO
from segment_anything import sam_model_registry, SamPredictor
from lama_cleaner.model.lama import LaMa
from lama_cleaner.schema import Config, HDStrategy

# Device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

# ===== CONFIGURATION PARAMETERS =====
# Adjust these to fine-tune results

# Detection
YOLO_MODEL = 'yolov8l.pt'  # Options: yolov8n, yolov8s, yolov8m, yolov8l, yolov8x
CONFIDENCE_THRESHOLD = 0.002

# Objects to remove
REMOVE_CLASSES = ['person', 'chair', 'couch', 'bed', 'dining table', 'toilet', 
                  'tv', 'laptop', 'mouse', 'keyboard', 'cell phone', 'book',
                  'clock', 'vase', 'teddy bear', 'potted plant', 'suitcase',
                  'handbag', 'backpack', 'umbrella', 'bottle', 'cup', 'bowl',
                  'refrigerator', 'oven', 'microwave', 'toaster', 'sink']

# Mask processing
DILATE_KERNEL_SIZE = 20
DILATE_ITERATIONS = 7

# Color correction
LIGHTNESS_CORRECTION = 1
COLOR_CORRECTION = 1

In [None]:
# Cell 3: Upload and load image
uploaded = files.upload()
image_path = list(uploaded.keys())[0]

# Load image
image = cv2.imread(image_path)
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

plt.figure(figsize=(10, 8))
plt.imshow(image_rgb)
plt.title("Original Image")
plt.axis('off')
plt.show()

print(f"Image size: {image_rgb.shape[1]} x {image_rgb.shape[0]}")

In [None]:
# Cell 4: Detect objects with smart filtering
model = YOLO(YOLO_MODEL)
results = model(image_rgb, conf=CONFIDENCE_THRESHOLD)

# Smart filtering to avoid removing architectural features
detections = []
excluded = []

for r in results:
    for box in r.boxes:
        class_name = r.names[int(box.cls)]
        confidence = float(box.conf)
        
        if class_name in REMOVE_CLASSES:
            bbox = box.xyxy[0].cpu().numpy()
            x1, y1, x2, y2 = bbox
            h, w = image_rgb.shape[:2]
            
            # Calculate useful metrics
            obj_height = y2 - y1
            obj_width = x2 - x1
            aspect_ratio = obj_height / obj_width if obj_width > 0 else 0
            area_ratio = (obj_height * obj_width) / (h * w)
            center_y = (y1 + y2) / 2
            
            exclude = False
            reason = ""
            
            # TV → Window detection
            if class_name == 'tv':
                if y1 < h * 0.5 and confidence < 0.5:  # Upper half + low confidence
                    exclude = True
                    reason = "Likely window (position)"
                elif aspect_ratio < 0.5 and y1 < h * 0.6:  # Wide and high up
                    exclude = True
                    reason = "Window-like dimensions"
                    
            # Refrigerator → Door/Wall detection
            elif class_name == 'refrigerator':
                if aspect_ratio > 2.5:  # Very tall and narrow
                    exclude = True
                    reason = "Door-like dimensions"
                elif area_ratio > 0.3:  # Takes up too much space
                    exclude = True
                    reason = "Wall-sized"
                elif confidence < 0.3 and x1 < 50:  # Low confidence near edge
                    exclude = True
                    reason = "Likely wall/door"
            
            # Oven/Microwave → Often architectural
            elif class_name in ['oven', 'microwave']:
                if confidence < 0.4:
                    exclude = True
                    reason = "Low confidence architectural"
                elif y1 < h * 0.3:  # Too high for appliances
                    exclude = True
                    reason = "Position suggests window"
            
            # Sink → Sometimes windows
            elif class_name == 'sink':
                if y1 < h * 0.4 and confidence < 0.5:
                    exclude = True
                    reason = "High position suggests window"
            
            # Bed → Sometimes floors
            elif class_name == 'bed':
                if area_ratio > 0.5:  # Covers most of image
                    exclude = True
                    reason = "Likely floor misdetection"
            
            # Dining table → Sometimes floors
            elif class_name == 'dining table':
                if area_ratio > 0.4 and center_y > h * 0.7:
                    exclude = True
                    reason = "Likely floor pattern"
            
            if exclude:
                excluded.append({
                    'bbox': bbox,
                    'class': class_name,
                    'conf': confidence,
                    'reason': reason
                })
            else:
                detections.append({
                    'bbox': bbox,
                    'class': class_name,
                    'conf': confidence
                })

print(f"\n📊 Detection Summary:")
print(f"  ✓ Objects to remove: {len(detections)}")
print(f"  ✗ Protected areas: {len(excluded)}")

# Visualize
annotated = results[0].plot()

# Draw excluded items in yellow with reason
for exc in excluded:
    x1, y1, x2, y2 = exc['bbox'].astype(int)
    cv2.rectangle(annotated, (x1, y1), (x2, y2), (0, 255, 255), 3)
    label = f"KEEP: {exc['reason']}"
    cv2.putText(annotated, label, (x1, y1-10), 
                cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 255), 2)

plt.figure(figsize=(12, 8))
plt.imshow(annotated)
plt.title("Smart Object Detection (Yellow = Protected from removal)")
plt.axis('off')
plt.show()

# Detailed report
if excluded:
    print("\n🛡️ Protected from removal:")
    for e in excluded:
        print(f"  - {e['class']} ({e['conf']:.2f}) → {e['reason']}")

if detections:
    print("\n🗑️ Will remove:")
    object_counts = {}
    for d in detections:
        obj_class = d['class']
        object_counts[obj_class] = object_counts.get(obj_class, 0) + 1
    for obj_class, count in sorted(object_counts.items()):
        print(f"  - {count}x {obj_class}")

In [None]:
# Cell 5: Create precise masks using SAM
sam = sam_model_registry["vit_b"](checkpoint="sam_vit_b_01ec64.pth")
sam.to(device=device)
predictor = SamPredictor(sam)
predictor.set_image(image_rgb)

# Create combined mask
h, w = image_rgb.shape[:2]
combined_mask = np.zeros((h, w), dtype=np.uint8)

for detection in detections:
    bbox = detection['bbox'].astype(int)
    input_box = np.array([bbox[0], bbox[1], bbox[2], bbox[3]])
    
    masks, _, _ = predictor.predict(
        box=input_box,
        multimask_output=True
    )
    
    best_mask = masks[np.argmax(masks.sum(axis=(1, 2)))]
    combined_mask = np.logical_or(combined_mask, best_mask).astype(np.uint8)

# Convert and dilate
combined_mask = combined_mask * 255
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (DILATE_KERNEL_SIZE, DILATE_KERNEL_SIZE))
dilated_mask = cv2.dilate(combined_mask, kernel, iterations=DILATE_ITERATIONS)

# Display
plt.figure(figsize=(10, 8))
plt.imshow(dilated_mask, cmap='gray')
plt.title("Objects to Remove (Mask)")
plt.axis('off')
plt.show()

In [None]:
# Cell 6: Quick preview with OpenCV inpainting
inpainted = cv2.inpaint(image_rgb, dilated_mask, 5, cv2.INPAINT_NS)

plt.figure(figsize=(15, 5))
plt.subplot(1, 2, 1)
plt.imshow(image_rgb)
plt.title("Original")
plt.axis('off')

plt.subplot(1, 2, 2)
plt.imshow(inpainted)
plt.title("OpenCV Inpainting (Quick Preview)")
plt.axis('off')
plt.show()

In [None]:
# Cell 7: High-quality inpainting with LaMa
print("Loading LaMa model...")
lama_model = LaMa(device)

# LaMa configuration
config = Config(
    ldm_steps=1,
    ldm_sampler='plms',
    hd_strategy=HDStrategy.RESIZE,
    hd_strategy_crop_margin=32,
    hd_strategy_crop_trigger_size=800,
    hd_strategy_resize_limit=1024,
)

# Ensure binary mask
binary_mask = (dilated_mask > 127).astype(np.uint8) * 255

# Run LaMa
print("Running LaMa inpainting...")
lama_result = lama_model(image_rgb, binary_mask, config)

# Fix format conversion
if lama_result.dtype != np.uint8:
    if lama_result.max() <= 1.0:
        lama_result = (lama_result * 255).astype(np.uint8)
    else:
        lama_result = np.clip(lama_result, 0, 255).astype(np.uint8)

print(f"LaMa result: dtype={lama_result.dtype}, range=[{lama_result.min()}, {lama_result.max()}]")

# Color correction
lama_lab = cv2.cvtColor(lama_result, cv2.COLOR_RGB2LAB).astype(float)
original_lab = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2LAB).astype(float)

mask_inv = cv2.bitwise_not(binary_mask)
for i in range(3):
    original_mean = original_lab[:,:,i][mask_inv > 0].mean()
    lama_mean = lama_lab[:,:,i].mean()
    
    if i == 0:  # Lightness
        shift = original_mean - lama_mean
        lama_lab[:,:,i] += shift * LIGHTNESS_CORRECTION
    else:  # Color
        shift = original_mean - lama_mean  
        lama_lab[:,:,i] += shift * COLOR_CORRECTION

lama_lab = np.clip(lama_lab, 0, 255)
color_corrected = cv2.cvtColor(lama_lab.astype(np.uint8), cv2.COLOR_LAB2RGB)

# Display results
plt.figure(figsize=(15, 5))
plt.subplot(1, 3, 1)
plt.imshow(image_rgb)
plt.title("Original", fontsize=14)
plt.axis('off')

plt.subplot(1, 3, 2)
plt.imshow(lama_result)
plt.title("LaMa Result", fontsize=14)
plt.axis('off')

plt.subplot(1, 3, 3)
plt.imshow(color_corrected)
plt.title("Color Corrected", fontsize=14)
plt.axis('off')

plt.tight_layout()
plt.show()

best_result = color_corrected
print("First pass complete!")

In [None]:
# Cell 8: Save the cleaned room image
# Save the result
output_path = 'empty_room.jpg'
cv2.imwrite(output_path, cv2.cvtColor(best_result, cv2.COLOR_RGB2BGR))

# Download the result
files.download(output_path)

print(f"Empty room image saved as: {output_path}")