- Defines output directory (exp3/)
- Sets model ID, precision (fp16), number of diffusion steps (N_STEPS), transformer levels (N_LEVELS), and random seed.
- Loads the SANA model with weights from the Hugging Face repo.
- Pushes the model and VAE/text encoder to GPU with the right precision.

In [1]:
# ==============================================================
#  imports, global config, and pipeline load
#  ─────────────────────────────────────────────────────────────
#  • Sets basic experiment parameters (model id, #steps…)
#  • Loads the SANA pipeline on GPU, casting weights to fp16/bf16
# ==============================================================
import math, itertools, pathlib
import torch, tqdm 
from diffusers import SanaPipeline
import re, matplotlib.pyplot as plt
from PIL import Image

  @torch.library.impl_abstract("xformers_flash::flash_fwd")
  @torch.library.impl_abstract("xformers_flash::flash_bwd")
A matching Triton is not available, some optimizations will not be enabled
Traceback (most recent call last):
  File "/home/galkesten/miniconda3/envs/sana/lib/python3.10/site-packages/xformers/__init__.py", line 57, in _is_triton_available
    import triton  # noqa
  File "/home/galkesten/miniconda3/envs/sana/lib/python3.10/site-packages/triton/__init__.py", line 8, in <module>
    from .runtime import (
  File "/home/galkesten/miniconda3/envs/sana/lib/python3.10/site-packages/triton/runtime/__init__.py", line 1, in <module>
    from .autotuner import (Autotuner, Config, Heuristics, autotune, heuristics)
  File "/home/galkesten/miniconda3/envs/sana/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 9, in <module>
    from ..testing import do_bench, do_bench_cudagraph
  File "/home/galkesten/miniconda3/envs/sana/lib/python3.10/site-packages/triton/testing.

In [2]:
# ---------- directory where all outputs will be saved ----------
OUTDIR = pathlib.Path("exp3")             # folder name
OUTDIR.mkdir(exist_ok=True)                  # create if it doesn't exist

# ---------- experiment hyper‑parameters -----------------------
MODEL_ID   = "Efficient-Large-Model/Sana_1600M_1024px_diffusers"
DTYPE      = torch.float16                   # weights dtype for GPU
DEVICE     = "cuda"                          # or "mps" / "cpu" if needed
N_STEPS    = 20                              # diffusion steps  (T)
N_LEVELS   = 20                              # transformer blocks (L)
SEED       = 42                              # reproducible randomness
PROMPT_TXT = "sana_position_prompts.txt"         # text file – one prompt / line


# ---------- load SANA pipeline --------------------------------
pipe = (
    SanaPipeline           # diffusers class
    .from_pretrained(MODEL_ID, variant="fp16", torch_dtype=DTYPE)
    .to(DEVICE)            # push entire pipeline to GPU
)

pipe.vae         .to(torch.bfloat16)
pipe.text_encoder.to(torch.bfloat16)
pipe.scheduler.set_timesteps(N_STEPS)

Fetching 14 files:   0%|          | 0/14 [00:00<?, ?it/s]

text_encoder/model.fp16-00001-of-00002.s(…):   0%|          | 0.00/4.99G [00:00<?, ?B/s]

vae/diffusion_pytorch_model.fp16.safeten(…):   0%|          | 0.00/1.25G [00:00<?, ?B/s]

transformer/diffusion_pytorch_model.fp16(…):   0%|          | 0.00/3.21G [00:00<?, ?B/s]

text_encoder/model.fp16-00002-of-00002.s(…):   0%|          | 0.00/241M [00:00<?, ?B/s]

Loading pipeline components...:   0%|          | 0/5 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

This part prepares tools for running and tracking experiments:
- Tracks which timestep (t) the model is currently processing via a forward_pre_hook.
- Loads text prompts from the file sana_position_prompts.txt.
- Defines a generate() helper function to run the pipeline and save results to disk.

In [3]:
# ==============================================================
#  utility hooks / helpers
#  ─────────────────────────────────────────────────────────────
#  • step_counter: lets us know which diffusion step (t) we’re on
#  • PROMPTS: list of prompts loaded from the text file
#  • generate(): helper to run the pipeline and save an image
# ==============================================================
# -------- 1. step counter (pre‑hook on entire transformer) ----
step_counter = {"t": -1}                     # wrapped in dict for mutability

def _count_steps(_, __):
    """Executed *before* each denoising step; increments global t."""
    step_counter["t"] += 1

pipe.transformer.register_forward_pre_hook(_count_steps)

# -------- 2. load prompts -------------------------------------
with open(PROMPT_TXT) as f:
    PROMPTS = [ln.strip() for ln in f if ln.strip()]
print("Loaded", len(PROMPTS), "prompts")

# -------- 3. generate & save helper ---------------------------
def generate(prompt: str,
             seed: int,
             filename: str,
             ):
    """
    Runs the pipeline once and writes PNG → OUTDIR/filename
    If hooks_enabled=False the model runs normally; otherwise
    already‑registered hooks (mean or swap) will manipulate outputs.
    """
    step_counter["t"] = -1                   # reset timestep counter

    # Deterministic generator for reproducibility
    gen = torch.Generator(device=DEVICE).manual_seed(seed)

    # Run diffusion
    img = pipe(
        prompt,
        height=1024, width=1024,
        num_inference_steps=N_STEPS,
        guidance_scale=5.0,                 # classifier‑free guidance weight
        generator=gen
    )[0][0]                                    # diffusers returns (images, …)
    img.save(OUTDIR / filename)
    return img                              # return PIL.Image for display


Loaded 453 prompts


# Phase A – Mean Activation Collection
- This phase records the average output of every FFN block at every timestep across all prompts.
- Hooks are registered to all FFN modules.
- For each prompt, we run the pipeline and record FFN outputs.
- After all prompts, the average activation per layer and timestep is computed.
- Final means are saved as ffn_mean_maps.pt.--> exp3/ffn_mean_maps.pt

In [4]:
def print_hook_counts():
    counts = [
        len(mod._forward_hooks)          # PyTorch stores existing hooks here
        for mod in pipe.transformer.transformer_blocks
    ]
    print("Hooks per block:", counts)
    print("Total hooks    :", sum(counts))


In [8]:
# ==============================================================
#  Phase A: collect ⟨level, t⟩ mean maps
#  ─────────────────────────────────────────────────────────────
#  • For each prompt, record the *output* of every FFN block
#    at every timestep; average across prompts to build a
#    stable “mean map” tensor.
# ==============================================================
# ---------- 1. storage containers -----------------------------
means = [[None] * N_STEPS for _ in range(N_LEVELS)]   # 0‑based: 0‥N_STEPS‑1
n_prompts_seen = 0

def make_record_hook(level: int):
    """
    Returns a hook that, for its transformer *level*, records the
    average token vector produced by the FFN at every diffusion step.
    """
    def _hook(_, __, out):
        print(out.shape)
        t = step_counter["t"]               # current denoising step (0‑based)

        # ------- NEW: skip the *extra* call after step N_STEPS‑1 ----------
        if t >= N_STEPS:
            return out                      # do nothing, just pass output on

        avg_tokens = out.detach().float().mean(0)  # (seq_len, hidden)
        print(avg_tokens.shape)
        # accumulate
        if means[level][t] is None:
            means[level][t] = avg_tokens
        else:
            means[level][t] += avg_tokens
        return out                          # always return the original output
    return _hook

# ---------- 2. attach recorders to each FFN --------------------
H_REC = [
    blk.ff.register_forward_hook(make_record_hook(lvl))
    for   lvl, blk in enumerate(pipe.transformer.transformer_blocks)
]
print_hook_counts()
# ---------- 3. run pipeline over *all* prompts -----------------
for prompt in tqdm.tqdm(PROMPTS, desc="Collect means"):
    step_counter["t"] = -1                                # reset
    seed_this_run = SEED + n_prompts_seen                 # new seed each run
    _ = pipe(
        prompt,
        height=1024, width=1024,
        num_inference_steps=N_STEPS,
        guidance_scale=5.0,
        generator=torch.Generator(device=DEVICE).manual_seed(seed_this_run)
    )
    n_prompts_seen += 1

# ---------- 4. finalise means and clean up ---------------------
for h in H_REC:
    h.remove()                                            # detach hooks
print_hook_counts()
H_MEAN.clear()
print("Phase B hooks removed; model is clean for Phase C.")
# divide by #prompts to get the actual mean
means = [[m / n_prompts_seen for m in row] for row in means]

torch.save(means, OUTDIR / "ffn_mean_maps.pt")
print("Mean maps saved →", OUTDIR / "ffn_mean_maps.pt")

Hooks per block: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
Total hooks    : 0


Collect means:   0%|          | 0/453 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

torch.Size([2, 2240, 32, 32])
torch.Size([2, 2240, 32, 32])
tensor([[[-1.1587, -0.7791, -0.5500,  ..., -1.7427, -0.6382, -0.7141],
         [-1.8701,  0.2274,  0.3870,  ...,  0.3434,  0.3674, -1.0171],
         [-1.4390, -0.1783,  0.0144,  ..., -0.1049,  0.7380, -1.3589],
         ...,
         [-1.1670,  0.6030,  0.2031,  ...,  0.8696,  1.0269, -1.2466],
         [-2.2441,  0.3792,  0.0813,  ...,  0.8599,  0.7861, -1.1982],
         [-2.8193, -0.5774,  0.1285,  ...,  0.2087, -0.5288, -1.5366]],

        [[ 0.9160,  0.2396,  1.2173,  ...,  0.6240,  0.4344,  0.1884],
         [ 1.8223, -0.4266, -0.7600,  ..., -0.4991, -0.9399,  0.5029],
         [ 0.3053,  0.1715, -0.4572,  ...,  0.2915, -0.5502, -0.3055],
         ...,
         [-0.1031, -1.1934,  0.1425,  ..., -0.3982, -0.4521, -0.7144],
         [-1.2153,  0.2515,  0.2111,  ..., -0.7776, -0.3843,  0.5913],
         [-0.5469, -0.7197, -0.2647,  ..., -0.2552,  0.4257,  0.2263]],

        [[ 1.4360, -0.4835,  0.5386,  ...,  2.0449,  0.2

Collect means:   0%|          | 0/453 [00:02<?, ?it/s]

tensor([[[ -6.2168, -13.6562, -19.9609,  ..., -11.4375, -20.2578, -21.2031],
         [ -7.6289,  -1.0020,  -9.8164,  ...,  -9.8672,  -8.4961, -17.2031],
         [-16.7891,  -2.1250, -16.3906,  ..., -10.0820,  -8.2891, -11.3594],
         ...,
         [-14.5039,  -2.2627, -13.1133,  ...,  -6.3438,  -8.7656, -12.8945],
         [  2.6914,   7.8789, -14.8281,  ...,  -7.0078,  -7.2324, -22.6562],
         [  6.2266,   0.5043,  -0.7498,  ...,   3.6357,  -1.9219,  -8.6562]],

        [[  9.3164,  -3.2607,   2.5986,  ...,  -0.3719,   6.9609,  -2.7305],
         [ 15.9492,  -6.5898,  -8.1348,  ...,   2.1079,   1.1616,   1.6416],
         [  6.6133,  -7.4746,  -4.3691,  ...,   5.3496,   3.9961,   4.2021],
         ...,
         [  3.4248,  -5.2891,  -8.9141,  ...,  -7.4961,   2.5078,  -2.7568],
         [  9.6523,   8.4277,  -7.0020,  ..., -13.5586,  -3.8633,  -0.0577],
         [ 10.6523,  -8.3359, -11.3555,  ..., -12.6953, -13.7891, -20.9141]],

        [[ -2.7324, -14.7188,  -7.2285,  ...




KeyboardInterrupt: 

# Phase B – Mean Activation Ablation
This phase tests how each layer behaves when its activations are replaced with the mean from Phase A.
- Registers new hooks that overwrite each FFN's output with its mean at time t.
- Generates baseline and mean-ablated images for the first 5 prompts
- Saves both versions for visual comparison.
- Displays side-by-side comparison using matplotlib.
Output Directory: exp3/results/

In [None]:
# ==============================================================
#  PHASE B – baselines first, then mean‑ablation, then clean‑up
# ==============================================================

N_PROMPTS_TO_SAVE = 5
RESULTS_DIR = OUTDIR / "results"; RESULTS_DIR.mkdir(exist_ok=True)

def _slug(s, n=35):
    s = re.sub(r"[^\w\s-]", "", s).strip().lower()
    return re.sub(r"[\s_-]+", "_", s)[:n]

# 1) clean baselines
for i, prompt in enumerate(PROMPTS[:N_PROMPTS_TO_SAVE]):
    fname = f"baseline_{i:02d}_{_slug(prompt)}.png"
    generate(prompt, SEED+i, fname)         # generate() has NO hook toggling
    (OUTDIR/fname).replace(RESULTS_DIR/fname)
print("Baselines saved.")

# 2) attach mean‑ablation hooks
def make_mean_hook(level):
    def _hook(_, __, out):
        t = step_counter["t"]
        if t < N_STEPS:
            m = means[level][t].to(out.device).to(out.dtype)
            return m.unsqueeze(0).expand_as(out)
        return out
    return _hook

H_MEAN = [blk.ff.register_forward_hook(make_mean_hook(lvl))
          for lvl, blk in enumerate(pipe.transformer.transformer_blocks)]

# 3) mean‑ablated images
for i, prompt in enumerate(PROMPTS[:N_PROMPTS_TO_SAVE]):
    fname = f"mean_{i:02d}_{_slug(prompt)}.png"
    generate(prompt, SEED+i, fname)          # hooks are active
    (OUTDIR/fname).replace(RESULTS_DIR/fname)
print("Mean‑ablation images saved.")

# 4) preview first pair
fig, ax = plt.subplots(1,2, figsize=(10,5))
ax[0].imshow(Image.open(RESULTS_DIR / "baseline_00_" + _slug(PROMPTS[0]) + ".png")); ax[0].set_title("Baseline"); ax[0].axis("off")
ax[1].imshow(Image.open(RESULTS_DIR / "mean_00_" + _slug(PROMPTS[0]) + ".png"));     ax[1].set_title("Mean Ablation"); ax[1].axis("off")
plt.tight_layout(); plt.show()

# 5) remove mean‑ablation hooks
for h in H_MEAN: h.remove()
H_MEAN.clear()
print_hook_counts()   # optional debug print


# Phase C – Left↔Right Swap
Here, the mean maps are spatially manipulated by swapping left and right halves, simulating altered visual attention.
- Previous mean hooks are removed.
- New swap maps are created (L×T spatial tensors with left-right halves flipped).
- Hooks overwrite the FFN outputs with these swapped maps.
- Two images are generated for a single test prompt: baseline vs. swapped.
- Shows the side-by-side result.
Output Directory: exp3/results/

In [None]:
# ==============================================================
#  PHASE C – Left ↔ Right swap for a batch of prompts
# ==============================================================

import re, tqdm, matplotlib.pyplot as plt
from PIL import Image
import torch, gc

# --------‑ user controls --------------------------------------
N_PROMPTS_C   = 5                 # how many prompts to process
RESULTS_DIR   = OUTDIR / "results"; RESULTS_DIR.mkdir(exist_ok=True)

def _slug(s, n=40):
    s = re.sub(r"[^\w\s-]", "", s).strip().lower()
    return re.sub(r"[\s_-]+", "_", s)[:n]

# --------‑ 0. make sure model is hook‑free --------------------
for lst in ("H_MEAN", "H_SWAP"):
    if lst in globals():
        for h in globals()[lst]:
            h.remove()
        globals()[lst].clear()

# --------‑ 1. pre‑compute left↔right‑swapped mean maps --------
swap_maps = []
for level_row in means:        # over transformer levels
    new_row = []
    for m in level_row:        # m : (C,H,W)
        C, H, W = m.shape; split = W // 2
        s = m.clone()
        s[:, :, :split], s[:, :, split:] = s[:, :, split:], s[:, :, :split]
        new_row.append(s)
    swap_maps.append(new_row)

# --------‑ 2. hook helpers ------------------------------------
H_SWAP = []
def make_swap_hook(level):
    def _hook(_, __, out):
        t = step_counter["t"]
        if t < N_STEPS:
            m = swap_maps[level][t].to(out.device).to(out.dtype)
            return m.unsqueeze(0).expand_as(out)
        return out
    return _hook

def attach_swap_hooks():
    global H_SWAP
    H_SWAP = [
        blk.ff.register_forward_hook(make_swap_hook(lvl))
        for lvl, blk in enumerate(pipe.transformer.transformer_blocks)
    ]

def detach_swap_hooks():
    for h in H_SWAP: h.remove()
    H_SWAP.clear()

# --------‑ 3. run over prompts -------------------------------
for i, prompt in enumerate(tqdm.tqdm(PROMPTS[:N_PROMPTS_C], desc="Phase C")):
    base_png = f"baseline_swap_{i:02d}_{_slug(prompt)}.png"
    var_png  = f"swap_ablate_{i:02d}_{_slug(prompt)}.png"

    # -- baseline (no hooks) --
    detach_swap_hooks()
    generate(prompt, SEED+i, base_png)
    (OUTDIR/base_png).replace(RESULTS_DIR/base_png)

    # -- variant (attach hooks, run, detach) --
    attach_swap_hooks()
    generate(prompt, SEED+i, var_png)
    (OUTDIR/var_png).replace(RESULTS_DIR/var_png)
    detach_swap_hooks()

    # optional VRAM tidy‑up for long loops
    gc.collect(); torch.cuda.empty_cache()

print(f"Saved {2*N_PROMPTS_C} images → {RESULTS_DIR.resolve()}")

# --------‑ 4. preview first pair ------------------------------
img_b = Image.open(RESULTS_DIR / f"baseline_swap_00_{_slug(PROMPTS[0])}.png")
img_v = Image.open(RESULTS_DIR / f"swap_ablate_00_{_slug(PROMPTS[0])}.png")

fig, ax = plt.subplots(1, 2, figsize=(10, 5))
ax[0].imshow(img_b); ax[0].set_title("Baseline");        ax[0].axis("off")
ax[1].imshow(img_v); ax[1].set_title("Left↔Right swap"); ax[1].axis("off")
fig.suptitle(PROMPTS[0], wrap=True); plt.tight_layout(); plt.show()


# Phase D – Map Transform Variants
This section generalizes the spatial transformations applied to the FFN maps using various strategies like:
- "first_token" – replicate the first token over the map
- "swap_tb" – swap top and bottom halves
- "shuffle" – random shuffle
- "constant" – constant average value
- build_maps() creates the transformed maps.
- attach_map_hooks() binds them to FFNs.
- Baseline and transformed images are generated and saved.
- Displays visual comparison.
Output Directory: exp3/

In [None]:
# ==============================================================
#  build_maps() supporting row/col replication
# ==============================================================
def build_maps(means, transform="first_token", n_tokens=3, seed=0):
    """
    returns a list-of-lists [level][t] of modified maps (C,H,W).

    Parameters
    ----------
    transform : str
        one of {"first_token", "first_n_tokens", "first_row", "first_col",
                "swap_lr", "swap_tb", "swap_quadrants", "shuffle", "constant"}
    n_tokens : int
        used only for 'first_n_tokens' – how many top‑left tokens to average.
    seed : int
        used only for 'shuffle'.
    """
    out = []
    rng = torch.Generator(device="cpu").manual_seed(seed)

    for level_row in means:
        new_row = []
        for m in level_row:                        # m : (C,H,W)
            C, H, W = m.shape
            mm = m.clone()

            if transform == "first_token":
                tok = mm[:, 0, 0].unsqueeze(-1).unsqueeze(-1)  # (C,1,1)
                mm = tok.expand_as(mm)

            elif transform == "first_n_tokens":
                # take n_tokens from the top‑left corner row‑major
                idx = [(r, c) for r in range(H) for c in range(W)][:n_tokens]
                mean_tok = torch.stack([mm[:, r, c] for r, c in idx], dim=0).mean(0)
                mm = mean_tok.view(C, 1, 1).expand_as(mm)

            elif transform == "first_row":
                row_mean = mm[:, 0, :].mean(1, keepdim=True)   # (C,1)
                mm = row_mean.unsqueeze(2).expand_as(mm)

            elif transform == "first_col":
                col_mean = mm[:, :, 0].mean(1, keepdim=True)   # (C,1)
                mm = col_mean.unsqueeze(1).expand_as(mm)

            elif transform == "swap_lr":
                split = W // 2
                mm[:, :, :split], mm[:, :, split:] = mm[:, :, split:], mm[:, :, :split]

            elif transform == "swap_tb":
                split = H // 2
                mm[:, :split, :], mm[:, split:, :] = mm[:, split:, :], mm[:, :split, :]

            elif transform == "swap_quadrants":
                h2, w2 = H // 2, W // 2
                tl, tr = mm[:, :h2, :w2].clone(), mm[:, :h2, w2:].clone()
                bl, br = mm[:, h2:, :w2].clone(), mm[:, h2:, w2:].clone()
                mm[:, :h2, :w2], mm[:, :h2, w2:], mm[:, h2:, w2:], mm[:, h2:, :w2] = \
                    bl, tl, tr, br

            elif transform == "shuffle":
                idx = torch.randperm(H * W, generator=rng)
                mm = mm.flatten(1)[:, idx].view(C, H, W)

            elif transform == "constant":
                mm[:] = mm.mean()

            else:
                raise ValueError(f"Unknown transform '{transform}'")

            new_row.append(mm)
        out.append(new_row)
    return out


In [None]:
# ==============================================================
#  HELPER: attach hooks for a given transformed map set
# ==============================================================
def attach_map_hooks(map_set):
    """Remove any existing H_SWAP hooks and attach new ones for map_set."""
    global H_SWAP
    # remove old swap hooks if they exist
    if "H_SWAP" in globals():
        for h in H_SWAP:
            h.remove()
    H_SWAP = []

    def make_hook(level):
        def _hook(_, __, out):
            t = step_counter["t"]
            if t >= N_STEPS:
                return out
            mp = map_set[level][t].to(out.device).to(out.dtype)
            return mp.unsqueeze(0).expand_as(out)   # (B, C, H, W)
        return _hook

    for lvl, blk in enumerate(pipe.transformer.transformer_blocks):
        H_SWAP.append(blk.ff.register_forward_hook(make_hook(lvl)))


In [None]:
# ==============================================================
#  PHASE D – variant map transform for a batch of prompts
# ==============================================================

import re, tqdm, matplotlib.pyplot as plt
from PIL import Image
import torch, gc

# ------------ experiment parameters ---------------------------
transform_choice = "swap_tb"          # e.g. "first_token", "shuffle", …
n_tokens   = 3                        # only for "first_n_tokens"
rng_seed   = 123                      # only for "shuffle"
N_PROMPTS_D = 5                       # how many prompts to run

RESULTS_DIR = OUTDIR / "results"; RESULTS_DIR.mkdir(exist_ok=True)

def _slug(s, n=40):
    s = re.sub(r"[^\w\s-]", "", s).strip().lower()
    return re.sub(r"[\s_-]+", "_", s)[:n]

# ---------- 0. make sure model starts clean -------------------
for lst in ("H_SWAP", "H_MEAN"):
    if lst in globals():
        for h in globals()[lst]:
            h.remove()
        globals()[lst].clear()

# ---------- 1. build variant map set once --------------------
variant_maps = build_maps(
    means,
    transform=transform_choice,
    n_tokens=n_tokens,
    seed=rng_seed
)

# ---------- 2. create hook factory ---------------------------
def make_variant_hook(level):
    def _hook(_, __, out):
        t = step_counter["t"]
        if t < N_STEPS:
            m = variant_maps[level][t].to(out.device).to(out.dtype)
            return m.unsqueeze(0).expand_as(out)
        return out
    return _hook

# helper to attach & detach variant hooks
def attach_swap_hooks():
    global H_SWAP
    H_SWAP = [
        blk.ff.register_forward_hook(make_variant_hook(lvl))
        for lvl, blk in enumerate(pipe.transformer.transformer_blocks)
    ]
def detach_swap_hooks():
    for h in globals().get("H_SWAP", []):
        h.remove()
    H_SWAP.clear()

# ---------- 3. run through prompts ---------------------------
for i, prompt in enumerate(tqdm.tqdm(PROMPTS[:N_PROMPTS_D], desc="Phase D")):
    base_png  = f"baseline_{transform_choice}_{i:02d}_{_slug(prompt)}.png"
    var_png   = f"{transform_choice}_{i:02d}_{_slug(prompt)}.png"

    # --- baseline ---
    detach_swap_hooks()                            # ensure no hooks
    generate(prompt, SEED+i, base_png)
    (OUTDIR/base_png).replace(RESULTS_DIR/base_png)

    # --- variant ---
    attach_swap_hooks()                            # hooks ON
    generate(prompt, SEED+i, var_png)
    (OUTDIR/var_png).replace(RESULTS_DIR/var_png)

    detach_swap_hooks()                            # clean for next loop

    # optional GPU tidy‑up if many prompts
    gc.collect(); torch.cuda.empty_cache()

print(f"Saved {2*N_PROMPTS_D} images → {RESULTS_DIR.resolve()}")

# ---------- 4. quick preview of the first pair ---------------
img_b = Image.open(RESULTS_DIR / f"baseline_{transform_choice}_00_{_slug(PROMPTS[0])}.png")
img_v = Image.open(RESULTS_DIR / f"{transform_choice}_00_{_slug(PROMPTS[0])}.png")

fig, ax = plt.subplots(1,2, figsize=(10,5))
ax[0].imshow(img_b); ax[0].set_title("Baseline");           ax[0].axis("off")
ax[1].imshow(img_v); ax[1].set_title(transform_choice);     ax[1].axis("off")
fig.suptitle(PROMPTS[0], wrap=True); plt.tight_layout(); plt.show()
