In [3]:
import torch
from diffusers import StableDiffusionPipeline
from diffusers.models.attention import Attention
from diffusers.models.attention_processor import AttnProcessor

# -----------------------------------------------------------------------------
# CONFIG
# -----------------------------------------------------------------------------
MODEL_NAME = "runwayml/stable-diffusion-v1-5"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"


GUIDANCE_SCALE = 7.5

# -----------------------------------------------------------------------------
# GLOBALS TO STORE ATTENTION
# -----------------------------------------------------------------------------
ATTENTION_LOGS = {
    "cross": [],
    "self": [],
}

# -----------------------------------------------------------------------------
# PATCH PROCESSOR TO SAVE ATTENTION
# -----------------------------------------------------------------------------
def make_patched_processor(base_cls):
    class PatchedProcessor(base_cls):
        def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, **kwargs):
            result = super().__call__(attn, hidden_states, encoder_hidden_states, attention_mask, **kwargs)

            # Recalculate attention matrix
            q = attn.to_q(hidden_states)
            is_cross = encoder_hidden_states is not None
            k = attn.to_k(encoder_hidden_states if is_cross else hidden_states)

            bsz, seqlen, _ = q.shape
            num_heads = attn.heads
            head_dim = q.shape[-1] // num_heads
            scale = 1 / head_dim ** 0.5

            q = q.view(bsz, seqlen, num_heads, head_dim).transpose(1, 2)
            k = k.view(bsz, -1, num_heads, head_dim).transpose(1, 2)

            scores = torch.matmul(q, k.transpose(-2, -1)) * scale
            probs = torch.nn.functional.softmax(scores, dim=-1)
            attn_map = probs.mean(1).detach().cpu()  # average over heads

            if is_cross:
                ATTENTION_LOGS["cross"].append(attn_map)
            else:
                ATTENTION_LOGS["self"].append(attn_map)

            return result
    return PatchedProcessor

# -----------------------------------------------------------------------------
# LOAD PIPELINE
# -----------------------------------------------------------------------------
pipe = StableDiffusionPipeline.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
    safety_checker=None,
).to(DEVICE)

# -----------------------------------------------------------------------------
# REPLACE ALL PROCESSORS WITH PATCHED ONES
# -----------------------------------------------------------------------------
for name, module in pipe.unet.named_modules():
    if isinstance(module, Attention):
        cls = type(module.processor)
        module.processor = make_patched_processor(cls)()

# -----------------------------------------------------------------------------
# RUN GENERATION
# -----------------------------------------------------------------------------



Loading pipeline components...: 100%|██████████| 6/6 [01:14<00:00, 12.43s/it]
You have disabled the safety checker for <class 'diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline'> by passing `safety_checker=None`. Ensure that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered results in services or applications open to the public. Both the diffusers team and Hugging Face strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling it only for use-cases that involve analyzing network behavior or auditing its results. For more information, please have a look at https://github.com/huggingface/diffusers/pull/254 .
100%|██████████| 5/5 [00:02<00:00,  2.13it/s]

🎯 Captured 96 cross-attn maps
🎯 Captured 96 self-attn maps
cross_attns[0] shape: torch.Size([2, 4096, 77])





In [26]:
PROMPT = "A wisdom tooth"
NUM_STEPS = 50
with torch.no_grad():
    output = pipe(PROMPT, num_inference_steps=NUM_STEPS, guidance_scale=GUIDANCE_SCALE)

  2%|▏         | 1/50 [00:00<00:37,  1.31it/s]

100%|██████████| 50/50 [00:19<00:00,  2.51it/s]


In [27]:
import matplotlib.pyplot as plt
import torch.nn.functional as F

def visualize_tokens_layers(attn_logs, image, tokenizer, prompt, layers):
    # Get tokenized prompt
    token_ids = tokenizer(prompt, return_tensors="pt")["input_ids"][0]
    tokens = tokenizer.convert_ids_to_tokens(token_ids)
    num_tokens = len(tokens)

    fig, axs = plt.subplots(num_tokens, len(layers) + 1, figsize=(4 * (len(layers) + 1), 4 * num_tokens))

    for i in range(num_tokens):
        token = tokens[i]

        for j, layer in enumerate(layers):
            attn = attn_logs[layer][0]  # shape: [Q, K]
            heat = attn[:, i]  # [Q] for token i

            side = int(heat.shape[0] ** 0.5)
            heatmap = heat.view(1, 1, side, side)
            heatmap = F.interpolate(heatmap, size=(512, 512), mode="bilinear", align_corners=False)[0, 0]
            heatmap = (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min())

            ax = axs[i][j + 1] if num_tokens > 1 else axs[j + 1]
            ax.imshow(image)
            ax.imshow(heatmap.numpy(), cmap="magma", alpha=0.5)
            ax.set_title(f"Layer {layer}")
            ax.axis("off")

        # First column = image only
        ax_img = axs[i][0] if num_tokens > 1 else axs[0]
        ax_img.imshow(image)
        ax_img.set_title(f"Token {i}: {token}")
        ax_img.axis("off")

    plt.tight_layout()
    plt.show()


In [None]:
import os
import json
import random
import numpy as np
from PIL import Image
from tqdm import tqdm
import torch
import sys
import matplotlib.pyplot as plt
import torch.nn.functional as F
from diffusers import (
    StableDiffusionInstructPix2PixPipeline,
    EulerAncestralDiscreteScheduler,
    DiffusionPipeline,
    DDIMScheduler
)
from diffusers.models.attention import Attention
from diffusers.models.attention_processor import AttnProcessor
from huggingface_hub import login

# === CONFIG ===
BASE_DIR = "../"
METADATA_PATH = os.path.join(BASE_DIR, "editing_common", "editing_metadata.json")
OUTPUT_ROOT = os.path.join(BASE_DIR, "attention_maps")
os.makedirs(OUTPUT_ROOT, exist_ok=True)
CACHE_DIR = os.path.join(BASE_DIR, "cache")
HF_TOKEN = "Your huggingface token"
DIFFUSION_STEPS = 50
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# === GLOBAL ATTENTION STORE ===
ATTENTION_LOGS = {"cross": [], "self": []}

def make_patched_processor(base_cls):
    class PatchedProcessor(base_cls):
        def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, **kwargs):
            result = super().__call__(attn, hidden_states, encoder_hidden_states, attention_mask, **kwargs)
            q = attn.to_q(hidden_states)
            is_cross = encoder_hidden_states is not None
            k = attn.to_k(encoder_hidden_states if is_cross else hidden_states)
            bsz, seqlen, _ = q.shape
            num_heads = attn.heads
            head_dim = q.shape[-1] // num_heads
            scale = 1 / head_dim ** 0.5
            q = q.view(bsz, seqlen, num_heads, head_dim).transpose(1, 2)
            k = k.view(bsz, -1, num_heads, head_dim).transpose(1, 2)
            scores = torch.matmul(q, k.transpose(-2, -1)) * scale
            probs = torch.nn.functional.softmax(scores, dim=-1)
            attn_map = probs.mean(1).detach().cpu()
            if is_cross:
                ATTENTION_LOGS["cross"].append(attn_map)
            else:
                ATTENTION_LOGS["self"].append(attn_map)
            return result
    return PatchedProcessor

def patch_attention(pipe):
    for name, module in pipe.unet.named_modules():
        if isinstance(module, Attention):
            module.processor = make_patched_processor(type(module.processor))()

# === PIPELINE LOADER ===
def load_ip2p_pipeline(model_id):
    pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained(
        model_id,
        torch_dtype=torch.float16 if DEVICE.type == "cuda" else torch.float32,
        safety_checker=None
    )
    pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
    patch_attention(pipe)
    return pipe.to(DEVICE)

# === GENERATION ===
def generate_with_ip2p(pipe, sample, output_dir, seeds=1, igs_values=[1.6], gs_values=[7.5], constant_seed=False):
    os.makedirs(output_dir, exist_ok=True)
    img_id = sample["id"]
    prompt = sample["prompt"]
    # img_path = os.path.normpath(os.path.join(BASE_DIR, sample["previous_image"]))
    img_path = os.path.join(BASE_DIR, sample["previous_image"].replace("\\", "/"))

    if not os.path.exists(img_path):
        print(f"Image not found for ID {img_id}: {img_path}")
        return

    original_image = Image.open(img_path).convert("RGB")
    image = original_image.resize((512, 512))
    seed_list = [0] * seeds if constant_seed else [random.randint(0, 9999) for _ in range(seeds)]
    idx = 0

    for seed in seed_list:
        for igs in igs_values:
            for gs in gs_values:
                generator = torch.manual_seed(seed)
                result = pipe(
                    prompt,
                    image=image,
                    num_inference_steps=DIFFUSION_STEPS,
                    image_guidance_scale=igs,
                    guidance_scale=gs,
                    generator=generator
                )
                result_image = result.images[0]
                result_path = os.path.join(output_dir, f"{img_id}_{idx}.png")
                result_image.save(result_path)

                # === ATTENTION VISUALIZATION ===
                tokens = pipe.tokenizer(prompt, return_tensors="pt")["input_ids"][0]
                decoded = pipe.tokenizer.convert_ids_to_tokens(tokens)
                token_ids_to_show = list(range(len(decoded)))

                TARGET_SHAPE = (64, 64)  # Resize all attention maps to this

                for tid in token_ids_to_show:
                    token = decoded[tid]
                    maps = []
                    for attn in ATTENTION_LOGS["cross"]:
                        token_map = attn[0, :, tid]  # [Q]
                        Q = token_map.shape[0]
                        H = W = int(Q ** 0.5)
                        while H * W != Q and H > 1:
                            H -= 1
                            W = Q // H if Q % H == 0 else W
                        if H * W != Q:
                            continue
                        heat = token_map.view(1, 1, H, W)
                        heat = F.interpolate(heat, size=TARGET_SHAPE, mode="bilinear", align_corners=False)
                        maps.append(heat)

                    if len(maps) > 0:
                        avg_map = torch.stack(maps).mean(dim=0)[0, 0]  # [64, 64]
                        avg_map = (avg_map - avg_map.min()) / (avg_map.max() - avg_map.min())
                        heatmap = F.interpolate(avg_map.unsqueeze(0).unsqueeze(0), size=(512, 512), mode="bilinear")[0, 0]

                        fig, axs = plt.subplots(1, 3, figsize=(18, 6))
                        im0 = axs[0].imshow(heatmap.numpy(), cmap="magma")
                        axs[0].set_title(f"Attention Map\nToken: '{token}'")
                        axs[0].axis("off")
                        fig.colorbar(im0, ax=axs[0], fraction=0.046, pad=0.04)

                        axs[1].imshow(original_image.resize((512, 512)))
                        axs[1].imshow(heatmap.numpy(), cmap="magma", alpha=0.5)
                        axs[1].set_title(f"Original + Attention\nToken: '{token}'")
                        axs[1].axis("off")

                        axs[2].imshow(result_image)
                        axs[2].imshow(heatmap.numpy(), cmap="magma", alpha=0.5)
                        axs[2].set_title(f"Edited + Attention\nToken: '{token}'")
                        axs[2].axis("off")

                        plt.tight_layout()
                        plt.savefig(os.path.join(output_dir, f"{img_id}_{idx}_avg_attn_token{tid}.png"))
                        plt.close()

                ATTENTION_LOGS["cross"].clear()
                ATTENTION_LOGS["self"].clear()
                idx += 1

# === PROCESSING ===
def process_model(model_id, model_name, sample_slice=None):
    print(f"Loading model: {model_name}")
    with open(METADATA_PATH, "r") as f:
        samples = json.load(f)["samples"]
    if sample_slice:
        samples = samples[sample_slice]
    output_dir = os.path.join(OUTPUT_ROOT, model_name)
    pipe = load_ip2p_pipeline(model_id)
    for sample in tqdm(samples, desc=f"Editing with {model_name}"):
        generate_with_ip2p(pipe, sample, output_dir)

# === MAIN ===
def main():
    login(HF_TOKEN)
    process_model(
        model_id="timbrooks/instruct-pix2pix",
        model_name="instruct-pix2pix_common",
        sample_slice=slice(None, 3)
    )

if __name__ == "__main__":
    main()
