@@ -278,6 +278,7 @@ def get_next_state(
278278 return out_env , transition , ExpressionChangeRule (BaseRule ())
279279
280280 change = operation .apply_to (token .clone_from_root ())
281+ assert change .result is not None
281282 root = change .result .get_root ()
282283 change_name = operation .name
283284 out_problem = str (root )
@@ -303,11 +304,12 @@ def print_state(
303304 token_index : int = - 1 ,
304305 change : ExpressionChangeRule = None ,
305306 change_reward : float = 0.0 ,
307+ pretty : bool = False ,
306308 ):
307309 """Render the given state to stdout for visualization"""
308310 print (
309311 self .render_state (
310- env_state , action_name , token_index , change , change_reward
312+ env_state , action_name , token_index , change , change_reward , pretty
311313 )
312314 )
313315
@@ -322,7 +324,7 @@ def is_terminal_state(self, env_state: MathyEnvState) -> bool:
322324 """
323325 return is_terminal_transition (self .get_state_transition (env_state ))
324326
325- def print_history (self , env_state : MathyEnvState ) -> None :
327+ def print_history (self , env_state : MathyEnvState , pretty : bool = True ) -> None :
326328 """Render the history of an episode from a given state.
327329
328330 # Arguments
@@ -333,17 +335,19 @@ def print_history(self, env_state: MathyEnvState) -> None:
333335 curr_state : MathyEnvState = MathyEnvState (
334336 problem = initial_step .raw , max_moves = env_state .max_moves ,
335337 )
336- self .print_state (curr_state , "initial-state" )
337- while not self . is_terminal_state ( curr_state ) :
338+ self .print_state (curr_state , "initial-state" , pretty = pretty )
339+ while len ( history ) > 0 :
338340 step : MathyEnvStateStep = history .pop (0 )
339341 curr_state , transition , change = self .get_next_state (
340342 curr_state , step .action + (step .focus * len (self .rules ))
341343 )
342344 rule_idx , token_idx = self .get_action_indices (step .action )
343345 rule : BaseRule = self .rules [rule_idx ]
346+ rule_name : str = rule .name [:25 ].lower ()
344347 self .print_state (
348+ pretty = pretty ,
345349 env_state = curr_state ,
346- action_name = rule . name [: 25 ]. lower () ,
350+ action_name = rule_name ,
347351 token_index = int (f"{ step .focus } " .zfill (3 )),
348352 change = change ,
349353 change_reward = transition .reward ,
@@ -356,6 +360,7 @@ def render_state(
356360 token_index : int = - 1 ,
357361 change : ExpressionChangeRule = None ,
358362 change_reward : float = 0.0 ,
363+ pretty : bool = False ,
359364 ) -> str :
360365 """Render the given state to a string suitable for printing to a log"""
361366 changed_problem = env_state .agent .problem
@@ -379,6 +384,8 @@ def get_move_shortname(index, move):
379384 moves = " " .join (move_codes )
380385 reward = f"{ change_reward :.2} "
381386 reward = f"{ reward :<5} "
387+ if pretty :
388+ return output
382389 return f"{ num_moves } | { moves } | { moves_left } | { token } | { reward } | { output } "
383390
384391 def random_action (self , expression : MathExpression , rule : Type [BaseRule ]) -> int :
0 commit comments