ICML 2026
Paper | Project Page | Code
Sparse ActionGen (SAG) accelerates transformer-based Diffusion Policy by predicting rollout-adaptive pruning masks before the denoising process and reusing cached activations for pruned computations. SAG targets the closed-loop action generation bottleneck in visuomotor control, where fixed caching schedules cannot adapt to changing robot-environment observations.
- Real-time Diffusion Pruner: Observation-conditioned pruner that predicts a global timestep-by-block sparsity pattern in one forward pass
- One-for-All Reusing Strategy: Shared activation cache that reuses residual computations across both denoising timesteps and transformer blocks
- Global Sparsity Objective: End-to-end sparsity loss that allocates computation under a strict global pruning budget
- Plug-and-Play Acceleration: Wraps pretrained Diffusion Policy checkpoints without updating the policy parameters
git clone --recursive https://github.com/ky-ji/SAG.git
cd SAG
git submodule update --init --recursive
conda env create -f conda_environment.yaml
conda activate robodiff
pip install -e diffusion_policy
pip install -e .This repository keeps the SAG inference and pruner training code in SAGInfer/. It is designed to work with pretrained Diffusion Policy transformer checkpoints and datasets from the original Diffusion Policy project.
Supported benchmark tasks include can_ph, can_mh, lift_ph, lift_mh, square_ph, square_mh, transport_ph, transport_mh, tool_hang_ph, kitchen, block_pushing, and pusht.
Download datasets and pretrained checkpoints following the Diffusion Policy instructions. By default, this code expects checkpoints under:
checkpoint/<task_name>/diffusion_policy_transformer/train_<id>/checkpoints/latest.ckpt
checkpoint/low_dim/<task_name>/diffusion_policy_transformer/train_<id>/checkpoints/latest.ckpt
You can also pass an explicit checkpoint path to each script.
python -m SAGInfer.scripts.create_pruner_dataset \
--task can_ph \
--checkpoint auto \
--num_train 5600 \
--num_valid 640 \
--output_dir pruner_dataThe generated data is saved under pruner_data/<task>/seed<seed>/train<num_train>/.
python -m SAGInfer.scripts.train_pruner \
--task_name can_ph \
--config SAGInfer/pruner_config/training_config.yaml \
--device cuda:0 \
--train_version 0The default config follows the paper setting: target pruning ratio 0.91, batch size 32, learning rate 1e-4, 30 epochs, and 5,600 / 640 train / validation samples.
python -m SAGInfer.scripts.eval_pruner \
--task_name can_ph \
--checkpoint auto \
--timestamp <train_timestamp> \
--epoch <best_epoch> \
--device cuda:0 \
--skip_videoThe evaluation script reports latency, speedup, FLOPs estimates, rollout success metrics, and optional pruning visualizations.
SAG/
|-- SAGInfer/ # SAG acceleration and pruner code
| |-- acceleration/ # One-for-all cache/pruner wrapper
| |-- pruner/ # Pruner model, training, and evaluation utilities
| |-- pruner_config/ # Public default training config
| `-- scripts/ # Dataset, training, and evaluation entrypoints
|-- diffusion_policy/ # Base Diffusion Policy implementation (submodule)
|-- figure/ # Method and sparsity visualizations
|-- checkpoint/ # Diffusion Policy checkpoints (not included)
|-- pruner_data/ # Generated pruner datasets (not included)
|-- output/ # Training/evaluation outputs (not included)
`-- conda_environment.yaml # Reproduction environment
If you find SAG useful for your research, please consider citing our paper:
@article{ji2026sparse,
title={Sparse ActionGen: Accelerating Diffusion Policy with Real-time Pruning},
author={Ji, Kangye and Meng, Yuan and Jianbo, Zhou and Li, Ye and Cui, Hanyun and Wang, Zhi},
journal={arXiv preprint arXiv:2601.12894},
year={2026}
}This work builds on:
- Diffusion Policy - Base diffusion policy implementation
- robomimic - Robot learning framework and benchmarks
We thank the authors of these projects for their open-source contributions.
This project is licensed under the MIT License. See LICENSE for details.
