1+ from keras .callbacks import ModelCheckpoint
12from keras .models import model_from_json
23from PIL import Image
34import numpy as np
45import os
6+ from keras .models import Sequential
7+ from keras .layers import Conv2D , MaxPooling2D , Dense , Activation , Flatten , Dropout
8+ from keras .utils import np_utils
9+ import tensorflow as tf
10+ import keras .backend as K
11+
512
613class Cifar10Classifier :
7- cifar10_model = None
14+ model_name = 'cnn_cifar10'
815
916 def __init__ (self ):
10- # load and configure the cifar19 classifier model
11- self .cifar10_model = model_from_json (
12- open (os .path .join ('../training/models' , 'cnn_cifar10_architecture.json' )).read ())
13- self .cifar10_model .load_weights (os .path .join ('../training/models' , 'cnn_cifar10_weights.h5' ))
14- self .cifar10_model .compile (optimizer = 'rmsprop' , loss = 'categorical_crossentropy' , metrics = ['accuracy' ])
17+ self .model = None
18+ self .input_shape = None
19+ self .nb_classes = None
20+
21+ @staticmethod
22+ def get_architecture_file_path (model_dir_path ):
23+ return os .path .join (model_dir_path , Cifar10Classifier .model_name + '_architecture.json' )
24+
25+ @staticmethod
26+ def get_weight_file_path (model_dir_path ):
27+ return os .path .join (model_dir_path , Cifar10Classifier .model_name + '_weights.h5' )
28+
29+ @staticmethod
30+ def get_config_file_path (model_dir_path ):
31+ return os .path .join (model_dir_path , Cifar10Classifier .model_name + '_config.npy' )
32+
33+ def load_model (self , model_dir_path ):
34+
35+ config_file_path = self .get_config_file_path (model_dir_path )
1536
16- def predict (self , filename ):
37+ config = np .load (config_file_path ).item ()
38+
39+ self .input_shape = config ['input_shape' ]
40+ self .nb_classes = config ['nb_classes' ]
41+
42+ self .model = model_from_json (open (self .get_architecture_file_path (model_dir_path )).read ())
43+ self .model .load_weights (self .get_weight_file_path (model_dir_path ))
44+ self .model .compile (optimizer = 'rmsprop' , loss = 'categorical_crossentropy' , metrics = ['accuracy' ])
45+
46+ def predict_label (self , filename ):
1747 img = Image .open (filename )
1848 img = img .resize ((32 , 32 ), Image .ANTIALIAS )
1949
@@ -23,7 +53,7 @@ def predict(self, filename):
2353
2454 print (input .shape )
2555
26- predicted_class = self .cifar10_model .predict_classes (input )[0 ]
56+ predicted_class = self .model .predict_classes (input )[0 ]
2757
2858 labels = [
2959 "airplane" ,
@@ -39,5 +69,137 @@ def predict(self, filename):
3969 ]
4070 return predicted_class , labels [predicted_class ]
4171
72+ @staticmethod
73+ def create_model (input_shape , nb_classes ):
74+ model = Sequential ()
75+ model .add (Conv2D (filters = 32 , input_shape = input_shape , padding = 'same' , kernel_size = (3 , 3 )))
76+ model .add (Activation ('relu' ))
77+ model .add (MaxPooling2D (pool_size = (2 , 2 )))
78+
79+ model .add (Conv2D (filters = 32 , padding = 'same' , kernel_size = (3 , 3 )))
80+ model .add (Activation ('relu' ))
81+ model .add (MaxPooling2D (pool_size = (2 , 2 )))
82+
83+ model .add (Dropout (rate = 0.25 ))
84+
85+ model .add (Conv2D (filters = 64 , kernel_size = (3 , 3 ), padding = 'same' , input_shape = input_shape ))
86+ model .add (Activation ('relu' ))
87+ model .add (MaxPooling2D (pool_size = (2 , 2 )))
88+
89+ model .add (Conv2D (filters = 64 , padding = 'same' , kernel_size = (3 , 3 )))
90+ model .add (Activation ('relu' ))
91+ model .add (MaxPooling2D (pool_size = (2 , 2 )))
92+
93+ model .add (Dropout (rate = 0.25 ))
94+
95+ model .add (Flatten ())
96+ model .add (Dense (units = 512 ))
97+ model .add (Activation ('relu' ))
98+ model .add (Dropout (rate = 0.5 ))
99+ model .add (Dense (units = nb_classes ))
100+ model .add (Activation ('softmax' ))
101+
102+ model .compile (optimizer = 'rmsprop' , loss = 'categorical_crossentropy' , metrics = ['accuracy' ])
103+
104+ return model
105+
42106 def run_test (self ):
43- print (self .predict ('../training/bi_classifier_data/training/cat/cat.2.jpg' ))
107+ print (self .predict_label ('../training/bi_classifier_data/training/cat/cat.2.jpg' ))
108+
109+ def fit (self , Xtrain , Ytrain , model_dir_path , input_shape = None , nb_classes = None , test_size = None , batch_size = None ,
110+ epochs = None ):
111+
112+ if batch_size is None :
113+ batch_size = 64
114+ if epochs is None :
115+ epochs = 20
116+ if test_size is None :
117+ test_size = 0.2
118+
119+ if input_shape is None :
120+ input_shape = (32 , 32 , 3 )
121+
122+ if nb_classes is None :
123+ nb_classes = 10
124+
125+ Xtrain = Xtrain .astype ('float32' ) / 255
126+ Ytrain = np_utils .to_categorical (Ytrain , nb_classes )
127+
128+ self .input_shape = input_shape
129+ self .nb_classes = nb_classes
130+
131+ config_file_path = self .get_config_file_path (model_dir_path )
132+
133+ config = dict ()
134+ config ['input_shape' ] = input_shape
135+ config ['nb_classes' ] = nb_classes
136+
137+ np .save (config_file_path , config )
138+
139+ weight_file_path = self .get_weight_file_path (model_dir_path )
140+
141+ self .model = self .create_model (input_shape , nb_classes )
142+
143+ checkpoint = ModelCheckpoint (filepath = weight_file_path , save_best_only = True )
144+ history = self .model .fit (x = Xtrain , y = Ytrain , batch_size = batch_size , epochs = epochs , verbose = 1 ,
145+ validation_split = test_size ,
146+ callbacks = [checkpoint ])
147+ self .model .save_weights (weight_file_path )
148+
149+ np .save (os .path .join (model_dir_path , Cifar10Classifier .model_name + '-history.npy' ), history .history )
150+
151+ return history
152+
153+ def evaluate (self , Xtest , Ytest , batch_size = None ):
154+
155+ if batch_size is None :
156+ batch_size = 64
157+
158+ Xtest = Xtest .astype ('float32' ) / 255
159+ Ytest = np_utils .to_categorical (Ytest , self .nb_classes )
160+
161+ return self .model .evaluate (x = Xtest , y = Ytest , batch_size = batch_size , verbose = 1 )
162+
163+ def export_tensorflow_model (self , output_fld , output_model_file = None ,
164+ output_graphdef_file = None ,
165+ num_output = None ,
166+ quantize = False ,
167+ save_output_graphdef_file = False ,
168+ output_node_prefix = None ):
169+
170+ K .set_learning_phase (0 )
171+
172+ if output_model_file is None :
173+ output_model_file = Cifar10Classifier .model_name + '.pb'
174+
175+ if output_graphdef_file is None :
176+ output_graphdef_file = 'model.ascii'
177+ if num_output is None :
178+ num_output = 1
179+ if output_node_prefix is None :
180+ output_node_prefix = 'output_node'
181+
182+ pred = [None ] * num_output
183+ pred_node_names = [None ] * num_output
184+ for i in range (num_output ):
185+ pred_node_names [i ] = output_node_prefix + str (i )
186+ pred [i ] = tf .identity (self .model .outputs [i ], name = pred_node_names [i ])
187+ print ('output nodes names are: ' , pred_node_names )
188+
189+ sess = K .get_session ()
190+
191+ if save_output_graphdef_file :
192+ tf .train .write_graph (sess .graph .as_graph_def (), output_fld , output_graphdef_file , as_text = True )
193+ print ('saved the graph definition in ascii format at: ' , output_graphdef_file )
194+
195+ from tensorflow .python .framework import graph_util
196+ from tensorflow .python .framework import graph_io
197+ from tensorflow .tools .graph_transforms import TransformGraph
198+ if quantize :
199+ transforms = ["quantize_weights" , "quantize_nodes" ]
200+ transformed_graph_def = TransformGraph (sess .graph .as_graph_def (), [], pred_node_names , transforms )
201+ constant_graph = graph_util .convert_variables_to_constants (sess , transformed_graph_def , pred_node_names )
202+ else :
203+ constant_graph = graph_util .convert_variables_to_constants (sess , sess .graph .as_graph_def (), pred_node_names )
204+ graph_io .write_graph (constant_graph , output_fld , output_model_file , as_text = False )
205+ print ('saved the freezed graph (ready for inference) at: ' , output_model_file )
0 commit comments