# DreamBooth for MNIST Digit Generation

This notebook explores **DreamBooth**, a powerful fine-tuning technique for diffusion models, adapted for MNIST digit generation.

## What is DreamBooth?

**DreamBooth** is a training technique that teaches a diffusion model to generate images of a specific subject using only a few example images (3-5 typically). The key innovation is:

1. **Unique Identifier**: You bind a special token (e.g., `[sks]`) to your subject in text prompts
2. **Full Model Fine-tuning**: Unlike LoRA or textual inversion, DreamBooth updates the entire UNet (and optionally the text encoder)
3. **Prior Preservation**: Uses the model's own generated images of the general class to prevent overfitting and language drift

### How it works:
- **Instance Prompt**: `"a photo of sks dog"` (where `sks` is your unique identifier)
- **Class Prompt**: `"a photo of a dog"` (the general class)
- **Loss Function**: Combines reconstruction loss on your images + prior preservation loss on class images

### For MNIST:
We'll adapt this to teach a model to generate specific styles/variations of digits. For example:
- Train on a few examples of a specific handwriting style
- Generate new digits in that style

**References:**
- [HuggingFace DreamBooth Docs](https://huggingface.co/docs/diffusers/en/training/dreambooth)
- [Original Paper](https://arxiv.org/abs/2208.12242)
- [train_dreambooth.py](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth.py)

## 1. Setup and Imports

In [None]:
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision
from torchvision import transforms
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from tqdm.auto import tqdm
import math

# Diffusers imports
from diffusers import DDPMScheduler, UNet2DModel, DDPMPipeline
from diffusers.optimization import get_cosine_schedule_with_warmup

# Check device
device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")

## 2. Load and Prepare MNIST Dataset

We'll load MNIST and extract a few examples of a specific digit to use as our "instance" images.

In [None]:
# Load MNIST from the default huggingface cache

mnist_train = torchvision.datasets.MNIST(
    root="~/.cache/huggingface/datasets/MNIST",
    train=True,
    download=True,
    transform=transforms.ToTensor()
)

mnist_test = torchvision.datasets.MNIST(
    root="~/.cache/huggingface/datasets/MNIST",
    train=False,
    download=True,
    transform=transforms.ToTensor()
)

print(f"Training samples: {len(mnist_train)}")
print(f"Test samples: {len(mnist_test)}")

### Select Instance Images

Let's pick a specific digit (e.g., digit "3") and select a few instances to represent our "subject".

In [None]:
# Configuration
TARGET_DIGIT = 3  # The digit we want to personalize
NUM_INSTANCE_IMAGES = 5  # How many examples to use (DreamBooth typically uses 3-5)
IMAGE_SIZE = 32  # We'll resize MNIST to 32x32 for compatibility with standard diffusion models

# Find instances of our target digit
instance_indices = [i for i, (img, label) in enumerate(mnist_train) if label == TARGET_DIGIT][:NUM_INSTANCE_IMAGES]
instance_images = [mnist_train[i][0] for i in instance_indices]

# Visualize our instance images
fig, axes = plt.subplots(1, NUM_INSTANCE_IMAGES, figsize=(15, 3))
fig.suptitle(f'Instance Images: Digit {TARGET_DIGIT} (Our "Subject")', fontsize=14, fontweight='bold')
for idx, (ax, img) in enumerate(zip(axes, instance_images)):
    ax.imshow(img.squeeze(), cmap='gray')
    ax.set_title(f'Instance {idx+1}')
    ax.axis('off')
plt.tight_layout()
plt.show()

### Prepare Class Images

For prior preservation, we need images of the general class (all digits).

In [None]:
NUM_CLASS_IMAGES = 100  # Number of class images for prior preservation

# Get random images from all digits (excluding our specific instances)
class_indices = [i for i in range(len(mnist_train)) if i not in instance_indices][:NUM_CLASS_IMAGES]
class_images = [mnist_train[i][0] for i in class_indices]
class_labels = [mnist_train[i][1] for i in class_indices]

# Visualize some class images
fig, axes = plt.subplots(2, 10, figsize=(15, 3))
fig.suptitle('Class Images: General Digits (For Prior Preservation)', fontsize=14, fontweight='bold')
for idx, (ax, img, label) in enumerate(zip(axes.flat, class_images[:20], class_labels[:20])):
    ax.imshow(img.squeeze(), cmap='gray')
    ax.set_title(f'{label}')
    ax.axis('off')
plt.tight_layout()
plt.show()

## 3. Create DreamBooth Dataset

We'll create a custom dataset that:
1. Returns instance images with an "instance prompt"
2. Returns class images with a "class prompt"
3. Preprocesses images to the right size and format

In [None]:
class DreamBoothDataset(Dataset):
    """
    Dataset for DreamBooth training.
    
    Combines instance images (your specific subject) with class images (prior preservation).
    """
    def __init__(
        self,
        instance_images,
        class_images,
        instance_prompt,
        class_prompt,
        size=32,
        repeats=1  # Repeat instance images to balance dataset
    ):
        self.instance_images = instance_images * repeats  # Repeat to balance with class images
        self.class_images = class_images
        self.instance_prompt = instance_prompt
        self.class_prompt = class_prompt
        
        # Image transformations
        self.transform = transforms.Compose([
            transforms.Resize(size),
            transforms.CenterCrop(size),
            transforms.RandomHorizontalFlip(p=0.5),  # Data augmentation
            transforms.Normalize([0.5], [0.5])  # Normalize to [-1, 1]
        ])
        
        # Total length is instance + class images
        self.num_instance_images = len(self.instance_images)
        self.num_class_images = len(self.class_images)
        
    def __len__(self):
        return max(self.num_instance_images, self.num_class_images)
    
    def __getitem__(self, idx):
        example = {}
        
        # Get instance image and prompt
        instance_idx = idx % self.num_instance_images
        instance_image = self.instance_images[instance_idx]
        instance_image = self.transform(instance_image)
        example["instance_images"] = instance_image
        example["instance_prompt"] = self.instance_prompt
        
        # Get class image and prompt (for prior preservation)
        class_idx = idx % self.num_class_images
        class_image = self.class_images[class_idx]
        class_image = self.transform(class_image)
        example["class_images"] = class_image
        example["class_prompt"] = self.class_prompt
        
        return example

# Create dataset
# Note: For MNIST (unconditional generation), we'll use simpler prompts or class labels
instance_prompt = f"digit_{TARGET_DIGIT}_sks"  # 'sks' is our unique identifier
class_prompt = "digit"

# Calculate repeats to balance instance and class images
repeats = math.ceil(NUM_CLASS_IMAGES / NUM_INSTANCE_IMAGES)

dreambooth_dataset = DreamBoothDataset(
    instance_images=instance_images,
    class_images=class_images,
    instance_prompt=instance_prompt,
    class_prompt=class_prompt,
    size=IMAGE_SIZE,
    repeats=repeats
)

print(f"Dataset size: {len(dreambooth_dataset)}")
print(f"Instance prompt: '{instance_prompt}'")
print(f"Class prompt: '{class_prompt}'")
print(f"Instance images repeated {repeats}x to balance with class images")

In [None]:
# Test the dataset
sample = dreambooth_dataset[0]
print("Sample keys:", sample.keys())
print("Instance image shape:", sample["instance_images"].shape)
print("Class image shape:", sample["class_images"].shape)

# Visualize
fig, axes = plt.subplots(1, 2, figsize=(8, 4))
axes[0].imshow(sample["instance_images"].squeeze() * 0.5 + 0.5, cmap='gray')  # Denormalize
axes[0].set_title(f'Instance: {sample["instance_prompt"]}')
axes[0].axis('off')

axes[1].imshow(sample["class_images"].squeeze() * 0.5 + 0.5, cmap='gray')
axes[1].set_title(f'Class: {sample["class_prompt"]}')
axes[1].axis('off')
plt.tight_layout()
plt.show()

## 4. Initialize Diffusion Model

We'll use a simple UNet2D model from diffusers. For a real application, you'd start with a pretrained model and fine-tune it.

In [None]:
# Model configuration
model = UNet2DModel(
    sample_size=IMAGE_SIZE,
    in_channels=1,  # Grayscale
    out_channels=1,
    layers_per_block=2,
    block_out_channels=(64, 128, 256, 256),
    down_block_types=(
        "DownBlock2D",
        "DownBlock2D",
        "AttnDownBlock2D",
        "DownBlock2D",
    ),
    up_block_types=(
        "UpBlock2D",
        "AttnUpBlock2D",
        "UpBlock2D",
        "UpBlock2D",
    ),
)

# Noise scheduler (DDPM)
noise_scheduler = DDPMScheduler(
    num_train_timesteps=1000,
    beta_schedule="squaredcos_cap_v2"  # Good for small images
)

model = model.to(device)
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

### Option: Load a Pretrained Model

For better results, you could start with a pretrained model. Uncomment to use:

In [None]:
# # Load a pretrained model (optional - requires fine-tuning)
# from diffusers import DDPMPipeline
# 
# # Load a small pretrained model (e.g., trained on CIFAR-10)
# pretrained_pipeline = DDPMPipeline.from_pretrained("google/ddpm-cifar10-32")
# model = pretrained_pipeline.unet
# noise_scheduler = pretrained_pipeline.scheduler
# 
# # Adapt for grayscale (MNIST)
# # Note: This requires modifying the first conv layer
# model = model.to(device)

## 5. DreamBooth Training Loop

The key to DreamBooth is the **combined loss**:

```
total_loss = instance_loss + prior_preservation_weight * class_loss
```

Where:
- `instance_loss`: Reconstruction loss on your specific subject images
- `class_loss`: Regularization loss on general class images (prevents overfitting)
- `prior_preservation_weight`: Controls strength of regularization (typically 1.0)

In [None]:
# Training configuration
config = {
    "batch_size": 4,
    "num_epochs": 100,
    "learning_rate": 1e-4,
    "lr_warmup_steps": 50,
    "prior_loss_weight": 1.0,  # Weight for prior preservation loss
    "gradient_accumulation_steps": 1,
    "mixed_precision": "no",  # "fp16" if using CUDA and want faster training
    "save_model_epochs": 20,
    "sample_every_n_epochs": 10,  # Generate samples during training
}

# Create dataloader
train_dataloader = DataLoader(
    dreambooth_dataset,
    batch_size=config["batch_size"],
    shuffle=True,
    num_workers=0  # Increase if you have CPU cores available
)

# Optimizer
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=config["learning_rate"],
    betas=(0.9, 0.999),
    weight_decay=0.01,
    eps=1e-8,
)

# Learning rate scheduler with warmup
lr_scheduler = get_cosine_schedule_with_warmup(
    optimizer=optimizer,
    num_warmup_steps=config["lr_warmup_steps"],
    num_training_steps=len(train_dataloader) * config["num_epochs"],
)

print("Training configuration:")
for key, value in config.items():
    print(f"  {key}: {value}")
print(f"\nSteps per epoch: {len(train_dataloader)}")
print(f"Total training steps: {len(train_dataloader) * config['num_epochs']}")

In [None]:
# Training loop
def train_dreambooth(
    model,
    noise_scheduler,
    train_dataloader,
    optimizer,
    lr_scheduler,
    config,
    device
):
    """DreamBooth training loop with prior preservation."""
    
    model.train()
    global_step = 0
    losses = []
    
    progress_bar = tqdm(
        range(config["num_epochs"] * len(train_dataloader)),
        desc="Training"
    )
    
    for epoch in range(config["num_epochs"]):
        epoch_loss = 0.0
        
        for step, batch in enumerate(train_dataloader):
            # Get instance and class images
            instance_images = batch["instance_images"].to(device)
            class_images = batch["class_images"].to(device)
            
            # Combine into single batch for efficiency
            images = torch.cat([instance_images, class_images], dim=0)
            batch_size = images.shape[0]
            
            # Sample random timesteps for each image
            timesteps = torch.randint(
                0,
                noise_scheduler.config.num_train_timesteps,
                (batch_size,),
                device=device
            ).long()
            
            # Add noise to images (forward diffusion process)
            noise = torch.randn_like(images)
            noisy_images = noise_scheduler.add_noise(images, noise, timesteps)
            
            # Predict the noise
            noise_pred = model(noisy_images, timesteps).sample
            
            # Calculate loss
            # Split predictions back into instance and class
            instance_pred = noise_pred[:len(instance_images)]
            class_pred = noise_pred[len(instance_images):]
            
            instance_noise = noise[:len(instance_images)]
            class_noise = noise[len(instance_images):]
            
            # Instance loss (MSE between predicted and actual noise)
            instance_loss = F.mse_loss(instance_pred, instance_noise, reduction="mean")
            
            # Prior preservation loss (regularization)
            class_loss = F.mse_loss(class_pred, class_noise, reduction="mean")
            
            # Combined loss
            loss = instance_loss + config["prior_loss_weight"] * class_loss
            
            # Backpropagation
            loss.backward()
            
            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            
            # Update weights
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()
            
            # Logging
            epoch_loss += loss.item()
            losses.append(loss.item())
            
            progress_bar.update(1)
            progress_bar.set_postfix({
                "epoch": epoch + 1,
                "loss": loss.item(),
                "inst_loss": instance_loss.item(),
                "class_loss": class_loss.item(),
                "lr": lr_scheduler.get_last_lr()[0]
            })
            
            global_step += 1
        
        # End of epoch
        avg_epoch_loss = epoch_loss / len(train_dataloader)
        print(f"\nEpoch {epoch + 1}/{config['num_epochs']} - Avg Loss: {avg_epoch_loss:.4f}")
        
        # Generate samples
        if (epoch + 1) % config["sample_every_n_epochs"] == 0:
            print(f"Generating samples at epoch {epoch + 1}...")
            generate_samples(model, noise_scheduler, device, num_samples=8)
    
    progress_bar.close()
    return losses

def generate_samples(model, noise_scheduler, device, num_samples=8):
    """Generate samples from the model."""
    model.eval()
    
    with torch.no_grad():
        # Start from random noise
        images = torch.randn(
            num_samples, 1, IMAGE_SIZE, IMAGE_SIZE,
            device=device
        )
        
        # Denoise iteratively
        for t in noise_scheduler.timesteps:
            # Predict noise
            noise_pred = model(images, t).sample
            
            # Remove noise
            images = noise_scheduler.step(noise_pred, t, images).prev_sample
        
        # Denormalize and clamp
        images = (images / 2 + 0.5).clamp(0, 1)
        images = images.cpu()
    
    # Visualize
    fig, axes = plt.subplots(1, num_samples, figsize=(16, 2))
    for idx, (ax, img) in enumerate(zip(axes, images)):
        ax.imshow(img.squeeze(), cmap='gray')
        ax.axis('off')
    plt.suptitle('Generated Samples', fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.show()
    
    model.train()

print("Training functions defined. Ready to train!")

### Run Training

**Note**: Training can take a while depending on your hardware. Start with fewer epochs for experimentation.

In [None]:
# Run training
print("Starting DreamBooth training...\n")
losses = train_dreambooth(
    model=model,
    noise_scheduler=noise_scheduler,
    train_dataloader=train_dataloader,
    optimizer=optimizer,
    lr_scheduler=lr_scheduler,
    config=config,
    device=device
)
print("\nTraining complete!")

### Plot Training Loss

In [None]:
# Plot loss curve
plt.figure(figsize=(12, 4))
plt.plot(losses, alpha=0.3, label='Step loss')

# Moving average for smoother visualization
window_size = 10
if len(losses) >= window_size:
    moving_avg = np.convolve(losses, np.ones(window_size)/window_size, mode='valid')
    plt.plot(range(window_size-1, len(losses)), moving_avg, linewidth=2, label=f'{window_size}-step moving avg')

plt.xlabel('Training Step')
plt.ylabel('Loss')
plt.title('DreamBooth Training Loss')
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

## 6. Generate and Evaluate Results

Now let's generate images and compare them to our original instance images.

In [None]:
# Generate a grid of samples
num_samples = 16

model.eval()
with torch.no_grad():
    # Start from random noise
    generated_images = torch.randn(
        num_samples, 1, IMAGE_SIZE, IMAGE_SIZE,
        device=device
    )
    
    # Denoise with progress bar
    for t in tqdm(noise_scheduler.timesteps, desc="Generating samples"):
        noise_pred = model(generated_images, t).sample
        generated_images = noise_scheduler.step(noise_pred, t, generated_images).prev_sample
    
    # Denormalize
    generated_images = (generated_images / 2 + 0.5).clamp(0, 1)
    generated_images = generated_images.cpu()

# Display results
fig, axes = plt.subplots(4, 4, figsize=(12, 12))
fig.suptitle(f'Generated Digits (Fine-tuned on Digit {TARGET_DIGIT})', fontsize=16, fontweight='bold')
for idx, (ax, img) in enumerate(zip(axes.flat, generated_images)):
    ax.imshow(img.squeeze(), cmap='gray')
    ax.set_title(f'Sample {idx+1}')
    ax.axis('off')
plt.tight_layout()
plt.show()

### Compare with Original Instance Images

In [None]:
# Side-by-side comparison
fig, axes = plt.subplots(2, NUM_INSTANCE_IMAGES, figsize=(15, 6))
fig.suptitle(f'Comparison: Original Instance Images vs Generated', fontsize=14, fontweight='bold')

# Original instances
for idx in range(NUM_INSTANCE_IMAGES):
    axes[0, idx].imshow(instance_images[idx].squeeze(), cmap='gray')
    axes[0, idx].set_title(f'Original {idx+1}')
    axes[0, idx].axis('off')

# Generated samples
for idx in range(NUM_INSTANCE_IMAGES):
    axes[1, idx].imshow(generated_images[idx].squeeze(), cmap='gray')
    axes[1, idx].set_title(f'Generated {idx+1}')
    axes[1, idx].axis('off')

plt.tight_layout()
plt.show()

## 7. Save and Load Model

In [None]:
# Save model
output_dir = Path("./dreambooth_mnist_output")
output_dir.mkdir(exist_ok=True)

# Save using diffusers format
pipeline = DDPMPipeline(
    unet=model,
    scheduler=noise_scheduler
)
pipeline.save_pretrained(output_dir)

print(f"Model saved to {output_dir}")

# To load later:
# loaded_pipeline = DDPMPipeline.from_pretrained(output_dir)
# generated = loaded_pipeline(batch_size=8, num_inference_steps=1000).images

## 8. Experiments and Analysis

### Experiment 1: Effect of Prior Preservation Weight

In [None]:
# Try different prior preservation weights
# Higher weight = stronger regularization = less overfitting but potentially less similarity to instances
# Lower weight = weaker regularization = more overfitting but potentially more similarity to instances

print("""
Experiment: Try different prior_loss_weight values:
- 0.0: No prior preservation (pure fine-tuning on instances) - may overfit
- 0.5: Weak regularization
- 1.0: Balanced (default)
- 2.0: Strong regularization - may not learn instance well

Modify config['prior_loss_weight'] above and re-run training to compare results!
""")

### Experiment 2: Number of Instance Images

In [None]:
print("""
Experiment: Try different numbers of instance images:
- 3-5 images: Standard DreamBooth setup
- 1-2 images: Extreme few-shot learning
- 10+ images: More training data (may not need DreamBooth)

Modify NUM_INSTANCE_IMAGES above and re-run from dataset creation!
""")

### Experiment 3: Different Target Digits

In [None]:
print("""
Experiment: Try different target digits:
- Simple digits (0, 1): May be easier to learn
- Complex digits (8, 9): More variation in handwriting

Modify TARGET_DIGIT above and re-run!
""")

## 9. Key Takeaways

### What We Learned:

1. **DreamBooth Core Idea**: Fine-tune a diffusion model on a few images of a specific subject using a unique identifier token

2. **Prior Preservation**: Essential to prevent overfitting and "language drift" - maintains the model's general knowledge

3. **Loss Function**:
   ```
   L = L_instance + λ * L_class
   ```
   Where λ controls the regularization strength

4. **Few-Shot Learning**: DreamBooth excels with just 3-5 training examples

5. **Adaptation for MNIST**: While DreamBooth is typically used with text-conditioned models (like Stable Diffusion), the core technique applies to unconditional generation too

### Limitations & Next Steps:

- **Memory intensive**: Full model fine-tuning requires significant GPU memory
- **Overfitting risk**: Too few class images or too many epochs can lead to overfitting
- **Hyperparameter sensitivity**: prior_loss_weight, learning rate, and number of training steps all matter

### Extensions:

1. **Use with Stable Diffusion**: Apply to real images with text conditioning
2. **LoRA + DreamBooth**: Combine with Low-Rank Adaptation for efficiency
3. **Custom regularization images**: Generate better class images for your specific use case
4. **Multi-subject training**: Train on multiple subjects with different identifiers

## References

- [DreamBooth Paper](https://arxiv.org/abs/2208.12242)
- [HuggingFace DreamBooth Documentation](https://huggingface.co/docs/diffusers/en/training/dreambooth)
- [HuggingFace train_dreambooth.py](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth.py)
- [Diffusers Library](https://github.com/huggingface/diffusers)