1- from typing import List
1+ """Use Fractal Monte Carlo search in order to solve mathy problems without a
2+ trained neural network."""
3+
24import copy
3- import sys
45import time
5- import traceback
66
77import gym
8- import holoviews
98import numpy as np
10- import srsly
119from gym import spaces
10+ from plangym import ParallelEnvironment
11+ from plangym .env import Environment
1212
1313from fragile .core .env import DiscreteEnv
14- from fragile .core .models import Bounds , DiscreteModel
14+ from fragile .core .models import DiscreteModel
1515from fragile .core .states import StatesEnv , StatesModel , StatesWalkers
1616from fragile .core .swarm import Swarm
1717from fragile .core .tree import HistoryTree
1818from fragile .core .utils import StateDict
19- from fragile .core .walkers import Walkers
20- from fragile .dataviz import AtariViz , LandscapeViz , Summary , SwarmViz , SwarmViz1D
21- from mathy import MathTypeKeysMax , MathyEnvState , is_terminal_transition , EnvRewards
19+ from mathy import EnvRewards , MathTypeKeysMax , MathyEnvState , is_terminal_transition
2220from mathy .envs .gym import MathyGymEnv
23- from plangym import ParallelEnvironment
24- from plangym .env import Environment
25- import os
21+ from typing import List , Optional
22+
23+ from pydantic import BaseModel
24+ from .. import about
2625
2726
28- # Print explored mathy states when True
29- verbose = False
30- use_mp = True
31- prune_tree = True
32- max_iters = 100
33- reward_scale = 5
34- distance_scale = 10
35- minimize = False
36- use_vis = False
27+ class SwarmConfig (BaseModel ):
28+ verbose : bool = False
29+ use_mp : bool = True
30+ n_walkers : int = 512
31+ max_iters : int = 100
3732
3833
3934class DiscreteMasked (DiscreteModel ):
@@ -80,23 +75,23 @@ class FragileMathyEnv(DiscreteEnv):
8075
8176 def get_params_dict (self ) -> StateDict :
8277 super_params = super (FragileMathyEnv , self ).get_params_dict ()
83- params = {"game_ends " : {"dtype" : np .bool_ }}
78+ params = {"terminals " : {"dtype" : np .bool_ }}
8479 params .update (super_params )
8580 return params
8681
8782 def step (self , model_states : StatesModel , env_states : StatesEnv ) -> StatesEnv :
8883 actions = model_states .actions .astype (np .int32 )
89- new_states , observs , rewards , game_ends , infos = self ._env .step_batch (
84+ new_states , observs , rewards , terminals , infos = self ._env .step_batch (
9085 actions = actions , states = env_states .states
9186 )
92- ends = [not inf .get ("valid" , False ) for inf in infos ]
87+ oobs = [not inf .get ("valid" , False ) for inf in infos ]
9388 new_state = self .states_from_data (
9489 states = new_states ,
9590 observs = observs ,
9691 rewards = rewards ,
97- ends = ends ,
92+ oobs = oobs ,
9893 batch_size = len (actions ),
99- game_ends = game_ends ,
94+ terminals = terminals ,
10095 )
10196 return new_state
10297
@@ -105,40 +100,47 @@ def reset(self, batch_size: int = 1, **kwargs) -> StatesEnv:
105100 states = np .array ([copy .deepcopy (state ) for _ in range (batch_size )])
106101 observs = np .array ([copy .deepcopy (obs ) for _ in range (batch_size )])
107102 rewards = np .zeros (batch_size , dtype = np .float32 )
108- ends = np .zeros (batch_size , dtype = np .bool_ )
109- game_ends = np .zeros (batch_size , dtype = np .bool_ )
103+ oobs = np .zeros (batch_size , dtype = np .bool_ )
104+ terminals = np .zeros (batch_size , dtype = np .bool_ )
110105 new_states = self .states_from_data (
111106 states = states ,
112107 observs = observs ,
113108 rewards = rewards ,
114- ends = ends ,
109+ oobs = oobs ,
115110 batch_size = batch_size ,
116- game_ends = game_ends ,
111+ terminals = terminals ,
117112 )
118113 return new_states
119114
120115
121116class FragileEnvironment (Environment ):
122117 """Fragile Environment for solving Mathy problems."""
123118
119+ problem : Optional [str ]
120+
124121 def __init__ (
125122 self ,
126123 name : str ,
127124 environment : str = "poly" ,
128125 difficulty : str = "normal" ,
126+ problem : str = None ,
127+ max_steps : int = 64 ,
129128 wrappers = None ,
130129 ** kwargs ,
131130 ):
132131 super (FragileEnvironment , self ).__init__ (name = name )
133132 self ._env : MathyGymEnv = gym .make (
134133 f"mathy-{ environment } -{ difficulty } -v0" ,
135- verbose = verbose ,
136134 np_observation = True ,
135+ error_invalid = False ,
136+ env_problem = problem ,
137137 ** kwargs ,
138138 )
139139 self .observation_space = spaces .Box (
140140 low = 0 , high = MathTypeKeysMax , shape = (256 , 256 , 1 ), dtype = np .uint8 ,
141141 )
142+ self .problem = problem
143+ self .max_steps = max_steps
142144 self .wrappers = wrappers
143145 self .init_env ()
144146
@@ -159,7 +161,7 @@ def __getattr__(self, item):
159161 return getattr (self ._env , item )
160162
161163 def get_state (self ) -> np .ndarray :
162- assert self ._env is not None , "env required to get_state"
164+ assert self ._env . state is not None , "env required to get_state"
163165 return self ._env .state .to_np ()
164166
165167 def set_state (self , state : np .ndarray ):
@@ -238,38 +240,69 @@ def solve_problem(problem_env: str, problem_difficulty: str = "easy"):
238240 # print_every = 5 if problem_difficulty == "hard" else 10
239241 print_every = 1e6
240242 n_walkers = 768 if problem_difficulty == "hard" else 256
243+ max_iters = 100
241244
242245 swarm = MathySwarm (
243246 model = lambda env : DiscreteMasked (env = env ),
244247 env = lambda : FragileMathyEnv (env = env ),
245248 tree = HistoryTree ,
246249 n_walkers = n_walkers ,
247250 max_iters = max_iters ,
248- prune_tree = prune_tree ,
249- reward_scale = reward_scale ,
250- distance_scale = distance_scale ,
251+ prune_tree = True ,
252+ reward_scale = 5 ,
253+ distance_scale = 10 ,
251254 distance_function = mathy_dist ,
252- minimize = minimize ,
255+ minimize = False ,
253256 )
254257
255- if not use_vis :
256- _ = swarm .run (print_every = print_every )
258+ _ = swarm .run (print_every = print_every )
259+ best_ix = swarm .walkers .states .cum_rewards .argmax ()
260+ best_id = swarm .walkers .states .id_walkers [best_ix ]
261+ path = swarm .tree .get_branch (best_id , from_hash = True )
262+ last_state = MathyEnvState .from_np (path [0 ][- 1 ])
263+ env ._env .mathy .print_history (last_state )
264+
265+
266+ def swarm_solve (problem : str , config : SwarmConfig ):
267+ if config .use_mp :
268+ env = ParallelEnvironment (
269+ env_class = FragileEnvironment ,
270+ name = "mathy_v0" ,
271+ problem = problem ,
272+ repeat_problem = True ,
273+ )
257274 else :
258- holoviews .extension ("bokeh" )
259- viz = SwarmViz1D (swarm , stream_interval = print_every )
260- viz .plot ()
261- viz .run (print_every = print_every )
275+ env = FragileEnvironment (name = "mathy_v0" , problem = problem , repeat_problem = True )
262276
277+ print_every = 1e6
278+
279+ swarm = MathySwarm (
280+ model = lambda env : DiscreteMasked (env = env ),
281+ env = lambda : FragileMathyEnv (env = env ),
282+ tree = HistoryTree ,
283+ n_walkers = config .n_walkers ,
284+ max_iters = config .max_iters ,
285+ prune_tree = True ,
286+ reward_scale = 5 ,
287+ distance_scale = 10 ,
288+ distance_function = mathy_dist ,
289+ minimize = False ,
290+ )
291+
292+ _ = swarm .run (print_every = print_every )
263293 best_ix = swarm .walkers .states .cum_rewards .argmax ()
264294 best_id = swarm .walkers .states .id_walkers [best_ix ]
265295 path = swarm .tree .get_branch (best_id , from_hash = True )
266- env .render (last_action = - 1 , last_reward = 0.0 )
267- for s , a in zip (path [0 ][1 :], path [1 ]):
268- _ , _ , r , _ , info = env .step (state = s , action = a )
269- env .render (last_action = a , last_reward = r )
270- time .sleep (0.01 )
296+ last_state = MathyEnvState .from_np (path [0 ][- 1 ])
297+ env ._env .mathy .print_history (last_state )
298+
271299
300+ if __name__ == "__main__" :
301+ # Solve one problem of each type and difficulty
302+ for e_type in ["complex" , "binomial" , "poly" , "poly-blockers" ]:
303+ for e_difficulty in ["easy" , "normal" ]:
304+ solve_problem (e_type , e_difficulty )
272305
273- for e_type in [ "complex" , "binomial" , "poly" , "poly-blockers" ]:
274- for e_difficulty in [ "easy" , "normal" , "hard" ] :
275- solve_problem (e_type , e_difficulty )
306+ # # Solve a bunch of problems of the same type
307+ # for i in range(10) :
308+ # solve_problem("poly", "easy" )
0 commit comments