Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added code on required files, no testing yet #3

Merged
merged 1 commit into from
Apr 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 20 additions & 20 deletions experiments/gp_rc/gp_rc_cartpole.yaml
Original file line number Diff line number Diff line change
@@ -1,18 +1,14 @@
cartpole_config:
task_config:
info_in_reset: True
ctrl_freq: 50
pyb_freq: 50
gui: False
normalized_rl_action_space: False
episode_len_sec: 5
# State initialization
init_state: null
randomized_init: True
init_state_randomization_info: null
# Randomization
inertial_prop: null
randomized_inertial_prop: False
inertial_prop_randomization_info: null

# Task
task: stabilization
task_info: null
Expand All @@ -22,19 +18,23 @@ cartpole_config:
dynamics:
- disturbance_func: white_noise
std: 0.01
adversary_disturbance: null
adversary_disturbance_offset: 0.0
adversary_disturbance_scale: 0.01
# Constraints
constraints: null
done_on_violation: False
use_constraint_penalty: False
constraint_penalty: -1
# Misc
verbose: False
# RL Hyper-parameters
obs_wrap_angle: False
rew_state_weight: 1.0
rew_act_weight: 0.0001
rew_exponential: True
done_on_out_of_bound: True


algo_config:
q: [1]
r: [0.1]
# GP training args
train_samples: 500
validation_samples: 200
train_iterations: [1000]
learning_rate: [0.1]

#H2 optimization args
step_size: 0.1
max_optim_tries: 100

# Runner args
deque_size: 10
eval_batch_size: 1
3 changes: 0 additions & 3 deletions experiments/gp_rc/gp_rc_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,6 @@ def train(config):
# Create the controller/control_agent.
control_agent = make(config.algo,
env_func,
training=True,
checkpoint_path=os.path.join(config.output_dir, "model_latest.pt"),
output_dir=config.output_dir,
device=config.device,
seed=config.seed,
Expand Down Expand Up @@ -86,7 +84,6 @@ def test_policy(config):
# Create the controller/control_agent.
control_agent = make(config.algo,
env_func,
training=False,
checkpoint_path=os.path.join(config.output_dir, "model_latest.pt"),
output_dir=config.output_dir,
device=config.device,
Expand Down
Loading