In [None]:
from diffusers import AutoencoderDC
import torch

tiny_ae = AutoencoderDC(
    in_channels=3,
    latent_channels=4,
    attention_head_dim=32,
    encoder_block_types=[
        "ResBlock",
        "ResBlock",
        "EfficientViTBlock",
        "EfficientViTBlock",
    ],
    decoder_block_types=[
        "ResBlock",
        "ResBlock",
        "EfficientViTBlock",
        "EfficientViTBlock",
    ],
    encoder_block_out_channels=[64, 64, 64, 64],
    decoder_block_out_channels=(64, 64, 64, 64),
    encoder_layers_per_block=(1, 2, 3, 3),
    decoder_layers_per_block=(3, 3, 3, 1),
    encoder_qkv_multiscales=((), (), (5,), (5,)),
    decoder_qkv_multiscales=((), (), (5,), (5,)),
    upsample_block_type="interpolate",
    downsample_block_type="Conv",
    decoder_norm_types="rms_norm",
    decoder_act_fns="silu",
    scaling_factor=0.41407,
)
tiny_ae = tiny_ae.eval().cuda()
tiny_ae = tiny_ae.to(torch.float16)
tiny_ae.requires_grad_(False)


In [None]:
from diffusers import UNet2DModel

unet2d_config = {
    "sample_size": 64,
    "in_channels": 4,
    "out_channels": 4,
    "center_input_sample": False,
    "time_embedding_type": "positional",
    "freq_shift": 0,
    "flip_sin_to_cos": True,
    "down_block_types": ("DownBlock2D", "DownBlock2D", "DownBlock2D"),
    "up_block_types": ("UpBlock2D", "UpBlock2D", "UpBlock2D"),
    "block_out_channels": [320, 640, 1280],
    "layers_per_block": 1,
    "mid_block_scale_factor": 1,
    "downsample_padding": 1,
    "downsample_type": "conv",
    "upsample_type": "conv",
    "dropout": 0.0,
    "act_fn": "silu",
    "norm_num_groups": 32,
    "norm_eps": 1e-05,
    "resnet_time_scale_shift": "default",
    "add_attention": False,
}


unet = UNet2DModel(**unet2d_config).to("cuda").to(torch.float16)
unet.requires_grad_(False)
""

In [17]:
c_t = torch.randn(
    [1, 3, 512, 512],
    # [1, 3, 720, 1280],
    device="cuda",
    dtype=torch.float16,
)
timesteps = torch.tensor([999], device="cuda").long()

In [18]:
encoded_control = tiny_ae.encode(c_t, False)[0] * tiny_ae.config.scaling_factor
# model_pred = model.simple_unet(
#     encoded_control,
#     timesteps,
# )

In [15]:
encoded_control.shape

torch.Size([1, 4, 64, 64])

In [23]:
unet(
    encoded_control,
    timesteps,
    return_dict=False,
)[0]

tensor([[[[-0.0953,  0.1296, -0.2114,  ..., -0.1379,  0.0031, -0.0450],
          [-0.1094,  0.2783,  0.0085,  ...,  0.1550,  0.1338,  0.0625],
          [ 0.1063,  0.4536,  0.0635,  ...,  0.3228,  0.1267,  0.1602],
          ...,
          [ 0.0596,  0.5068,  0.1382,  ...,  0.3899,  0.4070,  0.1794],
          [ 0.0935,  0.0449,  0.0568,  ...,  0.2219,  0.3391,  0.1384],
          [ 0.2146,  0.0825,  0.2273,  ...,  0.1475,  0.3501,  0.0595]],

         [[ 0.0799,  0.2280,  0.2698,  ..., -0.1169, -0.0296,  0.0739],
          [-0.1298, -0.1722, -0.3428,  ..., -0.3320, -0.0090, -0.2155],
          [ 0.0288, -0.0420, -0.0370,  ..., -0.0316, -0.3037, -0.0642],
          ...,
          [-0.1481,  0.0814, -0.0603,  ...,  0.0187,  0.0584, -0.1100],
          [-0.1915,  0.0198,  0.0429,  ...,  0.1971, -0.0712, -0.0511],
          [-0.0853,  0.0863,  0.1594,  ...,  0.0923,  0.0843,  0.0395]],

         [[ 0.1591,  0.3323,  0.1899,  ...,  0.3755,  0.1542,  0.2075],
          [ 0.1565,  0.0635,  