<a href="https://colab.research.google.com/github/fjadidi2001/Image_Inpaint/blob/main/Image_Inpaint_CelebHQ_Diffusions.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
!pip install torch==2.0.1+cu118 torchvision==0.15.2+cu118 --extra-index-url https://download.pytorch.org/whl/cu118
!pip install diffusers[torch]==0.19.3 datasets accelerate transformers

Looking in indexes: https://pypi.org/simple, https://download.pytorch.org/whl/cu118


In [1]:
from datasets import load_dataset
from torchvision import transforms
import torch
import numpy as np
import cv2
from diffusers import DDPMPipeline, DDPMScheduler
from torch import nn
from torch.utils.data import DataLoader
from torch.optim import Adam
import gc

# --- Memory Optimization Setup ---
torch.backends.cudnn.benchmark = True
torch.cuda.empty_cache()

# --- Reduced Batch Size ---
BATCH_SIZE = 2  # Reduced from 8 to prevent OOM

# --- Dataset with Smaller Resolution ---
transform = transforms.Compose([
    transforms.Resize((128, 128)),  # Reduced resolution
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

# --- Simplified Mask Generation ---
def create_irregular_mask(image_size=(128, 128)):
    mask = np.ones(image_size, dtype=np.uint8)
    # Generate random rectangle mask
    x = np.random.randint(0, image_size[0]//2)
    y = np.random.randint(0, image_size[1]//2)
    w = np.random.randint(image_size[0]//4, image_size[0]-x)
    h = np.random.randint(image_size[1]//4, image_size[1]-y)
    cv2.rectangle(mask, (x, y), (x+w, y+h), 0, -1)
    return mask

# --- Model Loading with SafeTensors Handling ---
try:
    model = DDPMPipeline.from_pretrained("google/ddpm-celebahq-256")
except:
    # Fallback for safetensors issue
    model = DDPMPipeline.from_pretrained("google/ddpm-celebahq-256", use_safetensors=False)

# --- Modified UNet with Gradient Checkpointing ---
original_conv_in = model.unet.conv_in
model.unet.conv_in = nn.Conv2d(4, original_conv_in.out_channels,
                              kernel_size=3, padding=1)
model.unet.enable_gradient_checkpointing()

# Move model to GPU and set to eval mode
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.unet.to(device)
model.unet.train()

# --- Memory-Optimized Training Loop ---
def collate_fn(batch):
    images = [transform(img['image'].convert('RGB')) for img in batch]
    return torch.stack(images).to(device, non_blocking=True)

train_loader = DataLoader(
    load_dataset("saitsharipova/CelebA-HQ", split="train"),
    batch_size=BATCH_SIZE,
    collate_fn=collate_fn,
    pin_memory=True
)

optimizer = Adam(model.unet.parameters(), lr=1e-4)

for epoch in range(5):  # Reduced epochs
    for images in train_loader:
        # Generate masks directly on GPU
        masks = torch.stack([
            torch.from_numpy(create_irregular_mask()).float()
            for _ in range(images.size(0))
        ]).unsqueeze(1).to(device, non_blocking=True)

        # Mixed Precision Training
        with torch.cuda.amp.autocast():
            masked_images = images * (1 - masks)
            timesteps = torch.randint(0, 1000, (BATCH_SIZE,), device=device)
            noise = torch.randn_like(masked_images)
            noisy_images = model.scheduler.add_noise(masked_images, noise, timesteps)

            # Forward pass with memory cleanup
            outputs = model.unet(torch.cat([noisy_images, masks], dim=1), timesteps).sample
            loss = nn.functional.mse_loss(outputs, noise)

        # Optimizer steps with gradient scaling
        optimizer.zero_grad(set_to_none=True)  # Reduces memory fragmentation
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.unet.parameters(), 1.0)
        optimizer.step()

        # Memory cleanup
        del noisy_images, masks, outputs
        torch.cuda.empty_cache()
        gc.collect()

    print(f"Epoch {epoch+1} Loss: {loss.item():.4f}")

# Save final model with safetensors
model.save_pretrained("inpainting_model", safe_serialization=True)

ImportError: cannot import name 'insecure_hashlib' from 'huggingface_hub.utils' (/usr/local/lib/python3.11/dist-packages/huggingface_hub/utils/__init__.py)