# ðŸŽ¨ Lecture 17: Efficient Diffusion Models - Complete Demo

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/gaurav-redhat/efficientml_course/blob/main/17_efficient_diffusion_models/demo.ipynb)

## What You'll Learn
- Diffusion model basics and why they're slow
- Fast samplers (DDIM, DPM++)
- Model distillation (LCM, SDXL Turbo)
- Architecture optimizations

In [None]:
!pip install torch matplotlib numpy -q
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np

print('Ready for Efficient Diffusion!')

## Part 1: The Diffusion Bottleneck

In [None]:
def diffusion_analysis():
    """
    Analyze diffusion model inference characteristics.
    """
    print('ðŸ“Š DIFFUSION MODEL INFERENCE')
    print('=' * 60)
    
    print('\nðŸ”¹ Why diffusion is slow:')
    print('   1. Sequential denoising: T steps, each needs full forward pass')
    print('   2. Large UNet: ~860M parameters (SD 1.5)')
    print('   3. High resolution: 512Ã—512 or 1024Ã—1024')
    
    # Comparison
    methods = {
        'GAN (StyleGAN)': {'steps': 1, 'time_per_step': 50, 'total': 50},
        'DDPM (original)': {'steps': 1000, 'time_per_step': 50, 'total': 50000},
        'DDIM': {'steps': 50, 'time_per_step': 50, 'total': 2500},
        'DPM++ 2M': {'steps': 20, 'time_per_step': 50, 'total': 1000},
        'LCM': {'steps': 4, 'time_per_step': 50, 'total': 200},
        'SDXL Turbo': {'steps': 1, 'time_per_step': 100, 'total': 100},
    }
    
    print(f'\n{"Method":<20} {"Steps":<10} {"Time/Step":<12} {"Total (ms)":<12}')
    print('-' * 55)
    for name, info in methods.items():
        print(f'{name:<20} {info["steps"]:<10} {info["time_per_step"]:>8}ms {info["total"]:>10}ms')
    
    return methods

methods = diffusion_analysis()

In [None]:
# Visualize speedup
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Steps comparison
names = list(methods.keys())
steps = [methods[n]['steps'] for n in names]
times = [methods[n]['total'] for n in names]

colors = plt.cm.RdYlGn(np.linspace(0.9, 0.2, len(names)))

axes[0].bar(names, steps, color=colors)
axes[0].set_ylabel('Number of Steps')
axes[0].set_title('Denoising Steps by Method')
axes[0].set_yscale('log')
plt.setp(axes[0].xaxis.get_majorticklabels(), rotation=20, ha='right')

for i, (n, s) in enumerate(zip(names, steps)):
    axes[0].text(i, s * 1.2, str(s), ha='center', fontsize=10)

# Total time comparison
axes[1].bar(names, times, color=colors)
axes[1].set_ylabel('Total Time (ms)')
axes[1].set_title('Generation Time by Method')
axes[1].set_yscale('log')
plt.setp(axes[1].xaxis.get_majorticklabels(), rotation=20, ha='right')

for i, (n, t) in enumerate(zip(names, times)):
    label = f'{t/1000:.1f}s' if t >= 1000 else f'{t}ms'
    axes[1].text(i, t * 1.2, label, ha='center', fontsize=10)

plt.tight_layout()
plt.show()

print(f'\nðŸ’¡ SDXL Turbo is {methods["DDPM (original)"]["total"] / methods["SDXL Turbo"]["total"]:.0f}x faster than DDPM!')

## Part 2: Fast Samplers (DDIM, DPM++)

In [None]:
def explain_samplers():
    """
    Explain different sampling strategies.
    """
    print('ðŸ“Š SAMPLING STRATEGIES')
    print('=' * 70)
    
    samplers = {
        'DDPM': {
            'formula': 'x_{t-1} = Î¼(x_t, t) + Ïƒ_t Ã— Îµ',
            'key_idea': 'Stochastic, follows Markov chain',
            'min_steps': 1000,
        },
        'DDIM': {
            'formula': 'x_{t-1} = âˆš(Î±_{t-1}) Ã— pred_x0 + âˆš(1-Î±_{t-1}) Ã— pred_noise',
            'key_idea': 'Deterministic, can skip steps',
            'min_steps': 20,
        },
        'DPM++ 2M': {
            'formula': 'Uses ODE solver with 2nd order multistep',
            'key_idea': 'Higher order = fewer steps needed',
            'min_steps': 15,
        },
        'Euler Ancestral': {
            'formula': 'First-order Euler method + noise',
            'key_idea': 'Simple, good for creative outputs',
            'min_steps': 25,
        },
    }
    
    for name, info in samplers.items():
        print(f'\nðŸ”¹ {name}')
        print(f'   Key idea: {info["key_idea"]}')
        print(f'   Minimum steps: ~{info["min_steps"]}')

explain_samplers()

In [None]:
# Simulate denoising process
def simulate_denoising(n_steps, method='linear'):
    """
    Simulate the denoising trajectory.
    """
    t = np.linspace(1, 0, n_steps + 1)
    
    if method == 'linear':
        noise_levels = t
    elif method == 'cosine':
        noise_levels = np.cos(t * np.pi / 2)
    elif method == 'quadratic':
        noise_levels = t ** 2
    
    return noise_levels

# Visualize denoising schedules
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Different step counts
for steps in [10, 20, 50, 100]:
    trajectory = simulate_denoising(steps, 'cosine')
    axes[0].plot(np.linspace(0, 1, len(trajectory)), trajectory, 
                 'o-', label=f'{steps} steps', markersize=3)

axes[0].set_xlabel('Progress')
axes[0].set_ylabel('Noise Level')
axes[0].set_title('Denoising Trajectory at Different Step Counts')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Different schedules
schedules = ['linear', 'cosine', 'quadratic']
colors = ['#3b82f6', '#22c55e', '#ef4444']

for schedule, color in zip(schedules, colors):
    trajectory = simulate_denoising(50, schedule)
    axes[1].plot(np.linspace(0, 1, len(trajectory)), trajectory, 
                 '-', label=schedule, color=color, linewidth=2)

axes[1].set_xlabel('Progress')
axes[1].set_ylabel('Noise Level')
axes[1].set_title('Different Noise Schedules (50 steps)')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## Part 3: Model Distillation (LCM, SDXL Turbo)

In [None]:
def explain_distillation():
    """
    Explain diffusion model distillation techniques.
    """
    print('ðŸ“Š DIFFUSION MODEL DISTILLATION')
    print('=' * 70)
    
    print('\nðŸ”¹ Progressive Distillation:')
    print('   Teacher: 1000 steps â†’ Student: 500 steps')
    print('   Student learns to match 2 teacher steps in 1 step')
    print('   Repeat: 500 â†’ 250 â†’ 125 â†’ ... â†’ 1 step')
    
    print('\nðŸ”¹ Latent Consistency Models (LCM):')
    print('   Train model to directly predict clean image')
    print('   Uses consistency loss: f(x_t) â‰ˆ f(x_{t-k})')
    print('   Result: 4 steps with good quality')
    
    print('\nðŸ”¹ Adversarial Distillation (SDXL Turbo):')
    print('   Add discriminator to guide single-step generation')
    print('   Student learns to fool discriminator in 1 step')
    print('   Result: 1 step with remarkable quality')
    
    # Quality comparison
    print('\nðŸ“Š QUALITY COMPARISON (FID on COCO)')
    print('-' * 50)
    
    comparisons = [
        ('SD 1.5 (50 steps)', 50, 8.5),
        ('SD 1.5 (20 steps)', 20, 9.5),
        ('LCM-LoRA (4 steps)', 4, 10.2),
        ('SDXL Turbo (1 step)', 1, 11.5),
    ]
    
    print(f'{"Method":<25} {"Steps":<10} {"FID â†“":<10}')
    for name, steps, fid in comparisons:
        print(f'{name:<25} {steps:<10} {fid:<10.1f}')

explain_distillation()

In [None]:
# Visualize quality vs speed trade-off
fig, ax = plt.subplots(figsize=(10, 6))

methods_data = [
    ('DDPM 1000', 1000, 8.0),
    ('DDIM 50', 50, 8.5),
    ('DPM++ 20', 20, 9.0),
    ('LCM 8', 8, 9.5),
    ('LCM 4', 4, 10.2),
    ('Turbo 1', 1, 11.5),
]

steps = [d[1] for d in methods_data]
fids = [d[2] for d in methods_data]
names = [d[0] for d in methods_data]

ax.scatter(steps, fids, s=200, c='#3b82f6', zorder=5)
ax.plot(steps, fids, '--', color='gray', alpha=0.5)

for name, s, f in methods_data:
    ax.annotate(name, (s, f), xytext=(10, 5), textcoords='offset points', fontsize=10)

ax.set_xlabel('Number of Steps (log scale)', fontsize=12)
ax.set_ylabel('FID Score (lower is better)', fontsize=12)
ax.set_title('ðŸ“Š Quality vs Speed Trade-off in Diffusion', fontsize=14)
ax.set_xscale('log')
ax.grid(True, alpha=0.3)
ax.invert_xaxis()

plt.tight_layout()
plt.show()

print('\nðŸ’¡ 1-step generation achieves ~90% of 1000-step quality!')

## Part 4: Architecture Optimizations

In [None]:
def architecture_optimizations():
    """
    Discuss architecture-level optimizations.
    """
    print('ðŸ“Š ARCHITECTURE OPTIMIZATIONS')
    print('=' * 70)
    
    optimizations = {
        'FlashAttention': {
            'speedup': '2-4x',
            'memory': '5-20x less',
            'description': 'Fused attention kernel, no NÂ² memory'
        },
        'xFormers': {
            'speedup': '2x',
            'memory': '2x less',
            'description': 'Memory-efficient attention implementations'
        },
        'VAE Tiling': {
            'speedup': '1x',
            'memory': '4x less',
            'description': 'Process large images in tiles'
        },
        'FP16/BF16': {
            'speedup': '2x',
            'memory': '2x less',
            'description': 'Half precision computation'
        },
        'Torch Compile': {
            'speedup': '1.5-2x',
            'memory': '1x',
            'description': 'Graph optimization and fusion'
        },
    }
    
    print(f'{"Optimization":<20} {"Speedup":<12} {"Memory":<12} {"Description":<30}')
    print('-' * 75)
    
    for name, info in optimizations.items():
        print(f'{name:<20} {info["speedup"]:<12} {info["memory"]:<12} {info["description"]:<30}')
    
    # Total impact
    print('\nðŸ“Š COMBINED IMPACT')
    print('=' * 50)
    print('Baseline (SD 1.5, FP32, naive): 15 seconds/image')
    print('Optimized (all above):          1.5 seconds/image')
    print('+ LCM (4 steps):                0.3 seconds/image')
    print('\nðŸ’¡ Total: 50x faster end-to-end!')

architecture_optimizations()

In [None]:
print('ðŸŽ¯ KEY TAKEAWAYS')
print('=' * 60)
print('\n1. Diffusion is slow: 1000 sequential denoising steps')
print('\n2. Fast samplers: DDIM, DPM++ reduce to 20-50 steps')
print('\n3. Distillation: LCM/Turbo achieve 1-4 step generation')
print('\n4. Architecture: FlashAttention, xFormers save memory')
print('\n5. Combined: 50x faster than naive implementation')
print('\n6. Quality trade-off: 1-step â‰ˆ 90% of 1000-step quality')
print('\n' + '=' * 60)
print('\nðŸ“š Next: Quantum Machine Learning!')