Skip to content

Commit d04366e

Browse files
fix(gym): return node ids and action mask for np observations
1 parent 96c7d23 commit d04366e

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

libraries/mathy_python/mathy/envs/gym/mathy_gym_env.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,10 @@ def _observe(self, state: MathyEnvState) -> Union[MathyObservation, np.ndarray]:
6969
self.action_space.mask = action_mask
7070
if self.np_observation:
7171
# convert mask to probabilities
72+
nodes = np.array(pad_array(observation.nodes, 512, 0))
7273
mask = np.array(pad_array(observation.mask, 512, 0))
7374
mask = mask / np.sum(mask)
74-
return mask
75+
return np.vstack((nodes, mask))
7576
return observation
7677

7778
def reset(self):

0 commit comments

Comments
 (0)