
# Base→Student Handoff (Diversity Distillation) — *Nag‑time* compatible

This notebook is a variation that **loads models the same way as _nag-time_** but implements a **custom pipeline class** (no monkey patching) to expose the right components for the diversity‑distillation base→student handoff:

- `SDXLHandoffPipeline` (subclass of `StableDiffusionXLPipeline`) exposes:
  - `from_timestep`, `till_timestep`
  - `start_latents`
  - `output_type="latent"` support for returning raw latents
- `diversity_distillation(...)` runs base UNet for the first few steps, hands latents to the distilled UNet, and completes generation.

> **Note:** This notebook avoids `StableDiffusionXLPipeline.__call__ = ...` and uses a new class instead, matching the requirement "same as Nag time, not monkey patched".


In [None]:

# Core libs
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import torch
import torch.nn.functional as F

# Diffusers / Transformers
from diffusers import (
    DiffusionPipeline,
    UNet2DConditionModel,
    LCMScheduler,
    EulerAncestralDiscreteScheduler,
    EulerDiscreteScheduler,
    DDIMScheduler,
    TCDScheduler,
)
from diffusers.pipelines import StableDiffusionXLPipeline
from diffusers.pipelines.pipeline_utils import ImagePipelineOutput
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import retrieve_timesteps
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor

from huggingface_hub import hf_hub_download
from safetensors.torch import load_file


In [None]:

def rescale_noise_cfg(noise_cfg: torch.Tensor, noise_pred_text: torch.Tensor, guidance_rescale: float = 0.0):
    """Rescales guidance to improve image quality (fixes overexposure).
    Based on Section 3.4 of 'Common Diffusion Noise Schedules and Sample Steps are Flawed' (2305.08891).
    """
    std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
    std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
    noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
    noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
    return noise_cfg


In [None]:

class SDXLHandoffPipeline(StableDiffusionXLPipeline):
    """Custom SDXL pipeline exposing start/end timesteps and latent handoff without monkey patching."""

    @torch.no_grad()
    def __call__(
        self,
        prompt: Union[str, List[str]] = None,
        prompt_2: Optional[Union[str, List[str]]] = None,
        height: Optional[int] = None,
        width: Optional[int] = None,
        num_inference_steps: int = 50,
        timesteps: List[int] = None,
        sigmas: List[float] = None,
        denoising_end: Optional[float] = None,
        guidance_scale: float = 5.0,
        negative_prompt: Optional[Union[str, List[str]]] = None,
        negative_prompt_2: Optional[Union[str, List[str]]] = None,
        num_images_per_prompt: Optional[int] = 1,
        eta: float = 0.0,
        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
        latents: Optional[torch.Tensor] = None,
        prompt_embeds: Optional[torch.Tensor] = None,
        negative_prompt_embeds: Optional[torch.Tensor] = None,
        pooled_prompt_embeds: Optional[torch.Tensor] = None,
        negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
        ip_adapter_image: Optional[PipelineImageInput] = None,
        ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
        output_type: Optional[str] = "pil",
        return_dict: bool = True,
        cross_attention_kwargs: Optional[Dict[str, Any]] = None,
        guidance_rescale: float = 0.0,
        original_size: Optional[Tuple[int, int]] = None,
        crops_coords_top_left: Tuple[int, int] = (0, 0),
        target_size: Optional[Tuple[int, int]] = None,
        negative_original_size: Optional[Tuple[int, int]] = None,
        negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
        negative_target_size: Optional[Tuple[int, int]] = None,
        clip_skip: Optional[int] = None,
        callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
        callback_on_step_end_tensor_inputs: List[str] = ["latents"],
        from_timestep: int = 0,
        till_timestep: Optional[int] = None,
        start_latents: Optional[torch.Tensor] = None,
    ) -> Union[ImagePipelineOutput, List[torch.Tensor], torch.Tensor]:
        """A lightly adapted SDXL __call__ that supports partial denoising and latent handoff.
        
        Key additions:
        - from_timestep / till_timestep to run a sub-trajectory
        - start_latents: start from provided latent state (for handoff)
        - output_type='latent' to return raw latents
        """
        # 0) Defaults and bookkeeping
        height = height or self.default_sample_size * self.vae_scale_factor
        width = width or self.default_sample_size * self.vae_scale_factor
        original_size = original_size or (height, width)
        target_size = target_size or (height, width)

        # 1) Check inputs
        self.check_inputs(
            prompt,
            prompt_2,
            height,
            width,
            None,  # callback_steps (unused here)
            negative_prompt,
            negative_prompt_2,
            prompt_embeds,
            negative_prompt_embeds,
            pooled_prompt_embeds,
            negative_pooled_prompt_embeds,
            ip_adapter_image,
            ip_adapter_image_embeds,
            callback_on_step_end_tensor_inputs,
        )

        self._guidance_scale = guidance_scale
        self._guidance_rescale = guidance_rescale
        self._clip_skip = clip_skip
        self._cross_attention_kwargs = cross_attention_kwargs
        self._denoising_end = denoising_end
        self._interrupt = False

        # 2) Batch size
        if prompt is not None and isinstance(prompt, str):
            batch_size = 1
        elif prompt is not None and isinstance(prompt, list):
            batch_size = len(prompt)
        else:
            batch_size = prompt_embeds.shape[0]

        device = self._execution_device

        # 3) Encode prompts
        lora_scale = self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None

        (
            prompt_embeds,
            negative_prompt_embeds,
            pooled_prompt_embeds,
            negative_pooled_prompt_embeds,
        ) = self.encode_prompt(
            prompt=prompt,
            prompt_2=prompt_2,
            device=device,
            num_images_per_prompt=num_images_per_prompt,
            do_classifier_free_guidance=self.do_classifier_free_guidance,
            negative_prompt=negative_prompt,
            negative_prompt_2=negative_prompt_2,
            prompt_embeds=prompt_embeds,
            negative_prompt_embeds=negative_prompt_embeds,
            pooled_prompt_embeds=pooled_prompt_embeds,
            negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
            lora_scale=lora_scale,
            clip_skip=self.clip_skip,
        )

        # 4) Timesteps
        timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps, sigmas)

        # 5) Latents
        num_channels_latents = self.unet.config.in_channels
        latents = self.prepare_latents(
            batch_size * num_images_per_prompt,
            num_channels_latents,
            height,
            width,
            prompt_embeds.dtype,
            device,
            generator,
            latents,
        )

        # 6) Extra step kwargs
        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)

        # 7) Added time ids & embeddings
        add_text_embeds = pooled_prompt_embeds
        if self.text_encoder_2 is None:
            text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
        else:
            text_encoder_projection_dim = self.text_encoder_2.config.projection_dim

        add_time_ids = self._get_add_time_ids(
            original_size,
            crops_coords_top_left,
            target_size,
            dtype=prompt_embeds.dtype,
            text_encoder_projection_dim=text_encoder_projection_dim,
        )
        if negative_original_size is not None and negative_target_size is not None:
            negative_add_time_ids = self._get_add_time_ids(
                negative_original_size,
                negative_crops_coords_top_left,
                negative_target_size,
                dtype=prompt_embeds.dtype,
                text_encoder_projection_dim=text_encoder_projection_dim,
            )
        else:
            negative_add_time_ids = add_time_ids

        if self.do_classifier_free_guidance:
            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
            add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
            add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)

        prompt_embeds = prompt_embeds.to(device)
        add_text_embeds = add_text_embeds.to(device)
        add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)

        if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
            image_embeds = self.prepare_ip_adapter_image_embeds(
                ip_adapter_image,
                ip_adapter_image_embeds,
                device,
                batch_size * num_images_per_prompt,
                self.do_classifier_free_guidance,
            )

        # 8) Denoising loop setup
        num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)

        if (self.denoising_end is not None and isinstance(self.denoising_end, float) 
                and 0 < self.denoising_end < 1):
            discrete_timestep_cutoff = int(round(self.scheduler.config.num_train_timesteps - (self.denoising_end * self.scheduler.config.num_train_timesteps)))
            num_inference_steps = len([ts for ts in timesteps if ts >= discrete_timestep_cutoff])
            timesteps = timesteps[:num_inference_steps]

        timestep_cond = None
        if self.unet.config.time_cond_proj_dim is not None:
            guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
            timestep_cond = self.get_guidance_scale_embedding(
                guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
            ).to(device=device, dtype=latents.dtype)

        self._num_timesteps = len(timesteps)

        # Handoff support
        if start_latents is not None:
            latents = start_latents.to(device=device, dtype=latents.dtype)

        with self.progress_bar(total=num_inference_steps) as progress_bar:
            for i, t in enumerate(timesteps[from_timestep:till_timestep]):
                # Prepare input (CFG)
                latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
                latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

                # Predict noise
                added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
                if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
                    added_cond_kwargs["image_embeds"] = image_embeds

                noise_pred = self.unet(
                    latent_model_input,
                    t,
                    encoder_hidden_states=prompt_embeds,
                    timestep_cond=timestep_cond,
                    cross_attention_kwargs=self.cross_attention_kwargs,
                    added_cond_kwargs=added_cond_kwargs,
                    return_dict=False,
                )[0]

                # CFG
                if self.do_classifier_free_guidance:
                    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
                    noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)

                if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
                    noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)

                # Step
                latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]

                # Progress
                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
                    progress_bar.update()

        # Decode or return latents
        if output_type == "latent":
            image = latents
        else:
            needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
            if needs_upcasting:
                self.upcast_vae()
                latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
            elif latents.dtype != self.vae.dtype and torch.backends.mps.is_available():
                self.vae = self.vae.to(latents.dtype)

            has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None
            has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None
            if has_latents_mean and has_latents_std:
                latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype)
                latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype)
                latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean
            else:
                latents = latents / self.vae.config.scaling_factor

            image = self.vae.decode(latents, return_dict=False)[0]
            image = self.image_processor.postprocess(image, output_type=output_type)

        self.maybe_free_model_hooks()

        if not return_dict:
            return image
        return ImagePipelineOutput(images=image)


In [None]:

@torch.no_grad()
def diversity_distillation(
    prompt: str,
    negative_prompt: str,
    seed: int,
    pipe: SDXLHandoffPipeline,
    base_unet: UNet2DConditionModel,
    distilled_unet: UNet2DConditionModel,
    distilled_scheduler,
    base_guidance_scale: float = 5.0,
    distilled_guidance_scale: float = 0.0,
    num_inference_steps: int = 4,
    run_base_till: int = 1,
    output_type: str = "pil",
):
    """Run base for a few steps, then hand off to distilled UNet to finish.
    
    Returns PIL images (default) or raw latents if `output_type='latent'`.
    """
    # 1) Run the base to get partial latents
    pipe.scheduler = distilled_scheduler  # schedule shape should match handoff trajectory
    pipe.unet = base_unet
    base_latents = pipe(
        prompt,
        guidance_scale=base_guidance_scale,
        till_timestep=run_base_till,
        negative_prompt=negative_prompt,
        num_inference_steps=num_inference_steps,
        generator=torch.Generator(device=pipe.device).manual_seed(seed),
        output_type='latent',
        return_dict=False,
    )

    # 2) Continue from those latents with the distilled student
    pipe.unet = distilled_unet
    images = pipe(
        prompt,
        guidance_scale=distilled_guidance_scale,
        start_latents=base_latents,
        negative_prompt=negative_prompt,
        num_inference_steps=num_inference_steps,
        from_timestep=run_base_till,
        output_type=output_type,
        return_dict=False,
    )
    return images


In [None]:

def load_model(distillation_type: Optional[str] = None, weights_dtype=torch.bfloat16, device: str = 'cuda:0'):
    """Load SDXL models with specified distillation type.
    
    For 'base': returns (pipe, base_unet, base_scheduler)
    For others: returns (pipe, base_unet, base_scheduler, distilled_unet, distilled_scheduler)
    """
    basemodel_id = "stabilityai/stable-diffusion-xl-base-1.0"

    # Base model and scheduler
    base_unet = UNet2DConditionModel.from_pretrained(basemodel_id, subfolder="unet").to(device, weights_dtype)
    pipe = SDXLHandoffPipeline.from_pretrained(basemodel_id, unet=base_unet, torch_dtype=weights_dtype, use_safetensors=True)
    pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
    base_scheduler = pipe.scheduler

    if distillation_type is None:
        pipe.to(device).to(weights_dtype)
        return pipe, base_unet, base_scheduler

    # Distilled variants
    if distillation_type == 'dmd':
        repo_name = "tianweiy/DMD2"
        ckpt_name = "dmd2_sdxl_4step_unet_fp16.bin"
        distilled_unet = UNet2DConditionModel.from_config(basemodel_id, subfolder="unet").to(device, weights_dtype)
        distilled_unet.load_state_dict(torch.load(hf_hub_download(repo_name, ckpt_name), weights_only=True))
        distilled_scheduler = LCMScheduler.from_config(pipe.scheduler.config)

    elif distillation_type == 'lightning':
        repo = "ByteDance/SDXL-Lightning"
        ckpt = "sdxl_lightning_4step_unet.safetensors"
        distilled_unet = UNet2DConditionModel.from_config(basemodel_id, subfolder="unet").to(device, weights_dtype)
        distilled_unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device=device))
        distilled_scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
        base_scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")

    elif distillation_type == 'turbo':
        distilled_unet = UNet2DConditionModel.from_pretrained("stabilityai/sdxl-turbo", subfolder="unet", torch_dtype=weights_dtype, variant="fp16").to(device, weights_dtype)
        distilled_scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
        base_scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")

    elif distillation_type == 'lcm':
        distilled_unet = UNet2DConditionModel.from_pretrained("latent-consistency/lcm-sdxl", torch_dtype=weights_dtype).to(device, weights_dtype)
        distilled_scheduler = LCMScheduler.from_config(pipe.scheduler.config)

    elif distillation_type == 'hyper':
        # Hyper-SDXL 8-step CFG-preserved LoRA
        pipe_tmp = DiffusionPipeline.from_pretrained(basemodel_id, torch_dtype=weights_dtype)
        pipe_tmp.load_lora_weights("ByteDance/Hyper-SD", weight_name="Hyper-SDXL-8steps-CFG-lora.safetensors", adapter_name="hyper-sdxl-8step")
        pipe_tmp.set_adapters(["hyper-sdxl-8step"], adapter_weights=[1.0])
        distilled_unet = pipe_tmp.unet
        distilled_scheduler = DDIMScheduler.from_config(pipe.scheduler.config)

    elif distillation_type == 'hyper_1step':
        pipe_tmp = DiffusionPipeline.from_pretrained(basemodel_id, torch_dtype=weights_dtype)
        pipe_tmp.load_lora_weights("ByteDance/Hyper-SD", weight_name="Hyper-SDXL-1step-lora.safetensors", adapter_name="hyper-sdxl-1step")
        pipe_tmp.set_adapters(["hyper-sdxl-1step"], adapter_weights=[1.0])
        distilled_unet = pipe_tmp.unet
        distilled_scheduler = TCDScheduler.from_config(pipe.scheduler.config)

    elif distillation_type == 'pcm':
        pipe_tmp = DiffusionPipeline.from_pretrained(basemodel_id, torch_dtype=weights_dtype)
        pipe_tmp.load_lora_weights("wangfuyun/PCM_Weights", weight_name="pcm_sdxl_smallcfg_4step_converted.safetensors", subfolder="sdxl", adapter_name="pcm-lora")
        pipe_tmp.set_adapters(["pcm-lora"], adapter_weights=[1.0])
        distilled_unet = pipe_tmp.unet
        distilled_scheduler = DDIMScheduler.from_config(
            pipe.scheduler.config,
            timestep_spacing="trailing",
            clip_sample=False,
            set_alpha_to_one=False,
        )

    elif distillation_type == 'tcd':
        pipe_tmp = DiffusionPipeline.from_pretrained(basemodel_id, torch_dtype=weights_dtype)
        pipe_tmp.load_lora_weights("h1t/TCD-SDXL-LoRA", adapter_name="tcd-lora")
        pipe_tmp.set_adapters(["tcd-lora"], adapter_weights=[1.0])
        distilled_unet = pipe_tmp.unet
        distilled_scheduler = TCDScheduler.from_config(pipe.scheduler.config)

    elif distillation_type == 'flash':
        repo = "jasperai/flash-sdxl"
        ckpt = "pytorch_lora_weights.safetensors"
        pipe_tmp = DiffusionPipeline.from_pretrained(basemodel_id, torch_dtype=weights_dtype)
        pipe_tmp.load_lora_weights(repo, weight_name=ckpt, adapter_name="flash-sdxl")
        pipe_tmp.set_adapters(["flash-sdxl"], adapter_weights=[1.0])
        pipe_tmp.fuse_lora()
        distilled_unet = pipe_tmp.unet
        distilled_scheduler = LCMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")

    else:
        raise ValueError(f"Unknown distillation type: {distillation_type}")

    # Finalise device / dtype
    
    pipe.to(device).to(weights_dtype)

    if distillation_type is not None:
        distilled_unet = distilled_unet.to(device, dtype=weights_dtype)
    return pipe, base_unet, base_scheduler, distilled_unet, distilled_scheduler


def load_pipe(distillation_type: Optional[str] = None, weights_dtype=torch.bfloat16, device: str = 'cuda:0'):
    """Return a configured SDXL handoff pipeline for direct use (mirrors the nag-time loader shape)."""
    result = load_model(distillation_type, weights_dtype, device)
    if distillation_type is None:
        pipe, _, _ = result
    else:
        pipe, _, _, _, distilled_scheduler = result
        pipe.scheduler = distilled_scheduler
    return pipe



## Minimal usage

Uncomment and run to test. This keeps defaults very light (4 steps) for quick checks, but you can scale up.


In [None]:

# Example (commented out to avoid heavy downloads/execution here):
device = 'cuda:0' 
pipe, base_unet, base_sched, distilled_unet, distilled_sched = load_model('turbo', device=device)
img = diversity_distillation(
     prompt="a older librarian",
     negative_prompt="male",
     seed=44,
     pipe=pipe,
     base_unet=base_unet,
     distilled_unet=distilled_unet,
     distilled_scheduler=distilled_sched,
     base_guidance_scale=5.0,
     distilled_guidance_scale=1.0,
     num_inference_steps=4,
     run_base_till=1,
)
img[0]


In [None]:
# --- Diversity Distillation: negatives vs not (full test cell) ---

import json
import os
from pathlib import Path
from tqdm import tqdm
from IPython.display import display as notebook_display

import gc
import torch

# ============== Config ==============
dry_run = False  # When True: process only 1 prompt and display images instead of saving

# e.g. any JSON list of {"prompt": "...", "negative_prompt": "...", "score": <optional>}
prompts_file = "/home/azureuser/cloudfiles/code/Users/Normalized-Attention-Guidance/data/prompts_general.json"

# Distilled model variants and suggested guidance settings for STUDENT phase.
# Base phase CFG is typically ~5.0; student CFG varies by distillation type.
model_configs = {
    'dmd':       {'steps': 4, 'base_cfg': 7.0, 'student_cfg': 1.5},
    'turbo':     {'steps': 4, 'base_cfg': 7.0, 'student_cfg': 0.0},
    'lightning': {'steps': 4, 'base_cfg': 7.0, 'student_cfg': 0.0},
    'lcm':       {'steps': 4, 'base_cfg': 7.0, 'student_cfg': 1.5},
    'hyper':     {'steps': 8, 'base_cfg': 7.0, 'student_cfg': 5.0},
    'pcm':       {'steps': 4, 'base_cfg': 7.0, 'student_cfg': 1.0},
}

# Handoff settings
dd_run_base_till = 1  # run the base UNet for this many steps, then hand off to the student

# Fixed seeds
fixed_seeds = [2025, 42, 1337]

# Output location
output_base_dir = "/home/azureuser/cloudfiles/code/Users/Normalized-Attention-Guidance/results-dd-neg-compare"

# ===================================

def clear_cuda(*objs):
    """Free refs, collect Python garbage, then flush CUDA caches."""
    for o in objs:
        try:
            del o
        except NameError:
            pass
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.ipc_collect()

def write_or_display(image, filepath, title="preview"):
    if dry_run:
        notebook_display(image)
    else:
        os.makedirs(os.path.dirname(filepath), exist_ok=True)
        image.save(filepath)

# Load prompts
with open(prompts_file, 'r') as f:
    prompts_data = json.load(f)

# Select top 50 by "score" if present, else first 50
if isinstance(prompts_data, list) and len(prompts_data) > 0 and isinstance(prompts_data[0], dict) and 'score' in prompts_data[0]:
    prompts_data = sorted(prompts_data, key=lambda x: x.get('score', 0), reverse=True)[:50]
else:
    prompts_data = prompts_data[:50]

if dry_run:
    prompts_data = prompts_data[:1]  # only one prompt

# Choose seeds (all vs one when dry-run)
seeds_to_use = fixed_seeds[:1] if dry_run else fixed_seeds

print(f"Loaded {len(prompts_data)} prompt(s) from {prompts_file}")
print(f"[CONFIG] models={list(model_configs.keys())} | dry_run={dry_run} | seeds={seeds_to_use}")
print(f"[CONFIG] dd_run_base_till={dd_run_base_till}\n")

# ============== Generate ==============
total_generated = 0

for model_name, cfgs in model_configs.items():
    steps      = cfgs["steps"]
    base_cfg   = cfgs["base_cfg"]
    student_cfg= cfgs["student_cfg"]

    print(f"[MODEL] {model_name} -> steps={steps}, base_cfg={base_cfg}, student_cfg={student_cfg}")

    # Load pipe + components for this distilled model
    try:
        load_result = load_model(model_name, device=("cuda:0" if torch.cuda.is_available() else "cpu"))
    except Exception as e:
        print(f"[ERROR] Failed to load model {model_name}: {e}")
        clear_cuda()
        continue

    # Unpack result (4-tuple vs 5-tuple depending on loader; for distilled we expect 5)
    try:
        pipe, base_unet, base_sched, distilled_unet, distilled_sched = load_result
    except ValueError:
        # If the loader returns only base pieces, skip (not relevant to this comparison)
        print(f"[WARN] Loader did not return distilled components for {model_name}; skipping.")
        clear_cuda()
        continue

    model_output_dir = os.path.join(output_base_dir, model_name)
    noneg_dir  = os.path.join(model_output_dir, "dd_noneg")
    withneg_dir= os.path.join(model_output_dir, "dd_withneg")

    if not dry_run:
        os.makedirs(noneg_dir, exist_ok=True)
        os.makedirs(withneg_dir, exist_ok=True)

    print("\n[DIRS]")
    print(f"  base_output: {output_base_dir}")
    print(f"  model_root : {model_output_dir}")
    print(f"  noneg      : {noneg_dir}")
    print(f"  withneg    : {withneg_dir}")

    print(f"\n{'='*60}")
    print(f"Testing diversity distillation (negatives vs not): {model_name}")
    print("Displaying images (no writes)\n" if dry_run else f"Writing under: {model_output_dir}\n")
    print(f"{'='*60}")

    generated_count = 0

    for idx, item in enumerate(tqdm(prompts_data, desc=f"{model_name} progress")):
        prompt = item["prompt"]
        neg = item.get("negative_prompt", None)

        runs = [
            {"out_dir": noneg_dir,   "use_negative": False, "label": "noneg"},
            {"out_dir": withneg_dir, "use_negative": True,  "label": "withneg"},
        ]

        for run in runs:
            for seed in seeds_to_use:
                try:
                    # Call diversity_distillation twice: without negatives vs with negatives.
                    # The function builds its own torch.Generator from the seed & pipe.device.
                    images = diversity_distillation(
                        prompt=prompt,
                        seed=seed,
                        pipe=pipe,
                        base_unet=base_unet,
                        distilled_unet=distilled_unet,
                        distilled_scheduler=distilled_sched,
                        base_guidance_scale=base_cfg,
                        distilled_guidance_scale=student_cfg,
                        num_inference_steps=steps,
                        run_base_till=dd_run_base_till,
                        output_type="pil",
                        negative_prompt=(neg if run["use_negative"] else None)
                    )

                    # diversity_distillation returns a list/np array of images, or a single image depending on pipeline
                    image = images[0] if isinstance(images, (list, tuple)) else images
                    filename = f"{idx:04d}_{seed}_{run['label']}.png"
                    filepath = os.path.join(run["out_dir"], filename)

                    if dry_run:
                        print(f"[DISPLAY] {model_name} ({run['label']} | seed={seed})")
                        write_or_display(image, filepath, title=f"{model_name}:{run['label']}:{seed}")
                    else:
                        print(f"[WRITE] {model_name} -> {filepath}")
                        write_or_display(image, filepath)

                    generated_count += 1
                    total_generated += 1

                except Exception as e:
                    print(f"[ERROR] model={model_name}, prompt_idx={idx}, seed={seed}, run={run['label']} -> {e}")
                    continue

    print(f"\n✓ {model_name}: Generated and {'displayed' if dry_run else 'saved'} {generated_count} image(s)")

    # --- hard cleanup between models ---
    try:
        if 'pipe' in locals() and pipe is not None:
            try:
                pipe.to("cpu")
            except Exception:
                pass
            clear_cuda(pipe, base_unet, distilled_unet)
        else:
            clear_cuda()
    finally:
        pipe = None
        base_unet = None
        distilled_unet = None

print(f"\n{'='*60}")
print(f"✓ Total generated and {'displayed' if dry_run else 'saved'}: {total_generated} image(s)")
print(f"✓ Models tested: {list(model_configs.keys())}")
print(f"{'='*60}")


In [None]:
# ==========================================================
# Diversity Distillation benchmark: latency (s/image) & peak GPU MB
# Techniques: dd_noneg vs dd_withneg for each model in model_configs
# ==========================================================
import json, os, time, csv, statistics, gc
from pathlib import Path
from collections import defaultdict

import torch  # <-- ensure torch is imported

assert torch.cuda.is_available(), "CUDA is required for this benchmark."
torch.backends.cudnn.benchmark = True  # faster kernels with fixed shapes
device = "cuda"

# Handoff settings
dd_run_base_till = 1  # run the base UNet for this many steps, then hand off to the student

# Fixed seeds
fixed_seeds = [2025, 42, 1337]


prompts_file = "/home/azureuser/cloudfiles/code/Users/Normalized-Attention-Guidance/data/prompts_general.json"

# Distilled model variants and suggested guidance settings for STUDENT phase.
# Base phase CFG is typically ~5.0; student CFG varies by distillation type.
model_configs = {
    'dmd':       {'steps': 4, 'base_cfg': 7.0, 'student_cfg': 1.5},
    'turbo':     {'steps': 4, 'base_cfg': 7.0, 'student_cfg': 0.0},
    'lightning': {'steps': 4, 'base_cfg': 7.0, 'student_cfg': 0.0},
    'lcm':       {'steps': 4, 'base_cfg': 7.0, 'student_cfg': 1.5},
    'hyper':     {'steps': 8, 'base_cfg': 7.0, 'student_cfg': 5.0},
    'pcm':       {'steps': 4, 'base_cfg': 7.0, 'student_cfg': 1.0},
}

# --------- You can tweak these ---------
WARMUP_RUNS   = 2
MEASURE_RUNS  = 5   # per your requirement: >= 5 after warm-up
HEIGHT        = None   # keep None to use pipeline defaults; set e.g. 1024 for apples-to-apples
WIDTH         = None
OUT_CSV = "/home/azureuser/cloudfiles/code/Users/Normalized-Attention-Guidance/results-dd-neg-compare/metrics_dd_latency_memory.csv"
# --------------------------------------

def _clear_cuda(*objs):
    for o in objs:
        try:
            del o
        except NameError:
            pass
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.ipc_collect()

def _measure_one_call(func, *args, **kwargs):
    """
    Measure end-to-end elapsed seconds (per image) and CUDA peak MB for one call.
    Returns (seconds_per_image, peak_mb, out).
    """
    if device == "cuda":
        torch.cuda.reset_peak_memory_stats()
        torch.cuda.synchronize()
    t0 = time.perf_counter()
    out = func(*args, **kwargs)
    if device == "cuda":
        torch.cuda.synchronize()
    t1 = time.perf_counter()

    # Determine number of images returned to compute seconds/image
    if isinstance(out, (list, tuple)):
        n_imgs = len(out) if len(out) > 0 else 1
    else:
        n_imgs = 1
    elapsed = (t1 - t0) / max(1, n_imgs)

    peak_mb = (torch.cuda.max_memory_allocated() / (1024**2)) if device == "cuda" else 0.0
    return elapsed, peak_mb, out

# ---- Load prompts (reusing your config above) ----
with open(prompts_file, 'r') as f:
    prompts_data = json.load(f)
if isinstance(prompts_data, list) and len(prompts_data) > 0 and isinstance(prompts_data[0], dict) and 'score' in prompts_data[0]:
    prompts_sorted = sorted(prompts_data, key=lambda x: x.get('score', 0), reverse=True)
else:
    prompts_sorted = prompts_data

# We'll benchmark using a single prompt to keep the measurement tight & repeatable.
bench_item = prompts_sorted[0]
BENCH_PROMPT = bench_item["prompt"]
BENCH_NEG    = bench_item.get("negative_prompt", "")

print(f"[BENCHMARK PROMPT]\n  prompt: {BENCH_PROMPT[:100]}{'...' if len(BENCH_PROMPT)>100 else ''}\n  negative_prompt: {BENCH_NEG[:100]}{'...' if BENCH_NEG and len(BENCH_NEG)>100 else ''}")

results = []
total_measured_images = 0

# ============= Models outermost; techniques inner; hard cleanup per model =============
for model_name, cfgs in model_configs.items():
    steps       = cfgs["steps"]
    base_cfg    = cfgs["base_cfg"]
    student_cfg = cfgs["student_cfg"]

    print("\n" + "="*70)
    print(f"Model: {model_name} | steps={steps} | base_cfg={base_cfg} | student_cfg={student_cfg}")
    print("="*70)

    # Load once per model
    try:
        load_result = load_model(model_name, device=("cuda:0" if torch.cuda.is_available() else "cpu"))
    except Exception as e:
        print(f"[ERROR] Failed to load model {model_name}: {e}")
        _clear_cuda()
        continue

    # Expect distilled components (pipe, base_unet, base_sched, distilled_unet, distilled_sched)
    try:
        pipe, base_unet, base_sched, distilled_unet, distilled_sched = load_result
    except ValueError:
        print(f"[WARN] Loader did not return distilled components for {model_name}; skipping.")
        _clear_cuda()
        continue

    techniques = [
        {"label": "dd_noneg",   "use_negative": False},
        {"label": "dd_withneg", "use_negative": True},
    ]

    # Common kwargs to diversity_distillation
    base_kwargs = dict(
        pipe=pipe,
        base_unet=base_unet,
        distilled_unet=distilled_unet,
        distilled_scheduler=distilled_sched,
        base_guidance_scale=base_cfg,
        distilled_guidance_scale=student_cfg,
        num_inference_steps=steps,
        run_base_till=dd_run_base_till,
        output_type="pil",
        # height=HEIGHT,  # uncomment if your function/plumbing supports explicit size
        # width=WIDTH,
    )

    for tech in techniques:
        label = tech["label"]
        print(f"\n -> Technique: {label}")

        # Warm-up (not recorded)
        for _ in range(WARMUP_RUNS):
            _measure_one_call(
                diversity_distillation,
                prompt=BENCH_PROMPT,
                seed=1234,
                negative_prompt=(BENCH_NEG if tech["use_negative"] else None),
                **base_kwargs
            )

        # Measured runs
        latencies, peaks = [], []
        for i in range(MEASURE_RUNS):
            elapsed, peak_mb, _ = _measure_one_call(
                diversity_distillation,
                prompt=BENCH_PROMPT,
                seed=42 + i,  # vary to avoid cache coincidences
                negative_prompt=(BENCH_NEG if tech["use_negative"] else None),
                **base_kwargs
            )
            latencies.append(elapsed)
            peaks.append(peak_mb)
            total_measured_images += 1
            print(f"    run {i+1}/{MEASURE_RUNS}: {elapsed:.3f}s per image, peak {peak_mb:.1f}MB")

        lat_mean = statistics.fmean(latencies)
        lat_std  = statistics.pstdev(latencies) if len(latencies) > 1 else 0.0
        mem_mean = statistics.fmean(peaks)
        mem_std  = statistics.pstdev(peaks) if len(peaks) > 1 else 0.0

        results.append({
            "model": model_name,
            "technique": label,
            "steps": steps,
            "base_cfg": base_cfg,
            "student_cfg": student_cfg,
            "runs": MEASURE_RUNS,
            "latency_mean_s_per_image": round(lat_mean, 4),
            "latency_std_s_per_image": round(lat_std, 4),
            "peak_gpu_mb_mean": round(mem_mean, 1),
            "peak_gpu_mb_std": round(mem_std, 1),
        })

    # Hard cleanup between models (keep stats honest)
    try:
        pipe.to("cpu")
    except Exception:
        pass
    _clear_cuda(pipe, base_unet, distilled_unet, base_sched, distilled_sched)
    pipe = base_unet = distilled_unet = base_sched = distilled_sched = None

# --------- Pretty print summary ----------
from tabulate import tabulate
print("\n================ SUMMARY: Diversity Distillation (neg vs not) ================")
table = []
for row in results:
    table.append([
        row["model"], row["technique"], row["steps"],
        row["base_cfg"], row["student_cfg"], row["runs"],
        f'{row["latency_mean_s_per_image"]:.3f} ± {row["latency_std_s_per_image"]:.3f}',
        f'{row["peak_gpu_mb_mean"]:.1f} ± {row["peak_gpu_mb_std"]:.1f}',
    ])
print(tabulate(
    table,
    headers=["Model", "Technique", "Steps", "Base CFG", "Student CFG", "Runs",
             "Latency (s/image)", "Peak GPU (MB)"],
    tablefmt="github"
))

# --------- Save CSV ----------
os.makedirs(os.path.dirname(OUT_CSV), exist_ok=True)
if results:
    with open(OUT_CSV, "w", newline="") as f:
        writer = csv.DictWriter(f, fieldnames=list(results[0].keys()))
        writer.writeheader()
        writer.writerows(results)
    print(f"\nSaved metrics to: {OUT_CSV}\nTotal measured images: {total_measured_images}")
else:
    print("\nNo results collected (unexpected).")
