Official implementation of:
Strengthening Generative Robot Policies through Predictive World Modeling
Han Qi, Haocheng Yin, Aris Zhu, Yilun Du, Heng Yang
arXiv 2025
Paper: https://arxiv.org/abs/2502.00622
This repository provides the official implementation of the framework proposed in the paper:
We strengthen diffusion-based generative robot policies by integrating a predictive world model that enables long-horizon reasoning and improved robustness.
The framework consists of two main components:
-
Diffusion-based Action Policy
Generates action sequences using a generative diffusion model. -
Predictive World Model
Learns environment dynamics to evaluate and refine candidate action trajectories.
At inference time, the world model enhances policy performance through trajectory prediction and ranking/optimization.
We use push-T experiment as an example in code.
.
├── all_checkpoint/ # Pretrained checkpoints (policy + world model)
├── diffusion_policy_data/ # Training data for diffusion action policy
├── diffusion_policy_training/ # Training code for diffusion-based action policy
├── gpc_opt_evaluation/ # Evaluation with GPC-OPT (trajectory optimization)
├── gpc_rank_evaluation/ # Evaluation with GPC-RANK (trajectory ranking)
├── world_model_data/ # Training data for predictive world model
├── world_model_train_phase_one/ # Phase I: single-step world model warmup training
└── world_model_train_phase_two/ # Phase II: multi-step world model training
git clone https://github.com/han20192019/gpc_code.git
cd gpc_codeWe recommend using a clean conda environment:
conda env create -f environment.yml
conda activate gpc- Pretrained checkpoints:
https://huggingface.co/han2019/gpc_checkpoints/tree/main
Please download the checkpoints under a folder named 'all_checkpoint' in the root folder.
-
Diffusion policy training dataset:
https://huggingface.co/datasets/han2019/gpc_pushT_data/tree/main/diffusion_policy_data -
World model training dataset:
https://huggingface.co/datasets/han2019/gpc_pushT_data/tree/main/world_model_data
Please download the datasets and place them in the root folder.
There are two independent modules to train:
Directory:
diffusion_policy_training/
Run:
python train_model.pyThis trains the diffusion-based generative policy that produces candidate action sequences.
World model training is performed in two stages, as described in the paper.
Directory:
world_model_train_phase_one/
Run:
python train.pyThis stage trains the world model for single-step prediction, which stabilizes early training and improves multi-step rollout performance.
Directory:
world_model_train_phase_two/
Run:
python train.pyThis stage trains the model for multi-step rollouts, enabling long-horizon trajectory evaluation.
After training both the policy and world model, you can evaluate the integrated system.
Directory:
gpc_rank_evaluation/
Run:
python gpc_rank_evaluation.pyThis mode:
- Samples candidate action sequences from the diffusion policy
- Uses the predictive world model to simulate future states
- Ranks trajectories
- Executes the highest-scoring candidate
Directory:
gpc_opt_evaluation/
Run:
python gpc_opt_evaluation.pyThis mode:
- Uses the world model to iteratively optimize action sequences
- Improves performance through predictive refinement
If you find this work useful, please cite:
@article{qi2025strengthening,
title={Strengthening generative robot policies through predictive world modeling},
author={Qi, Han and Yin, Haocheng and Zhu, Aris and Du, Yilun and Yang, Heng},
journal={arXiv preprint arXiv:2502.00622},
year={2025}
}