Skip to content
This repository has been archived by the owner on Oct 31, 2023. It is now read-only.

Commit

Permalink
Merge pull request #2 from Aladoro/main
Browse files Browse the repository at this point in the history
Simple modular configuration refactoring
  • Loading branch information
denisyarats committed Sep 7, 2021
2 parents 568340c + ccb9a4d commit c2eab01
Show file tree
Hide file tree
Showing 29 changed files with 156 additions and 12 deletions.
19 changes: 11 additions & 8 deletions config.yaml → cfgs/config.yaml
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
defaults:
- _self_
- task@_global_: quadruped_walk
- override hydra/launcher: submitit_local

# task settings
task: quadruped_walk
frame_stack: 3
action_repeat: 2
discount: 0.99
# train settings
num_train_frames: 1000000
num_seed_frames: 4000
# eval
eval_every_frames: 10000
Expand All @@ -24,30 +24,33 @@ seed: 1
device: cuda
save_video: true
save_train_video: false
use_tb: false
use_tb: true
# experiment
experiment: exp
# agent
lr: 1e-4
feature_dim: 50

agent:
_target_: drqv2.DrQV2Agent
obs_shape: ??? # to be specified later
action_shape: ??? # to be specified later
device: ${device}
lr: 1e-4
lr: ${lr}
critic_target_tau: 0.01
update_every_steps: 2
use_tb: ${use_tb}
num_expl_steps: 2000
hidden_dim: 1024
feature_dim: 50
stddev_schedule: 'linear(1.0,0.1,500000)'
feature_dim: ${feature_dim}
stddev_schedule: ${stddev_schedule}
stddev_clip: 0.3

hydra:
run:
dir: ./exp_local/${now:%Y.%m.%d}/${now:%H%M%S}_${hydra.job.override_dirname}
sweep:
dir: ./exp/${now:%Y.%m.%d}/${now:%H%M}_${experiment}
dir: ./exp/${now:%Y.%m.%d}/${now:%H%M}_${agent_cfg.experiment}
subdir: ${hydra.job.num}
launcher:
timeout_min: 4300
Expand All @@ -56,4 +59,4 @@ hydra:
tasks_per_node: 1
mem_gb: 160
nodes: 1
submitit_folder: ./exp/${now:%Y.%m.%d}/${now:%H%M%S}_${experiment}/.slurm
submitit_folder: ./exp/${now:%Y.%m.%d}/${now:%H%M%S}_${agent_cfg.experiment}/.slurm
5 changes: 5 additions & 0 deletions cfgs/task/acrobot_swingup.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
defaults:
- medium
- _self_

task_name: acrobot_swingup
5 changes: 5 additions & 0 deletions cfgs/task/cartpole_balance.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
defaults:
- easy
- _self_

task_name: cartpole_balance
5 changes: 5 additions & 0 deletions cfgs/task/cartpole_balance_sparse.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
defaults:
- easy
- _self_

task_name: cartpole_balance_sparse
5 changes: 5 additions & 0 deletions cfgs/task/cartpole_swingup.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
defaults:
- easy
- _self_

task_name: cartpole_swingup
5 changes: 5 additions & 0 deletions cfgs/task/cartpole_swingup_sparse.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
defaults:
- medium
- _self_

task_name: cartpole_swingup_sparse
5 changes: 5 additions & 0 deletions cfgs/task/cheetah_run.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
defaults:
- medium
- _self_

task_name: cheetah_run
5 changes: 5 additions & 0 deletions cfgs/task/cup_catch.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
defaults:
- easy
- _self_

task_name: cup_catch
2 changes: 2 additions & 0 deletions cfgs/task/easy.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
num_train_frames: 1100000
stddev_schedule: 'linear(1.0,0.1,100000)'
5 changes: 5 additions & 0 deletions cfgs/task/finger_spin.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
defaults:
- easy
- _self_

task_name: finger_spin
5 changes: 5 additions & 0 deletions cfgs/task/finger_turn_easy.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
defaults:
- medium
- _self_

task_name: finger_turn_easy
5 changes: 5 additions & 0 deletions cfgs/task/finger_turn_hard.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
defaults:
- medium
- _self_

task_name: finger_turn_hard
2 changes: 2 additions & 0 deletions cfgs/task/hard.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
num_train_frames: 30100000
stddev_schedule: 'linear(1.0,0.1,2000000)'
5 changes: 5 additions & 0 deletions cfgs/task/hopper_hop.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
defaults:
- medium
- _self_

task_name: hopper_hop
5 changes: 5 additions & 0 deletions cfgs/task/hopper_stand.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
defaults:
- easy
- _self_

task_name: hopper_stand
7 changes: 7 additions & 0 deletions cfgs/task/humanoid_run.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
defaults:
- hard
- _self_

task_name: humanoid_run
lr: 8e-5
feature_dim: 100
7 changes: 7 additions & 0 deletions cfgs/task/humanoid_stand.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
defaults:
- hard
- _self_

task_name: humanoid_stand
lr: 8e-5
feature_dim: 100
7 changes: 7 additions & 0 deletions cfgs/task/humanoid_walk.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
defaults:
- hard
- _self_

task_name: humanoid_walk
lr: 8e-5
feature_dim: 100
2 changes: 2 additions & 0 deletions cfgs/task/medium.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
num_train_frames: 3100000
stddev_schedule: 'linear(1.0,0.1,500000)'
5 changes: 5 additions & 0 deletions cfgs/task/pendulum_swingup.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
defaults:
- easy
- _self_

task_name: pendulum_swingup
6 changes: 6 additions & 0 deletions cfgs/task/quadruped_run.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
defaults:
- medium
- _self_

task_name: quadruped_run
replay_buffer_size: 100000
5 changes: 5 additions & 0 deletions cfgs/task/quadruped_walk.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
defaults:
- medium
- _self_

task_name: quadruped_walk
5 changes: 5 additions & 0 deletions cfgs/task/reach_duplo.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
defaults:
- medium
- _self_

task_name: reach_duplo
5 changes: 5 additions & 0 deletions cfgs/task/reacher_easy.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
defaults:
- medium
- _self_

task_name: reacher_easy
5 changes: 5 additions & 0 deletions cfgs/task/reacher_hard.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
defaults:
- medium
- _self_

task_name: reacher_hard
7 changes: 7 additions & 0 deletions cfgs/task/walker_run.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
defaults:
- medium
- _self_

task_name: walker_run
nstep: 1
batch_size: 512
7 changes: 7 additions & 0 deletions cfgs/task/walker_stand.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
defaults:
- easy
- _self_

task_name: walker_stand
nstep: 1
batch_size: 512
7 changes: 7 additions & 0 deletions cfgs/task/walker_walk.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
defaults:
- easy
- _self_

task_name: walker_walk
nstep: 1
batch_size: 512
10 changes: 6 additions & 4 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ def __init__(self, cfg):
self.setup()

self.agent = make_agent(self.train_env.observation_spec(),
self.train_env.action_spec(), cfg.agent)
self.train_env.action_spec(),
self.cfg.agent)
self.timer = utils.Timer()
self._global_step = 0
self._global_episode = 0
Expand All @@ -51,9 +52,9 @@ def setup(self):
# create logger
self.logger = Logger(self.work_dir, use_tb=self.cfg.use_tb)
# create envs
self.train_env = dmc.make(self.cfg.task, self.cfg.frame_stack,
self.train_env = dmc.make(self.cfg.task_name, self.cfg.frame_stack,
self.cfg.action_repeat, self.cfg.seed)
self.eval_env = dmc.make(self.cfg.task, self.cfg.frame_stack,
self.eval_env = dmc.make(self.cfg.task_name, self.cfg.frame_stack,
self.cfg.action_repeat, self.cfg.seed)
# create replay buffer
data_specs = (self.train_env.observation_spec(),
Expand All @@ -75,6 +76,7 @@ def setup(self):
self.train_video_recorder = TrainVideoRecorder(
self.work_dir if self.cfg.save_train_video else None)


@property
def global_step(self):
return self._global_step
Expand Down Expand Up @@ -202,7 +204,7 @@ def load_snapshot(self):
self.__dict__[k] = v


@hydra.main(config_path='.', config_name='config')
@hydra.main(config_path='cfgs', config_name='config')
def main(cfg):
from train import Workspace as W
root_dir = Path.cwd()
Expand Down

0 comments on commit c2eab01

Please sign in to comment.