11"""Use Fractal Monte Carlo search in order to solve mathy problems without a
22trained neural network."""
3- from typing import Dict , List , Optional , Union
3+ from typing import Any , Dict , List , Optional , Union
44
55import numpy as np
66from fragile .core .env import DiscreteEnv
@@ -70,7 +70,7 @@ def __init__(
7070 name : str ,
7171 environment : str = "poly" ,
7272 difficulty : str = "easy" ,
73- problem : str = None ,
73+ problem : Optional [ str ] = None ,
7474 max_steps : int = 64 ,
7575 ** kwargs ,
7676 ):
@@ -122,7 +122,7 @@ def __init__(
122122 name : str ,
123123 environment : str = "poly" ,
124124 difficulty : str = "normal" ,
125- problem : str = None ,
125+ problem : Optional [ str ] = None ,
126126 max_steps : int = 64 ,
127127 ** kwargs ,
128128 ):
@@ -132,10 +132,9 @@ def __init__(
132132
133133 self ._env : MathyGymEnv = gym .make (
134134 f"mathy-{ environment } -{ difficulty } -v0" ,
135- np_observation = True ,
136- mask_as_probabilities = True ,
137135 invalid_action_response = "terminal" ,
138136 env_problem = problem ,
137+ mask_as_probabilities = True ,
139138 ** kwargs ,
140139 )
141140 self .observation_space = spaces .Box (
@@ -148,7 +147,7 @@ def __init__(
148147
149148 def get_state (self ) -> np .ndarray :
150149 assert self ._env .state is not None , "env required to get_state"
151- return self ._env .state .to_np ()
150+ return self ._env .state .to_np (768 )
152151
153152 def set_state (self , state : np .ndarray ):
154153 assert self ._env is not None , "env required to set_state"
@@ -165,7 +164,7 @@ def step(self, action: int, state: np.ndarray = None) -> tuple:
165164 return new_state , obs , reward , oob , info
166165
167166 def step_batch (
168- self , actions , states = None , n_repeat_action : Union [int , np .ndarray ] = None
167+ self , actions , states : Optional [ Any ] = None , n_repeat_action : Optional [ Union [int , np .ndarray ] ] = None
169168 ) -> tuple :
170169 data = [self .step (action , state ) for action , state in zip (actions , states )]
171170 new_states , observs , rewards , terminals , infos = [], [], [], [], []
@@ -212,7 +211,7 @@ def mathy_swarm(config: SwarmConfig, env_callable=None) -> Swarm:
212211def swarm_solve (
213212 problems : Union [List [str ], str ],
214213 config : SwarmConfig ,
215- max_steps : Union [List [int ], int ] = 128 ,
214+ max_steps : Union [List [int ], int ] = 256 ,
216215 silent : bool = False ,
217216) -> Swarm :
218217 single_problem : bool = isinstance (problems , str )
0 commit comments