Skip to content

Commit

Permalink
fix: Reset was calling States(States(...)) causing invalid shaped arr…
Browse files Browse the repository at this point in the history
…ays)
  • Loading branch information
iwishiwasaneagle committed Apr 5, 2023
1 parent 3aa4e30 commit f8c3233
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 12 deletions.
36 changes: 24 additions & 12 deletions src/jdrones/envs/position.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,18 +51,6 @@ def __init__(
low=np.array([0, 0, 1]), high=np.array([10, 10, 10])
)

def reset(
self,
*,
seed: int | None = None,
options: dict[str, Any] | None = None,
) -> tuple[States, dict[str, Any]]:
super().reset(seed=seed, options=options)

obs, _ = self.env.reset(seed=seed, options=options)

return States([np.copy(obs)]), {}

@staticmethod
def get_reward(states: States) -> float:
"""
Expand Down Expand Up @@ -161,6 +149,18 @@ def _validate_action_input(
)
return action_as_state

def reset(
self,
*,
seed: int | None = None,
options: dict[str, Any] | None = None,
) -> tuple[States, dict[str, Any]]:
super().reset(seed=seed, options=options)

obs, _ = self.env.reset(seed=seed, options=options)

return States([np.copy(obs)]), {}

def step(
self, action: PositionAction | PositionVelocityAction
) -> tuple[States, float, bool, bool, dict[str, Any]]:
Expand Down Expand Up @@ -365,6 +365,18 @@ def __init__(
low=np.array([[0, 0, 1], [0, 0, 1]]), high=10, shape=(2, 3)
)

def reset(
self,
*,
seed: int | None = None,
options: dict[str, Any] | None = None,
) -> tuple[States, dict[str, Any]]:
super().reset(seed=seed, options=options)

obs, _ = self.env.reset(seed=seed, options=options)

return obs, {}

@staticmethod
def calc_v_at_B(A: VEC3, B: VEC3, C: VEC3, *, V: float, N: float = 3):
"""
Expand Down
1 change: 1 addition & 0 deletions tests/envs/position/test_stress.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
def run(T, dt, env):
trunc = False
t, c = 0, 0
env.reset()
while not trunc and t <= T:
setpoint = env.action_space.sample()
obs, _, term, trunc, _ = env.step(setpoint)
Expand Down

0 comments on commit f8c3233

Please sign in to comment.