Skip to content

Commit b699d6b

Browse files
committed
add in residual network
1 parent 1f89d4b commit b699d6b

File tree

8 files changed

+84
-0
lines changed

8 files changed

+84
-0
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ trained classifiers:
4343
* multi-class DCNN classifier trained with CIFAR-10 data
4444
* multi-class VGG16 classifier trained with ImageNet data
4545
* multi-class VGG19 classifier trained with ImageNet data
46+
* multi-class Residual Network classifier trained with ImageNet data
4647

4748

4849

keras_image_classifier_web/flaskr.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from keras_image_classifier_web.cifar10_classifier import Cifar10Classifier
44
from keras_image_classifier_web.vgg16_classifier import VGG16Classifier
55
from keras_image_classifier_web.vgg19_classifier import VGG19Classifier
6+
from keras_image_classifier_web.resnet50_classifier import ResNet50Classifier
67

78
from flask import Flask, request, session, g, redirect, url_for, abort, \
89
render_template, flash
@@ -31,6 +32,10 @@
3132
vgg19_classifier = VGG19Classifier()
3233
vgg19_classifier.run_test()
3334

35+
resnet50_classifier = ResNet50Classifier()
36+
resnet50_classifier.run_test()
37+
38+
3439
@app.route('/')
3540
def classifiers():
3641
return render_template('classifiers.html')
@@ -92,6 +97,13 @@ def vgg19():
9297
return render_template('vgg19.html')
9398

9499

100+
@app.route('/resnet50', methods=['GET', 'POST'])
101+
def resnet50():
102+
if request.method == 'POST':
103+
return store_uploaded_image('resnet50_result')
104+
return render_template('resnet50.html')
105+
106+
95107
@app.route('/cats_vs_dogs_result/<filename>')
96108
def cats_vs_dogs_result(filename):
97109
filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename)
@@ -124,6 +136,14 @@ def vgg19_result(filename):
124136
top3=top3)
125137

126138

139+
@app.route('/resnet50_result/<filename>')
140+
def resnet50_result(filename):
141+
filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename)
142+
top3 = resnet50_classifier.predict(filepath)
143+
return render_template('resnet50_result.html', filename=filename,
144+
top3=top3)
145+
146+
127147
@app.route('/images/<filename>')
128148
def get_image(filename):
129149
return send_from_directory(app.config['UPLOAD_FOLDER'],
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
from keras.models import Model
2+
from keras.applications.resnet50 import ResNet50, preprocess_input, decode_predictions
3+
from keras.optimizers import SGD
4+
from PIL import Image
5+
from keras.preprocessing.image import img_to_array
6+
import numpy as np
7+
8+
9+
class ResNet50Classifier:
10+
model = None
11+
12+
def __init__(self):
13+
self.model = ResNet50(include_top=True, weights='imagenet')
14+
self.model.compile(optimizer=SGD(), loss='categorical_crossentropy', metrics=['accuracy'])
15+
16+
def predict(self, filename):
17+
img = Image.open(filename)
18+
img = img.resize((224, 224), Image.ANTIALIAS)
19+
input = img_to_array(img)
20+
input = np.expand_dims(input, axis=0)
21+
input = preprocess_input(input)
22+
output = decode_predictions(self.model.predict(input), top=3)
23+
return output[0]
24+
25+
def run_test(self):
26+
print(self.predict('../keras_image_classifier/bi_classifier_data/training/cat/cat.3.jpg'))
27+
28+
29+
if __name__ == '__main__':
30+
classifier = ResNet50Classifier()
31+
classifier.run_test()

keras_image_classifier_web/templates/classifiers.html

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,6 @@
55
<li><a href="{{ url_for('cifar10') }}">CIFAR-10</a></li>
66
<li><a href="{{ url_for('vgg16') }}">VGG-16</a></li>
77
<li><a href="{{ url_for('vgg19') }}">VGG-19</a></li>
8+
<li><a href="{{ url_for('resnet50') }}">ResNet-50</a></li>
89
</ul>
910
{% endblock %}
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
{% extends "layout.html" %}
2+
{% block body %}
3+
<p>Upload your picture and we will try to what it is using ResNet50 classifier</p>
4+
<form method=post enctype=multipart/form-data>
5+
<p><input type=file name=file>
6+
<input type=submit value=Upload>
7+
</form>
8+
{% endblock %}
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
{% extends "layout.html" %}
2+
{% block body %}
3+
<img src="/images/{{ filename }}" />
4+
<br />
5+
<hr />
6+
Below is my top 3 guesses of what is inside your picture:
7+
<ol>
8+
{% for pred in top3 %}
9+
<li>{{ pred[1] }} ({{ pred[2] }})</li>
10+
{% endfor %}
11+
</ol>
12+
<hr />
13+
<a href="{{ url_for('resnet50') }}">Try another picture</a>
14+
{% endblock %}

keras_image_classifier_web/vgg16_classifier.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,7 @@ def predict(self, filename):
2424

2525
def run_test(self):
2626
print(self.predict('../keras_image_classifier/bi_classifier_data/training/cat/cat.3.jpg'))
27+
28+
if __name__ == '__main__':
29+
classifier = VGG16Classifier()
30+
classifier.run_test()

keras_image_classifier_web/vgg19_classifier.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,8 @@ def predict(self, filename):
2424

2525
def run_test(self):
2626
print(self.predict('../keras_image_classifier/bi_classifier_data/training/cat/cat.3.jpg'))
27+
28+
29+
if __name__ == '__main__':
30+
classifier = VGG19Classifier()
31+
classifier.run_test()

0 commit comments

Comments
 (0)