In [54]:
import numpy as np
np.random.seed(2066)

import os
import glob
import cv2
import math
import pickle

from keras.utils import np_utils

In [55]:
#是否使用缓存标志位
use_cache = 1

#颜色类型标志位，1为灰色就，3为RGB彩色
color_type_global = 1

In [56]:
# 使用opencv获取图片数据
def get_image_cv2(path, image_rows, image_cols, color_type=1):
    if color_type == 1:
        image = cv2.imread(path, 0)
    elif color_type == 3:
        image = cv2.imread(path)
    img_resized = cv2.resize(image, (image_cols, image_rows))
    return img_resized

In [57]:
#获取驾驶员数据
def get_driver_data():
    dd = dict()
    path = os.path.join('data', 'driver_imgs_list.csv')
    print('Read drivers data')
    f = open(path, 'r')
    line = f.readline()
    while (1):
        line = f.readline()
        if line == '':
            break
        arr = line.strip().split(',')
        dd[arr[2]] = arr[0]
    f.close()
    return dd

In [58]:
#载入训练数据
def load_train(image_rows, image_cols, color_type=1):
    X_train = []
    y_train = []
    driver_id = []

    driver_data = get_driver_data()
    
    print('Read train images')
    for j in range(10):
        path = os.path.join('data', 'train', 'c' + str(j), '*.jpg') #?input
        files = glob.glob(path)
        for file in files:
            filebase = os.path.basename(file)
            image = get_image_cv2(file, image_rows, image_cols, color_type)
            X_train.append(image)
            y_train.append(j)
            driver_id.append(driver_data[filebase])

    unique_drivers = sorted(list(set(driver_id)))
    return X_train, y_train, driver_id, unique_drivers

In [59]:
def load_test(image_rows, image_cols, color_type=1):
    print('Read test images')
    path = os.path.join('data', 'test', '*.jpg')
    files = glob.glob(path)
    X_test = []
    X_test_id = []
    total = 0
    thr = math.floor(len(files)/10)
    for fl in files:
        flbase = os.path.basename(fl)
        image = get_image_cv2(fl, image_rows, image_cols, color_type)
        X_test.append(image)
        X_test_id.append(flbase)
        total += 1
        if total%thr == 0:
            print('Read {} images from {}'.format(total, len(files)))

    return X_test, X_test_id

In [60]:
#缓存数据
def cache_data(data, path):
    if not os.path.isdir('cache'):
        os.mkdir('cache')
    if os.path.isdir(os.path.dirname(path)):
        file = open(path, 'wb')
        pickle.dump(data, file)
        file.close()
    else:
        print('Directory doesnt exists')

In [61]:
#恢复数据
def restore_data (path):
    data = dict()
    if os.path.isfile(path):
        print('Restore data from pickle........')
        file = open(path, 'rb')
        data = pickle.load(file)
    return data

In [62]:
#读取和规范化训练数据
def read_and_normalize_train_data(image_rows, image_cols, color_type=1):
    cache_path = os.path.join('cache', 'train_r_' + str(image_rows) + '_c_' + str(image_cols) + '_t_' + str(color_type) + '.dat')
    if not os.path.isfile(cache_path) or use_cache == 0:
        train_data, train_target, driver_id, unique_drivers = load_train(image_rows, image_cols, color_type)
        cache_data((train_data, train_target, driver_id, unique_drivers), cache_path)
    else:
        print('Restore train from cache!')
        (train_data, train_target, driver_id, unique_drivers) = restore_data(cache_path)

    train_data = np.array(train_data, dtype=np.uint8)
    train_target = np.array(train_target, dtype=np.uint8)
    train_data = train_data.reshape(train_data.shape[0], color_type, image_rows, image_cols)
    train_target = np_utils.to_categorical(train_target, 10)
    train_data = train_data.astype('float32')
    train_data /= 255
    print('Train shape:', train_data.shape)
    print(train_data.shape[0], 'train samples')
    return train_data, train_target, driver_id, unique_drivers


In [63]:
#读取和规范化测试数据
def read_and_normalize_test_data(image_rows, image_cols, color_type=1):
    cache_path = os.path.join('cache', 'test_r_' + str(image_rows) + '_c_' + str(image_cols) + '_t_' + str(color_type) + '.dat')
    if not os.path.isfile(cache_path) or use_cache == 0:
        test_data, test_id = load_test(image_rows, image_cols, color_type)
        cache_data((test_data, test_id), cache_path)
    else:
        print('Restore test from cache!')
        (test_data, test_id) = restore_data(cache_path)

    test_data = np.array(test_data, dtype=np.uint8)
    test_data = test_data.reshape(test_data.shape[0], color_type, image_rows, image_cols)
    test_data = test_data.astype('float32')
    test_data /= 255
    print('Test shape:', test_data.shape)
    print(test_data.shape[0], 'test samples')
    return test_data, test_id

In [64]:
#复制选定驾驶员的训练数据集
def copy_selected_drivers(train_data, train_target, driver_id, driver_list):
    data = []
    target = []
    index = []
    for i in range(len(driver_id)):
        if driver_id[i] in driver_list:
            data.append(train_data[i])
            target.append(train_target[i])
            index.append(i)
    data = np.array(data, dtype = np.float32)
    target = np.array(target, dtype = np.float32)
    index = np.array(index, dtype = np.float32)
    return data, target, index

In [65]:
def cache_preprocess(data, path):
    if not os.path.isdir('preprocess'):
        os.mkdir('preprocess')
    if os.path.isdir(os.path.dirname(path)):
        file = open(path, 'wb')
        pickle.dump(data, file)
        file.close()
    else:
        print('Directory doesnt exists')

In [70]:
image_rows, image_cols = 56, 56


train_data, train_target, driver_id, unique_drivers = read_and_normalize_train_data(image_rows, image_cols, color_type_global)
test_data, test_id = read_and_normalize_test_data(image_rows, image_cols, color_type_global)

yfull_train = dict()
yfull_test = []
unique_list_train = ['p002', 'p012', 'p014', 'p015', 'p016', 'p021', 'p022', 'p024',
                     'p026', 'p035', 'p039', 'p041', 'p042', 'p045', 'p047', 'p049',
                     'p050', 'p051', 'p052', 'p056', 'p061', 'p064', 'p066', 'p072',
                     'p075']
X_train, Y_train, train_index = copy_selected_drivers(train_data, train_target, driver_id, unique_list_train)
unique_list_valid = ['p081']
X_valid, Y_valid, valid_index = copy_selected_drivers(train_data, train_target, driver_id, unique_list_valid)


print('Split train: ', len(X_train), len(Y_train))
print('Split valid: ', len(X_valid), len(Y_valid))
print('Train drivers: ', unique_list_train)
print('Test drivers: ', unique_list_valid)


Read drivers data
Read train images
Train shape: (22424, 1, 56, 56)
22424 train samples
Read test images
Read 100 images from 1000
Read 200 images from 1000
Read 300 images from 1000
Read 400 images from 1000
Read 500 images from 1000
Read 600 images from 1000
Read 700 images from 1000
Read 800 images from 1000
Read 900 images from 1000
Read 1000 images from 1000
Test shape: (1000, 1, 56, 56)
1000 test samples
Split train:  21601 21601
Split valid:  823 823
Train drivers:  ['p002', 'p012', 'p014', 'p015', 'p016', 'p021', 'p022', 'p024', 'p026', 'p035', 'p039', 'p041', 'p042', 'p045', 'p047', 'p049', 'p050', 'p051', 'p052', 'p056', 'p061', 'p064', 'p066', 'p072', 'p075']
Test drivers:  ['p081']


In [72]:
cache_preprocess_path = os.path.join('preprocess','preprocess' + '.dat')
cache_preprocess((X_train, Y_train, train_index,X_valid, Y_valid, valid_index, test_data, test_id, unique_drivers), cache_preprocess_path)