# DQAR DiT-XL Benchmark

Test the DQARAttentionProcessor on DiT-XL-2-256 with real attention weight capture.

In [None]:
# Install dependencies (run once)
!pip install -q torch diffusers accelerate transformers

In [None]:
# Clone DQAR repo (or upload the src folder)
!git clone https://github.com/YOUR_USERNAME/DQAR.git 2>/dev/null || (cd DQAR && git pull)
import sys
sys.path.insert(0, 'DQAR/src')

In [None]:
import torch
import time
import numpy as np
import matplotlib.pyplot as plt
from diffusers import DiTPipeline, DPMSolverMultistepScheduler

# Check GPU
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
if device == "cuda":
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

In [None]:
# Load DiT-XL-2-256
print("Loading DiT-XL-2-256...")
pipe = DiTPipeline.from_pretrained(
    "facebook/DiT-XL-2-256",
    torch_dtype=torch.float16,
)
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
pipe = pipe.to(device)
print(f"Loaded! Transformer has {len(pipe.transformer.transformer_blocks)} blocks")

In [None]:
# Import DQAR and patch the pipeline
from dqar import DQARController, DQARConfig
from dqar.dit_wrapper import patch_dit_pipeline, get_dit_layer_count

# Patch pipeline with DQAR processors
pipe = patch_dit_pipeline(pipe)
num_layers = get_dit_layer_count(pipe)
print(f"Patched {num_layers} attention layers with DQARAttentionProcessor")

In [None]:
# Helper function to create configs
def make_baseline_config():
    """Config that prevents ALL reuse."""
    cfg = DQARConfig()
    cfg.gate.min_step = 9999
    cfg.gate.entropy_threshold = 0.0
    cfg.gate.min_probability = 1.0
    cfg.scheduler.max_gap = 0
    cfg.scheduler.max_reuse_per_block = 0
    return cfg

def make_static_config():
    """Config that always reuses if cached."""
    cfg = DQARConfig()
    cfg.gate.min_step = 0
    cfg.gate.entropy_threshold = 1e9
    cfg.gate.min_probability = 0.0
    cfg.gate.snr_range = (0.0, 1e9)
    cfg.gate.cooldown_steps = 0
    cfg.scheduler.max_gap = 50
    cfg.scheduler.max_reuse_per_block = 50
    return cfg

def make_dqar_config(entropy_threshold=3.0):
    """Config with entropy-gated reuse."""
    cfg = DQARConfig()
    cfg.gate.min_step = 0
    cfg.gate.entropy_threshold = entropy_threshold
    cfg.gate.min_probability = 0.0
    cfg.gate.snr_range = (0.0, 1e9)
    cfg.gate.cooldown_steps = 0
    cfg.scheduler.max_gap = 10
    cfg.scheduler.max_reuse_per_block = 10
    return cfg

In [None]:
# Benchmark function
def benchmark_config(pipe, config, class_labels, num_steps=25, num_runs=3, seed=42):
    """Benchmark a configuration and return stats."""
    times = []
    reuse_counts = []
    
    for run in range(num_runs):
        # Create fresh controller each run
        controller = DQARController(num_layers=num_layers, config=config)
        
        # Warmup CUDA
        if run == 0:
            torch.cuda.synchronize()
        
        generator = torch.Generator(device=device).manual_seed(seed + run)
        
        torch.cuda.synchronize()
        start = time.perf_counter()
        
        output = pipe(
            class_labels=class_labels,
            num_inference_steps=num_steps,
            generator=generator,
            controller=controller,
            output_type="pt",
        )
        
        torch.cuda.synchronize()
        elapsed = time.perf_counter() - start
        
        times.append(elapsed)
        reuse_counts.append(controller.get_reuse_count())
        
        print(f"  Run {run+1}: {elapsed:.2f}s, reuse={controller.get_reuse_count()}")
    
    return {
        "avg_time": np.mean(times),
        "std_time": np.std(times),
        "avg_reuse": np.mean(reuse_counts),
        "std_reuse": np.std(reuse_counts),
        "image": output.images[0] if hasattr(output, 'images') else None,
    }

In [None]:
# Run benchmarks
class_labels = [207]  # Golden retriever
num_steps = 25
num_runs = 3

print("="*50)
print("BASELINE (no reuse)")
print("="*50)
baseline = benchmark_config(pipe, make_baseline_config(), class_labels, num_steps, num_runs)

print("\n" + "="*50)
print("STATIC (always reuse)")
print("="*50)
static = benchmark_config(pipe, make_static_config(), class_labels, num_steps, num_runs)

print("\n" + "="*50)
print("DQAR (entropy threshold=3.0)")
print("="*50)
dqar_3 = benchmark_config(pipe, make_dqar_config(3.0), class_labels, num_steps, num_runs)

print("\n" + "="*50)
print("DQAR (entropy threshold=4.0)")
print("="*50)
dqar_4 = benchmark_config(pipe, make_dqar_config(4.0), class_labels, num_steps, num_runs)

print("\n" + "="*50)
print("DQAR (entropy threshold=5.0)")
print("="*50)
dqar_5 = benchmark_config(pipe, make_dqar_config(5.0), class_labels, num_steps, num_runs)

In [None]:
# Print summary
print("\n" + "="*60)
print("SUMMARY")
print("="*60)
print(f"{'Config':<20} {'Time (s)':<15} {'Reuse Events':<15} {'Speedup':<10}")
print("-"*60)
print(f"{'Baseline':<20} {baseline['avg_time']:.2f} ± {baseline['std_time']:.2f}   {baseline['avg_reuse']:.0f}             1.00x")
print(f"{'Static':<20} {static['avg_time']:.2f} ± {static['std_time']:.2f}   {static['avg_reuse']:.0f}           {baseline['avg_time']/static['avg_time']:.2f}x")
print(f"{'DQAR τ=3.0':<20} {dqar_3['avg_time']:.2f} ± {dqar_3['std_time']:.2f}   {dqar_3['avg_reuse']:.0f}           {baseline['avg_time']/dqar_3['avg_time']:.2f}x")
print(f"{'DQAR τ=4.0':<20} {dqar_4['avg_time']:.2f} ± {dqar_4['std_time']:.2f}   {dqar_4['avg_reuse']:.0f}           {baseline['avg_time']/dqar_4['avg_time']:.2f}x")
print(f"{'DQAR τ=5.0':<20} {dqar_5['avg_time']:.2f} ± {dqar_5['std_time']:.2f}   {dqar_5['avg_reuse']:.0f}           {baseline['avg_time']/dqar_5['avg_time']:.2f}x")

In [None]:
# Plot results
fig, axes = plt.subplots(1, 2, figsize=(12, 5))

configs = ['Baseline', 'Static', 'DQAR\nτ=3.0', 'DQAR\nτ=4.0', 'DQAR\nτ=5.0']
times = [baseline['avg_time'], static['avg_time'], dqar_3['avg_time'], dqar_4['avg_time'], dqar_5['avg_time']]
time_errs = [baseline['std_time'], static['std_time'], dqar_3['std_time'], dqar_4['std_time'], dqar_5['std_time']]
reuse = [baseline['avg_reuse'], static['avg_reuse'], dqar_3['avg_reuse'], dqar_4['avg_reuse'], dqar_5['avg_reuse']]

colors = ['#4C72B0', '#55A868', '#C44E52', '#8172B2', '#CCB974']
x = np.arange(len(configs))

# Plot 1: Inference Time
bars1 = axes[0].bar(x, times, yerr=time_errs, capsize=5, color=colors)
axes[0].set_ylabel('Time (seconds)', fontsize=12)
axes[0].set_title('Inference Time (25 steps)', fontsize=13)
axes[0].set_xticks(x)
axes[0].set_xticklabels(configs)
axes[0].bar_label(bars1, fmt='%.2f')

# Plot 2: Reuse Events
bars2 = axes[1].bar(x, reuse, color=colors)
axes[1].set_ylabel('Reuse Events', fontsize=12)
axes[1].set_title('Attention Reuse Count', fontsize=13)
axes[1].set_xticks(x)
axes[1].set_xticklabels(configs)
axes[1].bar_label(bars2, fmt='%.0f')

fig.suptitle('DQAR DiT-XL-2-256 Benchmark', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.savefig('dit_xl_benchmark_results.png', dpi=150)
plt.show()

In [None]:
# Show sample images
from PIL import Image

# Generate comparison images
print("Generating comparison images...")

# Baseline image
controller_base = DQARController(num_layers=num_layers, config=make_baseline_config())
generator = torch.Generator(device=device).manual_seed(42)
img_baseline = pipe(
    class_labels=[207],
    num_inference_steps=50,
    generator=generator,
    controller=controller_base,
).images[0]

# DQAR image (same seed)
controller_dqar = DQARController(num_layers=num_layers, config=make_dqar_config(4.0))
generator = torch.Generator(device=device).manual_seed(42)
img_dqar = pipe(
    class_labels=[207],
    num_inference_steps=50,
    generator=generator,
    controller=controller_dqar,
).images[0]

print(f"Baseline reuse: {controller_base.get_reuse_count()}")
print(f"DQAR reuse: {controller_dqar.get_reuse_count()}")

# Display side by side
fig, axes = plt.subplots(1, 2, figsize=(10, 5))
axes[0].imshow(img_baseline)
axes[0].set_title(f'Baseline (reuse=0)', fontsize=12)
axes[0].axis('off')

axes[1].imshow(img_dqar)
axes[1].set_title(f'DQAR τ=4.0 (reuse={controller_dqar.get_reuse_count()})', fontsize=12)
axes[1].axis('off')

fig.suptitle('DiT-XL-2-256: Golden Retriever (class 207)', fontsize=14)
plt.tight_layout()
plt.savefig('dit_xl_comparison_results.png', dpi=150)
plt.show()

In [None]:
# Entropy threshold sweep
print("Running entropy threshold sweep...")
thresholds = [2.0, 2.5, 3.0, 3.5, 4.0, 4.5, 5.0, 6.0]
sweep_results = []

for thresh in thresholds:
    cfg = make_dqar_config(thresh)
    controller = DQARController(num_layers=num_layers, config=cfg)
    generator = torch.Generator(device=device).manual_seed(42)
    
    torch.cuda.synchronize()
    start = time.perf_counter()
    
    output = pipe(
        class_labels=[207],
        num_inference_steps=25,
        generator=generator,
        controller=controller,
        output_type="pt",
    )
    
    torch.cuda.synchronize()
    elapsed = time.perf_counter() - start
    
    sweep_results.append({
        "threshold": thresh,
        "time": elapsed,
        "reuse": controller.get_reuse_count(),
    })
    print(f"τ={thresh:.1f}: time={elapsed:.2f}s, reuse={controller.get_reuse_count()}")

# Plot sweep
fig, axes = plt.subplots(1, 2, figsize=(12, 5))

threshs = [r['threshold'] for r in sweep_results]
times = [r['time'] for r in sweep_results]
reuses = [r['reuse'] for r in sweep_results]

axes[0].plot(threshs, reuses, 'o-', color='#C44E52', linewidth=2, markersize=8)
axes[0].axhline(y=baseline['avg_reuse'], color='#4C72B0', linestyle='--', label='Baseline')
axes[0].set_xlabel('Entropy Threshold', fontsize=12)
axes[0].set_ylabel('Reuse Events', fontsize=12)
axes[0].set_title('Threshold vs Reuse Count', fontsize=13)
axes[0].grid(True, alpha=0.3)
axes[0].legend()

axes[1].plot(threshs, times, 'o-', color='#C44E52', linewidth=2, markersize=8)
axes[1].axhline(y=baseline['avg_time'], color='#4C72B0', linestyle='--', label='Baseline')
axes[1].set_xlabel('Entropy Threshold', fontsize=12)
axes[1].set_ylabel('Time (seconds)', fontsize=12)
axes[1].set_title('Threshold vs Inference Time', fontsize=13)
axes[1].grid(True, alpha=0.3)
axes[1].legend()

fig.suptitle('DQAR Entropy Threshold Sweep (DiT-XL)', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.savefig('dit_xl_sweep_results.png', dpi=150)
plt.show()