In [1]:
from diffusers import (
    AutoencoderTiny,
    UNet2DModel,
    AutoModel,
    FlowMatchEulerDiscreteScheduler,
)
import torch

unet2d_config = {
    "sample_size": 32,
    "in_channels": 128,
    "out_channels": 128,
    "center_input_sample": False,
    "time_embedding_type": "positional",
    "freq_shift": 0,
    "flip_sin_to_cos": True,
    "down_block_types": ("DownBlock2D", "DownBlock2D", "DownBlock2D"),
    "up_block_types": ("UpBlock2D", "UpBlock2D", "UpBlock2D"),
    "block_out_channels": [320, 640, 1280],
    "layers_per_block": 1,
    "mid_block_scale_factor": 1,
    "downsample_padding": 1,
    "downsample_type": "conv",
    "upsample_type": "conv",
    "dropout": 0.0,
    "act_fn": "silu",
    "norm_num_groups": 32,
    "norm_eps": 1e-05,
    "resnet_time_scale_shift": "default",
    "add_attention": False,
}
device = "cuda"
dtype = torch.float16
unet = UNet2DModel(**unet2d_config).to(device, dtype=dtype)
vae_name = "fal/FLUX.2-Tiny-AutoEncoder"
vae = AutoModel.from_pretrained(vae_name, trust_remote_code=True).to(
    device, dtype=dtype
)

unet.requires_grad_(False)
vae.requires_grad_(False)
None

  from .autonotebook import tqdm as notebook_tqdm
`trust_remote_code` is enabled. Downloading code from fal/FLUX.2-Tiny-AutoEncoder. Please ensure you trust the contents of this repository
'(ReadTimeoutError("HTTPSConnectionPool(host='huggingface.co', port=443): Read timed out. (read timeout=10)"), '(Request ID: e2bf7981-47be-4167-868c-99e6cef56d67)')' thrown while requesting HEAD https://huggingface.co/fal/FLUX.2-Tiny-AutoEncoder/resolve/main/flux2_tiny_autoencoder.py
Retrying in 1s [Retry 1/5].
'(ReadTimeoutError("HTTPSConnectionPool(host='huggingface.co', port=443): Read timed out. (read timeout=10)"), '(Request ID: c9d5f958-208d-4d27-9081-c4e47dfbf99e)')' thrown while requesting HEAD https://huggingface.co/fal/FLUX.2-Tiny-AutoEncoder/resolve/main/config.json
Retrying in 1s [Retry 1/5].
The config attributes {'auto_map': {'AutoModel': 'flux2_tiny_autoencoder.Flux2TinyAutoEncoder'}} were passed to Flux2TinyAutoEncoder, but are not expected and will be ignored. Please verify your conf

In [2]:
image_tensor = torch.randn(
    1,
    3,
    512,
    512,
    device=device,
    dtype=dtype,
)
latents = vae.encode(image_tensor).latent
print("latents", latents.shape)
noise_scheduler = FlowMatchEulerDiscreteScheduler()
t = noise_scheduler.timesteps[0].to(device)
unet_pred = unet(
    latents,
    t,
    return_dict=False,
)[0]
print("unet_pred", unet_pred.shape)
decoded = vae.decode(unet_pred, return_dict=False)[0]
print(decoded.shape)

latents torch.Size([1, 128, 32, 32])
unet_pred torch.Size([1, 128, 32, 32])
torch.Size([3, 512, 512])
