相关论文
- 《Sana: Efficient High-Resolution Image Synthesis with Linear Diffusion Transformer》
- 《Sana1.5:Efficient Scaling of Training-Time and
Inference-Time Compute in Linear Diffusion Transformer》

# Github仓库README信息简要
- 标准模型：0.6B 模型需 9GB VRAM，1.6B 模型需 12GB VRAM.训练需要32GB VRAM。4 位量化模型最低需 8GB VRAM。
- Model 
    - Sana-1.5 实现了高效的模型扩展策略，在提高质量的同时保持了合理的计算需求。它引入了深度增长范式和模型剪枝技术
    - Sana-Sprint 是一种专注于时间步蒸馏的专门变体，仅需 1-4 次推理步骤即可实现高质量的生成，显著减少了生成时间。
- Inference & Test Metrics (FID, CLIP Score, GenEval, DPG-Bench, etc...)
- **Inference Scaling**:Sana 可以使用专门的 NVILA-2B 模型（称为 VISA）对候选图像进行评分，从而从多个生成中选择最高质量的结果，显著提高性能指标.(这算哪门子Inference Scaling.....)


| Model | 分辨率 | 参数（B） | 延迟（秒） | 吞吐量（张/秒） | FID ↓ | CLIP 得分 ↑ | GenEval ↑ |
| --- | --- | --- | --- | --- | --- | --- | --- |
| Sana-0.6B | 1024×1024 | 0.6 | 0.9 | 1.7 | 5.61 | 28.80 | 0.68 |
| Sana-1.6B | 1024×1024 | 1.6 | 1.2 | 1.0 | 5.76 | 28.67 | 0.66 |
| Sana-1.5-1.6B | 1024×1024 | 1.6 | 1.2 | 1.0 | 5.70 | 29.12 | 0.82 |
| Sana-1.5-4.8B | 1024×1024 | 4.8 | 4.2 | 0.26 | 5.99 | 29.23 | 0.81 |

- SANA-0.6B
    - 512px
    - 1024px
    - ControlNet
- SANA-1.6B
    - 512px
    - 1024px
    - 2Kpx
    - 4Kpx
    - ControlNet
- SANA1.5-1.6B
- SANA1.5-4.8B
- SANA-Sprint
    - Sana-Sprint-0.6B
    - Sana-Sprint-1.6B

| Model Variant | Depth | Hidden Size | Patch Size | Num Heads | Parameters |
|---|---|---|---|---|---|
| Sana-0.6B | 28 | 1152 | 1 or 2 | 16 | 600M |
| Sana-1.6B | 20 | 2240 | 1 or 2 | 20 | 1.6B |
| Sana-4.8B | 60 | 2240 | 1 | 20 | 4.8B |

# 模型架构剖析

![alt text](../Image/sana.png)

## Linear DiT

**标准自注意力（Self-Attention）的时间复杂度**:  $O(n^2 d)$ <p>
**线性注意力的时间复杂度**：$O(nd^2)$，如果d固定不变就是关于n线性的<p>
其中，$n$ 表示序列长度（例如，图像块的数量），$d$ 表示隐藏维度。


In [None]:
from diffusers import SanaTransformer2DModel
import torch
# 根据错误提示，尝试添加 low_cpu_mem_usage=False 和 device_map=None 参数
model=SanaTransformer2DModel.from_pretrained(r"G:\code\model\SANA1.5_1.6B_1024px_diffusers\transformer",
                                            torch_dtype=torch.bfloat16,
                                            low_cpu_mem_usage=True,
                                            device_map=None).to("cuda")


In [2]:
# 统计模型参数
total_params = sum(p.numel() for p in model.parameters())
print(f"加载的 DiT 模型总参数量: {total_params:,} 个")

加载的 DiT 模型总参数量: 1,604,462,752 个


可以看出1.6B的Sana是20个相同的Transformer块堆叠而成

In [5]:
for name, child in model.named_children():
    print(name) # 打印当前模块的名称
    # 遍历当前模块的子模块
    for sub_name, sub_child in child.named_children():
        print(f"  - {sub_name}") # 打印子模块的名称，使用缩进表示层级关系

patch_embed
  - proj
time_embed
  - emb
  - silu
  - linear
caption_projection
  - linear_1
  - act_1
  - linear_2
caption_norm
transformer_blocks
  - 0
  - 1
  - 2
  - 3
  - 4
  - 5
  - 6
  - 7
  - 8
  - 9
  - 10
  - 11
  - 12
  - 13
  - 14
  - 15
  - 16
  - 17
  - 18
  - 19
norm_out
proj_out


In [6]:
for name,child in model.transformer_blocks.named_children():
    print(type(child))

<class 'diffusers.models.transformers.sana_transformer.SanaTransformerBlock'>
<class 'diffusers.models.transformers.sana_transformer.SanaTransformerBlock'>
<class 'diffusers.models.transformers.sana_transformer.SanaTransformerBlock'>
<class 'diffusers.models.transformers.sana_transformer.SanaTransformerBlock'>
<class 'diffusers.models.transformers.sana_transformer.SanaTransformerBlock'>
<class 'diffusers.models.transformers.sana_transformer.SanaTransformerBlock'>
<class 'diffusers.models.transformers.sana_transformer.SanaTransformerBlock'>
<class 'diffusers.models.transformers.sana_transformer.SanaTransformerBlock'>
<class 'diffusers.models.transformers.sana_transformer.SanaTransformerBlock'>
<class 'diffusers.models.transformers.sana_transformer.SanaTransformerBlock'>
<class 'diffusers.models.transformers.sana_transformer.SanaTransformerBlock'>
<class 'diffusers.models.transformers.sana_transformer.SanaTransformerBlock'>
<class 'diffusers.models.transformers.sana_transformer.SanaTrans

主要的参数也是来自于这个Transformer块

In [7]:
# 统计模型参数
total_params = sum(p.numel() for p in model.transformer_blocks.parameters())
print(f"DiT中Transformer块模型总参数量: {total_params:,} 个")

DiT中Transformer块模型总参数量: 1,558,412,800 个


In [17]:
class SanaTransformerBlock(nn.Module):
    r"""
    Transformer block introduced in [Sana](https://huggingface.co/papers/2410.10629).
    """

    def __init__(
        self,
        dim: int = 2240,
        num_attention_heads: int = 70,
        attention_head_dim: int = 32,
        dropout: float = 0.0,
        num_cross_attention_heads: Optional[int] = 20,
        cross_attention_head_dim: Optional[int] = 112,
        cross_attention_dim: Optional[int] = 2240,
        attention_bias: bool = True,
        norm_elementwise_affine: bool = False,
        norm_eps: float = 1e-6,
        attention_out_bias: bool = True,
        mlp_ratio: float = 2.5,
    ) -> None:
        super().__init__()

        # 1. Self Attention
        self.norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=norm_eps)
        self.attn1 = Attention(
            query_dim=dim,
            heads=num_attention_heads,
            dim_head=attention_head_dim,
            dropout=dropout,
            bias=attention_bias,
            cross_attention_dim=None,
            processor=SanaLinearAttnProcessor2_0(),
        )

        # 2. Cross Attention
        if cross_attention_dim is not None:
            self.norm2 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
            self.attn2 = Attention(
                query_dim=dim,
                cross_attention_dim=cross_attention_dim,
                heads=num_cross_attention_heads,
                dim_head=cross_attention_head_dim,
                dropout=dropout,
                bias=True,
                out_bias=attention_out_bias,
                processor=AttnProcessor2_0(),
            )

        # 3. Feed-forward
        self.ff = GLUMBConv(dim, dim, mlp_ratio, norm_type=None, residual_connection=False)

        self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        encoder_attention_mask: Optional[torch.Tensor] = None,
        timestep: Optional[torch.LongTensor] = None,
        height: int = None,
        width: int = None,
    ) -> torch.Tensor:
        batch_size = hidden_states.shape[0]

        # 1. Modulation：基于时间步生成调制参数
        # 使用时间步生成六个调制参数：两个偏移(shift)、两个缩放(scale)和两个门控(gate)
        # 使用gate_msa和gate_mlp控制自注意力和前馈网络输出的影响程度
        shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
            self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
        ).chunk(6, dim=1)

        # 2. Self Attention
        norm_hidden_states = self.norm1(hidden_states)
        # 使用参数调制
        # ---------------------------------------------------------
        norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
        #-------------------------------------------------------------------
        norm_hidden_states = norm_hidden_states.to(hidden_states.dtype)

        attn_output = self.attn1(norm_hidden_states)
        # 使用参数调制
        # ---------------------------------------------------------
        hidden_states = hidden_states + gate_msa * attn_output
        #-------------------------------------------------------------------

        # 3. Cross Attention
        if self.attn2 is not None:
            attn_output = self.attn2(
                hidden_states,
                encoder_hidden_states=encoder_hidden_states,
                attention_mask=encoder_attention_mask,
            )
            hidden_states = attn_output + hidden_states

        # 4. Feed-forward
        norm_hidden_states = self.norm2(hidden_states)
        # 使用参数调制
        # ---------------------------------------------------------
        norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
        # ---------------------------------------------------------

        norm_hidden_states = norm_hidden_states.unflatten(1, (height, width)).permute(0, 3, 1, 2)
        ff_output = self.ff(norm_hidden_states)
        ff_output = ff_output.flatten(2, 3).permute(0, 2, 1)
        hidden_states = hidden_states + gate_mlp * ff_output

        return hidden_states


- 采用了类似PixArt的shift和scale参数调制特征，由时间步动态生成的
- 门控控制机制：引入gate_msa和gate_mlp两个门控参数，用于控制自注意力和前馈网络对最终特征的贡献程度
- 对Transformer块内不同组件(自注意力和前馈网络)分别进行调制，比简单的全局调制更精细和有效

## DC-AE

详见`dc-ae.ipynb`

## Flow-DPM-Slover