# SANA Latent-Space Visualizer

This experiment hooks into the **self-attention**, **cross-attention**, and **FFN** modules of the SANA Transformer to visualize how the latent representations evolve across **transformer layers** and **diffusion timesteps**.

For each timestep in the denoising process, the **input and output** of every module in each block is captured, decoded using the VAE, and saved as an image. This helps reveal the internal dynamics of how the model builds up the final generated image.


"A banana on the left side and a apple on the right side."

"A tree on the left side and a car on the right side."

"a flower at the top of the image and a balloon at the bottom of the image."

"a flower at the top, house in the bootom and sky in the right."

"a cyberpunk cat with a neon sign that says "Sana""

In [None]:
# -----------------------------------------------
#  SANA latent‑space visualiser  (with attn1/attn2 in+out)
# -----------------------------------------------
import os, math, torch
from PIL import Image
from diffusers import SanaPipeline

# -----------------------------
# 1)  load pipeline
# -----------------------------
pipe = SanaPipeline.from_pretrained(
    "Efficient-Large-Model/Sana_1600M_1024px_diffusers",
    variant="fp16",
    torch_dtype=torch.float16,
).to("cuda")

pipe.vae.to(torch.bfloat16)
pipe.text_encoder.to(torch.bfloat16)

root_dir = "exp1-sana_latent_vis_flower_house_sky"
os.makedirs(root_dir, exist_ok=True)

# -----------------------------
# 2)  bookkeeping
# -----------------------------
step_idx = {"t": -1}
pipe.transformer.register_forward_pre_hook(lambda *_: step_idx.__setitem__("t", step_idx["t"] + 1))

# -----------------------------
# 3)  helper – reshape, decode, save
# -----------------------------
@torch.no_grad()
def decode_and_save(lat, tag, step, block=None):
    if isinstance(lat, tuple):
        lat = lat[0]

    if lat.ndim == 3:                                  # (B,S,C)
        B, S, C = lat.shape
        if C == 2240:
            lat = pipe.transformer.proj_out(lat)       # (B,S,32)
        size = int(math.isqrt(S))
        lat = lat.transpose(1, 2).reshape(B, 32, size, size)
    else:                                              # (B,C,H,W)
        if lat.shape[1] == 2240:
            B, _, H, W = lat.shape
            lat = lat.permute(0, 2, 3, 1).reshape(B, H * W, 2240)
            lat = pipe.transformer.proj_out(lat)
            lat = lat.transpose(1, 2).reshape(B, 32, H, W)

    lat = lat.float() / pipe.vae.config.scaling_factor
    lat = lat.to(pipe.vae.decoder.conv_in.weight.dtype)
    rgb = pipe.vae.decode(lat).sample.float()
    img = ((rgb[0]*0.5+0.5).clamp(0,1)*255).permute(1,2,0)\
          .cpu().to(torch.uint8).numpy()

    sub = os.path.join(root_dir, f"step_{step:02d}")
    os.makedirs(sub, exist_ok=True)
    fn = f"{tag}.png" if block is None else f"block{block:02d}_{tag}.png"
    Image.fromarray(img).save(os.path.join(sub, fn))

# -----------------------------
# 4)  hooks on every block
# -----------------------------
for blk_id, blk in enumerate(pipe.transformer.transformer_blocks):

    # FFN input / output
    blk.ff.register_forward_pre_hook(
        lambda m, inp, idx=blk_id:
            decode_and_save(inp[0], "ff_in",  step_idx["t"], idx)
    )
    blk.ff.register_forward_hook(
        lambda m, inp, out, idx=blk_id:
            decode_and_save(out,    "ff_out", step_idx["t"], idx)
    )

    # attn1  (self‑attention)
    blk.attn1.register_forward_pre_hook(
        lambda m, inp, idx=blk_id:
            decode_and_save(inp[0], "attn1_in",  step_idx["t"], idx)
    )
    blk.attn1.register_forward_hook(
        lambda m, inp, out, idx=blk_id:
            decode_and_save(out,    "attn1_out", step_idx["t"], idx)
    )

    # attn2  (cross‑attention)
    blk.attn2.register_forward_pre_hook(
        lambda m, inp, idx=blk_id:
            decode_and_save(inp[0], "attn2_in",  step_idx["t"], idx)
    )
    blk.attn2.register_forward_hook(
        lambda m, inp, out, idx=blk_id:
            decode_and_save(out,    "attn2_out", step_idx["t"], idx)
    )

# -----------------------------
# 5)  transformer output  (post‑hook)
# -----------------------------
pipe.transformer.register_forward_hook(
    lambda _m, _inp, out:
        decode_and_save(out[0] if isinstance(out, tuple) else out,
                        "transformer_out",
                        step_idx["t"])
)

# -----------------------------
# 6)  run diffusion
# -----------------------------
num_steps = 20
pipe.scheduler.set_timesteps(num_steps)

prompt    = "a flower at the top, house in the bootom and sky in the right."
generator = torch.Generator(device="cuda").manual_seed(42)

with torch.inference_mode():
    images = pipe(
        prompt              = prompt,
        guidance_scale=4.0,
        num_inference_steps = num_steps,
        generator           = generator,
    ).images

images[0].save(os.path.join(root_dir, "final.png"))
print(f"✔ Activations saved under “{root_dir}/”")


In [1]:
# -----------------------------------------------
#  SANA latent‑space visualiser  (with attn1/attn2 in+out)
# -----------------------------------------------
import os, math, torch
from PIL import Image
from diffusers import SanaPipeline

# -----------------------------
# 1)  load pipeline
# -----------------------------
pipe = SanaPipeline.from_pretrained(
    "Efficient-Large-Model/Sana_1600M_1024px_diffusers",
    variant="fp16",
    torch_dtype=torch.float16,
).to("cuda")

pipe.vae.to(torch.bfloat16)
pipe.text_encoder.to(torch.bfloat16)

root_dir = "sana_latent_vis_banana_apple"
os.makedirs(root_dir, exist_ok=True)

# -----------------------------
# 2)  bookkeeping
# -----------------------------
step_idx = {"t": -1}
pipe.transformer.register_forward_pre_hook(lambda *_: step_idx.__setitem__("t", step_idx["t"] + 1))

# -----------------------------
# 3)  helper – reshape, decode, save
# -----------------------------
@torch.no_grad()
def decode_and_save(lat, tag, step, block=None):
    if isinstance(lat, tuple):
        lat = lat[0]

    if lat.ndim == 3:                                  # (B,S,C)
        B, S, C = lat.shape
        if C == 2240:
            lat = pipe.transformer.proj_out(lat)       # (B,S,32)
        size = int(math.isqrt(S))
        lat = lat.transpose(1, 2).reshape(B, 32, size, size)
    else:                                              # (B,C,H,W)
        if lat.shape[1] == 2240:
            B, _, H, W = lat.shape
            lat = lat.permute(0, 2, 3, 1).reshape(B, H * W, 2240)
            lat = pipe.transformer.proj_out(lat)
            lat = lat.transpose(1, 2).reshape(B, 32, H, W)

    lat = lat.float() / pipe.vae.config.scaling_factor
    lat = lat.to(pipe.vae.decoder.conv_in.weight.dtype)
    rgb = pipe.vae.decode(lat).sample.float()
    img = ((rgb[0]*0.5+0.5).clamp(0,1)*255).permute(1,2,0)\
          .cpu().to(torch.uint8).numpy()

    sub = os.path.join(root_dir, f"step_{step:02d}")
    os.makedirs(sub, exist_ok=True)
    fn = f"{tag}.png" if block is None else f"block{block:02d}_{tag}.png"
    Image.fromarray(img).save(os.path.join(sub, fn))

# -----------------------------
# 4)  hooks on every block
# -----------------------------
for blk_id, blk in enumerate(pipe.transformer.transformer_blocks):

    # FFN input / output
    blk.ff.register_forward_pre_hook(
        lambda m, inp, idx=blk_id:
            decode_and_save(inp[0], "ff_in",  step_idx["t"], idx)
    )
    blk.ff.register_forward_hook(
        lambda m, inp, out, idx=blk_id:
            decode_and_save(out,    "ff_out", step_idx["t"], idx)
    )

    # attn1  (self‑attention)
    blk.attn1.register_forward_pre_hook(
        lambda m, inp, idx=blk_id:
            decode_and_save(inp[0], "attn1_in",  step_idx["t"], idx)
    )
    blk.attn1.register_forward_hook(
        lambda m, inp, out, idx=blk_id:
            decode_and_save(out,    "attn1_out", step_idx["t"], idx)
    )

    # attn2  (cross‑attention)
    blk.attn2.register_forward_pre_hook(
        lambda m, inp, idx=blk_id:
            decode_and_save(inp[0], "attn2_in",  step_idx["t"], idx)
    )
    blk.attn2.register_forward_hook(
        lambda m, inp, out, idx=blk_id:
            decode_and_save(out,    "attn2_out", step_idx["t"], idx)
    )

# -----------------------------
# 5)  transformer output  (post‑hook)
# -----------------------------
pipe.transformer.register_forward_hook(
    lambda _m, _inp, out:
        decode_and_save(out[0] if isinstance(out, tuple) else out,
                        "transformer_out",
                        step_idx["t"])
)

# -----------------------------
# 6)  run diffusion
# -----------------------------
num_steps = 20
pipe.scheduler.set_timesteps(num_steps)

prompt    = "A banana on the left side and an apple on the right side."
generator = torch.Generator(device="cuda").manual_seed(42)

with torch.inference_mode():
    images = pipe(
        prompt              = prompt,
        guidance_scale=4.0,
        num_inference_steps = num_steps,
        generator           = generator,
    ).images

images[0].save(os.path.join(root_dir, "final.png"))
print(f"✔ Activations saved under “{root_dir}/”")


Loading pipeline components...:   0%|          | 0/5 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

✔ Activations saved under “sana_latent_vis_banana_apple/”


In [2]:
# -----------------------------------------------
#  SANA latent‑space visualiser  (with attn1/attn2 in+out)
# -----------------------------------------------
import os, math, torch
from PIL import Image
from diffusers import SanaPipeline

# -----------------------------
# 1)  load pipeline
# -----------------------------
pipe = SanaPipeline.from_pretrained(
    "Efficient-Large-Model/Sana_1600M_1024px_diffusers",
    variant="fp16",
    torch_dtype=torch.float16,
).to("cuda")

pipe.vae.to(torch.bfloat16)
pipe.text_encoder.to(torch.bfloat16)

root_dir = "sana_latent_vis_flower_ballo"
os.makedirs(root_dir, exist_ok=True)

# -----------------------------
# 2)  bookkeeping
# -----------------------------
step_idx = {"t": -1}
pipe.transformer.register_forward_pre_hook(lambda *_: step_idx.__setitem__("t", step_idx["t"] + 1))

# -----------------------------
# 3)  helper – reshape, decode, save
# -----------------------------
@torch.no_grad()
def decode_and_save(lat, tag, step, block=None):
    if isinstance(lat, tuple):
        lat = lat[0]

    if lat.ndim == 3:                                  # (B,S,C)
        B, S, C = lat.shape
        if C == 2240:
            lat = pipe.transformer.proj_out(lat)       # (B,S,32)
        size = int(math.isqrt(S))
        lat = lat.transpose(1, 2).reshape(B, 32, size, size)
    else:                                              # (B,C,H,W)
        if lat.shape[1] == 2240:
            B, _, H, W = lat.shape
            lat = lat.permute(0, 2, 3, 1).reshape(B, H * W, 2240)
            lat = pipe.transformer.proj_out(lat)
            lat = lat.transpose(1, 2).reshape(B, 32, H, W)

    lat = lat.float() / pipe.vae.config.scaling_factor
    lat = lat.to(pipe.vae.decoder.conv_in.weight.dtype)
    rgb = pipe.vae.decode(lat).sample.float()
    img = ((rgb[0]*0.5+0.5).clamp(0,1)*255).permute(1,2,0)\
          .cpu().to(torch.uint8).numpy()

    sub = os.path.join(root_dir, f"step_{step:02d}")
    os.makedirs(sub, exist_ok=True)
    fn = f"{tag}.png" if block is None else f"block{block:02d}_{tag}.png"
    Image.fromarray(img).save(os.path.join(sub, fn))

# -----------------------------
# 4)  hooks on every block
# -----------------------------
for blk_id, blk in enumerate(pipe.transformer.transformer_blocks):

    # FFN input / output
    blk.ff.register_forward_pre_hook(
        lambda m, inp, idx=blk_id:
            decode_and_save(inp[0], "ff_in",  step_idx["t"], idx)
    )
    blk.ff.register_forward_hook(
        lambda m, inp, out, idx=blk_id:
            decode_and_save(out,    "ff_out", step_idx["t"], idx)
    )

    # attn1  (self‑attention)
    blk.attn1.register_forward_pre_hook(
        lambda m, inp, idx=blk_id:
            decode_and_save(inp[0], "attn1_in",  step_idx["t"], idx)
    )
    blk.attn1.register_forward_hook(
        lambda m, inp, out, idx=blk_id:
            decode_and_save(out,    "attn1_out", step_idx["t"], idx)
    )

    # attn2  (cross‑attention)
    blk.attn2.register_forward_pre_hook(
        lambda m, inp, idx=blk_id:
            decode_and_save(inp[0], "attn2_in",  step_idx["t"], idx)
    )
    blk.attn2.register_forward_hook(
        lambda m, inp, out, idx=blk_id:
            decode_and_save(out,    "attn2_out", step_idx["t"], idx)
    )

# -----------------------------
# 5)  transformer output  (post‑hook)
# -----------------------------
pipe.transformer.register_forward_hook(
    lambda _m, _inp, out:
        decode_and_save(out[0] if isinstance(out, tuple) else out,
                        "transformer_out",
                        step_idx["t"])
)

# -----------------------------
# 6)  run diffusion
# -----------------------------
num_steps = 20
pipe.scheduler.set_timesteps(num_steps)

prompt    = "a flower at the top of the image and a balloon at the bottom of the image."
generator = torch.Generator(device="cuda").manual_seed(42)

with torch.inference_mode():
    images = pipe(
        prompt              = prompt,
        guidance_scale=4.0,
        num_inference_steps = num_steps,
        generator           = generator,
    ).images

images[0].save(os.path.join(root_dir, "final.png"))
print(f"✔ Activations saved under “{root_dir}/”")


Loading pipeline components...:   0%|          | 0/5 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

✔ Activations saved under “sana_latent_vis_flower_ballo/”
