88
99import numpy as np
1010from pydantic import BaseModel
11+ from wasabi import msg
1112
1213from ...agents .episode_memory import rnn_weighted_history
1314from ...agents .mcts import MCTS
@@ -56,6 +57,7 @@ def step(
5657 model : PolicyValueModel ,
5758 move_count ,
5859 history : List [EpisodeHistory ],
60+ is_verbose_worker : bool ,
5961 ):
6062 import tensorflow as tf
6163
@@ -88,8 +90,12 @@ def step(
8890 mcts_state_copy , [rnn_state_h , rnn_state_c ], temp = temp
8991 )
9092 action = np .random .choice (len (predicted_policy ), p = predicted_policy )
91-
92- next_state , transition , change = game .mathy .get_next_state (env_state , action )
93+ observation , reward , done , meta = game .step (action )
94+ next_state = game .state
95+ assert next_state is not None
96+ transition = meta ["transition" ]
97+ if is_verbose_worker and self .config .print_training is True :
98+ game .render ()
9399 example_text = next_state .agent .problem
94100 r = float (transition .reward )
95101 is_term = is_terminal_transition (transition )
@@ -183,7 +189,12 @@ def execute_episodes(
183189 return examples , results
184190
185191 def execute_episode (
186- self , episode , game : MathyGymEnv , predictor : PolicyValueModel , model_dir : str
192+ self ,
193+ episode ,
194+ game : MathyGymEnv ,
195+ predictor : PolicyValueModel ,
196+ model_dir : str ,
197+ is_verbose_worker : bool = False ,
187198 ):
188199 """
189200 This function executes one episode.
@@ -203,12 +214,24 @@ def execute_episode(
203214 episode_history : List [Any ] = []
204215 move_count = 0
205216 mcts = MCTS (game .mathy , predictor , self .config .cpuct , self .config .mcts_sims )
217+ if is_verbose_worker and self .config .print_training is True :
218+ game .render ()
219+
206220 while True :
207221 move_count += 1
208222 env_state , result = self .step (
209- game , env_state , mcts , predictor , move_count , episode_history
223+ game = game ,
224+ env_state = env_state ,
225+ mcts = mcts ,
226+ model = predictor ,
227+ move_count = move_count ,
228+ history = episode_history ,
229+ is_verbose_worker = is_verbose_worker ,
210230 )
211231 if result is not None :
232+ if is_verbose_worker and self .config .print_training is True :
233+ game .render ()
234+
212235 return result + (game .problem ,)
213236
214237 def episode_complete (self , episode : int , summary : EpisodeSummary ):
@@ -251,11 +274,13 @@ class ParallelPracticeRunner(PracticeRunner):
251274 def execute_episodes (
252275 self , episode_args_list
253276 ) -> Tuple [List [EpisodeHistory ], List [EpisodeSummary ]]:
254- def worker (work_queue : Queue , result_queue : Queue ):
277+ def worker (worker_idx : int , work_queue : Queue , result_queue : Queue ):
255278 """Pull items out of the work queue and execute episodes until there are
256- no items left"""
279+ no items left """
257280 game = self .get_game ()
258281 predictor = self .get_predictor (game )
282+ msg .good (f"Worker { worker_idx } started." )
283+
259284 while (
260285 ParallelPracticeRunner .request_quit is False
261286 and work_queue .empty () is False
@@ -268,7 +293,13 @@ def worker(work_queue: Queue, result_queue: Queue):
268293 episode_reward ,
269294 is_win ,
270295 problem ,
271- ) = self .execute_episode (episode , game , predictor , ** args )
296+ ) = self .execute_episode (
297+ episode ,
298+ game ,
299+ predictor ,
300+ is_verbose_worker = worker_idx == 0 ,
301+ ** args ,
302+ )
272303 except KeyboardInterrupt :
273304 break
274305 except Exception as e :
@@ -292,7 +323,7 @@ def worker(work_queue: Queue, result_queue: Queue):
292323 for i , args in enumerate (episode_args_list ):
293324 work_queue .put ((i , args ))
294325 processes = [
295- Process (target = worker , args = (work_queue , result_queue ), daemon = True )
326+ Process (target = worker , args = (i , work_queue , result_queue ))
296327 for i in range (self .config .num_workers )
297328 ]
298329 for proc in processes :
0 commit comments