Official PyTorch implementation of SARA, a representation-alignment objective for video diffusion models that decides which token pairs carry token-relation distillation supervision through a text-conditioned saliency.
TL;DR. SARA recasts representation alignment as a pair-routing problem. A frozen text-conditioned saliency aligner (Stage 1, supervised by per-entity SAM 3.1 masks + InfoNCE) tells token-relation distillation (Stage 2) which token pairs to weight, focusing the loss on subject-subject and subject-background relations rather than background filler.
SARA routes representation-alignment supervision by the prompt, not by pixels. (1) Token-relation distillation matches the pairwise token relations of the DiT features Vp to a frozen VFM Vy, weighting all O(N²) pairs equally — so ~27% of the budget is spent on background-background pairs that seldom carry the subject interactions the prompt describes. (2) A frozen text-conditioned Semantic Aligner Φ predicts a per-token saliency, and the OR pair-routing operator Wij = wi + wj − wiwj keeps adaptive weights on FG-FG and FG-BG pairs while dropping BG-BG.
SARA is a two-stage framework:
Stage I trains a lightweight Semantic Aligner on top of frozen V-JEPA 2.1 visual features. Given per-entity captions, it learns a saliency head (supervised by SAM 3.1 masks) and a global alignment head (supervised by InfoNCE against Qwen3-VL-Embedding). The saliency head outputs a per-patch mask (M_p) that identifies which spatial-temporal tokens belong to foreground entities.
Stage II uses the frozen Semantic Aligner to guide Token-Relation Distillation (TRD) during continual training of a Wan2.2 DiT. For every pair of tokens ((i, j)), the OR-routing weight (W_{ij} = w_i + w_j - w_i w_j) ensures all pairs involving at least one foreground token receive supervision, while pure background-background pairs are down-weighted.
(a) Foreground tokens make up roughly half of the V-JEPA grid. (b) OR routing allocates ~70% of the pair budget to semantically meaningful FG-FG and FG-BG pairs, compared to only 26% with AND masking, while avoiding the sparsity of XOR.
Stage 1 uses SAM 3.1 Multiplex to extract per-entity segmentation masks across all video frames. Each entity (person, object, background) is tracked independently, providing spatially precise supervision for the saliency head.
For each entity, Stage 1 produces: (row 1) video frames with SAM 3.1 GT mask overlay, (row 2) PCA of enhanced V-JEPA features after cross-attention with the entity caption, (row 3) predicted saliency heatmap from the saliency head. The saliency head learns to highlight the correct entity region without explicit spatial supervision at test time.
At Stage 2 inference, the frozen Semantic Aligner produces saliency masks conditioned on the full MTSS caption (including global setup, cast, scene, and shot descriptions). The resulting mask routes the TRD loss to focus on semantically rich token pairs during training.
| Capability | Entry point | Description |
|---|---|---|
| SARA Stage 1 | sara/stage1/train.py |
Train the text-conditioned saliency aligner from per-entity SAM 3.1 masks + InfoNCE on top of frozen V-JEPA + Qwen3-VL-Embedding backbones. |
| SARA Stage 2 | sara/stage2/train.py |
FSDP + Sequence Parallel continual training of Wan2.2 with the OR-routed masked TRD loss, using the frozen Stage 1 aligner. |
| Stage 1 Inference | sara/stage1/infer.py |
Saliency / cross-attention visualization for the trained Stage 1 aligner. |
| Stage 2 Inference | sara/stage2/inference.py |
Distributed text-to-video sampling with the trained DiT. |
git clone https://github.com/lian700/SARA.git
cd SARA
python3 -m venv .venv && source .venv/bin/activate
pip install -e .
# Install GPU extras manually (versions depend on your CUDA toolkit):
# pip install flash-attn --no-build-isolation
# pip install xformersSee docs/ENVIRONMENT.md for full instructions, including
how to fetch V-JEPA 2.1, SAM 3.1, Wan2.2, T5, and Qwen3-VL-Embedding weights.
All paths (data, weights, source trees, outputs) are read from environment variables. Copy the template and edit it for your machine:
cp .env.example .env
# then edit .env: SARA_DATA_ROOT, SARA_CKPT_ROOT, VJEPA2_ROOT, SARA_WAN_SRC, ...scripts/_common.sh sources .env automatically, and the YAML configs in
configs/ expand the ${VAR} placeholders. See
docs/ENVIRONMENT.md for the full variable reference and
docs/DATA.md for the dataset manifest layout.
# Optional — start a Qwen3-VL prompt-simplifier server on a separate node
bash scripts/run_sglang_server.sh
# Stage 1 — saliency aligner (single-node 8 GPU, ~24 h on H100)
bash scripts/train_stage1.sh
# Stage 2 — SARA (multi-node, requires SARA_STAGE1_CKPT from Stage 1)
export SARA_STAGE1_CKPT=/path/to/stage1/checkpoint_step0003000
bash scripts/train_stage2_sara.sh
# Stage 1 inference — saliency visualization
export SARA_STAGE1_CKPT=/path/to/stage1/checkpoint_step0003000
bash scripts/infer_stage1.sh
# Stage 2 inference — text-to-video generation with the trained DiT
bash scripts/infer_stage2.shResume training from a previous checkpoint:
RESUME_CKPT=/path/to/checkpoint bash scripts/train_stage1.shMulti-node training is automatic via a hostfile (one IP per line).
Place it at ./hosts (or export HOSTFILE=/path/to/hosts), then
run the same script on every node:
pssh -i -t 0 -h ./hosts bash scripts/train_stage2_sara.shThe script auto-detects each node's rank, sets MASTER_ADDR to the first
host, and configures NCCL parameters. On multi-node runs the training process
is launched in the background via nohup; monitor with
tail -f <output_dir>/logs/*.log.
Each script is a thin wrapper around python -m sara.cli that expands the
matching YAML in configs/ into argparse flags. Override any flag from the
command line, e.g. bash scripts/train_stage2_sara.sh --mask_pair_mode and
to reproduce the AND ablation row.
.
├── sara/
│ ├── stage1/ # Saliency aligner training
│ │ ├── train.py
│ │ ├── infer.py
│ │ ├── models/ # SAM3TextAligner + cross-attention fusion
│ │ ├── data/ # video + caption + per-entity prompt loader
│ │ └── utils/ # SAM/encoder/checkpoint/lr/bucket/qwen-sglang
│ │
│ ├── stage2/ # Wan2.2 continual training (SARA)
│ │ ├── train.py # FSDP + Ulysses Sequence Parallel
│ │ ├── inference.py
│ │ ├── models/
│ │ │ ├── repa_loss.py # SARA OR-routed TRD objective
│ │ │ ├── repa_encoder.py
│ │ │ ├── semantic_encoder.py
│ │ │ ├── sam3_mask_extractor.py
│ │ │ └── model_sp.py # Wan model wrapper with REPA hooks
│ │ ├── data/
│ │ └── utils/ # parallel_states, communication, wan_imports
│
├── scripts/ # Entry-point shell launchers
│ ├── train_stage1.sh # Stage 1 training
│ ├── train_stage2_sara.sh # Stage 2 training (requires SARA_STAGE1_CKPT)
│ ├── infer_stage1.sh # Stage 1 saliency visualization
│ ├── infer_stage2.sh # Stage 2 T2V inference
│ ├── run_sglang_server.sh # Qwen3-VL prompt simplifier
│ └── _common.sh # Shared bootstrap (multi-node, NCCL, etc.)
│
├── configs/ # YAML configs for every experiment
├── assets/ # Sample data + figures
├── docs/ # ENVIRONMENT, DATA
├── pyproject.toml
├── requirements.txt
└── .env.example
See docs/DATA.md for the full JSON manifest schema.
A minimal example is provided in assets/test.json.
This codebase builds on Wan2.2 (Tongyi-Wanxiang), V-JEPA 2.1 (Meta FAIR), SAM 3.1 (Meta FAIR), Qwen3-VL-Embedding (Qwen), the MTSS (Multi-Stream Scene Script, Tencent Hunyuan), and the REPA, MoAlign family of representation-alignment objectives. We thank the authors of all of these for releasing their weights and code.
@article{lian2026sara,
title = {SARA: Semantically Adaptive Relational Alignment for Video Diffusion Models},
author = {Lian, Jiesong and Zhou, Zixiang and Zhong, Ruizhe and Zhou, Yuan and
Lu, Qinglin and Wang, Rui and Hu, Long and Hao, Yixue and Huang, Baoru},
journal = {arXiv preprint arXiv:2605.07800},
year = {2026}
}




