# README:
Train a Stable Diffusion inpainting pipeline to automatically restore damaged artwork. Use encoder-decoder diffusion architecture for high-quality image completion, and specialized loss functions to focus on authentic historical style reconstruction rather than generic inpainting. Use memory-optimized training with gradient accumulation and mixed precision for efficient fine-tuning.

In [None]:
!pip install -q diffusers transformers accelerate datasets xformers ftfy bitsandbytes huggingface_hub torchvision tqdm wandb

In [None]:
import os
import torch
import numpy as np
from glob import glob
from PIL import Image, ImageOps
from tqdm.auto import tqdm
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
from datasets import Dataset as HFDataset
import random
import logging
import json
from datetime import datetime
from diffusers import (
    StableDiffusionInpaintPipeline,
    DDPMScheduler,
    UNet2DConditionModel,
    AutoencoderKL
)
from transformers import CLIPTextModel, CLIPTokenizer
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
from diffusers.optimization import get_cosine_schedule_with_warmup
from torch.optim import AdamW
import torch.nn as nn
from torchvision.transforms.functional import to_pil_image
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim

In [None]:
from huggingface_hub import notebook_login
notebook_login()

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# Define a class to hold all configuration parameters

class TrainingConfig:
    def __init__(self):
        # Paths
        self.damaged_dir = "/content/drive/My Drive/complete_dataset_25k_plus/damaged"
        self.mask_dir = "/content/drive/My Drive/complete_dataset_25k_plus/masks"
        self.target_dir = "/content/drive/My Drive/complete_dataset_25k_plus/Data Categories"
        self.export_dir = "/content/inpaint_dataset"
        self.output_dir = "/content/drive/My Drive/sd-inpainting-finetuned"

        # Training parameters
        self.model_id = "runwayml/stable-diffusion-inpainting"
        self.train_batch_size = 1
        self.gradient_accumulation_steps = 8
        self.learning_rate = 5e-6
        self.num_train_epochs = 10
        self.save_model_epochs = 2
        self.max_grad_norm = 1.0
        self.warmup_steps = 100
        self.logging_steps = 10
        self.validation_steps = 50

        # Image parameters
        self.resolution = 512
        self.valid_extensions = {'.jpg', '.jpeg', '.png'}

        # Device and memory settings
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.mixed_precision = True
        self.enable_xformers = True
        self.gradient_checkpointing = True

    def validate_paths(self):
        """Validate all required paths exist"""
        paths_to_check = [self.damaged_dir, self.mask_dir, self.target_dir]
        for path in paths_to_check:
            if not os.path.exists(path):
                raise FileNotFoundError(f"Required directory not found: {path}")

        # Create output directories
        os.makedirs(self.export_dir, exist_ok=True)
        os.makedirs(self.output_dir, exist_ok=True)

config = TrainingConfig()
config.validate_paths()

In [None]:
# Finds files with valid extensions and strip suffixes from name

def get_files_map_recursive(directory, suffixes, valid_extensions):
    files_map = {}
    if not os.path.exists(directory):
        logger.warning(f"Directory does not exist: {directory}")
        return files_map

    try:
        pattern = os.path.join(directory, '**', '*.*')
        all_files = glob(pattern, recursive=True)

        for filepath in all_files:
            try:
                ext = os.path.splitext(filepath)[1].lower()
                if ext not in valid_extensions:
                    continue

                fname = os.path.basename(filepath)
                root, _ = os.path.splitext(fname)

                for suf in suffixes:
                    if suf == '':
                        if not (root.endswith('_damaged') or root.endswith('_mask')):
                            base = root
                            files_map[base] = filepath
                            break
                    else:
                        if root.endswith(suf):
                            base = root[:-len(suf)]
                            files_map[base] = filepath
                            break
            except Exception as e:
                logger.warning(f"Error processing file {filepath}: {e}")
                continue

    except Exception as e:
        logger.error(f"Error scanning directory {directory}: {e}")

    return files_map


In [None]:
# Loads and processes an image, pad to square and resize
def pad_and_resize_safe(image_path, target_size=(512, 512)):
    """Safe image padding and resizing with error handling"""
    try:
        if not os.path.exists(image_path):
            raise FileNotFoundError(f"Image not found: {image_path}")

        img = Image.open(image_path).convert("RGB")
        width, height = img.size

        if width == 0 or height == 0:
            raise ValueError(f"Invalid image dimensions: {width}x{height}")

        max_side = max(width, height)
        delta_w = max_side - width
        delta_h = max_side - height
        padding = (delta_w//2, delta_h//2, delta_w - delta_w//2, delta_h - delta_h//2)
        padded = ImageOps.expand(img, padding, fill=(0, 0, 0))
        resized = padded.resize(target_size, Image.LANCZOS)
        return resized
    except Exception as e:
        logger.error(f"Error processing image {image_path}: {e}")
        return None

In [None]:
# Try to find triplets of files: a damaged image, its mask, and the clean target image

def find_triplets_recursive_safe(category, damaged_dir, mask_dir, target_dir, valid_extensions):
    """Find matching triplets with comprehensive error handling"""
    logger.info(f"Processing category: {category}")

    cat_damaged = os.path.join(damaged_dir, category)
    cat_mask = os.path.join(mask_dir, category)
    cat_target = os.path.join(target_dir, category)

    # Check if category directories exist
    missing_dirs = []
    for name, path in [("damaged", cat_damaged), ("mask", cat_mask), ("target", cat_target)]:
        if not os.path.exists(path):
            missing_dirs.append(f"{name}: {path}")

    if missing_dirs:
        logger.warning(f"Missing directories for category {category}: {missing_dirs}")
        return []

    damaged_map = get_files_map_recursive(cat_damaged, ['_damaged'], valid_extensions)
    mask_map = get_files_map_recursive(cat_mask, ['_mask'], valid_extensions)
    target_map = get_files_map_recursive(cat_target, [''], valid_extensions)

    logger.info(f"Found - Damaged: {len(damaged_map)}, Masks: {len(mask_map)}, Targets: {len(target_map)}")

    common_keys = set(damaged_map.keys()) & set(mask_map.keys()) & set(target_map.keys())
    logger.info(f"Complete triplets: {len(common_keys)}")

    triplets = []
    for base in common_keys:
        # Verify all files exist and are readable
        files_to_check = [
            ('damaged', damaged_map[base]),
            ('mask', mask_map[base]),
            ('target', target_map[base])
        ]

        valid_triplet = True
        for file_type, file_path in files_to_check:
            if not os.path.exists(file_path):
                logger.warning(f"Missing {file_type} file: {file_path}")
                valid_triplet = False
                break
            try:
                # Quick check if file is a valid image
                with Image.open(file_path) as img:
                    pass
            except Exception as e:
                logger.warning(f"Invalid {file_type} image {file_path}: {e}")
                valid_triplet = False
                break

        if valid_triplet:
            triplets.append({
                'name': base,
                'category': category,
                'damaged': damaged_map[base],
                'mask': mask_map[base],
                'target': target_map[base]
            })

    logger.info(f"Valid triplets after verification: {len(triplets)}")
    return triplets


In [None]:
# Set category
category = "Medieval"
logger.info(f"Starting processing for category: {category}")

triplets = find_triplets_recursive_safe(
    category, config.damaged_dir, config.mask_dir,
    config.target_dir, config.valid_extensions
)

if len(triplets) == 0:
    raise ValueError(f"No valid triplets found for category {category}")

logger.info(f"Found {len(triplets)} valid triplets")

In [None]:
# Prepare the dataset

def prepare_dataset_safe(triplets, export_dir, resolution=512):
    """Safely prepare dataset with error handling and validation"""
    successful_conversions = 0
    failed_conversions = 0

    for i, triplet in enumerate(tqdm(triplets, desc="Processing images")):
        try:
            base = f"img_{i:04d}"

            # Process each image with error checking
            damaged_img = pad_and_resize_safe(triplet["damaged"], (resolution, resolution))
            mask_img = pad_and_resize_safe(triplet["mask"], (resolution, resolution))
            target_img = pad_and_resize_safe(triplet["target"], (resolution, resolution))

            if damaged_img is None or mask_img is None or target_img is None:
                logger.warning(f"Failed to process triplet {triplet['name']}")
                failed_conversions += 1
                continue

            # Save images
            damaged_img.save(os.path.join(export_dir, f"{base}.png"))
            mask_img.save(os.path.join(export_dir, f"{base}_mask.png"))
            target_img.save(os.path.join(export_dir, f"{base}_target.png"))

            successful_conversions += 1

        except Exception as e:
            logger.error(f"Error processing triplet {triplet['name']}: {e}")
            failed_conversions += 1
            continue

    logger.info(f"Dataset preparation complete: {successful_conversions} successful, {failed_conversions} failed")
    return successful_conversions

In [None]:
successful_count = prepare_dataset_safe(triplets, config.export_dir, config.resolution)

In [None]:
def generate_prompts(category):
    return (
        f"You are an expert art restorer. "
        f"Carefully restore and complete this {category} era artwork, "
        f"preserving its original composition, brushwork, and textures. "
        f"Emphasize historical accuracy, authentic color palettes, and fine detail "
        f"true to the {category} period."
    )


In [None]:
# prepare dataset for training by combining images, masks, and prompts

examples = []
processed_count = 0
for i in range(len(triplets)):
    base = f"img_{i:04d}"
    image_path = os.path.join(config.export_dir, f"{base}.png")
    mask_path = os.path.join(config.export_dir, f"{base}_mask.png")

    if os.path.exists(image_path) and os.path.exists(mask_path):
        prompt = generate_prompts(triplets[i]['category'])
        examples.append({
            "image": image_path,
            "mask": mask_path,
            "prompt": prompt
        })
        processed_count += 1

logger.info(f"Created {processed_count} dataset examples")

if len(examples) < 10:
    raise ValueError(f"Too few examples ({len(examples)}) for training")

# Create and split dataset
random.shuffle(examples)
dataset = HFDataset.from_list(examples)
dataset = dataset.train_test_split(test_size=0.1)

logger.info(f"Dataset split - Train: {len(dataset['train'])}, Test: {len(dataset['test'])}")

In [None]:
# PyTorch dataset: convert the images and masks into tensors and tokenize the prompt

class SafeInpaintingDataset(Dataset):
    def __init__(self, dataset, tokenizer, size=512):
        self.dataset = dataset
        self.tokenizer = tokenizer
        self.size = size

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        try:
            item = self.dataset[idx]

            # Load and validate images
            if not os.path.exists(item["image"]):
                raise FileNotFoundError(f"Image not found: {item['image']}")
            if not os.path.exists(item["mask"]):
                raise FileNotFoundError(f"Mask not found: {item['mask']}")

            image = Image.open(item["image"]).convert("RGB")
            mask = Image.open(item["mask"]).convert("L")

            # Validate image dimensions
            if image.size != (self.size, self.size) or mask.size != (self.size, self.size):
                image = image.resize((self.size, self.size), Image.LANCZOS)
                mask = mask.resize((self.size, self.size), Image.LANCZOS)

            # Convert to tensor with proper normalization
            image = torch.from_numpy(np.array(image)).float() / 127.5 - 1.0
            mask = torch.from_numpy(np.array(mask)).float() / 255.0

            mask = (mask > 0.5).float()

            # Tokenize prompt with error handling
            prompt = item.get("prompt", "")
            text_input = self.tokenizer(
                prompt,
                padding="max_length",
                max_length=self.tokenizer.model_max_length,
                truncation=True,
                return_tensors="pt"
            )

            return {
                "pixel_values": image.permute(2, 0, 1),
                "mask_values": mask.unsqueeze(0),
                "input_ids": text_input.input_ids.squeeze(),
                "attention_mask": text_input.attention_mask.squeeze()
            }

        except Exception as e:
            logger.error(f"Error loading dataset item {idx}: {e}")

In [None]:
# Safely load the core models for inpainting pipeline

def load_models_safe(config):
    try:
        logger.info("Loading tokenizer and text encoder...")
        tokenizer = CLIPTokenizer.from_pretrained(
            config.model_id,
            subfolder="tokenizer",
            use_fast=False  # More stable for fine-tuning
        )
        text_encoder = CLIPTextModel.from_pretrained(
            config.model_id,
            subfolder="text_encoder"
        ).to(config.device)

        logger.info("Loading VAE...")
        vae = AutoencoderKL.from_pretrained(
            config.model_id,
            subfolder="vae",
            low_cpu_mem_usage=True
        ).to(config.device)

        logger.info("Loading UNet...")
        unet = UNet2DConditionModel.from_pretrained(
            config.model_id,
            subfolder="unet",
            low_cpu_mem_usage=True,
            torch_dtype=torch.float16 if config.mixed_precision else torch.float32
        )

        # Enable memory-saving features
        if config.gradient_checkpointing:
            unet.enable_gradient_checkpointing()
        if config.enable_xformers:
            unet.set_use_memory_efficient_attention_xformers(True)

        unet.to(config.device)

        # Freeze models that shouldn't be trained
        vae.requires_grad_(False)
        text_encoder.requires_grad_(False)

        logger.info("Model loading complete")
        return tokenizer, text_encoder, vae, unet

    except Exception as e:
        logger.error(f"Model loading failed: {e}")
        raise

tokenizer, text_encoder, vae, unet = load_models_safe(config)

In [None]:
# Create datasets
train_dataset = SafeInpaintingDataset(dataset["train"], tokenizer, config.resolution)
val_dataset = SafeInpaintingDataset(dataset["test"], tokenizer, config.resolution)

# Memory-optimized data loaders
train_loader = DataLoader(
    train_dataset,
    batch_size=config.train_batch_size,
    shuffle=True,
    pin_memory=True,
    num_workers=2
)

val_loader = DataLoader(
    val_dataset,
    batch_size=1,
    pin_memory=True
)

In [None]:
# Optimizer (only train UNet parameters)
optimizer = AdamW(
    unet.parameters(),
    lr=config.learning_rate,
    weight_decay=1e-6
)

# Learning rate scheduler
lr_scheduler = get_cosine_schedule_with_warmup(
    optimizer,
    num_warmup_steps=config.warmup_steps,
    num_training_steps=len(train_loader) * config.num_train_epochs
)

# Noise scheduler
noise_scheduler = DDPMScheduler.from_pretrained(config.model_id, subfolder="scheduler")

# Loss function
loss_fn = nn.MSELoss()

# Mixed precision scaler
scaler = torch.cuda.amp.GradScaler(enabled=config.mixed_precision)

In [None]:
# Trains the UNet for a single epoch on difussion pods

def train_one_epoch(unet, vae, text_encoder, train_loader, optimizer,
                   noise_scheduler, scaler, epoch, config):
    unet.train()
    total_loss = 0

    for step, batch in enumerate(tqdm(train_loader, desc=f"Epoch {epoch + 1}")):
        # Move batch to device
        pixel_values = batch["pixel_values"].to(config.device)
        mask_values = batch["mask_values"].to(config.device)
        input_ids = batch["input_ids"].to(config.device)
        attention_mask = batch["attention_mask"].to(config.device)

        # Get text embeddings
        with torch.no_grad():
            text_embeddings = text_encoder(
                input_ids,
                attention_mask=attention_mask
            )[0]

        # Prepare masked images
        masked_images = pixel_values * (mask_values < 0.5)

        # Convert images to latent space
        with torch.no_grad():
            latents = vae.encode(pixel_values).latent_dist.sample()
            latents = latents * 0.18215
            masked_latents = vae.encode(masked_images).latent_dist.sample()
            masked_latents = masked_latents * 0.18215

        # Sample noise and add to latents
        noise = torch.randn_like(latents)
        timesteps = torch.randint(
            0, noise_scheduler.config.num_train_timesteps,
            (latents.shape[0],), device=config.device
        ).long()

        noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)

        # Forward pass with mixed precision
        with torch.amp.autocast(device_type='cuda', enabled=config.mixed_precision):
            mask_resized = F.interpolate(mask_values, size=latents.shape[2:], mode='nearest')

            # Concatenate inputs
            model_input = torch.cat([
                noisy_latents,          # 4 channels
                masked_latents,         # 4 channels
                mask_resized            # 1 channel
            ], dim=1)

            noise_pred = unet(model_input, timesteps, text_embeddings).sample

            # Calculate loss only on masked regions
            mask = mask_resized.expand(-1, 4, -1, -1)
            loss = loss_fn(noise_pred[mask > 0], noise[mask > 0])
            loss = loss / config.gradient_accumulation_steps

        # Backward pass with gradient scaling
        scaler.scale(loss).backward()

        # Gradient accumulation
        if (step + 1) % config.gradient_accumulation_steps == 0:
            if config.mixed_precision:
                # Mixed precision path with proper FP16 handling
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(unet.parameters(), config.max_grad_norm)
                scaler.step(optimizer)
                scaler.update()
            else:
                # Non-mixed precision path
                torch.nn.utils.clip_grad_norm_(unet.parameters(), config.max_grad_norm)
                optimizer.step()

            optimizer.zero_grad()
            lr_scheduler.step()

        total_loss += loss.item()

        # Logging
        if step % config.logging_steps == 0:
            avg_loss = total_loss / (step + 1)
            lr = optimizer.param_groups[0]["lr"]
            logger.info(f"Step {step}: Loss {avg_loss:.4f}, LR {lr:.2e}")

    return total_loss / len(train_loader)

def validate(unet, vae, text_encoder, val_loader, noise_scheduler, epoch, config):
    unet.eval()
    val_loss = 0
    psnr_values = []
    ssim_values = []

    with torch.no_grad():
        for batch in tqdm(val_loader, desc="Validating"):
            pixel_values = batch["pixel_values"].to(config.device)
            mask_values = batch["mask_values"].to(config.device)
            input_ids = batch["input_ids"].to(config.device)

            # Get text embeddings
            text_embeddings = text_encoder(input_ids)[0]

            # Prepare masked images
            masked_images = pixel_values * (mask_values < 0.5)

            # Convert to latent space
            latents = vae.encode(pixel_values).latent_dist.sample() * 0.18215
            masked_latents = vae.encode(masked_images).latent_dist.sample() * 0.18215

            # Sample noise
            noise = torch.randn_like(latents)
            timesteps = torch.zeros((1,), device=config.device).long()
            noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)

            # Resize mask to match latent dimensions
            mask_resized = F.interpolate(mask_values, size=latents.shape[2:], mode='nearest')

            # Predict noise
            model_input = torch.cat([
                noisy_latents,
                masked_latents,
                mask_resized
            ], dim=1)
            noise_pred = unet(model_input, timesteps, text_embeddings).sample

            # Calculate loss
            mask = mask_resized.expand(-1, 4, -1, -1)
            loss = loss_fn(noise_pred[mask > 0], noise[mask > 0])
            val_loss += loss.item()

            # Generate reconstruction for metrics
            pred_latents = (noisy_latents - noise_pred) / noise_scheduler.init_noise_sigma
            pred_images = vae.decode(pred_latents / 0.18215).sample

            # Calculate metrics
            pred_np = pred_images.squeeze().cpu().numpy().transpose(1, 2, 0)
            target_np = pixel_values.squeeze().cpu().numpy().transpose(1, 2, 0)

            # Normalize for metrics
            pred_np = (pred_np + 1) / 2  # [-1,1] -> [0,1]
            target_np = (target_np + 1) / 2

            # Calculate PSNR and SSIM on unmasked regions
            mask_np = mask_values.squeeze().cpu().numpy() < 0.5
            if mask_np.any():
                psnr_val = psnr(target_np, pred_np, data_range=1.0)
                ssim_val = ssim(target_np, pred_np, channel_axis=-1, data_range=1.0)
                psnr_values.append(psnr_val)
                ssim_values.append(ssim_val)

    avg_loss = val_loss / len(val_loader)
    avg_psnr = np.mean(psnr_values) if psnr_values else 0
    avg_ssim = np.mean(ssim_values) if ssim_values else 0

    logger.info(f"Validation - Loss: {avg_loss:.4f}, PSNR: {avg_psnr:.2f}, SSIM: {avg_ssim:.4f}")

    return avg_loss, avg_psnr, avg_ssim

In [None]:
# Training loop with checkpointing and improved error handling
best_psnr = 0

# Ensure UNet parameters are in the right dtype
if config.mixed_precision:
    # This might throw some errors
    unet = unet.float()

# Create scaler with proper settings for FP16 handling
if config.mixed_precision:
    scaler = torch.cuda.amp.GradScaler(enabled=True)
else:
    scaler = torch.cuda.amp.GradScaler(enabled=False)

for epoch in range(config.num_train_epochs):
    try:
        # Train
        train_loss = train_one_epoch(
            unet, vae, text_encoder, train_loader, optimizer,
            noise_scheduler, scaler, epoch, config
        )

        # Validate
        val_loss, val_psnr, val_ssim = validate(
            unet, vae, text_encoder, val_loader,
            noise_scheduler, epoch, config
        )

        # Save checkpoint
        if (epoch + 1) % config.save_model_epochs == 0 or val_psnr > best_psnr:
            checkpoint_path = os.path.join(config.output_dir, f"checkpoint-{epoch}")
            os.makedirs(checkpoint_path, exist_ok=True)

            # Ensure model is in FP32 for saving
            unet_to_save = unet.float() if config.mixed_precision else unet

            # Save model components
            unet_to_save.save_pretrained(os.path.join(checkpoint_path, "unet"))
            tokenizer.save_pretrained(os.path.join(checkpoint_path, "tokenizer"))

            # Save training state
            torch.save({
                "epoch": epoch,
                "optimizer": optimizer.state_dict(),
                "lr_scheduler": lr_scheduler.state_dict(),
                "scaler": scaler.state_dict(),
                "best_psnr": best_psnr,
            }, os.path.join(checkpoint_path, "training_state.pt"))

            logger.info(f"Saved checkpoint to {checkpoint_path}")

            if val_psnr > best_psnr:
                best_psnr = val_psnr
                # Save as best model
                best_path = os.path.join(config.output_dir, "best_model")
                os.makedirs(best_path, exist_ok=True)
                unet_to_save.save_pretrained(os.path.join(best_path, "unet"))
                logger.info(f"New best model with PSNR: {best_psnr:.2f}")

        # Log epoch results
        logger.info(f"Epoch {epoch + 1}/{config.num_train_epochs} - "
                   f"Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, "
                   f"Val PSNR: {val_psnr:.2f}, Val SSIM: {val_ssim:.4f}")

    except Exception as e:
        logger.error(f"Error during epoch {epoch}: {str(e)}")
        # Save emergency checkpoint
        emergency_path = os.path.join(config.output_dir, f"emergency_checkpoint_epoch_{epoch}")
        os.makedirs(emergency_path, exist_ok=True)
        try:
            unet_to_save = unet.float() if config.mixed_precision else unet
            unet_to_save.save_pretrained(os.path.join(emergency_path, "unet"))
            torch.save({
                "epoch": epoch,
                "optimizer": optimizer.state_dict(),
                "lr_scheduler": lr_scheduler.state_dict(),
                "scaler": scaler.state_dict(),
                "best_psnr": best_psnr,
                "error": str(e)
            }, os.path.join(emergency_path, "training_state.pt"))
            logger.info(f"Saved emergency checkpoint to {emergency_path}")
        except Exception as save_error:
            logger.error(f"Failed to save emergency checkpoint: {str(save_error)}")

        raise e  # Re-raise the original error

logger.info("Training complete!")

In [None]:
# View training results and metrics
print("TRAINING COMPLETED SUCCESSFULLY!")
print("=" * 50)

# Summary of training
print(f"Best PSNR achieved: {best_psnr:.4f}")
print(f"Final training loss: {train_loss:.6f}")
print(f"Final validation loss: {val_loss:.6f}")
print(f"Final validation PSNR: {val_psnr:.4f}")
print(f"Final validation SSIM: {val_ssim:.6f}")