Skip to content

Commit 3016c48

Browse files
feat(fragile): enable multiprocessing solver by default
- fragile == 0.0.47 - this is almost always going to be faster because mathy environments aren't optimized for speed to begin with, so they max out a single-process python app quickly when it comes to env iterations
1 parent a2dfde9 commit 3016c48

File tree

5 files changed

+23
-11
lines changed

5 files changed

+23
-11
lines changed

libraries/mathy_python/mathy/agents/fragile.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from fragile.core.swarm import Swarm
1414
from fragile.core.utils import StateDict
1515
from fragile.core.walkers import Walkers
16+
from fragile.distributed.env import ParallelEnv
1617
from gym import spaces
1718
from pydantic import BaseModel
1819
from wasabi import msg
@@ -29,6 +30,7 @@
2930

3031

3132
class SwarmConfig(BaseModel):
33+
use_mp: bool = True
3234
verbose: bool = False
3335
n_walkers: int = 512
3436
max_iters: int = 100
@@ -193,6 +195,8 @@ def swarm_solve(problem: str, config: SwarmConfig):
193195
name="mathy_v0", problem=problem, repeat_problem=True
194196
)
195197
mathy_env: MathyEnv = env_callable()._env._env.mathy
198+
if config.use_mp:
199+
env_callable = ParallelEnv(env_callable=env_callable)
196200
swarm = Swarm(
197201
model=lambda env: DiscreteMasked(env=env),
198202
env=env_callable,

libraries/mathy_python/mathy/cli.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,13 @@ def cli_contribute():
3636
is_flag=True,
3737
help="Use swarm solver from fragile library without a trained model",
3838
)
39+
@click.option(
40+
"parallel",
41+
"--parallel",
42+
default=True,
43+
is_flag=True,
44+
help="Use parallel execution with the swarm solver",
45+
)
3946
@click.option(
4047
"model", "--model", default="mathy_alpha_sm", help="The path to a mathy model",
4148
)
@@ -47,7 +54,9 @@ def cli_contribute():
4754
help="The max number of steps before the episode is over",
4855
)
4956
@click.argument("problem", type=str)
50-
def cli_simplify(agent: str, problem: str, model: str, max_steps: int, swarm: bool):
57+
def cli_simplify(
58+
agent: str, problem: str, model: str, max_steps: int, swarm: bool, parallel: bool
59+
):
5160
"""Simplify an input polynomial expression."""
5261
setup_tf_env()
5362

@@ -57,12 +66,12 @@ def cli_simplify(agent: str, problem: str, model: str, max_steps: int, swarm: bo
5766

5867
mt: Mathy
5968
if swarm is True:
60-
mt = Mathy(config=SwarmConfig())
69+
mt = Mathy(config=SwarmConfig(use_mp=parallel))
6170
else:
6271
try:
6372
mt = load_model(model)
6473
except ValueError:
65-
mt = Mathy(config=SwarmConfig())
74+
mt = Mathy(config=SwarmConfig(use_mp=parallel))
6675

6776
mt.simplify(problem=problem, max_steps=max_steps)
6877

libraries/mathy_python/setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def setup_package():
2323
DEVELOPMENT_MODULES = [line.strip() for line in file if "-e" not in line]
2424

2525
extras = {
26-
"fragile": ["fragile==0.0.45"],
26+
"fragile": ["fragile==0.0.47"],
2727
"dev": DEVELOPMENT_MODULES,
2828
}
2929
extras["all"] = [item for group in extras.values() for item in group]

libraries/mathy_python/tests/test_api.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,6 @@ def test_mathy_with_model_and_config():
2424
mt = Mathy(model=model, config=config)
2525

2626

27-
def test_api_mathy_simplify():
28-
mt: Mathy = Mathy(config=SwarmConfig())
29-
mt.simplify(problem="2x+4x", max_steps=4)
30-
31-
3227
def test_api_mathy_constructor():
3328
# Defaults to swarm config
3429
assert isinstance(Mathy().state, MathyAPISwarmState)

libraries/mathy_python/tests/test_cli.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,13 @@ def test_cli_simplify():
3232
assert result.exit_code == 0
3333

3434

35-
def test_cli_simplify_swarm():
35+
@pytest.mark.parametrize("use_mp", [True, False])
36+
def test_cli_simplify_swarm(use_mp: bool):
3637
runner = CliRunner()
37-
result = runner.invoke(cli, ["simplify", "4x + 2x", "--swarm"])
38+
args = ["simplify", "4x + 2x", "--swarm"]
39+
if use_mp:
40+
args.append("--parallel")
41+
result = runner.invoke(cli, args)
3842
assert result.exit_code == 0
3943

4044

0 commit comments

Comments
 (0)