<a href="https://colab.research.google.com/github/fjadidi2001/Image_Inpaint/blob/main/Image_Inpaint_CelebHQ_Diffusions.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install diffusers transformers datasets accelerate safetensors

Collecting datasets
  Downloading datasets-3.2.0-py3-none-any.whl.metadata (20 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py311-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.9.0,>=2023.1.0 (from fsspec[http]<=2024.9.0,>=2023.1.0->datasets)
  Downloading fsspec-2024.9.0-py3-none-any.whl.metadata (11 kB)
Downloading datasets-3.2.0-py3-none-any.whl (480 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m480.6/480.6 kB[0m [31m15.0 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m8.9 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading fsspec-2024.9.0-py3-none-any.whl (

In [6]:
from datasets import load_dataset
from torchvision import transforms
import torch
import numpy as np
import cv2
from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel
from torch import nn
from torch.utils.data import DataLoader
from torch.optim import Adam
import gc

# --- Memory Optimization Setup ---
torch.backends.cudnn.benchmark = True
torch.cuda.empty_cache()

# --- Reduced Batch Size ---
BATCH_SIZE = 2

# --- Dataset with Smaller Resolution ---
transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

# --- Corrected Mask Generation ---
def create_irregular_mask(image_size=(128, 128)):
    H, W = image_size
    mask = np.zeros((H, W), dtype=np.uint8)
    x = np.random.randint(0, W//2)
    y = np.random.randint(0, H//2)
    w = np.random.randint(W//4, W - x)
    h = np.random.randint(H//4, H - y)
    cv2.rectangle(mask, (x, y), (x + w, y + h), 1, -1)
    return mask

# --- Model Loading with Explicit safetensors Disable ---
try:
    # First try loading with safetensors disabled
    model = DDPMPipeline.from_pretrained(
        "google/ddpm-celebahq-256",
        use_safetensors=False,
        safety_checker=None
    )
except Exception as e:
    print(f"Error loading model: {e}")
    # Fallback to basic UNet architecture
    model = DDPMPipeline.from_pretrained(
        UNet2DModel(
            sample_size=128,
            in_channels=3,
            out_channels=3,
            layers_per_block=2,
            block_out_channels=(128, 256, 512),
            down_block_types=(
                "DownBlock2D",
                "DownBlock2D",
                "DownBlock2D"
            ),
            up_block_types=(
                "UpBlock2D",
                "UpBlock2D",
                "UpBlock2D"
            ),
        ),
        safety_checker=None
    )

# --- Modified UNet with Proper Initialization ---
original_conv_in = model.unet.conv_in
new_conv_in = nn.Conv2d(4, original_conv_in.out_channels,
                        kernel_size=original_conv_in.kernel_size,
                        padding=original_conv_in.padding,
                        stride=original_conv_in.stride)

with torch.no_grad():
    # Handle different weight dimensions
    if original_conv_in.weight.shape[1] == 3:
        new_conv_in.weight[:, :3] = original_conv_in.weight.clone()
        new_conv_in.weight[:, 3:] = 0
    else:  # If loading failed, initialize randomly
        new_conv_in.weight.normal_()
    new_conv_in.bias = original_conv_in.bias.clone()

model.unet.conv_in = new_conv_in
model.unet.enable_gradient_checkpointing()

# --- Device Setup ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.unet.to(device)
model.unet.train()

# --- Data Loading ---
ds = load_dataset("saitsharipov/CelebA-HQ")

def collate_fn(batch):
    images = [transform(img['image'].convert('RGB')) for img in batch]
    return torch.stack(images).to(device, non_blocking=True)

train_loader = DataLoader(
    ds['train'],
    batch_size=BATCH_SIZE,
    shuffle=True,
    collate_fn=collate_fn,
    num_workers=4,
    pin_memory=True
)

# --- Training Setup ---
optimizer = Adam(model.unet.parameters(), lr=1e-4)

# --- Training Loop ---
for epoch in range(5):
    for images in train_loader:
        masks = torch.stack([
            torch.from_numpy(create_irregular_mask()).float()
            for _ in range(images.size(0))
        ]).unsqueeze(1).to(device)

        masked_images = images * (1 - masks)
        noise = torch.randn_like(images)
        timesteps = torch.randint(0, model.scheduler.config.num_train_timesteps,
                                (BATCH_SIZE,), device=device)

        with torch.cuda.amp.autocast():
            noisy_images = model.scheduler.add_noise(images, noise, timesteps)
            model_output = model.unet(
                torch.cat([noisy_images, masks], dim=1),
                timesteps
            ).sample
            loss = nn.functional.mse_loss(model_output, noise)

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.unet.parameters(), 1.0)
        optimizer.step()

        del noisy_images, masks, model_output
        torch.cuda.empty_cache()

    print(f"Epoch {epoch+1} Loss: {loss.item():.4f}")

# Save model
model.save_pretrained("inpainting_model", safe_serialization=torch.cuda.is_available())

Keyword arguments {'safety_checker': None} are not expected by DDPMPipeline and will be ignored.


Loading pipeline components...:   0%|          | 0/2 [00:00<?, ?it/s]

TypeError: cannot assign 'torch.FloatTensor' as parameter 'bias' (torch.nn.Parameter or None expected)