Skip to content

Commit be3ab58

Browse files
feat: Add support for integrating with Fragile library
* chore: stash WIP fragile search mathy agent - add `mathy_fragile.py` in the root for now. It requires a local mathy installation to use the hacked up library compatible with plangym: https://mathy.ai/contributing/#use-your-local-version-of-mathy - requires plangym/fragile packages, installed by whatever means you use * chore(env): add invalid move error scaling - calculate the distance from the selected invalid action to the nearest valid action and provide a negative reward based on that. - remove some extra state from the MathyGymEnv to ensure that fragile can cleanly reset all of the state during batch steps * chore: finish cleaning up gym env for clean reset - I changed the env to reset to the same initial state when called. I think this is more in-line with what an atari env might do. * chore: use node types as observations - sequence of node type IDs - adjust scale of rewards to remove negatives * test(state): fix to_string/from_string * refactor(env): use positive reward signals for fragile search - normalize all rewards so that dead-states return 0.0 and positive ones return > 0.0 (as in paper) - treat invalid action selections as null states and continue the simulation * chore: cleanup and fix test script * chore: cleanup * chore(fragile): add masked action selection - this is pretty slow by comparison. See if there's a way to avoid looping over the action masks. Maybe build the probabilities into the observation so the calculation is unnecessary? * chore: faster numpy action selection - calculate mask probabilities and return as observation - use numpy trickery to avoid looping over the observations to choose random actions from probabilities in batches * chore: fix fragile oob/terminal conditions * chore: fix tests * chore: drop fragile script
1 parent 478fad5 commit be3ab58

File tree

13 files changed

+204
-140
lines changed

13 files changed

+204
-140
lines changed
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
mypy
2+
pytest
3+
pytest-cov
4+
black
5+
flake8

libraries/mathy_pydoc/tools/setup.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,6 @@ fi
1010

1111
. .env/bin/activate
1212
echo "Installing/updating requirements..."
13-
pip install pytest pytest-cov
14-
pip install -e .[dev]
13+
pip install -e .
14+
pip install -r requirements-dev.txt
1515

libraries/mathy_python/mathy/agents/a3c/agent.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
from ..policy_value_model import get_or_create_policy_model, PolicyValueModel
1111
from .config import A3CConfig
1212
from .worker import A3CWorker
13+
from ...envs.gym import MathyGymEnv
14+
from ...state import observations_to_window
1315

1416

1517
class A3CAgent:
@@ -33,17 +35,17 @@ def __init__(self, args: A3CConfig, env_extra: dict = None):
3335
win_threshold=self.args.teacher_promote_wins,
3436
lose_threshold=self.args.teacher_demote_wins,
3537
)
36-
env = gym.make(self.teacher.get_env(0, 0), **self.env_extra)
38+
env: MathyGymEnv = gym.make(self.teacher.get_env(0, 0), **self.env_extra)
3739
self.action_size = env.action_space.n
3840
self.log_dir = os.path.join(self.args.model_dir, "tensorboard")
3941
self.writer = tf.summary.create_file_writer(self.log_dir)
40-
init_window = env.initial_window()
42+
initial_window = observations_to_window([env.reset()])
4143
self.global_model = get_or_create_policy_model(
4244
args=args, predictions=self.action_size, is_main=True, env=env.mathy
4345
)
4446
with self.writer.as_default():
4547
tf.summary.trace_on(graph=True)
46-
inputs = init_window.to_inputs()
48+
inputs = initial_window.to_inputs()
4749

4850
@tf.function
4951
def trace_fn():

libraries/mathy_python/mathy/agents/a3c/worker.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -224,13 +224,13 @@ def run_episode(self, episode_memory: EpisodeMemory) -> float:
224224
last_observation: MathyObservation = env.reset()
225225
last_text = env.state.agent.problem
226226
last_action = -1
227-
last_reward = -1
227+
last_reward = 0.0
228228

229229
selector = self.build_episode_selector(env)
230230

231231
while not done and A3CWorker.request_quit is False:
232232
if self.args.print_training and self.worker_idx == 0:
233-
env.render(self.args.print_mode, None)
233+
env.render(last_action=last_action, last_reward=last_reward)
234234
window = episode_memory.to_window_observation(
235235
last_observation, window_size=self.args.prediction_window_size
236236
)
@@ -253,7 +253,7 @@ def run_episode(self, episode_memory: EpisodeMemory) -> float:
253253
)
254254
if time_count == self.args.update_gradients_every or done:
255255
if done and self.args.print_training and self.worker_idx == 0:
256-
env.render(self.args.print_mode, None)
256+
env.render(last_action=last_action, last_reward=last_reward)
257257

258258
self.update_global_network(done, observation, episode_memory)
259259
self.maybe_write_histograms()

libraries/mathy_python/mathy/api.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,11 +65,11 @@ def simplify(
6565
assert env.state is not None
6666
last_text = env.state.agent.problem
6767
last_action = -1
68-
last_reward = -1
68+
last_reward = 0.0
6969
selector = GreedyActionSelector(model=self.model, episode=0, worker_id=0)
7070
done = False
7171
while not done:
72-
env.render(self.config.print_mode, None)
72+
env.render(last_action=last_action, last_reward=last_reward)
7373
window = episode_memory.to_window_observation(
7474
last_observation, window_size=self.config.prediction_window_size
7575
)
@@ -96,7 +96,7 @@ def simplify(
9696
if done:
9797
# Last timestep reward
9898
win = reward > 0.0
99-
env.render(self.config.print_mode, None)
99+
env.render(last_action=last_action, last_reward=last_reward)
100100
print(
101101
color(
102102
text="SOLVE" if win else "FAIL", fore="green" if win else "red",

libraries/mathy_python/mathy/core/rule.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,12 +83,13 @@ class ExpressionChangeRule:
8383

8484
rule: BaseRule
8585
node: Optional[MathExpression]
86-
result: MathExpression
86+
result: Optional[MathExpression]
8787
_save_parent: Optional[MathExpression]
8888

8989
def __init__(self, rule, node: MathExpression = None):
9090
self.rule = rule
9191
self.node = node
92+
self.result = None
9293
self._save_parent = None
9394

9495
def save_parent(

libraries/mathy_python/mathy/env.py

Lines changed: 32 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,13 @@ def __init__(
4343
rules: List[BaseRule] = None,
4444
max_moves: int = 20,
4545
verbose: bool = False,
46+
error_invalid: bool = False,
4647
reward_discount: float = 0.99,
4748
):
4849
self.discount = reward_discount
4950
self.verbose = verbose
5051
self.max_moves = max_moves
52+
self.error_invalid = error_invalid
5153
self.parser = ExpressionParser()
5254
if rules is None:
5355
self.rules = MathyEnv.core_rules()
@@ -73,18 +75,6 @@ def action_size(self) -> int:
7375
"""Return the number of available actions"""
7476
return len(self.rules)
7577

76-
def step(
77-
self, state: MathyEnvState, action: int, as_observation: bool = False
78-
) -> Tuple[Union[MathyEnvState, MathyObservation], float, bool, Any]:
79-
new_state, transition, change = self.get_next_state(state, action)
80-
observation = self.state_to_observation(state)
81-
info = {"transition": transition}
82-
done = is_terminal_transition(transition)
83-
self.last_action = action
84-
self.last_change = change
85-
self.last_reward = round(float(transition.reward), 4)
86-
return observation, transition.reward, done, info
87-
8878
def finalize_state(self, state: MathyEnvState):
8979
"""Perform final checks on a problem state, to ensure the episode yielded
9080
results that were uncorrupted by transformation errors."""
@@ -108,15 +98,17 @@ def get_rewarding_actions(self, state: MathyEnvState) -> List[Type[BaseRule]]:
10898
# rewarding.
10999
return [
110100
ConstantsSimplifyRule,
111-
DistributiveMultiplyRule,
112101
DistributiveFactorOutRule,
113102
VariableMultiplyRule,
114103
]
115104

116105
def get_penalizing_actions(self, state: MathyEnvState) -> List[Type[BaseRule]]:
117106
"""Get the list of penalizing action types. When these actions
118107
are selected, the agent gets a negative reward."""
119-
return [AssociativeSwapRule]
108+
return [
109+
AssociativeSwapRule,
110+
DistributiveMultiplyRule,
111+
]
120112

121113
def max_moves_fn(
122114
self, problem: MathyEnvProblem, config: MathyEnvProblemArgs
@@ -267,28 +259,41 @@ def get_next_state(
267259
op_not_rule = not isinstance(operation, BaseRule)
268260
op_cannot_apply = token is None or operation.can_apply_to(token) is False
269261
if token is None or op_not_rule or op_cannot_apply:
270-
steps = int(env_state.max_moves - agent.moves_remaining)
271-
msg = "Step: {} - Invalid action({}) '{}' for expression '{}'.".format(
272-
steps, action, type(operation), expression
273-
)
274-
raise_with_history("Invalid Action", msg, agent.history)
275-
raise ValueError(f"Invalid Action: {msg}")
262+
if self.error_invalid:
263+
steps = int(env_state.max_moves - agent.moves_remaining)
264+
msg = "Step: {} - Invalid action({}) '{}' for expression '{}'.".format(
265+
steps, action, type(operation), expression
266+
)
267+
raise_with_history("Invalid Action", msg, agent.history)
268+
raise ValueError(f"Invalid Action: {msg}")
269+
else:
270+
valid_mask = self.get_valid_moves(env_state)
271+
# Non-masked searches ignore invalid moves entirely
272+
out_env = MathyEnvState.copy(env_state)
273+
out_env.action = -1
274+
obs = out_env.to_observation(
275+
self.get_valid_moves(out_env), parser=self.parser
276+
)
277+
transition = time_step.transition(obs, EnvRewards.INVALID_MOVE)
278+
return out_env, transition, ExpressionChangeRule(BaseRule())
276279

277280
change = operation.apply_to(token.clone_from_root())
278281
root = change.result.get_root()
279282
change_name = operation.name
280283
out_problem = str(root)
281284
out_env = env_state.get_out_state(
282285
problem=out_problem,
283-
focus_index=token_index,
286+
focus=token_index,
284287
action=action_index,
285288
moves_remaining=agent.moves_remaining - 1,
286289
)
287290

291+
transition = self.get_state_transition(out_env, searching)
288292
if not searching and self.verbose:
289293
token_idx = int("{}".format(token_index).zfill(3))
290-
self.print_state(out_env, change_name[:25].lower(), token_idx, change)
291-
transition = self.get_state_transition(out_env, searching)
294+
self.print_state(
295+
out_env, change_name[:25].lower(), token_idx, change, transition.reward
296+
)
292297
return out_env, transition, change
293298

294299
def print_state(
@@ -316,7 +321,7 @@ def render_state(
316321
):
317322
"""Render the given state to a string suitable for printing to a log"""
318323
changed_problem = env_state.agent.problem
319-
if change is not None:
324+
if change is not None and change.result is not None:
320325
changed_problem = change.result.get_root().terminal_text
321326
output = """{:<25} | {}""".format(action_name.lower(), changed_problem)
322327

@@ -382,16 +387,16 @@ def get_agent_actions_count(self, env_state: MathyEnvState) -> int:
382387
return self.action_size * node_count
383388

384389
def get_token_at_index(
385-
self, expression: MathExpression, focus_index: int
390+
self, expression: MathExpression, index: int
386391
) -> Optional[MathExpression]:
387-
"""Get the token that is `focus_index` from the left of the expression"""
392+
"""Get the token that is `index` from the left of the expression"""
388393
count = 0
389394
result = None
390395

391396
def visit_fn(node, depth, data):
392397
nonlocal result, count
393398
result = node
394-
if count == focus_index:
399+
if count == index:
395400
return STOP
396401
count = count + 1
397402

0 commit comments

Comments
 (0)