This repo contains code for the paper Predicting Point Tracks from Internet Videos enables Diverse Zero-Shot Manipulation
Follow the environment.yml
file for creating conda environment and installing dependencies.
For training the point track prediction model, run the following after changing the number of nodes / GPUs per node, batch size as needed
torchrun --nnodes=1 --nproc_per_node=8 train_track_pred.py --global-batch-size=480 --data-path=<folder with data files>
Specify path to initial image, goal image, and checkpoint (trained model is in this link). The visualization will be saved in the folder save_tracK_pred
.
python inference_track_pred.py --ckpt=<path to model> --init=<path to initial image> --goal=<path to goal image>
For any questions about the project, feel free to email Homanga Bharadhwaj hbharadh@cs.cmu.edu
The code is licensed under CC-BY-NC License.md
The code in this repo is based on Diffusion Transformers https://github.com/facebookresearch/DiT
and uses open-source packages like diffusers
, scipy
, opencv
, numpy
, pytorch
If you find the repository helpful, please consider citing our paper
@misc{bharadhwaj2024track2act,
title={Track2Act: Predicting Point Tracks from Internet Videos enables Diverse Zero-shot Robot Manipulation},
author={Homanga Bharadhwaj and Roozbeh Mottaghi and Abhinav Gupta and Shubham Tulsiani},
year={2024},
eprint={2405.01527},
archivePrefix={arXiv},
primaryClass={cs.RO}
}