This repository has been archived by the owner on Nov 3, 2022. It is now read-only.
/
cifar10_wide_resnet.py
61 lines (47 loc) · 2.01 KB
/
cifar10_wide_resnet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
'''
Trains a WRN-28-8 model on the CIFAR-10 Dataset.
Performance is slightly less than the paper, since
they use WRN-28-10 model (95.83%).
Gets a 95.54% accuracy score after 300 epochs.
'''
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
from keras.datasets import cifar10
import keras.callbacks as callbacks
import keras.utils.np_utils as kutils
from keras.preprocessing.image import ImageDataGenerator
from keras_contrib.applications.wide_resnet import WideResidualNetwork
batch_size = 64
epochs = 300
img_rows, img_cols = 32, 32
(trainX, trainY), (testX, testY) = cifar10.load_data()
trainX = trainX.astype('float32')
trainX /= 255.0
testX = testX.astype('float32')
testX /= 255.0
tempY = testY
trainY = kutils.to_categorical(trainY)
testY = kutils.to_categorical(testY)
generator = ImageDataGenerator(rotation_range=10,
width_shift_range=5. / 32,
height_shift_range=5. / 32,
horizontal_flip=True)
generator.fit(trainX, seed=0, augment=True)
# We will be training the model, therefore no need to load weights
model = WideResidualNetwork(depth=28, width=8, dropout_rate=0.0, weights=None)
model.summary()
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['acc'])
print('Finished compiling')
model_checkpoint = callbacks.ModelCheckpoint('WRN-28-8 Weights.h5',
monitor='val_acc',
save_best_only=True,
save_weights_only=True)
model.fit_generator(generator.flow(trainX, trainY, batch_size=batch_size),
steps_per_epoch=len(trainX) // batch_size,
epochs=epochs,
callbacks=[model_checkpoint],
validation_data=(testX, testY))
scores = model.evaluate(testX, testY, batch_size)
print('Test loss : %0.5f' % (scores[0]))
print('Test accuracy = %0.5f' % (scores[1]))