Skip to content

Commit 82b3f66

Browse files
feat(zero): support --show argument for worker 0
- print steps from worker_0 when arg is present
1 parent 4d5fd37 commit 82b3f66

File tree

2 files changed

+40
-9
lines changed

2 files changed

+40
-9
lines changed

libraries/mathy_python/mathy/agents/zero/practice_runner.py

Lines changed: 39 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import numpy as np
1010
from pydantic import BaseModel
11+
from wasabi import msg
1112

1213
from ...agents.episode_memory import rnn_weighted_history
1314
from ...agents.mcts import MCTS
@@ -56,6 +57,7 @@ def step(
5657
model: PolicyValueModel,
5758
move_count,
5859
history: List[EpisodeHistory],
60+
is_verbose_worker: bool,
5961
):
6062
import tensorflow as tf
6163

@@ -88,8 +90,12 @@ def step(
8890
mcts_state_copy, [rnn_state_h, rnn_state_c], temp=temp
8991
)
9092
action = np.random.choice(len(predicted_policy), p=predicted_policy)
91-
92-
next_state, transition, change = game.mathy.get_next_state(env_state, action)
93+
observation, reward, done, meta = game.step(action)
94+
next_state = game.state
95+
assert next_state is not None
96+
transition = meta["transition"]
97+
if is_verbose_worker and self.config.print_training is True:
98+
game.render()
9399
example_text = next_state.agent.problem
94100
r = float(transition.reward)
95101
is_term = is_terminal_transition(transition)
@@ -183,7 +189,12 @@ def execute_episodes(
183189
return examples, results
184190

185191
def execute_episode(
186-
self, episode, game: MathyGymEnv, predictor: PolicyValueModel, model_dir: str
192+
self,
193+
episode,
194+
game: MathyGymEnv,
195+
predictor: PolicyValueModel,
196+
model_dir: str,
197+
is_verbose_worker: bool = False,
187198
):
188199
"""
189200
This function executes one episode.
@@ -203,12 +214,24 @@ def execute_episode(
203214
episode_history: List[Any] = []
204215
move_count = 0
205216
mcts = MCTS(game.mathy, predictor, self.config.cpuct, self.config.mcts_sims)
217+
if is_verbose_worker and self.config.print_training is True:
218+
game.render()
219+
206220
while True:
207221
move_count += 1
208222
env_state, result = self.step(
209-
game, env_state, mcts, predictor, move_count, episode_history
223+
game=game,
224+
env_state=env_state,
225+
mcts=mcts,
226+
model=predictor,
227+
move_count=move_count,
228+
history=episode_history,
229+
is_verbose_worker=is_verbose_worker,
210230
)
211231
if result is not None:
232+
if is_verbose_worker and self.config.print_training is True:
233+
game.render()
234+
212235
return result + (game.problem,)
213236

214237
def episode_complete(self, episode: int, summary: EpisodeSummary):
@@ -251,11 +274,13 @@ class ParallelPracticeRunner(PracticeRunner):
251274
def execute_episodes(
252275
self, episode_args_list
253276
) -> Tuple[List[EpisodeHistory], List[EpisodeSummary]]:
254-
def worker(work_queue: Queue, result_queue: Queue):
277+
def worker(worker_idx: int, work_queue: Queue, result_queue: Queue):
255278
"""Pull items out of the work queue and execute episodes until there are
256-
no items left"""
279+
no items left """
257280
game = self.get_game()
258281
predictor = self.get_predictor(game)
282+
msg.good(f"Worker {worker_idx} started.")
283+
259284
while (
260285
ParallelPracticeRunner.request_quit is False
261286
and work_queue.empty() is False
@@ -268,7 +293,13 @@ def worker(work_queue: Queue, result_queue: Queue):
268293
episode_reward,
269294
is_win,
270295
problem,
271-
) = self.execute_episode(episode, game, predictor, **args)
296+
) = self.execute_episode(
297+
episode,
298+
game,
299+
predictor,
300+
is_verbose_worker=worker_idx == 0,
301+
**args,
302+
)
272303
except KeyboardInterrupt:
273304
break
274305
except Exception as e:
@@ -292,7 +323,7 @@ def worker(work_queue: Queue, result_queue: Queue):
292323
for i, args in enumerate(episode_args_list):
293324
work_queue.put((i, args))
294325
processes = [
295-
Process(target=worker, args=(work_queue, result_queue), daemon=True)
326+
Process(target=worker, args=(i, work_queue, result_queue))
296327
for i in range(self.config.num_workers)
297328
]
298329
for proc in processes:

libraries/mathy_python/mathy/agents/zero/practice_session.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def episode_complete(self, episode, summary):
7979
bar.next()
8080

8181
episodes_with_args = []
82-
for _ in range(1, num_episodes + 1):
82+
for i in range(1, num_episodes + 1):
8383
episodes_with_args.append(dict(model_dir=self.runner.config.model_dir))
8484

8585
old_update = self.runner.episode_complete

0 commit comments

Comments
 (0)