<a href="https://colab.research.google.com/github/nikolagojakovic/text-to-image-diffusion_model/blob/main/conditional_diffusion_text_to_image.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

**Conditional Diffusion Model - Educational Implementation**

This notebook contains my implementation of a conditional diffusion model, created for educational purposes to understand the core concepts behind modern text-to-image generation systems.










Dataset:https://www.kaggle.com/datasets/jessicali9530/celeba-dataset/data

Papers:
 - https://arxiv.org/pdf/2006.11239
 - https://arxiv.org/pdf/2103.00020
 - https://arxiv.org/pdf/2112.10752


# Imports

In [3]:
# Core PyTorch imports
import torch
import torch.nn.functional as F
from torch.optim import AdamW
from torch.utils.data import DataLoader, Dataset

# Diffusers library for diffusion models
from diffusers import UNet2DConditionModel, DDPMScheduler
from diffusers import AutoencoderKL

# Transformers library for CLIP text encoder
from transformers import CLIPTextModel, CLIPTokenizer

# Additional utilities
import numpy as np
from PIL import Image
import torchvision.transforms as transforms
from torchvision.datasets import CelebA
import random
from tqdm import tqdm


# Dataset


In [None]:
class CelebAWithText(Dataset):
    def __init__(self, root_dir, split='train', transform=None):
        self.celeba = CelebA(root=root_dir, split=split, download=True)
        self.transform = transform

        # Simple attribute-based text descriptions
        self.attr_descriptions = {
            5: "person with bangs", 8: "person with black hair", 9: "person with blond hair",
            10: "person with brown hair", 15: "person with heavy makeup", 17: "male person",
            28: "person smiling", 31: "person wearing earrings", 33: "person wearing lipstick",
            39: "female person"
        }

    def __len__(self):
        return len(self.celeba)

    def __getitem__(self, idx):
        image, attributes = self.celeba[idx]

        # Generate text description from attributes
        active_attrs = [i for i, attr in enumerate(attributes) if attr == 1 and i in self.attr_descriptions]
        if active_attrs:
            selected_attrs = random.sample(active_attrs, min(2, len(active_attrs)))
            text_parts = [self.attr_descriptions[attr] for attr in selected_attrs]
            text = "a photo of " + ", ".join(text_parts)
        else:
            text = "a photo of a person"

        if self.transform:
            image = self.transform(image)

        return image, text

# Your load_batch_data function - enhanced for CelebA
def load_batch_data(batch_size=128):
    # Setup data transformation
    transform = transforms.Compose([
        transforms.Resize((64, 64)),
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5])
    ])

    # Create dataset and dataloader
    dataset = CelebAWithText(root_dir="./data", split='train', transform=transform)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2)

    # Get one batch
    images, text = next(iter(dataloader))
    return images, text

def augment(images):
    # Run augmentations like flipping, brightness/contrast
    return images


# Initialization

Key Takeaways:
- Conditional UNet
- Conditional Latent Diffusion Model

In [None]:
# Initialize models for your training loop
def initialize_models():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Initialize VAE
    vae = AutoencoderKL.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="vae")

    # Initialize CLIP model for text encoding
    clip_model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32")
    tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")

    # Initialize UNet model
    unet_model = UNet2DConditionModel(
        sample_size=32,
        in_channels=4,  # VAE latent channels
        out_channels=4,
        down_block_types=("DownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D"),
        up_block_types=("CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "UpBlock2D"),
        block_out_channels=(64, 128, 256),
        cross_attention_dim=512,
    )

    # Initialize noise scheduler
    noise_scheduler = DDPMScheduler(
        num_train_timesteps=1000,
        beta_start=0.00085,
        beta_end=0.012,
        beta_schedule="scaled_linear"
    )

    # Move to device
    vae = vae.to(device)
    clip_model = clip_model.to(device)
    unet_model = unet_model.to(device)

    # Freeze VAE and CLIP
    vae.requires_grad_(False)
    clip_model.requires_grad_(False)

    return vae, clip_model, tokenizer, unet_model, noise_scheduler, device

# Enhanced text encoding function
def encode_text_with_clip(text, tokenizer, clip_model, device):
    tokens = tokenizer(text, padding=True, truncation=True, max_length=77, return_tensors="pt").to(device)
    with torch.no_grad():
        text_embedding = clip_model(tokens.input_ids)[0]
    return text_embedding, tokens.attention_mask

# Training loop - enhanced with proper setup
def train_diffusion_model():
    # Initialize everything
    vae, clip_model, tokenizer, unet_model, noise_scheduler, device = initialize_models()
    optimizer = AdamW(unet_model.parameters(), lr=1e-4)
    max_timesteps = noise_scheduler.config.num_train_timesteps

    print(f"Training on device: {device}")

    # Training loop
    num_epochs = 5
    batch_size = 4  # Reduced for memory efficiency

    for epoch in range(num_epochs):
        print(f"Epoch {epoch+1}/{num_epochs}")

        for step in tqdm(range(100)):  # 100 steps per epoch for demo

            # Load Data -
            images, text = load_batch_data(batch_size=batch_size)
            images = images.to(device)
            img_shape = images.shape  # [Batch Size, Num Channels, W, H]

            # Run augmentations like flipping, brightness/contrast -
            augmented_images = augment(images)

            # Run Encoder -
            with torch.no_grad():
                encoded = vae.encode(augmented_images).latent_dist.sample()
                encoded = encoded * vae.config.scaling_factor

            # Run text encoder - Y
            text_embedding, attn_mask = encode_text_with_clip(text, tokenizer, clip_model, device)

            # Generate Random Noise for Random timesteps -
            timesteps = torch.randint(0, max_timesteps, (len(encoded),), device=device)
            noise = torch.randn_like(encoded)

            # Add noise to images -
            noisy_images = noise_scheduler.add_noise(encoded, noise, timesteps)

            # UNet Forward Pass -
            prediction = unet_model(noisy_images, timesteps,
                                   encoder_hidden_states=text_embedding,
                                   encoder_attention_mask=attn_mask,
                                   return_dict=False)[0]

            # Calculate Loss -
            loss = F.mse_loss(prediction, noise)

            # Update Weights -
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            if step % 50 == 0:
                print(f"Step {step}, Loss: {loss.item():.4f}")

        # Save model after each epoch
        torch.save(unet_model.state_dict(), f'unet_celeba_epoch_{epoch+1}.pt')
        print(f"Epoch {epoch+1} completed, model saved!")

# Generating


In [None]:
def generate_sample(prompt="a photo of a person with blond hair, smiling"):
    vae, clip_model, tokenizer, unet_model, noise_scheduler, device = initialize_models()

    # Load your trained weights (uncomment when you have trained model)
    # unet_model.load_state_dict(torch.load('unet_celeba_epoch_5.pt'))

    unet_model.eval()

    # Encode prompt
    text_embedding, attn_mask = encode_text_with_clip([prompt], tokenizer, clip_model, device)

    # Generate
    noise_scheduler.set_timesteps(50)
    latents = torch.randn((1, 4, 32, 32), device=device)

    for t in tqdm(noise_scheduler.timesteps):
        with torch.no_grad():
            noise_pred = unet_model(
                latents, t.unsqueeze(0),
                encoder_hidden_states=text_embedding,
                encoder_attention_mask=attn_mask,
                return_dict=False
            )[0]
            latents = noise_scheduler.step(noise_pred, t, latents).prev_sample

    # Decode to image
    with torch.no_grad():
        latents = latents / vae.config.scaling_factor
        image = vae.decode(latents).sample

    # Convert to PIL
    image = (image / 2 + 0.5).clamp(0, 1)
    image = image.cpu().permute(0, 2, 3, 1).numpy()
    image = (image * 255).astype(np.uint8)

    return Image.fromarray(image[0])

if __name__ == "__main__":
    print("Starting training with your code structure...")
    train_diffusion_model()

    print("Generating sample image...")
    sample_image = generate_sample()
    sample_image.save("sample_output.png")
    print("Sample saved as 'sample_output.png'")