Attend Locally, Remember Linearly: Linear Attention as Cross-Frame Memory for Autoregressive Video Diffusion
ARL² replaces quadratic cross-frame softmax attention in autoregressive video diffusion with a fixed-size recurrent linear state, achieving 2.26x wall-clock speedup and 54% memory reduction while maintaining comparable quality and improved temporal consistency. Our hybrid attention decomposes self-attention into an intra-frame softmax branch (spatial detail) and an inter-frame gated recurrent linear branch (temporal memory).
- Hybrid Attention: Intra-block softmax + inter-block Gated Delta Network (GDN) with block-level query
- Clean-state update: Recurrent state updated only post-denoising to prevent noise corruption
- Scalable: Constant memory for cross-frame attention regardless of video length
- Two-stage training: Per-layer distillation (Stage 1) followed by end-to-end teacher distillation (Stage 2)
- Compatible: Built on Causal Forcing + Wan 2.1 backbone
conda create -n arl2 python=3.10 -y
conda activate arl2
pip install -r requirements.txt
pip install git+https://github.com/openai/CLIP.git
pip install flash-attn --no-build-isolation
pip install triton fla
python setup.py developBase models (Wan 2.1 + Causal Forcing):
huggingface-cli download Wan-AI/Wan2.1-T2V-1.3B --local-dir wan_models/Wan2.1-T2V-1.3B
huggingface-cli download zhuhz22/Causal-Forcing chunkwise/causal_forcing.pt --local-dir checkpointsARL² hybrid attention checkpoints (coming soon):
# huggingface-cli download lky-ang/ARL2 ... --local-dir checkpointspython inference.py \
--config_path configs/causal_forcing_dmd_chunkwise.yaml \
--output_folder output/chunkwise \
--checkpoint_path checkpoints/chunkwise/causal_forcing.pt \
--data_path prompts/test.txt \
--hybrid_layers 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21Key design choices:
- Block-level query: All tokens in a block see the same recurrent state, ensuring spatial consistency
- Clean-state update: State is updated only after denoising completes, preventing noise from corrupting the temporal memory
- Headwise sigmoid gate: Learned per-head gate balances intra vs. inter contributions
Training code will be released in a future update. The training pipeline consists of:
- Stage 1: Per-layer distillation — align hybrid attention outputs to original softmax attention via MSE
- Stage 2: End-to-end teacher distillation — frozen teacher (original model) supervises the hybrid student
This codebase is built on top of:
- Causal Forcing (Zhu et al.) — autoregressive diffusion distillation framework
- Wan 2.1 (Alibaba) — base video diffusion model
- CausVid / Self Forcing — distillation infrastructure
- FLA — efficient linear attention kernels
If you find this work useful, please cite:
@article{li2026attend,
title={Attend Locally, Remember Linearly: Linear Attention as Cross-Frame Memory for Autoregressive Video Diffusion},
author={Li, Kunyang and Shah, Mubarak and Shang, Yuzhang},
journal={arXiv preprint arXiv:2605.16579},
year={2026}
}This project is licensed under the Apache License 2.0. See LICENSE for details.
Portions of this codebase are derived from Causal Forcing (Apache 2.0) and Wan 2.1 (Apache 2.0). See NOTICE for third-party attributions.
