# Stable Diffusion xl 跨注意力可视化（组合概念纠缠）

In [None]:
import torch
import torch.nn.functional as F
from torch import nn
from diffusers import StableDiffusionPipeline
from diffusers.models.attention_processor import AttnProcessor
import matplotlib.pyplot as plt
import numpy as np
import cv2
from typing import Dict, List, Tuple

device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16  # v1-5 支持 fp16 推理


In [None]:
class CrossAttnStore(AttnProcessor):
    """
    拦截 cross-attention，存储 softmax(QK^T) 后的注意力。
    仅对 cross-attn (attn.is_cross_attention=True) 生效。
    """
    def __init__(self, attn_store: Dict):
        super().__init__()
        self.store = attn_store  # {"maps": [(B,H,N,L), ...], "meta": [...]}

    def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None):
        # 复制自 diffusers 默认 AttnProcessor，但在 softmax 后记录
        batch_size, sequence_length, _ = hidden_states.shape
        attn = attn.to_qkv(dtype=hidden_states.dtype)
        query = attn.to_q(hidden_states)
        key = attn.to_k(encoder_hidden_states if encoder_hidden_states is not None else hidden_states)
        value = attn.to_v(encoder_hidden_states if encoder_hidden_states is not None else hidden_states)

        inner_dim = key.shape[-1]
        head_dim = inner_dim // attn.heads
        query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

        # scaled dot-product attention
        scale = attn.scale
        attn_scores = torch.matmul(query, key.transpose(-1, -2)) * scale

        if attention_mask is not None:
            attn_scores = attn_scores + attention_mask

        attn_probs = attn_scores.softmax(dim=-1)  # (B, heads, query_len, key_len)

        # ---- 存储 cross-attn 的 softmax 后权重 ----
        if attn.is_cross_attention:
            self.store["maps"].append(attn_probs.detach().cpu())

        hidden_states = torch.matmul(attn_probs, value)
        hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, inner_dim)
        hidden_states = attn.to_out[0](hidden_states)
        hidden_states = attn.to_out[1](hidden_states)
        return hidden_states


In [None]:
def register_attn_store(pipe, store: Dict):
    proc = CrossAttnStore(store)
    for name, module in pipe.unet.named_modules():
        if "attn2" in name:  # cross-attn 层通常叫 attn2
            module.set_processor(proc)


In [None]:
def find_token_indices(tokenizer, prompt: str, target: str) -> List[int]:
    """
    在 tokenized prompt 中找到与 target(小写) 匹配的 token 索引。
    简单做法：逐 token 解码匹配子串。
    """
    ids = tokenizer(prompt, return_tensors="pt")["input_ids"][0]
    tokens = [tokenizer.decode([i]).strip().lower() for i in ids]
    target = target.lower()
    return [i for i, t in enumerate(tokens) if target in t and t != ""]

def aggregate_maps(attn_maps: List[torch.Tensor]) -> torch.Tensor:
    """
    attn_maps: list of (B, H, N, L) over all cross-attn layers/steps
    返回 shape (L, sqrt(N), sqrt(N)) 的平均注意力（对 batch/head/层 求均值）。
    """
    if len(attn_maps) == 0:
        raise ValueError("没有捕获到 cross-attention map")
    # concat -> (S, B, H, N, L)
    stacked = torch.stack(attn_maps)  # S: layers*steps
    # mean over S, H, B
    mean_map = stacked.mean(dim=(0,1,2))  # (N, L)
    return mean_map  # query_len x key_len


In [None]:
def get_token_map(mean_map: torch.Tensor, token_indices: List[int], hw: int) -> torch.Tensor:
    """
    mean_map: (N, L)
    token_indices: 需要的 key token index 列表
    hw: 空间边长 sqrt(N)
    """
    if len(token_indices) == 0:
        raise ValueError("未找到目标 token")
    # 对目标 tokens 取均值
    sub = mean_map[:, token_indices].mean(dim=-1)  # (N,)
    attn_hw = sub.reshape(int(hw), int(hw))  # (h, w)
    return attn_hw


In [None]:
def upscale_to_512(attn_hw: torch.Tensor) -> np.ndarray:
    attn = attn_hw.unsqueeze(0).unsqueeze(0)  # (1,1,h,w)
    up = F.interpolate(attn, size=(512,512), mode="bilinear", align_corners=False)
    up = up[0,0].cpu().numpy()
    up = (up - up.min()) / (up.max() - up.min() + 1e-8)
    return up


In [None]:
def overlay_heatmap(rgb: np.ndarray, heat: np.ndarray, alpha=0.5, cmap=cv2.COLORMAP_JET):
    heat_color = cv2.applyColorMap((heat*255).astype(np.uint8), cmap)
    heat_color = cv2.cvtColor(heat_color, cv2.COLOR_BGR2RGB)
    over = (alpha*heat_color + (1-alpha)*rgb).astype(np.uint8)
    return over


In [None]:
def iou_soft(a: np.ndarray, b: np.ndarray):
    num = np.minimum(a, b).sum()
    den = np.maximum(a, b).sum() + 1e-8
    return num / den

def cosine_sim(a: np.ndarray, b: np.ndarray):
    a_flat, b_flat = a.flatten(), b.flatten()
    return np.dot(a_flat, b_flat) / (np.linalg.norm(a_flat)*np.linalg.norm(b_flat)+1e-8)


In [None]:
def visualize_attention(prompt: str, subject: str, action: str, guidance_scale=7.5, steps=30):
    store = {"maps": []}

    pipe = StableDiffusionPipeline.from_pretrained(
        "runwayml/stable-diffusion-v1-5",
        torch_dtype=dtype,
        safety_checker=None,
    ).to(device)

    register_attn_store(pipe, store)

    # 生成
    with torch.autocast(device):
        image = pipe(prompt, guidance_scale=guidance_scale, num_inference_steps=steps).images[0]
    rgb = np.array(image)

    # 聚合注意力
    mean_map = aggregate_maps(store["maps"])  # (N, L)
    hw = int(np.sqrt(mean_map.shape[0]))  # e.g., 16x16 or 32x32

    tokenizer = pipe.tokenizer
    subj_idx = find_token_indices(tokenizer, prompt, subject)
    act_idx  = find_token_indices(tokenizer, prompt, action)

    subj_map = upscale_to_512(get_token_map(mean_map, subj_idx, hw))
    act_map  = upscale_to_512(get_token_map(mean_map, act_idx, hw))

    subj_overlay = overlay_heatmap(rgb, subj_map, alpha=0.45)
    act_overlay  = overlay_heatmap(rgb, act_map,  alpha=0.45)

    # 量化纠缠
    iou = iou_soft(subj_map, act_map)
    cos = cosine_sim(subj_map, act_map)

    # 可视化
    fig, axes = plt.subplots(1, 3, figsize=(18,6))
    axes[0].imshow(rgb); axes[0].set_title("生成图像"); axes[0].axis("off")
    axes[1].imshow(subj_overlay); axes[1].set_title(f"{subject} 注意力叠加"); axes[1].axis("off")
    axes[2].imshow(act_overlay);  axes[2].set_title(f"{action} 注意力叠加"); axes[2].axis("off")
    plt.suptitle(f"IoU={iou:.3f} | Cosine={cos:.3f}", fontsize=14)
    plt.tight_layout()
    plt.show()

    return {"image": image, "subj_map": subj_map, "act_map": act_map, "iou": iou, "cos": cos}


In [None]:
result = visualize_attention(
    prompt="A photo of Mickey Mouse smoking",
    subject="mickey",
    action="smoking",
    guidance_scale=7.5,
    steps=30,
)
print(f"IoU={result['iou']:.4f}, Cosine={result['cos']:.4f}")
