Skip to content

Commit 910bcd6

Browse files
fix(a3c): use episode outcome for log coloring
- previously the win was inferred from the total episode reward, but it's not always the case that a winning episode has a > 0 reward, especially with long episodes. - pass the win info along and use it to control color instead of total reward
1 parent d000a88 commit 910bcd6

File tree

3 files changed

+36
-21
lines changed

3 files changed

+36
-21
lines changed

libraries/mathy_python/mathy/agents/a3c/util.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
from typing import Dict
2-
from colr import color
31
import datetime
4-
5-
2+
import multiprocessing
63
from dataclasses import dataclass, field
4+
from typing import Dict
5+
6+
from colr import color
77

88

99
@dataclass
@@ -35,14 +35,15 @@ def truncate(value):
3535

3636

3737
def record(
38-
episode,
39-
episode_reward,
40-
worker_idx,
41-
global_ep_reward,
42-
result_queue,
38+
episode: int,
39+
is_win: bool,
40+
episode_reward: float,
41+
worker_idx: int,
42+
global_ep_reward: float,
43+
result_queue: multiprocessing.Queue,
4344
losses: EpisodeLosses,
44-
num_steps,
45-
env_name,
45+
num_steps: int,
46+
env_name: str,
4647
):
4748
"""Helper function to store score and print statistics.
4849
Arguments:
@@ -56,10 +57,12 @@ def record(
5657
"""
5758

5859
now = datetime.datetime.now().strftime("%H:%M:%S")
60+
# Clamp to range -2, 2
61+
episode_reward = min(2.0, max(-2.0, episode_reward))
5962

6063
global_ep_reward = global_ep_reward * 0.99 + episode_reward * 0.01
6164

62-
fore = "green" if episode_reward > 0.0 else "red"
65+
fore = "green" if is_win else "red"
6366
heading = "{:<8} {:<3} {:<8} {:<10}".format(
6467
now, f"w{worker_idx}", f"ep: {episode}", f"steps: {num_steps}"
6568
)

libraries/mathy_python/mathy/agents/a3c/worker.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ def run_episode(self, episode_memory: EpisodeMemory) -> float:
246246
continue
247247

248248
# Take an env step
249-
observation, reward, done, _ = env.step(action)
249+
observation, reward, done, last_obs_info = env.step(action)
250250

251251
observation = MathyObservation(
252252
nodes=observation.nodes,
@@ -262,10 +262,7 @@ def run_episode(self, episode_memory: EpisodeMemory) -> float:
262262
)
263263
ep_reward += reward
264264
episode_memory.store(
265-
observation=last_observation,
266-
action=action,
267-
reward=reward,
268-
value=value,
265+
observation=last_observation, action=action, reward=reward, value=value,
269266
)
270267
if time_count == self.args.update_gradients_every or done:
271268
if done and self.args.print_training and self.worker_idx == 0:
@@ -275,7 +272,9 @@ def run_episode(self, episode_memory: EpisodeMemory) -> float:
275272
self.maybe_write_histograms()
276273
time_count = 0
277274
if done:
278-
self.finish_episode(ep_reward, ep_steps, env.state)
275+
self.finish_episode(
276+
last_obs_info.get("win", False), ep_reward, ep_steps, env.state
277+
)
279278

280279
ep_steps += 1
281280
time_count += 1
@@ -383,15 +382,24 @@ def update_global_network(
383382

384383
if done:
385384
episode_memory.clear()
385+
else:
386+
episode_memory.clear_except_window(self.args.prediction_window_size)
386387

387-
def finish_episode(self, episode_reward, episode_steps, last_state: MathyEnvState):
388+
def finish_episode(
389+
self,
390+
is_win: bool,
391+
episode_reward: float,
392+
episode_steps: int,
393+
last_state: MathyEnvState,
394+
):
388395
env_name = self.teacher.get_env(self.worker_idx, self.iteration)
389396

390397
# Only observe/track the most-greedy worker (high epsilon exploration
391398
# stats are unlikely to be consistent or informative)
392399
if self.worker_idx == 0:
393400
A3CWorker.global_moving_average_reward = record(
394401
A3CWorker.global_episode,
402+
is_win,
395403
episode_reward,
396404
self.worker_idx,
397405
A3CWorker.global_moving_average_reward,
@@ -490,7 +498,9 @@ def compute_policy_value_loss(
490498

491499
policy_loss *= advantage
492500
policy_loss = tf.reduce_mean(policy_loss)
493-
501+
if self.args.normalize_pi_loss:
502+
policy_loss /= sequence_length
503+
# Scale the policy loss down by the seq_len to make invariant to length
494504
total_loss = value_loss + policy_loss + entropy_loss + rp_loss
495505
prefix = self.tb_prefix
496506
tf.summary.scalar(f"{prefix}/policy_loss", data=policy_loss, step=step)

libraries/mathy_python/mathy/envs/gym/mathy_gym_env.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,10 @@ def action_size(self) -> int:
6666
def step(self, action):
6767
self.state, transition, change = self.mathy.get_next_state(self.state, action)
6868
observation = self._observe(self.state)
69-
info = {"transition": transition}
7069
done = is_terminal_transition(transition)
70+
info = {"transition": transition, "done": done}
71+
if done:
72+
info["win"] = transition.reward > 0.0
7173
self.last_action = action
7274
self.last_change = change
7375
self.last_reward = round(float(transition.reward), 4)

0 commit comments

Comments
 (0)