2626from .. import action_selectors
2727from ..episode_memory import EpisodeMemory
2828from ..mcts import MCTS
29- from ..policy_value_model import PolicyValueModel , get_or_create_policy_model
29+ from ..policy_value_model import ThincPolicyValueModel , get_or_create_policy_model
3030from ..trfl import discrete_policy_entropy_loss , td_lambda
3131from .config import A3CConfig
3232from .util import record , truncate
@@ -50,7 +50,7 @@ def __init__(
5050 self ,
5151 args : A3CConfig ,
5252 action_size : int ,
53- global_model : PolicyValueModel ,
53+ global_model : ThincPolicyValueModel ,
5454 optimizer ,
5555 greedy_epsilon : Union [float , List [float ]],
5656 result_queue : Queue ,
@@ -132,7 +132,7 @@ def run(self):
132132 if win_pct is not None :
133133 with self .writer .as_default ():
134134 student = self .teacher .students [self .worker_idx ]
135- step = self .global_model .optimizer .iterations
135+ step = self .global_model .unwrapped . optimizer .iterations
136136 if self .worker_idx == 0 :
137137 tf .summary .scalar (
138138 f"win_rate/{ student .topic } " , data = win_pct , step = step
@@ -216,7 +216,7 @@ def build_episode_selector(
216216 )
217217 return selector
218218
219- def run_episode (self , episode_memory : EpisodeMemory ):
219+ def run_episode (self , episode_memory : EpisodeMemory ) -> float :
220220 env_name = self .teacher .get_env (self .worker_idx , self .iteration )
221221 env = gym .make (env_name , ** self .env_extra )
222222 episode_memory .clear ()
@@ -233,16 +233,16 @@ def run_episode(self, episode_memory: EpisodeMemory):
233233 selector = self .build_episode_selector (env )
234234
235235 # Set RNN to 0 state for start of episode
236- selector .model .embedding .reset_rnn_state ()
236+ selector .model .unwrapped . embedding .reset_rnn_state ()
237237
238238 # Start with the "init" sequence [n] times
239239 for i in range (self .args .num_thinking_steps_begin ):
240- rnn_state_h = tf .squeeze (selector .model .embedding .state_h .numpy ())
241- rnn_state_c = tf .squeeze (selector .model .embedding .state_c .numpy ())
240+ rnn_state_h = tf .squeeze (selector .model .unwrapped . embedding .state_h .numpy ())
241+ rnn_state_c = tf .squeeze (selector .model .unwrapped . embedding .state_c .numpy ())
242242 seq_start = env .state .to_start_observation (rnn_state_h , rnn_state_c )
243243 try :
244244 window = observations_to_window ([seq_start , last_observation ])
245- selector .model . call ( window .to_inputs ())
245+ selector .model ([ window .to_inputs ()], is_train = True )
246246 except BaseException as err :
247247 print_error (
248248 err , f"Episode begin thinking steps prediction failed." ,
@@ -253,8 +253,8 @@ def run_episode(self, episode_memory: EpisodeMemory):
253253 if self .args .print_training and self .worker_idx == 0 :
254254 env .render (self .args .print_mode , None )
255255 # store rnn state for replay training
256- rnn_state_h = tf .squeeze (selector .model .embedding .state_h .numpy ())
257- rnn_state_c = tf .squeeze (selector .model .embedding .state_c .numpy ())
256+ rnn_state_h = tf .squeeze (selector .model .unwrapped . embedding .state_h .numpy ())
257+ rnn_state_c = tf .squeeze (selector .model .unwrapped . embedding .state_c .numpy ())
258258 rnn_history_h = episode_memory .rnn_weighted_history (self .args .lstm_units )[0 ]
259259 last_rnn_state = [rnn_state_h , rnn_state_c ]
260260
@@ -269,8 +269,8 @@ def run_episode(self, episode_memory: EpisodeMemory):
269269 rnn_state_c = tf .squeeze (rnn_state_c ),
270270 rnn_history_h = rnn_history_h ,
271271 )
272- # before_rnn_state_h = selector.model.embedding.state_h.numpy()
273- # before_rnn_state_c = selector.model.embedding.state_c.numpy()
272+ # before_rnn_state_h = selector.model.unwrapped. embedding.state_h.numpy()
273+ # before_rnn_state_c = selector.model.unwrapped. embedding.state_c.numpy()
274274
275275 window = episode_memory .to_window_observation (last_observation )
276276 try :
@@ -287,8 +287,8 @@ def run_episode(self, episode_memory: EpisodeMemory):
287287
288288 # Take an env step
289289 observation , reward , done , _ = env .step (action )
290- rnn_state_h = tf .squeeze (selector .model .embedding .state_h .numpy ())
291- rnn_state_c = tf .squeeze (selector .model .embedding .state_c .numpy ())
290+ rnn_state_h = tf .squeeze (selector .model .unwrapped . embedding .state_h .numpy ())
291+ rnn_state_c = tf .squeeze (selector .model .unwrapped . embedding .state_c .numpy ())
292292
293293 # TODO: make this a unit test, check that EpisodeMemory states are not equal
294294 # across time steps.
@@ -379,7 +379,7 @@ def maybe_write_episode_summaries(
379379 assert self .worker_idx == 0 , "only write summaries for greedy worker"
380380 # Track metrics for all workers
381381 name = self .teacher .get_env (self .worker_idx , self .iteration )
382- step = self .global_model .optimizer .iterations
382+ step = self .global_model .unwrapped . optimizer .iterations
383383 with self .writer .as_default ():
384384 agent_state = last_state .agent
385385 steps = int (last_state .max_moves - agent_state .moves_remaining )
@@ -396,28 +396,30 @@ def maybe_write_episode_summaries(
396396 step = step ,
397397 )
398398
399- def maybe_write_histograms (self ):
399+ def maybe_write_histograms (self ) -> None :
400400 if self .worker_idx != 0 :
401401 return
402- step = self .global_model .optimizer .iterations .numpy ()
402+ step = self .global_model .unwrapped . optimizer .iterations .numpy ()
403403 next_write = self .last_histogram_write + self .args .summary_interval
404404 if step >= next_write or self .last_histogram_write == - 1 :
405405 with self .writer .as_default ():
406406 self .last_histogram_write = step
407- for var in self .local_model .trainable_variables :
407+ for var in self .local_model .unwrapped . trainable_variables :
408408 tf .summary .histogram (
409- var .name , var , step = self .global_model .optimizer .iterations
409+ var .name ,
410+ var ,
411+ step = self .global_model .unwrapped .optimizer .iterations ,
410412 )
411413 # Write out current LSTM hidden/cell states
412414 tf .summary .histogram (
413415 "memory/lstm_c" ,
414- self .local_model .embedding .state_c ,
415- step = self .global_model .optimizer .iterations ,
416+ self .local_model .unwrapped . embedding .state_c ,
417+ step = self .global_model .unwrapped . optimizer .iterations ,
416418 )
417419 tf .summary .histogram (
418420 "memory/lstm_h" ,
419- self .local_model .embedding .state_h ,
420- step = self .global_model .optimizer .iterations ,
421+ self .local_model .unwrapped . embedding .state_h ,
422+ step = self .global_model .unwrapped . optimizer .iterations ,
421423 )
422424
423425 def update_global_network (
@@ -442,10 +444,10 @@ def update_global_network(
442444 self .ep_aux_loss [k ] = 0.0
443445 self .ep_aux_loss [k ] += aux_losses [k ].numpy ()
444446 # Calculate local gradients
445- grads = tape .gradient (total_loss , self .local_model .trainable_weights )
447+ grads = tape .gradient (total_loss , self .local_model .unwrapped . trainable_weights )
446448 # Push local gradients to global model
447449
448- zipped_gradients = zip (grads , self .global_model .trainable_weights )
450+ zipped_gradients = zip (grads , self .global_model .unwrapped . trainable_weights )
449451 # Assert that we always have some gradient flow in each trainable var
450452
451453 # TODO: Make this a unit test. It degrades performance at train time
@@ -460,7 +462,7 @@ def update_global_network(
460462
461463 self .optimizer .apply_gradients (zipped_gradients )
462464 # Update local model with new weights
463- self .local_model .set_weights (self .global_model .get_weights ())
465+ self .local_model .unwrapped . set_weights (self .global_model . unwrapped .get_weights ())
464466 episode_memory .clear ()
465467
466468 def finish_episode (self , episode_reward , episode_steps , last_state : MathyEnvState ):
@@ -487,7 +489,7 @@ def finish_episode(self, episode_reward, episode_steps, last_state: MathyEnvStat
487489 episode_reward , episode_steps , last_state
488490 )
489491
490- step = self .global_model .optimizer .iterations .numpy ()
492+ step = self .global_model .unwrapped . optimizer .iterations .numpy ()
491493 next_write = self .last_model_write + A3CWorker .save_every_n_episodes
492494 if step >= next_write or self .last_model_write == - 1 :
493495 self .last_model_write = step
@@ -512,12 +514,12 @@ def compute_policy_value_loss(
512514 episode_memory : EpisodeMemory ,
513515 gamma = 0.99 ,
514516 ):
515- step = self .global_model .optimizer .iterations
517+ step = self .global_model .unwrapped . optimizer .iterations
516518 if done :
517519 bootstrap_value = 0.0 # terminal
518520 else :
519521 # Predict the reward using the local network
520- _ , values , _ = self .local_model .call (
522+ _ , values , _ = self .local_model .unwrapped . call (
521523 observations_to_window ([observation ]).to_inputs ()
522524 )
523525 # Select the last timestep
@@ -536,7 +538,8 @@ def compute_policy_value_loss(
536538 batch_size = len (episode_memory .actions )
537539 sequence_length = len (episode_memory .observations [0 ].nodes )
538540 inputs = episode_memory .to_episode_window ().to_inputs ()
539- logits , values , trimmed_logits = self .local_model (inputs , apply_mask = False )
541+ logits , values , trimmed_logits = self .local_model .unwrapped (inputs , apply_mask = False )
542+ # TODO: don't call unwrapped here
540543
541544 logits = tf .reshape (logits , [batch_size , - 1 ])
542545 policy_logits = tf .reshape (trimmed_logits , [batch_size , - 1 ])
@@ -615,7 +618,7 @@ def compute_loss(
615618 gamma = 0.99 ,
616619 ):
617620 with self .writer .as_default ():
618- step = self .global_model .optimizer .iterations
621+ step = self .global_model .unwrapped . optimizer .iterations
619622 loss_tuple = self .compute_policy_value_loss (
620623 done , observation , episode_memory
621624 )
0 commit comments