# SANA FFN Ablation – Step-Wise vs. Block-Wise

This experiment ablates the **Feed-Forward Network (FFN)** outputs in the SANA Transformer to assess how they influence the generated image.

Two types of ablation are explored:
- **Step-wise**: Replace FFN outputs at a specific **timestep** across all blocks.
- **Block-wise**: Replace FFN outputs at a specific **transformer block** across all timesteps.

Each ablation is done using:
- `zero` — replaces FFN output with all zeros.
- `mean` — replaces FFN output with its spatial/channel-wise mean.

Images are saved for each ablation case and compared to the baseline image.


In [1]:
# prompt= "a banana at the top of the image and an apple at the bottom"
# PARENT = "exp2-sana_ablation_results_banana_apple"

In [2]:
# prompt = "a flower at the top, house in the bootom and sky in the right."
# PARENT = "exp2-sana_ablation_results_flower, house, sky"

In [3]:
prompt= 'a cyberpunk cat with a neon sign that says "Sana"'
PARENT = "exp2-sana_ablation_results_sana"

In [4]:
# =============================================================
#  SANA – two ablation experiments
#      1. step‑wise  (all blocks at timestep j)
#      2. block‑wise (block i across all timesteps)
# =============================================================
import os, shutil, torch
from PIL import Image
from diffusers import SanaPipeline

# ------------------------------------------------------------------
# 0)  config
# ------------------------------------------------------------------
num_steps      = 20
num_blocks     = 20                       # SANA‑1600M has 20 blocks
guidance_scale = 3.0                      # same as your last script
seed           = 42

ROOT_STEP  = {"zero": os.path.join(PARENT, "step_zero"),
              "mean": os.path.join(PARENT, "step_mean")}
ROOT_BLOCK = {"zero": os.path.join(PARENT, "block_zero"),
              "mean": os.path.join(PARENT, "block_mean")}

# create / clean all at once
for d in [*ROOT_STEP.values(), *ROOT_BLOCK.values()]:
    shutil.rmtree(d, ignore_errors=True)
    os.makedirs(d, exist_ok=True)

# ------------------------------------------------------------------
# 1)  load pipeline (same as before)
# ------------------------------------------------------------------
pipe = SanaPipeline.from_pretrained(
    "Efficient-Large-Model/Sana_1600M_1024px_diffusers",
    variant="fp16",
    torch_dtype=torch.float16,
).to("cuda")

pipe.vae.to(torch.bfloat16)
pipe.text_encoder.to(torch.bfloat16)
pipe.scheduler.set_timesteps(num_steps)

# ------------------------------------------------------------------
# 2)  global timestep counter
# ------------------------------------------------------------------
step_counter = {"t": -1}
pipe.transformer.register_forward_pre_hook(lambda *_: step_counter.__setitem__("t", step_counter["t"] + 1))

# ------------------------------------------------------------------
# 3)  utility: run once with supplied hooks → final PIL.Image
# ------------------------------------------------------------------
def run_with_hooks(hook_handles):
    step_counter["t"] = -1
    g = torch.Generator(device="cuda").manual_seed(seed)
    with torch.inference_mode():
        img = pipe(
            prompt              = prompt,
            guidance_scale      = guidance_scale,
            num_inference_steps = num_steps,
            generator           = g,
        ).images[0]
    for h in hook_handles:
        h.remove()
    return img

# ------------------------------------------------------------------
# 4)  baseline (no ablation) into every dir
# ------------------------------------------------------------------
baseline = run_with_hooks([])
for d in [*ROOT_STEP.values(), *ROOT_BLOCK.values()]:
    baseline.save(os.path.join(d, "baseline.png"))

# ------------------------------------------------------------------
# 5)  STEP‑WISE experiment
# ------------------------------------------------------------------
def make_step_hook(abl_type, target_step):
    def hook(_m, _inp, out):
        if step_counter["t"] != target_step:
            return out
        if abl_type == "zero":
            return torch.zeros_like(out)
        mean = out.mean(dim=1, keepdim=True) if out.ndim == 3 \
             else out.mean(dim=(2,3), keepdim=True)
        return mean.expand_as(out)
    return hook

for abl_type, root in ROOT_STEP.items():
    print(f"\n== STEP‑WISE  {abl_type.upper()} ==")
    for j in range(num_steps):
        # register hook on **every** block
        handles = [
            blk.ff.register_forward_hook(make_step_hook(abl_type, j))
            for blk in pipe.transformer.transformer_blocks
        ]
        img = run_with_hooks(handles)
        img.save(os.path.join(root, f"step_{j:02d}.png"))
        print(f"{abl_type} step {j:02d}", end="\r")

# ------------------------------------------------------------------
# 6)  BLOCK‑WISE experiment
# ------------------------------------------------------------------
def make_block_hook(abl_type):
    def hook(_m, _inp, out):
        if abl_type == "zero":
            return torch.zeros_like(out)
        mean = out.mean(dim=1, keepdim=True) if out.ndim == 3 \
             else out.mean(dim=(2,3), keepdim=True)
        return mean.expand_as(out)
    return hook

for abl_type, root in ROOT_BLOCK.items():
    print(f"\n== BLOCK‑WISE {abl_type.upper()} ==")
    for i, blk in enumerate(pipe.transformer.transformer_blocks):
        handle = blk.ff.register_forward_hook(make_block_hook(abl_type))
        img = run_with_hooks([handle])
        img.save(os.path.join(root, f"block_{i:02d}.png"))
        print(f"{abl_type} block {i:02d}", end="\r")

print("\n✔ Done – results in sana_step_* / sana_block_*")


Loading pipeline components...:   0%|          | 0/5 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]


== STEP‑WISE  ZERO ==


  0%|          | 0/20 [00:00<?, ?it/s]

zero step 00

  0%|          | 0/20 [00:00<?, ?it/s]

zero step 01

  0%|          | 0/20 [00:00<?, ?it/s]

zero step 02

  0%|          | 0/20 [00:00<?, ?it/s]

zero step 03

  0%|          | 0/20 [00:00<?, ?it/s]

zero step 04

  0%|          | 0/20 [00:00<?, ?it/s]

zero step 05

  0%|          | 0/20 [00:00<?, ?it/s]

zero step 06

  0%|          | 0/20 [00:00<?, ?it/s]

zero step 07

  0%|          | 0/20 [00:00<?, ?it/s]

zero step 08

  0%|          | 0/20 [00:00<?, ?it/s]

zero step 09

  0%|          | 0/20 [00:00<?, ?it/s]

zero step 10

  0%|          | 0/20 [00:00<?, ?it/s]

zero step 11

  0%|          | 0/20 [00:00<?, ?it/s]

zero step 12

  0%|          | 0/20 [00:00<?, ?it/s]

zero step 13

  0%|          | 0/20 [00:00<?, ?it/s]

zero step 14

  0%|          | 0/20 [00:00<?, ?it/s]

zero step 15

  0%|          | 0/20 [00:00<?, ?it/s]

zero step 16

  0%|          | 0/20 [00:00<?, ?it/s]

zero step 17

  0%|          | 0/20 [00:00<?, ?it/s]

zero step 18

  0%|          | 0/20 [00:00<?, ?it/s]

zero step 19
== STEP‑WISE  MEAN ==


  0%|          | 0/20 [00:00<?, ?it/s]

mean step 00

  0%|          | 0/20 [00:00<?, ?it/s]

mean step 01

  0%|          | 0/20 [00:00<?, ?it/s]

mean step 02

  0%|          | 0/20 [00:00<?, ?it/s]

mean step 03

  0%|          | 0/20 [00:00<?, ?it/s]

mean step 04

  0%|          | 0/20 [00:00<?, ?it/s]

mean step 05

  0%|          | 0/20 [00:00<?, ?it/s]

mean step 06

  0%|          | 0/20 [00:00<?, ?it/s]

mean step 07

  0%|          | 0/20 [00:00<?, ?it/s]

mean step 08

  0%|          | 0/20 [00:00<?, ?it/s]

mean step 09

  0%|          | 0/20 [00:00<?, ?it/s]

mean step 10

  0%|          | 0/20 [00:00<?, ?it/s]

mean step 11

  0%|          | 0/20 [00:00<?, ?it/s]

mean step 12

  0%|          | 0/20 [00:00<?, ?it/s]

mean step 13

  0%|          | 0/20 [00:00<?, ?it/s]

mean step 14

  0%|          | 0/20 [00:00<?, ?it/s]

mean step 15

  0%|          | 0/20 [00:00<?, ?it/s]

mean step 16

  0%|          | 0/20 [00:00<?, ?it/s]

mean step 17

  0%|          | 0/20 [00:00<?, ?it/s]

mean step 18

  0%|          | 0/20 [00:00<?, ?it/s]

mean step 19
== BLOCK‑WISE ZERO ==


  0%|          | 0/20 [00:00<?, ?it/s]

zero block 00

  0%|          | 0/20 [00:00<?, ?it/s]

zero block 01

  0%|          | 0/20 [00:00<?, ?it/s]

zero block 02

  0%|          | 0/20 [00:00<?, ?it/s]

zero block 03

  0%|          | 0/20 [00:00<?, ?it/s]

zero block 04

  0%|          | 0/20 [00:00<?, ?it/s]

zero block 05

  0%|          | 0/20 [00:00<?, ?it/s]

zero block 06

  0%|          | 0/20 [00:00<?, ?it/s]

zero block 07

  0%|          | 0/20 [00:00<?, ?it/s]

zero block 08

  0%|          | 0/20 [00:00<?, ?it/s]

zero block 09

  0%|          | 0/20 [00:00<?, ?it/s]

zero block 10

  0%|          | 0/20 [00:00<?, ?it/s]

zero block 11

  0%|          | 0/20 [00:00<?, ?it/s]

zero block 12

  0%|          | 0/20 [00:00<?, ?it/s]

zero block 13

  0%|          | 0/20 [00:00<?, ?it/s]

zero block 14

  0%|          | 0/20 [00:00<?, ?it/s]

zero block 15

  0%|          | 0/20 [00:00<?, ?it/s]

zero block 16

  0%|          | 0/20 [00:00<?, ?it/s]

zero block 17

  0%|          | 0/20 [00:00<?, ?it/s]

zero block 18

  0%|          | 0/20 [00:00<?, ?it/s]

zero block 19
== BLOCK‑WISE MEAN ==


  0%|          | 0/20 [00:00<?, ?it/s]

mean block 00

  0%|          | 0/20 [00:00<?, ?it/s]

mean block 01

  0%|          | 0/20 [00:00<?, ?it/s]

mean block 02

  0%|          | 0/20 [00:00<?, ?it/s]

mean block 03

  0%|          | 0/20 [00:00<?, ?it/s]

mean block 04

  0%|          | 0/20 [00:00<?, ?it/s]

mean block 05

  0%|          | 0/20 [00:00<?, ?it/s]

mean block 06

  0%|          | 0/20 [00:00<?, ?it/s]

mean block 07

  0%|          | 0/20 [00:00<?, ?it/s]

mean block 08

  0%|          | 0/20 [00:00<?, ?it/s]

mean block 09

  0%|          | 0/20 [00:00<?, ?it/s]

mean block 10

  0%|          | 0/20 [00:00<?, ?it/s]

mean block 11

  0%|          | 0/20 [00:00<?, ?it/s]

mean block 12

  0%|          | 0/20 [00:00<?, ?it/s]

mean block 13

  0%|          | 0/20 [00:00<?, ?it/s]

mean block 14

  0%|          | 0/20 [00:00<?, ?it/s]

mean block 15

  0%|          | 0/20 [00:00<?, ?it/s]

mean block 16

  0%|          | 0/20 [00:00<?, ?it/s]

mean block 17

  0%|          | 0/20 [00:00<?, ?it/s]

mean block 18

  0%|          | 0/20 [00:00<?, ?it/s]

mean block 19
✔ Done – results in sana_step_* / sana_block_*
