# Tutorial: Conditional Cat Image Generation

Generate realistic cat images of specific breeds using our pre-trained TinyDiT diffusion model.

## What You'll Learn

- Load generator model from HuggingFace
- Generate cat images by breed
- Adjust classifier-free guidance (CFG) for variety
- Understand generation trade-offs

## Prerequisites

**Important:** This notebook works best with GPU acceleration. On Colab, use Runtime ‚Üí Change runtime type ‚Üí GPU.

In [None]:
!pip install torch huggingface_hub pillow numpy matplotlib -q

## Step 1: Check for GPU

In [None]:
import torch

if torch.cuda.is_available():
    device = "cuda"
    print(f"‚úÖ Using GPU: {torch.cuda.get_device_name(0)}")
    print(f"   CUDA Version: {torch.version.cuda}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
else:
    device = "cpu"
    print("‚ö†Ô∏è  Using CPU (slower, but works)")
    print("   For faster generation, use GPU runtime (Colab: Runtime ‚Üí Change runtime type ‚Üí GPU)")

## Step 2: Load Generator Model

We'll load the PyTorch checkpoint from HuggingFace Hub.

In [None]:
import torch
from huggingface_hub import hf_hub_download
import sys
import os

# Add src to path for model imports
sys.path.insert(0, os.path.join(os.getcwd(), '..'))

# Download model
print("Downloading generator model...")
model_path = hf_hub_download(
    repo_id="d4oit/tiny-cats-model",
    filename="generator/model.pt"
)
print(f"Model downloaded to: {model_path}")

# Import model architecture
from src.dit import tinydit_128

# Load checkpoint
print("Loading model weights...")
checkpoint = torch.load(model_path, map_location=device, weights_only=True)

# Handle different checkpoint formats
if "model_state_dict" in checkpoint:
    model_state = checkpoint["model_state_dict"]
elif "state_dict" in checkpoint:
    model_state = checkpoint["state_dict"]
else:
    model_state = checkpoint

# Initialize model
model = tinydit_128(
    image_size=128,
    patch_size=16,
    embed_dim=384,
    depth=12,
    num_heads=6,
    num_classes=13  # 12 breeds + other
)

# Load weights
model.load_state_dict(model_state, strict=False)
model = model.to(device)
model.eval()

print(f"‚úÖ Model loaded successfully!")
print(f"   Parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"   Device: {device}")

## Step 3: Define Breed Names and Sampling Function

In [None]:
# Cat breed names (must match training data)
BREED_NAMES = [
    "Abyssinian",
    "Bengal",
    "Birman",
    "Bombay",
    "British Shorthair",
    "Egyptian Mau",
    "Maine Coon",
    "Persian",
    "Ragdoll",
    "Russian Blue",
    "Siamese",
    "Sphynx",
    "Other"
]

print("Supported breeds:")
for i, breed in enumerate(BREED_NAMES):
    print(f"  {i}: {breed}")

Now let's define the sampling function for generation.

In [None]:
import numpy as np
from PIL import Image

def sample_t(batch_size: int, device: str) -> torch.Tensor:
    """Sample timesteps uniformly from [0, 1]."""
    return torch.rand(batch_size, device=device)

@torch.no_grad()
def generate(
    model: torch.nn.Module,
    breed_index: int,
    num_steps: int = 50,
    cfg_scale: float = 1.5,
    batch_size: int = 1,
    device: str = "cpu",
) -> torch.Tensor:
    """
    Generate cat images using flow matching with CFG.
    
    Args:
        model: TinyDiT model
        breed_index: Index of breed to generate (0-12)
        num_steps: Number of ODE integration steps
        cfg_scale: Classifier-free guidance scale
        batch_size: Number of images to generate
        device: Device to run on
    
    Returns:
        Generated images (batch_size, 3, 128, 128) in [-1, 1] range
    """
    # Sample initial noise
    z = torch.randn(batch_size, 3, 128, 128, device=device)
    
    # Create one-hot breed vector
    y = torch.zeros(batch_size, 13, device=device)
    y[:, breed_index] = 1
    
    # Null condition for CFG (unconditional)
    y_null = torch.zeros(batch_size, 13, device=device)
    
    # Euler integration
    dt = 1.0 / num_steps
    x = z.clone()
    
    for step in range(num_steps):
        t = torch.full((batch_size,), step * dt, device=device)
        
        # Conditional prediction
        pred_cond = model(x, t, y)
        
        # Unconditional prediction (for CFG)
        if cfg_scale != 1.0:
            pred_uncond = model(x, t, y_null)
            # Apply CFG: v = v_uncond + cfg_scale * (v_cond - v_uncond)
            pred = pred_uncond + cfg_scale * (pred_cond - pred_uncond)
        else:
            pred = pred_cond
        
        # Euler step: x_{t+1} = x_t + v * dt
        x = x + pred * dt
    
    return x

def tensor_to_image(tensor: torch.Tensor) -> Image.Image:
    """Convert tensor to PIL image."""
    # Convert from [-1, 1] to [0, 255]
    image = ((tensor + 1) / 2 * 255).clip(0, 255).to(torch.uint8)
    image = image.permute(1, 2, 0).cpu().numpy()
    return Image.fromarray(image)

print("Generation functions defined!")

## Step 4: Generate Your First Cat Image

Let's generate an Abyssinian cat (breed index 0).

In [None]:
import time

# Generate Abyssinian
breed_index = 0  # Abyssinian
breed_name = BREED_NAMES[breed_index]

print(f"Generating {breed_name}...")
start_time = time.time()

with torch.no_grad():
    generated = generate(
        model,
        breed_index=breed_index,
        num_steps=50,
        cfg_scale=1.5,
        batch_size=1,
        device=device
    )

elapsed = time.time() - start_time
print(f"Generation complete in {elapsed:.2f} seconds!")

# Convert to image and display
image = tensor_to_image(generated[0])

plt.figure(figsize=(8, 8))
plt.imshow(image)
plt.title(f"Generated {breed_name}", fontsize=16, fontweight='bold')
plt.axis("off")
plt.tight_layout()
plt.show()

# Save image
output_path = f"generated_{breed_name.lower().replace(' ', '_')}.png"
image.save(output_path)
print(f"Image saved to: {output_path}")

## Step 5: Generate All Breeds

Let's generate one image for each of the 13 breeds.

In [None]:
print("Generating all 13 breeds...")
start_time = time.time()

# Create figure for grid
fig, axes = plt.subplots(3, 5, figsize=(20, 12))
axes = axes.flatten()

generated_images = []

for i, breed_name in enumerate(BREED_NAMES):
    if i >= len(axes):
        break
    
    print(f"  [{i+1}/13] Generating {breed_name}...")
    
    with torch.no_grad():
        generated = generate(
            model,
            breed_index=i,
            num_steps=50,
            cfg_scale=1.5,
            batch_size=1,
            device=device
        )
    
    image = tensor_to_image(generated[0])
    generated_images.append(image)
    
    axes[i].imshow(image)
    axes[i].set_title(breed_name, fontsize=12, fontweight='bold')
    axes[i].axis("off")

elapsed = time.time() - start_time
print(f"\n‚úÖ Generated all breeds in {elapsed:.2f} seconds!")
print(f"   Average: {elapsed/13:.2f} seconds per breed")

plt.suptitle("TinyDiT Generated Cat Breeds", fontsize=16, fontweight='bold', y=1.02)
plt.tight_layout()
plt.show()

# Save grid
fig.savefig("all_breeds_grid.png", dpi=150, bbox_inches='tight')
print(f"Grid saved to: all_breeds_grid.png")

## Step 6: Experiment with CFG Scale

Classifier-Free Guidance (CFG) controls the trade-off between diversity and quality.

In [None]:
# Test different CFG values
cfg_values = [0.5, 1.0, 1.5, 2.0, 3.0]
breed_index = 0  # Abyssinian

print(f"Testing CFG scales: {cfg_values}")
print(f"Breed: {BREED_NAMES[breed_index]}\n")

fig, axes = plt.subplots(1, len(cfg_values), figsize=(20, 4))

for idx, cfg in enumerate(cfg_values):
    print(f"  Generating with CFG={cfg}...")
    
    with torch.no_grad():
        generated = generate(
            model,
            breed_index=breed_index,
            num_steps=50,
            cfg_scale=cfg,
            batch_size=1,
            device=device
        )
    
    image = tensor_to_image(generated[0])
    
    axes[idx].imshow(image)
    axes[idx].set_title(f"CFG = {cfg}", fontsize=14)
    axes[idx].axis("off")

plt.suptitle(f"CFG Scale Comparison - {BREED_NAMES[breed_index]}", fontsize=16, fontweight='bold')
plt.tight_layout()
plt.show()

print("\nüìù Observation:")
print("   - Low CFG (0.5-1.0): More diverse but less coherent")
print("   - Medium CFG (1.5-2.0): Good balance (recommended)")
print("   - High CFG (3.0+): Sharper but less diverse, possible artifacts")

## Step 7: Generate Multiple Variations

Generate multiple images of the same breed to see variety.

In [None]:
breed_index = 5  # Egyptian Mau
num_variations = 8

print(f"Generating {num_variations} variations of {BREED_NAMES[breed_index]}...")

fig, axes = plt.subplots(2, 4, figsize=(20, 10))
axes = axes.flatten()

for i in range(num_variations):
    with torch.no_grad():
        generated = generate(
            model,
            breed_index=breed_index,
            num_steps=50,
            cfg_scale=1.5,
            batch_size=1,
            device=device
        )
    
    image = tensor_to_image(generated[0])
    axes[i].imshow(image)
    axes[i].set_title(f"Variation {i+1}", fontsize=12)
    axes[i].axis("off")

plt.suptitle(f"{BREED_NAMES[breed_index]} - Multiple Variations", fontsize=16, fontweight='bold')
plt.tight_layout()
plt.show()

print("\n‚úÖ Each variation is unique due to random noise initialization!")

## Step 8: Batch Generation

Generate multiple images in parallel for efficiency.

In [None]:
batch_size = 8
breed_index = 2  # Birman

print(f"Generating {batch_size} images in batch...")
start_time = time.time()

with torch.no_grad():
    generated_batch = generate(
        model,
        breed_index=breed_index,
        num_steps=50,
        cfg_scale=1.5,
        batch_size=batch_size,
        device=device
    )

elapsed = time.time() - start_time
print(f"Batch generation complete in {elapsed:.2f} seconds!")
print(f"Per-image time: {elapsed/batch_size:.2f} seconds")

# Display batch
fig, axes = plt.subplots(2, 4, figsize=(16, 8))
axes = axes.flatten()

for i in range(batch_size):
    image = tensor_to_image(generated_batch[i])
    axes[i].imshow(image)
    axes[i].set_title(f"#{i+1}", fontsize=12)
    axes[i].axis("off")

plt.suptitle(f"Batch Generation - {BREED_NAMES[breed_index]}", fontsize=16, fontweight='bold')
plt.tight_layout()
plt.show()

## Step 9: Speed vs Quality Trade-off

Fewer steps = faster generation but lower quality.

In [None]:
step_configs = [
    {"steps": 10, "label": "Fast (10 steps)"},
    {"steps": 25, "label": "Balanced (25 steps)"},
    {"steps": 50, "label": "Quality (50 steps)"},
    {"steps": 100, "label": "Best (100 steps)"},
]

breed_index = 10  # Siamese

fig, axes = plt.subplots(1, len(step_configs), figsize=(20, 4))

for idx, config in enumerate(step_configs):
    print(f"Generating with {config['steps']} steps...")
    
    start = time.time()
    with torch.no_grad():
        generated = generate(
            model,
            breed_index=breed_index,
            num_steps=config["steps"],
            cfg_scale=1.5,
            batch_size=1,
            device=device
        )
    elapsed = time.time() - start
    
    image = tensor_to_image(generated[0])
    
    axes[idx].imshow(image)
    axes[idx].set_title(f"{config['label']}\n{elapsed:.2f}s", fontsize=12)
    axes[idx].axis("off")

plt.suptitle(f"Step Count Comparison - {BREED_NAMES[breed_index]}", fontsize=16, fontweight='bold')
plt.tight_layout()
plt.show()

print("\nüìù Recommendation:")
print("   - 10-25 steps: Quick previews, drafts")
print("   - 50 steps: Good quality (default)")
print("   - 100 steps: Maximum quality, final outputs")

## Step 10: Save Generated Images

In [None]:
import os

# Create output directory
output_dir = "generated_cats"
os.makedirs(output_dir, exist_ok=True)

print(f"Saving generated images to {output_dir}/...")

# Generate and save one image per breed
for i, breed_name in enumerate(BREED_NAMES):
    with torch.no_grad():
        generated = generate(
            model,
            breed_index=i,
            num_steps=50,
            cfg_scale=1.5,
            batch_size=1,
            device=device
        )
    
    image = tensor_to_image(generated[0])
    filename = f"{breed_name.lower().replace(' ', '_')}.png"
    image.save(os.path.join(output_dir, filename))
    print(f"  ‚úì {breed_name}")

print(f"\n‚úÖ Saved {len(BREED_NAMES)} images to {output_dir}/")
print(f"   Files: {os.listdir(output_dir)}")

## Performance Tips

### 1. Use GPU
GPU acceleration provides 10-50x speedup for diffusion models.

In [None]:
# Check current device
print(f"Current device: {device}")

# On Colab: Runtime ‚Üí Change runtime type ‚Üí GPU
# Then re-run all cells from the beginning

### 2. Reduce Steps for Faster Generation

In [None]:
# Quick preview (5 steps)
with torch.no_grad():
    quick = generate(model, breed_index=0, num_steps=5, cfg_scale=1.5, device=device)
    
# High quality (100 steps)
with torch.no_grad():
    quality = generate(model, breed_index=0, num_steps=100, cfg_scale=1.5, device=device)

print("Quick: ~1 second, Quality: ~10 seconds (on GPU)")

### 3. Batch Generation
Generate multiple images in parallel for better GPU utilization.

In [None]:
# Generate 8 images at once
with torch.no_grad():
    batch = generate(model, breed_index=0, num_steps=50, cfg_scale=1.5, batch_size=8, device=device)

print(f"Generated {batch.shape[0]} images in one batch")

## Common Issues & Solutions

### Issue 1: "CUDA out of memory"
**Solution:** Reduce batch_size to 1 or use CPU.

### Issue 2: "Slow generation on CPU"
**Solution:** Use fewer steps (10-25) or switch to GPU.

### Issue 3: "Poor image quality"
**Solution:** Increase num_steps to 50-100 or adjust CFG scale.

### Issue 4: "All images look the same"
**Solution:** Each generation uses random noise - they should be different. If not, check random seed.

## Summary

‚úÖ You've learned how to:
- Load a diffusion generator from HuggingFace
- Generate cat images by breed
- Adjust CFG scale for quality/variety trade-off
- Optimize generation speed vs quality
- Batch generate for efficiency

## Next Steps

- Try [Notebook 03: Training & Fine-Tuning](03_training_fine_tuning.ipynb)
- Read about the model architecture in [ADR-017](../plans/ADR-017-tinydit-training-infrastructure.md)
- Learn about flow matching in [ADR-008](../plans/ADR-008-adapt-tiny-models-architecture-for-cats-classifier-with-web-frontend.md)
- Check out the [model repository](https://huggingface.co/d4oit/tiny-cats-model)

## References

- Model: https://huggingface.co/d4oit/tiny-cats-model
- DiT Paper: https://arxiv.org/abs/2212.09748
- Flow Matching: https://arxiv.org/abs/2210.02747
- Classifier-Free Guidance: https://arxiv.org/abs/2207.12598