Skip to content

Commit 5dd5821

Browse files
committed
refactor cifar10 classifier
1 parent c5d746e commit 5dd5821

File tree

4 files changed

+195
-67
lines changed

4 files changed

+195
-67
lines changed

demo/cnn_cifar10_train.py

Lines changed: 23 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1,72 +1,38 @@
1-
from keras.models import Sequential
2-
from keras.layers import Conv2D, MaxPooling2D, Dense, Activation, Flatten, Dropout
3-
from keras.utils import np_utils
4-
from keras.preprocessing.image import ImageDataGenerator
51
from keras.datasets import cifar10
62
import keras.backend as K
7-
from keras.callbacks import ModelCheckpoint
83

9-
train_data_dir = 'multi_classifier_data/training'
10-
validation_data_dir = 'multi_classifier_data/validation'
11-
img_width, img_height = 32, 32
12-
batch_size = 128
13-
epochs= 20
14-
nb_classes = 10
15-
WEIGHT_FILE_PATH = 'models/cnn_cifar10_weights.h5'
4+
from keras_image_classifier.library.cifar10_classifier import Cifar10Classifier
165

17-
(Xtrain, Ytrain), (Xtest, Ytest) = cifar10.load_data()
186

19-
Xtrain = Xtrain.astype('float32') / 255
20-
Xtest = Xtest.astype('float32') / 255
7+
def main():
8+
img_width, img_height = 32, 32
9+
batch_size = 128
10+
epochs = 20
11+
nb_classes = 10
12+
output_dir_path = './models'
2113

22-
Ytrain = np_utils.to_categorical(Ytrain, nb_classes)
23-
Ytest = np_utils.to_categorical(Ytest, nb_classes)
14+
(Xtrain, Ytrain), (Xtest, Ytest) = cifar10.load_data()
2415

25-
if K.image_data_format() == 'channels_first':
26-
input_shape = (3, img_width, img_height)
27-
else:
28-
input_shape = (img_width, img_height, 3)
16+
if K.image_data_format() == 'channels_first':
17+
input_shape = (3, img_width, img_height)
18+
else:
19+
input_shape = (img_width, img_height, 3)
2920

30-
model = Sequential()
31-
model.add(Conv2D(filters=32, input_shape=input_shape, padding='same', kernel_size=(3, 3)))
32-
model.add(Activation('relu'))
33-
model.add(MaxPooling2D(pool_size=(2, 2)))
21+
classifier = Cifar10Classifier()
3422

35-
model.add(Conv2D(filters=32, padding='same', kernel_size=(3, 3)))
36-
model.add(Activation('relu'))
37-
model.add(MaxPooling2D(pool_size=(2, 2)))
23+
classifier.fit(Xtrain, Ytrain, model_dir_path=output_dir_path,
24+
batch_size=batch_size,
25+
epochs=epochs,
26+
input_shape=input_shape, nb_classes=nb_classes)
3827

39-
model.add(Dropout(rate=0.25))
28+
score = classifier.evaluate(Xtest, Ytest, batch_size=batch_size)
4029

41-
model.add(Conv2D(filters=64, kernel_size=(3, 3), padding='same', input_shape=input_shape))
42-
model.add(Activation('relu'))
43-
model.add(MaxPooling2D(pool_size=(2, 2)))
30+
print('score: ', score[0])
31+
print('accurarcy: ', score[1])
4432

45-
model.add(Conv2D(filters=64, padding='same', kernel_size=(3, 3)))
46-
model.add(Activation('relu'))
47-
model.add(MaxPooling2D(pool_size=(2, 2)))
33+
classifier.export_tensorflow_model(output_fld='./models/tf')
4834

49-
model.add(Dropout(rate=0.25))
5035

51-
model.add(Flatten())
52-
model.add(Dense(units=512))
53-
model.add(Activation('relu'))
54-
model.add(Dropout(rate=0.5))
55-
model.add(Dense(units=nb_classes))
56-
model.add(Activation('softmax'))
57-
58-
model.compile(optimizer='rmsprop', loss='categorical_crossentropy', metrics=['accuracy'])
59-
60-
json = model.to_json()
61-
open('models/cnn_cifar10_architecture.json', 'w').write(json)
62-
63-
checkpoint = ModelCheckpoint(filepath=WEIGHT_FILE_PATH, save_best_only=True)
64-
model.fit(x=Xtrain, y=Ytrain, batch_size=batch_size, epochs=epochs, verbose=1, validation_split=0.2, callbacks=[checkpoint])
65-
66-
score = model.evaluate(x=Xtest, y=Ytest, batch_size=batch_size, verbose=1)
67-
68-
print('score: ', score[0])
69-
print('accurarcy: ', score[1])
70-
71-
model.save_weights(WEIGHT_FILE_PATH, overwrite=True)
36+
if __name__ == '__main__':
37+
main()
7238

demo/models/cnn_cifar10_weights.h5

803 KB
Binary file not shown.

demo_web/flaskr.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def cats_vs_dogs_result(filename):
106106
@app.route('/cifar10_result/<filename>')
107107
def cifar10_result(filename):
108108
filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename)
109-
predicted_class, predicted_label = cifar10_classifier.predict(filepath)
109+
predicted_class, predicted_label = cifar10_classifier.predict_label(filepath)
110110
return render_template('cifar10_result.html', filename=filename,
111111
predicted_class=predicted_class, predicted_label=predicted_label)
112112

keras_image_classifier/library/cifar10_classifier.py

Lines changed: 171 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,49 @@
1+
from keras.callbacks import ModelCheckpoint
12
from keras.models import model_from_json
23
from PIL import Image
34
import numpy as np
45
import os
6+
from keras.models import Sequential
7+
from keras.layers import Conv2D, MaxPooling2D, Dense, Activation, Flatten, Dropout
8+
from keras.utils import np_utils
9+
import tensorflow as tf
10+
import keras.backend as K
11+
512

613
class Cifar10Classifier:
7-
cifar10_model = None
14+
model_name = 'cnn_cifar10'
815

916
def __init__(self):
10-
# load and configure the cifar19 classifier model
11-
self.cifar10_model = model_from_json(
12-
open(os.path.join('../training/models', 'cnn_cifar10_architecture.json')).read())
13-
self.cifar10_model.load_weights(os.path.join('../training/models', 'cnn_cifar10_weights.h5'))
14-
self.cifar10_model.compile(optimizer='rmsprop', loss='categorical_crossentropy', metrics=['accuracy'])
17+
self.model = None
18+
self.input_shape = None
19+
self.nb_classes = None
20+
21+
@staticmethod
22+
def get_architecture_file_path(model_dir_path):
23+
return os.path.join(model_dir_path, Cifar10Classifier.model_name + '_architecture.json')
24+
25+
@staticmethod
26+
def get_weight_file_path(model_dir_path):
27+
return os.path.join(model_dir_path, Cifar10Classifier.model_name + '_weights.h5')
28+
29+
@staticmethod
30+
def get_config_file_path(model_dir_path):
31+
return os.path.join(model_dir_path, Cifar10Classifier.model_name + '_config.npy')
32+
33+
def load_model(self, model_dir_path):
34+
35+
config_file_path = self.get_config_file_path(model_dir_path)
1536

16-
def predict(self, filename):
37+
config = np.load(config_file_path).item()
38+
39+
self.input_shape = config['input_shape']
40+
self.nb_classes = config['nb_classes']
41+
42+
self.model = model_from_json(open(self.get_architecture_file_path(model_dir_path)).read())
43+
self.model.load_weights(self.get_weight_file_path(model_dir_path))
44+
self.model.compile(optimizer='rmsprop', loss='categorical_crossentropy', metrics=['accuracy'])
45+
46+
def predict_label(self, filename):
1747
img = Image.open(filename)
1848
img = img.resize((32, 32), Image.ANTIALIAS)
1949

@@ -23,7 +53,7 @@ def predict(self, filename):
2353

2454
print(input.shape)
2555

26-
predicted_class = self.cifar10_model.predict_classes(input)[0]
56+
predicted_class = self.model.predict_classes(input)[0]
2757

2858
labels = [
2959
"airplane",
@@ -39,5 +69,137 @@ def predict(self, filename):
3969
]
4070
return predicted_class, labels[predicted_class]
4171

72+
@staticmethod
73+
def create_model(input_shape, nb_classes):
74+
model = Sequential()
75+
model.add(Conv2D(filters=32, input_shape=input_shape, padding='same', kernel_size=(3, 3)))
76+
model.add(Activation('relu'))
77+
model.add(MaxPooling2D(pool_size=(2, 2)))
78+
79+
model.add(Conv2D(filters=32, padding='same', kernel_size=(3, 3)))
80+
model.add(Activation('relu'))
81+
model.add(MaxPooling2D(pool_size=(2, 2)))
82+
83+
model.add(Dropout(rate=0.25))
84+
85+
model.add(Conv2D(filters=64, kernel_size=(3, 3), padding='same', input_shape=input_shape))
86+
model.add(Activation('relu'))
87+
model.add(MaxPooling2D(pool_size=(2, 2)))
88+
89+
model.add(Conv2D(filters=64, padding='same', kernel_size=(3, 3)))
90+
model.add(Activation('relu'))
91+
model.add(MaxPooling2D(pool_size=(2, 2)))
92+
93+
model.add(Dropout(rate=0.25))
94+
95+
model.add(Flatten())
96+
model.add(Dense(units=512))
97+
model.add(Activation('relu'))
98+
model.add(Dropout(rate=0.5))
99+
model.add(Dense(units=nb_classes))
100+
model.add(Activation('softmax'))
101+
102+
model.compile(optimizer='rmsprop', loss='categorical_crossentropy', metrics=['accuracy'])
103+
104+
return model
105+
42106
def run_test(self):
43-
print(self.predict('../training/bi_classifier_data/training/cat/cat.2.jpg'))
107+
print(self.predict_label('../training/bi_classifier_data/training/cat/cat.2.jpg'))
108+
109+
def fit(self, Xtrain, Ytrain, model_dir_path, input_shape=None, nb_classes=None, test_size=None, batch_size=None,
110+
epochs=None):
111+
112+
if batch_size is None:
113+
batch_size = 64
114+
if epochs is None:
115+
epochs = 20
116+
if test_size is None:
117+
test_size = 0.2
118+
119+
if input_shape is None:
120+
input_shape = (32, 32, 3)
121+
122+
if nb_classes is None:
123+
nb_classes = 10
124+
125+
Xtrain = Xtrain.astype('float32') / 255
126+
Ytrain = np_utils.to_categorical(Ytrain, nb_classes)
127+
128+
self.input_shape = input_shape
129+
self.nb_classes = nb_classes
130+
131+
config_file_path = self.get_config_file_path(model_dir_path)
132+
133+
config = dict()
134+
config['input_shape'] = input_shape
135+
config['nb_classes'] = nb_classes
136+
137+
np.save(config_file_path, config)
138+
139+
weight_file_path = self.get_weight_file_path(model_dir_path)
140+
141+
self.model = self.create_model(input_shape, nb_classes)
142+
143+
checkpoint = ModelCheckpoint(filepath=weight_file_path, save_best_only=True)
144+
history = self.model.fit(x=Xtrain, y=Ytrain, batch_size=batch_size, epochs=epochs, verbose=1,
145+
validation_split=test_size,
146+
callbacks=[checkpoint])
147+
self.model.save_weights(weight_file_path)
148+
149+
np.save(os.path.join(model_dir_path, Cifar10Classifier.model_name + '-history.npy'), history.history)
150+
151+
return history
152+
153+
def evaluate(self, Xtest, Ytest, batch_size=None):
154+
155+
if batch_size is None:
156+
batch_size = 64
157+
158+
Xtest = Xtest.astype('float32') / 255
159+
Ytest = np_utils.to_categorical(Ytest, self.nb_classes)
160+
161+
return self.model.evaluate(x=Xtest, y=Ytest, batch_size=batch_size, verbose=1)
162+
163+
def export_tensorflow_model(self, output_fld, output_model_file=None,
164+
output_graphdef_file=None,
165+
num_output=None,
166+
quantize=False,
167+
save_output_graphdef_file=False,
168+
output_node_prefix=None):
169+
170+
K.set_learning_phase(0)
171+
172+
if output_model_file is None:
173+
output_model_file = Cifar10Classifier.model_name + '.pb'
174+
175+
if output_graphdef_file is None:
176+
output_graphdef_file = 'model.ascii'
177+
if num_output is None:
178+
num_output = 1
179+
if output_node_prefix is None:
180+
output_node_prefix = 'output_node'
181+
182+
pred = [None] * num_output
183+
pred_node_names = [None] * num_output
184+
for i in range(num_output):
185+
pred_node_names[i] = output_node_prefix + str(i)
186+
pred[i] = tf.identity(self.model.outputs[i], name=pred_node_names[i])
187+
print('output nodes names are: ', pred_node_names)
188+
189+
sess = K.get_session()
190+
191+
if save_output_graphdef_file:
192+
tf.train.write_graph(sess.graph.as_graph_def(), output_fld, output_graphdef_file, as_text=True)
193+
print('saved the graph definition in ascii format at: ', output_graphdef_file)
194+
195+
from tensorflow.python.framework import graph_util
196+
from tensorflow.python.framework import graph_io
197+
from tensorflow.tools.graph_transforms import TransformGraph
198+
if quantize:
199+
transforms = ["quantize_weights", "quantize_nodes"]
200+
transformed_graph_def = TransformGraph(sess.graph.as_graph_def(), [], pred_node_names, transforms)
201+
constant_graph = graph_util.convert_variables_to_constants(sess, transformed_graph_def, pred_node_names)
202+
else:
203+
constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph.as_graph_def(), pred_node_names)
204+
graph_io.write_graph(constant_graph, output_fld, output_model_file, as_text=False)
205+
print('saved the freezed graph (ready for inference) at: ', output_model_file)

0 commit comments

Comments
 (0)