In [1]:
import torch
from diffusers import SanaPipeline

pipe = SanaPipeline.from_pretrained(
    "Efficient-Large-Model/Sana_1600M_1024px_diffusers",
    #"Efficient-Large-Model/Sana_600M_1024px_diffusers",
    variant="fp16",
    torch_dtype=torch.float16,
)
pipe.to("cuda")

pipe.vae.to(torch.bfloat16)
pipe.text_encoder.to(torch.bfloat16)



  @torch.library.impl_abstract("xformers_flash::flash_fwd")
  @torch.library.impl_abstract("xformers_flash::flash_bwd")


Loading pipeline components...:   0%|          | 0/5 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Gemma2Model(
  (embed_tokens): Embedding(256000, 2304, padding_idx=0)
  (layers): ModuleList(
    (0-25): 26 x Gemma2DecoderLayer(
      (self_attn): Gemma2Attention(
        (q_proj): Linear(in_features=2304, out_features=2048, bias=False)
        (k_proj): Linear(in_features=2304, out_features=1024, bias=False)
        (v_proj): Linear(in_features=2304, out_features=1024, bias=False)
        (o_proj): Linear(in_features=2048, out_features=2304, bias=False)
        (rotary_emb): Gemma2RotaryEmbedding()
      )
      (mlp): Gemma2MLP(
        (gate_proj): Linear(in_features=2304, out_features=9216, bias=False)
        (up_proj): Linear(in_features=2304, out_features=9216, bias=False)
        (down_proj): Linear(in_features=9216, out_features=2304, bias=False)
        (act_fn): PytorchGELUTanh()
      )
      (input_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)
      (pre_feedforward_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)
      (post_feedforward_layernorm): Gemma2RMSNorm((2304,),

In [8]:
pipe.transformer_blocks

SanaTransformer2DModel(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(32, 2240, kernel_size=(1, 1), stride=(1, 1))
  )
  (time_embed): AdaLayerNormSingle(
    (emb): PixArtAlphaCombinedTimestepSizeEmbeddings(
      (time_proj): Timesteps()
      (timestep_embedder): TimestepEmbedding(
        (linear_1): Linear(in_features=256, out_features=2240, bias=True)
        (act): SiLU()
        (linear_2): Linear(in_features=2240, out_features=2240, bias=True)
      )
    )
    (silu): SiLU()
    (linear): Linear(in_features=2240, out_features=13440, bias=True)
  )
  (caption_projection): PixArtAlphaTextProjection(
    (linear_1): Linear(in_features=2304, out_features=2240, bias=True)
    (act_1): GELU(approximate='tanh')
    (linear_2): Linear(in_features=2240, out_features=2240, bias=True)
  )
  (caption_norm): RMSNorm()
  (transformer_blocks): ModuleList(
    (0-19): 20 x SanaTransformerBlock(
      (norm1): LayerNorm((2240,), eps=1e-06, elementwise_affine=False)
      (attn1): Attention(

In [19]:
from torch.nn import Identity

pipe.transformer.identity_after_attn = Identity()
pipe.transformer.identity_after_ff = Identity()


In [31]:
from torch.nn import Identity
import types

def install_forward_with_identities(block):
    block.identity_after_attn = Identity()
    block.identity_after_ff = Identity()

    def forward2(self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        encoder_attention_mask: Optional[torch.Tensor] = None,
        timestep: Optional[torch.LongTensor] = None,
        height: int = None,
        width: int = None,
    ) -> torch.Tensor:
        batch_size = hidden_states.shape[0]

        shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
            self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
        ).chunk(6, dim=1)

        # Self-Attention
        norm_hidden_states = self.norm1(hidden_states)
        norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
        norm_hidden_states = norm_hidden_states.to(hidden_states.dtype)

        attn_output = self.attn1(norm_hidden_states)
        hidden_states = hidden_states + self.identity_after_attn(gate_msa * attn_output)

        # Cross-Attention
        if self.attn2 is not None:
            attn_output = self.attn2(
                hidden_states,
                encoder_hidden_states=encoder_hidden_states,
                attention_mask=encoder_attention_mask,
            )
            hidden_states = hidden_states + attn_output

        # Feed-forward
        norm_hidden_states = self.norm2(hidden_states)
        norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
        norm_hidden_states = norm_hidden_states.unflatten(1, (height, width)).permute(0, 3, 1, 2)

        ff_output = self.ff(norm_hidden_states)
        ff_output = ff_output.flatten(2, 3).permute(0, 2, 1)
        hidden_states = hidden_states + self.identity_after_ff(gate_mlp * ff_output)

        return hidden_states

    block.forward = types.MethodType(forward2, block)


In [32]:
for i, block in enumerate(pipe.transformer.transformer_blocks):
    install_forward_with_identities(block)
    print(f"✅ Patched transformer block {i}")


✅ Patched transformer block 0
✅ Patched transformer block 1
✅ Patched transformer block 2
✅ Patched transformer block 3
✅ Patched transformer block 4
✅ Patched transformer block 5
✅ Patched transformer block 6
✅ Patched transformer block 7
✅ Patched transformer block 8
✅ Patched transformer block 9
✅ Patched transformer block 10
✅ Patched transformer block 11
✅ Patched transformer block 12
✅ Patched transformer block 13
✅ Patched transformer block 14
✅ Patched transformer block 15
✅ Patched transformer block 16
✅ Patched transformer block 17
✅ Patched transformer block 18
✅ Patched transformer block 19


In [43]:
def print_shape_hook(name):
    def hook(module, input, output):
        print(f"name: {name}")
        print(type(input))
        print(len(input))
        print(output)
        if isinstance(output, tuple):
            for i, o in enumerate(output):
                if isinstance(o, torch.Tensor):
                    print(f"{module.__class__.__name__} output[{i}] shape: {tuple(o.shape)}")
        elif isinstance(output, torch.Tensor):
            print(f"{module.__class__.__name__} output shape: {tuple(output.shape)}")
        else:
             print(f"{module.__class__.__name__}, {type(output)}")
        print("\n")
        
    return hook

    

In [44]:
hook_handles = []

hook_handles.append(pipe.transformer.transformer_blocks[19].identity_after_attn.register_forward_hook(print_shape_hook("identity_after_attn")))
hook_handles.append(pipe.transformer.transformer_blocks[19].identity_after_ff.register_forward_hook(print_shape_hook("identity_after_ff")))

In [45]:
prompt = 'A red apple floating below a yellow banana on a white background.'
images = pipe(
    prompt=prompt,
    height=1024,
    width=1024,
    guidance_scale=5.0,
    num_inference_steps=20,
    generator=torch.Generator(device="cuda").manual_seed(42),)[0]

#image[0].save("sana4.png")

  0%|          | 0/20 [00:00<?, ?it/s]

name: identity_after_attn
<class 'tuple'>
1
Identity output shape: (2, 1024, 2240)


name: identity_after_attn
<class 'tuple'>
1
tensor([[[-1.3878e-02,  1.8644e-04, -2.1805e-02,  ..., -1.0567e-02,
          -1.2139e+00,  3.5919e-02],
         [ 2.8431e-05, -1.0471e-03, -4.3411e-03,  ..., -1.4725e-02,
          -1.3008e+00,  4.9164e-02],
         [-1.5507e-03,  2.8172e-03, -1.4629e-03,  ...,  1.5574e-03,
          -1.1680e+00,  2.4734e-02],
         ...,
         [-1.4511e-02, -3.4666e-04, -9.5825e-03,  ...,  7.3776e-03,
          -8.8379e-01,  3.8940e-02],
         [ 9.7122e-03,  8.9598e-04, -1.2833e-02,  ...,  1.1124e-02,
          -1.1113e+00,  4.4220e-02],
         [-9.3765e-03, -3.7384e-03, -9.5596e-03,  ..., -3.9368e-02,
          -6.9238e-01,  1.6434e-02]],

        [[-1.3878e-02,  1.8644e-04, -2.1805e-02,  ..., -1.0567e-02,
          -1.2139e+00,  3.5919e-02],
         [ 2.8431e-05, -1.0471e-03, -4.3411e-03,  ..., -1.4725e-02,
          -1.3008e+00,  4.9164e-02],
         [-1.55