Skip to content

frt03/mxt_bench

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

3 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

A System for Morphology-Task Generalization via Unified Representation and Behavior Distillation

Accepted to ICLR2023 (notable-top-25%, Spotlight) [arxiv] [Website]

Citation

If you use this codebase for your research, please cite the paper:

@inproceedings{furuta2023asystem,
  title={A System for Morphology-Task Generalization via Unified Representation and Behavior Distillation},
  author={Hiroki Furuta and Yusuke Iwasawa and Yutaka Matsuo and Shixiang Shane Gu},
  booktitle={International Conference on Learning Representations},
  year={2023},
}

Installation

pip install -r requirements.txt

Behavior Distillation Pipeline

  1. Train single-task single-morphology PPO policy on the environment:
CUDA_VISIBLE_DEVICES=0 python train_ppo_mlp.py --logdir ../results --seed 0 --env ant_reach_4
  1. Pick trained policy weight, and collect expert brax.QP:
CUDA_VISIBLE_DEVICES=0,1 python generate_behavior_and_qp.py --seed 0 --env ant_reach_4 --task_name ant_reach --params_path ../results/ao_ppo_mlp_single_pro_ant_reach_4_20220707_174507/ppo_mlp_98304000.pkl
  1. Register qp_path (path to saved brax.QP) in dataset_config.py.

  2. Convert brax.QP to morphlogy-task graph representation (e.g. mtg_v2_base_m):

CUDA_VISIBLE_DEVICES=0 python generate_behavior_from_qp.py --seed 0 --env ant_reach_4 --task_name ant_reach --data_name ant_reach_4_mtg_v2_base_m --obs_config2 mtg_v2_base_m
  1. Register dataset_path (path to saved observations) in dataset_config.py and task_config.py.

  2. Train Transformer policy via multi-task behavior cloning:

CUDA_VISIBLE_DEVICES=0,1 python train_bc_transformer.py --task_name example --seed 0
# zero-shot evaluation
CUDA_VISIBLE_DEVICES=0,1 python train_bc_transformer_zs.py --task_name example --seed 0
# fine-tuning on multi-task imitation learning
CUDA_VISIBLE_DEVICES=0,1 python train_bc_transformer_fs.py --task_name example --seed 0 --params_path ../results/bc_transformer_zs/policy.pkl

How to Register New Morphology

How to Register New Task

ENV_DESCS = dict()

# add environments
for i in range(2, 7, 1):
  ENV_DESCS[f'ant_reach_{i}'] = functools.partial(load_desc, num_legs=i)
  ENV_DESCS[f'ant_reach_hard_{i}'] = functools.partial(load_desc, num_legs=i, r_min=10.5, r_max=11.5)

# missing
for i in range(3, 7, 1):
  for j in range(i):
    ENV_DESCS[f'ant_reach_{i}_b_{j}'] = functools.partial(load_desc, agent='broken_ant', num_legs=i, broken_id=j)
    ENV_DESCS[f'ant_reach_hard_{i}_b_{j}'] = functools.partial(load_desc, agent='broken_ant', num_legs=i, broken_id=j, r_min=10.5, r_max=11.5)
  • If you would like to avoid immidiate termination after the agents reach to the goal, please set min_dist=0 in each reward function dict.

Structure

Reference

About

A System for Morphology-Task Generalization via Unified Representation and Behavior Distillation (ICLR2023)

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages