Skip to content

Commit

Permalink
feat: Extend to any n env
Browse files Browse the repository at this point in the history
  • Loading branch information
iwishiwasaneagle committed Apr 9, 2024
1 parent 29bbe52 commit 04d3e72
Showing 1 changed file with 20 additions and 16 deletions.
36 changes: 20 additions & 16 deletions docs/examples/drl_3d_wp/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,10 +309,10 @@ class Dual_DRL_WP_Env_LQR(gymnasium.Env):
@staticmethod
def make_sub_env(angle: float, T: float, dt: float) -> gymnasium.Env:
state = JState()
state.pos = [np.cos(angle), np.sin(angle), 0]
state.pos = [2.5 * np.cos(angle), 2.5 * np.sin(angle), 0]
return DRL_WP_Env_LQR(T=T, dt=dt, env_cls_kwargs=dict(initial_state=state))

def __init__(self, *, T, dt):
def __init__(self, *, T, dt, N_envs: int = 2):
super().__init__()

self.T = T
Expand All @@ -321,21 +321,21 @@ def __init__(self, *, T, dt):
self.envs = DummyVecEnv(
[
functools.partial(self.make_sub_env, f, self.T, self.dt)
for f in np.linspace(0, np.pi, 2)
for f in np.linspace(0, 2 * np.pi, N_envs + 1)[1:]
]
)

obs = self.envs.observation_space
self.observation_space = gymnasium.spaces.Box(
low=np.tile(obs.low, (2, 1)),
high=np.tile(obs.high, (2, 1)),
shape=(2, *obs.shape),
low=np.tile(obs.low, (N_envs, 1)),
high=np.tile(obs.high, (N_envs, 1)),
shape=(N_envs, *obs.shape),
)
act = self.envs.action_space
self.action_space = gymnasium.spaces.Box(
low=np.tile(act.low, (2, 1)),
high=np.tile(act.high, (2, 1)),
shape=(2, *act.shape),
low=np.tile(act.low, (N_envs, 1)),
high=np.tile(act.high, (N_envs, 1)),
shape=(N_envs, *act.shape),
)

@staticmethod
Expand Down Expand Up @@ -377,16 +377,20 @@ def step(self, action):
done = np.any(dones)
reward = np.sum(rew) / self.envs.num_envs

distance_between_envs = np.linalg.norm(
np.subtract(*[f.pos for f in self.envs.get_attr("state")])
)
pos = np.array([f.pos for f in self.envs.get_attr("state")])
dists = np.linalg.norm(pos[np.newaxis, :, :] - pos[:, np.newaxis, :], axis=2)[
~np.eye(pos.shape[0], dtype=bool)
]

min_distance_between_envs = dists.min()

if distance_between_envs < 0.5:
collision = min_distance_between_envs < 0.5
if collision:
reward = -100
done = True
info[0]["collision"] = info[1]["collision"] = True
else:
info[0]["collision"] = info[1]["collision"] = False

for i in range(len(info)):
info[i]["collision"] = collision

info = self.merge_infos(*info)
return obs, reward, done, done, info

0 comments on commit 04d3e72

Please sign in to comment.