Skip to content

PyTorch implementation of Generating Multi-Agent Trajectories using Programmatic Weak Supervision

License

Notifications You must be signed in to change notification settings

ezhan94/multiagent-programmatic-supervision

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

20 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Generating Multi-Agent Trajectories using Programmatic Weak Supervision

Code for paper titled Generating Multi-Agent Trajectories using Programmatic Weak Supervision by Zhan et al., ICLR 2019.

Installation & Setup

Code is written using PyTorch version 1.0.0.

After cloning the repository, you need to download the data.

[Update 11/25/20] The basketball dataset is now available on AWS Data Exchange. Please make sure to acknowledge Stats Perform if you use the data for your research.

The Boids dataset can be generated by running:

$ python datasets/boids/generate_data.py

This may take a while, so a pre-generated Boids dataset is included here.

Running the Code

To train a model, you can edit the parameters in train_model.sh and run the script from the command-line:

$ ./train_model.sh

After training a model,

$ python sample.py -t <trial_id> -n <num_samples> -b <burn_in> --run --plot

will generate and plot samples from a model and save them in saved/<trial_id>/experiments/sample/.

For full usage, use flag --help.

Scripts

To see the parameters of a past experiment (for reproducability), run:

$ python scripts/print_params.py -t <trial_id>

To visualize examples from a test dataset, run:

$ python scripts/show_groundtruth.py -d <dataset> -n <num_examples>

which will save them into datasets/<dataset>/data/examples/.

To compute and compare domain statistics for basketball, run:

$ python sample.py -t <trial_id> -n 1000 -b 10 --run
$ python scripts/compute_bball_stats.py -t <trial_id>

Pretrained Models

Included in this repository in saved/ are four pretrained models for basketball as discussed in the paper:

Trial ID Model
101 RNN_GAUSS
102 VRNN_SINGLE
103 VRNN_INDEP
104 MACRO_VRNN

About

PyTorch implementation of Generating Multi-Agent Trajectories using Programmatic Weak Supervision

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published