Skip to content

ky-ji/SAG

Repository files navigation

Sparse ActionGen: Accelerating Diffusion Policy with Real-time Pruning

Accelerating Diffusion Policy with Rollout-Adaptive Sparse Action Generation

arXiv Project Page GitHub License: MIT ICML 2026

ICML 2026

Paper | Project Page | Code

Sparse ActionGen Method Overview


Overview

Sparse ActionGen (SAG) accelerates transformer-based Diffusion Policy by predicting rollout-adaptive pruning masks before the denoising process and reusing cached activations for pruned computations. SAG targets the closed-loop action generation bottleneck in visuomotor control, where fixed caching schedules cannot adapt to changing robot-environment observations.

Key Features

  • Real-time Diffusion Pruner: Observation-conditioned pruner that predicts a global timestep-by-block sparsity pattern in one forward pass
  • One-for-All Reusing Strategy: Shared activation cache that reuses residual computations across both denoising timesteps and transformer blocks
  • Global Sparsity Objective: End-to-end sparsity loss that allocates computation under a strict global pruning budget
  • Plug-and-Play Acceleration: Wraps pretrained Diffusion Policy checkpoints without updating the policy parameters

Installation

git clone --recursive https://github.com/ky-ji/SAG.git
cd SAG

git submodule update --init --recursive

conda env create -f conda_environment.yaml
conda activate robodiff

pip install -e diffusion_policy
pip install -e .

Simulation Reproduction

This repository keeps the SAG inference and pruner training code in SAGInfer/. It is designed to work with pretrained Diffusion Policy transformer checkpoints and datasets from the original Diffusion Policy project.

Supported benchmark tasks include can_ph, can_mh, lift_ph, lift_mh, square_ph, square_mh, transport_ph, transport_mh, tool_hang_ph, kitchen, block_pushing, and pusht.

Step 1: Prepare Diffusion Policy Data and Checkpoints

Download datasets and pretrained checkpoints following the Diffusion Policy instructions. By default, this code expects checkpoints under:

checkpoint/<task_name>/diffusion_policy_transformer/train_<id>/checkpoints/latest.ckpt
checkpoint/low_dim/<task_name>/diffusion_policy_transformer/train_<id>/checkpoints/latest.ckpt

You can also pass an explicit checkpoint path to each script.

Step 2: Create Pruner Training Data

python -m SAGInfer.scripts.create_pruner_dataset \
  --task can_ph \
  --checkpoint auto \
  --num_train 5600 \
  --num_valid 640 \
  --output_dir pruner_data

The generated data is saved under pruner_data/<task>/seed<seed>/train<num_train>/.

Step 3: Train the SAG Pruner

python -m SAGInfer.scripts.train_pruner \
  --task_name can_ph \
  --config SAGInfer/pruner_config/training_config.yaml \
  --device cuda:0 \
  --train_version 0

The default config follows the paper setting: target pruning ratio 0.91, batch size 32, learning rate 1e-4, 30 epochs, and 5,600 / 640 train / validation samples.

Step 4: Evaluate SAG

python -m SAGInfer.scripts.eval_pruner \
  --task_name can_ph \
  --checkpoint auto \
  --timestamp <train_timestamp> \
  --epoch <best_epoch> \
  --device cuda:0 \
  --skip_video

The evaluation script reports latency, speedup, FLOPs estimates, rollout success metrics, and optional pruning visualizations.


Project Structure

SAG/
|-- SAGInfer/                         # SAG acceleration and pruner code
|   |-- acceleration/                 # One-for-all cache/pruner wrapper
|   |-- pruner/                       # Pruner model, training, and evaluation utilities
|   |-- pruner_config/                # Public default training config
|   `-- scripts/                      # Dataset, training, and evaluation entrypoints
|-- diffusion_policy/                 # Base Diffusion Policy implementation (submodule)
|-- figure/                           # Method and sparsity visualizations
|-- checkpoint/                       # Diffusion Policy checkpoints (not included)
|-- pruner_data/                      # Generated pruner datasets (not included)
|-- output/                           # Training/evaluation outputs (not included)
`-- conda_environment.yaml            # Reproduction environment

Citation

If you find SAG useful for your research, please consider citing our paper:

@article{ji2026sparse,
  title={Sparse ActionGen: Accelerating Diffusion Policy with Real-time Pruning},
  author={Ji, Kangye and Meng, Yuan and Jianbo, Zhou and Li, Ye and Cui, Hanyun and Wang, Zhi},
  journal={arXiv preprint arXiv:2601.12894},
  year={2026}
}

Acknowledgments

This work builds on:

We thank the authors of these projects for their open-source contributions.


License

This project is licensed under the MIT License. See LICENSE for details.

About

[ICML 2026] Sparse ActionGen: Accelerating Diffusion Policy with Real-time Pruning

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages