Skip to content

Commit 8515273

Browse files
feat(fragile): add swarm agent as untrained solver
- uses the crazy cool FractalAI fragile library to search the mathy env state spaces and find solutions without a trained neural network. - check them out: https://github.com/FragileTech/fragile
1 parent bc5648f commit 8515273

File tree

5 files changed

+163
-69
lines changed

5 files changed

+163
-69
lines changed

mathy_fragile.py renamed to libraries/mathy_python/mathy/agents/fragile.py

Lines changed: 83 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,34 @@
1-
from typing import List
1+
"""Use Fractal Monte Carlo search in order to solve mathy problems without a
2+
trained neural network."""
3+
24
import copy
3-
import sys
45
import time
5-
import traceback
66

77
import gym
8-
import holoviews
98
import numpy as np
10-
import srsly
119
from gym import spaces
10+
from plangym import ParallelEnvironment
11+
from plangym.env import Environment
1212

1313
from fragile.core.env import DiscreteEnv
14-
from fragile.core.models import Bounds, DiscreteModel
14+
from fragile.core.models import DiscreteModel
1515
from fragile.core.states import StatesEnv, StatesModel, StatesWalkers
1616
from fragile.core.swarm import Swarm
1717
from fragile.core.tree import HistoryTree
1818
from fragile.core.utils import StateDict
19-
from fragile.core.walkers import Walkers
20-
from fragile.dataviz import AtariViz, LandscapeViz, Summary, SwarmViz, SwarmViz1D
21-
from mathy import MathTypeKeysMax, MathyEnvState, is_terminal_transition, EnvRewards
19+
from mathy import EnvRewards, MathTypeKeysMax, MathyEnvState, is_terminal_transition
2220
from mathy.envs.gym import MathyGymEnv
23-
from plangym import ParallelEnvironment
24-
from plangym.env import Environment
25-
import os
21+
from typing import List, Optional
22+
23+
from pydantic import BaseModel
24+
from .. import about
2625

2726

28-
# Print explored mathy states when True
29-
verbose = False
30-
use_mp = True
31-
prune_tree = True
32-
max_iters = 100
33-
reward_scale = 5
34-
distance_scale = 10
35-
minimize = False
36-
use_vis = False
27+
class SwarmConfig(BaseModel):
28+
verbose: bool = False
29+
use_mp: bool = True
30+
n_walkers: int = 512
31+
max_iters: int = 100
3732

3833

3934
class DiscreteMasked(DiscreteModel):
@@ -80,23 +75,23 @@ class FragileMathyEnv(DiscreteEnv):
8075

8176
def get_params_dict(self) -> StateDict:
8277
super_params = super(FragileMathyEnv, self).get_params_dict()
83-
params = {"game_ends": {"dtype": np.bool_}}
78+
params = {"terminals": {"dtype": np.bool_}}
8479
params.update(super_params)
8580
return params
8681

8782
def step(self, model_states: StatesModel, env_states: StatesEnv) -> StatesEnv:
8883
actions = model_states.actions.astype(np.int32)
89-
new_states, observs, rewards, game_ends, infos = self._env.step_batch(
84+
new_states, observs, rewards, terminals, infos = self._env.step_batch(
9085
actions=actions, states=env_states.states
9186
)
92-
ends = [not inf.get("valid", False) for inf in infos]
87+
oobs = [not inf.get("valid", False) for inf in infos]
9388
new_state = self.states_from_data(
9489
states=new_states,
9590
observs=observs,
9691
rewards=rewards,
97-
ends=ends,
92+
oobs=oobs,
9893
batch_size=len(actions),
99-
game_ends=game_ends,
94+
terminals=terminals,
10095
)
10196
return new_state
10297

@@ -105,40 +100,47 @@ def reset(self, batch_size: int = 1, **kwargs) -> StatesEnv:
105100
states = np.array([copy.deepcopy(state) for _ in range(batch_size)])
106101
observs = np.array([copy.deepcopy(obs) for _ in range(batch_size)])
107102
rewards = np.zeros(batch_size, dtype=np.float32)
108-
ends = np.zeros(batch_size, dtype=np.bool_)
109-
game_ends = np.zeros(batch_size, dtype=np.bool_)
103+
oobs = np.zeros(batch_size, dtype=np.bool_)
104+
terminals = np.zeros(batch_size, dtype=np.bool_)
110105
new_states = self.states_from_data(
111106
states=states,
112107
observs=observs,
113108
rewards=rewards,
114-
ends=ends,
109+
oobs=oobs,
115110
batch_size=batch_size,
116-
game_ends=game_ends,
111+
terminals=terminals,
117112
)
118113
return new_states
119114

120115

121116
class FragileEnvironment(Environment):
122117
"""Fragile Environment for solving Mathy problems."""
123118

119+
problem: Optional[str]
120+
124121
def __init__(
125122
self,
126123
name: str,
127124
environment: str = "poly",
128125
difficulty: str = "normal",
126+
problem: str = None,
127+
max_steps: int = 64,
129128
wrappers=None,
130129
**kwargs,
131130
):
132131
super(FragileEnvironment, self).__init__(name=name)
133132
self._env: MathyGymEnv = gym.make(
134133
f"mathy-{environment}-{difficulty}-v0",
135-
verbose=verbose,
136134
np_observation=True,
135+
error_invalid=False,
136+
env_problem=problem,
137137
**kwargs,
138138
)
139139
self.observation_space = spaces.Box(
140140
low=0, high=MathTypeKeysMax, shape=(256, 256, 1), dtype=np.uint8,
141141
)
142+
self.problem = problem
143+
self.max_steps = max_steps
142144
self.wrappers = wrappers
143145
self.init_env()
144146

@@ -159,7 +161,7 @@ def __getattr__(self, item):
159161
return getattr(self._env, item)
160162

161163
def get_state(self) -> np.ndarray:
162-
assert self._env is not None, "env required to get_state"
164+
assert self._env.state is not None, "env required to get_state"
163165
return self._env.state.to_np()
164166

165167
def set_state(self, state: np.ndarray):
@@ -238,38 +240,69 @@ def solve_problem(problem_env: str, problem_difficulty: str = "easy"):
238240
# print_every = 5 if problem_difficulty == "hard" else 10
239241
print_every = 1e6
240242
n_walkers = 768 if problem_difficulty == "hard" else 256
243+
max_iters = 100
241244

242245
swarm = MathySwarm(
243246
model=lambda env: DiscreteMasked(env=env),
244247
env=lambda: FragileMathyEnv(env=env),
245248
tree=HistoryTree,
246249
n_walkers=n_walkers,
247250
max_iters=max_iters,
248-
prune_tree=prune_tree,
249-
reward_scale=reward_scale,
250-
distance_scale=distance_scale,
251+
prune_tree=True,
252+
reward_scale=5,
253+
distance_scale=10,
251254
distance_function=mathy_dist,
252-
minimize=minimize,
255+
minimize=False,
253256
)
254257

255-
if not use_vis:
256-
_ = swarm.run(print_every=print_every)
258+
_ = swarm.run(print_every=print_every)
259+
best_ix = swarm.walkers.states.cum_rewards.argmax()
260+
best_id = swarm.walkers.states.id_walkers[best_ix]
261+
path = swarm.tree.get_branch(best_id, from_hash=True)
262+
last_state = MathyEnvState.from_np(path[0][-1])
263+
env._env.mathy.print_history(last_state)
264+
265+
266+
def swarm_solve(problem: str, config: SwarmConfig):
267+
if config.use_mp:
268+
env = ParallelEnvironment(
269+
env_class=FragileEnvironment,
270+
name="mathy_v0",
271+
problem=problem,
272+
repeat_problem=True,
273+
)
257274
else:
258-
holoviews.extension("bokeh")
259-
viz = SwarmViz1D(swarm, stream_interval=print_every)
260-
viz.plot()
261-
viz.run(print_every=print_every)
275+
env = FragileEnvironment(name="mathy_v0", problem=problem, repeat_problem=True)
262276

277+
print_every = 1e6
278+
279+
swarm = MathySwarm(
280+
model=lambda env: DiscreteMasked(env=env),
281+
env=lambda: FragileMathyEnv(env=env),
282+
tree=HistoryTree,
283+
n_walkers=config.n_walkers,
284+
max_iters=config.max_iters,
285+
prune_tree=True,
286+
reward_scale=5,
287+
distance_scale=10,
288+
distance_function=mathy_dist,
289+
minimize=False,
290+
)
291+
292+
_ = swarm.run(print_every=print_every)
263293
best_ix = swarm.walkers.states.cum_rewards.argmax()
264294
best_id = swarm.walkers.states.id_walkers[best_ix]
265295
path = swarm.tree.get_branch(best_id, from_hash=True)
266-
env.render(last_action=-1, last_reward=0.0)
267-
for s, a in zip(path[0][1:], path[1]):
268-
_, _, r, _, info = env.step(state=s, action=a)
269-
env.render(last_action=a, last_reward=r)
270-
time.sleep(0.01)
296+
last_state = MathyEnvState.from_np(path[0][-1])
297+
env._env.mathy.print_history(last_state)
298+
271299

300+
if __name__ == "__main__":
301+
# Solve one problem of each type and difficulty
302+
for e_type in ["complex", "binomial", "poly", "poly-blockers"]:
303+
for e_difficulty in ["easy", "normal"]:
304+
solve_problem(e_type, e_difficulty)
272305

273-
for e_type in ["complex", "binomial", "poly", "poly-blockers"]:
274-
for e_difficulty in ["easy", "normal", "hard"]:
275-
solve_problem(e_type, e_difficulty)
306+
# # Solve a bunch of problems of the same type
307+
# for i in range(10):
308+
# solve_problem("poly", "easy")

libraries/mathy_python/mathy/api.py

Lines changed: 41 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,66 @@
1-
from typing import Optional
1+
from dataclasses import dataclass
2+
from typing import Optional, Union
23

34
from .agents.base_config import BaseConfig
45
from .agents.episode_memory import EpisodeMemory
6+
from .agents.fragile import SwarmConfig, swarm_solve
57
from .agents.policy_value_model import PolicyValueModel, load_policy_value_model
8+
from .state import MathyEnvState
9+
10+
11+
@dataclass
12+
class MathyAPIModelState:
13+
config: BaseConfig
14+
model: PolicyValueModel
15+
16+
17+
@dataclass
18+
class MathyAPISwarmState:
19+
config: SwarmConfig
620

721

822
class Mathy:
923
"""The standard interface for working with Mathy models and agents."""
1024

11-
config: BaseConfig
12-
model: PolicyValueModel
25+
state: Union[MathyAPIModelState, MathyAPISwarmState]
1326

1427
def __init__(
1528
self,
1629
*,
1730
model_path: str = None,
1831
model: PolicyValueModel = None,
19-
config: BaseConfig = None,
32+
config: Union[BaseConfig, SwarmConfig] = None,
2033
silent: bool = False,
2134
):
2235
if model_path is not None:
23-
self.model, self.config = load_policy_value_model(model_path, silent=silent)
36+
model, config = load_policy_value_model(model_path, silent=silent)
37+
self.state = MathyAPIModelState(model=model, config=config)
2438
elif model is not None and config is not None:
2539
if not isinstance(model, PolicyValueModel):
2640
raise ValueError("model must derive PolicyValueModel for compatibility")
27-
self.model = model
28-
self.config = config
41+
if not isinstance(config, BaseConfig):
42+
raise ValueError("config must be a BaseConfig instance")
43+
self.state = MathyAPIModelState(model=model, config=config)
2944
else:
30-
raise ValueError(
31-
"Either 'model_path' or ('model' and 'config') must be provided"
32-
)
45+
if not isinstance(config, SwarmConfig):
46+
raise ValueError("config must be a BaseConfig instance")
47+
self.state = MathyAPISwarmState(config=config)
3348

3449
def simplify(
3550
self, *, model: str = "mathy_alpha_sm", problem: str, max_steps: int = 128,
51+
) -> EpisodeMemory:
52+
if isinstance(self.state, MathyAPISwarmState):
53+
return self.simplify_swarm(problem=problem, max_steps=max_steps)
54+
if not isinstance(self.state, MathyAPIModelState):
55+
raise ValueError(f"unknown state type: {type(self.state)}!")
56+
return self.simplify_model(model=model, problem=problem, max_steps=max_steps)
57+
58+
def simplify_swarm(self, *, problem: str, max_steps: int) -> EpisodeMemory:
59+
assert isinstance(self.state, MathyAPISwarmState), "not configured for swarm"
60+
return swarm_solve(problem, self.state.config)
61+
62+
def simplify_model(
63+
self, *, model: str = "mathy_alpha_sm", problem: str, max_steps: int,
3664
) -> EpisodeMemory:
3765
"""Simplify an input problem using the PolySimplify environment.
3866
@@ -47,6 +75,7 @@ def simplify(
4775
to the solution for the input problem.
4876
4977
"""
78+
assert isinstance(self.state, MathyAPIModelState), "not configured for model"
5079
import gym
5180
import tensorflow as tf
5281
from colr import color
@@ -66,12 +95,12 @@ def simplify(
6695
last_text = env.state.agent.problem
6796
last_action = -1
6897
last_reward = 0.0
69-
selector = GreedyActionSelector(model=self.model, episode=0, worker_id=0)
98+
selector = GreedyActionSelector(model=self.state.model, episode=0, worker_id=0)
7099
done = False
71100
while not done:
72101
env.render(last_action=last_action, last_reward=last_reward)
73102
window = episode_memory.to_window_observation(
74-
last_observation, window_size=self.config.prediction_window_size
103+
last_observation, window_size=self.state.config.prediction_window_size
75104
)
76105
try:
77106
action, value = selector.select(

libraries/mathy_python/mathy/cli.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,13 @@ def cli_contribute():
2929

3030

3131
@cli.command("simplify")
32+
@click.option(
33+
"swarm",
34+
"--swarm",
35+
default=False,
36+
is_flag=True,
37+
help="Use swarm solver from fragile library without a trained model",
38+
)
3239
@click.option(
3340
"model", "--model", default="mathy_alpha_sm", help="The path to a mathy model",
3441
)
@@ -40,14 +47,23 @@ def cli_contribute():
4047
help="The max number of steps before the episode is over",
4148
)
4249
@click.argument("problem", type=str)
43-
def cli_simplify(agent: str, problem: str, model: str, max_steps: int):
50+
def cli_simplify(agent: str, problem: str, model: str, max_steps: int, swarm: bool):
4451
"""Simplify an input polynomial expression."""
4552
setup_tf_env()
4653

4754
from .models import load_model
4855
from .api import Mathy
56+
from .agents.fragile import SwarmConfig
57+
58+
mt: Mathy
59+
if swarm is True:
60+
mt = Mathy(config=SwarmConfig())
61+
else:
62+
try:
63+
mt = load_model(model)
64+
except ValueError:
65+
mt = Mathy(config=SwarmConfig())
4966

50-
mt: Mathy = load_model(model)
5167
mt.simplify(problem=problem, max_steps=max_steps)
5268

5369

libraries/mathy_python/mathy/envs/gym/mathy_gym_env.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ def __init__(
2828
self,
2929
env_class: Type[MathyEnv] = MathyEnv,
3030
env_problem_args: Optional[MathyEnvProblemArgs] = None,
31+
env_problem: Optional[str] = None,
32+
env_max_moves: int = 64,
3133
np_observation: bool = False,
3234
repeat_problem: bool = False,
3335
**env_kwargs,
@@ -38,7 +40,12 @@ def __init__(
3840
self.mathy = env_class(**env_kwargs)
3941
self.env_class = env_class
4042
self.env_problem_args = env_problem_args
41-
self._challenge, _ = self.mathy.get_initial_state(env_problem_args)
43+
if env_problem is not None:
44+
self._challenge = MathyEnvState(
45+
problem=env_problem, max_moves=env_max_moves
46+
)
47+
else:
48+
self._challenge, _ = self.mathy.get_initial_state(env_problem_args)
4249
self.action_space = MaskedDiscrete(self.action_size, [1] * self.action_size)
4350

4451
@property

0 commit comments

Comments
 (0)