# SDXL Inpainting with Google Drive Models

Uses pre-downloaded models from your Drive - no downloading needed!

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# Check if models exist
import os
model_path = '/content/drive/MyDrive/inpainting_models/'
if os.path.exists(model_path):
    print("✅ Models found in Drive")
    print("Available models:", os.listdir(model_path))
else:
    print("❌ Models not found. Please run download_inpainting_models.py locally first.")

In [None]:
# Install minimal dependencies
!pip install diffusers transformers accelerate safetensors pillow

In [None]:
# Load SDXL from your Drive
from diffusers import AutoPipelineForInpainting
import torch
from PIL import Image
import numpy as np

# Load from Drive (no downloading!)
pipe = AutoPipelineForInpainting.from_pretrained(
    "/content/drive/MyDrive/inpainting_models/sdxl_inpainting",
    torch_dtype=torch.float16,
    variant="fp16",
    local_files_only=True  # Don't download, use local files
).to("cuda")

# Optimize for speed
pipe.enable_xformers_memory_efficient_attention()
pipe.enable_model_cpu_offload()

print("✅ SDXL model loaded from Drive!")

In [None]:
# Upload your exported files
from google.colab import files
import zipfile

print("Upload the zip file from room_removal_ultimate.py export")
uploaded = files.upload()

# Extract
zip_name = list(uploaded.keys())[0]
with zipfile.ZipFile(zip_name, 'r') as zip_ref:
    zip_ref.extractall('.')

# Load images
image = Image.open("image.png").convert("RGB")
mask = Image.open("mask.png").convert("L")

print(f"Image size: {image.size}")
print(f"Mask size: {mask.size}")

In [None]:
# Advanced inpainting with multiple passes for best quality
def high_quality_inpaint(image, mask, prompt=None, num_passes=2):
    """
    Multi-pass inpainting for highest quality
    """
    result = image
    
    # Auto-generate prompt if not provided
    if prompt is None:
        prompt = "high quality interior, empty room, professional photography, clean walls and floor"
    
    for pass_num in range(num_passes):
        print(f"Pass {pass_num + 1}/{num_passes}...")
        
        # Adjust strength for each pass
        strength = 0.99 if pass_num == 0 else 0.85
        
        result = pipe(
            prompt=prompt,
            negative_prompt="furniture, objects, people, artifacts, blurry, distorted",
            image=result,
            mask_image=mask,
            num_inference_steps=50,
            strength=strength,
            guidance_scale=8.0,
            height=image.height,
            width=image.width
        ).images[0]
    
    return result

# Run high-quality inpainting
result = high_quality_inpaint(image, mask)
print("✅ Inpainting complete!")

In [None]:
# Display comparison
import matplotlib.pyplot as plt

fig, axes = plt.subplots(1, 3, figsize=(18, 6))

axes[0].imshow(image)
axes[0].set_title('Original', fontsize=14)
axes[0].axis('off')

axes[1].imshow(mask, cmap='gray')
axes[1].set_title('Mask', fontsize=14)
axes[1].axis('off')

axes[2].imshow(result)
axes[2].set_title('SDXL Inpainted (High Quality)', fontsize=14)
axes[2].axis('off')

plt.tight_layout()
plt.show()

# Save and download
result.save("result_sdxl_hq.png")
files.download("result_sdxl_hq.png")

In [None]:
# Alternative: Batch processing for multiple images
def batch_process(zip_file_path):
    """
    Process multiple image/mask pairs
    """
    import zipfile
    import os
    
    results = []
    with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
        zip_ref.extractall('batch')
    
    # Find all image/mask pairs
    for file in os.listdir('batch'):
        if file.startswith('image') and file.endswith('.png'):
            img_path = f'batch/{file}'
            mask_path = img_path.replace('image', 'mask')
            
            if os.path.exists(mask_path):
                img = Image.open(img_path).convert('RGB')
                msk = Image.open(mask_path).convert('L')
                
                result = high_quality_inpaint(img, msk)
                result_path = f'results/{file}'
                result.save(result_path)
                results.append(result_path)
                print(f"Processed: {file}")
    
    return results

# Uncomment to use batch processing
# results = batch_process('your_batch.zip')