Skip to content

clvrai/new-actions-rl

master
Switch branches/tags
Code

Latest commit

 

Git stats

Files

Permalink
Failed to load latest commit information.
Type
Name
Latest commit message
Commit time
 
 
 
 
 
 
 
 
rlf
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Generalization to New Actions in Reinforcement Learning

Ayush Jain*, Andrew Szot*, Joseph J. Lim at USC CLVR lab
[Paper website]

Directories

The structure of the repository:

  • analysis: Scripts used for analysis figures and experiments.
  • envs: the four subfolders in this folder contain the four environments.
  • method: Implementation of all method and baseline details
  • rlf: Reinforcement Learning Framework. General RL / PPO training code.
  • scripts: Miscalaneous scripts. Contains script for generating the train / test action set splits.
  • main.py: Entry point for running policy.
  • embedder.py: Entry point for training embedder.

Log directories:

  • data/trained_model/ENV-NAME_PREFIX/: Trained models.
  • data/vids/ENV-NAME/: Evaluation videos.
  • data/logs/ENV-NAME/PREFIX/: Tensorboard summary.

Prerequisites

  • Python 3.7
  • MuJoCo 2.0

Dependencies

All the python package requirements are in requirements.txt. If you are using conda, you can use the following command with Python 3.7.3:

conda create -n [your_name] python=3.7
source activate [your_name]
pip install -r requirements.txt

Experiments

The experiment flow for each environment is similar. The steps are always the same as follows:

  • Generate train and test action splits: python gen_action_sets.py --env-name $ENV_NAME

  • Generate Action Datasets for the environment: python embedder.py --env-name $PLAY_ENV_NAME --save-dataset

  • Train action embedder model: python embedder.py --env-name $PLAY_ENV_NAME --save-emb-model-file $EMB_FILE_NAME --train-embeddings

  • Generate embedding files: python main.py --env-name $ENV_NAME --play-env-name $PLAY_ENV_NAME --load-emb-model-file $EMB_MODEL_NAME --save-embeddings-file $EMB_FILE_NAME --prefix main

  • Train policy with saved embeddings: python main.py --env-name $ENV_NAME --load-embeddings-file $EMB_FILE_NAME

Note:
(1) $EMB_MODEL_NAME must be $EMB_FILE_NAME-htvae-500.m if your model is trained for at least 500 epochs (specified by --emb-epochs).
(2) Use --n-trajectories 64 and --emb-epochs 500 for faster data generation and embedder training.

Below are the example commands used for each environment and method approach.

Environments

CREATE (Chain Reaction Tool Environment)

$ENV_NAME = 'CreateLevelPush-v0' or 'CreateLevelNavigate-v0' or 'CreateLevelObstacle-v0'.
$PLAY_ENV_NAME = 'StateCreateGameN1PlayNew-v0' (state-based) or 'CreateGamePlay-v0' (video-based).
$EMB_FILE_NAME = 'create_st' (state-based) or create_im (video-based)

(1) Train policy directly with:

python main.py --env-name CreateLevelPush-v0 --prefix main.
python main.py --env-name CreateLevelNavigate-v0 --prefix main.
python main.py --env-name CreateLevelObstacle-v0 --prefix main.

OR

(2) For full procedure, follow these commands:

  • Generate Splits: python gen_action_sets.py --env-name CreateLevelPush-v0
  • Generate Data: python embedder.py --env-name StateCreateGameN1PlayNew-v0 --save-dataset
  • Train Action Embedder: python embedder.py --env-name StateCreateGameN1PlayNew-v0 --save-emb-model-file create_st --train-embeddings
  • Generate embedding files: python main.py --env-name CreateLevelPush-v0 --play-env-name StateCreateGameN1PlayNew-v0 --load-emb-model-file create_st-htvae-5000.m --save-embeddings-file create_st --prefix main
  • Train policy with saved embeddings: python main.py --env-name CreateLevelPush-v0 --load-embeddings-file create_st --prefix main

Reco

There is no data generation or embedding learning to recommender system

$ENV_NAME = 'RecoEnv-v0'

(1) Train policy directly with:

python main.py --env-name RecoEnv-v0 --prefix main

OR

(2) For full procedure, follow these commands:

  • Generate Splits: python gen_action_sets.py --env-name RecoEnv-v0
  • Policy: python main.py --env-name RecoEnv-v0 --prefix main

Block Stacking

$ENV_NAME = 'StackEnv-v0'
$PLAY_ENV_NAME = 'BlockPlayImg-v0'
$EMB_FILE_NAME = 'stack_im'

(1) Train policy directly with:

python main.py --env-name StackEnv-v0 --prefix main

OR

(2) For full procedure, follow these commands:

  • Generate Splits: python gen_action_sets.py --env-name StackEnv-v0
  • Generate Data: python embedder.py --env-name BlockPlayImg-v0 --save-dataset
  • Train Action Embedder: python embedder.py --env-name BlockPlayImg-v0 --save-emb-model-file stack_im --train-embeddings
  • Generate embedding files: python main.py --env-name StackEnv-v0 --play-env-name BlockPlayImg-v0 --load-emb-model-file stack_im-htvae-5000.m --save-embeddings-file stack_im --prefix main
  • Train policy with saved embeddings: python main.py --env-name StackEnv-v0 --load-embeddings-file stack_im --prefix main

Grid world

$ENV_NAME = 'MiniGrid-LavaCrossingS9N1-v0'
$PLAY_ENV_NAME = 'MiniGrid-Empty-Random-80x80-v0'
$EMB_FILE_NAME = 'gw_onehot_new'

(1) Train policy directly with:

python main.py --env-name MiniGrid-LavaCrossingS9N1-v0 --prefix main

OR

(2) For full procedure, follow these commands:

  • Generate Splits: python gen_action_sets.py --env-name MiniGrid-LavaCrossingS9N1-v0
  • Generate Data: python embedder.py --env-name MiniGrid-Empty-Random-80x80-v0 --save-dataset
  • Train Action Embedder: python embedder.py --env-name MiniGrid-Empty-Random-80x80-v0 --save-emb-model-file gw_onehot_new --train-embeddings
  • Generate embedding files: python main.py --env-name MiniGrid-LavaCrossingS9N1-v0 --play-env-name MiniGrid-Empty-Random-80x80-v0 --load-emb-model-file gw_onehot_new-htvae-5000.m --save-embeddings-file gw_onehot_new --prefix main
  • Train policy with saved embeddings: python main.py --env-name MiniGrid-LavaCrossingS9N1-v0 --load-embeddings-file gw_onehot_new --prefix main

Baselines and Ablations

To run the baselines for any environment, add the following to the main command:

Baselines

  • Nearest-Neighbor (NN): --nearest-neighbor --fixed-action-set --action-random-sample False --prefix NN
  • Distance-based Policy Architecture (Dist): --distance-based --prefix dist
  • Non-hierarchical embeddings (VAE): --load-embeddings-file $FILE --prefix vae, where $FILE storing these embeddings is environment-dependent:
    • CREATE: create_fc_st_vae
    • Shape Stacking: stack_vae
    • Grid World: gw_onehot_vae

Ablations

  • Fixed Action Space (FX): --fixed-action-set --action-random-sample False --prefix FX
  • Random-Sampling without clustering (RS): --sample-clusters False --prefix RS
  • No-entropy (NE): --entropy-coef 0. --prefix NE

Other embedding data formats

  • CREATE: Video-based embeddings: --load-embeddings-file create_fc_im --o-dim 128 --z-dim 128 --prefix im
  • Grid World: (x,y) coordinate state-based embeddings: --load-embeddings-file gw_st --prefix st

Ground-truth embeddings

for CREATE and Grid World: --gt-embs --prefix GT

Analysis

For running the three analysis scripts simply run

  • analysis/analysis_dist.py.
  • analysis/analysis_emb.py.
  • analysis/analysis_ratio.py

Acknowledgement

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published