## New ver

In [None]:
# If required libraries are not installed, uncomment the following:
# !pip install opencv-python transformers accelerate


import os, math
import numpy as np
import torch
from PIL import Image, ImageDraw, ImageFont
import matplotlib.cm as cm
from diffusers.utils import load_image
from diffusers import EulerDiscreteScheduler
from photomaker import PhotoMakerStableDiffusionXLPipeline, FaceAnalysis2, analyze_faces
from transformers import CLIPTokenizer            # ← add (needed later if you keep tokenizer logic)

from pathlib import Path

from PIL import Image, ImageDraw, ImageFont
from pathlib import Path



# ── global font object (about 3 % of the tile height) ───────────────────
FONT_SIZE = 40                                  # tweak if tiles change
try:                                            # try a TTF first
    ttf_path = next(
            p for p in (
                        Path("/usr/share/fonts"), Path("/usr/local/share/fonts"))
                                if p.is_dir()).rglob("DejaVuSans.ttf").__next__()
                                    font = ImageFont.truetype(str(ttf_path), FONT_SIZE)
                                    except StopIteration:
                                        font = ImageFont.load_default()             # fallback bitmap font




                                        # Device and dtype setup
                                        device = "cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu")
                                        torch_dtype = torch.bfloat16 if (device == "cuda" and torch.cuda.is_bf16_supported()) else torch.float16
                                        if device == "mps":
                                            torch_dtype = torch.float16

                                            # Load base SDXL model and PhotoMaker V2 adapter:contentReference[oaicite:4]{index=4}:contentReference[oaicite:5]{index=5}
                                            pipe = PhotoMakerStableDiffusionXLPipeline.from_pretrained(
                                                "SG161222/RealVisXL_V4.0", torch_dtype=torch_dtype
                                                ).to(device)
                                                from huggingface_hub import hf_hub_download
                                                ckpt_path = hf_hub_download(repo_id="TencentARC/PhotoMaker-V2", filename="photomaker-v2.bin", repo_type="model")
                                                pipe.load_photomaker_adapter(os.path.dirname(ckpt_path), subfolder="", weight_name=os.path.basename(ckpt_path), trigger_word="img")
                                                pipe.fuse_lora()
                                                pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)  # Euler sampler:contentReference[oaicite:6]{index=6}
                                                pipe.disable_xformers_memory_efficient_attention()  # disable for attention inspection

                                                # Reference image and identity embedding preparation
                                                reference_image_path = "keanu.jpg"  # <--- set your reference image path
                                                # reference_image_path = "tom.jpg" 


                                                # Which token should the heat-map follow?
                                                #   "face" – fixed auxiliary prompt "a face" (sharper ID-agnostic face map)
                                                #   "img"  – the PhotoMaker trigger word inside *your* prompt
                                                #   "man"  – a normal word inside *your* prompt
                                                TOKEN_FOCUS = "face"            # ← change to "img" or "man" when needed
                                                # TOKEN_FOCUS = "man"  
                                                prompt = "a man img with a beard in a space shuttle"
                                                # prompt = "a portrait of a man img with a beard playing football"


                                                # Load reference image
                                                ref_image = load_image(reference_image_path)
                                                # Detect face and get identity embedding:contentReference[oaicite:7]{index=7}
                                                face_detector = FaceAnalysis2(providers=['CUDAExecutionProvider'], allowed_modules=['detection', 'recognition'])
                                                face_detector.prepare(ctx_id=0, det_size=(640, 640))
                                                img_np = np.array(ref_image)[:, :, ::-1]  # convert PIL (RGB) to BGR NumPy for detector
                                                faces = analyze_faces(face_detector, img_np)
                                                if not faces:
                                                    raise RuntimeError("No face detected in the reference image.")

                                                    id_embed = torch.from_numpy(faces[0]["embedding"]).unsqueeze(0)  # identity embedding tensor


                                                        
                                                            
                                                            tokenizer = getattr(pipe, "tokenizer", None) or CLIPTokenizer.from_pretrained(
                                                                "SG161222/RealVisXL_V4.0", subfolder="tokenizer")


                                                                if TOKEN_FOCUS == "face":
                                                                    AUX_PROMPT = "a face"
                                                                        # AUX_PROMPT = "wtf"
                                                                            # AUX_PROMPT = "arms"
                                                                                with torch.no_grad():
                                                                                        face_latents, *_ = pipe.encode_prompt(
                                                                                                    prompt=AUX_PROMPT, device=device,
                                                                                                                num_images_per_prompt=1, do_classifier_free_guidance=False)   # (1,77,2048)

                                                                                                                    # ---- locate exact BPE sequence for “ face” -------------------------
                                                                                                                        aux_ids  = tokenizer(AUX_PROMPT, add_special_tokens=False).input_ids
                                                                                                                            # face_ids = tokenizer(" face",    add_special_tokens=False).input_ids
                                                                                                                                face_ids = tokenizer(" face",    add_special_tokens=False).input_ids
                                                                                                                                    def find_sub(seq, sub):
                                                                                                                                            for i in range(len(seq) - len(sub) + 1):
                                                                                                                                                        if seq[i:i+len(sub)] == sub:
                                                                                                                                                                        return list(range(i, i+len(sub)))
                                                                                                                                                                                return []
                                                                                                                                                                                    FACE_TOKEN_IDX = find_sub(aux_ids, face_ids)
                                                                                                                                                                                        if not FACE_TOKEN_IDX:
                                                                                                                                                                                                raise RuntimeError("Could not locate 'face' tokens in aux prompt")
                                                                                                                                                                                                    # debug
                                                                                                                                                                                                        print("[DEBUG] face token positions in aux prompt:", FACE_TOKEN_IDX)
                                                                                                                                                                                                            print("[DEBUG] ‖v_face‖ =",
                                                                                                                                                                                                                      face_latents[0, FACE_TOKEN_IDX].mean(0).norm().item())

                                                                                                                                                                                                                      # Prepare lists to collect overlay images

                                                                                                                                                                                                                      layer_names    = []     # to be filled after first callback
                                                                                                                                                                                                                      heatmaps_cross = {}     # dict[layer] -> list[PIL.Image]
                                                                                                                                                                                                                      final_image    = None


                                                                                                                                                                                                                      # Seed for reproducibility
                                                                                                                                                                                                                      seed = 56789
                                                                                                                                                                                                                      generator = torch.Generator(device=device).manual_seed(seed)


                                                                                                                                                                                                                      ###############################################################################
                                                                                                                                                                                                                      # 🪄  Monkey‑patch every CrossAttention: Q = to_q(hidden_states),
                                                                                                                                                                                                                      #      K = to_k(face_latents); heat‑map = softmax(Q·K_face)
                                                                                                                                                                                                                      ###############################################################################
                                                                                                                                                                                                                      from diffusers.models.attention_processor import Attention as CrossAttention

                                                                                                                                                                                                                      attn_maps_current = {}                 # {layer_name: [head_maps]}
                                                                                                                                                                                                                       


                                                                                                                                                                                                                       def make_hook(layer_name, module):
                                                                                                                                                                                                                           orig_forward = module.forward

                                                                                                                                                                                                                               def forward_with_hook(hidden_states,
                                                                                                                                                                                                                                                         encoder_hidden_states=None,
                                                                                                                                                                                                                                                                                   attention_mask=None):

                                                                                                                                                                                                                                                                                           # keep unconditional branch out of the analysis (CFG doubles batch)
                                                                                                                                                                                                                                                                                                   B_total = hidden_states.shape[0]
                                                                                                                                                                                                                                                                                                           hs_cond = hidden_states[B_total // 2:]           # guided half


                                                                                                                                                                                                                                                                                                                   
                                                                                                                                                                                                                                                                                                                           # ── 0. drop the unconditional half (CFG doubles the batch) ──────────────
                                                                                                                                                                                                                                                                                                                                   B_all = hidden_states.shape[0]
                                                                                                                                                                                                                                                                                                                                           hs_cond = hidden_states[B_all // 2:]           # keep “guided” branch

                                                                                                                                                                                                                                                                                                                                                   out = orig_forward(hidden_states, encoder_hidden_states, attention_mask)
                                                                                                                                                                                                                                                                                                                                                           if encoder_hidden_states is None:              # self-attention → skip
                                                                                                                                                                                                                                                                                                                                                                       return out

                                                                                                                                                                                                                                                                                                                                                                               # ── 1. Q from *image* latents (conditional half) ────────────────────────
                                                                                                                                                                                                                                                                                                                                                                                       proj_q = (module.to_q if hasattr(module, "to_q") else module.q_proj)(hs_cond)
                                                                                                                                                                                                                                                                                                                                                                                               
                                                                                                                                                                                                                                                                                                                                                                                                       


                                                                                                                                                                                                                                                                                                                                                                                                               B, L, C = proj_q.shape
                                                                                                                                                                                                                                                                                                                                                                                                                       h       = module.heads
                                                                                                                                                                                                                                                                                                                                                                                                                               d       = C // h

                                                                                                                                                                                                                                                                                                                                                                                                                                       Q = proj_q.view(B, L, h, d).permute(0, 2, 1, 3)         # (B,h,L,d)


                                                                                                                                                                                                                                                                                                                                                                                                                                               if TOKEN_FOCUS == "face":
                                                                                                                                                                                                                                                                                                                                                                                                                                                           # use only the true “face” token(s) found above
                                                                                                                                                                                                                                                                                                                                                                                                                                                                       B = hs_cond.shape[0]
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                   face_lat_batched = face_latents.to(hs_cond.dtype).repeat(B, 1, 1)
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                               proj_k = (module.to_k if hasattr(module, "to_k") else module.k_proj)(
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                               face_lat_batched)
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                           K_all = proj_k.view(B, -1, h, d).permute(0, 2, 3, 1)   # (B,h,d,77)
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       K_tok = K_all[..., FACE_TOKEN_IDX].mean(-1)            # (B,h,d)
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                               
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                               else:
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                           # token lives inside the *text prompt* that produced encoder_hidden_states
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       enc_cond = encoder_hidden_states[B_all // 2:]          # align with CFG
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                   proj_k = (module.to_k if hasattr(module, "to_k") else module.k_proj)(
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                   enc_cond)
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                               K_all = proj_k.view(B, -1, h, d).permute(0, 2, 3, 1)  # (B,h,d,T_text)


                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                           # Encode the word with its leading space → may yield 1-N tokens
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       word_ids = tokenizer(" " + TOKEN_FOCUS,
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        add_special_tokens=False).input_ids        # list[int]

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    prompt_ids = tokenizer(prompt, add_special_tokens=False).input_ids

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                # locate the first matching subsequence
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                            def find_subseq(seq, subseq):
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                            for i in range(len(seq) - len(subseq) + 1):
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                if seq[i : i + len(subseq)] == subseq:
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        return list(range(i, i + len(subseq)))
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        return []

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    idxs = find_subseq(prompt_ids, word_ids)
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                if not idxs:
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                raise RuntimeError(f'"{TOKEN_FOCUS}" not found in prompt – '
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                   "cannot build attention map.")

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                               # combine K of all tokens that spell the word (mean-pool)
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                           K_tok = K_all[..., idxs].mean(-1)                       # (B,h,d)

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                   # • normalise both vectors  → focus on *direction*
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                           eps  = 1e-8
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                   Qn   = Q  / (Q.norm(dim=-1, keepdim=True)  + eps)          # (B,h,L,d)
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                           Ktn  = K_tok / (K_tok.norm(dim=-1, keepdim=True) + eps)    # (B,h,d)

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                   # # OLD VER
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                           # logits = (Qn * Ktn.unsqueeze(2)).sum(-1)                   # (B,h,L)
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                   # att  = logits.float().softmax(-1).mean(1)[0]            # (L,)
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                           
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                   # NEW VER
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                           sim    = (Qn * Ktn.unsqueeze(2)).sum(-1)               # cosine, (B,h,L)
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                   att  = ((sim + 1) / 2).mean(1)[0].float()        # → float32 (L,)

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                           L_spatial = att.numel()             # = H*W of the current feature map
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                   H = int(math.sqrt(L_spatial))       # e.g. 32, 64, 128 …
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                           W = L_spatial // H                  # survives non-square cases
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                   

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                           layer_buf = attn_maps_current.setdefault(layer_name, [])
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                   layer_buf.append(att.view(H, W).cpu().numpy())


                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                           return out

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                               return forward_with_hook




                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                               for lname, mod in pipe.unet.named_modules():
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                   if isinstance(mod, CrossAttention):
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                           mod.forward = make_hook(lname, mod)


                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                           num_steps = 50

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                           def callback(step, timestep, latents):
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                               global attn_maps_current, layer_names, heatmaps_cross, final_image

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                   if step % 10 == 0 or step == num_steps - 1:
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                           # ❶  Consolidate maps *layer by layer* into square grids
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                   snapshot = {}
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                           for layer, maps in attn_maps_current.items():
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       flat  = np.stack(maps).mean(0)                 # mean over heads
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                   H     = int(math.sqrt(flat.size))
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                               snapshot[layer] = flat.reshape(H, H)
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                           

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                   if step == 0:
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                               k, v = next(iter(snapshot.items()))
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                           print(f"[DEBUG] first-snapshot  layer={k}  max={v.max():.4f}  mean={v.mean():.4f}")
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                   # NEW: print where the absolute max lives and some refs
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                               flat_idx = v.argmax()
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                           r, c     = divmod(flat_idx, v.shape[1])
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       print(f"[VAL0] max@({r},{c})={v.max():.3f}  "
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                         f"centre={v[v.shape[0]//2, v.shape[1]//2]:.3f}  "
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                           f"corner={v[0,0]:.3f}")


                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                   
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                           # ── DEBUG: check numeric values before colouring ────────────────
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                   if step == 0:
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                               first_layer = next(iter(snapshot))
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                           H0, W0      = snapshot[first_layer].shape
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       centre_val  = snapshot[first_layer][H0 // 2, W0 // 2]
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                   corner_val  = snapshot[first_layer][0, 0]
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                               print(f"[VAL] centre={centre_val:.3f}  corner={corner_val:.3f}  "
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                 f"max={snapshot[first_layer].max():.3f}")

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                         
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                 with torch.no_grad():
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                             
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                         vae_dev = next(pipe.vae.parameters()).device
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                     img = pipe.vae.decode(
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                     (latents / 0.18215).to(device=vae_dev, dtype=pipe.vae.dtype)
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                 ).sample[0]

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                         img_np = ((img.float() / 2 + 0.5).clamp(0, 1).cpu().permute(1, 2, 0).numpy() * 255).astype(np.uint8)

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                 if not layer_names:
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                             layer_names = list(snapshot.keys())
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                         heatmaps_cross = {ln: [] for ln in layer_names}

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                 # jet = cm.get_cmap("jet")
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                         # pick colour-map: "jet" (default) or "Greys" for a monotone ramp
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                 cmap_name = "jet"          # ← change to "Greys" if you like
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                         INVERT    = False        # ← flip to True to check if colours are inverted
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                 colormap  = cm.get_cmap(cmap_name)

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                         # save the colour-bar once (step-0 of first layer)
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                 if step == 0:
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                             import matplotlib.pyplot as plt, numpy as _np
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                         plt.figure(figsize=(4, .4))
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                     plt.axis("off")
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                 plt.imshow(_np.linspace(0, 1, 256)[None, :],
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        cmap=colormap, aspect="auto")
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    plt.savefig("colourbar.png", bbox_inches="tight")
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                plt.close()        
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        for ln in layer_names:
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    amap = snapshot[ln]
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                amap = (amap / amap.max()) if amap.max() > 0 else amap
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                            
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        # normalise 0‥1 and (optionally) invert
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    amap = (amap / amap.max()) if amap.max() > 0 else amap
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                if INVERT:
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                amap = 1.0 - amap
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                            
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    # hmap = (jet(amap)[..., :3] * 255).astype(np.uint8)
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                hmap = (colormap(amap)[..., :3] * 255).astype(np.uint8)
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                            hmap = np.array(Image.fromarray(hmap).resize((1024, 1024),
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                     Image.BILINEAR))
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                 # overlay = Image.fromarray((0.5 * img_np + 0.5 * hmap)
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                             #                           .astype(np.uint8))
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                         
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                     # ───────────────────────── add 5×5 numerical grid ──────────────────
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                 overlay_arr = (0.5 * img_np + 0.5 * hmap).astype(np.uint8)
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                             overlay     = Image.fromarray(overlay_arr)
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                         draw        = ImageDraw.Draw(overlay)

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                     # split original similarity map (H×W) into 5×5 blocks
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                 H_blk, W_blk = amap.shape[0] // 5, amap.shape[1] // 5
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                             vis_h, vis_w = 1024 // 5, 1024 // 5              # overlay size

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                         for bi in range(5):
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                         for bj in range(5):
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                             block = amap[
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                     bi*H_blk : (bi+1)*H_blk,
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                             bj*W_blk : (bj+1)*W_blk]
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                 mean_val = block.mean()

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                     # position text at block centre (resize factor already 1024/H)
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                         cx = bj * vis_w + vis_w // 2
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                             cy = bi * vis_h + vis_h // 2

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                 txt = f"{mean_val:.2f}"
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                     tw, th = draw.textbbox((0, 0), txt, font=font)[2:]
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                         draw.text((cx - tw // 2, cy - th // 2),
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       txt, font=font, fill="white",
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                     stroke_width=2, stroke_fill="black")
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                 
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                             
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                         heatmaps_cross[ln].append(overlay)

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                 # keep latest clean image for the “Final” panel
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                         final_image = Image.fromarray(img_np)

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                             attn_maps_current = {}      # reset per step

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                             # Run the diffusion process with the callback (50 steps):contentReference[oaicite:8]{index=8}

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                             print(f'prompt: {prompt}')

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                             _ = pipe(
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                 prompt=prompt,
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                     negative_prompt="(asymmetry, worst quality, low quality, illustration, 3d, cartoon, sketch)", 
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                         input_id_images=[ref_image], id_embeds=id_embed, 
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                             num_inference_steps=num_steps, 
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                 start_merge_step=10,
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                     # start_merge_step=0, 
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                         generator=generator, callback=callback, callback_steps=1
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                         )


                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                         # ─────────────────────────────────────────────────────────────
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                         #  Build and save montage from the overlays collected in‑callback
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                         # ─────────────────────────────────────────────────────────────
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                         import re
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                         header_h = 30

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                         for ln in layer_names:                          # set during callback
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                             cols = []
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                 for i, attn_img in enumerate(heatmaps_cross[ln]):
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                         step_num = i * 10                       # 0,10,20,30,40,50
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                 cols.append((f"S{step_num}", attn_img))

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                     cols.append(("Final", final_image))         # last clean frame

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                         img_w, img_h = cols[0][1].width, cols[0][1].height
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                             strip = Image.new("RGB", (img_w * len(cols), img_h + header_h),
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                   color=(0, 0, 0))
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       draw  = ImageDraw.Draw(strip)
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                           font  = ImageFont.load_default()

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                               for idx, (label, img) in enumerate(cols):
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       x_off = idx * img_w
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                               strip.paste(img, (x_off, header_h))
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       tw, th = draw.textbbox((0, 0), label, font=font)[2:]
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                               draw.text((x_off + (img_w - tw)//2, (header_h - th)//2),
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                 label, font=font, fill=(255, 255, 255))

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                     safe = re.sub(r"[^\w\-]+", "_", ln)
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                         strip.save(f"{safe}_attention_evolution.jpg")

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                         # save final image
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                         if final_image is not None:
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                             final_image.save("final_image.jpg")

  from .autonotebook import tqdm as notebook_tqdm
  deprecate("Transformer2DModelOutput", "1.0.0", deprecation_message)
Loading pipeline components...: 100%|██████████| 7/7 [00:07<00:00,  1.14s/it]


Loading PhotoMaker v2 components [1] id_encoder from [/home/kolyangg/.cache/huggingface/hub/models--TencentARC--PhotoMaker-V2/snapshots/f5a1e5155dc02166253fa7e29d13519f5ba22eac]...
4096
Loading PhotoMaker v2 components [2] lora_weights from [/home/kolyangg/.cache/huggingface/hub/models--TencentARC--PhotoMaker-V2/snapshots/f5a1e5155dc02166253fa7e29d13519f5ba22eac]




Applied providers: ['CPUExecutionProvider'], with options: {'CPUExecutionProvider': {}}
model ignore: /home/kolyangg/.insightface/models/buffalo_l/1k3d68.onnx landmark_3d_68
Applied providers: ['CPUExecutionProvider'], with options: {'CPUExecutionProvider': {}}
model ignore: /home/kolyangg/.insightface/models/buffalo_l/2d106det.onnx landmark_2d_106
Applied providers: ['CPUExecutionProvider'], with options: {'CPUExecutionProvider': {}}
find model: /home/kolyangg/.insightface/models/buffalo_l/det_10g.onnx detection [1, 3, '?', '?'] 127.5 128.0
Applied providers: ['CPUExecutionProvider'], with options: {'CPUExecutionProvider': {}}
model ignore: /home/kolyangg/.insightface/models/buffalo_l/genderage.onnx genderage
Applied providers: ['CPUExecutionProvider'], with options: {'CPUExecutionProvider': {}}
find model: /home/kolyangg/.insightface/models/buffalo_l/w600k_r50.onnx recognition ['None', 3, 112, 112] 127.5 127.5
set det-size: (640, 640)


  deprecate(
  deprecate(


[DEBUG] face token positions in aux prompt: [1]
[DEBUG] ‖v_face‖ = 33.25
prompt: a man img with a beard in a space shuttle


  2%|▏         | 1/50 [00:00<00:17,  2.83it/s]

[DEBUG] first-snapshot  layer=down_blocks.1.attentions.0.transformer_blocks.0.attn2  max=0.6055  mean=0.5354
[VAL0] max@(35,0)=0.605  centre=0.543  corner=0.586
[VAL] centre=0.543  corner=0.586  max=0.605


  colormap  = cm.get_cmap(cmap_name)
100%|██████████| 50/50 [00:31<00:00,  1.57it/s]


In [None]:
# ─────────────────────────────────────────────────────────────
#  QUICK PIL‑ONLY PDF MAKER  (≤10 rows per portrait A4 page)
# ─────────────────────────────────────────────────────────────
from PIL import Image, ImageDraw, ImageFont
import os

# --- config --------------------------------------------------
DPI         = 150                       # output resolution
PAGE_W_PX   = int(8.27 * DPI)           # A4 portrait 8.27×11.69 in
PAGE_H_PX   = int(11.69 * DPI)
ROWS_PER_PG = 10
ROW_H_PX    = PAGE_H_PX // ROWS_PER_PG
LABEL_W_PX  = int(PAGE_W_PX * 0.15)     # 15 % gutter for filename
RIGHT_PAD   = 20                        # px margin on right
FONT        = ImageFont.load_default()

# gather montage strips
montage_files = sorted(
    f for f in os.listdir('.') if f.endswith('_attention_evolution.jpg')
)

pages, y = [], 0
page = Image.new('RGB', (PAGE_W_PX, PAGE_H_PX), 'white')
draw = ImageDraw.Draw(page)

def wrapped_label(draw, text, x, y_top, row_h, max_w):
    """Draw *any* filename (no spaces needed) within max_w pixels."""
    line_h  = FONT.getbbox("A")[3]
    max_lin = row_h // line_h
    lines, cur = [], ""

    for ch in text:
        trial = cur + ch
        if draw.textlength(trial, font=FONT) <= max_w:
            cur = trial
        else:
            lines.append(cur)
            cur = ch
    lines.append(cur)

    if len(lines) > max_lin:                 # truncate vertically
        lines = lines[:max_lin]
        if len(lines[-1]) > 1:
            while draw.textlength(lines[-1] + "…", font=FONT) > max_w:
                lines[-1] = lines[-1][:-1]
            lines[-1] += "…"

    y_txt = y_top + (row_h - line_h * len(lines)) // 2
    for ln in lines:
        draw.text((x, y_txt), ln, fill="black", font=FONT)
        y_txt += line_h


for fname in montage_files:
    # --- scale montage to fit row height *and* available width ----
    strip = Image.open(fname)
    max_w = PAGE_W_PX - LABEL_W_PX - RIGHT_PAD
    scale = min(ROW_H_PX / strip.height, max_w / strip.width)
    strip = strip.resize((int(strip.width * scale),
                          int(strip.height * scale)),
                         Image.LANCZOS)

    # --- new page if needed ---------------------------------------
    if y + ROW_H_PX > PAGE_H_PX:
        pages.append(page)
        page = Image.new('RGB', (PAGE_W_PX, PAGE_H_PX), 'white')
        draw = ImageDraw.Draw(page)
        y = 0

    # --- filename (wrapped) ---------------------------------------
    wrapped_label(draw, fname, 10, y, ROW_H_PX, LABEL_W_PX - 20)

    # --- paste montage strip --------------------------------------
    x_strip = LABEL_W_PX
    y_strip = y + (ROW_H_PX - strip.height) // 2
    page.paste(strip, (x_strip, y_strip))

    y += ROW_H_PX

# final page
pages.append(page)

# --- save multipage PDF -------------------------------------------
out_pdf = 'attention_evolution_report.pdf'
pages[0].save(out_pdf, save_all=True, append_images=pages[1:])
print(f'PDF saved to {out_pdf}')


PDF saved to attention_evolution_report.pdf


: 

In [3]:
# pipe.unet.named_modules
# layer_names
