In [None]:
%pip install -qq diffusers datasets accelerate wandb open-clip-torch
%pip install -q ip-adapter
!pip install -q diffusers transformers accelerate invisible-watermark>=0.2.0

In [None]:
from huggingface_hub import notebook_login

notebook_login()

In [None]:
!nvidia-smi

In [None]:
import numpy as np
import torch
import torch.nn.functional as F
import torchvision
from datasets import load_dataset
from diffusers import DDIMScheduler, DDPMPipeline
from matplotlib import pyplot as plt
from PIL import Image
from torchvision import transforms
from tqdm.auto import tqdm
from diffusers import StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline
import torch
from google.colab import drive
import zipfile
import shutil
from diffusers import AutoPipelineForInpainting
from diffusers.utils import load_image, make_image_grid
import os
from diffusers.utils import load_image
from transformers import CLIPVisionModelWithProjection

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

print('torch version:',torch.__version__)
print('device:', device)
weight_dtype = torch.float16

In [None]:
base_model_path = "stabilityai/stable-diffusion-2-inpainting"
ip_adapter_path = "h94/IP-Adapter"
image_encoder_path = "models--laion--CLIP-ViT-H-14-laion2B-s32B-b79K"

In [None]:
drive.mount('/content/drive')

source_path = '/content/drive/MyDrive/viton_plus.zip'
destination_path = '/content/clothes.zip'
shutil.copy2(source_path, destination_path)

with zipfile.ZipFile(source_path, 'r') as zip_ref:
    zip_ref.extractall('/content/clothes_data')

print("File copied from Drive and extracted")

In [None]:
print("Loading SDXL Inpainting pipeline...")
pipe = AutoPipelineForInpainting.from_pretrained(
    "diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
    torch_dtype=torch.float16,
    variant="fp16"
).to(device)

In [None]:
print("Loading base SDXL pipeline to extract image encoder...")
base_pipe = StableDiffusionXLPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    torch_dtype=weight_dtype,
    variant="fp16",
    use_safetensors=True
).to(device)

In [None]:
pipe.image_encoder = base_pipe.image_encoder
del base_pipe

In [None]:
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import json
from pathlib import Path

class ClothingDataset(Dataset):
    def __init__(self, data_root, pairs_file, image_size=512, is_train=True):
        """
        Dataset for loading clothing data for Stable Diffusion fine-tuning

        Args:
            data_root: Path to clothes_data folder
            pairs_file: train_pairs.txt or test_pairs.txt
            image_size: Target image size for training
            is_train: Whether this is training data
        """
        self.data_root = Path(data_root)
        self.image_size = image_size
        self.is_train = is_train

        # Load image pairs
        self.pairs = self._load_pairs(pairs_file)

        # Define transforms
        self.transform = transforms.Compose([
            transforms.Resize((image_size, image_size)),
            transforms.ToTensor(),
            transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])  # Normalize to [-1, 1]
        ])

        self.mask_transform = transforms.Compose([
            transforms.Resize((image_size, image_size)),
            transforms.ToTensor()
        ])

    def _load_pairs(self, pairs_file):
        """Load image pairs from txt file"""
        pairs = []
        with open(pairs_file, 'r') as f:
            for line in f:
                line = line.strip()
                if line:
                    # Assuming format: person_image cloth_image
                    parts = line.split()
                    if len(parts) >= 2:
                        pairs.append((parts[0], parts[1]))
        return pairs

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
      person_img_name, cloth_img_name = self.pairs[idx]
      split = "train" if self.is_train else "test"

      # Load images
      person_img_path = self.data_root / split / "image" / person_img_name
      cloth_img_path  = self.data_root / split / "cloth" / cloth_img_name

      person_img = Image.open(person_img_path).convert('RGB')
      cloth_img = Image.open(cloth_img_path).convert('RGB')

      # Load masks
      person_mask_path = self.data_root / split / "image-mask" / person_img_name
      cloth_mask_path  = self.data_root / split / "cloth-mask" / cloth_img_name

      if person_mask_path.exists():
          person_mask = Image.open(person_mask_path).convert('L')
          person_mask = self.mask_transform(person_mask)
      else:
          person_mask = torch.zeros(1, self.image_size, self.image_size)  # <- replace None

      if cloth_mask_path.exists():
          cloth_mask = Image.open(cloth_mask_path).convert('L')
          cloth_mask = self.mask_transform(cloth_mask)
      else:
          cloth_mask = torch.zeros(1, self.image_size, self.image_size)  # <- replace None

      # Apply transforms
      person_img = self.transform(person_img)
      cloth_img = self.transform(cloth_img)

      sample = {
          'person_image': person_img,
          'cloth_image': cloth_img,
          'person_name': person_img_name,
          'cloth_name': cloth_img_name,
          'person_mask': person_mask,
          'cloth_mask': cloth_mask
      }

      return sample


In [None]:
class ConditionedClothingDataset(ClothingDataset):
    """
    Dataset for conditioning Stable Diffusion on clothing items
    Returns: target_image, conditioning_image, caption
    """

    def __init__(self, data_root, pairs_file, image_size=512, is_train=True):
        super().__init__(data_root, pairs_file, image_size, is_train)

    def __getitem__(self, idx):
        sample = super().__getitem__(idx)

        # For fine-tuning SD: person wearing cloth is target, cloth alone is condition
        target_image = sample['person_image']  # Person wearing the clothing
        condition_image = sample['cloth_image']  # Clothing item alone

        # Create caption (you can make this more sophisticated)
        caption = f"person wearing {sample['cloth_name'].split('.')[0]}"

        return {
            'target': target_image,           # What we want to generate
            'condition': condition_image,     # What we condition on
            'caption': caption,               # Text description
            'person_mask': sample.get('person_mask'),
            'cloth_mask': sample.get('cloth_mask')
        }

In [None]:
def create_dataloaders(data_root, train_pairs, test_pairs, batch_size=4, num_workers=4, image_size=512):

    train_dataset = ConditionedClothingDataset(
        data_root=data_root,
        pairs_file=train_pairs,
        image_size=image_size,
        is_train=True
    )

    test_dataset = ConditionedClothingDataset(
        data_root=data_root,
        pairs_file=test_pairs,
        image_size=image_size,
        is_train=False
    )

    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        drop_last=True
    )

    test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        drop_last=False
    )

    return train_loader, test_loader

In [None]:
# 🔹 Create train/test loaders
train_loader, test_loader = create_dataloaders(
    data_root="/content/clothes_data",  # path to extracted dataset
    train_pairs="/content/clothes_data/train_pairs.txt",
    test_pairs="/content/clothes_data/test_pairs.txt",
    batch_size=2,
    num_workers=0,
    image_size=512
)

# 🔹 Sanity check: one batch
batch = next(iter(train_loader))
print("Keys:", batch.keys())
print("Target shape:", batch["target"].shape)
print("Condition shape:", batch["condition"].shape)
print("Caption example:", batch["caption"][0])


In [None]:
len(train_loader)

In [None]:
import torch.nn as nn
import torch.nn.functional as F

num_epochs = 2
lr = 1e-5
grad_accumulation_steps = 2
weight_dtype = torch.float16

# Extract components for direct use in training
vae = pipe.vae
unet = pipe.unet
image_encoder = pipe.image_encoder
scheduler = pipe.scheduler

# --- 2. Projection Layer for Image Conditioning ---
print("Setting up projection layer...")
image_projection = nn.Linear(1280, 2048).to(device, dtype=weight_dtype)

# --- 3. Optimizer Setup ---
print("Configuring optimizer...")
vae.requires_grad_(False)

image_encoder.requires_grad_(False)
pipe.text_encoder.requires_grad_(False)
pipe.text_encoder_2.requires_grad_(False)
unet.requires_grad_(False)

unet_trainable_params = [
    param for name, param in unet.named_parameters()
    if "attn" in name or "to_k" in name or "to_v" in name
]
for param in unet_trainable_params:
    param.requires_grad = True

optimizer = torch.optim.AdamW(
    list(image_projection.parameters()) + unet_trainable_params, lr=lr
)

# --- 4. Helper Function for Conditioning ---
def get_conditioning_embeds(pipe, captions, cloth_image):
    prompt_embeds, pooled_prompt_embeds = pipe.encode_prompt(
        prompt=captions, device=device, num_images_per_prompt=1, do_classifier_free_guidance=False
    )
    with torch.no_grad():
        image_encoder.eval()
        image_embeds = pipe.image_encoder(cloth_image.to(device, dtype=weight_dtype)).image_embeds
    return prompt_embeds, pooled_prompt_embeds, image_embeds

print("\n--- All variables are declared and ready for the training loop. ---")