11import os
22import pickle
33import time
4- from shutil import copyfile
5- from typing import Any , Dict , Optional , Tuple , List , Callable
64from pathlib import Path
5+ from shutil import copyfile
6+ from typing import Any , Callable , Dict , List , Optional , Tuple , cast
7+ from memory_profiler import profile
8+
79import numpy as np
810import srsly
911import tensorflow as tf
12+ import thinc
1013from tensorflow .keras import backend as K
14+ from thinc .api import TensorFlowWrapper , keras_subclass , tensorflow2xp , xp2tensorflow
15+ from thinc .backends import Ops , get_current_ops
16+ from thinc .layers import Linear
17+ from thinc .model import Model
18+ from thinc .optimizers import Adam
19+ from thinc .shims .tensorflow import TensorFlowShim
20+ from thinc .types import Array , Array1d , Array2d , ArrayNd
21+ from thinc .util import has_tensorflow , to_categorical
1122from wasabi import msg
1223
24+ from mathy .agents .example import example
25+
1326from ..env import MathyEnv
1427from ..envs import PolySimplify
1528from ..state import (
2235from ..util import print_error
2336from .base_config import BaseConfig
2437from .embedding import MathyEmbedding
25- import thinc
26- from thinc .layers import Linear
27- from thinc .api import TensorFlowWrapper , tensorflow2xp , xp2tensorflow
28- from thinc .backends import Ops , get_current_ops
29- from thinc .model import Model
30- from thinc .optimizers import Adam
31- from thinc .types import Array , Array1d , Array2d , ArrayNd
32- from thinc .shims .tensorflow import TensorFlowShim
33- from thinc .util import has_tensorflow , to_categorical
3438
39+ eg = example ()
3540
41+
42+ @keras_subclass (
43+ "TFPVModel.v0" , X = eg .to_inputs (), Y = eg .mask , input_shape = eg .to_input_shapes ()
44+ )
3645class TFPVModel (tf .keras .Model ):
3746 args : BaseConfig
3847 optimizer : tf .optimizers .Optimizer
@@ -83,7 +92,13 @@ def compute_output_shape(
8392 )
8493
8594 def call (
86- self , features_window : MathyInputsType , apply_mask = True
95+ self , features_window : MathyInputsType
96+ ) -> Tuple [tf .Tensor , tf .Tensor , tf .Tensor ]:
97+ return self ._call (features_window )
98+
99+ # @profile
100+ def _call (
101+ self , features_window : MathyInputsType
87102 ) -> Tuple [tf .Tensor , tf .Tensor , tf .Tensor ]:
88103 call_print = self .args .print_model_call_times
89104 nodes = features_window [ObservationFeatureIndices .nodes ]
@@ -99,10 +114,13 @@ def call(
99114 values = self .normalize_v (self .value_logits (self .embedding .state_h ))
100115 logits = self .normalize_pi (self .policy_logits (sequence_inputs ))
101116 mask_logits = self .apply_pi_mask (logits , features_window )
102- mask_result = logits if not apply_mask else mask_logits
103117 if call_print is True :
104- print ("call took : {0:03f}" .format (time .time () - start ))
105- return logits , values , mask_result
118+ print (
119+ "call took : {0:03f} for batch {1}" .format (
120+ time .time () - start , batch_size
121+ )
122+ )
123+ return logits , values , mask_logits
106124
107125 def apply_pi_mask (
108126 self , logits : tf .Tensor , features_window : MathyInputsType ,
@@ -125,14 +143,18 @@ def apply_pi_mask(
125143 return negative_mask_logits
126144
127145
128- class ThincPolicyValueModel (thinc .model .Model [ArrayNd , Tuple [Array1d , Array2d ]]):
146+ class ThincPolicyValueModel (
147+ thinc .model .Model [ArrayNd , Tuple [Array2d , Array1d , Array2d ]]
148+ ):
129149 @property
130150 def unwrapped (self ) -> TFPVModel :
131- tf_shim : TensorFlowShim = self .shims [0 ]
151+ tf_shim = cast ( TensorFlowShim , self .shims [0 ])
132152 assert isinstance (tf_shim , TensorFlowShim ), "only tensorflow shim is supported"
133153 return tf_shim ._model
134154
135- def predict_next (self , inputs : MathyInputsType ) -> Tuple [tf .Tensor , tf .Tensor ]:
155+ def predict_next (
156+ self , inputs : MathyInputsType , is_train : bool = False
157+ ) -> Tuple [tf .Tensor , tf .Tensor ]:
136158 """Predict one probability distribution and value for the
137159 given sequence of inputs """
138160 logits , values , masked = self .unwrapped .call (inputs )
@@ -148,21 +170,22 @@ def save(self) -> None:
148170 model_path = os .path .join (
149171 self .unwrapped .args .model_dir , self .unwrapped .args .model_name
150172 )
173+ save_model_file = f"{ model_path } .bytes"
174+ self .to_disk (save_model_file )
151175 with open (f"{ model_path } .optimizer" , "wb" ) as f :
152176 pickle .dump (self .unwrapped .optimizer .get_weights (), f )
153- model_path += ".h5"
154- self .unwrapped .save_weights (model_path , save_format = "keras" )
155177 step = self .unwrapped .optimizer .iterations .numpy ()
156- print (f"[save] step({ step } ) model({ model_path } )" )
178+ print (f"[save] step({ step } ) model({ save_model_file } )" )
157179
158180
159181def PolicyValueModel (
160- args : BaseConfig = None , predictions = 2 , initial_state : Any = None , ** kwargs ,
182+ args : BaseConfig = None , predictions = 2 , ** kwargs ,
161183):
162- tf_model = TFPVModel (args , predictions , initial_state , ** kwargs )
184+ tf_model = TFPVModel (args , predictions , ** kwargs )
163185 return TensorFlowWrapper (
164186 tf_model ,
165- build_model = False ,
187+ build_model = True ,
188+ input_shape = eg .to_input_shapes (),
166189 model_class = ThincPolicyValueModel ,
167190 model_name = "agent" ,
168191 )
@@ -174,7 +197,7 @@ def _load_model(
174197 optimizer_file : str ,
175198 build_fn : Callable [[ThincPolicyValueModel ], None ] = None ,
176199) -> ThincPolicyValueModel :
177- model .unwrapped . load_weights (model_file )
200+ model .from_disk (model_file )
178201 if build_fn is not None :
179202 build_fn (model )
180203 model .unwrapped ._make_train_function ()
@@ -205,8 +228,8 @@ def get_or_create_policy_model(
205228 if is_main and args .init_model_from is not None :
206229 init_model_path = os .path .join (args .init_model_from , args .model_name )
207230 opt = f"{ init_model_path } .optimizer"
208- mod = f"{ init_model_path } .h5 "
209- if os .path .exists (f"{ model_path } .h5 " ):
231+ mod = f"{ init_model_path } .bytes "
232+ if os .path .exists (f"{ model_path } .bytes " ):
210233 print_error (
211234 ValueError ("Model Exists" ),
212235 f"Cannot initialize on top of model: { model_path } " ,
@@ -215,7 +238,7 @@ def get_or_create_policy_model(
215238 if os .path .exists (opt ) and os .path .exists (mod ):
216239 print (f"initialize model from: { init_model_path } " )
217240 copyfile (opt , f"{ model_path } .optimizer" )
218- copyfile (mod , f"{ model_path } .h5 " )
241+ copyfile (mod , f"{ model_path } .bytes " )
219242 else :
220243 print_error (
221244 ValueError ("Model Exists" ),
@@ -240,7 +263,7 @@ def handshake_keras(m: ThincPolicyValueModel):
240263 handshake_keras (model )
241264
242265 opt = f"{ model_path } .optimizer"
243- mod = f"{ model_path } .h5 "
266+ mod = f"{ model_path } .bytes "
244267 if os .path .exists (mod ):
245268 if is_main and args .verbose :
246269 with msg .loading (f"Loading model: { mod } ..." ):
@@ -275,7 +298,7 @@ def load_policy_value_model(
275298 if not meta_file .exists ():
276299 raise ValueError (f"model meta not found: { meta_file } " )
277300 args = BaseConfig (** srsly .read_json (str (meta_file )))
278- model_file = Path (model_data_folder ) / "model.h5 "
301+ model_file = Path (model_data_folder ) / "model.bytes "
279302 optimizer_file = Path (model_data_folder ) / "model.optimizer"
280303 if not model_file .exists ():
281304 raise ValueError (f"model not found: { model_file } " )
0 commit comments