Skip to content

Commit 4cf6255

Browse files
feat(print_history): add pretty print flag (#38)
- excludes the debug information like node/action/rule validity
1 parent 0620439 commit 4cf6255

File tree

2 files changed

+15
-7
lines changed

2 files changed

+15
-7
lines changed

libraries/mathy_python/mathy/env.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,7 @@ def get_next_state(
278278
return out_env, transition, ExpressionChangeRule(BaseRule())
279279

280280
change = operation.apply_to(token.clone_from_root())
281+
assert change.result is not None
281282
root = change.result.get_root()
282283
change_name = operation.name
283284
out_problem = str(root)
@@ -303,11 +304,12 @@ def print_state(
303304
token_index: int = -1,
304305
change: ExpressionChangeRule = None,
305306
change_reward: float = 0.0,
307+
pretty: bool = False,
306308
):
307309
"""Render the given state to stdout for visualization"""
308310
print(
309311
self.render_state(
310-
env_state, action_name, token_index, change, change_reward
312+
env_state, action_name, token_index, change, change_reward, pretty
311313
)
312314
)
313315

@@ -322,7 +324,7 @@ def is_terminal_state(self, env_state: MathyEnvState) -> bool:
322324
"""
323325
return is_terminal_transition(self.get_state_transition(env_state))
324326

325-
def print_history(self, env_state: MathyEnvState) -> None:
327+
def print_history(self, env_state: MathyEnvState, pretty: bool = True) -> None:
326328
"""Render the history of an episode from a given state.
327329
328330
# Arguments
@@ -333,17 +335,19 @@ def print_history(self, env_state: MathyEnvState) -> None:
333335
curr_state: MathyEnvState = MathyEnvState(
334336
problem=initial_step.raw, max_moves=env_state.max_moves,
335337
)
336-
self.print_state(curr_state, "initial-state")
337-
while not self.is_terminal_state(curr_state):
338+
self.print_state(curr_state, "initial-state", pretty=pretty)
339+
while len(history) > 0:
338340
step: MathyEnvStateStep = history.pop(0)
339341
curr_state, transition, change = self.get_next_state(
340342
curr_state, step.action + (step.focus * len(self.rules))
341343
)
342344
rule_idx, token_idx = self.get_action_indices(step.action)
343345
rule: BaseRule = self.rules[rule_idx]
346+
rule_name: str = rule.name[:25].lower()
344347
self.print_state(
348+
pretty=pretty,
345349
env_state=curr_state,
346-
action_name=rule.name[:25].lower(),
350+
action_name=rule_name,
347351
token_index=int(f"{step.focus}".zfill(3)),
348352
change=change,
349353
change_reward=transition.reward,
@@ -356,6 +360,7 @@ def render_state(
356360
token_index: int = -1,
357361
change: ExpressionChangeRule = None,
358362
change_reward: float = 0.0,
363+
pretty: bool = False,
359364
) -> str:
360365
"""Render the given state to a string suitable for printing to a log"""
361366
changed_problem = env_state.agent.problem
@@ -379,6 +384,8 @@ def get_move_shortname(index, move):
379384
moves = " ".join(move_codes)
380385
reward = f"{change_reward:.2}"
381386
reward = f"{reward:<5}"
387+
if pretty:
388+
return output
382389
return f"{num_moves} | {moves} | {moves_left} | {token} | {reward} | {output}"
383390

384391
def random_action(self, expression: MathExpression, rule: Type[BaseRule]) -> int:
Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,14 +71,15 @@ def test_mathy_env_terminal_conditions():
7171
assert text == text and is_terminal_transition(reward) == bool(is_win)
7272

7373

74-
def test_print_history():
74+
@pytest.mark.parametrize("pretty", [True, False])
75+
def test_print_history(pretty: bool):
7576
env = PolySimplify()
7677
env_state = MathyEnvState(problem="4x+2")
7778
for i in range(10):
7879
env_state = env_state.get_out_state(
7980
problem="2+4x", focus=i, moves_remaining=10 - i, action=i
8081
)
81-
assert env.print_history(env_state) is None
82+
env.print_history(env_state, pretty=pretty)
8283

8384

8485
def test_env_finalize_state():

0 commit comments

Comments
 (0)