Skip to content

Commit 703a977

Browse files
committed
working bp generation
1 parent 6d0e8f0 commit 703a977

File tree

7 files changed

+52
-3
lines changed

7 files changed

+52
-3
lines changed
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import tensorflow as tf
2+
import numpy as np
3+
from keras.datasets import cifar10
4+
5+
from keras_image_classifier.library.cifar10_classifier import Cifar10Classifier
6+
7+
8+
def main():
9+
10+
11+
(Xtrain, Ytrain), (Xtest, Ytest) = cifar10.load_data()
12+
13+
Xtest = Xtest.astype('float32') / 255
14+
15+
with tf.gfile.FastGFile('./models/tf/cnn_cifar10.pb', 'rb') as f:
16+
graph_def = tf.GraphDef()
17+
graph_def.ParseFromString(f.read())
18+
_ = tf.import_graph_def(graph_def, name='')
19+
20+
with tf.Session() as sess:
21+
[print(n.name) for n in sess.graph.as_graph_def().node]
22+
predict_op = sess.graph.get_tensor_by_name('output_node0:0')
23+
24+
for i in range(Xtest.shape[0]):
25+
x = Xtest[i]
26+
x = x.astype(np.float) / 255
27+
x = np.expand_dims(x, axis=0)
28+
y = Ytest[i]
29+
predicted = sess.run(predict_op, feed_dict={"conv2d_1_input:0": x,
30+
'dropout_1/keras_learning_phase:0': 0})
31+
predicted_y = np.argmax(predicted, axis=1)
32+
print('actual: ', y, '\tpredicted: ', predicted_y)
33+
34+
35+
if __name__ == '__main__':
36+
main()
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from keras_image_classifier.library.cifar10_classifier import Cifar10Classifier
2+
3+
4+
def main():
5+
classifier = Cifar10Classifier()
6+
classifier.load_model('./models')
7+
classifier.export_tensorflow_model(output_fld='./models/tf')
8+
9+
if __name__ == '__main__':
10+
main()
11+

demo/cnn_cifar10_predict.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ def predict_cifar10(filename):
2626
output = cifar10_model.predict_classes(input)[0]
2727
return output
2828

29+
2930
(Xtrain, Ytrain), (Xtest, Ytest) = cifar10.load_data()
3031

3132
Xtest = Xtest.astype('float32') / 255

demo/cnn_cifar10_train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def main():
3030
print('score: ', score[0])
3131
print('accurarcy: ', score[1])
3232

33-
classifier.export_tensorflow_model(output_fld='./models/tf')
33+
3434

3535

3636
if __name__ == '__main__':

demo/models/tf/cnn_cifar10.pb

92 Bytes
Binary file not shown.

keras_image_classifier/library/cifar10_classifier.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,9 @@ def load_model(self, model_dir_path):
4040
self.nb_classes = config['nb_classes']
4141

4242
self.model = model_from_json(open(self.get_architecture_file_path(model_dir_path)).read())
43-
self.model.load_weights(self.get_weight_file_path(model_dir_path))
4443
self.model.compile(optimizer='rmsprop', loss='categorical_crossentropy', metrics=['accuracy'])
44+
self.model.load_weights(self.get_weight_file_path(model_dir_path))
45+
4546

4647
def predict_label(self, filename):
4748
img = Image.open(filename)

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,4 @@ numpy
55
h5py
66
pillow
77
scikit-learn
8-
tensorflow
8+
tensorflow == 1.6.0rc1

0 commit comments

Comments
 (0)