In [None]:
import torch
torch.enable_grad(False)

# Enable PyTorch performance optimizations
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cudnn.benchmark = True

import os
import time
import pandas as pd
from tqdm.auto import tqdm
import argparse
import random

import sys
sys.path.append('../.')
from utils.load_util import load_model, load_pipe

distillation_type = 'dmd'  # what type of distillation model do you want to use ("dmd", "lcm", "turbo", "lightning")
device = 'cuda:0'  # Use CUDA for A100 GPU
weights_dtype = torch.bfloat16  # Use bfloat16 for better performance on A100

# Load model with proper handling of distillation_type=None
result = load_model(distillation_type=distillation_type, 
                    weights_dtype=weights_dtype, 
                    device=device)

if distillation_type is None:
    pipe, base_unet, base_scheduler = result
    distilled_unet = None
    distilled_scheduler = None
else:
    pipe, base_unet, base_scheduler, distilled_unet, distilled_scheduler = result

In [None]:
def diversity_distillation(prompt, seed, pipe, base_unet, distilled_unet, distilled_scheduler, base_guidance_scale=5, distilled_guidance_scale=0, num_inference_steps=4, run_base_till=1):
    """Generate images using diversity distillation (base + distilled model)."""
    pipe.scheduler = distilled_scheduler
    pipe.unet = base_unet

    base_latents = pipe(prompt,
                    guidance_scale=base_guidance_scale,
                    till_timestep=run_base_till, 
                    num_inference_steps=num_inference_steps,
                    generator=torch.Generator().manual_seed(seed),
                    output_type='latent'
                   )

    if distilled_unet is not None:
        pipe.unet = distilled_unet
    
    images = pipe(prompt,
                 guidance_scale=distilled_guidance_scale,
                 start_latents = base_latents,   
                 num_inference_steps=num_inference_steps,
                 from_timestep=run_base_till,
                 output_type='pil'
                )
    return images

In [None]:
import matplotlib.pyplot as plt
import numpy as np

def distilled_only_generation(prompt, seed, pipe, distilled_unet, distilled_scheduler, guidance_scale=0, num_inference_steps=4):
    """Generate images using only the distilled model."""
    if distilled_unet is None:
        raise ValueError("No distilled model loaded. Cannot use distilled_only_generation.")
    
    pipe.scheduler = distilled_scheduler
    pipe.unet = distilled_unet
    
    images = pipe(prompt,
                 guidance_scale=guidance_scale,
                 num_inference_steps=num_inference_steps,
                 generator=torch.Generator().manual_seed(seed),
                 output_type='pil'
                )
    return images

def base_only_generation(prompt, seed, pipe, base_unet, base_scheduler, guidance_scale=5, num_inference_steps=20):
    """Generate images using only the base model."""
    pipe.scheduler = base_scheduler
    pipe.unet = base_unet
    
    images = pipe(prompt,
                 guidance_scale=guidance_scale,
                 num_inference_steps=num_inference_steps,
                 generator=torch.Generator().manual_seed(seed),
                 output_type='pil'
                )
    return images

# Multi-Model Comparison with load_model

This notebook uses the refactored `load_model` function where `distilled_model = None` when no distillation model is present.

In [None]:
# Configuration
prompt = "frog in a hat"
num_images = 3  # Reduce for testing
nrows = 1
ncols = 3

# Models to compare
model_types = ['dmd', 'lightning', 'turbo']
model_names = {
    'dmd': 'DMD',
    'lightning': 'Lightning', 
    'turbo': 'Turbo'
}

# Create output directory
output_dir = "multi_model_comparison"
os.makedirs(output_dir, exist_ok=True)

# Store results for each model and method
all_results = {}
timing_results = {}

print(f"Comparing 3 generation methods across {len(model_types)} models: {', '.join(model_names.values())}")
print(f"Generating {num_images} images per method per model with prompt: '{prompt}'\n")

# Load and test each model
for model_type in model_types:
    print(f"\n{'='*60}")
    print(f"Loading {model_names[model_type]} model...")
    print(f"{'='*60}")
    
    # Load the model using load_model
    start_load = time.time()
    result = load_model(
        distillation_type=model_type,
        weights_dtype=torch.bfloat16,
        device=device
    )
    
    # Unpack result - distilled model should not be None for these model types
    model_pipe, model_base_unet, model_base_scheduler, model_distilled_unet, model_distilled_scheduler = result
    load_time = time.time() - start_load
    print(f"✓ {model_names[model_type]} loaded in {load_time:.2f}s")
    
    # Initialize storage for this model
    all_results[model_type] = {
        'diversity': [],
        'distilled': [],
        'base': []
    }
    timing_results[model_type] = {
        'diversity': {'total': 0, 'average': 0},
        'distilled': {'total': 0, 'average': 0},
        'base': {'total': 0, 'average': 0},
        'load_time': load_time
    }
    
    model_pipe.set_progress_bar_config(disable=True)
    
    # Generate images with all three methods
    for i in tqdm(range(num_images), desc=f"Generating {model_names[model_type]} images"):
        seed = np.random.randint(0, 2**32 - 1)
        
        # 1. Diversity Distillation (Base + Distilled)
        start_time = time.perf_counter()
        diversity_images = diversity_distillation(
            prompt, seed, model_pipe, 
            model_base_unet, model_distilled_unet, 
            model_distilled_scheduler
        )
        gen_time = time.perf_counter() - start_time
        timing_results[model_type]['diversity']['total'] += gen_time
        all_results[model_type]['diversity'].append(diversity_images[0])
        
        # 2. Distilled Model Only
        start_time = time.perf_counter()
        distilled_images = distilled_only_generation(
            prompt, seed, model_pipe,
            model_distilled_unet, model_distilled_scheduler
        )
        gen_time = time.perf_counter() - start_time
        timing_results[model_type]['distilled']['total'] += gen_time
        all_results[model_type]['distilled'].append(distilled_images[0])
        
        # 3. Base Model Only
        start_time = time.perf_counter()
        base_images = base_only_generation(
            prompt, seed, model_pipe,
            model_base_unet, model_base_scheduler
        )
        gen_time = time.perf_counter() - start_time
        timing_results[model_type]['base']['total'] += gen_time
        all_results[model_type]['base'].append(base_images[0])
    
    # Calculate averages
    for method in ['diversity', 'distilled', 'base']:
        timing_results[model_type][method]['average'] = timing_results[model_type][method]['total'] / num_images
    
    print(f"✓ Generated {num_images*3} images total ({num_images} per method)")
    
    # Clean up to free memory
    del model_pipe, model_base_unet, model_base_scheduler, model_distilled_unet, model_distilled_scheduler
    torch.cuda.empty_cache()

print(f"\n{'='*60}")
print("✅ Multi-model comparison complete!")
print(f"{'='*60}")

In [None]:
# Test loading with distillation_type=None (base model only)
print("Testing load_model with distillation_type=None...\n")

result_base = load_model(
    distillation_type=None,
    weights_dtype=torch.bfloat16,
    device=device
)

# When distillation_type=None, only 3 values are returned
base_pipe, base_model_unet, base_model_scheduler = result_base
distilled_unet_test = None
distilled_scheduler_test = None

print(f"✓ Base model loaded successfully")
print(f"  pipe: {type(base_pipe).__name__}")
print(f"  base_unet: {type(base_model_unet).__name__}")
print(f"  base_scheduler: {type(base_model_scheduler).__name__}")
print(f"  distilled_unet: {distilled_unet_test}")
print(f"  distilled_scheduler: {distilled_scheduler_test}")

# Test that diversity_distillation still works with None distilled model
test_prompt = "a cat on a table"
test_seed = 42

try:
    test_images = diversity_distillation(
        test_prompt, test_seed, base_pipe, 
        base_model_unet, distilled_unet_test, 
        distilled_scheduler_test
    )
    print(f"\n✓ diversity_distillation works with distilled_unet=None")
    print(f"  Generated {len(test_images)} image(s)")
except Exception as e:
    print(f"\n✗ Error: {e}")

# Cleanup
del base_pipe, base_model_unet, base_model_scheduler
torch.cuda.empty_cache()
print("\n✅ Test complete!")

## Summary of Refactoring

### Key Changes:

1. **Updated Imports**: Changed from `load_sdxl_models` to `load_model`
   - The new `load_model` function is the correct one in `load_util.py`

2. **Proper Return Value Handling**:
   - When `distillation_type=None`: Returns only `(pipe, base_unet, base_scheduler)`
   - When `distillation_type` is specified: Returns `(pipe, base_unet, base_scheduler, distilled_unet, distilled_scheduler)`

3. **None Handling**:
   - `distilled_unet = None` clearly indicates when no distillation model is loaded
   - Functions check for `None` before using distilled models

4. **Function Updates**:
   - `diversity_distillation()`: Checks `if distilled_unet is not None` before using it
   - `distilled_only_generation()`: Raises error if `distilled_unet` is None
   - `base_only_generation()`: Always works since base model is always present

5. **Fixed `load_util.py`**:
   - Updated `load_pipe()` to use `load_model()` instead of non-existent `load_sdxl_models()`

### Benefits:
- ✅ Clear API: `distilled_unet = None` explicitly shows when no distillation model exists
- ✅ Consistent: All model loading uses the same `load_model()` function
- ✅ Flexible: Supports both base-only and base+distilled configurations
- ✅ Safe: Functions validate model availability before use