Skip to content

Commit d4095c5

Browse files
feat(config): add prediction_window_size
- controls the number of timesteps to include when calling the policy/value model. Default to 7 because why not, and I think it's the number of frames that AlphaZero stacked..? 🤷
1 parent ee77ae5 commit d4095c5

File tree

2 files changed

+7
-2
lines changed

2 files changed

+7
-2
lines changed

libraries/mathy_python/mathy/agents/a3c/worker.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,9 @@ def run_episode(self, episode_memory: EpisodeMemory) -> float:
273273
# before_rnn_state_h = selector.model.unwrapped.embedding.state_h.numpy()
274274
# before_rnn_state_c = selector.model.unwrapped.embedding.state_c.numpy()
275275

276-
window = episode_memory.to_window_observation(last_observation)
276+
window = episode_memory.to_window_observation(
277+
last_observation, window_size=self.args.prediction_window_size
278+
)
277279
try:
278280
action, value = selector.select(
279281
last_state=env.state,

libraries/mathy_python/mathy/agents/base_config.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@ class Config:
2020

2121
# One of "batch" or "layer"
2222
normalization_style = "layer"
23+
# The number of previous timesteps to pass in with the current one
24+
# when making predictions.
25+
prediction_window_size: int = 7
2326

2427
# Whether to use the LSTM or non-reccurrent architecture
2528
use_lstm: bool = True
@@ -62,7 +65,7 @@ class Config:
6265
clip_grouping_control = True
6366

6467
# Include the time/type environment features in the embeddings
65-
use_env_features = False
68+
use_env_features = True
6669

6770
# Include the node values floating point features in the embeddings
6871
use_node_values = True

0 commit comments

Comments
 (0)