Skip to content

Commit 4e0cd8c

Browse files
committed
minor update
1 parent 8fc7ad5 commit 4e0cd8c

File tree

2 files changed

+80
-75
lines changed

2 files changed

+80
-75
lines changed

keras_image_classifier_train/cnn_bi_classifier_train.py

Lines changed: 79 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -4,79 +4,83 @@
44
from keras import backend as K
55
from keras.callbacks import ModelCheckpoint
66

7-
img_width, img_height = 150, 150
8-
9-
train_data_dir = 'bi_classifier_data/training'
10-
validation_data_dir = 'bi_classifier_data/validation'
11-
nb_train_samples = 2000
12-
nb_validation_samples = 800
13-
nb_test_samples = 400
14-
epochs = 50
15-
batch_size = 16
16-
WEIGHT_FILE_PATH = 'models/cnn_bi_classifier_weights.h5'
17-
18-
if K.image_data_format() == 'channels_first':
19-
input_shape = (3, img_width, img_height)
20-
else:
21-
input_shape = (img_width, img_height, 3)
22-
23-
model = Sequential()
24-
model.add(Conv2D(32, kernel_size=(3, 3), input_shape=input_shape))
25-
model.add(Activation('relu'))
26-
model.add(MaxPooling2D(pool_size=(2, 2)))
27-
28-
model.add(Conv2D(32, (3, 3)))
29-
model.add(Activation('relu'))
30-
model.add(MaxPooling2D(pool_size=(2, 2)))
31-
32-
model.add(Conv2D(64, (3, 3)))
33-
model.add(Activation('relu'))
34-
model.add(MaxPooling2D(pool_size=(2, 2)))
35-
36-
model.add(Flatten())
37-
model.add(Dense(64))
38-
model.add(Dropout(0.5))
39-
model.add(Dense(1))
40-
model.add(Activation('sigmoid'))
41-
42-
model.compile(loss='binary_crossentropy', optimizer='rmsprop', metrics=['accuracy'])
43-
44-
model_json = model.to_json()
45-
open('models/cnn_bi_classifier_architecture.json','w').write(model_json)
46-
47-
train_datagen = ImageDataGenerator(
48-
rescale=1./255,
49-
shear_range=0.2,
50-
zoom_range=0.2,
51-
horizontal_flip=True
52-
)
53-
54-
validation_datagen = ImageDataGenerator(rescale=1./255)
55-
56-
test_datagen = ImageDataGenerator(rescale=1./255)
57-
58-
train_generator = train_datagen.flow_from_directory(
59-
directory=train_data_dir,
60-
target_size=(img_width, img_height),
61-
batch_size=batch_size,
62-
class_mode='binary'
63-
)
64-
65-
validation_generator = validation_datagen.flow_from_directory(
66-
directory=validation_data_dir,
67-
target_size=(img_width, img_height),
68-
batch_size=batch_size,
69-
class_mode='binary'
70-
)
71-
72-
checkpoint = ModelCheckpoint(filepath=WEIGHT_FILE_PATH, save_best_only=True)
73-
model.fit_generator(train_generator,
74-
steps_per_epoch=nb_train_samples // batch_size,
75-
epochs=epochs,
76-
validation_data=validation_generator,
77-
validation_steps=nb_validation_samples // batch_size,
78-
callbacks=[checkpoint])
79-
80-
81-
model.save_weights(WEIGHT_FILE_PATH, overwrite=True)
827

8+
def main():
9+
img_width, img_height = 150, 150
10+
11+
train_data_dir = 'bi_classifier_data/training'
12+
validation_data_dir = 'bi_classifier_data/validation'
13+
nb_train_samples = 2000
14+
nb_validation_samples = 800
15+
nb_test_samples = 400
16+
epochs = 50
17+
batch_size = 16
18+
WEIGHT_FILE_PATH = 'models/cnn_bi_classifier_weights.h5'
19+
20+
if K.image_data_format() == 'channels_first':
21+
input_shape = (3, img_width, img_height)
22+
else:
23+
input_shape = (img_width, img_height, 3)
24+
25+
model = Sequential()
26+
model.add(Conv2D(32, kernel_size=(3, 3), input_shape=input_shape))
27+
model.add(Activation('relu'))
28+
model.add(MaxPooling2D(pool_size=(2, 2)))
29+
30+
model.add(Conv2D(32, (3, 3)))
31+
model.add(Activation('relu'))
32+
model.add(MaxPooling2D(pool_size=(2, 2)))
33+
34+
model.add(Conv2D(64, (3, 3)))
35+
model.add(Activation('relu'))
36+
model.add(MaxPooling2D(pool_size=(2, 2)))
37+
38+
model.add(Flatten())
39+
model.add(Dense(64))
40+
model.add(Dropout(0.5))
41+
model.add(Dense(1))
42+
model.add(Activation('sigmoid'))
43+
44+
model.compile(loss='binary_crossentropy', optimizer='rmsprop', metrics=['accuracy'])
45+
46+
model_json = model.to_json()
47+
open('models/cnn_bi_classifier_architecture.json', 'w').write(model_json)
48+
49+
train_datagen = ImageDataGenerator(
50+
rescale=1. / 255,
51+
shear_range=0.2,
52+
zoom_range=0.2,
53+
horizontal_flip=True
54+
)
55+
56+
validation_datagen = ImageDataGenerator(rescale=1. / 255)
57+
58+
test_datagen = ImageDataGenerator(rescale=1. / 255)
59+
60+
train_generator = train_datagen.flow_from_directory(
61+
directory=train_data_dir,
62+
target_size=(img_width, img_height),
63+
batch_size=batch_size,
64+
class_mode='binary'
65+
)
66+
67+
validation_generator = validation_datagen.flow_from_directory(
68+
directory=validation_data_dir,
69+
target_size=(img_width, img_height),
70+
batch_size=batch_size,
71+
class_mode='binary'
72+
)
73+
74+
checkpoint = ModelCheckpoint(filepath=WEIGHT_FILE_PATH, save_best_only=True)
75+
model.fit_generator(train_generator,
76+
steps_per_epoch=nb_train_samples // batch_size,
77+
epochs=epochs,
78+
validation_data=validation_generator,
79+
validation_steps=nb_validation_samples // batch_size,
80+
callbacks=[checkpoint])
81+
82+
model.save_weights(WEIGHT_FILE_PATH, overwrite=True)
83+
84+
85+
if __name__ == '__main__':
86+
main()

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,5 @@ keras
44
numpy
55
h5py
66
pillow
7+
scikit-learn
78
https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.11.0-cp27-none-linux_x86_64.whl

0 commit comments

Comments
 (0)