# SDXL Model Pipeline Setup - Lightning Fix Applied
Supports 9 distillation models with proper scheduler configuration

In [31]:
# Imports and Configuration
import sys
import torch
from PIL import Image
from diffusers import (
    UNet2DConditionModel,
    StableDiffusionXLPipeline,
    EulerAncestralDiscreteScheduler,
    EulerDiscreteScheduler,
    DDIMScheduler,
    LCMScheduler,
    TCDScheduler,
    DiffusionPipeline,
)
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file

%load_ext autoreload
%autoreload 2

print("✓ Imports completed")

# ---------- Configuration ----------
device = "cuda"
weights_dtype = torch.bfloat16
basemodel_id = "stabilityai/stable-diffusion-xl-base-1.0"

model_configs = {
#    'base': {'steps': 100, 'recommended_cfg': 5.0},
    'dmd': {'steps': 4, 'recommended_cfg': 0.0},
    'turbo': {'steps': 4, 'recommended_cfg': 0.0},
    'lightning': {'steps': 4, 'recommended_cfg': 0.0},
    'lcm': {'steps': 4, 'recommended_cfg': 1.0},
    'hyper': {'steps': 8, 'recommended_cfg': 5.0},
    'pcm': {'steps': 4, 'recommended_cfg': 2.0},
    'tcd': {'steps': 4, 'recommended_cfg': 3.0},
    'flash': {'steps': 4, 'recommended_cfg': 2.0}
}

print(f"✓ Configuration set - Device: {device}, Dtype: {weights_dtype}")
print(f"✓ Available models: {list(model_configs.keys())}")



def load_model(distillation_type=None, weights_dtype=torch.float16, device='cuda'):
    """
    Load SDXL models with specified distillation type.
    
    Returns:
      'base'/'None': (pipe, base_unet, base_scheduler)
      others:       (pipe, base_unet, base_scheduler, distilled_unet, distilled_scheduler)
    """
    kind = ('base' if distillation_type in (None, 'base') else distillation_type).lower()
    print(f"Loading {kind.upper()} model...")

    # ---- base (always build this once for config/safety) ----
    base_unet = UNet2DConditionModel.from_pretrained(
        basemodel_id, subfolder="unet", torch_dtype=weights_dtype
    ).to(device)

    pipe = StableDiffusionXLPipeline.from_pretrained(
        basemodel_id,
        unet=base_unet,
        torch_dtype=weights_dtype,
        use_safetensors=True,
    )
    base_scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
    pipe.scheduler = base_scheduler
    pipe.to(device=device, dtype=weights_dtype)

    if kind == 'base':
        return pipe, base_unet, base_scheduler

    # fresh UNet matching base config (required for state_dict load)
    distilled_unet = UNet2DConditionModel.from_config(pipe.unet.config).to(device, dtype=weights_dtype)

    if kind == 'dmd':
        repo_name, ckpt_name = "tianweiy/DMD2", "dmd2_sdxl_4step_unet_fp16.bin"
        state = torch.load(hf_hub_download(repo_name, ckpt_name), map_location='cpu')
        distilled_unet.load_state_dict(state if isinstance(state, dict) else state['state_dict'])
        distilled_scheduler = LCMScheduler.from_config(pipe.scheduler.config)

    elif kind == 'lightning':
        repo, ckpt = "ByteDance/SDXL-Lightning", "sdxl_lightning_4step_unet.safetensors"
        state = load_file(hf_hub_download(repo, ckpt))
        distilled_unet.load_state_dict(state, strict=True)
        # FIX: Use EulerDiscreteScheduler with trailing timesteps for both schedulers
        distilled_scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
        base_scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")

    elif kind == 'turbo':
        # turbo ships a full UNet; pull that directly
        distilled_unet = UNet2DConditionModel.from_pretrained(
            "stabilityai/sdxl-turbo", subfolder="unet", torch_dtype=weights_dtype, variant="fp16"
        ).to(device)
        distilled_scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")

    elif kind == 'lcm':
        distilled_unet = UNet2DConditionModel.from_pretrained(
            "latent-consistency/lcm-sdxl", torch_dtype=weights_dtype
        ).to(device)
        distilled_scheduler = LCMScheduler.from_config(pipe.scheduler.config)

    elif kind == 'hyper':
        pipe = DiffusionPipeline.from_pretrained(basemodel_id, torch_dtype=weights_dtype)
        pipe.load_lora_weights("ByteDance/Hyper-SD",
                               weight_name="Hyper-SDXL-8steps-CFG-lora.safetensors",
                               adapter_name="hyper-sdxl-8step")
        pipe.set_adapters(["hyper-sdxl-8step"], adapter_weights=[1.0])
        distilled_unet = pipe.unet
        distilled_scheduler = DDIMScheduler.from_config(pipe.scheduler.config)

    elif kind == 'pcm':
        pipe = DiffusionPipeline.from_pretrained(basemodel_id, torch_dtype=weights_dtype)
        pipe.load_lora_weights("wangfuyun/PCM_Weights",
                               weight_name="pcm_sdxl_smallcfg_4step_converted.safetensors",
                               subfolder="sdxl",
                               adapter_name="pcm-lora")
        pipe.set_adapters(["pcm-lora"], adapter_weights=[1.0])
        distilled_unet = pipe.unet
        distilled_scheduler = DDIMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")

    elif kind == 'tcd':
        pipe = DiffusionPipeline.from_pretrained(basemodel_id, torch_dtype=weights_dtype)
        pipe.load_lora_weights("h1t/TCD-SDXL-LoRA", adapter_name="tcd-lora")
        pipe.set_adapters(["tcd-lora"], adapter_weights=[1.0])
        distilled_unet = pipe.unet
        distilled_scheduler = TCDScheduler.from_config(pipe.scheduler.config)

    elif kind == 'flash':
        pipe = DiffusionPipeline.from_pretrained(basemodel_id, torch_dtype=weights_dtype)
        pipe.load_lora_weights("jasperai/flash-sdxl",
                               weight_name="pytorch_lora_weights.safetensors",
                               adapter_name="flash-sdxl")
        pipe.set_adapters(["flash-sdxl"], adapter_weights=[1.0])
        pipe.fuse_lora()
        distilled_unet = pipe.unet
        distilled_scheduler = LCMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")

    else:
        raise ValueError(f"Unknown distillation type: '{distillation_type}'. "
                         f"Available: {', '.join(sorted(model_configs.keys()))}")

    # IMPORTANT: actually use the distilled UNet
    if hasattr(pipe, "unet") and distilled_unet is not pipe.unet:
        pipe.unet = distilled_unet
    pipe.scheduler = distilled_scheduler
    pipe.to(device=device, dtype=weights_dtype)
    return pipe, base_unet, base_scheduler, distilled_unet, distilled_scheduler


def load_pipe(distillation_type='base'):
    """
    Returns a ready-to-sample pipeline with the correct UNet and scheduler.
    """
    pipe_result = load_model(distillation_type, weights_dtype, device)
    # result already sets the right scheduler/UNet when not 'base'
    pipe = pipe_result[0]
    print(f"✓ {('base' if distillation_type in (None, 'base') else distillation_type).upper()} pipeline ready")
    return pipe

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
✓ Imports completed
✓ Configuration set - Device: cuda, Dtype: torch.bfloat16
✓ Available models: ['dmd', 'turbo', 'lightning', 'lcm', 'hyper', 'pcm', 'tcd', 'flash']


## Test Across Select Models
Run tests on multiple models with first 10 prompts and first seed, organized by model folder

In [None]:
import json
import os
from pathlib import Path
from tqdm import tqdm

# ---------- Load prompts from JSON file ----------
prompts_file = "/home/azureuser/cloudfiles/code/Users/Normalized-Attention-Guidance/data/prompts_noun_negative.json"
output_base_dir = "/home/azureuser/cloudfiles/code/Users/Normalized-Attention-Guidance/results"

# Load prompts
with open(prompts_file, 'r') as f:
    prompts_data = json.load(f)

# Use only first 10 prompts
prompts_data = prompts_data[:10]

print(f"Loaded {len(prompts_data)} prompts from {prompts_file}")
print(f"Output base directory: {output_base_dir}")

# ---------- Generate and save images organized by folder ----------
total_generated = 0

for model_name, model_config in model_configs.items():
    steps = model_config["steps"]
    cfg = model_config["recommended_cfg"]
    
    # Load the model pipeline
    pipe = load_pipe(model_name)
    
    # Create model-specific output directory
    model_output_dir = os.path.join(output_base_dir, model_name)
    os.makedirs(model_output_dir, exist_ok=True)
    
    print(f"\n{'='*60}")
    print(f"Testing model: {model_name}")
    print(f"Output directory: {model_output_dir}")
    print(f"{'='*60}")
    
    generated_count = 0
    
    for idx, item in enumerate(tqdm(prompts_data, desc=f"{model_name} progress")):
        prompt = item["prompt"]
        negative_prompt = item["negative_prompt"]
        seeds = item.get("seeds", [42])
        
        # Use first seed for quick generation
        seed = 2014
        
        try:
            # Generate image
            generator = torch.Generator(device=device).manual_seed(seed)
            image = pipe(
                prompt,
                guidance_scale=cfg,
                num_inference_steps=steps,
                generator=generator
            ).images[0]
            
            # Create output filename
            group = item.get("group", "unknown")
            filename = f"{idx:04d}_{group}_{seed}.png"
            filepath = os.path.join(model_output_dir, filename)
            
            # Save image
            image.save(filepath)
            generated_count += 1
            total_generated += 1
            
        except Exception as e:
            print(f"Error generating image for prompt {idx} with {model_name}: {e}")
            continue
    
    print(f"\n✓ {model_name}: Generated and saved {generated_count} images")

print(f"\n{'='*60}")
print(f"✓ Total generated and saved: {total_generated} images")
print(f"✓ Models tested: {list(model_configs.keys())}")
print(f"{'='*60}")

Loaded 10 prompts from /home/azureuser/cloudfiles/code/Users/Normalized-Attention-Guidance/data/prompts_noun_negative.json
Output base directory: /home/azureuser/cloudfiles/code/Users/Normalized-Attention-Guidance/results
Loading DMD model...
