# QLoRA Fine-Tuning for SDXL Inpainting Model

This notebook fine-tunes a Stable Diffusion XL inpainting model using QLoRA (Quantized LoRA) for efficient training on Google Colab.

## Dataset Format Required:
- `input_image` or `image`: PIL.Image - The input room image
- `mask`: PIL.Image (grayscale) - Mask indicating where to edit
- `edit_instruction` or `instruction` or `text`: str - Text instruction describing the edit
- `output_image` or `edited_image`: PIL.Image - The expected output image after editing



In [None]:
# Install required packages
%pip install -q torch diffusers transformers peft bitsandbytes accelerate datasets Pillow numpy safetensors xformers

# Check GPU availability
import torch
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"CUDA version: {torch.version.cuda}")



## Mount Google Drive (Optional)

If you want to save your trained adapter to Google Drive, mount it here. Otherwise, the adapter will be saved in the Colab runtime (temporary).


In [None]:
# Mount Google Drive (uncomment if you want to save to Drive)
# from google.colab import drive
# drive.mount('/content/drive')

# Set output directory (change this if you mounted Drive)
OUTPUT_DIR = "./my-room-editor-qlora"
# OUTPUT_DIR = "/content/drive/MyDrive/my-room-editor-qlora"  # Use this if Drive is mounted



## Configuration

Set your training parameters here. Adjust these based on your dataset size and available GPU memory.


In [None]:
# Model and dataset configuration
MODEL_NAME = "diffusers/stable-diffusion-xl-1.0-inpainting-0.1"
DATASET_NAME = "fusing/instructpix2pix-1000-samples"  # Replace with your dataset
VAE_ID = "madebyollin/sdxl-vae-fp16-fix"

# Training hyperparameters
LEARNING_RATE = 1e-4
NUM_TRAIN_EPOCHS = 5
TRAIN_BATCH_SIZE = 1
GRADIENT_ACCUMULATION_STEPS = 4
MIXED_PRECISION = "fp16"
RESOLUTION = 1024
LORA_RANK = 4
LORA_ALPHA = 4
LORA_DROPOUT = 0.1

print("Configuration:")
print(f"  Model: {MODEL_NAME}")
print(f"  Dataset: {DATASET_NAME}")
print(f"  Output directory: {OUTPUT_DIR}")
print(f"  Learning rate: {LEARNING_RATE}")
print(f"  Epochs: {NUM_TRAIN_EPOCHS}")
print(f"  Batch size: {TRAIN_BATCH_SIZE}")
print(f"  Gradient accumulation: {GRADIENT_ACCUMULATION_STEPS}")
print(f"  LoRA rank: {LORA_RANK}")



## Import Libraries and Setup


In [None]:
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset
from datasets import load_dataset
from peft import LoraConfig, get_peft_model
from transformers import CLIPTextModel, CLIPTokenizer
from diffusers import (
    AutoencoderKL,
    UNet2DConditionModel,
    DDPMScheduler,
)
from diffusers.optimization import get_scheduler
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration
from PIL import Image
import numpy as np
import os
from pathlib import Path
from tqdm.auto import tqdm
import transformers
import diffusers
import datasets

# Setup logging
logger = get_logger(__name__)



## Inspect Dataset Structure

First, let's load the dataset and check what keys it has. This will help us understand the dataset format.


In [None]:
# Load dataset to inspect its structure
try:
    dataset = load_dataset(DATASET_NAME, split="train")
    print(f"Dataset loaded! Size: {len(dataset)}")
    
    # Check the first example to see what keys are available
    if len(dataset) > 0:
        first_example = dataset[0]
        print("\nAvailable keys in dataset:")
        print(list(first_example.keys()))
        print("\nFirst example structure:")
        for key, value in first_example.items():
            if isinstance(value, Image.Image):
                print(f"  {key}: PIL.Image ({value.size}, {value.mode})")
            elif isinstance(value, str):
                print(f"  {key}: str (length: {len(value)})")
            else:
                print(f"  {key}: {type(value).__name__}")
except Exception as e:
    print(f"Error loading dataset: {e}")
    print("Please check your dataset name and format.")



## Dataset Class and Helper Functions

The dataset class now supports multiple possible key names to handle different dataset formats.


In [None]:
class InpaintingDataset(Dataset):
    """
    Dataset class for inpainting fine-tuning.
    
    Supports multiple key name formats:
    - input_image, image, input
    - output_image, edited_image, output, edited
    - edit_instruction, instruction, text, prompt
    - mask (optional)
    """
    
    def __init__(
        self,
        dataset,
        tokenizer,
        vae,
        size=1024,
        center_crop=False,
    ):
        self.dataset = dataset
        self.tokenizer = tokenizer
        self.vae = vae
        self.size = size
        self.center_crop = center_crop
    
    def __len__(self):
        return len(self.dataset)
    
    def _get_key(self, example, possible_keys, default=None):
        """Try multiple possible key names."""
        for key in possible_keys:
            if key in example:
                return example[key]
        if default is not None:
            return default
        raise KeyError(f"Could not find any of these keys: {possible_keys}. Available keys: {list(example.keys())}")
    
    def __getitem__(self, idx):
        example = self.dataset[idx]
        
        # Try multiple possible key names for input image
        input_image = self._get_key(
            example,
            ["input_image", "image", "input", "original_image"],
        )
        
        # Try multiple possible key names for output image
        output_image = self._get_key(
            example,
            ["output_image", "edited_image", "output", "edited", "target_image"],
        )
        
        # Convert to PIL Image if needed
        if isinstance(input_image, dict):
            input_image = Image.open(input_image["path"]) if "path" in input_image else input_image
        if not isinstance(input_image, Image.Image):
            if isinstance(input_image, np.ndarray):
                input_image = Image.fromarray(input_image)
            else:
                raise ValueError(f"Unexpected input_image type: {type(input_image)}")
        
        if isinstance(output_image, dict):
            output_image = Image.open(output_image["path"]) if "path" in output_image else output_image
        if not isinstance(output_image, Image.Image):
            if isinstance(output_image, np.ndarray):
                output_image = Image.fromarray(output_image)
            else:
                raise ValueError(f"Unexpected output_image type: {type(output_image)}")
        
        # Load mask (optional - will create default if not present)
        mask = None
        for mask_key in ["mask", "mask_image", "edit_mask"]:
            if mask_key in example:
                mask = example[mask_key]
                break
        
        if mask is None:
            # If no mask provided, create a simple mask (you may want to generate this differently)
            mask = Image.new("L", input_image.size, 255)
        else:
            if isinstance(mask, dict):
                mask = Image.open(mask["path"]) if "path" in mask else mask
            if not isinstance(mask, Image.Image):
                if isinstance(mask, np.ndarray):
                    mask = Image.fromarray(mask)
                else:
                    mask = Image.new("L", input_image.size, 255)
            if mask.mode != "L":
                mask = mask.convert("L")
        
        # Get instruction text - try multiple possible keys
        instruction = self._get_key(
            example,
            ["edit_instruction", "instruction", "text", "prompt", "edit_prompt", "caption"],
            default="",  # Default to empty string if not found
        )
        
        # Resize and preprocess images
        input_image = input_image.convert("RGB")
        output_image = output_image.convert("RGB")
        
        # Resize maintaining aspect ratio
        def resize_image(image, size):
            image.thumbnail((size, size), Image.Resampling.LANCZOS)
            new_image = Image.new("RGB", (size, size), (0, 0, 0))
            new_image.paste(image, ((size - image.width) // 2, (size - image.height) // 2))
            return new_image
        
        input_image = resize_image(input_image, self.size)
        output_image = resize_image(output_image, self.size)
        mask = mask.resize((self.size, self.size), Image.Resampling.LANCZOS)
        
        # Convert to tensors
        input_image = np.array(input_image).astype(np.float32) / 255.0
        output_image = np.array(output_image).astype(np.float32) / 255.0
        mask = np.array(mask).astype(np.float32) / 255.0
        
        # Normalize to [-1, 1]
        input_image = (input_image - 0.5) / 0.5
        output_image = (output_image - 0.5) / 0.5
        
        # Expand mask to 3 channels
        mask = np.expand_dims(mask, axis=0)
        
        return {
            "input_image": torch.from_numpy(input_image).permute(2, 0, 1),
            "output_image": torch.from_numpy(output_image).permute(2, 0, 1),
            "mask": torch.from_numpy(mask),
            "instruction": instruction,
        }


def collate_fn(examples):
    """Collate function for batching."""
    input_images = [example["input_image"] for example in examples]
    output_images = [example["output_image"] for example in examples]
    masks = [example["mask"] for example in examples]
    instructions = [example["instruction"] for example in examples]
    
    # Stack tensors
    input_images = torch.stack(input_images)
    output_images = torch.stack(output_images)
    masks = torch.stack(masks)
    
    return {
        "input_images": input_images,
        "output_images": output_images,
        "masks": masks,
        "instructions": instructions,
    }


def encode_prompt(text_encoder, text_encoder_2, tokenizer, tokenizer_2, prompts, device):
    """Encode prompts using both text encoders (SDXL uses dual encoders)."""
    # Tokenize
    tokenizer_output = tokenizer(
        prompts,
        padding="max_length",
        max_length=77,
        truncation=True,
        return_tensors="pt",
    )
    input_ids = tokenizer_output.input_ids.to(device)
    
    tokenizer_2_output = tokenizer_2(
        prompts,
        padding="max_length",
        max_length=77,
        truncation=True,
        return_tensors="pt",
    )
    input_ids_2 = tokenizer_2_output.input_ids.to(device)
    
    # Encode
    with torch.no_grad():
        prompt_embeds = text_encoder(input_ids)[0]
        prompt_embeds_2 = text_encoder_2(input_ids_2)[0]
    
    # Concatenate embeddings
    prompt_embeds = torch.cat([prompt_embeds, prompt_embeds_2], dim=-1)
    
    # Get pooled embeddings
    pooled_prompt_embeds = text_encoder_2(input_ids_2)[0].mean(dim=1)
    
    return prompt_embeds, pooled_prompt_embeds


def get_time_ids(batch_size, device):
    """Get time IDs for SDXL conditioning."""
    # SDXL uses time_ids for size and crop conditioning
    # Format: [original_size, crops_coords_top_left, target_size]
    # For simplicity, we use default values (1024x1024, no crop)
    time_ids = torch.tensor([[1024, 1024, 0, 0, 1024, 1024]], dtype=torch.float32, device=device)
    return time_ids.repeat(batch_size, 1)

print("Dataset class and helper functions defined!")



## Setup Accelerator


In [None]:
# Setup accelerator
logging_dir = Path(OUTPUT_DIR, "logs")
project_config = ProjectConfiguration(project_dir=OUTPUT_DIR, logging_dir=logging_dir)
accelerator = Accelerator(
    gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
    mixed_precision=MIXED_PRECISION,
    project_config=project_config,
)

# Configure logging
if accelerator.is_local_main_process:
    datasets.utils.logging.set_verbosity_warning()
    transformers.utils.logging.set_verbosity_warning()
    diffusers.utils.logging.set_verbosity_info()
else:
    datasets.utils.logging.set_verbosity_error()
    transformers.utils.logging.set_verbosity_error()
    diffusers.utils.logging.set_verbosity_error()

logger.info(accelerator.state, main_process_only=False)
print("Accelerator setup complete!")



## Load Models


In [None]:
# Load tokenizer and text encoders
logger.info("Loading tokenizer and text encoders...")
tokenizer = CLIPTokenizer.from_pretrained(MODEL_NAME, subfolder="tokenizer")
tokenizer_2 = CLIPTokenizer.from_pretrained(MODEL_NAME, subfolder="tokenizer_2")
text_encoder = CLIPTextModel.from_pretrained(MODEL_NAME, subfolder="text_encoder", torch_dtype=torch.float16)
text_encoder_2 = CLIPTextModel.from_pretrained(MODEL_NAME, subfolder="text_encoder_2", torch_dtype=torch.float16)

# Load VAE
logger.info("Loading VAE...")
vae = AutoencoderKL.from_pretrained(VAE_ID, torch_dtype=torch.float16)
vae.requires_grad_(False)
vae.eval()

# Load UNet
logger.info("Loading UNet...")
unet = UNet2DConditionModel.from_pretrained(MODEL_NAME, subfolder="unet", torch_dtype=torch.float16)

# Enable memory-efficient attention (xformers) if available
try:
    unet.enable_xformers_memory_efficient_attention()
    print("âœ“ XFormers memory-efficient attention enabled!")
except Exception as e:
    print(f"XFormers not available: {e}")
    # Fallback to attention slicing
    try:
        unet.enable_attention_slicing(slice_size="max")
        print("âœ“ Attention slicing enabled as fallback!")
    except Exception as e2:
        print(f"Could not enable attention optimizations: {e2}")

# Enable VAE tiling for memory efficiency
try:
    vae.enable_tiling()
    print("âœ“ VAE tiling enabled for memory efficiency!")
except Exception as e:
    print(f"VAE tiling not available: {e}")

# Freeze text encoders and VAE
text_encoder.requires_grad_(False)
text_encoder_2.requires_grad_(False)
text_encoder.eval()
text_encoder_2.eval()

print("Models loaded successfully!")



## Configure LoRA Adapters


In [None]:
# Configure LoRA
logger.info("Configuring LoRA adapters...")
unet_lora_config = LoraConfig(
    r=LORA_RANK,
    lora_alpha=LORA_ALPHA,
    init_lora_weights="gaussian",
    target_modules=["to_k", "to_q", "to_v", "to_out.0"],
    lora_dropout=LORA_DROPOUT,
)

# Apply LoRA to UNet
unet = get_peft_model(unet, unet_lora_config)
unet.print_trainable_parameters()

# Enable gradient checkpointing for memory efficiency
# This allows training with larger effective batch sizes
if hasattr(unet, "enable_gradient_checkpointing"):
    unet.enable_gradient_checkpointing()
    print("Gradient checkpointing enabled!")

# Compile UNet for faster training (PyTorch 2.0+)
# This can provide 20-30% speedup
try:
    if hasattr(torch, "compile"):
        print("Compiling UNet with torch.compile()...")
        unet = torch.compile(unet, mode="reduce-overhead")
        print("âœ“ UNet compiled successfully!")
    else:
        print("torch.compile not available (requires PyTorch 2.0+)")
except Exception as e:
    print(f"Warning: Could not compile UNet: {e}")
    print("Continuing without compilation...")

# Load noise scheduler
noise_scheduler = DDPMScheduler.from_pretrained(MODEL_NAME, subfolder="scheduler")
print("LoRA adapters configured!")



## Load and Prepare Dataset

Now we'll reload the dataset and create the training dataloader. The dataset class will automatically handle different key formats.


In [None]:
# Load dataset
logger.info(f"Loading dataset: {DATASET_NAME}")
try:
    dataset = load_dataset(DATASET_NAME, split="train")
    print(f"Dataset loaded! Size: {len(dataset)}")
except Exception as e:
    logger.warning(f"Could not load dataset {DATASET_NAME}: {e}")
    logger.info("Please ensure your dataset has the required format:")
    logger.info("  - input_image or image: PIL.Image")
    logger.info("  - output_image or edited_image: PIL.Image")
    logger.info("  - mask (optional): PIL.Image (grayscale)")
    logger.info("  - edit_instruction or instruction or text: str")
    raise

# Create dataset wrapper
train_dataset = InpaintingDataset(
    dataset=dataset,
    tokenizer=tokenizer,
    vae=vae,
    size=RESOLUTION,
)

# Test the dataset with the first example to catch any errors early
try:
    test_sample = train_dataset[0]
    print("âœ“ Dataset test successful! Sample keys:", list(test_sample.keys()))
except Exception as e:
    print(f"âœ— Error testing dataset: {e}")
    print("\nPlease check the dataset format. The dataset should have:")
    print("  - An input image (key: 'input_image', 'image', or 'input')")
    print("  - An output image (key: 'output_image', 'edited_image', or 'output')")
    print("  - An instruction text (key: 'edit_instruction', 'instruction', or 'text')")
    print("  - Optionally a mask (key: 'mask')")
    raise

# Create optimized dataloader
# Use multiple workers for faster data loading
# pin_memory=True speeds up GPU transfer
# prefetch_factor helps pipeline data loading
train_dataloader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=TRAIN_BATCH_SIZE,
    shuffle=True,
    collate_fn=collate_fn,
    num_workers=2,  # 2-4 workers for Colab (adjust based on available CPU)
    pin_memory=True,  # Faster GPU transfer
    prefetch_factor=2,  # Prefetch batches for smoother pipeline
    persistent_workers=True,  # Keep workers alive between epochs
)

print(f"Dataset prepared! Training samples: {len(train_dataset)}")
print("âœ“ DataLoader optimized with num_workers=2, pin_memory=True, prefetch_factor=2")



## Setup Optimizer and Learning Rate Scheduler


In [None]:
# Setup optimizer
optimizer = torch.optim.AdamW(
    unet.parameters(),
    lr=LEARNING_RATE,
    betas=(0.9, 0.999),
    weight_decay=1e-2,
    eps=1e-08,
)

# Calculate number of training steps
num_update_steps_per_epoch = len(train_dataloader) // GRADIENT_ACCUMULATION_STEPS
num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1)
max_train_steps = NUM_TRAIN_EPOCHS * num_update_steps_per_epoch

# Setup learning rate scheduler
lr_scheduler = get_scheduler(
    "constant",
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=max_train_steps,
)

# Prepare everything with accelerator
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
    unet, optimizer, train_dataloader, lr_scheduler
)

# Move VAE and text encoders to device
vae = vae.to(accelerator.device)
text_encoder = text_encoder.to(accelerator.device)
text_encoder_2 = text_encoder_2.to(accelerator.device)

print(f"Training setup complete!")
print(f"  Total training steps: {max_train_steps}")
print(f"  Steps per epoch: {num_update_steps_per_epoch}")



## Training Loop

This is the main training loop. It will train for the specified number of epochs and save checkpoints after each epoch.


In [None]:
# Training loop
logger.info("Starting training...")
global_step = 0

for epoch in range(NUM_TRAIN_EPOCHS):
    unet.train()
    train_loss = 0.0
    
    progress_bar = tqdm(
        total=num_update_steps_per_epoch,  # Use update steps, not batches
        disable=not accelerator.is_local_main_process,
        desc=f"Epoch {epoch + 1}/{NUM_TRAIN_EPOCHS}",
    )
    
    for step, batch in enumerate(train_dataloader):
            with accelerator.accumulate(unet):
                # Move data to device with non_blocking for faster transfer
                input_images = batch["input_images"].to(accelerator.device, dtype=torch.float16, non_blocking=True)
                output_images = batch["output_images"].to(accelerator.device, dtype=torch.float16, non_blocking=True)
                masks = batch["masks"].to(accelerator.device, dtype=torch.float16, non_blocking=True)
                
                # Encode images to latents (both in one no_grad block for efficiency)
                with torch.no_grad():
                    # Encode both images in parallel if possible
                    input_latents = vae.encode(input_images).latent_dist.sample() * vae.config.scaling_factor
                    output_latents = vae.encode(output_images).latent_dist.sample() * vae.config.scaling_factor
                
                    # Prepare mask for latents
                # Resize mask to match latent space dimensions
                # Mask shape: [batch, 1, H, W] where 1 = edit area, 0 = keep area
                mask_tensor = F.interpolate(
                    masks,
                    size=(input_latents.shape[2], input_latents.shape[3]),
                    mode="nearest",
                )
                # Ensure mask is in [0, 1] range (should already be from dataset preprocessing)
                mask_tensor = torch.clamp(mask_tensor, 0.0, 1.0)
                
                # Get batch size
                bsz = output_latents.shape[0]
                
                # Encode text prompts
                instructions = batch["instructions"]
                with torch.no_grad():
                    # Tokenize with both tokenizers
                    prompt_embeds, pooled_prompt_embeds = encode_prompt(
                        text_encoder, text_encoder_2, tokenizer, tokenizer_2, instructions, accelerator.device
                    )
                    time_ids = get_time_ids(bsz, accelerator.device)
                
                # Sample noise
                noise = torch.randn_like(output_latents)
                timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=output_latents.device)
                timesteps = timesteps.long()
                
                # Add noise to latents
                noisy_latents = noise_scheduler.add_noise(output_latents, noise, timesteps)
                
                # Prepare input for inpainting: SDXL inpainting expects 9 channels
                # Concatenate: [input_latents (4), mask (1), noisy_latents (4)] = 9 channels
                # The mask should be in the range [0, 1] where 1 means "edit this area"
                mask_normalized = mask_tensor  # Already in [0, 1] range
                model_input = torch.cat([input_latents, mask_normalized, noisy_latents], dim=1)
                
                # Predict noise with autocast for better mixed precision performance
                with torch.cuda.amp.autocast(enabled=MIXED_PRECISION == "fp16"):
                    model_pred = unet(
                        model_input,
                        timesteps,
                        encoder_hidden_states=prompt_embeds,
                        added_cond_kwargs={"text_embeds": pooled_prompt_embeds, "time_ids": time_ids},
                    ).sample
                
                # Compute loss (use float32 for stability)
                loss = F.mse_loss(model_pred.float(), noise.float(), reduction="mean")
                
                # Backward pass
                accelerator.backward(loss)
                if accelerator.sync_gradients:
                    accelerator.clip_grad_norm_(unet.parameters(), 1.0)
                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad()
            
            if accelerator.sync_gradients:
                progress_bar.update(1)
                global_step += 1
                train_loss += loss.detach().item()
                
                # Reduce logging frequency to every 200 steps (was 100) for less overhead
                if global_step % 200 == 0:
                    avg_loss = train_loss / 200
                    # Use print instead of logger for faster output
                    progress_bar.set_postfix({"loss": f"{avg_loss:.4f}"})
                    if accelerator.is_local_main_process:
                        print(f"\nStep {global_step}, Loss: {avg_loss:.4f}")
                    train_loss = 0.0
    
    progress_bar.close()
    
    # Save checkpoint after each epoch
    if accelerator.is_main_process:
        logger.info(f"Saving checkpoint after epoch {epoch + 1}...")
        save_path = os.path.join(OUTPUT_DIR, f"checkpoint-{epoch + 1}")
        os.makedirs(save_path, exist_ok=True)
        # Save using PEFT's save_pretrained which works with diffusers
        unet.save_pretrained(save_path)
        logger.info(f"Checkpoint saved to {save_path}")
        print(f"âœ“ Checkpoint saved: {save_path}")

print("Training complete!")



## Save Final Adapter

Save the final trained adapter weights.


In [None]:
# Save final adapter
if accelerator.is_main_process:
    logger.info("Saving final adapter...")
    os.makedirs(OUTPUT_DIR, exist_ok=True)
    # Save the PEFT model - this will create adapter_model.safetensors and adapter_config.json
    # which can be loaded by the pipeline using load_lora_weights()
    unet.save_pretrained(OUTPUT_DIR)
    logger.info(f"Final adapter saved to {OUTPUT_DIR}")
    print(f"âœ“ Final adapter saved to: {OUTPUT_DIR}")
    print(f"Adapter can be loaded in app.py using: pipeline.load_lora_weights('{OUTPUT_DIR}')")
    
    # List saved files
    print("\nSaved files:")
    for file in os.listdir(OUTPUT_DIR):
        file_path = os.path.join(OUTPUT_DIR, file)
        if os.path.isfile(file_path):
            size = os.path.getsize(file_path) / (1024 * 1024)  # Size in MB
            print(f"  - {file} ({size:.2f} MB)")

accelerator.end_training()
print("\nðŸŽ‰ Training finished successfully!")



## Download Adapter (if not using Google Drive)

If you saved to the Colab runtime (not Drive), download the adapter files before the session ends.


In [None]:
# Download adapter files (uncomment if needed)
# from google.colab import files
# import shutil

# # Create a zip file of the adapter
# shutil.make_archive("my-room-editor-qlora", "zip", OUTPUT_DIR)
# files.download("my-room-editor-qlora.zip")

print("To download the adapter, uncomment the code above or copy from Google Drive if you mounted it.")



# QLoRA Fine-Tuning for SDXL Inpainting Model

This notebook fine-tunes a Stable Diffusion XL inpainting model using QLoRA (Quantized LoRA) for efficient training on Google Colab.

## Dataset Format Required:
- `input_image`: PIL.Image - The input room image
- `mask`: PIL.Image (grayscale) - Mask indicating where to edit
- `edit_instruction`: str - Text instruction describing the edit
- `output_image`: PIL.Image - The expected output image after editing



In [None]:
# Install required packages
%pip install -q torch diffusers transformers peft bitsandbytes accelerate datasets Pillow numpy safetensors xformers

# Check GPU availability
import torch
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"CUDA version: {torch.version.cuda}")



## Mount Google Drive (Optional)

If you want to save your trained adapter to Google Drive, mount it here. Otherwise, the adapter will be saved in the Colab runtime (temporary).


In [None]:
# Mount Google Drive (uncomment if you want to save to Drive)
# from google.colab import drive
# drive.mount('/content/drive')

# Set output directory (change this if you mounted Drive)
OUTPUT_DIR = "./my-room-editor-qlora"
# OUTPUT_DIR = "/content/drive/MyDrive/my-room-editor-qlora"  # Use this if Drive is mounted



## Configuration

Set your training parameters here. Adjust these based on your dataset size and available GPU memory.


In [None]:
# Model and dataset configuration
MODEL_NAME = "diffusers/stable-diffusion-xl-1.0-inpainting-0.1"
DATASET_NAME = "fusing/instructpix2pix-1000-samples"  # Replace with your dataset
VAE_ID = "madebyollin/sdxl-vae-fp16-fix"

# Training hyperparameters
LEARNING_RATE = 1e-4
NUM_TRAIN_EPOCHS = 5
TRAIN_BATCH_SIZE = 1
GRADIENT_ACCUMULATION_STEPS = 4
MIXED_PRECISION = "fp16"
RESOLUTION = 1024
LORA_RANK = 4
LORA_ALPHA = 4
LORA_DROPOUT = 0.1

print("Configuration:")
print(f"  Model: {MODEL_NAME}")
print(f"  Dataset: {DATASET_NAME}")
print(f"  Output directory: {OUTPUT_DIR}")
print(f"  Learning rate: {LEARNING_RATE}")
print(f"  Epochs: {NUM_TRAIN_EPOCHS}")
print(f"  Batch size: {TRAIN_BATCH_SIZE}")
print(f"  Gradient accumulation: {GRADIENT_ACCUMULATION_STEPS}")
print(f"  LoRA rank: {LORA_RANK}")



## Import Libraries and Setup


In [None]:
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset
from datasets import load_dataset
from peft import LoraConfig, get_peft_model
from transformers import CLIPTextModel, CLIPTokenizer
from diffusers import (
    AutoencoderKL,
    UNet2DConditionModel,
    DDPMScheduler,
)
from diffusers.optimization import get_scheduler
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration
from PIL import Image
import numpy as np
import os
from pathlib import Path
from tqdm.auto import tqdm
import transformers
import diffusers
import datasets

# Setup logging
logger = get_logger(__name__)



## Dataset Class and Helper Functions


In [None]:
class InpaintingDataset(Dataset):
    """
    Dataset class for inpainting fine-tuning.
    
    Expects dataset with keys: 'input_image', 'mask', 'edit_instruction', 'output_image'
    """
    
    def __init__(
        self,
        dataset,
        tokenizer,
        vae,
        size=1024,
        center_crop=False,
    ):
        self.dataset = dataset
        self.tokenizer = tokenizer
        self.vae = vae
        self.size = size
        self.center_crop = center_crop
    
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        example = self.dataset[idx]
        
        # Load images
        input_image = example["input_image"]
        if isinstance(input_image, dict):
            input_image = Image.open(input_image["path"]) if "path" in input_image else input_image
        if not isinstance(input_image, Image.Image):
            input_image = Image.fromarray(input_image)
        
        output_image = example["output_image"]
        if isinstance(output_image, dict):
            output_image = Image.open(output_image["path"]) if "path" in output_image else output_image
        if not isinstance(output_image, Image.Image):
            output_image = Image.fromarray(output_image)
        
        # Load mask
        mask = example.get("mask", None)
        if mask is None:
            # If no mask provided, create a simple mask (you may want to generate this differently)
            mask = Image.new("L", input_image.size, 255)
        else:
            if isinstance(mask, dict):
                mask = Image.open(mask["path"]) if "path" in mask else mask
            if not isinstance(mask, Image.Image):
                mask = Image.fromarray(mask)
            if mask.mode != "L":
                mask = mask.convert("L")
        
        # Get instruction text
        instruction = example.get("edit_instruction", "")
        if not instruction:
            instruction = example.get("instruction", "")
        if not instruction:
            instruction = example.get("text", "")
        
        # Resize and preprocess images
        input_image = input_image.convert("RGB")
        output_image = output_image.convert("RGB")
        
        # Resize maintaining aspect ratio
        def resize_image(image, size):
            image.thumbnail((size, size), Image.Resampling.LANCZOS)
            new_image = Image.new("RGB", (size, size), (0, 0, 0))
            new_image.paste(image, ((size - image.width) // 2, (size - image.height) // 2))
            return new_image
        
        input_image = resize_image(input_image, self.size)
        output_image = resize_image(output_image, self.size)
        mask = mask.resize((self.size, self.size), Image.Resampling.LANCZOS)
        
        # Convert to tensors
        input_image = np.array(input_image).astype(np.float32) / 255.0
        output_image = np.array(output_image).astype(np.float32) / 255.0
        mask = np.array(mask).astype(np.float32) / 255.0
        
        # Normalize to [-1, 1]
        input_image = (input_image - 0.5) / 0.5
        output_image = (output_image - 0.5) / 0.5
        
        # Expand mask to 3 channels
        mask = np.expand_dims(mask, axis=0)
        
        return {
            "input_image": torch.from_numpy(input_image).permute(2, 0, 1),
            "output_image": torch.from_numpy(output_image).permute(2, 0, 1),
            "mask": torch.from_numpy(mask),
            "instruction": instruction,
        }


def collate_fn(examples):
    """Collate function for batching."""
    input_images = [example["input_image"] for example in examples]
    output_images = [example["output_image"] for example in examples]
    masks = [example["mask"] for example in examples]
    instructions = [example["instruction"] for example in examples]
    
    # Stack tensors
    input_images = torch.stack(input_images)
    output_images = torch.stack(output_images)
    masks = torch.stack(masks)
    
    return {
        "input_images": input_images,
        "output_images": output_images,
        "masks": masks,
        "instructions": instructions,
    }


def encode_prompt(text_encoder, text_encoder_2, tokenizer, tokenizer_2, prompts, device):
    """Encode prompts using both text encoders (SDXL uses dual encoders)."""
    # Tokenize
    tokenizer_output = tokenizer(
        prompts,
        padding="max_length",
        max_length=77,
        truncation=True,
        return_tensors="pt",
    )
    input_ids = tokenizer_output.input_ids.to(device)
    
    tokenizer_2_output = tokenizer_2(
        prompts,
        padding="max_length",
        max_length=77,
        truncation=True,
        return_tensors="pt",
    )
    input_ids_2 = tokenizer_2_output.input_ids.to(device)
    
    # Encode
    with torch.no_grad():
        prompt_embeds = text_encoder(input_ids)[0]
        prompt_embeds_2 = text_encoder_2(input_ids_2)[0]
    
    # Concatenate embeddings
    prompt_embeds = torch.cat([prompt_embeds, prompt_embeds_2], dim=-1)
    
    # Get pooled embeddings
    pooled_prompt_embeds = text_encoder_2(input_ids_2)[0].mean(dim=1)
    
    return prompt_embeds, pooled_prompt_embeds


def get_time_ids(batch_size, device):
    """Get time IDs for SDXL conditioning."""
    # SDXL uses time_ids for size and crop conditioning
    # Format: [original_size, crops_coords_top_left, target_size]
    # For simplicity, we use default values (1024x1024, no crop)
    time_ids = torch.tensor([[1024, 1024, 0, 0, 1024, 1024]], dtype=torch.float32, device=device)
    return time_ids.repeat(batch_size, 1)

print("Dataset class and helper functions defined!")



## Setup Accelerator


In [None]:
# Setup accelerator
logging_dir = Path(OUTPUT_DIR, "logs")
project_config = ProjectConfiguration(project_dir=OUTPUT_DIR, logging_dir=logging_dir)
accelerator = Accelerator(
    gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
    mixed_precision=MIXED_PRECISION,
    project_config=project_config,
)

# Configure logging
if accelerator.is_local_main_process:
    datasets.utils.logging.set_verbosity_warning()
    transformers.utils.logging.set_verbosity_warning()
    diffusers.utils.logging.set_verbosity_info()
else:
    datasets.utils.logging.set_verbosity_error()
    transformers.utils.logging.set_verbosity_error()
    diffusers.utils.logging.set_verbosity_error()

logger.info(accelerator.state, main_process_only=False)
print("Accelerator setup complete!")



## Load Models


In [None]:
# Load tokenizer and text encoders
logger.info("Loading tokenizer and text encoders...")
tokenizer = CLIPTokenizer.from_pretrained(MODEL_NAME, subfolder="tokenizer")
tokenizer_2 = CLIPTokenizer.from_pretrained(MODEL_NAME, subfolder="tokenizer_2")
text_encoder = CLIPTextModel.from_pretrained(MODEL_NAME, subfolder="text_encoder", torch_dtype=torch.float16)
text_encoder_2 = CLIPTextModel.from_pretrained(MODEL_NAME, subfolder="text_encoder_2", torch_dtype=torch.float16)

# Load VAE
logger.info("Loading VAE...")
vae = AutoencoderKL.from_pretrained(VAE_ID, torch_dtype=torch.float16)
vae.requires_grad_(False)
vae.eval()

# Load UNet
logger.info("Loading UNet...")
unet = UNet2DConditionModel.from_pretrained(MODEL_NAME, subfolder="unet", torch_dtype=torch.float16)

# Freeze text encoders and VAE
text_encoder.requires_grad_(False)
text_encoder_2.requires_grad_(False)
text_encoder.eval()
text_encoder_2.eval()

print("Models loaded successfully!")



## Configure LoRA Adapters


In [None]:
# Configure LoRA
logger.info("Configuring LoRA adapters...")
unet_lora_config = LoraConfig(
    r=LORA_RANK,
    lora_alpha=LORA_ALPHA,
    init_lora_weights="gaussian",
    target_modules=["to_k", "to_q", "to_v", "to_out.0"],
    lora_dropout=LORA_DROPOUT,
)

# Apply LoRA to UNet
unet = get_peft_model(unet, unet_lora_config)
unet.print_trainable_parameters()

# Load noise scheduler
noise_scheduler = DDPMScheduler.from_pretrained(MODEL_NAME, subfolder="scheduler")
print("LoRA adapters configured!")



## Load Dataset


In [None]:
# Load dataset
logger.info(f"Loading dataset: {DATASET_NAME}")
try:
    dataset = load_dataset(DATASET_NAME, split="train")
    print(f"Dataset loaded! Size: {len(dataset)}")
except Exception as e:
    logger.warning(f"Could not load dataset {DATASET_NAME}: {e}")
    logger.info("Please ensure your dataset has the required format:")
    logger.info("  - input_image: PIL.Image")
    logger.info("  - mask: PIL.Image (grayscale)")
    logger.info("  - edit_instruction: str")
    logger.info("  - output_image: PIL.Image")
    raise

# Create dataset wrapper
train_dataset = InpaintingDataset(
    dataset=dataset,
    tokenizer=tokenizer,
    vae=vae,
    size=RESOLUTION,
)

# Create dataloader
train_dataloader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=TRAIN_BATCH_SIZE,
    shuffle=True,
    collate_fn=collate_fn,
    num_workers=2,  # Reduced for Colab
)

print(f"Dataset prepared! Training samples: {len(train_dataset)}")



## Setup Optimizer and Learning Rate Scheduler


In [None]:
# Setup optimizer
optimizer = torch.optim.AdamW(
    unet.parameters(),
    lr=LEARNING_RATE,
    betas=(0.9, 0.999),
    weight_decay=1e-2,
    eps=1e-08,
)

# Calculate number of training steps
num_update_steps_per_epoch = len(train_dataloader) // GRADIENT_ACCUMULATION_STEPS
num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1)
max_train_steps = NUM_TRAIN_EPOCHS * num_update_steps_per_epoch

# Setup learning rate scheduler
lr_scheduler = get_scheduler(
    "constant",
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=max_train_steps,
)

# Prepare everything with accelerator
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
    unet, optimizer, train_dataloader, lr_scheduler
)

# Move VAE and text encoders to device
vae = vae.to(accelerator.device)
text_encoder = text_encoder.to(accelerator.device)
text_encoder_2 = text_encoder_2.to(accelerator.device)

print(f"Training setup complete!")
print(f"  Total training steps: {max_train_steps}")
print(f"  Steps per epoch: {num_update_steps_per_epoch}")



## Training Loop

This is the main training loop. It will train for the specified number of epochs and save checkpoints after each epoch.


In [None]:
# Training loop
logger.info("Starting training...")
global_step = 0

for epoch in range(NUM_TRAIN_EPOCHS):
    unet.train()
    train_loss = 0.0
    
    progress_bar = tqdm(
        total=len(train_dataloader),
        disable=not accelerator.is_local_main_process,
        desc=f"Epoch {epoch + 1}/{NUM_TRAIN_EPOCHS}",
    )
    
    for step, batch in enumerate(train_dataloader):
        with accelerator.accumulate(unet):
            # Encode images to latents
            input_images = batch["input_images"].to(dtype=torch.float16)
            output_images = batch["output_images"].to(dtype=torch.float16)
            masks = batch["masks"].to(dtype=torch.float16)
            
            # Encode input image to latents
            with torch.no_grad():
                input_latents = vae.encode(input_images).latent_dist.sample()
                input_latents = input_latents * vae.config.scaling_factor
                
                # Encode output image to latents
                output_latents = vae.encode(output_images).latent_dist.sample()
                output_latents = output_latents * vae.config.scaling_factor
            
            # Prepare mask for latents
            mask_tensor = F.interpolate(
                masks,
                size=(input_latents.shape[2], input_latents.shape[3]),
                mode="nearest",
            )
            
            # Get batch size
            bsz = output_latents.shape[0]
            
            # Encode text prompts
            instructions = batch["instructions"]
            with torch.no_grad():
                # Tokenize with both tokenizers
                prompt_embeds, pooled_prompt_embeds = encode_prompt(
                    text_encoder, text_encoder_2, tokenizer, tokenizer_2, instructions, accelerator.device
                )
                time_ids = get_time_ids(bsz, accelerator.device)
            
            # Sample noise
            noise = torch.randn_like(output_latents)
            timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=output_latents.device)
            timesteps = timesteps.long()
            
            # Add noise to latents
            noisy_latents = noise_scheduler.add_noise(output_latents, noise, timesteps)
            
            # Prepare input for inpainting: combine masked input latents with noisy output latents
            masked_input_latents = input_latents * (1 - mask_tensor)
            noisy_masked_output_latents = noisy_latents * mask_tensor
            model_input = masked_input_latents + noisy_masked_output_latents
            
            # Predict noise
            model_pred = unet(
                model_input,
                timesteps,
                encoder_hidden_states=prompt_embeds,
                added_cond_kwargs={"text_embeds": pooled_prompt_embeds, "time_ids": time_ids},
            ).sample
            
            # Compute loss
            loss = F.mse_loss(model_pred.float(), noise.float(), reduction="mean")
            
            # Backward pass
            accelerator.backward(loss)
            if accelerator.sync_gradients:
                accelerator.clip_grad_norm_(unet.parameters(), 1.0)
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()
        
        if accelerator.sync_gradients:
            progress_bar.update(1)
            global_step += 1
            train_loss += loss.detach().item()
            
            if global_step % 100 == 0:
                avg_loss = train_loss / 100
                logger.info(f"Step {global_step}, Loss: {avg_loss:.4f}")
                train_loss = 0.0
    
    progress_bar.close()
    
    # Save checkpoint after each epoch
    if accelerator.is_main_process:
        logger.info(f"Saving checkpoint after epoch {epoch + 1}...")
        save_path = os.path.join(OUTPUT_DIR, f"checkpoint-{epoch + 1}")
        os.makedirs(save_path, exist_ok=True)
        # Save using PEFT's save_pretrained which works with diffusers
        unet.save_pretrained(save_path)
        logger.info(f"Checkpoint saved to {save_path}")
        print(f"âœ“ Checkpoint saved: {save_path}")

print("Training complete!")



## Save Final Adapter

Save the final trained adapter weights.


In [None]:
# Save final adapter
if accelerator.is_main_process:
    logger.info("Saving final adapter...")
    os.makedirs(OUTPUT_DIR, exist_ok=True)
    # Save the PEFT model - this will create adapter_model.safetensors and adapter_config.json
    # which can be loaded by the pipeline using load_lora_weights()
    unet.save_pretrained(OUTPUT_DIR)
    logger.info(f"Final adapter saved to {OUTPUT_DIR}")
    print(f"âœ“ Final adapter saved to: {OUTPUT_DIR}")
    print(f"Adapter can be loaded in app.py using: pipeline.load_lora_weights('{OUTPUT_DIR}')")
    
    # List saved files
    print("\nSaved files:")
    for file in os.listdir(OUTPUT_DIR):
        file_path = os.path.join(OUTPUT_DIR, file)
        if os.path.isfile(file_path):
            size = os.path.getsize(file_path) / (1024 * 1024)  # Size in MB
            print(f"  - {file} ({size:.2f} MB)")

accelerator.end_training()
print("\nðŸŽ‰ Training finished successfully!")



## Download Adapter (if not using Google Drive)

If you saved to the Colab runtime (not Drive), download the adapter files before the session ends.


In [None]:
# Download adapter files (uncomment if needed)
# from google.colab import files
# import shutil

# # Create a zip file of the adapter
# shutil.make_archive("my-room-editor-qlora", "zip", OUTPUT_DIR)
# files.download("my-room-editor-qlora.zip")

print("To download the adapter, uncomment the code above or copy from Google Drive if you mounted it.")

