In [2]:
import torch
from diffusers import AutoencoderDC
dc_vae = AutoencoderDC.from_pretrained(r"G:\code\model\dc-ae-f32c32-sana-1.1-diffusers",
                                         torch_dtype=torch.float16).to('cuda')


SANA1.1中使用的DC-AEf32c32参数为312M

In [None]:
# 统计参数
total_params = sum(p.numel() for p in dc_vae.parameters())
print(f'{total_params:,} total parameters.')

In [3]:
from diffusers import AutoencoderKL
import torch
xl_vae = AutoencoderKL.from_pretrained("g:code/model/stable-diffusion-xl-base-1.0", 
                                      subfolder="vae",
                                      torch_dtype=torch.float16).to('cuda')


sdxl的VAE(0.9 version)83M参数

In [None]:
total_params = sum(p.numel() for p in xl_vae.parameters())
print(f'{total_params:,} total parameters.')

In [21]:
# 比较SDXL-VAE和DC-AE的结构差异
print("=== SDXL-VAE结构 ===")
for name, child in xl_vae.named_children():
    print(f"SDXL-VAE: {name}")
print('-'*50)

print("=== DC-AE结构 ===")
print("整体结构:")
for name, child in dc_vae.named_children():
    print(f"DC-AE: {name}")
print('-'*50)

print("编码器结构:")
for name, child in dc_vae.encoder.named_children():
    print(f"DC-AE Encoder: {name}")
print('-'*50)

print("解码器结构:")
for name, child in dc_vae.decoder.named_children():
    print(f"DC-AE Decoder: {name}")

=== SDXL-VAE结构 ===
SDXL-VAE: encoder
SDXL-VAE: decoder
SDXL-VAE: quant_conv
SDXL-VAE: post_quant_conv
--------------------------------------------------
=== DC-AE结构 ===
整体结构:
DC-AE: encoder
DC-AE: decoder
--------------------------------------------------
编码器结构:
DC-AE Encoder: conv_in
DC-AE Encoder: down_blocks
DC-AE Encoder: conv_out
--------------------------------------------------
解码器结构:
DC-AE Decoder: conv_in
DC-AE Decoder: up_blocks
DC-AE Decoder: norm_out
DC-AE Decoder: conv_act
DC-AE Decoder: conv_out


- **SDXL-VAE**
    - SDXL-VAE.encoder
        - conv_in
        - down_blocks
        - mid_block
        - conv_norm_out
        - conv_act
        - conv_out
    - SDXL-VAE.decoder
        - conv_in
        - up_blocks
        - mid_block
        - conv_norm_out
        - conv_act
        - conv_out
    - SDXL-VAE.quant_conv:Conv2d(8, 8, kernel_size=(1, 1), stride=(1, 1))
    - SDXL-VAE.post_quant_conv:Conv2d(4, 4, kernel_size=(1, 1), stride=(1, 1))

- **DC-AE f32c32**
    - DC-AE Encoder
        - conv_in
        - down_blocks
            - 0-2:ResBlock*2+DCdownBlock
            - 3-4:EfficientViTBlock*3+DCDownBlock
            - 5:EfficientViTBlock*3
        - conv_out
    - DC-AE Decoder
        - conv_in
        - up_blocks
            - 0-2:DCUpBlock2d+ResBLock*2
            - 3-4:DCUpBlock2d+EfficientViTBLock*3
            - 5:EfficientViTBlock*3
        - norm_out:RMSNorm()
        - conv_act:ReLu()
        - conv_out

In [91]:
dc_vae.encoder.down_blocks[3][0]

EfficientViTBlock(
  (attn): SanaMultiscaleLinearAttention(
    (to_q): Linear(in_features=512, out_features=512, bias=False)
    (to_k): Linear(in_features=512, out_features=512, bias=False)
    (to_v): Linear(in_features=512, out_features=512, bias=False)
    (to_qkv_multiscale): ModuleList(
      (0): SanaMultiscaleAttentionProjection(
        (proj_in): Conv2d(1536, 1536, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=1536, bias=False)
        (proj_out): Conv2d(1536, 1536, kernel_size=(1, 1), stride=(1, 1), groups=48, bias=False)
      )
    )
    (nonlinearity): ReLU()
    (to_out): Linear(in_features=1024, out_features=512, bias=False)
    (norm_out): RMSNorm()
  )
  (conv_out): GLUMBConv(
    (nonlinearity): SiLU()
    (conv_inverted): Conv2d(512, 4096, kernel_size=(1, 1), stride=(1, 1))
    (conv_depth): Conv2d(4096, 4096, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=4096)
    (conv_point): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False

In [5]:
from diffusers import AutoencoderDC
# 我想要查看AutoencoderDC的适配模型
# 尝试查看类本身的文档字符串，看是否有相关信息
print(AutoencoderDC.__doc__)


    An Autoencoder model introduced in [DCAE](https://arxiv.org/abs/2410.10733) and used in
    [SANA](https://arxiv.org/abs/2410.10629).

    This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
    for all models (such as downloading or saving).

    Args:
        in_channels (`int`, defaults to `3`):
            The number of input channels in samples.
        latent_channels (`int`, defaults to `32`):
            The number of channels in the latent space representation.
        encoder_block_types (`Union[str, Tuple[str]]`, defaults to `"ResBlock"`):
            The type(s) of block to use in the encoder.
        decoder_block_types (`Union[str, Tuple[str]]`, defaults to `"ResBlock"`):
            The type(s) of block to use in the decoder.
        encoder_block_out_channels (`Tuple[int, ...]`, defaults to `(128, 256, 512, 512, 1024, 1024)`):
            The number of output channels for each block in the encoder.
      