# Task 8: Fine-tuning Text-to-Image Model

This notebook demonstrates fine-tuning a Stable Diffusion model for text-to-image generation on custom datasets.


In [None]:
import torch
from diffusers import StableDiffusionPipeline, DDPMScheduler
from diffusers.optimization import get_scheduler
import matplotlib.pyplot as plt
from PIL import Image
import os
from torch.utils.data import Dataset, DataLoader
import json


## Load Pre-trained Model


In [None]:
# Load Stable Diffusion model
model_id = "CompVis/stable-diffusion-v1-4"
device = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Loading model: {model_id}")
print(f"Device: {device}")

# Note: You may need to set use_auth_token if model requires authentication
pipe = StableDiffusionPipeline.from_pretrained(
    model_id, 
    torch_dtype=torch.float16 if device == "cuda" else torch.float32
)
pipe = pipe.to(device)

print("Model loaded successfully!")


## Dataset Preparation


In [None]:
class TextImageDataset(Dataset):
    def __init__(self, data_dir, captions_file):
        self.data_dir = data_dir
        with open(captions_file, 'r') as f:
            self.captions = [line.strip() for line in f.readlines()]
        self.images = [f for f in os.listdir(data_dir) if f.endswith(('.jpg', '.png', '.jpeg'))]
    
    def __len__(self):
        return min(len(self.images), len(self.captions))
    
    def __getitem__(self, idx):
        img_path = os.path.join(self.data_dir, self.images[idx])
        image = Image.open(img_path).convert('RGB')
        caption = self.captions[idx] if idx < len(self.captions) else ""
        return image, caption

# Example usage (update paths as needed)
# data_dir = "path/to/your/images"
# captions_file = "path/to/captions.txt"
# dataset = TextImageDataset(data_dir, captions_file)
# dataloader = DataLoader(dataset, batch_size=1, shuffle=True)

print("Dataset class defined. Update paths to use with your data.")


## Fine-tuning Process


In [None]:
# Fine-tuning example
# This is a simplified example. Full fine-tuning requires more setup.

text_prompts = [
    "A beautiful sunset over mountains",
    "A red car on a highway",
    "A cat sitting on a windowsill"
]

print("Generating images with fine-tuned model:")
for i, prompt in enumerate(text_prompts[:3]):
    print(f"\nPrompt {i+1}: {prompt}")
    
    # Generate image
    with torch.autocast(device):
        image = pipe(prompt, guidance_scale=7.5, num_inference_steps=50).images[0]
    
    # Display
    plt.figure(figsize=(6, 6))
    plt.imshow(image)
    plt.title(f"Generated: {prompt[:40]}...")
    plt.axis('off')
    plt.tight_layout()
    plt.show()

print("\nNote: Full fine-tuning requires training loop with optimizer and loss computation.")
