Skip to content

Commit 95e764e

Browse files
feat(tensorflow): update to 2.1.0
1 parent 7f0aac2 commit 95e764e

10 files changed

Lines changed: 174 additions & 95 deletions

File tree

libraries/mathy_python/mathy/agents/a3c/agent.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,15 +42,16 @@ def __init__(self, args: A3CConfig, env_extra: dict = None):
4242
args=args, env_actions=self.action_size, is_main=True, env=env.mathy
4343
)
4444
with self.writer.as_default():
45+
tf_model: tf.keras.Model = self.global_model.unwrapped
4546
tf.summary.trace_on(graph=True)
4647
inputs = init_window.to_inputs()
47-
self.global_model.call_graph(inputs)
48+
tf_model.call_graph(inputs)
4849
tf.summary.trace_export(
4950
name="PolicyValueModel", step=0, profiler_outdir=self.log_dir
5051
)
5152
tf.summary.trace_off()
5253
if self.args.verbose:
53-
print(self.global_model.summary())
54+
print(tf_model.summary())
5455

5556
def train(self):
5657
res_queue = Queue()
@@ -68,7 +69,7 @@ def train(self):
6869
args=self.args,
6970
teacher=self.teacher,
7071
worker_idx=i,
71-
optimizer=self.global_model.optimizer,
72+
optimizer=self.global_model.unwrapped.optimizer,
7273
result_queue=res_queue,
7374
writer=self.writer,
7475
)
@@ -93,3 +94,4 @@ def train(self):
9394
# Do a final save after joining to get the very latest model
9495
self.global_model.save()
9596
print("Done. Bye!")
97+

libraries/mathy_python/mathy/agents/a3c/worker.py

Lines changed: 34 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from .. import action_selectors
2727
from ..episode_memory import EpisodeMemory
2828
from ..mcts import MCTS
29-
from ..policy_value_model import PolicyValueModel, get_or_create_policy_model
29+
from ..policy_value_model import ThincPolicyValueModel, get_or_create_policy_model
3030
from ..trfl import discrete_policy_entropy_loss, td_lambda
3131
from .config import A3CConfig
3232
from .util import record, truncate
@@ -50,7 +50,7 @@ def __init__(
5050
self,
5151
args: A3CConfig,
5252
action_size: int,
53-
global_model: PolicyValueModel,
53+
global_model: ThincPolicyValueModel,
5454
optimizer,
5555
greedy_epsilon: Union[float, List[float]],
5656
result_queue: Queue,
@@ -132,7 +132,7 @@ def run(self):
132132
if win_pct is not None:
133133
with self.writer.as_default():
134134
student = self.teacher.students[self.worker_idx]
135-
step = self.global_model.optimizer.iterations
135+
step = self.global_model.unwrapped.optimizer.iterations
136136
if self.worker_idx == 0:
137137
tf.summary.scalar(
138138
f"win_rate/{student.topic}", data=win_pct, step=step
@@ -216,7 +216,7 @@ def build_episode_selector(
216216
)
217217
return selector
218218

219-
def run_episode(self, episode_memory: EpisodeMemory):
219+
def run_episode(self, episode_memory: EpisodeMemory) -> float:
220220
env_name = self.teacher.get_env(self.worker_idx, self.iteration)
221221
env = gym.make(env_name, **self.env_extra)
222222
episode_memory.clear()
@@ -233,16 +233,16 @@ def run_episode(self, episode_memory: EpisodeMemory):
233233
selector = self.build_episode_selector(env)
234234

235235
# Set RNN to 0 state for start of episode
236-
selector.model.embedding.reset_rnn_state()
236+
selector.model.unwrapped.embedding.reset_rnn_state()
237237

238238
# Start with the "init" sequence [n] times
239239
for i in range(self.args.num_thinking_steps_begin):
240-
rnn_state_h = tf.squeeze(selector.model.embedding.state_h.numpy())
241-
rnn_state_c = tf.squeeze(selector.model.embedding.state_c.numpy())
240+
rnn_state_h = tf.squeeze(selector.model.unwrapped.embedding.state_h.numpy())
241+
rnn_state_c = tf.squeeze(selector.model.unwrapped.embedding.state_c.numpy())
242242
seq_start = env.state.to_start_observation(rnn_state_h, rnn_state_c)
243243
try:
244244
window = observations_to_window([seq_start, last_observation])
245-
selector.model.call(window.to_inputs())
245+
selector.model([window.to_inputs()], is_train=True)
246246
except BaseException as err:
247247
print_error(
248248
err, f"Episode begin thinking steps prediction failed.",
@@ -253,8 +253,8 @@ def run_episode(self, episode_memory: EpisodeMemory):
253253
if self.args.print_training and self.worker_idx == 0:
254254
env.render(self.args.print_mode, None)
255255
# store rnn state for replay training
256-
rnn_state_h = tf.squeeze(selector.model.embedding.state_h.numpy())
257-
rnn_state_c = tf.squeeze(selector.model.embedding.state_c.numpy())
256+
rnn_state_h = tf.squeeze(selector.model.unwrapped.embedding.state_h.numpy())
257+
rnn_state_c = tf.squeeze(selector.model.unwrapped.embedding.state_c.numpy())
258258
rnn_history_h = episode_memory.rnn_weighted_history(self.args.lstm_units)[0]
259259
last_rnn_state = [rnn_state_h, rnn_state_c]
260260

@@ -269,8 +269,8 @@ def run_episode(self, episode_memory: EpisodeMemory):
269269
rnn_state_c=tf.squeeze(rnn_state_c),
270270
rnn_history_h=rnn_history_h,
271271
)
272-
# before_rnn_state_h = selector.model.embedding.state_h.numpy()
273-
# before_rnn_state_c = selector.model.embedding.state_c.numpy()
272+
# before_rnn_state_h = selector.model.unwrapped.embedding.state_h.numpy()
273+
# before_rnn_state_c = selector.model.unwrapped.embedding.state_c.numpy()
274274

275275
window = episode_memory.to_window_observation(last_observation)
276276
try:
@@ -287,8 +287,8 @@ def run_episode(self, episode_memory: EpisodeMemory):
287287

288288
# Take an env step
289289
observation, reward, done, _ = env.step(action)
290-
rnn_state_h = tf.squeeze(selector.model.embedding.state_h.numpy())
291-
rnn_state_c = tf.squeeze(selector.model.embedding.state_c.numpy())
290+
rnn_state_h = tf.squeeze(selector.model.unwrapped.embedding.state_h.numpy())
291+
rnn_state_c = tf.squeeze(selector.model.unwrapped.embedding.state_c.numpy())
292292

293293
# TODO: make this a unit test, check that EpisodeMemory states are not equal
294294
# across time steps.
@@ -379,7 +379,7 @@ def maybe_write_episode_summaries(
379379
assert self.worker_idx == 0, "only write summaries for greedy worker"
380380
# Track metrics for all workers
381381
name = self.teacher.get_env(self.worker_idx, self.iteration)
382-
step = self.global_model.optimizer.iterations
382+
step = self.global_model.unwrapped.optimizer.iterations
383383
with self.writer.as_default():
384384
agent_state = last_state.agent
385385
steps = int(last_state.max_moves - agent_state.moves_remaining)
@@ -396,28 +396,30 @@ def maybe_write_episode_summaries(
396396
step=step,
397397
)
398398

399-
def maybe_write_histograms(self):
399+
def maybe_write_histograms(self) -> None:
400400
if self.worker_idx != 0:
401401
return
402-
step = self.global_model.optimizer.iterations.numpy()
402+
step = self.global_model.unwrapped.optimizer.iterations.numpy()
403403
next_write = self.last_histogram_write + self.args.summary_interval
404404
if step >= next_write or self.last_histogram_write == -1:
405405
with self.writer.as_default():
406406
self.last_histogram_write = step
407-
for var in self.local_model.trainable_variables:
407+
for var in self.local_model.unwrapped.trainable_variables:
408408
tf.summary.histogram(
409-
var.name, var, step=self.global_model.optimizer.iterations
409+
var.name,
410+
var,
411+
step=self.global_model.unwrapped.optimizer.iterations,
410412
)
411413
# Write out current LSTM hidden/cell states
412414
tf.summary.histogram(
413415
"memory/lstm_c",
414-
self.local_model.embedding.state_c,
415-
step=self.global_model.optimizer.iterations,
416+
self.local_model.unwrapped.embedding.state_c,
417+
step=self.global_model.unwrapped.optimizer.iterations,
416418
)
417419
tf.summary.histogram(
418420
"memory/lstm_h",
419-
self.local_model.embedding.state_h,
420-
step=self.global_model.optimizer.iterations,
421+
self.local_model.unwrapped.embedding.state_h,
422+
step=self.global_model.unwrapped.optimizer.iterations,
421423
)
422424

423425
def update_global_network(
@@ -442,10 +444,10 @@ def update_global_network(
442444
self.ep_aux_loss[k] = 0.0
443445
self.ep_aux_loss[k] += aux_losses[k].numpy()
444446
# Calculate local gradients
445-
grads = tape.gradient(total_loss, self.local_model.trainable_weights)
447+
grads = tape.gradient(total_loss, self.local_model.unwrapped.trainable_weights)
446448
# Push local gradients to global model
447449

448-
zipped_gradients = zip(grads, self.global_model.trainable_weights)
450+
zipped_gradients = zip(grads, self.global_model.unwrapped.trainable_weights)
449451
# Assert that we always have some gradient flow in each trainable var
450452

451453
# TODO: Make this a unit test. It degrades performance at train time
@@ -460,7 +462,7 @@ def update_global_network(
460462

461463
self.optimizer.apply_gradients(zipped_gradients)
462464
# Update local model with new weights
463-
self.local_model.set_weights(self.global_model.get_weights())
465+
self.local_model.unwrapped.set_weights(self.global_model.unwrapped.get_weights())
464466
episode_memory.clear()
465467

466468
def finish_episode(self, episode_reward, episode_steps, last_state: MathyEnvState):
@@ -487,7 +489,7 @@ def finish_episode(self, episode_reward, episode_steps, last_state: MathyEnvStat
487489
episode_reward, episode_steps, last_state
488490
)
489491

490-
step = self.global_model.optimizer.iterations.numpy()
492+
step = self.global_model.unwrapped.optimizer.iterations.numpy()
491493
next_write = self.last_model_write + A3CWorker.save_every_n_episodes
492494
if step >= next_write or self.last_model_write == -1:
493495
self.last_model_write = step
@@ -512,12 +514,12 @@ def compute_policy_value_loss(
512514
episode_memory: EpisodeMemory,
513515
gamma=0.99,
514516
):
515-
step = self.global_model.optimizer.iterations
517+
step = self.global_model.unwrapped.optimizer.iterations
516518
if done:
517519
bootstrap_value = 0.0 # terminal
518520
else:
519521
# Predict the reward using the local network
520-
_, values, _ = self.local_model.call(
522+
_, values, _ = self.local_model.unwrapped.call(
521523
observations_to_window([observation]).to_inputs()
522524
)
523525
# Select the last timestep
@@ -536,7 +538,8 @@ def compute_policy_value_loss(
536538
batch_size = len(episode_memory.actions)
537539
sequence_length = len(episode_memory.observations[0].nodes)
538540
inputs = episode_memory.to_episode_window().to_inputs()
539-
logits, values, trimmed_logits = self.local_model(inputs, apply_mask=False)
541+
logits, values, trimmed_logits = self.local_model.unwrapped(inputs, apply_mask=False)
542+
# TODO: don't call unwrapped here
540543

541544
logits = tf.reshape(logits, [batch_size, -1])
542545
policy_logits = tf.reshape(trimmed_logits, [batch_size, -1])
@@ -615,7 +618,7 @@ def compute_loss(
615618
gamma=0.99,
616619
):
617620
with self.writer.as_default():
618-
step = self.global_model.optimizer.iterations
621+
step = self.global_model.unwrapped.optimizer.iterations
619622
loss_tuple = self.compute_policy_value_loss(
620623
done, observation, episode_memory
621624
)

libraries/mathy_python/mathy/agents/action_selectors.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,15 @@
33
import numpy as np
44

55
from ..state import MathyEnvState, MathyWindowObservation
6-
from .policy_value_model import PolicyValueModel
6+
from .policy_value_model import ThincPolicyValueModel
77
from .mcts import MCTS
8+
import thinc
89

910

1011
class ActionSelector:
1112
"""An episode-specific selector of actions"""
1213

13-
def __init__(self, *, model: PolicyValueModel, episode: int, worker_id: int):
14+
def __init__(self, *, model: ThincPolicyValueModel, episode: int, worker_id: int):
1415
self.model = model
1516
self.worker_id = worker_id
1617
self.episode = episode

libraries/mathy_python/mathy/agents/mcts.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import numpy
55

66
from ..env import MathyEnv
7+
from .policy_value_model import ThincPolicyValueModel
78
from ..state import MathyEnvState, observations_to_window
89
from ..util import is_terminal_transition
910

@@ -28,7 +29,7 @@ class MCTS:
2829
def __init__(
2930
self,
3031
env: MathyEnv,
31-
model: Any,
32+
model: ThincPolicyValueModel,
3233
cpuct: float = 1.0,
3334
num_mcts_sims: int = 15,
3435
epsilon: float = 0.25,
@@ -130,8 +131,8 @@ def search(self, env_state: MathyEnvState, rnn_state: List[Any], isRootNode=Fals
130131
observations = observations_to_window([obs]).to_inputs()
131132
out_policy, state_v = self.model.predict_next(observations, use_graph=False)
132133
out_rnn_state = [
133-
tf.squeeze(self.model.embedding.state_h).numpy(),
134-
tf.squeeze(self.model.embedding.state_c).numpy(),
134+
tf.squeeze(self.model.unwrapped.embedding.state_h).numpy(),
135+
tf.squeeze(self.model.unwrapped.embedding.state_c).numpy(),
135136
]
136137
self.Rs[s] = out_rnn_state
137138
self.Ps[s] = out_policy

0 commit comments

Comments
 (0)