33
44import copy
55import time
6+ from typing import List , Optional , Union
67
78import gym
89import numpy as np
910from gym import spaces
10- from plangym import ParallelEnvironment
11- from plangym .env import Environment
12-
13- from fragile .core .env import DiscreteEnv
14- from fragile .core .models import DiscreteModel
15- from fragile .core .states import StatesEnv , StatesModel , StatesWalkers
16- from fragile .core .swarm import Swarm
17- from fragile .core .tree import HistoryTree
18- from fragile .core .utils import StateDict
11+ from pydantic import BaseModel
12+
1913from mathy import EnvRewards , MathTypeKeysMax , MathyEnvState , is_terminal_transition
2014from mathy .envs .gym import MathyGymEnv
21- from typing import List , Optional
2215
23- from pydantic import BaseModel
2416from .. import about
2517
2618
@@ -31,175 +23,187 @@ class SwarmConfig(BaseModel):
3123 max_iters : int = 100
3224
3325
34- class DiscreteMasked (DiscreteModel ):
35- def sample (
36- self ,
37- batch_size : int ,
38- model_states : StatesModel = None ,
39- env_states : StatesEnv = None ,
40- walkers_states : StatesWalkers = None ,
41- ** kwargs ,
42- ) -> StatesModel :
43- def random_choice_prob_index (a , axis = 1 ):
44- """Select random actions with probabilities across a batch.
45-
46- Source: https://stackoverflow.com/a/47722393/287335"""
47- r = np .expand_dims (self .random_state .rand (a .shape [1 - axis ]), axis = axis )
48- return (a .cumsum (axis = axis ) > r ).argmax (axis = axis )
49-
50- if env_states is not None :
51- # Each state is a vstack([node_ids, mask]) and we only want the mask.
52- #
53- # Swap columns and slice the last element to get it.
54- masks = np .transpose (env_states .observs , [0 , 2 , 1 ])[:, :, - 1 ]
55- actions = random_choice_prob_index (masks )
56- else :
57- actions = self .random_state .randint (0 , self .n_actions , size = batch_size )
58- return self .update_states_with_critic (
59- actions = actions , model_states = model_states , batch_size = batch_size , ** kwargs
60- )
61-
62-
63- class MathySwarm (Swarm ):
64- def calculate_end_condition (self ) -> bool :
65- """Stop when a walker receives a positive terminal reward."""
66- max_reward = self .walkers .env_states .rewards .max ()
67- return max_reward > EnvRewards .WIN or self .walkers .calculate_end_condition ()
68-
69-
70- class FragileMathyEnv (DiscreteEnv ):
71- """FragileMathyEnv is an interface between the `plangym.Environment` and a
72- Mathy environment."""
73-
74- STATE_CLASS = StatesEnv
75-
76- def get_params_dict (self ) -> StateDict :
77- super_params = super (FragileMathyEnv , self ).get_params_dict ()
78- params = {"terminals" : {"dtype" : np .bool_ }}
79- params .update (super_params )
80- return params
81-
82- def step (self , model_states : StatesModel , env_states : StatesEnv ) -> StatesEnv :
83- actions = model_states .actions .astype (np .int32 )
84- new_states , observs , rewards , terminals , infos = self ._env .step_batch (
85- actions = actions , states = env_states .states
86- )
87- oobs = [not inf .get ("valid" , False ) for inf in infos ]
88- new_state = self .states_from_data (
89- states = new_states ,
90- observs = observs ,
91- rewards = rewards ,
92- oobs = oobs ,
93- batch_size = len (actions ),
94- terminals = terminals ,
95- )
96- return new_state
97-
98- def reset (self , batch_size : int = 1 , ** kwargs ) -> StatesEnv :
99- state , obs = self ._env .reset ()
100- states = np .array ([copy .deepcopy (state ) for _ in range (batch_size )])
101- observs = np .array ([copy .deepcopy (obs ) for _ in range (batch_size )])
102- rewards = np .zeros (batch_size , dtype = np .float32 )
103- oobs = np .zeros (batch_size , dtype = np .bool_ )
104- terminals = np .zeros (batch_size , dtype = np .bool_ )
105- new_states = self .states_from_data (
106- states = states ,
107- observs = observs ,
108- rewards = rewards ,
109- oobs = oobs ,
110- batch_size = batch_size ,
111- terminals = terminals ,
112- )
113- return new_states
114-
115-
116- class FragileEnvironment (Environment ):
117- """Fragile Environment for solving Mathy problems."""
118-
119- problem : Optional [str ]
120-
121- def __init__ (
122- self ,
123- name : str ,
124- environment : str = "poly" ,
125- difficulty : str = "normal" ,
126- problem : str = None ,
127- max_steps : int = 64 ,
128- ** kwargs ,
129- ):
130- super (FragileEnvironment , self ).__init__ (name = name )
131- self ._env : MathyGymEnv = gym .make (
132- f"mathy-{ environment } -{ difficulty } -v0" ,
133- np_observation = True ,
134- error_invalid = False ,
135- env_problem = problem ,
136- ** kwargs ,
137- )
138- self .observation_space = spaces .Box (
139- low = 0 , high = MathTypeKeysMax , shape = (256 , 256 , 1 ), dtype = np .uint8 ,
140- )
141- self .problem = problem
142- self .max_steps = max_steps
143- self .init_env ()
144-
145- def init_env (self ):
146- env = self ._env
147- env .reset ()
148- self .action_space = spaces .Discrete (self ._env .action_size )
149- self .observation_space = (
150- self ._env .observation_space
151- if self .observation_space is None
152- else self .observation_space
153- )
154-
155- def __getattr__ (self , item ):
156- return getattr (self ._env , item )
157-
158- def get_state (self ) -> np .ndarray :
159- assert self ._env .state is not None , "env required to get_state"
160- return self ._env .state .to_np ()
161-
162- def set_state (self , state : np .ndarray ):
163- assert self ._env is not None , "env required to set_state"
164- self ._env .state = MathyEnvState .from_np (state )
165- return state
166-
167- def step (
168- self , action : np .ndarray , state : np .ndarray = None , n_repeat_action : int = None
169- ) -> tuple :
170- assert self ._env is not None , "env required to step"
171- assert state is not None , "only works with state stepping"
172- self .set_state (state )
173- obs , reward , _ , info = self ._env .step (action )
174- terminal = info .get ("done" , False )
175- new_state = self .get_state ()
176- return new_state , obs , reward , terminal , info
177-
178- def step_batch (
179- self , actions , states = None , n_repeat_action : [int , np .ndarray ] = None
180- ) -> tuple :
181- data = [self .step (action , state ) for action , state in zip (actions , states )]
182- new_states , observs , rewards , terminals , infos = [], [], [], [], []
183- for d in data :
184- new_state , obs , _reward , end , info = d
185- new_states .append (new_state )
186- observs .append (obs )
187- rewards .append (_reward )
188- terminals .append (end )
189- infos .append (info )
190- return new_states , observs , rewards , terminals , infos
191-
192- def reset (self , return_state : bool = False ):
193- assert self ._env is not None , "env required to reset"
194- obs = self ._env .reset ()
195- return self .get_state (), obs
196-
197-
19826def mathy_dist (x : np .ndarray , y : np .ndarray ) -> np .ndarray :
19927 return np .linalg .norm (x - y , axis = 1 )
20028
20129
20230def swarm_solve (problem : str , config : SwarmConfig ):
31+ from plangym import ParallelEnvironment
32+ from plangym .env import Environment
33+
34+ from fragile .core .env import DiscreteEnv
35+ from fragile .core .models import DiscreteModel
36+ from fragile .core .states import StatesEnv , StatesModel , StatesWalkers
37+ from fragile .core .swarm import Swarm
38+ from fragile .core .tree import HistoryTree
39+ from fragile .core .utils import StateDict
40+
41+ class DiscreteMasked (DiscreteModel ):
42+ def sample (
43+ self ,
44+ batch_size : int ,
45+ model_states : StatesModel = None ,
46+ env_states : StatesEnv = None ,
47+ walkers_states : StatesWalkers = None ,
48+ ** kwargs ,
49+ ) -> StatesModel :
50+ def random_choice_prob_index (a , axis = 1 ):
51+ """Select random actions with probabilities across a batch.
52+
53+ Source: https://stackoverflow.com/a/47722393/287335"""
54+ r = np .expand_dims (self .random_state .rand (a .shape [1 - axis ]), axis = axis )
55+ return (a .cumsum (axis = axis ) > r ).argmax (axis = axis )
56+
57+ if env_states is not None :
58+ # Each state is a vstack([node_ids, mask]) and we only want the mask.
59+ #
60+ # Swap columns and slice the last element to get it.
61+ masks = np .transpose (env_states .observs , [0 , 2 , 1 ])[:, :, - 1 ]
62+ actions = random_choice_prob_index (masks )
63+ else :
64+ actions = self .random_state .randint (0 , self .n_actions , size = batch_size )
65+ return self .update_states_with_critic (
66+ actions = actions ,
67+ model_states = model_states ,
68+ batch_size = batch_size ,
69+ ** kwargs ,
70+ )
71+
72+ class MathySwarm (Swarm ):
73+ def calculate_end_condition (self ) -> bool :
74+ """Stop when a walker receives a positive terminal reward."""
75+ max_reward = self .walkers .env_states .rewards .max ()
76+ return max_reward > EnvRewards .WIN or self .walkers .calculate_end_condition ()
77+
78+ class FragileMathyEnv (DiscreteEnv ):
79+ """FragileMathyEnv is an interface between the `plangym.Environment` and a
80+ Mathy environment."""
81+
82+ STATE_CLASS = StatesEnv
83+
84+ def get_params_dict (self ) -> StateDict :
85+ super_params = super (FragileMathyEnv , self ).get_params_dict ()
86+ params = {"terminals" : {"dtype" : np .bool_ }}
87+ params .update (super_params )
88+ return params
89+
90+ def step (self , model_states : StatesModel , env_states : StatesEnv ) -> StatesEnv :
91+ actions = model_states .actions .astype (np .int32 )
92+ new_states , observs , rewards , terminals , infos = self ._env .step_batch (
93+ actions = actions , states = env_states .states
94+ )
95+ oobs = [not inf .get ("valid" , False ) for inf in infos ]
96+ new_state = self .states_from_data (
97+ states = new_states ,
98+ observs = observs ,
99+ rewards = rewards ,
100+ oobs = oobs ,
101+ batch_size = len (actions ),
102+ terminals = terminals ,
103+ )
104+ return new_state
105+
106+ def reset (self , batch_size : int = 1 , ** kwargs ) -> StatesEnv :
107+ state , obs = self ._env .reset ()
108+ states = np .array ([copy .deepcopy (state ) for _ in range (batch_size )])
109+ observs = np .array ([copy .deepcopy (obs ) for _ in range (batch_size )])
110+ rewards = np .zeros (batch_size , dtype = np .float32 )
111+ oobs = np .zeros (batch_size , dtype = np .bool_ )
112+ terminals = np .zeros (batch_size , dtype = np .bool_ )
113+ new_states = self .states_from_data (
114+ states = states ,
115+ observs = observs ,
116+ rewards = rewards ,
117+ oobs = oobs ,
118+ batch_size = batch_size ,
119+ terminals = terminals ,
120+ )
121+ return new_states
122+
123+ class FragileEnvironment (Environment ):
124+ """Fragile Environment for solving Mathy problems."""
125+
126+ problem : Optional [str ]
127+
128+ def __init__ (
129+ self ,
130+ name : str ,
131+ environment : str = "poly" ,
132+ difficulty : str = "normal" ,
133+ problem : str = None ,
134+ max_steps : int = 64 ,
135+ ** kwargs ,
136+ ):
137+ super (FragileEnvironment , self ).__init__ (name = name )
138+ self ._env : MathyGymEnv = gym .make (
139+ f"mathy-{ environment } -{ difficulty } -v0" ,
140+ np_observation = True ,
141+ error_invalid = False ,
142+ env_problem = problem ,
143+ ** kwargs ,
144+ )
145+ self .observation_space = spaces .Box (
146+ low = 0 , high = MathTypeKeysMax , shape = (256 , 256 , 1 ), dtype = np .uint8 ,
147+ )
148+ self .problem = problem
149+ self .max_steps = max_steps
150+ self .init_env ()
151+
152+ def init_env (self ):
153+ env = self ._env
154+ env .reset ()
155+ self .action_space = spaces .Discrete (self ._env .action_size )
156+ self .observation_space = (
157+ self ._env .observation_space
158+ if self .observation_space is None
159+ else self .observation_space
160+ )
161+
162+ def __getattr__ (self , item ):
163+ return getattr (self ._env , item )
164+
165+ def get_state (self ) -> np .ndarray :
166+ assert self ._env .state is not None , "env required to get_state"
167+ return self ._env .state .to_np ()
168+
169+ def set_state (self , state : np .ndarray ):
170+ assert self ._env is not None , "env required to set_state"
171+ self ._env .state = MathyEnvState .from_np (state )
172+ return state
173+
174+ def step (
175+ self ,
176+ action : np .ndarray ,
177+ state : np .ndarray = None ,
178+ n_repeat_action : int = None ,
179+ ) -> tuple :
180+ assert self ._env is not None , "env required to step"
181+ assert state is not None , "only works with state stepping"
182+ self .set_state (state )
183+ obs , reward , _ , info = self ._env .step (action )
184+ terminal = info .get ("done" , False )
185+ new_state = self .get_state ()
186+ return new_state , obs , reward , terminal , info
187+
188+ def step_batch (
189+ self , actions , states = None , n_repeat_action : Union [int , np .ndarray ] = None
190+ ) -> tuple :
191+ data = [self .step (action , state ) for action , state in zip (actions , states )]
192+ new_states , observs , rewards , terminals , infos = [], [], [], [], []
193+ for d in data :
194+ new_state , obs , _reward , end , info = d
195+ new_states .append (new_state )
196+ observs .append (obs )
197+ rewards .append (_reward )
198+ terminals .append (end )
199+ infos .append (info )
200+ return new_states , observs , rewards , terminals , infos
201+
202+ def reset (self , return_state : bool = False ):
203+ assert self ._env is not None , "env required to reset"
204+ obs = self ._env .reset ()
205+ return self .get_state (), obs
206+
203207 if config .use_mp :
204208 env = ParallelEnvironment (
205209 env_class = FragileEnvironment ,
0 commit comments