99import numpy as np
1010from gym import spaces
1111from pydantic import BaseModel
12+ from wasabi import msg
1213
1314from mathy import EnvRewards , MathTypeKeysMax , MathyEnvState , is_terminal_transition
1415from mathy .envs .gym import MathyGymEnv
@@ -28,15 +29,31 @@ def mathy_dist(x: np.ndarray, y: np.ndarray) -> np.ndarray:
2829
2930
3031def swarm_solve (problem : str , config : SwarmConfig ):
31- from plangym import ParallelEnvironment
32- from plangym .env import Environment
33-
32+ from contextlib import contextmanager
33+ import sys , os
34+
35+ @contextmanager
36+ def suppress_stdout ():
37+ with open (os .devnull , "w" ) as devnull :
38+ old_stdout = sys .stdout
39+ sys .stdout = devnull
40+ try :
41+ yield
42+ finally :
43+ sys .stdout = old_stdout
44+
45+ # Suppress the plangym package error "Please install OpenAI retro" that happens
46+ # on import. We don't use it, so it's safe to ignore.
47+ with suppress_stdout ():
48+ from plangym import ParallelEnvironment
49+ from plangym .core import BaseEnvironment
3450 from fragile .core .env import DiscreteEnv
3551 from fragile .core .models import DiscreteModel
3652 from fragile .core .states import StatesEnv , StatesModel , StatesWalkers
3753 from fragile .core .swarm import Swarm
3854 from fragile .core .tree import HistoryTree
3955 from fragile .core .utils import StateDict
56+ from fragile .core .walkers import Walkers
4057
4158 class DiscreteMasked (DiscreteModel ):
4259 def sample (
@@ -69,14 +86,8 @@ def random_choice_prob_index(a, axis=1):
6986 ** kwargs ,
7087 )
7188
72- class MathySwarm (Swarm ):
73- def calculate_end_condition (self ) -> bool :
74- """Stop when a walker receives a positive terminal reward."""
75- max_reward = self .walkers .env_states .rewards .max ()
76- return max_reward > EnvRewards .WIN or self .walkers .calculate_end_condition ()
77-
7889 class FragileMathyEnv (DiscreteEnv ):
79- """FragileMathyEnv is an interface between the `plangym.Environment ` and a
90+ """FragileMathyEnv is an interface between the `plangym.BaseEnvironment ` and a
8091 Mathy environment."""
8192
8293 STATE_CLASS = StatesEnv
@@ -120,7 +131,7 @@ def reset(self, batch_size: int = 1, **kwargs) -> StatesEnv:
120131 )
121132 return new_states
122133
123- class FragileEnvironment (Environment ):
134+ class FragileEnvironment (BaseEnvironment ):
124135 """Fragile Environment for solving Mathy problems."""
125136
126137 problem : Optional [str ]
@@ -216,22 +227,24 @@ def reset(self, return_state: bool = False):
216227
217228 print_every = 1e6
218229
219- swarm = MathySwarm (
230+ swarm = Swarm (
220231 model = lambda env : DiscreteMasked (env = env ),
221232 env = lambda : FragileMathyEnv (env = env ),
222- tree = HistoryTree ,
233+ walkers = lambda ** kwargs : Walkers ( reward_limit = EnvRewards . WIN , ** kwargs ) ,
223234 n_walkers = config .n_walkers ,
224- max_iters = config .max_iters ,
225- prune_tree = True ,
226- reward_scale = 5 ,
227- distance_scale = 10 ,
235+ max_epochs = config .max_iters ,
236+ reward_scale = 1 ,
237+ distance_scale = 3 ,
228238 distance_function = mathy_dist ,
229239 minimize = False ,
230240 )
231241
232- _ = swarm .run (print_every = print_every )
233- best_ix = swarm .walkers .states .cum_rewards .argmax ()
234- best_id = swarm .walkers .states .id_walkers [best_ix ]
235- path = swarm .tree .get_branch (best_id , from_hash = True )
236- last_state = MathyEnvState .from_np (path [0 ][- 1 ])
237- env ._env .mathy .print_history (last_state )
242+ with msg .loading (f"Solving { problem } ..." ):
243+ _ = swarm .run (show_pbar = False )
244+
245+ if swarm .walkers .best_reward > EnvRewards .WIN :
246+ last_state = MathyEnvState .from_np (swarm .walkers .states .best_state )
247+ msg .good (f"Solved! { problem } = { last_state .agent .problem } " )
248+ env ._env .mathy .print_history (last_state )
249+ else :
250+ msg .fail (f"Failed to find a solution! :(" )
0 commit comments