-
Notifications
You must be signed in to change notification settings - Fork 180
/
ant_env_rand_direc.py
76 lines (64 loc) · 2.73 KB
/
ant_env_rand_direc.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
from .mujoco_env import MujocoEnv
from rllab.core.serializable import Serializable
import numpy as np
from rllab.envs.base import Step
from rllab.misc.overrides import overrides
from rllab.misc import logger
class AntEnvRandDirec(MujocoEnv, Serializable):
FILE = 'ant.xml'
def __init__(self, goal=None, *args, **kwargs):
self._goal_vel = goal
super(AntEnvRandDirec, self).__init__(*args, **kwargs)
Serializable.__init__(self, *args, **kwargs)
def get_current_obs(self):
return np.concatenate([
self.model.data.qpos.flat,
self.model.data.qvel.flat,
np.clip(self.model.data.cfrc_ext, -1, 1).flat,
self.get_body_xmat("torso").flat,
self.get_body_com("torso"),
]).reshape(-1)
def sample_goals(self, num_goals):
# for fwd/bwd env, goal direc is backwards if < 1.5, forwards if > 1.5
return np.random.uniform(0.0, 3.0, (num_goals, ))
@overrides
def reset(self, init_state=None, reset_args=None, **kwargs):
goal_vel = reset_args
if goal_vel is not None:
self._goal_vel = goal_vel
elif self._goal_vel is None:
self._goal_vel = np.random.uniform(0.0, 3.0)
self.goal_direction = -1.0 if self._goal_vel < 1.5 else 1.0
self.reset_mujoco(init_state)
self.model.forward()
self.current_com = self.model.data.com_subtree[0]
self.dcom = np.zeros_like(self.current_com)
obs = self.get_current_obs()
return obs
def step(self, action):
self.forward_dynamics(action)
comvel = self.get_body_comvel("torso")
forward_reward = self.goal_direction*comvel[0]
lb, ub = self.action_bounds
scaling = (ub - lb) * 0.5
ctrl_cost = 0.5 * 1e-2 * np.sum(np.square(action / scaling))
contact_cost = 0.5 * 1e-3 * np.sum(
np.square(np.clip(self.model.data.cfrc_ext, -1, 1))),
survive_reward = 0.05
reward = forward_reward - ctrl_cost - contact_cost + survive_reward
state = self._state
notdone = np.isfinite(state).all() \
and state[2] >= 0.2 and state[2] <= 1.0
done = not notdone
ob = self.get_current_obs()
return Step(ob, float(reward), done)
@overrides
def log_diagnostics(self, paths, prefix=''):
progs = [
path["observations"][-1][-3] - path["observations"][0][-3]
for path in paths
]
logger.record_tabular(prefix+'AverageForwardProgress', np.mean(progs))
logger.record_tabular(prefix+'MaxForwardProgress', np.max(progs))
logger.record_tabular(prefix+'MinForwardProgress', np.min(progs))
logger.record_tabular(prefix+'StdForwardProgress', np.std(progs))