In [1]:
import torch
from diffusers import SanaPipeline

pipe = SanaPipeline.from_pretrained(
    "Efficient-Large-Model/Sana_1600M_1024px_diffusers",
    #"Efficient-Large-Model/Sana_600M_1024px_diffusers",
    variant="fp16",
    torch_dtype=torch.float16,
)
pipe.to("cuda")

pipe.vae.to(torch.bfloat16)
pipe.text_encoder.to(torch.bfloat16)



  @torch.library.impl_abstract("xformers_flash::flash_fwd")
  @torch.library.impl_abstract("xformers_flash::flash_bwd")


Loading pipeline components...:   0%|          | 0/5 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Gemma2Model(
  (embed_tokens): Embedding(256000, 2304, padding_idx=0)
  (layers): ModuleList(
    (0-25): 26 x Gemma2DecoderLayer(
      (self_attn): Gemma2Attention(
        (q_proj): Linear(in_features=2304, out_features=2048, bias=False)
        (k_proj): Linear(in_features=2304, out_features=1024, bias=False)
        (v_proj): Linear(in_features=2304, out_features=1024, bias=False)
        (o_proj): Linear(in_features=2048, out_features=2304, bias=False)
        (rotary_emb): Gemma2RotaryEmbedding()
      )
      (mlp): Gemma2MLP(
        (gate_proj): Linear(in_features=2304, out_features=9216, bias=False)
        (up_proj): Linear(in_features=2304, out_features=9216, bias=False)
        (down_proj): Linear(in_features=9216, out_features=2304, bias=False)
        (act_fn): PytorchGELUTanh()
      )
      (input_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)
      (pre_feedforward_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)
      (post_feedforward_layernorm): Gemma2RMSNorm((2304,),

In [2]:
pipe

SanaPipeline {
  "_class_name": "SanaPipeline",
  "_diffusers_version": "0.34.0.dev0",
  "_name_or_path": "Efficient-Large-Model/Sana_1600M_1024px_diffusers",
  "scheduler": [
    "diffusers",
    "DPMSolverMultistepScheduler"
  ],
  "text_encoder": [
    "transformers",
    "Gemma2Model"
  ],
  "tokenizer": [
    "transformers",
    "GemmaTokenizerFast"
  ],
  "transformer": [
    "diffusers",
    "SanaTransformer2DModel"
  ],
  "vae": [
    "diffusers",
    "AutoencoderDC"
  ]
}

In [3]:
len(pipe.transformer.transformer_blocks)

20

In [5]:
import torch
import os

def zero_ablation_hook(module, input, output):
    return torch.zeros_like(output)

def mean_ablation_hook(module, input, output):
    return output.mean(dim=1, keepdim=True).expand_as(output)

# === Create output directory ===
os.makedirs("ablated_outputs", exist_ok=True)

# === Define prompt ===
prompt = "apple"

# === Generate and save baseline (no ablation) ===
print("Generating baseline image (no ablation)")
baseline_image = pipe(
    prompt=prompt,
    height=1024,
    width=1024,
    guidance_scale=5.0,
    num_inference_steps=20,
    generator=torch.Generator(device="cuda").manual_seed(42),
)[0]
baseline_image[0].save("ablated_outputs/baseline_no_ablation.png")

# === Iterate over transformer blocks for zero ablation ===
for idx, block in enumerate(pipe.transformer.transformer_blocks):
    print(f"Zero ablating FFN in block {idx}")

    # Register hook
    handle = block.ff.register_forward_hook(zero_ablation_hook)

    # Generate image
    image = pipe(
        prompt=prompt,
        height=1024,
        width=1024,
        guidance_scale=5.0,
        num_inference_steps=20,
        generator=torch.Generator(device="cuda").manual_seed(42),
    )[0]

    # Save image
    image[0].save(f"ablated_outputs/ablated_block{idx}_ffn_zero.png")

    # Remove hook
    handle.remove()

# === Optional: Repeat for mean ablation ===
for idx, block in enumerate(pipe.transformer.transformer_blocks):
    print(f"Mean ablating FFN in block {idx}")

    # Register hook
    handle = block.ff.register_forward_hook(mean_ablation_hook)

    # Generate image
    image = pipe(
        prompt=prompt,
        height=1024,
        width=1024,
        guidance_scale=5.0,
        num_inference_steps=20,
        generator=torch.Generator(device="cuda").manual_seed(42),
    )[0]

    # Save image
    image[0].save(f"ablated_outputs/ablated_block{idx}_ffn_mean.png")

    # Remove hook
    handle.remove()


Generating baseline image (no ablation)


  0%|          | 0/20 [00:00<?, ?it/s]

Zero ablating FFN in block 0


  0%|          | 0/20 [00:00<?, ?it/s]

Zero ablating FFN in block 1


  0%|          | 0/20 [00:00<?, ?it/s]

Zero ablating FFN in block 2


  0%|          | 0/20 [00:00<?, ?it/s]

Zero ablating FFN in block 3


  0%|          | 0/20 [00:00<?, ?it/s]

Zero ablating FFN in block 4


  0%|          | 0/20 [00:00<?, ?it/s]

Zero ablating FFN in block 5


  0%|          | 0/20 [00:00<?, ?it/s]

Zero ablating FFN in block 6


  0%|          | 0/20 [00:00<?, ?it/s]

Zero ablating FFN in block 7


  0%|          | 0/20 [00:00<?, ?it/s]

Zero ablating FFN in block 8


  0%|          | 0/20 [00:00<?, ?it/s]

Zero ablating FFN in block 9


  0%|          | 0/20 [00:00<?, ?it/s]

Zero ablating FFN in block 10


  0%|          | 0/20 [00:00<?, ?it/s]

Zero ablating FFN in block 11


  0%|          | 0/20 [00:00<?, ?it/s]

Zero ablating FFN in block 12


  0%|          | 0/20 [00:00<?, ?it/s]

Zero ablating FFN in block 13


  0%|          | 0/20 [00:00<?, ?it/s]

Zero ablating FFN in block 14


  0%|          | 0/20 [00:00<?, ?it/s]

Zero ablating FFN in block 15


  0%|          | 0/20 [00:00<?, ?it/s]

Zero ablating FFN in block 16


  0%|          | 0/20 [00:00<?, ?it/s]

Zero ablating FFN in block 17


  0%|          | 0/20 [00:00<?, ?it/s]

Zero ablating FFN in block 18


  0%|          | 0/20 [00:00<?, ?it/s]

Zero ablating FFN in block 19


  0%|          | 0/20 [00:00<?, ?it/s]

Mean ablating FFN in block 0


  0%|          | 0/20 [00:00<?, ?it/s]

Mean ablating FFN in block 1


  0%|          | 0/20 [00:00<?, ?it/s]

Mean ablating FFN in block 2


  0%|          | 0/20 [00:00<?, ?it/s]

Mean ablating FFN in block 3


  0%|          | 0/20 [00:00<?, ?it/s]

Mean ablating FFN in block 4


  0%|          | 0/20 [00:00<?, ?it/s]

Mean ablating FFN in block 5


  0%|          | 0/20 [00:00<?, ?it/s]

Mean ablating FFN in block 6


  0%|          | 0/20 [00:00<?, ?it/s]

Mean ablating FFN in block 7


  0%|          | 0/20 [00:00<?, ?it/s]

Mean ablating FFN in block 8


  0%|          | 0/20 [00:00<?, ?it/s]

Mean ablating FFN in block 9


  0%|          | 0/20 [00:00<?, ?it/s]

Mean ablating FFN in block 10


  0%|          | 0/20 [00:00<?, ?it/s]

Mean ablating FFN in block 11


  0%|          | 0/20 [00:00<?, ?it/s]

Mean ablating FFN in block 12


  0%|          | 0/20 [00:00<?, ?it/s]

Mean ablating FFN in block 13


  0%|          | 0/20 [00:00<?, ?it/s]

Mean ablating FFN in block 14


  0%|          | 0/20 [00:00<?, ?it/s]

Mean ablating FFN in block 15


  0%|          | 0/20 [00:00<?, ?it/s]

Mean ablating FFN in block 16


  0%|          | 0/20 [00:00<?, ?it/s]

Mean ablating FFN in block 17


  0%|          | 0/20 [00:00<?, ?it/s]

Mean ablating FFN in block 18


  0%|          | 0/20 [00:00<?, ?it/s]

Mean ablating FFN in block 19


  0%|          | 0/20 [00:00<?, ?it/s]

In [4]:
import torch
import os

def zero_ablation_hook(module, input, output):
    return torch.zeros_like(output)

def mean_ablation_hook(module, input, output):
    return output.mean(dim=1, keepdim=True).expand_as(output)

# Create output directory
os.makedirs("ablated_outputs_attn1", exist_ok=True)

# Define prompt
prompt = "a banana at the top of the image and an apple at the bottom"

# === Generate and save baseline image ===
print("Generating baseline image (no ablation)")
baseline_image = pipe(
    prompt=prompt,
    height=1024,
    width=1024,
    guidance_scale=5.0,
    num_inference_steps=20,
    generator=torch.Generator(device="cuda").manual_seed(42),
)[0]
baseline_image[0].save("ablated_outputs_attn1/baseline_no_ablation.png")

# === Iterate over transformer blocks and ablate attn1 ===
for idx, block in enumerate(pipe.transformer.transformer_blocks):
    print(f"Mean ablating attn1 in block {idx}")

    # Register hook on attn1
    handle = block.attn1.register_forward_hook(mean_ablation_hook)

    # Generate image with ablated attn1
    image = pipe(
        prompt=prompt,
        height=1024,
        width=1024,
        guidance_scale=5.0,
        num_inference_steps=20,
        generator=torch.Generator(device="cuda").manual_seed(42),  # fresh generator
    )[0]

    # Save image
    image[0].save(f"ablated_outputs_attn1/ablated_block{idx}_attn1_mean.png")

    # Remove hook
    handle.remove()


Generating baseline image (no ablation)


  0%|          | 0/20 [00:00<?, ?it/s]

Mean ablating attn1 in block 0


  0%|          | 0/20 [00:00<?, ?it/s]

Mean ablating attn1 in block 1


  0%|          | 0/20 [00:00<?, ?it/s]

Mean ablating attn1 in block 2


  0%|          | 0/20 [00:00<?, ?it/s]

Mean ablating attn1 in block 3


  0%|          | 0/20 [00:00<?, ?it/s]

Mean ablating attn1 in block 4


  0%|          | 0/20 [00:00<?, ?it/s]

Mean ablating attn1 in block 5


  0%|          | 0/20 [00:00<?, ?it/s]

Mean ablating attn1 in block 6


  0%|          | 0/20 [00:00<?, ?it/s]

Mean ablating attn1 in block 7


  0%|          | 0/20 [00:00<?, ?it/s]

Mean ablating attn1 in block 8


  0%|          | 0/20 [00:00<?, ?it/s]

Mean ablating attn1 in block 9


  0%|          | 0/20 [00:00<?, ?it/s]

Mean ablating attn1 in block 10


  0%|          | 0/20 [00:00<?, ?it/s]

Mean ablating attn1 in block 11


  0%|          | 0/20 [00:00<?, ?it/s]

Mean ablating attn1 in block 12


  0%|          | 0/20 [00:00<?, ?it/s]

Mean ablating attn1 in block 13


  0%|          | 0/20 [00:00<?, ?it/s]

Mean ablating attn1 in block 14


  0%|          | 0/20 [00:00<?, ?it/s]

Mean ablating attn1 in block 15


  0%|          | 0/20 [00:00<?, ?it/s]

Mean ablating attn1 in block 16


  0%|          | 0/20 [00:00<?, ?it/s]

Mean ablating attn1 in block 17


  0%|          | 0/20 [00:00<?, ?it/s]

Mean ablating attn1 in block 18


  0%|          | 0/20 [00:00<?, ?it/s]

Mean ablating attn1 in block 19


  0%|          | 0/20 [00:00<?, ?it/s]