Skip to content

Commit 0c76609

Browse files
feat(MathyWindowObservation): add option to return inputs using numpy instead of tf.Tensor
1 parent 95e764e commit 0c76609

File tree

12 files changed

+295
-42
lines changed

12 files changed

+295
-42
lines changed

libraries/mathy_python/mathy/agents/a3c/agent.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,13 +39,18 @@ def __init__(self, args: A3CConfig, env_extra: dict = None):
3939
self.writer = tf.summary.create_file_writer(self.log_dir)
4040
init_window = env.initial_window(self.args.lstm_units)
4141
self.global_model = get_or_create_policy_model(
42-
args=args, env_actions=self.action_size, is_main=True, env=env.mathy
42+
args=args, predictions=self.action_size, is_main=True, env=env.mathy
4343
)
4444
with self.writer.as_default():
4545
tf_model: tf.keras.Model = self.global_model.unwrapped
4646
tf.summary.trace_on(graph=True)
4747
inputs = init_window.to_inputs()
48-
tf_model.call_graph(inputs)
48+
49+
@tf.function
50+
def trace_fn():
51+
return tf_model.call(inputs)
52+
53+
trace_fn()
4954
tf.summary.trace_export(
5055
name="PolicyValueModel", step=0, profiler_outdir=self.log_dir
5156
)

libraries/mathy_python/mathy/agents/embedding.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -135,16 +135,11 @@ def reset_rnn_state(self, force: bool = False):
135135
if self.episode_reset_state_c or force is True:
136136
self.state_c.assign(tf.zeros([1, self.config.lstm_units]))
137137

138-
@tf.function
139-
def call_graph(self, inputs: MathyInputsType) -> tf.Tensor:
140-
"""Autograph optimized function"""
141-
return self.call(inputs)
142-
143-
def call(self, features: MathyInputsType) -> tf.Tensor:
144-
nodes = tf.convert_to_tensor(features[ObservationFeatureIndices.nodes])
145-
values = tf.convert_to_tensor(features[ObservationFeatureIndices.values])
146-
type = tf.cast(features[ObservationFeatureIndices.type], dtype=tf.float32)
147-
time = tf.cast(features[ObservationFeatureIndices.time], dtype=tf.float32)
138+
def call(self, features: MathyInputsType, train: tf.Tensor = None) -> tf.Tensor:
139+
nodes = features[ObservationFeatureIndices.nodes]
140+
values = features[ObservationFeatureIndices.values]
141+
type = features[ObservationFeatureIndices.type]
142+
time = features[ObservationFeatureIndices.time]
148143
nodes_shape = tf.shape(features[ObservationFeatureIndices.nodes])
149144
batch_size = nodes_shape[0] # noqa
150145
sequence_length = nodes_shape[1]

libraries/mathy_python/mathy/agents/mcts.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ def search(self, env_state: MathyEnvState, rnn_state: List[Any], isRootNode=Fals
129129
rnn_history_h=rnn_state[0],
130130
)
131131
observations = observations_to_window([obs]).to_inputs()
132-
out_policy, state_v = self.model.predict_next(observations, use_graph=False)
132+
out_policy, state_v = self.model.predict_next(observations)
133133
out_rnn_state = [
134134
tf.squeeze(self.model.unwrapped.embedding.state_h).numpy(),
135135
tf.squeeze(self.model.unwrapped.embedding.state_c).numpy(),

libraries/mathy_python/mathy/agents/policy_value_model.py

Lines changed: 5 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import numpy as np
88
import srsly
99
import tensorflow as tf
10-
from tensorflow.python.keras import backend as K
10+
from tensorflow.keras import backend as K
1111
from wasabi import msg
1212

1313
from ..env import MathyEnv
@@ -124,13 +124,6 @@ def apply_pi_mask(
124124
)
125125
return negative_mask_logits
126126

127-
@tf.function
128-
def call_graph(
129-
self, inputs: MathyInputsType
130-
) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
131-
"""Autograph optimized function"""
132-
return self.call(inputs)
133-
134127

135128
class ThincPolicyValueModel(thinc.model.Model[ArrayNd, Tuple[Array1d, Array2d]]):
136129
@property
@@ -139,15 +132,10 @@ def unwrapped(self) -> TFPVModel:
139132
assert isinstance(tf_shim, TensorFlowShim), "only tensorflow shim is supported"
140133
return tf_shim._model
141134

142-
def predict_next(
143-
self, inputs: MathyInputsType, use_graph: bool = False
144-
) -> Tuple[tf.Tensor, tf.Tensor]:
135+
def predict_next(self, inputs: MathyInputsType) -> Tuple[tf.Tensor, tf.Tensor]:
145136
"""Predict one probability distribution and value for the
146137
given sequence of inputs """
147-
if use_graph:
148-
logits, values, masked = self.unwrapped.call_graph(inputs)
149-
else:
150-
logits, values, masked = self.unwrapped.call(inputs)
138+
logits, values, masked = self.unwrapped.call(inputs)
151139
# take the last timestep
152140
masked = masked[-1][:]
153141
flat_logits = tf.reshape(tf.squeeze(masked), [-1])
@@ -198,7 +186,7 @@ def _load_model(
198186

199187
def get_or_create_policy_model(
200188
args: BaseConfig,
201-
env_actions: int,
189+
predictions: int,
202190
is_main=False,
203191
required=False,
204192
env: MathyEnv = None,
@@ -235,7 +223,7 @@ def get_or_create_policy_model(
235223
print_error=False,
236224
)
237225

238-
model = PolicyValueModel(args=args, predictions=env_actions, name="agent")
226+
model = PolicyValueModel(args=args, predictions=predictions, name="agent")
239227
init_inputs = initial_state.to_inputs()
240228

241229
def handshake_keras(m: ThincPolicyValueModel):

libraries/mathy_python/mathy/agents/zero/self_play_runner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def get_model(self, game):
3939
)
4040

4141
model: PolicyValueModel = get_or_create_policy_model(
42-
args=config, env_actions=game.action_space.n, is_main=True
42+
args=config, predictions=game.action_space.n, is_main=True
4343
)
4444
return model
4545

libraries/mathy_python/mathy/state.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -77,18 +77,23 @@ class MathyWindowObservation(NamedTuple):
7777
rnn_state_c: WindowRNNStateFloatList
7878
rnn_history_h: WindowRNNStateFloatList
7979

80-
def to_inputs(self) -> MathyInputsType:
80+
def to_inputs(self, as_tf_tensor: bool = True) -> MathyInputsType:
8181
import tensorflow as tf
8282

83+
def to_res(in_value):
84+
if as_tf_tensor is True:
85+
return tf.convert_to_tensor(in_value, dtype=tf.float32)
86+
return np.asarray(in_value, dtype="float32")
87+
8388
result = [
84-
tf.convert_to_tensor(self.nodes),
85-
tf.convert_to_tensor(self.mask),
86-
tf.convert_to_tensor(self.values),
87-
tf.convert_to_tensor(self.type),
88-
tf.convert_to_tensor(self.time),
89-
tf.convert_to_tensor(self.rnn_state_h),
90-
tf.convert_to_tensor(self.rnn_state_c),
91-
tf.convert_to_tensor(self.rnn_history_h),
89+
to_res(self.nodes),
90+
to_res(self.mask),
91+
to_res(self.values),
92+
to_res(self.type),
93+
to_res(self.time),
94+
to_res(self.rnn_state_h),
95+
to_res(self.rnn_state_c),
96+
to_res(self.rnn_history_h),
9297
]
9398
for r in result:
9499
for s in r.shape:

libraries/website/docs/snippets/ml/policy_value_serialization.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
"instance.train()\n",
3939
"# Load the model back in\n",
4040
"model_two = get_or_create_policy_model(\n",
41-
" args=args, env_actions=PolySimplify().action_size, is_main=True\n",
41+
" args=args, predictions=PolySimplify().action_size, is_main=True\n",
4242
")\n",
4343
"# Comment this out to keep your model\n",
4444
"shutil.rmtree(model_folder)"

libraries/website/docs/snippets/ml/policy_value_serialization.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
instance.train()
2727
# Load the model back in
2828
model_two = get_or_create_policy_model(
29-
args=args, env_actions=PolySimplify().action_size, is_main=True
29+
args=args, predictions=PolySimplify().action_size, is_main=True
3030
)
3131
# Comment this out to keep your model
3232
shutil.rmtree(model_folder)
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": null,
6+
"metadata": {
7+
"collapsed": false
8+
},
9+
"outputs": [],
10+
"source": [
11+
"# This file is generated from a Mathy (https://mathy.ai) code example.\n",
12+
"!pip install mathy --upgrade\n",
13+
"import os\n",
14+
"import shutil\n",
15+
"import tempfile\n",
16+
"from typing import List, Tuple\n",
17+
"\n",
18+
"import numpy as np\n",
19+
"\n",
20+
"from mathy import envs\n",
21+
"from mathy.agents.base_config import BaseConfig\n",
22+
"from mathy.agents.embedding import MathyEmbedding\n",
23+
"from mathy.env import MathyEnv\n",
24+
"from mathy.state import MathyEnvState, MathyObservation, observations_to_window\n",
25+
"from thinc.api import TensorFlowWrapper\n",
26+
"from thinc.layers import (\n",
27+
" Embed,\n",
28+
" Linear,\n",
29+
" MeanPool,\n",
30+
" ReLu,\n",
31+
" Softmax,\n",
32+
" chain,\n",
33+
" list2ragged,\n",
34+
" with_array,\n",
35+
" with_list,\n",
36+
")\n",
37+
"from thinc.model import Model\n",
38+
"from thinc.shims.tensorflow import TensorFlowShim\n",
39+
"from thinc.types import Array, Array1d, Array2d, ArrayNd\n",
40+
"\n",
41+
"# Mathy env setup and initial observations\n",
42+
"args = BaseConfig()\n",
43+
"env: MathyEnv = envs.PolySimplify()\n",
44+
"state: MathyEnvState = env.get_initial_state()[0]\n",
45+
"observation: MathyObservation = env.state_to_observation(\n",
46+
" state, rnn_size=args.lstm_units\n",
47+
")\n",
48+
"window = observations_to_window([observation, observation])\n",
49+
"inputs = window.to_inputs()\n",
50+
"\n",
51+
"X = [inputs] # TODO: why do I need to wrap inputs for the tf wrapper?\n",
52+
"input_shape = window.to_input_shapes()\n",
53+
"embeddings = TensorFlowWrapper(MathyEmbedding(args), input_shape=input_shape)\n",
54+
"embeddings.initialize([inputs])\n",
55+
"\n",
56+
"embed_Y = embeddings.predict([inputs])\n",
57+
"# Shape = (2, 23, 128) = (num_observations, padded_sequence_len, vector_width)\n",
58+
"\n",
59+
"# The policy head is a softmax(actions) for each node in the sequence.\n",
60+
"policy_head = chain(embeddings, Softmax(6))\n",
61+
"policy_head.initialize([inputs])\n",
62+
"policy_Y = policy_head.predict([inputs])\n",
63+
"# Shape (desired) = (2, 23, 6)\n",
64+
"\n",
65+
"# The value head is normally a linear transformation from the\n",
66+
"# output embedding layer's RNN state. I haven't tried mixing\n",
67+
"# that tensor in here. TODO: try that\n",
68+
"value_head = chain(embeddings, MeanPool(), Linear(1))\n",
69+
"value_head.initialize([inputs])\n",
70+
"value_Y = value_head.predict([inputs])\n",
71+
"# Shape (desired) = (2, 1)\n",
72+
"\n",
73+
"# Combined [policy_head, value_head] outputs without invoking embeddings twice?\n",
74+
"model: Model[ArrayNd, Tuple[Array2d, Array1d]] = ...\n",
75+
"\n",
76+
"model.initialize([inputs])\n",
77+
"\n",
78+
"Y = model.predict([inputs])\n",
79+
"model.to_disk(f\"training/model\")"
80+
]
81+
}
82+
],
83+
"metadata": {},
84+
"nbformat": 4,
85+
"nbformat_minor": 2
86+
}
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
import os
2+
import shutil
3+
import tempfile
4+
from typing import List, Tuple
5+
6+
import numpy as np
7+
8+
from mathy import envs
9+
from mathy.agents.base_config import BaseConfig
10+
from mathy.agents.embedding import MathyEmbedding
11+
from mathy.env import MathyEnv
12+
from mathy.state import MathyEnvState, MathyObservation, observations_to_window
13+
from thinc.api import TensorFlowWrapper
14+
from thinc.layers import (
15+
Embed,
16+
Linear,
17+
MeanPool,
18+
ReLu,
19+
Softmax,
20+
chain,
21+
list2ragged,
22+
with_array,
23+
with_list,
24+
)
25+
from thinc.model import Model
26+
from thinc.shims.tensorflow import TensorFlowShim
27+
from thinc.types import Array, Array1d, Array2d, ArrayNd
28+
29+
# Mathy env setup and initial observations
30+
args = BaseConfig()
31+
env: MathyEnv = envs.PolySimplify()
32+
state: MathyEnvState = env.get_initial_state()[0]
33+
observation: MathyObservation = env.state_to_observation(
34+
state, rnn_size=args.lstm_units
35+
)
36+
window = observations_to_window([observation, observation])
37+
inputs = window.to_inputs()
38+
39+
X = [inputs] # TODO: why do I need to wrap inputs for the tf wrapper?
40+
input_shape = window.to_input_shapes()
41+
embeddings = TensorFlowWrapper(MathyEmbedding(args), input_shape=input_shape)
42+
embeddings.initialize([inputs])
43+
44+
embed_Y = embeddings.predict([inputs])
45+
# Shape = (2, 23, 128) = (num_observations, padded_sequence_len, vector_width)
46+
47+
# The policy head is a softmax(actions) for each node in the sequence.
48+
policy_head = chain(embeddings, Softmax(6))
49+
policy_head.initialize([inputs])
50+
policy_Y = policy_head.predict([inputs])
51+
# Shape (desired) = (2, 23, 6)
52+
53+
# The value head is normally a linear transformation from the
54+
# output embedding layer's RNN state. I haven't tried mixing
55+
# that tensor in here. TODO: try that
56+
value_head = chain(embeddings, MeanPool(), Linear(1))
57+
value_head.initialize([inputs])
58+
value_Y = value_head.predict([inputs])
59+
# Shape (desired) = (2, 1)
60+
61+
# Combined [policy_head, value_head] outputs without invoking embeddings twice?
62+
model: Model[ArrayNd, Tuple[Array2d, Array1d]] = ...
63+
64+
model.initialize([inputs])
65+
66+
Y = model.predict([inputs])
67+
model.to_disk(f"training/model")

0 commit comments

Comments
 (0)