In [None]:
import torch
torch.enable_grad(False)

# Enable PyTorch performance optimizations
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cudnn.benchmark = True

import os
import time
import pandas as pd
from tqdm.auto import tqdm
import argparse
import random

import sys
sys.path.append('../.')
from utils.load_util import load_sdxl_models, load_pipe



distillation_type='dmd'  # what type of distillation model do you want to use ("dmd", "lcm", "turbo", "lightning")
device = 'cuda:0'  # Use CUDA for A100 GPU
weights_dtype = torch.bfloat16  # Use bfloat16 for better performance on A100

pipe, base_unet, base_scheduler, distilled_unet, distilled_scheduler = load_sdxl_models(distillation_type=distillation_type, 
                                                                                        weights_dtype=weights_dtype, 
                                                                                        device=device)

In [None]:
def diversity_distillation(prompt, seed, pipe, base_unet, distilled_unet, distilled_scheduler, base_guidance_scale=5, distilled_guidance_scale=0, num_inference_steps=4, run_base_till=1):
    pipe.scheduler = distilled_scheduler
    pipe.unet=base_unet

    base_latents = pipe(prompt,
                    guidance_scale=base_guidance_scale,
                    till_timestep=run_base_till, 
                    num_inference_steps=num_inference_steps,
                    generator=torch.Generator().manual_seed(seed),
                    output_type='latent'
                   )
    

    pipe.unet = distilled_unet
    images = pipe(prompt,
                 guidance_scale=distilled_guidance_scale,
                 start_latents = base_latents,   
                 num_inference_steps=num_inference_steps,
                 from_timestep=run_base_till,
                 output_type='pil'
                )
    return images

In [None]:
prompt = 'cartoon character'
seed = random.randint(0, 2**15)

images = diversity_distillation(prompt, seed, pipe, base_unet, distilled_unet, distilled_scheduler)

images[0]

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import os

# Function to generate images using only the distilled model
def distilled_only_generation(prompt, seed, pipe, distilled_unet, distilled_scheduler, guidance_scale=0, num_inference_steps=4):
    pipe.scheduler = distilled_scheduler
    pipe.unet = distilled_unet
    
    images = pipe(prompt,
                 guidance_scale=guidance_scale,
                 num_inference_steps=num_inference_steps,
                 generator=torch.Generator().manual_seed(seed),
                 output_type='pil'
                )
    return images

# Function to generate images using only the base model
def base_only_generation(prompt, seed, pipe, base_unet, base_scheduler, guidance_scale=5, num_inference_steps=20):
    pipe.scheduler = base_scheduler
    pipe.unet = base_unet
    
    images = pipe(prompt,
                 guidance_scale=guidance_scale,
                 num_inference_steps=num_inference_steps,
                 generator=torch.Generator().manual_seed(seed),
                 output_type='pil'
                )
    return images

# User prompt
prompt = "bear in a top hat"  # Replace with your desired prompt
num_images = 6  # 2x3 grid

# Create output directory
output_dir = "generated_images"
os.makedirs(output_dir, exist_ok=True)

nrows = 2
ncols = 3

# Initialize variables
total_time_diversity = 0
total_time_distilled = 0
total_time_base = 0
all_diversity_images = []
all_distilled_images = []
all_base_images = []

pipe.set_progress_bar_config(disable=True)

# Generate images with all three methods
for i in tqdm(range(num_images)):
    # Generate random seed (use same seed for all methods for fair comparison)
    seed = np.random.randint(0, 2**32 - 1)
    
    # Generate diversity distillation image
    start_time = time.perf_counter()
    diversity_image = diversity_distillation(prompt, seed, pipe, base_unet, distilled_unet, distilled_scheduler)[0]
    end_time = time.perf_counter()
    runtime_diversity = end_time - start_time
    total_time_diversity += runtime_diversity
    
    # Generate distilled-only image
    start_time = time.perf_counter()
    distilled_image = distilled_only_generation(prompt, seed, pipe, distilled_unet, distilled_scheduler)[0]
    end_time = time.perf_counter()
    runtime_distilled = end_time - start_time
    total_time_distilled += runtime_distilled
    
    # Generate base-only image
    start_time = time.perf_counter()
    base_image = base_only_generation(prompt, seed, pipe, base_unet, base_scheduler)[0]
    end_time = time.perf_counter()
    runtime_base = end_time - start_time
    total_time_base += runtime_base
    
    # Save individual images to disk
    diversity_filename = f"{output_dir}/diversity_image_{i+1:02d}_seed_{seed}.png"
    distilled_filename = f"{output_dir}/distilled_image_{i+1:02d}_seed_{seed}.png"
    base_filename = f"{output_dir}/base_image_{i+1:02d}_seed_{seed}.png"
    diversity_image.save(diversity_filename)
    distilled_image.save(distilled_filename)
    base_image.save(base_filename)
    
    # Append to lists for grid creation
    all_diversity_images.append(diversity_image)
    all_distilled_images.append(distilled_image)
    all_base_images.append(base_image)

# Create comparison figure with three subplots side by side
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(ncols*9, nrows*3), dpi=200)

# Create grid for diversity distillation images
ax1.set_title("Diversity Distillation\n(Base + Distilled)", fontsize=14, pad=20)
for i in range(nrows):
    for j in range(ncols):
        idx = i * ncols + j
        if idx < len(all_diversity_images):
            # Calculate position for each image in the grid
            y_start = (nrows - 1 - i) / nrows
            y_end = (nrows - i) / nrows
            x_start = j / ncols
            x_end = (j + 1) / ncols
            
            # Create inset axes for each image
            img_ax = ax1.inset_axes([x_start, y_start, x_end - x_start, y_end - y_start])
            img_ax.imshow(all_diversity_images[idx])
            img_ax.axis('off')

ax1.set_xlim(0, 1)
ax1.set_ylim(0, 1)
ax1.axis('off')

# Create grid for distilled-only images
ax2.set_title("Distilled Model Only\n(4 steps)", fontsize=14, pad=20)
for i in range(nrows):
    for j in range(ncols):
        idx = i * ncols + j
        if idx < len(all_distilled_images):
            # Calculate position for each image in the grid
            y_start = (nrows - 1 - i) / nrows
            y_end = (nrows - i) / nrows
            x_start = j / ncols
            x_end = (j + 1) / ncols
            
            # Create inset axes for each image
            img_ax = ax2.inset_axes([x_start, y_start, x_end - x_start, y_end - y_start])
            img_ax.imshow(all_distilled_images[idx])
            img_ax.axis('off')

ax2.set_xlim(0, 1)
ax2.set_ylim(0, 1)
ax2.axis('off')

# Create grid for base-only images
ax3.set_title("Base Model Only\n(20 steps)", fontsize=14, pad=20)
for i in range(nrows):
    for j in range(ncols):
        idx = i * ncols + j
        if idx < len(all_base_images):
            # Calculate position for each image in the grid
            y_start = (nrows - 1 - i) / nrows
            y_end = (nrows - i) / nrows
            x_start = j / ncols
            x_end = (j + 1) / ncols
            
            # Create inset axes for each image
            img_ax = ax3.inset_axes([x_start, y_start, x_end - x_start, y_end - y_start])
            img_ax.imshow(all_base_images[idx])
            img_ax.axis('off')

ax3.set_xlim(0, 1)
ax3.set_ylim(0, 1)
ax3.axis('off')

plt.tight_layout()

# Print timing information
print(f"Diversity Distillation - Total Runtime: {total_time_diversity:.4f} seconds")
print(f"Diversity Distillation - Average per image: {total_time_diversity/num_images:.4f} seconds")
print(f"Distilled Only - Total Runtime: {total_time_distilled:.4f} seconds")
print(f"Distilled Only - Average per image: {total_time_distilled/num_images:.4f} seconds")
print(f"Base Only - Total Runtime: {total_time_base:.4f} seconds")
print(f"Base Only - Average per image: {total_time_base/num_images:.4f} seconds")
print(f"Individual images saved to: {output_dir}/")

# Save the comparison grid
comparison_filename = f"{output_dir}/three_way_comparison_{prompt.replace(' ', '_')}.png"
plt.savefig(comparison_filename, bbox_inches='tight', pad_inches=0.1, dpi=200)
print(f"Three-way comparison grid saved to: {comparison_filename}")

plt.show()

# Multi-Model Diversity Comparison

Compare diversity across three different distillation models: DMD, Lightning, and Turbo

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import os
import time
from tqdm.auto import tqdm

# Configuration
prompt = "frog in a hat"
num_images = 6  # 2x3 grid
nrows = 2
ncols = 3

# Models to compare
model_types = ['dmd', 'lightning', 'turbo']
model_names = {
    'dmd': 'DMD',
    'lightning': 'Lightning', 
    'turbo': 'Turbo'
}

# Create output directory
output_dir = "multi_model_comparison"
os.makedirs(output_dir, exist_ok=True)

# Store results for each model and method
all_results = {}
timing_results = {}

print(f"Comparing 3 generation methods across {len(model_types)} models: {', '.join(model_names.values())}")
print(f"Generating {num_images} images per method per model with prompt: '{prompt}'\n")

# Load and test each model
for model_type in model_types:
    print(f"\n{'='*60}")
    print(f"Loading {model_names[model_type]} model...")
    print(f"{'='*60}")
    
    # Load the model
    start_load = time.time()
    model_pipe, model_base_unet, model_base_scheduler, model_distilled_unet, model_distilled_scheduler = load_sdxl_models(
        distillation_type=model_type,
        weights_dtype=torch.bfloat16,
        device=device
    )
    load_time = time.time() - start_load
    print(f"✓ {model_names[model_type]} loaded in {load_time:.2f}s")
    
    # Initialize storage for this model
    all_results[model_type] = {
        'diversity': [],
        'distilled': [],
        'base': []
    }
    timing_results[model_type] = {
        'diversity': {'total': 0, 'average': 0},
        'distilled': {'total': 0, 'average': 0},
        'base': {'total': 0, 'average': 0},
        'load_time': load_time
    }
    
    model_pipe.set_progress_bar_config(disable=True)
    
    # Generate images with all three methods
    for i in tqdm(range(num_images), desc=f"Generating {model_names[model_type]} images"):
        # Use same seed for all methods for fair comparison
        seed = np.random.randint(0, 2**32 - 1)
        
        # 1. Diversity Distillation (Base + Distilled)
        start_time = time.perf_counter()
        diversity_images = diversity_distillation(
            prompt, seed, model_pipe, 
            model_base_unet, model_distilled_unet, 
            model_distilled_scheduler
        )
        gen_time = time.perf_counter() - start_time
        timing_results[model_type]['diversity']['total'] += gen_time
        all_results[model_type]['diversity'].append(diversity_images[0])
        
        # Save individual image
        diversity_filename = f"{output_dir}/{model_type}_diversity_{i+1:02d}_seed_{seed}.png"
        diversity_images[0].save(diversity_filename)
        
        # 2. Distilled Model Only
        start_time = time.perf_counter()
        distilled_images = distilled_only_generation(
            prompt, seed, model_pipe,
            model_distilled_unet, model_distilled_scheduler
        )
        gen_time = time.perf_counter() - start_time
        timing_results[model_type]['distilled']['total'] += gen_time
        all_results[model_type]['distilled'].append(distilled_images[0])
        
        # Save individual image
        distilled_filename = f"{output_dir}/{model_type}_distilled_{i+1:02d}_seed_{seed}.png"
        distilled_images[0].save(distilled_filename)
        
        # 3. Base Model Only
        start_time = time.perf_counter()
        base_images = base_only_generation(
            prompt, seed, model_pipe,
            model_base_unet, model_base_scheduler
        )
        gen_time = time.perf_counter() - start_time
        timing_results[model_type]['base']['total'] += gen_time
        all_results[model_type]['base'].append(base_images[0])
        
        # Save individual image
        base_filename = f"{output_dir}/{model_type}_base_{i+1:02d}_seed_{seed}.png"
        base_images[0].save(base_filename)
    
    # Calculate averages
    for method in ['diversity', 'distilled', 'base']:
        timing_results[model_type][method]['average'] = timing_results[model_type][method]['total'] / num_images
    
    print(f"✓ Generated {num_images*3} images total ({num_images} per method)")
    
    # Clean up to free memory
    del model_pipe, model_base_unet, model_base_scheduler, model_distilled_unet, model_distilled_scheduler
    torch.cuda.empty_cache()

print(f"\n{'='*60}")
print("Creating comparison visualization...")
print(f"{'='*60}\n")

# Create comparison figure: 3 rows (methods) x 3 columns (models)
fig, axes = plt.subplots(3, 3, figsize=(ncols*10, nrows*11), dpi=200)

method_labels = {
    'diversity': 'Diversity Distillation\n(Base + Distilled)',
    'distilled': 'Distilled Model Only\n(4 steps)',
    'base': 'Base Model Only\n(20 steps)'
}
methods = ['diversity', 'distilled', 'base']

# Create grid for each combination of method and model
for method_idx, method in enumerate(methods):
    for model_idx, model_type in enumerate(model_types):
        ax = axes[method_idx, model_idx]
        model_name = model_names[model_type]
        avg_time = timing_results[model_type][method]['average']
        
        # Title includes model name, method, and timing
        if method_idx == 0:  # Top row gets model name
            title = f"{model_name}\n{method_labels[method]}\n{avg_time:.2f}s/img"
        else:
            title = f"{method_labels[method]}\n{avg_time:.2f}s/img"
        
        ax.set_title(title, fontsize=12, pad=15, fontweight='bold')
        
        # Create grid for this method's images
        for i in range(nrows):
            for j in range(ncols):
                idx = i * ncols + j
                if idx < len(all_results[model_type][method]):
                    # Calculate position for each image in the grid
                    y_start = (nrows - 1 - i) / nrows
                    y_end = (nrows - i) / nrows
                    x_start = j / ncols
                    x_end = (j + 1) / ncols
                    
                    # Create inset axes for each image
                    img_ax = ax.inset_axes([x_start, y_start, x_end - x_start, y_end - y_start])
                    img_ax.imshow(all_results[model_type][method][idx])
                    img_ax.axis('off')
        
        ax.set_xlim(0, 1)
        ax.set_ylim(0, 1)
        ax.axis('off')

plt.tight_layout()

# Print comprehensive timing information
print("\n" + "="*60)
print("TIMING RESULTS")
print("="*60)
for model_type in model_types:
    model_name = model_names[model_type]
    results = timing_results[model_type]
    print(f"\n{model_name} (Load time: {results['load_time']:.2f}s):")
    print(f"  Diversity Distillation:")
    print(f"    Total: {results['diversity']['total']:.2f}s | Avg: {results['diversity']['average']:.2f}s | Throughput: {1.0/results['diversity']['average']:.2f} img/s")
    print(f"  Distilled Only:")
    print(f"    Total: {results['distilled']['total']:.2f}s | Avg: {results['distilled']['average']:.2f}s | Throughput: {1.0/results['distilled']['average']:.2f} img/s")
    print(f"  Base Only:")
    print(f"    Total: {results['base']['total']:.2f}s | Avg: {results['base']['average']:.2f}s | Throughput: {1.0/results['base']['average']:.2f} img/s")

print(f"\nIndividual images saved to: {output_dir}/")

# Save the comparison grid
comparison_filename = f"{output_dir}/full_comparison_{prompt.replace(' ', '_')}.png"
plt.savefig(comparison_filename, bbox_inches='tight', pad_inches=0.1, dpi=200)
print(f"Full comparison grid saved to: {comparison_filename}")

plt.show()

print("\n✅ Multi-model multi-method comparison complete!")
print(f"Total images generated: {len(model_types) * 3 * num_images} ({len(model_types)} models × 3 methods × {num_images} images)")