Skip to content

Commit b6203fd

Browse files
fix(pypi): optional swarm install with mathy[fragile]
- this is because the current fragile requires `pillow-simd` which doesn't cleanly install on OSX - todo: review and submit fragile PR to remove these deps from the core package. It seems pillow isn't needed for Mathy use-cases. /fyi @Guillemdb
1 parent f929081 commit b6203fd

File tree

7 files changed

+191
-185
lines changed

7 files changed

+191
-185
lines changed

libraries/mathy_alpha_sm/tools/setup.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,5 +11,5 @@ fi
1111
. .env/bin/activate
1212
echo "Installing/updating requirements..."
1313
pip install -e ../mathy_python
14-
pip install -e .[dev]
14+
pip install -e .
1515

libraries/mathy_python/mathy/agents/fragile.py

Lines changed: 179 additions & 175 deletions
Original file line numberDiff line numberDiff line change
@@ -3,24 +3,16 @@
33

44
import copy
55
import time
6+
from typing import List, Optional, Union
67

78
import gym
89
import numpy as np
910
from gym import spaces
10-
from plangym import ParallelEnvironment
11-
from plangym.env import Environment
12-
13-
from fragile.core.env import DiscreteEnv
14-
from fragile.core.models import DiscreteModel
15-
from fragile.core.states import StatesEnv, StatesModel, StatesWalkers
16-
from fragile.core.swarm import Swarm
17-
from fragile.core.tree import HistoryTree
18-
from fragile.core.utils import StateDict
11+
from pydantic import BaseModel
12+
1913
from mathy import EnvRewards, MathTypeKeysMax, MathyEnvState, is_terminal_transition
2014
from mathy.envs.gym import MathyGymEnv
21-
from typing import List, Optional
2215

23-
from pydantic import BaseModel
2416
from .. import about
2517

2618

@@ -31,175 +23,187 @@ class SwarmConfig(BaseModel):
3123
max_iters: int = 100
3224

3325

34-
class DiscreteMasked(DiscreteModel):
35-
def sample(
36-
self,
37-
batch_size: int,
38-
model_states: StatesModel = None,
39-
env_states: StatesEnv = None,
40-
walkers_states: StatesWalkers = None,
41-
**kwargs,
42-
) -> StatesModel:
43-
def random_choice_prob_index(a, axis=1):
44-
"""Select random actions with probabilities across a batch.
45-
46-
Source: https://stackoverflow.com/a/47722393/287335"""
47-
r = np.expand_dims(self.random_state.rand(a.shape[1 - axis]), axis=axis)
48-
return (a.cumsum(axis=axis) > r).argmax(axis=axis)
49-
50-
if env_states is not None:
51-
# Each state is a vstack([node_ids, mask]) and we only want the mask.
52-
#
53-
# Swap columns and slice the last element to get it.
54-
masks = np.transpose(env_states.observs, [0, 2, 1])[:, :, -1]
55-
actions = random_choice_prob_index(masks)
56-
else:
57-
actions = self.random_state.randint(0, self.n_actions, size=batch_size)
58-
return self.update_states_with_critic(
59-
actions=actions, model_states=model_states, batch_size=batch_size, **kwargs
60-
)
61-
62-
63-
class MathySwarm(Swarm):
64-
def calculate_end_condition(self) -> bool:
65-
"""Stop when a walker receives a positive terminal reward."""
66-
max_reward = self.walkers.env_states.rewards.max()
67-
return max_reward > EnvRewards.WIN or self.walkers.calculate_end_condition()
68-
69-
70-
class FragileMathyEnv(DiscreteEnv):
71-
"""FragileMathyEnv is an interface between the `plangym.Environment` and a
72-
Mathy environment."""
73-
74-
STATE_CLASS = StatesEnv
75-
76-
def get_params_dict(self) -> StateDict:
77-
super_params = super(FragileMathyEnv, self).get_params_dict()
78-
params = {"terminals": {"dtype": np.bool_}}
79-
params.update(super_params)
80-
return params
81-
82-
def step(self, model_states: StatesModel, env_states: StatesEnv) -> StatesEnv:
83-
actions = model_states.actions.astype(np.int32)
84-
new_states, observs, rewards, terminals, infos = self._env.step_batch(
85-
actions=actions, states=env_states.states
86-
)
87-
oobs = [not inf.get("valid", False) for inf in infos]
88-
new_state = self.states_from_data(
89-
states=new_states,
90-
observs=observs,
91-
rewards=rewards,
92-
oobs=oobs,
93-
batch_size=len(actions),
94-
terminals=terminals,
95-
)
96-
return new_state
97-
98-
def reset(self, batch_size: int = 1, **kwargs) -> StatesEnv:
99-
state, obs = self._env.reset()
100-
states = np.array([copy.deepcopy(state) for _ in range(batch_size)])
101-
observs = np.array([copy.deepcopy(obs) for _ in range(batch_size)])
102-
rewards = np.zeros(batch_size, dtype=np.float32)
103-
oobs = np.zeros(batch_size, dtype=np.bool_)
104-
terminals = np.zeros(batch_size, dtype=np.bool_)
105-
new_states = self.states_from_data(
106-
states=states,
107-
observs=observs,
108-
rewards=rewards,
109-
oobs=oobs,
110-
batch_size=batch_size,
111-
terminals=terminals,
112-
)
113-
return new_states
114-
115-
116-
class FragileEnvironment(Environment):
117-
"""Fragile Environment for solving Mathy problems."""
118-
119-
problem: Optional[str]
120-
121-
def __init__(
122-
self,
123-
name: str,
124-
environment: str = "poly",
125-
difficulty: str = "normal",
126-
problem: str = None,
127-
max_steps: int = 64,
128-
**kwargs,
129-
):
130-
super(FragileEnvironment, self).__init__(name=name)
131-
self._env: MathyGymEnv = gym.make(
132-
f"mathy-{environment}-{difficulty}-v0",
133-
np_observation=True,
134-
error_invalid=False,
135-
env_problem=problem,
136-
**kwargs,
137-
)
138-
self.observation_space = spaces.Box(
139-
low=0, high=MathTypeKeysMax, shape=(256, 256, 1), dtype=np.uint8,
140-
)
141-
self.problem = problem
142-
self.max_steps = max_steps
143-
self.init_env()
144-
145-
def init_env(self):
146-
env = self._env
147-
env.reset()
148-
self.action_space = spaces.Discrete(self._env.action_size)
149-
self.observation_space = (
150-
self._env.observation_space
151-
if self.observation_space is None
152-
else self.observation_space
153-
)
154-
155-
def __getattr__(self, item):
156-
return getattr(self._env, item)
157-
158-
def get_state(self) -> np.ndarray:
159-
assert self._env.state is not None, "env required to get_state"
160-
return self._env.state.to_np()
161-
162-
def set_state(self, state: np.ndarray):
163-
assert self._env is not None, "env required to set_state"
164-
self._env.state = MathyEnvState.from_np(state)
165-
return state
166-
167-
def step(
168-
self, action: np.ndarray, state: np.ndarray = None, n_repeat_action: int = None
169-
) -> tuple:
170-
assert self._env is not None, "env required to step"
171-
assert state is not None, "only works with state stepping"
172-
self.set_state(state)
173-
obs, reward, _, info = self._env.step(action)
174-
terminal = info.get("done", False)
175-
new_state = self.get_state()
176-
return new_state, obs, reward, terminal, info
177-
178-
def step_batch(
179-
self, actions, states=None, n_repeat_action: [int, np.ndarray] = None
180-
) -> tuple:
181-
data = [self.step(action, state) for action, state in zip(actions, states)]
182-
new_states, observs, rewards, terminals, infos = [], [], [], [], []
183-
for d in data:
184-
new_state, obs, _reward, end, info = d
185-
new_states.append(new_state)
186-
observs.append(obs)
187-
rewards.append(_reward)
188-
terminals.append(end)
189-
infos.append(info)
190-
return new_states, observs, rewards, terminals, infos
191-
192-
def reset(self, return_state: bool = False):
193-
assert self._env is not None, "env required to reset"
194-
obs = self._env.reset()
195-
return self.get_state(), obs
196-
197-
19826
def mathy_dist(x: np.ndarray, y: np.ndarray) -> np.ndarray:
19927
return np.linalg.norm(x - y, axis=1)
20028

20129

20230
def swarm_solve(problem: str, config: SwarmConfig):
31+
from plangym import ParallelEnvironment
32+
from plangym.env import Environment
33+
34+
from fragile.core.env import DiscreteEnv
35+
from fragile.core.models import DiscreteModel
36+
from fragile.core.states import StatesEnv, StatesModel, StatesWalkers
37+
from fragile.core.swarm import Swarm
38+
from fragile.core.tree import HistoryTree
39+
from fragile.core.utils import StateDict
40+
41+
class DiscreteMasked(DiscreteModel):
42+
def sample(
43+
self,
44+
batch_size: int,
45+
model_states: StatesModel = None,
46+
env_states: StatesEnv = None,
47+
walkers_states: StatesWalkers = None,
48+
**kwargs,
49+
) -> StatesModel:
50+
def random_choice_prob_index(a, axis=1):
51+
"""Select random actions with probabilities across a batch.
52+
53+
Source: https://stackoverflow.com/a/47722393/287335"""
54+
r = np.expand_dims(self.random_state.rand(a.shape[1 - axis]), axis=axis)
55+
return (a.cumsum(axis=axis) > r).argmax(axis=axis)
56+
57+
if env_states is not None:
58+
# Each state is a vstack([node_ids, mask]) and we only want the mask.
59+
#
60+
# Swap columns and slice the last element to get it.
61+
masks = np.transpose(env_states.observs, [0, 2, 1])[:, :, -1]
62+
actions = random_choice_prob_index(masks)
63+
else:
64+
actions = self.random_state.randint(0, self.n_actions, size=batch_size)
65+
return self.update_states_with_critic(
66+
actions=actions,
67+
model_states=model_states,
68+
batch_size=batch_size,
69+
**kwargs,
70+
)
71+
72+
class MathySwarm(Swarm):
73+
def calculate_end_condition(self) -> bool:
74+
"""Stop when a walker receives a positive terminal reward."""
75+
max_reward = self.walkers.env_states.rewards.max()
76+
return max_reward > EnvRewards.WIN or self.walkers.calculate_end_condition()
77+
78+
class FragileMathyEnv(DiscreteEnv):
79+
"""FragileMathyEnv is an interface between the `plangym.Environment` and a
80+
Mathy environment."""
81+
82+
STATE_CLASS = StatesEnv
83+
84+
def get_params_dict(self) -> StateDict:
85+
super_params = super(FragileMathyEnv, self).get_params_dict()
86+
params = {"terminals": {"dtype": np.bool_}}
87+
params.update(super_params)
88+
return params
89+
90+
def step(self, model_states: StatesModel, env_states: StatesEnv) -> StatesEnv:
91+
actions = model_states.actions.astype(np.int32)
92+
new_states, observs, rewards, terminals, infos = self._env.step_batch(
93+
actions=actions, states=env_states.states
94+
)
95+
oobs = [not inf.get("valid", False) for inf in infos]
96+
new_state = self.states_from_data(
97+
states=new_states,
98+
observs=observs,
99+
rewards=rewards,
100+
oobs=oobs,
101+
batch_size=len(actions),
102+
terminals=terminals,
103+
)
104+
return new_state
105+
106+
def reset(self, batch_size: int = 1, **kwargs) -> StatesEnv:
107+
state, obs = self._env.reset()
108+
states = np.array([copy.deepcopy(state) for _ in range(batch_size)])
109+
observs = np.array([copy.deepcopy(obs) for _ in range(batch_size)])
110+
rewards = np.zeros(batch_size, dtype=np.float32)
111+
oobs = np.zeros(batch_size, dtype=np.bool_)
112+
terminals = np.zeros(batch_size, dtype=np.bool_)
113+
new_states = self.states_from_data(
114+
states=states,
115+
observs=observs,
116+
rewards=rewards,
117+
oobs=oobs,
118+
batch_size=batch_size,
119+
terminals=terminals,
120+
)
121+
return new_states
122+
123+
class FragileEnvironment(Environment):
124+
"""Fragile Environment for solving Mathy problems."""
125+
126+
problem: Optional[str]
127+
128+
def __init__(
129+
self,
130+
name: str,
131+
environment: str = "poly",
132+
difficulty: str = "normal",
133+
problem: str = None,
134+
max_steps: int = 64,
135+
**kwargs,
136+
):
137+
super(FragileEnvironment, self).__init__(name=name)
138+
self._env: MathyGymEnv = gym.make(
139+
f"mathy-{environment}-{difficulty}-v0",
140+
np_observation=True,
141+
error_invalid=False,
142+
env_problem=problem,
143+
**kwargs,
144+
)
145+
self.observation_space = spaces.Box(
146+
low=0, high=MathTypeKeysMax, shape=(256, 256, 1), dtype=np.uint8,
147+
)
148+
self.problem = problem
149+
self.max_steps = max_steps
150+
self.init_env()
151+
152+
def init_env(self):
153+
env = self._env
154+
env.reset()
155+
self.action_space = spaces.Discrete(self._env.action_size)
156+
self.observation_space = (
157+
self._env.observation_space
158+
if self.observation_space is None
159+
else self.observation_space
160+
)
161+
162+
def __getattr__(self, item):
163+
return getattr(self._env, item)
164+
165+
def get_state(self) -> np.ndarray:
166+
assert self._env.state is not None, "env required to get_state"
167+
return self._env.state.to_np()
168+
169+
def set_state(self, state: np.ndarray):
170+
assert self._env is not None, "env required to set_state"
171+
self._env.state = MathyEnvState.from_np(state)
172+
return state
173+
174+
def step(
175+
self,
176+
action: np.ndarray,
177+
state: np.ndarray = None,
178+
n_repeat_action: int = None,
179+
) -> tuple:
180+
assert self._env is not None, "env required to step"
181+
assert state is not None, "only works with state stepping"
182+
self.set_state(state)
183+
obs, reward, _, info = self._env.step(action)
184+
terminal = info.get("done", False)
185+
new_state = self.get_state()
186+
return new_state, obs, reward, terminal, info
187+
188+
def step_batch(
189+
self, actions, states=None, n_repeat_action: Union[int, np.ndarray] = None
190+
) -> tuple:
191+
data = [self.step(action, state) for action, state in zip(actions, states)]
192+
new_states, observs, rewards, terminals, infos = [], [], [], [], []
193+
for d in data:
194+
new_state, obs, _reward, end, info = d
195+
new_states.append(new_state)
196+
observs.append(obs)
197+
rewards.append(_reward)
198+
terminals.append(end)
199+
infos.append(info)
200+
return new_states, observs, rewards, terminals, infos
201+
202+
def reset(self, return_state: bool = False):
203+
assert self._env is not None, "env required to reset"
204+
obs = self._env.reset()
205+
return self.get_state(), obs
206+
203207
if config.use_mp:
204208
env = ParallelEnvironment(
205209
env_class=FragileEnvironment,

libraries/mathy_python/requirements.txt

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,3 @@ gym<=0.12.5
1515
tensorflow>=2.1.0
1616
tensorflow_probability
1717
keras-self-attention
18-
19-
# Fractal Monte Carlo agent
20-
fragile==0.0.27
21-
plangym==0.0.2

0 commit comments

Comments
 (0)