|
39 | 39 |
|
40 | 40 |
|
41 | 41 | @keras_subclass( |
42 | | - "TFPVModel.v0", X=eg.to_inputs(), Y=eg.mask, input_shape=eg.to_input_shapes() |
| 42 | + "TFPVModel.v0", |
| 43 | + X=eg.to_inputs(as_tf_tensor=False), |
| 44 | + Y=eg.mask, |
| 45 | + input_shape=eg.to_input_shapes(), |
| 46 | + compile_args={"loss": "binary_crossentropy", "metrics": ["accuracy"]}, |
43 | 47 | ) |
44 | 48 | class TFPVModel(tf.keras.Model): |
45 | 49 | args: BaseConfig |
@@ -183,23 +187,16 @@ def PolicyValueModel( |
183 | 187 | tf_model = TFPVModel(args, predictions, **kwargs) |
184 | 188 | return TensorFlowWrapper( |
185 | 189 | tf_model, |
186 | | - build_model=True, |
187 | | - input_shape=eg.to_input_shapes(), |
188 | 190 | model_class=ThincPolicyValueModel, |
189 | 191 | model_name="agent", |
| 192 | + optimizer=tf_model.optimizer, |
190 | 193 | ) |
191 | 194 |
|
192 | 195 |
|
193 | 196 | def _load_model( |
194 | | - model: ThincPolicyValueModel, |
195 | | - model_file: str, |
196 | | - optimizer_file: str, |
197 | | - build_fn: Callable[[ThincPolicyValueModel], None] = None, |
| 197 | + model: ThincPolicyValueModel, model_file: str, optimizer_file: str, |
198 | 198 | ) -> ThincPolicyValueModel: |
199 | 199 | model.from_disk(model_file) |
200 | | - if build_fn is not None: |
201 | | - build_fn(model) |
202 | | - model.unwrapped._make_train_function() |
203 | 200 | with open(optimizer_file, "rb") as f: |
204 | 201 | weight_values = pickle.load(f) |
205 | 202 | model.unwrapped.optimizer.set_weights(weight_values) |
@@ -248,28 +245,15 @@ def get_or_create_policy_model( |
248 | 245 | model = PolicyValueModel(args=args, predictions=predictions, name="agent") |
249 | 246 | init_inputs = initial_state.to_inputs() |
250 | 247 |
|
251 | | - def handshake_keras(m: ThincPolicyValueModel): |
252 | | - |
253 | | - m.unwrapped.compile( |
254 | | - optimizer=m.unwrapped.optimizer, |
255 | | - loss="binary_crossentropy", |
256 | | - metrics=["accuracy"], |
257 | | - ) |
258 | | - m.unwrapped.build(initial_state.to_input_shapes()) |
259 | | - m.unwrapped.predict(init_inputs) |
260 | | - m.predict_next(init_inputs) |
261 | | - |
262 | | - handshake_keras(model) |
263 | | - |
264 | 248 | opt = f"{model_path}.optimizer" |
265 | 249 | mod = f"{model_path}.bytes" |
266 | 250 | if os.path.exists(mod): |
267 | 251 | if is_main and args.verbose: |
268 | 252 | with msg.loading(f"Loading model: {mod}..."): |
269 | | - _load_model(model, mod, opt, build_fn=handshake_keras) |
| 253 | + _load_model(model, mod, opt) |
270 | 254 | msg.good(f"Loaded model: {mod}") |
271 | 255 | else: |
272 | | - _load_model(model, mod, opt, build_fn=handshake_keras) |
| 256 | + _load_model(model, mod, opt) |
273 | 257 |
|
274 | 258 | # If we're doing transfer, reset optimizer steps |
275 | 259 | if is_main and args.init_model_from is not None: |
@@ -303,35 +287,17 @@ def load_policy_value_model( |
303 | 287 | raise ValueError(f"model not found: {model_file}") |
304 | 288 | if not optimizer_file.exists(): |
305 | 289 | raise ValueError(f"optimizer not found: {optimizer_file}") |
306 | | - |
307 | 290 | env: MathyEnv = PolySimplify() |
308 | 291 | observation: MathyObservation = env.state_to_observation( |
309 | 292 | env.get_initial_state()[0], rnn_size=args.lstm_units |
310 | 293 | ) |
311 | 294 | initial_state: MathyWindowObservation = observations_to_window([observation]) |
312 | 295 | model = PolicyValueModel(args=args, predictions=env.action_size, name="agent") |
313 | 296 | init_inputs = initial_state.to_inputs() |
314 | | - |
315 | | - def handshake_keras(m: ThincPolicyValueModel): |
316 | | - |
317 | | - m.unwrapped.compile( |
318 | | - optimizer=m.unwrapped.optimizer, |
319 | | - loss="binary_crossentropy", |
320 | | - metrics=["accuracy"], |
321 | | - ) |
322 | | - m.unwrapped.build(initial_state.to_input_shapes()) |
323 | | - m.unwrapped.predict(init_inputs) |
324 | | - m.predict_next(init_inputs) |
325 | | - |
326 | | - handshake_keras(model) |
327 | 297 | if not silent: |
328 | 298 | with msg.loading(f"Loading model: {model_file}..."): |
329 | | - _load_model( |
330 | | - model, str(model_file), str(optimizer_file), build_fn=handshake_keras |
331 | | - ) |
| 299 | + _load_model(model, str(model_file), str(optimizer_file)) |
332 | 300 | msg.good(f"Loaded model: {model_file}") |
333 | 301 | else: |
334 | | - _load_model( |
335 | | - model, str(model_file), str(optimizer_file), build_fn=handshake_keras |
336 | | - ) |
| 302 | + _load_model(model, str(model_file), str(optimizer_file)) |
337 | 303 | return model, args |
0 commit comments