In [1]:
import numpy as np
from skimage import io, color, exposure, transform
from sklearn.cross_validation import train_test_split
import os
import glob
import h5py

from keras.preprocessing.image import ImageDataGenerator
from keras.models import Sequential, model_from_json
from keras.layers.core import Dense, Dropout, Activation, Flatten
from keras.layers.convolutional import Conv2D
from keras.layers.pooling import MaxPooling2D

from keras.optimizers import SGD
from keras.utils import np_utils
from keras.callbacks import LearningRateScheduler, ModelCheckpoint
from keras import backend as K
K.set_image_data_format('channels_first')

from matplotlib import pyplot as plt
%matplotlib inline

NUM_CLASSES = 43
IMG_SIZE = 48

Using TensorFlow backend.


## Function to preprocess the image:

In [23]:
def preprocess_img(img):
    # Histogram normalization in y
    hsv = color.rgb2hsv(img)
    hsv[:,:,2] = exposure.equalize_hist(hsv[:,:,2])
    img = color.hsv2rgb(hsv)

    # central scrop
    min_side = min(img.shape[:-1])
    centre = img.shape[0]//2, img.shape[1]//2
    img = img[centre[0]-min_side//2:centre[0]+min_side//2,
              centre[1]-min_side//2:centre[1]+min_side//2,
              :]

    # rescale to standard size
    img = transform.resize(img, (IMG_SIZE, IMG_SIZE))

    # roll color axis to axis 0
    img = np.rollaxis(img,-1)

    return img

import os
files=os.listdir(os.getcwd()+"/Images")
for file in files:
    print(preprocess_img(io.imread(os.getcwd()+"/Images/"+file)))

  warn("The default mode, 'constant', will be changed to 'reflect' in "


[[[ 0.95562632  0.89991279  0.72989447 ...,  0.89991279  0.87416613
    0.81605349]
  [ 0.89014794  0.84009639  0.95660775 ...,  0.85211135  0.90650599
    0.96273012]
  [ 0.85822543  0.93254366  0.851565   ...,  0.74820469  0.96303476
    0.98259277]
  ..., 
  [ 0.99859183  0.99445824  0.83476649 ...,  0.31089059  0.29419185
    0.3434062 ]
  [ 0.81202395  0.89379527  0.65036106 ...,  0.30218902  0.20324431
    0.18924848]
  [ 0.99622959  0.99810167  0.93456067 ...,  0.23045018  0.26219873
    0.26219873]]

 [[ 0.92961832  0.87833561  0.72631333 ...,  0.89559735  0.86995451
    0.81208781]
  [ 0.87084132  0.82341731  0.93391336 ...,  0.85018147  0.89874088
    0.94452411]
  [ 0.85407964  0.90689287  0.84745211 ...,  0.7387043   0.95847621
    0.95958122]
  ..., 
  [ 0.99451891  0.98879355  0.83071967 ...,  0.3028276   0.28654855
    0.33454325]
  [ 0.80638274  0.88951893  0.64713858 ...,  0.29433655  0.19786263
    0.1842128 ]
  [ 0.99181393  0.98966073  0.93015842 ...,  0.2243884   0

[[[ 0.16570963  0.18130712  0.17712916 ...,  0.21671599  0.22611891
    0.20075117]
  [ 0.18590425  0.20075117  0.21671599 ...,  0.2521776   0.23238333
    0.21386811]
  [ 0.19324901  0.21121029  0.21973012 ...,  0.26828679  0.24866388
    0.22927232]
  ..., 
  [ 0.0231172   0.0262749   0.01548654 ...,  0.10989158  0.10989158
    0.0993709 ]
  [ 0.11323848  0.0764356   0.14828767 ...,  0.14285147  0.14285147
    0.11978929]
  [ 0.11650953  0.15988351  0.16380829 ...,  0.05654124  0.15564658
    0.09592988]]

 [[ 0.11684653  0.12649334  0.12230347 ...,  0.14086539  0.14489173
    0.134546  ]
  [ 0.12886544  0.13241035  0.14519971 ...,  0.15903092  0.14606952
    0.14257874]
  [ 0.13166416  0.14008846  0.13923492 ...,  0.15958438  0.15598007
    0.14770429]
  ..., 
  [ 0.01062142  0.01244601  0.00637681 ...,  0.05115643  0.05494579
    0.04697533]
  [ 0.051821    0.02651847  0.06990705 ...,  0.06302271  0.06092195
    0.05105773]
  [ 0.04660381  0.07248053  0.07445831 ...,  0.02010355  0

[[[ 0.45226717  0.61643852  0.69420397 ...,  1.          0.92220456  1.        ]
  [ 0.42824484  0.83699096  0.62225175 ...,  0.99607843  1.          0.85028489]
  [ 0.55940995  0.61017411  0.57903633 ...,  0.86559949  1.          0.95446178]
  ..., 
  [ 0.91197662  0.92220456  0.93659118 ...,  0.8388006   0.79432925
    0.78845304]
  [ 0.90008961  0.92220456  0.94485233 ...,  0.80942928  0.78436829
    0.78436829]
  [ 0.91677676  0.91944298  0.92502191 ...,  0.81539205  0.78711882
    0.78576301]]

 [[ 0.47811101  0.64212346  0.71492648 ...,  1.          0.80692899
    0.99607843]
  [ 0.45343571  0.83699096  0.64764978 ...,  1.          1.          0.7165923 ]
  [ 0.57305409  0.63613896  0.60596825 ...,  0.79140525  1.          0.87059694]
  ..., 
  [ 0.86637779  0.87691773  0.89160208 ...,  0.80360617  0.74236379
    0.73434352]
  [ 0.85783188  0.88926868  0.9000533  ...,  0.75635195  0.72098499
    0.72098499]
  [ 0.87135088  0.87408929  0.87979861 ...,  0.76402877  0.72477277
    0

[[[ 0.60998366  0.58525553  0.70579272 ...,  0.42743059  0.35189056
    0.41595489]
  [ 0.85096886  0.70579272  0.68391148 ...,  0.51372549  0.46223459
    0.1998112 ]
  [ 0.6242511   0.52299535  0.77551612 ...,  0.37240343  0.34281802
    0.41915028]
  ..., 
  [ 0.76346187  0.71646998  0.65230872 ...,  0.76346187  0.73935968
    0.92754025]
  [ 0.76933249  0.78941073  0.70053354 ...,  0.75168632  0.79695108
    0.82115113]
  [ 0.84448387  0.73935968  0.76346187 ...,  0.76346187  0.75763296
    0.77551612]]

 [[ 0.4289995   0.41099405  0.53024022 ...,  0.71108907  0.64513269
    0.55951454]
  [ 0.62763766  0.50157859  0.49964517 ...,  0.89803922  0.84289838
    0.38089009]
  [ 0.43426163  0.3414852   0.57240475 ...,  0.70952653  0.63888813
    0.78345847]
  ..., 
  [ 0.55791444  0.53018779  0.47535263 ...,  0.56158493  0.53277389
    0.68489259]
  [ 0.56687657  0.58088714  0.5146777  ...,  0.55464233  0.57994093
    0.59685522]
  [ 0.62952434  0.5363982   0.56158493 ...,  0.55791444  0

[[[ 0.91119168  0.92757683  0.92740209 ...,  0.69409632  0.70745369
    0.68829655]
  [ 0.91574272  0.9335333   0.92012485 ...,  0.69101036  0.69045841
    0.71316667]
  [ 0.91372085  0.91168424  0.90407943 ...,  0.69076296  0.67322422
    0.69224423]
  ..., 
  [ 0.28582739  0.3007048   0.05866969 ...,  0.15006416  0.12710631
    0.18757565]
  [ 0.2223027   0.35246161  0.35345472 ...,  0.16799963  0.1211764
    0.08255369]
  [ 0.24205907  0.3488592   0.28656143 ...,  0.06412484  0.06422448
    0.12803205]]

 [[ 0.91574478  0.9321499   0.9319661  ...,  0.69040269  0.71119929
    0.68829655]
  [ 0.92028046  0.9335333   0.92470259 ...,  0.69101036  0.68311802
    0.70568626]
  [ 0.90916582  0.91432679  0.90407943 ...,  0.69076296  0.67322422
    0.69072   ]
  ..., 
  [ 0.28211673  0.29682393  0.05790845 ...,  0.14800848  0.12710631
    0.18504628]
  [ 0.22384314  0.35471854  0.35117625 ...,  0.16605654  0.1211764
    0.08194053]
  [ 0.24369317  0.34660513  0.28848729 ...,  0.06412484  0.0

[[[ 0.52074084  0.56049367  0.68192599 ...,  0.12423447  0.52898491
    0.19219239]
  [ 0.45043389  0.35223063  0.07664397 ...,  0.83613897  0.83807859
    0.83711415]
  [ 0.72858356  0.73365443  0.70276611 ...,  0.70960615  0.72286183
    0.73365443]
  ..., 
  [ 0.3094606   0.63582916  0.70960615 ...,  0.90161495  0.94408381
    0.94408381]
  [ 0.65504535  0.73365443  0.71644516 ...,  0.97356236  0.95760202
    0.96730144]
  [ 0.73814659  0.73604741  0.72858356 ...,  0.97356236  0.94408381
    0.97356236]]

 [[ 0.4284109   0.48371371  0.58874356 ...,  0.0980072   0.42467802
    0.1503287 ]
  [ 0.37253931  0.30603645  0.06288736 ...,  0.73367096  0.74857505
    0.76361144]
  [ 0.61582658  0.63380797  0.61277776 ...,  0.60639071  0.6189775
    0.63380797]
  ..., 
  [ 0.26978616  0.60729837  0.66659971 ...,  0.84967005  0.89035546
    0.89035546]
  [ 0.60114921  0.69024293  0.66465394 ...,  0.91882391  0.90332498
    0.91269571]
  [ 0.69929677  0.6927505   0.68521549 ...,  0.91882391  0.

[[[ 0.15296547  0.16482052  0.15296547 ...,  0.06532572  0.05988191
    0.07153535]
  [ 0.17897238  0.16714093  0.15506088 ...,  0.1133077   0.08552794
    0.07749663]
  [ 0.23989431  0.2664889   0.16690686 ...,  0.41193171  0.36290143
    0.32554968]
  ..., 
  [ 0.13644453  0.77900686  0.75469603 ...,  0.45518375  0.6083099
    0.44029312]
  [ 0.26311194  0.77110247  0.81714084 ...,  0.42876906  0.80115719
    0.42883834]
  [ 0.85627522  0.67763368  0.62728926 ...,  0.36901092  0.76759747
    0.65185596]]

 [[ 0.20954174  0.21489258  0.20954174 ...,  0.07621334  0.06532572
    0.07749663]
  [ 0.20213351  0.21310469  0.20954174 ...,  0.1133077   0.10996449
    0.08941919]
  [ 0.20266933  0.23687902  0.21489258 ...,  0.3306294   0.28781838
    0.2500599 ]
  ..., 
  [ 0.14907829  0.63886806  0.77110247 ...,  0.35081041  0.65282038
    0.4730176 ]
  [ 0.25258746  0.62754616  0.80871671 ...,  0.34524262  0.87868854
    0.46408533]
  [ 0.70642705  0.64717824  0.59484326 ...,  0.31340653  0.

[[[ 0.06445464  0.07617366  0.33086507 ...,  0.45117688  0.42029596
    0.66125   ]
  [ 0.08015248  0.09769576  0.35725763 ...,  0.62091674  0.80165233
    0.76803975]
  [ 0.03389505  0.11621209  0.44678341 ...,  0.48198712  0.42924281
    0.42065626]
  ..., 
  [ 0.14659263  0.15853073  0.16312201 ...,  0.79772611  0.84484372
    0.61429628]
  [ 0.16673767  0.16312201  0.26294024 ...,  0.77916086  0.77916086
    0.67816517]
  [ 0.15475678  0.18523379  0.23987648 ...,  0.1364341   0.25147064
    0.85694384]]

 [[ 0.08789269  0.08789269  0.25945534 ...,  0.49080729  0.46438295
    0.72064372]
  [ 0.08631806  0.10990773  0.28282896 ...,  0.66607432  0.87923158
    0.84484372]
  [ 0.04142729  0.08715907  0.35513553 ...,  0.37354002  0.34562408
    0.33817464]
  ..., 
  [ 0.20941804  0.21485086  0.21122209 ...,  0.80190269  0.83630995
    0.58233867]
  [ 0.21675897  0.21122209  0.23769798 ...,  0.63487181  0.63074926
    0.65149575]
  [ 0.21122209  0.20207322  0.20472217 ...,  0.14906688  0

[[[ 0.84576942  0.84576942  0.80556945 ...,  0.50144755  0.48526538
    0.44949183]
  [ 0.85954328  0.84576942  0.80556945 ...,  0.4748984   0.4748984
    0.47267925]
  [ 0.82406009  0.87443696  0.80556945 ...,  0.50167149  0.4748984
    0.47687462]
  ..., 
  [ 0.93706908  0.88555268  0.85656647 ...,  0.30783018  0.29656918
    0.29605309]
  [ 0.86924794  0.83072218  0.83072218 ...,  0.30463513  0.29656918
    0.29332057]
  [ 0.85279751  0.8089974   0.83072218 ...,  0.32721337  0.32761733
    0.32862662]]

 [[ 0.83857138  0.83857138  0.79871354 ...,  0.49441791  0.47844509
    0.44527429]
  [ 0.85224316  0.83857138  0.79871354 ...,  0.47046009  0.47046009
    0.46602179]
  [ 0.81707229  0.86702648  0.79871354 ...,  0.49700326  0.47046009
    0.47016289]
  ..., 
  [ 0.91764649  0.87073172  0.8421704  ...,  0.29858284  0.28753662
    0.2870296 ]
  [ 0.85466598  0.81670156  0.81670156 ...,  0.29851182  0.28904204
    0.28584105]
  [ 0.83845619  0.79528558  0.81670156 ...,  0.31781467  0.3

[[[ 0.15719408  0.02041717  0.03334301 ...,  0.1560478   0.16859127
    0.00734659]
  [ 0.13350929  0.01776629  0.02416204 ...,  0.16469415  0.15296247
    0.00312186]
  [ 0.05148444  0.01906707  0.0629858  ...,  0.15602911  0.1593511
    0.0061011 ]
  ..., 
  [ 0.54613499  0.52754727  0.63873291 ...,  0.72593685  0.69577461
    0.61642223]
  [ 0.46457041  0.57699942  0.38729998 ...,  0.63758416  0.64355601
    0.67875914]
  [ 0.52922992  0.5612159   0.64355601 ...,  0.70360107  0.6383707
    0.72019541]]

 [[ 0.13667556  0.01563928  0.02616254 ...,  0.15447194  0.1702438
    0.00808747]
  [ 0.10796837  0.01438386  0.02052461 ...,  0.16309382  0.15296247
    0.00357376]
  [ 0.03693457  0.01692886  0.04655446 ...,  0.15607341  0.15775923
    0.0061011 ]
  ..., 
  [ 0.51714125  0.49922474  0.62801354 ...,  0.70658255  0.68443096
    0.60255483]
  [ 0.4529744   0.56190254  0.37437988 ...,  0.6223544   0.63275644
    0.67126907]
  [ 0.51661197  0.5382479   0.62012334 ...,  0.69032476  0.62

[[[  3.20334827e-01   3.19064516e-01   3.21190708e-01 ...,   6.14376813e-01
     6.14673985e-01   4.49687980e-01]
  [  3.20334827e-01   3.33747647e-01   3.42113563e-01 ...,   6.28819751e-01
     6.02826635e-01   5.26982536e-01]
  [  2.86305109e-01   3.24339012e-01   3.27142098e-01 ...,   2.28525123e-01
     6.02433526e-01   6.03628945e-01]
  ..., 
  [  3.24361716e-02   3.07791704e-01   4.98622782e-01 ...,   8.30943606e-01
     9.22440931e-01   8.86006392e-01]
  [  4.27762398e-04   7.28471689e-02   3.82479721e-01 ...,   9.14654615e-01
     8.53084795e-01   8.53084795e-01]
  [  9.63972092e-03   6.18569337e-03   2.29692235e-01 ...,   7.61254236e-01
     8.00019973e-01   8.75280264e-01]]

 [[  3.36351568e-01   3.39258472e-01   3.37051977e-01 ...,   6.28918276e-01
     6.14673985e-01   4.49687980e-01]
  [  3.36351568e-01   3.50028020e-01   3.54479113e-01 ...,   6.39724140e-01
     6.02826635e-01   5.23020261e-01]
  [  2.97757314e-01   3.28343197e-01   3.43297264e-01 ...,   2.16172414e-01
  

ValueError: the input array must be have a shape == (.., ..,[ ..,] 3)), got (1062, 1610, 4)

## Preprocess all training images into a numpy array

In [11]:
try:
    with  h5py.File('X.h5') as hf: 
        X, Y = hf['imgs'][:], hf['labels'][:]
    print("Loaded images from X.h5")
    
except (IOError,OSError, KeyError):  
    print("Error in reading X.h5. Processing all images...")
    root_dir = os.cwd()+"/Images/"
    imgs = []
    labels = []

    all_img_paths = glob.glob(os.path.join(root_dir, '*/*.ppm'))
    np.random.shuffle(all_img_paths)
    for img_path in all_img_paths:
        try:
            img = preprocess_img(io.imread(img_path))
            label = get_class(img_path)
            imgs.append(img)
            labels.append(label)

            if len(imgs)%1000 == 0: print("Processed {}/{}".format(len(imgs), len(all_img_paths)))
        except (IOError, OSError):
            print('missed', img_path)
            pass

    X = np.array(imgs, dtype='float32')
    Y = np.eye(NUM_CLASSES, dtype='uint8')[labels]
    print(X)
    with h5py.File('X.h5','w') as hf:
        hf.create_dataset('imgs', data=X)
        hf.create_dataset('labels', data=Y)

Loaded images from X.h5


# Define Keras model

In [12]:
def cnn_model():
    model = Sequential()

    model.add(Conv2D(32, (3, 3), padding='same',
                     input_shape=(3, IMG_SIZE, IMG_SIZE),
                     activation='relu'))
    model.add(Conv2D(32, (3, 3), activation='relu'))
    model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(Dropout(0.2))

    model.add(Conv2D(64, (3, 3), padding='same',
                     activation='relu'))
    model.add(Conv2D(64, (3, 3), activation='relu'))
    model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(Dropout(0.2))

    model.add(Conv2D(128, (3, 3), padding='same',
                     activation='relu'))
    model.add(Conv2D(128, (3, 3), activation='relu'))
    model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(Dropout(0.2))

    model.add(Flatten())
    model.add(Dense(512, activation='relu'))
    model.add(Dropout(0.5))
    model.add(Dense(NUM_CLASSES, activation='softmax'))
    return model

model = cnn_model()
# let's train the model using SGD + momentum (how original).
lr = 0.01
sgd = SGD(lr=lr, decay=1e-6, momentum=0.9, nesterov=True)
model.compile(loss='categorical_crossentropy',
          optimizer=sgd,
          metrics=['accuracy'])


def lr_schedule(epoch):
    return lr*(0.1**int(epoch/10))

# Start Training

In [13]:
batch_size = 32
nb_epoch = 30
print(X)
# model.fit(X, Y,
#           batch_size=batch_size,
#           epochs=nb_epoch,
#           validation_split=0.2,
#           shuffle=True,
#           callbacks=[LearningRateScheduler(lr_schedule),
#                     ModelCheckpoint('model.h5',save_best_only=True)]
#             )

[]


# Load Test data

In [None]:
import pandas as pd
test = pd.read_csv('GT-final_test.csv',sep=';')

X_test = []
y_test = []
i = 0
for file_name, class_id  in zip(list(test['Filename']), list(test['ClassId'])):
    img_path = os.path.join('GTSRB/Final_Test/Images/',file_name)
    X_test.append(preprocess_img(io.imread(img_path)))
    y_test.append(class_id)
    
X_test = np.array(X_test)
y_test = np.array(y_test)

In [None]:
y_pred = model.predict_classes(X_test)
acc = np.sum(y_pred==y_test)/np.size(y_pred)
print("Test accuracy = {}".format(acc))

# With Data augmentation

In [None]:
from sklearn.cross_validation import train_test_split

X_train, X_val, Y_train, Y_val = train_test_split(X, Y, test_size=0.2, random_state=42)

datagen = ImageDataGenerator(featurewise_center=False, 
                            featurewise_std_normalization=False, 
                            width_shift_range=0.1,
                            height_shift_range=0.1,
                            zoom_range=0.2,
                            shear_range=0.1,
                            rotation_range=10.,)

datagen.fit(X_train)

In [None]:
# Reinstallise models 

model = cnn_model()
# let's train the model using SGD + momentum (how original).
lr = 0.01
sgd = SGD(lr=lr, decay=1e-6, momentum=0.9, nesterov=True)
model.compile(loss='categorical_crossentropy',
          optimizer=sgd,
          metrics=['accuracy'])


def lr_schedule(epoch):
    return lr*(0.1**int(epoch/10))

In [None]:
nb_epoch = 30
model.fit_generator(datagen.flow(X_train, Y_train, batch_size=batch_size),
                            steps_per_epoch=X_train.shape[0],
                            epochs=nb_epoch,
                            validation_data=(X_val, Y_val),
                            callbacks=[LearningRateScheduler(lr_schedule),
                                       ModelCheckpoint('model.h5',save_best_only=True)]
                           )

In [None]:
y_pred = model.predict_classes(X_test)
acc = np.sum(y_pred==y_test)/np.size(y_pred)
print("Test accuracy = {}".format(acc))

In [None]:
model.summary()

In [None]:
model.count_params()