# Stable Diffusion 1.5: Capturing Layer Representations Example

In [None]:
import torch
from diffusers import StableDiffusionPipeline

from utils.reprezentation import LayerPath, capture_layer_representations

## Setup Model

In [None]:
model_id = "sd-legacy/stable-diffusion-v1-5"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.float16 if device.type == "cuda" else torch.float32

In [None]:
pipe = StableDiffusionPipeline.from_pretrained(
    model_id, 
    torch_dtype=dtype,
    safety_checker=None,
).to(device)

In [None]:
prompt = "A close-up photo of a futuristic clock with glowing numbers, detailed"
generator = torch.Generator(device).manual_seed(42)

layers_to_capture = [
    # Text conditioning
    LayerPath.TEXT_EMBEDDING_FINAL,      # What U-Net reads from prompt
    
    # Critical attention layers
    LayerPath.UNET_MID_ATT,              # Global composition
    LayerPath.UNET_DOWN_2_ATT_0,         # Object-level alignment (16x16)
    LayerPath.UNET_UP_1_ATT_2,           # Fine detail alignment (32x32â†’64x64)
    
    # ResNet features for comparison
    LayerPath.UNET_DOWN_1_RES_0,         # Visual features before text (32x32)
    LayerPath.UNET_MID_RES_1,            # Post-attention features at bottleneck
    LayerPath.UNET_UP_0_RES_2,           # Final features before output
    
    # Time conditioning
    LayerPath.UNET_TIME_EMBED,           # Timestep conditioning vector
]

In [None]:
captured_tensors = capture_layer_representations(
    pipe=pipe, 
    prompt=prompt, 
    layer_paths=layers_to_capture, 
    generator=generator,
    num_inference_steps=50,
    guidance_scale=7.5
)

In [None]:
print("Captured Representations:")
print("=" * 80)
for i, tensor in enumerate(captured_tensors):
    layer_name = str(layers_to_capture[i]).split('.')[-1]
    print(f"\n{i+1}. {layer_name}")
    print(f"   Path: {layers_to_capture[i].value}")
    print(f"   Shape: {tensor.shape}")
    print(f"   Memory: {tensor.element_size() * tensor.nelement() / 1024 / 1024:.2f} MB")