Skip to content

Commit af99730

Browse files
feat(fragile): update to 0.0.44 (#35)
* feat(fragile): update to 0.0.40 - depends on a PR to fix `get_best_index` so that the best state and rewards are identified properly. * feat(fragile): update to 0.0.44 * chore: drop history tree and update r/d scales - the tree isn't needed for accessing the best attributes - bring the reward/distance scales back down inline with guillem's recommendations - suppress the openai retro install print during import of plangym (hacks)
1 parent bc019e2 commit af99730

File tree

3 files changed

+38
-27
lines changed

3 files changed

+38
-27
lines changed

libraries/mathy_python/mathy/agents/fragile.py

Lines changed: 36 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import numpy as np
1010
from gym import spaces
1111
from pydantic import BaseModel
12+
from wasabi import msg
1213

1314
from mathy import EnvRewards, MathTypeKeysMax, MathyEnvState, is_terminal_transition
1415
from mathy.envs.gym import MathyGymEnv
@@ -28,15 +29,31 @@ def mathy_dist(x: np.ndarray, y: np.ndarray) -> np.ndarray:
2829

2930

3031
def swarm_solve(problem: str, config: SwarmConfig):
31-
from plangym import ParallelEnvironment
32-
from plangym.env import Environment
33-
32+
from contextlib import contextmanager
33+
import sys, os
34+
35+
@contextmanager
36+
def suppress_stdout():
37+
with open(os.devnull, "w") as devnull:
38+
old_stdout = sys.stdout
39+
sys.stdout = devnull
40+
try:
41+
yield
42+
finally:
43+
sys.stdout = old_stdout
44+
45+
# Suppress the plangym package error "Please install OpenAI retro" that happens
46+
# on import. We don't use it, so it's safe to ignore.
47+
with suppress_stdout():
48+
from plangym import ParallelEnvironment
49+
from plangym.core import BaseEnvironment
3450
from fragile.core.env import DiscreteEnv
3551
from fragile.core.models import DiscreteModel
3652
from fragile.core.states import StatesEnv, StatesModel, StatesWalkers
3753
from fragile.core.swarm import Swarm
3854
from fragile.core.tree import HistoryTree
3955
from fragile.core.utils import StateDict
56+
from fragile.core.walkers import Walkers
4057

4158
class DiscreteMasked(DiscreteModel):
4259
def sample(
@@ -69,14 +86,8 @@ def random_choice_prob_index(a, axis=1):
6986
**kwargs,
7087
)
7188

72-
class MathySwarm(Swarm):
73-
def calculate_end_condition(self) -> bool:
74-
"""Stop when a walker receives a positive terminal reward."""
75-
max_reward = self.walkers.env_states.rewards.max()
76-
return max_reward > EnvRewards.WIN or self.walkers.calculate_end_condition()
77-
7889
class FragileMathyEnv(DiscreteEnv):
79-
"""FragileMathyEnv is an interface between the `plangym.Environment` and a
90+
"""FragileMathyEnv is an interface between the `plangym.BaseEnvironment` and a
8091
Mathy environment."""
8192

8293
STATE_CLASS = StatesEnv
@@ -120,7 +131,7 @@ def reset(self, batch_size: int = 1, **kwargs) -> StatesEnv:
120131
)
121132
return new_states
122133

123-
class FragileEnvironment(Environment):
134+
class FragileEnvironment(BaseEnvironment):
124135
"""Fragile Environment for solving Mathy problems."""
125136

126137
problem: Optional[str]
@@ -216,22 +227,24 @@ def reset(self, return_state: bool = False):
216227

217228
print_every = 1e6
218229

219-
swarm = MathySwarm(
230+
swarm = Swarm(
220231
model=lambda env: DiscreteMasked(env=env),
221232
env=lambda: FragileMathyEnv(env=env),
222-
tree=HistoryTree,
233+
walkers=lambda **kwargs: Walkers(reward_limit=EnvRewards.WIN, **kwargs),
223234
n_walkers=config.n_walkers,
224-
max_iters=config.max_iters,
225-
prune_tree=True,
226-
reward_scale=5,
227-
distance_scale=10,
235+
max_epochs=config.max_iters,
236+
reward_scale=1,
237+
distance_scale=3,
228238
distance_function=mathy_dist,
229239
minimize=False,
230240
)
231241

232-
_ = swarm.run(print_every=print_every)
233-
best_ix = swarm.walkers.states.cum_rewards.argmax()
234-
best_id = swarm.walkers.states.id_walkers[best_ix]
235-
path = swarm.tree.get_branch(best_id, from_hash=True)
236-
last_state = MathyEnvState.from_np(path[0][-1])
237-
env._env.mathy.print_history(last_state)
242+
with msg.loading(f"Solving {problem} ..."):
243+
_ = swarm.run(show_pbar=False)
244+
245+
if swarm.walkers.best_reward > EnvRewards.WIN:
246+
last_state = MathyEnvState.from_np(swarm.walkers.states.best_state)
247+
msg.good(f"Solved! {problem} = {last_state.agent.problem}")
248+
env._env.mathy.print_history(last_state)
249+
else:
250+
msg.fail(f"Failed to find a solution! :(")

libraries/mathy_python/requirements.txt

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,7 @@ click
77
wasabi
88

99
# OpenAI Gym environments
10-
#
11-
# > 0.12.5 has a conflicting version of cloudpickle with TensorFlow 2.0.0
12-
gym<=0.12.5
10+
gym
1311

1412
# ML model agents
1513
tensorflow>=2.1.0

libraries/mathy_python/setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def setup_package():
2323
DEVELOPMENT_MODULES = [line.strip() for line in file if "-e" not in line]
2424

2525
extras = {
26-
"fragile": ["fragile==0.0.27", "plangym==0.0.2"],
26+
"fragile": ["fragile==0.0.44"],
2727
"dev": DEVELOPMENT_MODULES,
2828
}
2929
extras["all"] = [item for group in extras.values() for item in group]

0 commit comments

Comments
 (0)