- Defines output directory (exp3/)
- Sets model ID, precision (fp16), number of diffusion steps (N_STEPS), transformer levels (N_LEVELS), and random seed.
- Loads the SANA model with weights from the Hugging Face repo.
- Pushes the model and VAE/text encoder to GPU with the right precision.

In [1]:
# ==============================================================
#  imports, global config, and pipeline load
#  ─────────────────────────────────────────────────────────────
#  • Sets basic experiment parameters (model id, #steps…)
#  • Loads the SANA pipeline on GPU, casting weights to fp16/bf16
# ==============================================================
import math, itertools, pathlib
import torch, tqdm 
from diffusers import SanaPipeline
import re, matplotlib.pyplot as plt
from PIL import Image

In [2]:
# ---------- directory where all outputs will be saved ----------
OUTDIR = pathlib.Path("exp3")             # folder name
OUTDIR.mkdir(exist_ok=True)                  # create if it doesn't exist

# ---------- experiment hyper‑parameters -----------------------
MODEL_ID   = "Efficient-Large-Model/Sana_1600M_1024px_diffusers"
DTYPE      = torch.float16                   # weights dtype for GPU
DEVICE     = "cuda"                          # or "mps" / "cpu" if needed
N_STEPS    = 20                              # diffusion steps  (T)
N_LEVELS   = 20                              # transformer blocks (L)
SEED       = 42                              # reproducible randomness
PROMPT_TXT = "sana_position_prompts.txt"         # text file – one prompt / line


# ---------- load SANA pipeline --------------------------------
pipe = (
    SanaPipeline           # diffusers class
    .from_pretrained(MODEL_ID, variant="fp16", torch_dtype=DTYPE)
    .to(DEVICE)            # push entire pipeline to GPU
)

pipe.vae         .to(torch.bfloat16)
pipe.text_encoder.to(torch.bfloat16)
pipe.scheduler.set_timesteps(N_STEPS)

Loading pipeline components...:   0%|          | 0/5 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

This part prepares tools for running and tracking experiments:
- Tracks which timestep (t) the model is currently processing via a forward_pre_hook.
- Loads text prompts from the file sana_position_prompts.txt.
- Defines a generate() helper function to run the pipeline and save results to disk.

In [3]:
# ==============================================================
#  utility hooks / helpers
#  ─────────────────────────────────────────────────────────────
#  • step_counter: lets us know which diffusion step (t) we’re on
#  • PROMPTS: list of prompts loaded from the text file
#  • generate(): helper to run the pipeline and save an image
# ==============================================================
# -------- 1. step counter (pre‑hook on entire transformer) ----
step_counter = {"t": -1}                     # wrapped in dict for mutability

def _count_steps(_, __):
    """Executed *before* each denoising step; increments global t."""
    step_counter["t"] += 1

pipe.transformer.register_forward_pre_hook(_count_steps)

# -------- 2. load prompts -------------------------------------
with open(PROMPT_TXT) as f:
    PROMPTS = [ln.strip() for ln in f if ln.strip()]
print("Loaded", len(PROMPTS), "prompts")

# -------- 3. generate & save helper ---------------------------
def generate(prompt: str,
             seed: int,
             filename: str,
             ):
    """
    Runs the pipeline once and writes PNG → OUTDIR/filename
    If hooks_enabled=False the model runs normally; otherwise
    already‑registered hooks (mean or swap) will manipulate outputs.
    """
    step_counter["t"] = -1                   # reset timestep counter

    # Deterministic generator for reproducibility
    gen = torch.Generator(device=DEVICE).manual_seed(seed)

    # Run diffusion
    img = pipe(
        prompt,
        height=1024, width=1024,
        num_inference_steps=N_STEPS,
        guidance_scale=5.0,                 # classifier‑free guidance weight
        generator=gen
    )[0][0]                                    # diffusers returns (images, …)
    img.save(OUTDIR / filename)
    return img                              # return PIL.Image for display


Loaded 450 prompts


# Phase A – Mean Activation Collection
- This phase records the average output of every FFN block at every timestep across all prompts.
- Hooks are registered to all FFN modules.
- For each prompt, we run the pipeline and record FFN outputs.
- After all prompts, the average activation per layer and timestep is computed.
- Final means are saved as ffn_mean_maps.pt.--> exp3/ffn_mean_maps.pt

In [4]:
def print_hook_counts():
    counts = [
        len(mod._forward_hooks)          # PyTorch stores existing hooks here
        for mod in pipe.transformer.transformer_blocks
    ]
    print("Hooks per block:", counts)
    print("Total hooks    :", sum(counts))


In [None]:
# ==============================================================
#  Phase A: collect ⟨level, t⟩ mean maps
#  ─────────────────────────────────────────────────────────────
#  • For each prompt, record the *output* of every FFN block
#    at every timestep; average across prompts to build a
#    stable “mean map” tensor.
# ==============================================================
# ---------- 1. storage containers -----------------------------
means = [[None] * N_STEPS for _ in range(N_LEVELS)]   # 0‑based: 0‥N_STEPS‑1
n_prompts_seen = 0

def make_record_hook(level: int):
    """
    Returns a hook that, for its transformer *level*, records the
    average token vector produced by the FFN at every diffusion step.
    """
    def _hook(_, __, out):
        t = step_counter["t"]               # current denoising step (0‑based)

        # ------- NEW: skip the *extra* call after step N_STEPS‑1 ----------
        if t >= N_STEPS:
            return out                      # do nothing, just pass output on

        avg_tokens = out.detach().float().mean(0)  # (seq_len, hidden)

        # accumulate
        if means[level][t] is None:
            means[level][t] = avg_tokens
        else:
            means[level][t] += avg_tokens
        return out                          # always return the original output
    return _hook

# ---------- 2. attach recorders to each FFN --------------------
H_REC = [
    blk.ff.register_forward_hook(make_record_hook(lvl))
    for   lvl, blk in enumerate(pipe.transformer.transformer_blocks)
]
print_hook_counts()
# ---------- 3. run pipeline over *all* prompts -----------------
for prompt in tqdm.tqdm(PROMPTS, desc="Collect means"):
    step_counter["t"] = -1                                # reset
    seed_this_run = SEED + n_prompts_seen                 # new seed each run
    _ = pipe(
        prompt,
        height=1024, width=1024,
        num_inference_steps=N_STEPS,
        guidance_scale=5.0,
        generator=torch.Generator(device=DEVICE).manual_seed(seed_this_run)
    )
    n_prompts_seen += 1

# ---------- 4. finalise means and clean up ---------------------
for h in H_REC:
    h.remove()                                            # detach hooks
print_hook_counts()
H_MEAN.clear()
print("Phase B hooks removed; model is clean for Phase C.")
# divide by #prompts to get the actual mean
means = [[m / n_prompts_seen for m in row] for row in means]

torch.save(means, OUTDIR / "ffn_mean_maps.pt")
print("Mean maps saved →", OUTDIR / "ffn_mean_maps.pt")

Hooks per block: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
Total hooks    : 0


Collect means:   0%|                                                        | 0/450 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:   0%|                                                | 1/450 [00:03<25:29,  3.41s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:   0%|▏                                               | 2/450 [00:06<24:05,  3.23s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:   1%|▎                                               | 3/450 [00:09<23:36,  3.17s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:   1%|▍                                               | 4/450 [00:12<23:20,  3.14s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:   1%|▌                                               | 5/450 [00:15<23:12,  3.13s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:   1%|▋                                               | 6/450 [00:18<23:04,  3.12s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:   2%|▋                                               | 7/450 [00:22<22:58,  3.11s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:   2%|▊                                               | 8/450 [00:25<22:55,  3.11s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:   2%|▉                                               | 9/450 [00:28<22:51,  3.11s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:   2%|█                                              | 10/450 [00:31<22:47,  3.11s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:   2%|█▏                                             | 11/450 [00:34<22:43,  3.11s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:   3%|█▎                                             | 12/450 [00:37<22:40,  3.11s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:   3%|█▎                                             | 13/450 [00:40<22:36,  3.11s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:   3%|█▍                                             | 14/450 [00:43<22:33,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:   3%|█▌                                             | 15/450 [00:46<22:31,  3.11s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:   4%|█▋                                             | 16/450 [00:49<22:27,  3.11s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:   4%|█▊                                             | 17/450 [00:53<22:24,  3.11s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:   4%|█▉                                             | 18/450 [00:56<22:21,  3.11s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:   4%|█▉                                             | 19/450 [00:59<22:18,  3.11s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:   4%|██                                             | 20/450 [01:02<22:16,  3.11s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:   5%|██▏                                            | 21/450 [01:05<22:13,  3.11s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:   5%|██▎                                            | 22/450 [01:08<22:10,  3.11s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:   5%|██▍                                            | 23/450 [01:11<22:06,  3.11s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:   5%|██▌                                            | 24/450 [01:14<22:04,  3.11s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:   6%|██▌                                            | 25/450 [01:17<22:00,  3.11s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:   6%|██▋                                            | 26/450 [01:21<21:57,  3.11s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:   6%|██▊                                            | 27/450 [01:24<21:54,  3.11s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:   6%|██▉                                            | 28/450 [01:27<21:51,  3.11s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:   6%|███                                            | 29/450 [01:30<21:48,  3.11s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:   7%|███▏                                           | 30/450 [01:33<21:44,  3.11s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:   7%|███▏                                           | 31/450 [01:36<21:42,  3.11s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:   7%|███▎                                           | 32/450 [01:39<21:39,  3.11s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:   7%|███▍                                           | 33/450 [01:42<21:35,  3.11s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:   8%|███▌                                           | 34/450 [01:45<21:32,  3.11s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:   8%|███▋                                           | 35/450 [01:49<21:29,  3.11s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:   8%|███▊                                           | 36/450 [01:52<21:26,  3.11s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:   8%|███▊                                           | 37/450 [01:55<21:23,  3.11s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:   8%|███▉                                           | 38/450 [01:58<21:20,  3.11s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:   9%|████                                           | 39/450 [02:01<21:16,  3.11s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:   9%|████▏                                          | 40/450 [02:04<21:13,  3.11s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:   9%|████▎                                          | 41/450 [02:07<21:11,  3.11s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:   9%|████▍                                          | 42/450 [02:10<21:08,  3.11s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  10%|████▍                                          | 43/450 [02:13<21:04,  3.11s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  10%|████▌                                          | 44/450 [02:16<21:01,  3.11s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  10%|████▋                                          | 45/450 [02:20<20:58,  3.11s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  10%|████▊                                          | 46/450 [02:23<20:55,  3.11s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  10%|████▉                                          | 47/450 [02:26<20:52,  3.11s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  11%|█████                                          | 48/450 [02:29<20:49,  3.11s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  11%|█████                                          | 49/450 [02:32<20:46,  3.11s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  11%|█████▏                                         | 50/450 [02:35<20:42,  3.11s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  11%|█████▎                                         | 51/450 [02:38<20:39,  3.11s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  12%|█████▍                                         | 52/450 [02:41<20:35,  3.11s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  12%|█████▌                                         | 53/450 [02:44<20:33,  3.11s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  12%|█████▋                                         | 54/450 [02:48<20:29,  3.11s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  12%|█████▋                                         | 55/450 [02:51<20:26,  3.11s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  12%|█████▊                                         | 56/450 [02:54<20:23,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  13%|█████▉                                         | 57/450 [02:57<20:20,  3.11s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  13%|██████                                         | 58/450 [03:00<20:17,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  13%|██████▏                                        | 59/450 [03:03<20:13,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  13%|██████▎                                        | 60/450 [03:06<20:10,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  14%|██████▎                                        | 61/450 [03:09<20:07,  3.11s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  14%|██████▍                                        | 62/450 [03:12<20:05,  3.11s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  14%|██████▌                                        | 63/450 [03:15<20:02,  3.11s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  14%|██████▋                                        | 64/450 [03:19<19:59,  3.11s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  14%|██████▊                                        | 65/450 [03:22<19:55,  3.11s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  15%|██████▉                                        | 66/450 [03:25<19:52,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  15%|██████▉                                        | 67/450 [03:28<19:49,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  15%|███████                                        | 68/450 [03:31<19:45,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  15%|███████▏                                       | 69/450 [03:34<19:42,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  16%|███████▎                                       | 70/450 [03:37<19:39,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  16%|███████▍                                       | 71/450 [03:40<19:36,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  16%|███████▌                                       | 72/450 [03:43<19:33,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  16%|███████▌                                       | 73/450 [03:47<19:30,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  16%|███████▋                                       | 74/450 [03:50<19:27,  3.11s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  17%|███████▊                                       | 75/450 [03:53<19:24,  3.11s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  17%|███████▉                                       | 76/450 [03:56<19:21,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  17%|████████                                       | 77/450 [03:59<19:18,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  17%|████████▏                                      | 78/450 [04:02<19:14,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  18%|████████▎                                      | 79/450 [04:05<19:11,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  18%|████████▎                                      | 80/450 [04:08<19:08,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  18%|████████▍                                      | 81/450 [04:11<19:05,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  18%|████████▌                                      | 82/450 [04:14<19:01,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  18%|████████▋                                      | 83/450 [04:18<18:58,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  19%|████████▊                                      | 84/450 [04:21<18:55,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  19%|████████▉                                      | 85/450 [04:24<18:52,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  19%|████████▉                                      | 86/450 [04:27<18:49,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  19%|█████████                                      | 87/450 [04:30<18:47,  3.11s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  20%|█████████▏                                     | 88/450 [04:33<18:44,  3.11s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  20%|█████████▎                                     | 89/450 [04:36<18:41,  3.11s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  20%|█████████▍                                     | 90/450 [04:39<18:37,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  20%|█████████▌                                     | 91/450 [04:42<18:34,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  20%|█████████▌                                     | 92/450 [04:46<18:31,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  21%|█████████▋                                     | 93/450 [04:49<18:28,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  21%|█████████▊                                     | 94/450 [04:52<18:25,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  21%|█████████▉                                     | 95/450 [04:55<18:21,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  21%|██████████                                     | 96/450 [04:58<18:18,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  22%|██████████▏                                    | 97/450 [05:01<18:15,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  22%|██████████▏                                    | 98/450 [05:04<18:12,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  22%|██████████▎                                    | 99/450 [05:07<18:09,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  22%|██████████▏                                   | 100/450 [05:10<18:05,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  22%|██████████▎                                   | 101/450 [05:13<18:02,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  23%|██████████▍                                   | 102/450 [05:17<17:59,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  23%|██████████▌                                   | 103/450 [05:20<17:56,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  23%|██████████▋                                   | 104/450 [05:23<17:53,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  23%|██████████▋                                   | 105/450 [05:26<17:50,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  24%|██████████▊                                   | 106/450 [05:29<17:47,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  24%|██████████▉                                   | 107/450 [05:32<17:44,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  24%|███████████                                   | 108/450 [05:35<17:41,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  24%|███████████▏                                  | 109/450 [05:38<17:38,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  24%|███████████▏                                  | 110/450 [05:41<17:35,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  25%|███████████▎                                  | 111/450 [05:44<17:31,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  25%|███████████▍                                  | 112/450 [05:48<17:28,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  25%|███████████▌                                  | 113/450 [05:51<17:25,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  25%|███████████▋                                  | 114/450 [05:54<17:23,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  26%|███████████▊                                  | 115/450 [05:57<17:19,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  26%|███████████▊                                  | 116/450 [06:00<17:16,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  26%|███████████▉                                  | 117/450 [06:03<17:13,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  26%|████████████                                  | 118/450 [06:06<17:10,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  26%|████████████▏                                 | 119/450 [06:09<17:07,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  27%|████████████▎                                 | 120/450 [06:12<17:04,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  27%|████████████▎                                 | 121/450 [06:16<17:01,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  27%|████████████▍                                 | 122/450 [06:19<16:58,  3.11s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  27%|████████████▌                                 | 123/450 [06:22<16:55,  3.11s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  28%|████████████▋                                 | 124/450 [06:25<16:52,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  28%|████████████▊                                 | 125/450 [06:28<16:48,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  28%|████████████▉                                 | 126/450 [06:31<16:45,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  28%|████████████▉                                 | 127/450 [06:34<16:42,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  28%|█████████████                                 | 128/450 [06:37<16:39,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  29%|█████████████▏                                | 129/450 [06:40<16:35,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  29%|█████████████▎                                | 130/450 [06:43<16:32,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  29%|█████████████▍                                | 131/450 [06:47<16:29,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  29%|█████████████▍                                | 132/450 [06:50<16:26,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  30%|█████████████▌                                | 133/450 [06:53<16:23,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  30%|█████████████▋                                | 134/450 [06:56<16:20,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  30%|█████████████▊                                | 135/450 [06:59<16:17,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  30%|█████████████▉                                | 136/450 [07:02<16:14,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  30%|██████████████                                | 137/450 [07:05<16:11,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  31%|██████████████                                | 138/450 [07:08<16:08,  3.11s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  31%|██████████████▏                               | 139/450 [07:11<16:05,  3.11s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  31%|██████████████▎                               | 140/450 [07:14<16:02,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  31%|██████████████▍                               | 141/450 [07:18<15:59,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  32%|██████████████▌                               | 142/450 [07:21<15:55,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  32%|██████████████▌                               | 143/450 [07:24<15:52,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  32%|██████████████▋                               | 144/450 [07:27<15:49,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  32%|██████████████▊                               | 145/450 [07:30<15:46,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  32%|██████████████▉                               | 146/450 [07:33<15:43,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  33%|███████████████                               | 147/450 [07:36<15:40,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  33%|███████████████▏                              | 148/450 [07:39<15:37,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  33%|███████████████▏                              | 149/450 [07:42<15:34,  3.11s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  33%|███████████████▎                              | 150/450 [07:46<15:31,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  34%|███████████████▍                              | 151/450 [07:49<15:28,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  34%|███████████████▌                              | 152/450 [07:52<15:24,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  34%|███████████████▋                              | 153/450 [07:55<15:21,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  34%|███████████████▋                              | 154/450 [07:58<15:18,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  34%|███████████████▊                              | 155/450 [08:01<15:15,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  35%|███████████████▉                              | 156/450 [08:04<15:12,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  35%|████████████████                              | 157/450 [08:07<15:09,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  35%|████████████████▏                             | 158/450 [08:10<15:05,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  35%|████████████████▎                             | 159/450 [08:13<15:03,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  36%|████████████████▎                             | 160/450 [08:17<14:59,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  36%|████████████████▍                             | 161/450 [08:20<14:56,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  36%|████████████████▌                             | 162/450 [08:23<14:53,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  36%|████████████████▋                             | 163/450 [08:26<14:50,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  36%|████████████████▊                             | 164/450 [08:29<14:47,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  37%|████████████████▊                             | 165/450 [08:32<14:43,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  37%|████████████████▉                             | 166/450 [08:35<14:41,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  37%|█████████████████                             | 167/450 [08:38<14:37,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  37%|█████████████████▏                            | 168/450 [08:41<14:34,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  38%|█████████████████▎                            | 169/450 [08:44<14:31,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  38%|█████████████████▍                            | 170/450 [08:48<14:28,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  38%|█████████████████▍                            | 171/450 [08:51<14:25,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  38%|█████████████████▌                            | 172/450 [08:54<14:22,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  38%|█████████████████▋                            | 173/450 [08:57<14:19,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  39%|█████████████████▊                            | 174/450 [09:00<14:16,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  39%|█████████████████▉                            | 175/450 [09:03<14:13,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  39%|█████████████████▉                            | 176/450 [09:06<14:10,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  39%|██████████████████                            | 177/450 [09:09<14:07,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  40%|██████████████████▏                           | 178/450 [09:12<14:04,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  40%|██████████████████▎                           | 179/450 [09:15<14:01,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  40%|██████████████████▍                           | 180/450 [09:19<13:57,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  40%|██████████████████▌                           | 181/450 [09:22<13:54,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  40%|██████████████████▌                           | 182/450 [09:25<13:51,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  41%|██████████████████▋                           | 183/450 [09:28<13:48,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  41%|██████████████████▊                           | 184/450 [09:31<13:45,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  41%|██████████████████▉                           | 185/450 [09:34<13:42,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  41%|███████████████████                           | 186/450 [09:37<13:39,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  42%|███████████████████                           | 187/450 [09:40<13:36,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  42%|███████████████████▏                          | 188/450 [09:43<13:33,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  42%|███████████████████▎                          | 189/450 [09:47<13:29,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  42%|███████████████████▍                          | 190/450 [09:50<13:26,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  42%|███████████████████▌                          | 191/450 [09:53<13:23,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  43%|███████████████████▋                          | 192/450 [09:56<13:20,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  43%|███████████████████▋                          | 193/450 [09:59<13:17,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  43%|███████████████████▊                          | 194/450 [10:02<13:14,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  43%|███████████████████▉                          | 195/450 [10:05<13:11,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  44%|████████████████████                          | 196/450 [10:08<13:08,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  44%|████████████████████▏                         | 197/450 [10:11<13:05,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  44%|████████████████████▏                         | 198/450 [10:14<13:01,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  44%|████████████████████▎                         | 199/450 [10:18<12:58,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  44%|████████████████████▍                         | 200/450 [10:21<12:55,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  45%|████████████████████▌                         | 201/450 [10:24<12:52,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  45%|████████████████████▋                         | 202/450 [10:27<12:49,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  45%|████████████████████▊                         | 203/450 [10:30<12:46,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  45%|████████████████████▊                         | 204/450 [10:33<12:43,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  46%|████████████████████▉                         | 205/450 [10:36<12:40,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  46%|█████████████████████                         | 206/450 [10:39<12:37,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  46%|█████████████████████▏                        | 207/450 [10:42<12:34,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  46%|█████████████████████▎                        | 208/450 [10:45<12:30,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  46%|█████████████████████▎                        | 209/450 [10:49<12:27,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  47%|█████████████████████▍                        | 210/450 [10:52<12:24,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  47%|█████████████████████▌                        | 211/450 [10:55<12:21,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  47%|█████████████████████▋                        | 212/450 [10:58<12:18,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  47%|█████████████████████▊                        | 213/450 [11:01<12:15,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  48%|█████████████████████▉                        | 214/450 [11:04<12:12,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  48%|█████████████████████▉                        | 215/450 [11:07<12:09,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  48%|██████████████████████                        | 216/450 [11:10<12:06,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  48%|██████████████████████▏                       | 217/450 [11:13<12:02,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  48%|██████████████████████▎                       | 218/450 [11:17<11:59,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  49%|██████████████████████▍                       | 219/450 [11:20<11:56,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  49%|██████████████████████▍                       | 220/450 [11:23<11:53,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  49%|██████████████████████▌                       | 221/450 [11:26<11:50,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  49%|██████████████████████▋                       | 222/450 [11:29<11:47,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  50%|██████████████████████▊                       | 223/450 [11:32<11:44,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  50%|██████████████████████▉                       | 224/450 [11:35<11:41,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  50%|███████████████████████                       | 225/450 [11:38<11:38,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  50%|███████████████████████                       | 226/450 [11:41<11:34,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  50%|███████████████████████▏                      | 227/450 [11:44<11:31,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  51%|███████████████████████▎                      | 228/450 [11:48<11:28,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  51%|███████████████████████▍                      | 229/450 [11:51<11:25,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  51%|███████████████████████▌                      | 230/450 [11:54<11:22,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  51%|███████████████████████▌                      | 231/450 [11:57<11:19,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  52%|███████████████████████▋                      | 232/450 [12:00<11:16,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  52%|███████████████████████▊                      | 233/450 [12:03<11:13,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  52%|███████████████████████▉                      | 234/450 [12:06<11:10,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  52%|████████████████████████                      | 235/450 [12:09<11:07,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  52%|████████████████████████                      | 236/450 [12:12<11:03,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  53%|████████████████████████▏                     | 237/450 [12:15<11:00,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  53%|████████████████████████▎                     | 238/450 [12:19<10:57,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  53%|████████████████████████▍                     | 239/450 [12:22<10:54,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  53%|████████████████████████▌                     | 240/450 [12:25<10:52,  3.11s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  54%|████████████████████████▋                     | 241/450 [12:28<10:49,  3.11s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  54%|████████████████████████▋                     | 242/450 [12:31<10:45,  3.11s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  54%|████████████████████████▊                     | 243/450 [12:34<10:42,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  54%|████████████████████████▉                     | 244/450 [12:37<10:39,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  54%|█████████████████████████                     | 245/450 [12:40<10:36,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  55%|█████████████████████████▏                    | 246/450 [12:43<10:33,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  55%|█████████████████████████▏                    | 247/450 [12:47<10:30,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  55%|█████████████████████████▎                    | 248/450 [12:50<10:27,  3.11s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  55%|█████████████████████████▍                    | 249/450 [12:53<10:24,  3.11s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  56%|█████████████████████████▌                    | 250/450 [12:56<10:20,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  56%|█████████████████████████▋                    | 251/450 [12:59<10:17,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  56%|█████████████████████████▊                    | 252/450 [13:02<10:14,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  56%|█████████████████████████▊                    | 253/450 [13:05<10:11,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  56%|█████████████████████████▉                    | 254/450 [13:08<10:08,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  57%|██████████████████████████                    | 255/450 [13:11<10:05,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  57%|██████████████████████████▏                   | 256/450 [13:14<10:01,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  57%|██████████████████████████▎                   | 257/450 [13:18<09:58,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  57%|██████████████████████████▎                   | 258/450 [13:21<09:55,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  58%|██████████████████████████▍                   | 259/450 [13:24<09:52,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  58%|██████████████████████████▌                   | 260/450 [13:27<09:49,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  58%|██████████████████████████▋                   | 261/450 [13:30<09:46,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  58%|██████████████████████████▊                   | 262/450 [13:33<09:43,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  58%|██████████████████████████▉                   | 263/450 [13:36<09:40,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  59%|██████████████████████████▉                   | 264/450 [13:39<09:37,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  59%|███████████████████████████                   | 265/450 [13:42<09:34,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  59%|███████████████████████████▏                  | 266/450 [13:45<09:31,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  59%|███████████████████████████▎                  | 267/450 [13:49<09:27,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  60%|███████████████████████████▍                  | 268/450 [13:52<09:24,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  60%|███████████████████████████▍                  | 269/450 [13:55<09:21,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  60%|███████████████████████████▌                  | 270/450 [13:58<09:30,  3.17s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  60%|███████████████████████████▋                  | 271/450 [14:01<09:23,  3.15s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  60%|███████████████████████████▊                  | 272/450 [14:04<09:18,  3.13s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  61%|███████████████████████████▉                  | 273/450 [14:07<09:13,  3.12s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  61%|████████████████████████████                  | 274/450 [14:11<09:08,  3.12s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  61%|████████████████████████████                  | 275/450 [14:14<09:04,  3.11s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  61%|████████████████████████████▏                 | 276/450 [14:17<09:01,  3.11s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  62%|████████████████████████████▎                 | 277/450 [14:20<08:57,  3.11s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  62%|████████████████████████████▍                 | 278/450 [14:23<08:54,  3.11s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  62%|████████████████████████████▌                 | 279/450 [14:26<08:50,  3.11s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  62%|████████████████████████████▌                 | 280/450 [14:29<08:47,  3.11s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  62%|████████████████████████████▋                 | 281/450 [14:32<08:44,  3.11s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  63%|████████████████████████████▊                 | 282/450 [14:35<08:41,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  63%|████████████████████████████▉                 | 283/450 [14:38<08:38,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  63%|█████████████████████████████                 | 284/450 [14:42<08:35,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  63%|█████████████████████████████▏                | 285/450 [14:45<08:31,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  64%|█████████████████████████████▏                | 286/450 [14:48<08:28,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  64%|█████████████████████████████▎                | 287/450 [14:51<08:25,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  64%|█████████████████████████████▍                | 288/450 [14:54<08:22,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  64%|█████████████████████████████▌                | 289/450 [14:57<08:19,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  64%|█████████████████████████████▋                | 290/450 [15:00<08:16,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  65%|█████████████████████████████▋                | 291/450 [15:03<08:13,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  65%|█████████████████████████████▊                | 292/450 [15:06<08:10,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  65%|█████████████████████████████▉                | 293/450 [15:09<08:07,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  65%|██████████████████████████████                | 294/450 [15:13<08:04,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  66%|██████████████████████████████▏               | 295/450 [15:16<08:01,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  66%|██████████████████████████████▎               | 296/450 [15:19<07:57,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  66%|██████████████████████████████▎               | 297/450 [15:22<07:54,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  66%|██████████████████████████████▍               | 298/450 [15:25<07:51,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  66%|██████████████████████████████▌               | 299/450 [15:28<07:49,  3.11s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  67%|██████████████████████████████▋               | 300/450 [15:31<07:45,  3.11s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  67%|██████████████████████████████▊               | 301/450 [15:34<07:42,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  67%|██████████████████████████████▊               | 302/450 [15:37<07:39,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  67%|██████████████████████████████▉               | 303/450 [15:41<07:36,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  68%|███████████████████████████████               | 304/450 [15:44<07:33,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  68%|███████████████████████████████▏              | 305/450 [15:47<07:30,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  68%|███████████████████████████████▎              | 306/450 [15:50<07:26,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  68%|███████████████████████████████▍              | 307/450 [15:53<07:23,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  68%|███████████████████████████████▍              | 308/450 [15:56<07:20,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  69%|███████████████████████████████▌              | 309/450 [15:59<07:17,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  69%|███████████████████████████████▋              | 310/450 [16:02<07:14,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  69%|███████████████████████████████▊              | 311/450 [16:05<07:11,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  69%|███████████████████████████████▉              | 312/450 [16:08<07:08,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  70%|███████████████████████████████▉              | 313/450 [16:12<07:04,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  70%|████████████████████████████████              | 314/450 [16:15<07:01,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  70%|████████████████████████████████▏             | 315/450 [16:18<06:59,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  70%|████████████████████████████████▎             | 316/450 [16:21<06:55,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  70%|████████████████████████████████▍             | 317/450 [16:24<06:52,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  71%|████████████████████████████████▌             | 318/450 [16:27<06:49,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  71%|████████████████████████████████▌             | 319/450 [16:30<06:46,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  71%|████████████████████████████████▋             | 320/450 [16:33<06:43,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  71%|████████████████████████████████▊             | 321/450 [16:36<06:40,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  72%|████████████████████████████████▉             | 322/450 [16:39<06:37,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  72%|█████████████████████████████████             | 323/450 [16:43<06:34,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  72%|█████████████████████████████████             | 324/450 [16:46<06:31,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  72%|█████████████████████████████████▏            | 325/450 [16:49<06:28,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  72%|█████████████████████████████████▎            | 326/450 [16:52<06:24,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  73%|█████████████████████████████████▍            | 327/450 [16:55<06:21,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  73%|█████████████████████████████████▌            | 328/450 [16:58<06:18,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  73%|█████████████████████████████████▋            | 329/450 [17:01<06:15,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  73%|█████████████████████████████████▋            | 330/450 [17:04<06:12,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  74%|█████████████████████████████████▊            | 331/450 [17:07<06:09,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  74%|█████████████████████████████████▉            | 332/450 [17:11<06:06,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  74%|██████████████████████████████████            | 333/450 [17:14<06:03,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  74%|██████████████████████████████████▏           | 334/450 [17:17<05:59,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  74%|██████████████████████████████████▏           | 335/450 [17:20<05:56,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  75%|██████████████████████████████████▎           | 336/450 [17:23<05:53,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  75%|██████████████████████████████████▍           | 337/450 [17:26<05:50,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  75%|██████████████████████████████████▌           | 338/450 [17:29<05:47,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  75%|██████████████████████████████████▋           | 339/450 [17:32<05:44,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  76%|██████████████████████████████████▊           | 340/450 [17:35<05:41,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  76%|██████████████████████████████████▊           | 341/450 [17:38<05:38,  3.11s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  76%|██████████████████████████████████▉           | 342/450 [17:42<05:35,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  76%|███████████████████████████████████           | 343/450 [17:45<05:32,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  76%|███████████████████████████████████▏          | 344/450 [17:48<05:28,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  77%|███████████████████████████████████▎          | 345/450 [17:51<05:25,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  77%|███████████████████████████████████▎          | 346/450 [17:54<05:22,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

Collect means:  77%|███████████████████████████████████▍          | 347/450 [17:57<05:19,  3.10s/it]

  0%|          | 0/20 [00:00<?, ?it/s]

# Phase B – Mean Activation Ablation
This phase tests how each layer behaves when its activations are replaced with the mean from Phase A.
- Registers new hooks that overwrite each FFN's output with its mean at time t.
- Generates baseline and mean-ablated images for the first 5 prompts
- Saves both versions for visual comparison.
- Displays side-by-side comparison using matplotlib.
Output Directory: exp3/results/

In [None]:
# ==============================================================
#  PHASE B – baselines first, then mean‑ablation, then clean‑up
# ==============================================================

N_PROMPTS_TO_SAVE = 5
RESULTS_DIR = OUTDIR / "results"; RESULTS_DIR.mkdir(exist_ok=True)

def _slug(s, n=35):
    s = re.sub(r"[^\w\s-]", "", s).strip().lower()
    return re.sub(r"[\s_-]+", "_", s)[:n]

# 1) clean baselines
for i, prompt in enumerate(PROMPTS[:N_PROMPTS_TO_SAVE]):
    fname = f"baseline_{i:02d}_{_slug(prompt)}.png"
    generate(prompt, SEED+i, fname)         # generate() has NO hook toggling
    (OUTDIR/fname).replace(RESULTS_DIR/fname)
print("Baselines saved.")

# 2) attach mean‑ablation hooks
def make_mean_hook(level):
    def _hook(_, __, out):
        t = step_counter["t"]
        if t < N_STEPS:
            m = means[level][t].to(out.device).to(out.dtype)
            return m.unsqueeze(0).expand_as(out)
        return out
    return _hook

H_MEAN = [blk.ff.register_forward_hook(make_mean_hook(lvl))
          for lvl, blk in enumerate(pipe.transformer.transformer_blocks)]

# 3) mean‑ablated images
for i, prompt in enumerate(PROMPTS[:N_PROMPTS_TO_SAVE]):
    fname = f"mean_{i:02d}_{_slug(prompt)}.png"
    generate(prompt, SEED+i, fname)          # hooks are active
    (OUTDIR/fname).replace(RESULTS_DIR/fname)
print("Mean‑ablation images saved.")

# 4) preview first pair
fig, ax = plt.subplots(1,2, figsize=(10,5))
ax[0].imshow(Image.open(RESULTS_DIR / "baseline_00_" + _slug(PROMPTS[0]) + ".png")); ax[0].set_title("Baseline"); ax[0].axis("off")
ax[1].imshow(Image.open(RESULTS_DIR / "mean_00_" + _slug(PROMPTS[0]) + ".png"));     ax[1].set_title("Mean Ablation"); ax[1].axis("off")
plt.tight_layout(); plt.show()

# 5) remove mean‑ablation hooks
for h in H_MEAN: h.remove()
H_MEAN.clear()
print_hook_counts()   # optional debug print


# Phase C – Left↔Right Swap
Here, the mean maps are spatially manipulated by swapping left and right halves, simulating altered visual attention.
- Previous mean hooks are removed.
- New swap maps are created (L×T spatial tensors with left-right halves flipped).
- Hooks overwrite the FFN outputs with these swapped maps.
- Two images are generated for a single test prompt: baseline vs. swapped.
- Shows the side-by-side result.
Output Directory: exp3/results/

In [None]:
# ==============================================================
#  PHASE C – Left ↔ Right swap for a batch of prompts
# ==============================================================

import re, tqdm, matplotlib.pyplot as plt
from PIL import Image
import torch, gc

# --------‑ user controls --------------------------------------
N_PROMPTS_C   = 5                 # how many prompts to process
RESULTS_DIR   = OUTDIR / "results"; RESULTS_DIR.mkdir(exist_ok=True)

def _slug(s, n=40):
    s = re.sub(r"[^\w\s-]", "", s).strip().lower()
    return re.sub(r"[\s_-]+", "_", s)[:n]

# --------‑ 0. make sure model is hook‑free --------------------
for lst in ("H_MEAN", "H_SWAP"):
    if lst in globals():
        for h in globals()[lst]:
            h.remove()
        globals()[lst].clear()

# --------‑ 1. pre‑compute left↔right‑swapped mean maps --------
swap_maps = []
for level_row in means:        # over transformer levels
    new_row = []
    for m in level_row:        # m : (C,H,W)
        C, H, W = m.shape; split = W // 2
        s = m.clone()
        s[:, :, :split], s[:, :, split:] = s[:, :, split:], s[:, :, :split]
        new_row.append(s)
    swap_maps.append(new_row)

# --------‑ 2. hook helpers ------------------------------------
H_SWAP = []
def make_swap_hook(level):
    def _hook(_, __, out):
        t = step_counter["t"]
        if t < N_STEPS:
            m = swap_maps[level][t].to(out.device).to(out.dtype)
            return m.unsqueeze(0).expand_as(out)
        return out
    return _hook

def attach_swap_hooks():
    global H_SWAP
    H_SWAP = [
        blk.ff.register_forward_hook(make_swap_hook(lvl))
        for lvl, blk in enumerate(pipe.transformer.transformer_blocks)
    ]

def detach_swap_hooks():
    for h in H_SWAP: h.remove()
    H_SWAP.clear()

# --------‑ 3. run over prompts -------------------------------
for i, prompt in enumerate(tqdm.tqdm(PROMPTS[:N_PROMPTS_C], desc="Phase C")):
    base_png = f"baseline_swap_{i:02d}_{_slug(prompt)}.png"
    var_png  = f"swap_ablate_{i:02d}_{_slug(prompt)}.png"

    # -- baseline (no hooks) --
    detach_swap_hooks()
    generate(prompt, SEED+i, base_png)
    (OUTDIR/base_png).replace(RESULTS_DIR/base_png)

    # -- variant (attach hooks, run, detach) --
    attach_swap_hooks()
    generate(prompt, SEED+i, var_png)
    (OUTDIR/var_png).replace(RESULTS_DIR/var_png)
    detach_swap_hooks()

    # optional VRAM tidy‑up for long loops
    gc.collect(); torch.cuda.empty_cache()

print(f"Saved {2*N_PROMPTS_C} images → {RESULTS_DIR.resolve()}")

# --------‑ 4. preview first pair ------------------------------
img_b = Image.open(RESULTS_DIR / f"baseline_swap_00_{_slug(PROMPTS[0])}.png")
img_v = Image.open(RESULTS_DIR / f"swap_ablate_00_{_slug(PROMPTS[0])}.png")

fig, ax = plt.subplots(1, 2, figsize=(10, 5))
ax[0].imshow(img_b); ax[0].set_title("Baseline");        ax[0].axis("off")
ax[1].imshow(img_v); ax[1].set_title("Left↔Right swap"); ax[1].axis("off")
fig.suptitle(PROMPTS[0], wrap=True); plt.tight_layout(); plt.show()


# Phase D – Map Transform Variants
This section generalizes the spatial transformations applied to the FFN maps using various strategies like:
- "first_token" – replicate the first token over the map
- "swap_tb" – swap top and bottom halves
- "shuffle" – random shuffle
- "constant" – constant average value
- build_maps() creates the transformed maps.
- attach_map_hooks() binds them to FFNs.
- Baseline and transformed images are generated and saved.
- Displays visual comparison.
Output Directory: exp3/

In [None]:
# ==============================================================
#  build_maps() supporting row/col replication
# ==============================================================
def build_maps(means, transform="first_token", n_tokens=3, seed=0):
    """
    returns a list-of-lists [level][t] of modified maps (C,H,W).

    Parameters
    ----------
    transform : str
        one of {"first_token", "first_n_tokens", "first_row", "first_col",
                "swap_lr", "swap_tb", "swap_quadrants", "shuffle", "constant"}
    n_tokens : int
        used only for 'first_n_tokens' – how many top‑left tokens to average.
    seed : int
        used only for 'shuffle'.
    """
    out = []
    rng = torch.Generator(device="cpu").manual_seed(seed)

    for level_row in means:
        new_row = []
        for m in level_row:                        # m : (C,H,W)
            C, H, W = m.shape
            mm = m.clone()

            if transform == "first_token":
                tok = mm[:, 0, 0].unsqueeze(-1).unsqueeze(-1)  # (C,1,1)
                mm = tok.expand_as(mm)

            elif transform == "first_n_tokens":
                # take n_tokens from the top‑left corner row‑major
                idx = [(r, c) for r in range(H) for c in range(W)][:n_tokens]
                mean_tok = torch.stack([mm[:, r, c] for r, c in idx], dim=0).mean(0)
                mm = mean_tok.view(C, 1, 1).expand_as(mm)

            elif transform == "first_row":
                row_mean = mm[:, 0, :].mean(1, keepdim=True)   # (C,1)
                mm = row_mean.unsqueeze(2).expand_as(mm)

            elif transform == "first_col":
                col_mean = mm[:, :, 0].mean(1, keepdim=True)   # (C,1)
                mm = col_mean.unsqueeze(1).expand_as(mm)

            elif transform == "swap_lr":
                split = W // 2
                mm[:, :, :split], mm[:, :, split:] = mm[:, :, split:], mm[:, :, :split]

            elif transform == "swap_tb":
                split = H // 2
                mm[:, :split, :], mm[:, split:, :] = mm[:, split:, :], mm[:, :split, :]

            elif transform == "swap_quadrants":
                h2, w2 = H // 2, W // 2
                tl, tr = mm[:, :h2, :w2].clone(), mm[:, :h2, w2:].clone()
                bl, br = mm[:, h2:, :w2].clone(), mm[:, h2:, w2:].clone()
                mm[:, :h2, :w2], mm[:, :h2, w2:], mm[:, h2:, w2:], mm[:, h2:, :w2] = \
                    bl, tl, tr, br

            elif transform == "shuffle":
                idx = torch.randperm(H * W, generator=rng)
                mm = mm.flatten(1)[:, idx].view(C, H, W)

            elif transform == "constant":
                mm[:] = mm.mean()

            else:
                raise ValueError(f"Unknown transform '{transform}'")

            new_row.append(mm)
        out.append(new_row)
    return out


In [None]:
# ==============================================================
#  HELPER: attach hooks for a given transformed map set
# ==============================================================
def attach_map_hooks(map_set):
    """Remove any existing H_SWAP hooks and attach new ones for map_set."""
    global H_SWAP
    # remove old swap hooks if they exist
    if "H_SWAP" in globals():
        for h in H_SWAP:
            h.remove()
    H_SWAP = []

    def make_hook(level):
        def _hook(_, __, out):
            t = step_counter["t"]
            if t >= N_STEPS:
                return out
            mp = map_set[level][t].to(out.device).to(out.dtype)
            return mp.unsqueeze(0).expand_as(out)   # (B, C, H, W)
        return _hook

    for lvl, blk in enumerate(pipe.transformer.transformer_blocks):
        H_SWAP.append(blk.ff.register_forward_hook(make_hook(lvl)))


In [None]:
# ==============================================================
#  PHASE D – variant map transform for a batch of prompts
# ==============================================================

import re, tqdm, matplotlib.pyplot as plt
from PIL import Image
import torch, gc

# ------------ experiment parameters ---------------------------
transform_choice = "swap_tb"          # e.g. "first_token", "shuffle", …
n_tokens   = 3                        # only for "first_n_tokens"
rng_seed   = 123                      # only for "shuffle"
N_PROMPTS_D = 5                       # how many prompts to run

RESULTS_DIR = OUTDIR / "results"; RESULTS_DIR.mkdir(exist_ok=True)

def _slug(s, n=40):
    s = re.sub(r"[^\w\s-]", "", s).strip().lower()
    return re.sub(r"[\s_-]+", "_", s)[:n]

# ---------- 0. make sure model starts clean -------------------
for lst in ("H_SWAP", "H_MEAN"):
    if lst in globals():
        for h in globals()[lst]:
            h.remove()
        globals()[lst].clear()

# ---------- 1. build variant map set once --------------------
variant_maps = build_maps(
    means,
    transform=transform_choice,
    n_tokens=n_tokens,
    seed=rng_seed
)

# ---------- 2. create hook factory ---------------------------
def make_variant_hook(level):
    def _hook(_, __, out):
        t = step_counter["t"]
        if t < N_STEPS:
            m = variant_maps[level][t].to(out.device).to(out.dtype)
            return m.unsqueeze(0).expand_as(out)
        return out
    return _hook

# helper to attach & detach variant hooks
def attach_swap_hooks():
    global H_SWAP
    H_SWAP = [
        blk.ff.register_forward_hook(make_variant_hook(lvl))
        for lvl, blk in enumerate(pipe.transformer.transformer_blocks)
    ]
def detach_swap_hooks():
    for h in globals().get("H_SWAP", []):
        h.remove()
    H_SWAP.clear()

# ---------- 3. run through prompts ---------------------------
for i, prompt in enumerate(tqdm.tqdm(PROMPTS[:N_PROMPTS_D], desc="Phase D")):
    base_png  = f"baseline_{transform_choice}_{i:02d}_{_slug(prompt)}.png"
    var_png   = f"{transform_choice}_{i:02d}_{_slug(prompt)}.png"

    # --- baseline ---
    detach_swap_hooks()                            # ensure no hooks
    generate(prompt, SEED+i, base_png)
    (OUTDIR/base_png).replace(RESULTS_DIR/base_png)

    # --- variant ---
    attach_swap_hooks()                            # hooks ON
    generate(prompt, SEED+i, var_png)
    (OUTDIR/var_png).replace(RESULTS_DIR/var_png)

    detach_swap_hooks()                            # clean for next loop

    # optional GPU tidy‑up if many prompts
    gc.collect(); torch.cuda.empty_cache()

print(f"Saved {2*N_PROMPTS_D} images → {RESULTS_DIR.resolve()}")

# ---------- 4. quick preview of the first pair ---------------
img_b = Image.open(RESULTS_DIR / f"baseline_{transform_choice}_00_{_slug(PROMPTS[0])}.png")
img_v = Image.open(RESULTS_DIR / f"{transform_choice}_00_{_slug(PROMPTS[0])}.png")

fig, ax = plt.subplots(1,2, figsize=(10,5))
ax[0].imshow(img_b); ax[0].set_title("Baseline");           ax[0].axis("off")
ax[1].imshow(img_v); ax[1].set_title(transform_choice);     ax[1].axis("off")
fig.suptitle(PROMPTS[0], wrap=True); plt.tight_layout(); plt.show()
