Minguk Kang1,2, Suha Kwak2
1Pika Labs, 2POSTECH
Real-time video generation demands fast decoding as much as fast denoising, yet current latent video diffusion models rely on 3D convolutional decoders that are slow and memory-intensive at high resolutions or for long video. We introduce FlashDecoder, a fast, memory-efficient pure-Transformer video decoder that decodes latents to pixels frame by frame. At each step, the current frame attends only to a fixed-size window of past frames through a rolling KV cache. The fixed temporal window keeps decoding fast and memory bounded regardless of video length, enabling constant-latency streaming.
- Pure-Transformer decoder for latent-to-pixel video decoding with frame-by-frame streaming
- Rolling KV cache with a fixed temporal window for constant memory and latency, regardless of video length
- No explicit attention masks needed -- causality is enforced by sequential processing order, enabling high-resolution training up to 1080p
- Matches convolutional decoder quality (e.g., 41.55 vs. 41.49 dB PSNR at 1080p on Wan2.2) while being 3.6x-4.7x faster with up to 11x lower GPU memory
- With architecture-aware inference optimizations, the throughput gap widens to up to 12x
FlashDecoder processes one latent frame at a time through:
- Backbone Transformer: GQA-based Transformer with 3D-RoPE, processing spatial tokens with a rolling KV cache (window size W_frm=2)
- Temporal-First Upsampling: Channel expansion reinterpreted as temporal upsampling (factor r_t), followed by Transformer-based temporal refinement
- Spatial Upsampling: MLP + PixelShuffle for final spatial resolution recovery
Evaluated on UltraVideo at 480p, 720p, and 1080p with 25-frame clips on a single H100 GPU:
| Method | Params (M) | PSNR (720p) | LPIPS (720p) | FPS (720p) | Mem (720p) |
|---|---|---|---|---|---|
| Wan2.2 | 555.0 | 38.29 | 0.04 | 16.1 | 19.3 |
| FlashDecoder-XL | 769.3 | 37.08 | 0.05 | 166.0 | 1.9 |
| FlashDecoder-XL-Opt | 769.3 | 37.02 | 0.05 | 166.0 | 1.6 |
| Model | Depth | Width | Heads | KV Groups | Params (M) |
|---|---|---|---|---|---|
| FlashDecoder-S | 12 | 512 | 8 | 2 | 56.8 |
| FlashDecoder-B | 16 | 768 | 12 | 3 | 161.7 |
| FlashDecoder-L | 20 | 1024 | 16 | 4 | 348.0 |
| FlashDecoder-XL | 20 | 1536 | 24 | 3 | 769.3 |