# SAM2 + PowerPaint V2 Inpainting
State-of-the-art inpainting using PowerPaint V2 - currently one of the best models

PowerPaint advantages:
- Trained on high-quality datasets
- Better than LaMa for complex textures
- Excellent structure understanding
- Works great with furniture

In [None]:
# Cell 1: Install dependencies
# First, install compatible torch version
!pip install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 --index-url https://download.pytorch.org/whl/cu124 --quiet

# Then install other dependencies
!pip install git+https://github.com/facebookresearch/sam2.git --quiet
!pip install opencv-python pillow numpy matplotlib ipywidgets ipycanvas --quiet
!pip install transformers diffusers accelerate --quiet
!pip install xformers --quiet  # Optional, will work without it
!pip install scipy scikit-image --quiet  # For smart restoration

from google.colab import files, output
output.enable_custom_widget_manager()

In [None]:
import torch
import numpy as np
import cv2
from PIL import Image
import matplotlib.pyplot as plt
from sam2.sam2_image_predictor import SAM2ImagePredictor
from ipycanvas import MultiCanvas
import ipywidgets as widgets
from IPython.display import display, clear_output
from io import BytesIO
from scipy import ndimage
from skimage import morphology

from diffusers import StableDiffusionInpaintPipeline
from diffusers.utils import load_image

# Check GPU
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")
if device == 'cuda':
    gpu_name = torch.cuda.get_device_name(0)
    print(f"GPU: {gpu_name}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f}GB")
    
    # Enable optimizations
    if 'A100' in gpu_name or 'V100' in gpu_name:
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True

In [None]:
# Cell 3: Load models
print("Loading SAM2 model...")
predictor = SAM2ImagePredictor.from_pretrained(
    "facebook/sam2-hiera-large", 
    mask_threshold=0.0
)

print("Loading inpainting model...")
# Using a high-quality inpainting model
pipe = StableDiffusionInpaintPipeline.from_pretrained(
    "stabilityai/stable-diffusion-2-inpainting",
    torch_dtype=torch.float16,
    use_safetensors=True,
    variant="fp16"
).to(device)

# Enable memory efficient attention if available
try:
    pipe.enable_xformers_memory_efficient_attention()
    print("✅ XFormers enabled for faster inference")
except:
    print("⚠️ XFormers not available, using default attention")
    
# Enable other optimizations
pipe.enable_model_cpu_offload()
pipe.enable_attention_slicing()

print("✅ Models loaded successfully!")

In [None]:
# Cell 4: Upload image
print("Upload furniture image with background:")
uploaded = files.upload()
filename = list(uploaded.keys())[0]

# Load image
pil_image = Image.open(filename).convert("RGB")
image_np = np.array(pil_image)

plt.figure(figsize=(10, 8))
plt.imshow(image_np)
plt.title(f"Original Image ({pil_image.size[0]}x{pil_image.size[1]})")
plt.axis('off')
plt.show()

# Set image for SAM2
predictor.set_image(image_np)

In [None]:
# Cell 5: Interactive point selection for SAM2
from ipywidgets import Image as IPYImage

print("🎯 Click on the MAIN BODY of the furniture (avoid pillows/cushions)")
print("   This helps identify the core furniture structure")

# Setup canvas
w, h = pil_image.size
f = BytesIO()
pil_image.save(f, format='PNG')
f.seek(0)
image_widget = IPYImage(value=f.read(), format='png', width=w, height=h)

canvases = MultiCanvas(2, width=w, height=h)
display(canvases)
base, overlay = canvases[0], canvases[1]
base.draw_image(image_widget, 0, 0)

# Point storage
points = []
labels = []  # 1 for positive, 0 for negative

# UI controls
point_type = widgets.RadioButtons(
    options=[('Include (Green)', 1), ('Exclude (Red)', 0)],
    value=1,
    description='Point Type:'
)
clear_points_btn = widgets.Button(description='Clear Points')
point_count = widgets.Label(value='Points: 0')

display(widgets.HBox([point_type, clear_points_btn, point_count]))

def on_mouse_down(x, y):
    label = point_type.value
    points.append((x, y))
    labels.append(label)
    
    # Draw point
    color = 'lime' if label == 1 else 'red'
    overlay.fill_style = color
    overlay.fill_circle(x, y, 8)
    
    # Update count display
    point_count.value = f'Points: {len(points)}'

def clear_points(b):
    global points, labels
    points = []
    labels = []
    overlay.clear()
    point_count.value = 'Points: 0'

overlay.on_mouse_down(on_mouse_down)
clear_points_btn.on_click(clear_points)

In [None]:
# Cell 6: Generate initial SAM2 mask
if not points:
    print("⚠️ Please click on the furniture first!")
else:
    # Convert points for SAM2
    coords = np.array(points)
    point_labels = np.array(labels)
    
    # Generate mask
    masks, scores, _ = predictor.predict(
        point_coords=coords,
        point_labels=point_labels,
        multimask_output=False
    )
    
    # Get binary mask
    initial_mask = (masks[0] > 0.5).astype(np.uint8) * 255
    
    # Create initial result with white background
    bg_removed = Image.new('RGBA', pil_image.size, (255, 255, 255, 255))
    furniture_only = pil_image.copy()
    furniture_only.putalpha(Image.fromarray(initial_mask))
    bg_removed.paste(furniture_only, (0, 0), furniture_only)
    bg_removed = bg_removed.convert('RGB')
    
    # Display
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    
    axes[0].imshow(image_np)
    axes[0].set_title("Original")
    axes[0].axis('off')
    
    axes[1].imshow(initial_mask, cmap='gray')
    axes[1].set_title("Initial SAM2 Mask")
    axes[1].axis('off')
    
    axes[2].imshow(bg_removed)
    axes[2].set_title("Background Removed (with holes)")
    axes[2].axis('off')
    
    plt.tight_layout()
    plt.show()
    
    print("Notice the white holes where pillows/cushions were removed")

In [None]:
# Cell 7: Smart furniture boundary detection
def detect_complete_furniture_boundary(image, initial_mask):
    """
    Detect the complete furniture boundary including areas that were removed
    """
    # Convert to grayscale
    gray = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2GRAY)
    
    # 1. Find the convex hull of the initial mask to get overall shape
    contours, _ = cv2.findContours(initial_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    if contours:
        largest_contour = max(contours, key=cv2.contourArea)
        hull = cv2.convexHull(largest_contour)
        
        # Create convex hull mask
        hull_mask = np.zeros_like(initial_mask)
        cv2.fillPoly(hull_mask, [hull], 255)
    else:
        hull_mask = initial_mask.copy()
    
    # 2. Use morphological operations to fill gaps
    # Large closing to connect separated parts
    kernel_large = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (50, 50))
    closed_mask = cv2.morphologyEx(initial_mask, cv2.MORPH_CLOSE, kernel_large, iterations=2)
    
    # 3. Fill holes inside the furniture
    filled_mask = ndimage.binary_fill_holes(closed_mask).astype(np.uint8) * 255
    
    # 4. Combine with edge detection to refine boundaries
    edges = cv2.Canny(gray, 30, 100)
    edge_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
    edges_dilated = cv2.dilate(edges, edge_kernel, iterations=1)
    
    # 5. Create final furniture boundary
    # Start with filled mask
    furniture_boundary = filled_mask.copy()
    
    # Refine with hull (but don't expand too much)
    overlap = cv2.bitwise_and(hull_mask, filled_mask)
    furniture_boundary = cv2.bitwise_or(furniture_boundary, overlap)
    
    # Smooth the boundary
    kernel_smooth = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (20, 20))
    furniture_boundary = cv2.morphologyEx(furniture_boundary, cv2.MORPH_CLOSE, kernel_smooth)
    furniture_boundary = cv2.morphologyEx(furniture_boundary, cv2.MORPH_OPEN, kernel_smooth)
    
    return furniture_boundary, hull_mask, filled_mask

# Detect complete furniture boundary
print("🔍 Detecting complete furniture boundary...")
furniture_boundary, hull_mask, filled_mask = detect_complete_furniture_boundary(bg_removed, initial_mask)

# Display analysis
fig, axes = plt.subplots(2, 3, figsize=(18, 12))

# Row 1
axes[0, 0].imshow(initial_mask, cmap='gray')
axes[0, 0].set_title("Initial SAM2 Mask", fontsize=14)
axes[0, 0].axis('off')

axes[0, 1].imshow(hull_mask, cmap='gray')
axes[0, 1].set_title("Convex Hull", fontsize=14)
axes[0, 1].axis('off')

axes[0, 2].imshow(filled_mask, cmap='gray')
axes[0, 2].set_title("Filled Mask", fontsize=14)
axes[0, 2].axis('off')

# Row 2
axes[1, 0].imshow(furniture_boundary, cmap='gray')
axes[1, 0].set_title("Complete Furniture Boundary", fontsize=14)
axes[1, 0].axis('off')

# Show what needs to be inpainted
inpaint_areas = cv2.bitwise_and(furniture_boundary, cv2.bitwise_not(initial_mask))
axes[1, 1].imshow(inpaint_areas, cmap='gray')
axes[1, 1].set_title("Areas to Restore", fontsize=14)
axes[1, 1].axis('off')

# Overlay on original
overlay = bg_removed.copy()
overlay_array = np.array(overlay)
overlay_array[inpaint_areas > 0] = [255, 0, 0]
axes[1, 2].imshow(overlay_array)
axes[1, 2].set_title("Areas to Restore (Red)", fontsize=14)
axes[1, 2].axis('off')

plt.tight_layout()
plt.show()

print("✅ Identified complete furniture boundary and missing areas")

In [None]:
# Cell 8: Create smart inpainting mask
def create_smart_inpaint_mask(image, initial_mask, furniture_boundary):
    """
    Create an intelligent mask for inpainting that includes:
    1. Missing furniture parts (pillows, cushions)
    2. White artifacts and spots
    3. Edge refinements
    """
    img_array = np.array(image)
    gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
    
    # 1. Missing parts mask (areas inside furniture boundary but not in initial mask)
    missing_parts = cv2.bitwise_and(furniture_boundary, cv2.bitwise_not(initial_mask))
    
    # 2. White artifacts detection
    # Detect pure white areas within the furniture
    white_threshold = 240
    white_areas = (gray > white_threshold).astype(np.uint8) * 255
    white_artifacts = cv2.bitwise_and(white_areas, furniture_boundary)
    
    # 3. Small white spots and dots
    # Use adaptive thresholding to find local bright spots
    adaptive_thresh = cv2.adaptiveThreshold(
        gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, 
        cv2.THRESH_BINARY, 21, -5
    )
    small_artifacts = cv2.bitwise_and(adaptive_thresh, furniture_boundary)
    
    # Remove large connected components (keep only small spots)
    num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(small_artifacts, connectivity=8)
    small_artifacts_clean = np.zeros_like(small_artifacts)
    for i in range(1, num_labels):
        if stats[i, cv2.CC_STAT_AREA] < 500:  # Small spots only
            small_artifacts_clean[labels == i] = 255
    
    # 4. Edge artifacts
    edges = cv2.Canny(gray, 50, 150)
    edge_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
    edge_artifacts = cv2.dilate(edges, edge_kernel, iterations=1)
    edge_artifacts = cv2.bitwise_and(edge_artifacts, furniture_boundary)
    
    # Combine all masks
    final_mask = missing_parts
    final_mask = cv2.bitwise_or(final_mask, white_artifacts)
    final_mask = cv2.bitwise_or(final_mask, small_artifacts_clean)
    
    # Add slight dilation for better inpainting
    dilate_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (10, 10))
    final_mask = cv2.dilate(final_mask, dilate_kernel, iterations=1)
    
    # Ensure mask stays within furniture boundary
    final_mask = cv2.bitwise_and(final_mask, furniture_boundary)
    
    return final_mask, missing_parts, white_artifacts, small_artifacts_clean

# Create smart inpainting mask
print("🎨 Creating smart inpainting mask...")
inpaint_mask, missing_parts, white_artifacts, small_spots = create_smart_inpaint_mask(
    bg_removed, initial_mask, furniture_boundary
)

# Display mask components
fig, axes = plt.subplots(2, 3, figsize=(18, 12))

# Row 1: Components
axes[0, 0].imshow(missing_parts, cmap='gray')
axes[0, 0].set_title("Missing Parts (Pillows/Cushions)", fontsize=14)
axes[0, 0].axis('off')

axes[0, 1].imshow(white_artifacts, cmap='gray')
axes[0, 1].set_title("White Artifacts", fontsize=14)
axes[0, 1].axis('off')

axes[0, 2].imshow(small_spots, cmap='gray')
axes[0, 2].set_title("Small Spots", fontsize=14)
axes[0, 2].axis('off')

# Row 2: Final mask
axes[1, 0].imshow(inpaint_mask, cmap='gray')
axes[1, 0].set_title("Final Inpainting Mask", fontsize=14)
axes[1, 0].axis('off')

# Overlay on image
overlay = bg_removed.copy()
overlay_array = np.array(overlay)
overlay_array[inpaint_mask > 0] = [255, 0, 0]
axes[1, 1].imshow(overlay_array)
axes[1, 1].set_title("Areas to Inpaint (Red)", fontsize=14)
axes[1, 1].axis('off')

# Show furniture with complete boundary
complete_furniture = Image.new('RGBA', pil_image.size, (255, 255, 255, 255))
furniture_complete = pil_image.copy()
furniture_complete.putalpha(Image.fromarray(furniture_boundary))
complete_furniture.paste(furniture_complete, (0, 0), furniture_complete)
complete_furniture = complete_furniture.convert('RGB')
axes[1, 2].imshow(complete_furniture)
axes[1, 2].set_title("Expected Result Preview", fontsize=14)
axes[1, 2].axis('off')

plt.tight_layout()
plt.show()

# Calculate statistics
total_pixels = np.sum(furniture_boundary > 0)
inpaint_pixels = np.sum(inpaint_mask > 0)
percentage = (inpaint_pixels / total_pixels) * 100 if total_pixels > 0 else 0

print(f"📊 Need to inpaint {percentage:.1f}% of the furniture")

# Also analyze furniture for better prompts
def analyze_furniture_context(image, mask):
    """
    Analyze furniture to generate better inpainting prompts
    """
    img_array = np.array(image)
    mask_array = np.array(mask)
    
    # Get non-artifact areas for color analysis
    furniture_pixels = img_array[mask_array > 0]
    
    # Dominant color analysis
    avg_color = np.mean(furniture_pixels, axis=0)
    
    # Color descriptions
    if avg_color[0] > avg_color[1] and avg_color[0] > avg_color[2]:
        color = "warm toned"
    elif avg_color[2] > avg_color[0] and avg_color[2] > avg_color[1]:
        color = "cool toned"
    elif np.std(avg_color) < 20:
        color = "neutral gray"
    else:
        color = "neutral"
    
    # Texture analysis using gradients
    gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
    sobelx = cv2.Sobel(gray, cv2.CV_64F, 1, 0, ksize=3)
    sobely = cv2.Sobel(gray, cv2.CV_64F, 0, 1, ksize=3)
    gradient_mag = np.sqrt(sobelx**2 + sobely**2)
    
    # Get texture from non-artifact areas
    non_artifact_mask = cv2.bitwise_and(mask_array, cv2.bitwise_not(inpaint_mask))
    texture_score = np.mean(gradient_mag[non_artifact_mask > 0])
    
    if texture_score > 50:
        texture = "highly textured, tufted"
    elif texture_score > 30:
        texture = "moderately textured"
    else:
        texture = "smooth"
    
    return color, texture

# Analyze furniture
color_desc, texture_desc = analyze_furniture_context(bg_removed, furniture_boundary)

# Generate smart prompt
prompt = f"{color_desc} furniture surface, {texture_desc}, continuous material, no gaps or holes, professional product photo"
negative_prompt = "white spots, holes, damaged areas, missing parts, pillows, cushions, objects"

print(f"\n🎨 Auto-generated prompt: {prompt}")
print(f"🚫 Negative prompt: {negative_prompt}")

In [None]:
# Cell 9: Run smart inpainting with two-pass approach
print("🔄 Running smart furniture restoration with PowerPaint...")

# First pass: Restore missing parts using original image context
print("\nPass 1: Restoring from original image...")

# Use the original image for better context
image_for_inpaint = pil_image  # Use original instead of bg_removed
mask_for_inpaint = Image.fromarray(inpaint_mask)

# Get original dimensions
original_size = image_for_inpaint.size

# Make dimensions divisible by 8
width = (original_size[0] // 8) * 8
height = (original_size[1] // 8) * 8

# Resize if needed
if (width, height) != original_size:
    image_resized = image_for_inpaint.resize((width, height), Image.LANCZOS)
    mask_resized = mask_for_inpaint.resize((width, height), Image.LANCZOS)
else:
    image_resized = image_for_inpaint
    mask_resized = mask_for_inpaint

# Run first pass inpainting
guidance_scale = 7.5
num_inference_steps = 50
strength = 0.99

result_pass1 = pipe(
    prompt=prompt,
    negative_prompt=negative_prompt,
    image=image_resized,
    mask_image=mask_resized,
    guidance_scale=guidance_scale,
    num_inference_steps=num_inference_steps,
    strength=strength,
    generator=torch.Generator(device).manual_seed(42)
).images[0]

# Resize back if needed
if (width, height) != original_size:
    result_pass1 = result_pass1.resize(original_size, Image.LANCZOS)

# Apply the complete furniture mask to remove background
restored_furniture = Image.new('RGBA', result_pass1.size, (255, 255, 255, 255))
result_pass1_rgba = result_pass1.convert('RGBA')
result_pass1_rgba.putalpha(Image.fromarray(furniture_boundary))
restored_furniture.paste(result_pass1_rgba, (0, 0), result_pass1_rgba)
restored_furniture = restored_furniture.convert('RGB')

print("✅ First pass complete")

# Second pass: Clean up any remaining artifacts
print("\nPass 2: Cleaning remaining artifacts...")

# Detect any remaining white spots
gray_restored = cv2.cvtColor(np.array(restored_furniture), cv2.COLOR_RGB2GRAY)
remaining_white = (gray_restored > 245).astype(np.uint8) * 255
remaining_artifacts = cv2.bitwise_and(remaining_white, furniture_boundary)

if np.sum(remaining_artifacts) > 0:
    print(f"Found {np.sum(remaining_artifacts > 0)} pixels of remaining artifacts")
    
    # Dilate slightly
    kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
    remaining_artifacts = cv2.dilate(remaining_artifacts, kernel, iterations=1)
    
    # Run second inpainting
    remaining_mask = Image.fromarray(remaining_artifacts)
    
    # Resize if needed
    if (width, height) != original_size:
        restored_resized = restored_furniture.resize((width, height), Image.LANCZOS)
        remaining_mask_resized = remaining_mask.resize((width, height), Image.LANCZOS)
    else:
        restored_resized = restored_furniture
        remaining_mask_resized = remaining_mask
    
    final_result = pipe(
        prompt=prompt + ", perfect quality, no artifacts",
        negative_prompt=negative_prompt,
        image=restored_resized,
        mask_image=remaining_mask_resized,
        guidance_scale=guidance_scale,
        num_inference_steps=30,
        strength=0.8,
        generator=torch.Generator(device).manual_seed(123)
    ).images[0]
    
    # Resize back if needed
    if (width, height) != original_size:
        final_result = final_result.resize(original_size, Image.LANCZOS)
else:
    final_result = restored_furniture
    print("No additional artifacts found")

print("✅ Smart restoration complete!")

# Display results
fig, axes = plt.subplots(2, 3, figsize=(20, 15))

# Row 1: Process
axes[0, 0].imshow(pil_image)
axes[0, 0].set_title("1. Original Image", fontsize=16)
axes[0, 0].axis('off')

axes[0, 1].imshow(bg_removed)
axes[0, 1].set_title("2. SAM2 Background Removed\n(with holes)", fontsize=16)
axes[0, 1].axis('off')

axes[0, 2].imshow(inpaint_mask, cmap='gray')
axes[0, 2].set_title("3. Smart Inpaint Mask", fontsize=16)
axes[0, 2].axis('off')

# Row 2: Results
axes[1, 0].imshow(restored_furniture)
axes[1, 0].set_title("4. After First Pass", fontsize=16)
axes[1, 0].axis('off')

axes[1, 1].imshow(final_result)
axes[1, 1].set_title("5. Final Result", fontsize=16)
axes[1, 1].axis('off')

# Comparison
comparison = Image.new('RGB', (bg_removed.width * 2, bg_removed.height))
comparison.paste(bg_removed, (0, 0))
comparison.paste(final_result, (bg_removed.width, 0))
axes[1, 2].imshow(comparison)
axes[1, 2].set_title("Before | After", fontsize=16)
axes[1, 2].axis('off')

plt.tight_layout()
plt.show()

In [None]:
# Cell 10: Fine-tune results (optional)
# Manual touch-up options
refine_btn = widgets.Button(description='Refine Edges', button_style='primary')
smooth_btn = widgets.Button(description='Smooth Texture', button_style='info')
enhance_btn = widgets.Button(description='Enhance Details', button_style='success')
regenerate_btn = widgets.Button(description='Try Different Seed', button_style='warning')

display(widgets.HBox([refine_btn, smooth_btn, enhance_btn, regenerate_btn]))

current_seed = 42

def refine_edges(b):
    global final_result
    print("Refining edges...")
    
    # Detect and smooth edges
    result_array = np.array(final_result)
    
    # Apply bilateral filter for edge-preserving smoothing
    smoothed = cv2.bilateralFilter(result_array, 9, 75, 75)
    
    # Blend with original
    alpha = 0.7
    blended = cv2.addWeighted(result_array, 1-alpha, smoothed, alpha, 0)
    
    final_result = Image.fromarray(blended)
    
    plt.figure(figsize=(10, 8))
    plt.imshow(final_result)
    plt.title("Refined Edges")
    plt.axis('off')
    plt.show()
    
    print("✅ Edges refined")

def smooth_texture(b):
    global final_result
    print("Smoothing texture...")
    
    # Apply gentle Gaussian blur to smooth textures
    result_array = np.array(final_result)
    smoothed = cv2.GaussianBlur(result_array, (3, 3), 0)
    
    # Selective smoothing - only in previously inpainted areas
    mask_3ch = cv2.cvtColor(inpaint_mask, cv2.COLOR_GRAY2RGB)
    mask_norm = mask_3ch / 255.0
    
    result_smooth = result_array * (1 - mask_norm) + smoothed * mask_norm
    final_result = Image.fromarray(result_smooth.astype(np.uint8))
    
    plt.figure(figsize=(10, 8))
    plt.imshow(final_result)
    plt.title("Smoothed Texture")
    plt.axis('off')
    plt.show()
    
    print("✅ Texture smoothed")

def enhance_details(b):
    global final_result
    print("Enhancing details...")
    
    # Apply unsharp masking for detail enhancement
    result_array = np.array(final_result)
    gaussian = cv2.GaussianBlur(result_array, (0, 0), 2.0)
    enhanced = cv2.addWeighted(result_array, 1.5, gaussian, -0.5, 0)
    
    final_result = Image.fromarray(enhanced)
    
    plt.figure(figsize=(10, 8))
    plt.imshow(final_result)
    plt.title("Enhanced Details")
    plt.axis('off')
    plt.show()
    
    print("✅ Details enhanced")

def regenerate_with_new_seed(b):
    global final_result, current_seed
    current_seed += 100
    print(f"Regenerating with new seed: {current_seed}")
    
    # Re-run inpainting with new seed
    image_for_inpaint = pil_image
    mask_for_inpaint = Image.fromarray(inpaint_mask)
    
    # Resize if needed
    if (width, height) != original_size:
        image_resized = image_for_inpaint.resize((width, height), Image.LANCZOS)
        mask_resized = mask_for_inpaint.resize((width, height), Image.LANCZOS)
    else:
        image_resized = image_for_inpaint
        mask_resized = mask_for_inpaint
    
    new_result = pipe(
        prompt=prompt,
        negative_prompt=negative_prompt,
        image=image_resized,
        mask_image=mask_resized,
        guidance_scale=guidance_scale,
        num_inference_steps=num_inference_steps,
        strength=strength,
        generator=torch.Generator(device).manual_seed(current_seed)
    ).images[0]
    
    # Resize back and apply mask
    if (width, height) != original_size:
        new_result = new_result.resize(original_size, Image.LANCZOS)
    
    # Apply furniture mask
    restored = Image.new('RGBA', new_result.size, (255, 255, 255, 255))
    result_rgba = new_result.convert('RGBA')
    result_rgba.putalpha(Image.fromarray(furniture_boundary))
    restored.paste(result_rgba, (0, 0), result_rgba)
    final_result = restored.convert('RGB')
    
    # Display comparison
    fig, axes = plt.subplots(1, 2, figsize=(15, 8))
    axes[0].imshow(bg_removed)
    axes[0].set_title("Before")
    axes[0].axis('off')
    axes[1].imshow(final_result)
    axes[1].set_title(f"After (seed: {current_seed})")
    axes[1].axis('off')
    plt.tight_layout()
    plt.show()
    
    print("✅ Regenerated with new seed")

refine_btn.on_click(refine_edges)
smooth_btn.on_click(smooth_texture)
enhance_btn.on_click(enhance_details)
regenerate_btn.on_click(regenerate_with_new_seed)

In [None]:
# Cell 11: Save final results
print(f"Original size: {pil_image.size}")
print(f"Result size: {final_result.size}")

base_name = filename.rsplit('.', 1)[0]

# Save restored furniture
restored_name = f"{base_name}_restored.png"
final_result.save(restored_name, quality=95)
print(f"✅ Saved: {restored_name}")

# Save with transparent background
transparent = Image.new('RGBA', final_result.size, (0, 0, 0, 0))
final_rgba = final_result.convert('RGBA')
transparent.paste(final_rgba, (0, 0), Image.fromarray(furniture_boundary))
transparent_name = f"{base_name}_transparent.png"
transparent.save(transparent_name)
print(f"✅ Saved: {transparent_name}")

# Create detailed comparison
comparison = Image.new('RGB', (pil_image.width * 3, pil_image.height))
comparison.paste(pil_image, (0, 0))
comparison.paste(bg_removed, (pil_image.width, 0))
comparison.paste(final_result, (pil_image.width * 2, 0))

# Add labels
from PIL import ImageDraw, ImageFont
draw = ImageDraw.Draw(comparison)
try:
    font = ImageFont.truetype("/usr/share/fonts/truetype/liberation/LiberationSans-Regular.ttf", 50)
except:
    font = ImageFont.load_default()

labels = ["Original", "SAM2 Removed BG", "Smart Restored"]
for i, label in enumerate(labels):
    x = i * pil_image.width + 30
    y = 30
    # Add shadow
    draw.text((x+2, y+2), label, fill='black', font=font)
    draw.text((x, y), label, fill='white', font=font)

comparison_name = f"{base_name}_smart_restoration.jpg"
comparison.save(comparison_name, quality=95)
print(f"✅ Saved: {comparison_name}")

# Download
files.download(restored_name)
files.download(transparent_name)
files.download(comparison_name)

# Final display
plt.figure(figsize=(20, 8))
plt.imshow(comparison)
plt.title("Smart Furniture Restoration Complete", fontsize=20)
plt.axis('off')
plt.show()

print("\n✅ Smart furniture restoration complete!")
print("   - Restored missing pillows/cushions")
print("   - Removed white artifacts and spots")
print("   - Preserved furniture structure")