# 毕业项目：检测分神司机

In [3]:
%config InlineBackend.figure_format = 'retina'

from urllib.request import urlretrieve
from os.path import isfile, isdir, join, pardir
from IPython.display import SVG, Image
import os, shutil
import random
import glob
import zipfile

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

from keras.preprocessing import image
from keras.models import Model, load_model
from keras.preprocessing.image import ImageDataGenerator
from keras.layers import Input, Dense, Flatten, GlobalAveragePooling2D, Dropout
from keras import optimizers
from keras.utils.vis_utils import model_to_dot, plot_model
from keras.utils import to_categorical

from keras.applications import vgg16
from keras.applications.vgg16 import VGG16


import pydot
import cv2
import h5py

  from ._conv import register_converters as _register_converters
Using TensorFlow backend.


图像增强

In [37]:
vgg16_train_datagen = ImageDataGenerator(preprocessing_function=vgg16.preprocess_input)
vgg16_valid_datagen = ImageDataGenerator(preprocessing_function=vgg16.preprocess_input)

resnet50_train_datagen = ImageDataGenerator(preprocessing_function=resnet50.preprocess_input)
resnet50_valid_datagen = ImageDataGenerator(preprocessing_function=resnet50.preprocess_input)

vgg16_train_generator = vgg16_train_datagen.flow_from_directory(
    image_train_folder_path,
    target_size=(224, 224),
    batch_size=16,
    class_mode='categorical')

vgg16_valid_generator = vgg16_valid_datagen.flow_from_directory(
    image_valid_folder_path,
    target_size=(224, 224),
    batch_size=16,
    class_mode='categorical',
    shuffle=False)

Found 20097 images belonging to 10 classes.
Found 2327 images belonging to 10 classes.
Found 20097 images belonging to 10 classes.
Found 2327 images belonging to 10 classes.


### 二、创建基准模型

In [10]:
def create_model_vgg16():
    model_vgg16 = VGG16(include_top=False, weights='imagenet')
    
    #print('Print vgg16 model summary:')
    #print(model_vgg16.summary())
    
    #for i in range(172):
        #model_vgg16.layers[i].trainable = False

    input = Input(shape=(224, 224, 3), name='image_input')

    output_vgg16_conv = model_vgg16(input)
    
    x = GlobalAveragePooling2D()(output_vgg16_conv)
    
    x = Dropout(0.2)(x)
    
    x = Dense(10, activation='softmax')(x)

    model = Model(inputs=input, outputs=x)
    
    sgd = optimizers.SGD(lr=1e-4, momentum=0.9)

    model.compile(optimizer=sgd, loss='categorical_crossentropy', metrics=['accuracy'])
    
    return model

### 三、基准模型可视化

In [4]:
model_vgg16 = create_model_vgg16()
#print('Print my model summry:')
#print(model.summary())

5


### 五、基准模型训练

In [17]:
history_vgg16 = model_vgg16.fit_generator(
    vgg16_train_generator,
    epochs=10,
    validation_data=vgg16_valid_generator)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


In [18]:
model_vgg16.save(join(pardir, 'model', 'vgg16.h5'))
print("Vgg16 model saved.")

Vgg16 model saved.


### 六、基准模型评估

In [20]:
test_image_path = join(driver_dataset_folder_path, 'test')
vgg16_test_datagen = ImageDataGenerator(preprocessing_function=vgg16.preprocess_input)
vgg16_test_generator = vgg16_test_datagen.flow_from_directory(
    test_image_path,
    target_size=(224, 224),
    batch_size=32,
    shuffle=False,
    class_mode='categorical')

model_vgg16 = load_model(join(pardir, 'model', 'vgg16.h5'))
print("Model loaded.")
pred_vgg16 = model_vgg16.predict_generator(vgg16_test_generator, verbose=1)
print(pred_vgg16.shape)

Found 79726 images belonging to 1 classes.
Model loaded.
(79726, 10)


### 六、生成kaggle提交文件

In [16]:
image_names = []
for root, dirs, file_name in os.walk(join(test_image_path, '0')):
    image_names.append(file_name)
image_names = np.array(image_names).transpose()

In [23]:
result_vgg16 = np.append(image_names, pred_vgg16, axis = 1)
predict_result = pd.DataFrame(result_vgg16, 
            columns=['img', 'c0', 'c1', 'c2', 'c3', 'c4', 'c5', 'c6', 'c7', 'c8', 'c9'])
predict_result.to_csv('result_vgg16.csv', index=False)