In [1]:
import cv2
from keras.callbacks import *
from keras.layers import *
from keras.models import *
from keras.preprocessing import image
import numpy as np
from keras.applications.vgg16 import VGG16
from keras.utils import to_categorical
import sys
from tqdm import tqdm

Using TensorFlow backend.


In [2]:
batch_size=64
gene=4
width=150
height=80
num_class=100

In [4]:
datagen = image.ImageDataGenerator(featurewise_center=False,
                                    samplewise_center=False,
                                    featurewise_std_normalization=False,
                                    samplewise_std_normalization=False,
                                    zca_whitening=False,
                                    rotation_range=0.,
                                    width_shift_range=0.,
                                    height_shift_range=0.,
                                    shear_range=0.,
                                    zoom_range=0.,
                                    channel_shift_range=0.,
                                    fill_mode='nearest',
                                    cval=0.0,
                                    horizontal_flip=False,
                                    vertical_flip=False,
                                    rescale=None,
                                    preprocessing_function=None,
                                    data_format=K.image_data_format(),
                                    )

In [3]:
def read_data(flag, path='../../../../assets/brand_images'):
    """
    read the dataset，dir format must be: ./train/, ./train.txt, ./test/, ./test.txt
    :param flag: 'train' or 'test'
    :param path
    """
    content = open(os.path.join(path, '%s.txt' % flag))

    imgs = []
    labels = []

    lines = content.readlines()
    # tqdm is the progress bar, please install it with "pip install tqdm".
    # if u wan't, you can replace tqdm(lines) with lines.
    for i in tqdm(lines):
        fname, y = i.replace('\n', '').split(' ')
        y = int(y)

        x = os.path.join(path, flag, fname)

        imgs.append(x)
        labels.append(y)
    return np.array(imgs), np.array(labels) - 1

files, labels = read_data('train')

100%|██████████| 2725/2725 [00:00<00:00, 127039.37it/s]


In [5]:
def data_gen(batch_size=128, gene=4):
    """
    generate the data with the size batch_size
    :param batch_size:
    :param gene:
    :return:
    """
    X_ = np.zeros([batch_size, height, width, 3], np.uint8)
    y_ = np.zeros([batch_size, 1], np.uint8)

    while True:

        index = np.random.choice(len(labels), batch_size, replace=False)
        label_list = labels[index]
        file_list = files[index]

        for i in range(len(file_list)):
            fname = file_list[i]
            img = cv2.imread(fname)
            img = cv2.resize(img, (width, height))
            img = img / 255.0
            X_[i] = img
            y_[i] = np.array(label_list[i])

        i = 0
        X = np.zeros([batch_size * gene, height, width, 3], np.uint8)
        y = np.zeros([batch_size * gene, 1], np.uint8)
        for batch in datagen.flow(X_, y_, batch_size=batch_size):
            X[i * batch_size: (i + 1) * batch_size] = batch[0]
            y[i * batch_size: (i + 1) * batch_size] = batch[1]

            i += 1
            if i >= gene:
                break

        yield X, to_categorical(np.array(y), num_classes=100)

In [8]:
input_tensor = Input((height, width, 3))
# vgg = VGG16(weights='../../../../models/VGG16_WEIGHTS.h5', include_top=False, input_tensor=input_tensor)
# print('the last vgg layer is ', vgg.layers[-1])
# print('the last vgg layer output is ', vgg.output)

# tensor_shape = vgg.output.shape
# print(tensor_shape)

# rnn_length = tensor_shape[1].value
# rnn_dimen = tensor_shape[2].value * tensor_shape[3].value
# units = tensor_shape[3].value
# print(rnn_length, rnn_dimen, units)
x = input_tensor
for i in range(3):
    x = Conv2D(32 * 2 ** i, (3, 3), kernel_initializer='he_normal')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = Conv2D(32 * 2 ** 2, (3, 3), kernel_initializer='he_normal')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = MaxPool2D(pool_size=(2, 2))(x)

x = Flatten()(x)
x = Dense(128, kernel_initializer='he_normal')(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = Dropout(0.25)(x)
x = Dense(100, kernel_initializer='he_normal', activation='softmax')(x)
# print('now x\'s shape:', x.shape)

base_model = Model(input=input_tensor, output=x)
modellabels = Input(name='labels', shape=[num_class], dtype='float32')
model = Model(inputs=input_tensor, outputs=x)



In [9]:
# model.compile(loss='mean_squared_error', optimizer='adam')
model.compile(loss='mean_squared_logarithmic_error', optimizer='adam')

model.fit_generator(
    data_gen(batch_size, gene),
    steps_per_epoch=1,
    epochs=1,
    validation_data=data_gen(1, 1),
    validation_steps=1
)
base_model.save('brand_classify_vgg16.h5')
base_model.save_weights('brand_classify_vgg16_weights.h5')

Epoch 1/1


In [12]:
bb = 0
for i in range(2525):
    ff, ll = files[i], labels[i]
    ll = to_categorical(np.array(ll), num_classes=100)
    img = cv2.imread(ff)
    img = cv2.resize(img, (150, 80)) / 255.0
    img = np.expand_dims(img, 0)
    print('mean', np.mean(base_model.predict(img) - ll))
    print('mse', np.mean(np.square(base_model.predict(img) - ll)))
    bb += np.mean(np.square(base_model.predict(img) - ll))
print(bb / 2525.0)

('mean', 1.1350494e-10)
('mse', 0.016619993)
('mean', 8.105497e-11)
('mse', 0.018495856)
('mean', -4.7239154e-11)
('mse', 0.014919)
('mean', -1.3562385e-10)
('mse', 0.01354617)
('mean', 2.1787137e-09)
('mse', 0.015260234)
('mean', 7.94098e-10)
('mse', 0.016874498)
('mean', -1.254848e-09)
('mse', 0.01587939)
('mean', 5.994365e-10)
('mse', 0.0198724)
('mean', 1.3916462e-09)
('mse', 0.018332586)
('mean', -7.2866896e-10)
('mse', 0.01979855)
('mean', -1.6943522e-09)
('mse', 0.016448675)
('mean', 4.7941284e-10)
('mse', 0.0143881785)
('mean', -4.1244042e-10)
('mse', 0.01980546)
('mean', -1.2933742e-09)
('mse', 0.0178015)
('mean', 1.2188746e-09)
('mse', 0.014224026)
('mean', 2.840534e-10)
('mse', 0.016234748)
('mean', -1.700555e-09)
('mse', 0.014941922)
('mean', 9.690686e-10)
('mse', 0.01940785)
('mean', -3.050078e-10)
('mse', 0.014642507)
('mean', -1.8154722e-10)
('mse', 0.01723555)
('mean', 1.6681588e-09)
('mse', 0.018901134)
('mean', 1.1328075e-09)
('mse', 0.017725592)
('mean', -1.4700344e-

('mean', 3.1590403e-10)
('mse', 0.016307915)
('mean', 2.7319915e-09)
('mse', 0.018125772)
('mean', 3.8882716e-10)
('mse', 0.013011902)
('mean', -1.2246889e-09)
('mse', 0.016040074)
('mean', 4.0978193e-10)
('mse', 0.01955565)
('mean', -5.0029464e-10)
('mse', 0.019035159)
('mean', -4.638858e-10)
('mse', 0.018941557)
('mean', -7.916013e-11)
('mse', 0.016502626)
('mean', 6.472692e-10)
('mse', 0.0152791785)
('mean', 7.385097e-10)
('mse', 0.016285129)
('mean', -6.0290406e-10)
('mse', 0.01699071)
('mean', -1.1767043e-10)
('mse', 0.019536717)
('mean', 5.7298166e-11)
('mse', 0.019208003)
('mean', -2.2078166e-09)
('mse', 0.012923939)
('mean', -2.2378117e-11)
('mse', 0.018221602)
('mean', -9.1239544e-10)
('mse', 0.019423457)
('mean', -2.4301696e-11)
('mse', 0.019607672)
('mean', -7.3106904e-10)
('mse', 0.013603359)
('mean', 1.5950689e-09)
('mse', 0.013017731)
('mean', 4.7957655e-11)
('mse', 0.01510085)
('mean', -3.3527592e-10)
('mse', 0.0148847345)
('mean', -5.9932065e-10)
('mse', 0.014463136)
('

('mean', -3.1664665e-10)
('mse', 0.014113316)
('mean', -1.0604717e-09)
('mse', 0.014243874)
('mean', -3.9784936e-10)
('mse', 0.014450269)
('mean', -5.2140364e-10)
('mse', 0.012495451)
('mean', 8.568623e-10)
('mse', 0.016170293)
('mean', -1.6778355e-09)
('mse', 0.014364579)
('mean', -2.1470115e-10)
('mse', 0.015121379)
('mean', -1.2393685e-10)
('mse', 0.019893475)
('mean', -8.766074e-10)
('mse', 0.015812425)
('mean', -6.239854e-10)
('mse', 0.018883616)
('mean', 4.3723948e-11)
('mse', 0.017083792)
('mean', 7.537892e-11)
('mse', 0.015682295)
('mean', 5.7963917e-10)
('mse', 0.016945593)
('mean', -1.1532393e-10)
('mse', 0.014566005)
('mean', 1.0206713e-09)
('mse', 0.01750146)
('mean', -1.625085e-10)
('mse', 0.019906892)
('mean', 1.7043204e-09)
('mse', 0.014494338)
('mean', -4.970934e-10)
('mse', 0.016786676)
('mean', -8.614734e-11)
('mse', 0.017203027)
('mean', -5.0274157e-10)
('mse', 0.018227043)
('mean', -1.4691613e-09)
('mse', 0.016132835)
('mean', -5.966285e-11)
('mse', 0.018489515)
('m

('mean', -5.7212673e-10)
('mse', 0.014930708)
('mean', -2.0005972e-09)
('mse', 0.018618027)
('mean', -8.9449087e-10)
('mse', 0.019953167)
('mean', 5.2154064e-10)
('mse', 0.01436078)
('mean', 7.162043e-10)
('mse', 0.019947425)
('mean', -5.6112187e-10)
('mse', 0.015017795)
('mean', 5.7400357e-10)
('mse', 0.015545588)
('mean', 4.8778437e-10)
('mse', 0.014489109)
('mean', 1.2049713e-09)
('mse', 0.017616939)
('mean', -1.1349767e-09)
('mse', 0.015765311)
('mean', -8.2607476e-10)
('mse', 0.014908508)
('mean', 5.634502e-10)
('mse', 0.01798351)
('mean', -2.537849e-10)
('mse', 0.013465685)
('mean', -6.521259e-10)
('mse', 0.019788982)
('mean', 2.0023437e-10)
('mse', 0.015811654)
('mean', 3.651803e-10)
('mse', 0.014244313)
('mean', 4.3946785e-10)
('mse', 0.015990239)
('mean', -1.025619e-09)
('mse', 0.015642975)
('mean', 6.00121e-10)
('mse', 0.015334056)
('mean', 4.4819948e-10)
('mse', 0.013351356)
('mean', -1.3236422e-09)
('mse', 0.014170574)
('mean', 2.0532752e-10)
('mse', 0.018985622)
('mean', 2

('mean', -6.529444e-10)
('mse', 0.01821311)
('mean', 1.5149272e-09)
('mse', 0.019096918)
('mean', -2.4709151e-10)
('mse', 0.01368405)
('mean', -2.1893323e-10)
('mse', 0.018185737)
('mean', 1.8673085e-09)
('mse', 0.01788206)
('mean', 5.873153e-10)
('mse', 0.013948207)
('mean', 1.8362699e-11)
('mse', 0.019578647)
('mean', -5.9375455e-10)
('mse', 0.01737711)
('mean', 7.5029677e-10)
('mse', 0.015713131)
('mean', 1.7185812e-10)
('mse', 0.013890629)
('mean', -1.1688096e-09)
('mse', 0.01733893)
('mean', 6.2267647e-10)
('mse', 0.015200447)
('mean', 5.9604643e-10)
('mse', 0.014478934)
('mean', -5.820766e-10)
('mse', 0.01807682)
('mean', -8.009374e-10)
('mse', 0.01660311)
('mean', -1.810258e-10)
('mse', 0.014692781)
('mean', 6.705524e-10)
('mse', 0.012437427)
('mean', -1.1888911e-09)
('mse', 0.012882736)
('mean', -8.3167834e-10)
('mse', 0.016117275)
('mean', 2.6142574e-10)
('mse', 0.015881145)
('mean', 8.8475643e-11)
('mse', 0.01694494)
('mean', -1.0337681e-09)
('mse', 0.01538088)
('mean', 4.659

('mean', -2.5843894e-10)
('mse', 0.014116512)
('mean', 1.1082745e-09)
('mse', 0.01740745)
('mean', -4.5867637e-10)
('mse', 0.01814954)
('mean', 1.832043e-09)
('mse', 0.019641679)
('mean', 5.0653964e-10)
('mse', 0.017630715)
('mean', 1.527369e-09)
('mse', 0.017483335)
('mean', 1.3371573e-09)
('mse', 0.010982514)
('mean', -1.862645e-11)
('mse', 0.018701807)
('mean', 3.9495263e-10)
('mse', 0.018032856)
('mean', -2.9976935e-10)
('mse', 0.014823666)
('mean', -1.0929896e-09)
('mse', 0.014743167)
('mean', 1.4430497e-10)
('mse', 0.017893359)
('mean', -7.238732e-10)
('mse', 0.012507933)
('mean', 4.1443854e-10)
('mse', 0.010391109)
('mean', -1.9324936e-10)
('mse', 0.015599982)
('mean', -1.3038515e-10)
('mse', 0.018504215)
('mean', -4.958923e-10)
('mse', 0.019978873)
('mean', -2.7868778e-09)
('mse', 0.017068552)
('mean', -1.9133906e-09)
('mse', 0.0122295525)
('mean', -1.2386486e-09)
('mse', 0.01608985)
('mean', 3.7417228e-10)
('mse', 0.014172072)
('mean', -7.373455e-10)
('mse', 0.015430188)
('mea

('mean', 2.8858266e-10)
('mse', 0.019566756)
('mean', -1.151202e-09)
('mse', 0.014304228)
('mean', 5.1716142e-11)
('mse', 0.019049548)
('mean', 4.1933534e-10)
('mse', 0.015371228)
('mean', 6.106984e-10)
('mse', 0.01499717)
('mean', 3.4404365e-10)
('mse', 0.014737474)
('mean', -6.1001626e-10)
('mse', 0.014281525)
('mean', 1.1734664e-09)
('mse', 0.013368561)
('mean', 7.811521e-10)
('mse', 0.01293413)
('mean', 8.009388e-10)
('mse', 0.01452381)
('mean', 1.1560769e-09)
('mse', 0.017637052)
('mean', -3.873515e-10)
('mse', 0.018370114)
('mean', -1.8626451e-11)
('mse', 0.015063056)
('mean', -6.3617334e-10)
('mse', 0.016542086)
('mean', 7.0838724e-10)
('mse', 0.017059615)
('mean', -1.6752151e-09)
('mse', 0.014699484)
('mean', 3.0937372e-10)
('mse', 0.019368358)
('mean', -1.2006786e-09)
('mse', 0.0166929)
('mean', -3.0500813e-10)
('mse', 0.014982269)
('mean', -1.6390913e-09)
('mse', 0.018765872)
('mean', -1.6124608e-09)
('mse', 0.019699002)
('mean', 1.706776e-15)
('mse', 0.013824779)
('mean', 7.

('mean', -1.5017576e-10)
('mse', 0.013338856)
('mean', -4.294816e-10)
('mse', 0.019250352)
('mean', -9.138421e-10)
('mse', 0.0144470325)
('mean', -3.2086972e-10)
('mse', 0.019480612)
('mean', -1.9965228e-10)
('mse', 0.015048545)
('mean', 9.691575e-11)
('mse', 0.018979313)
('mean', -2.577144e-10)
('mse', 0.017292207)
('mean', 2.9693184e-10)
('mse', 0.017505758)
('mean', -1.458102e-10)
('mse', 0.017697144)
('mean', 9.1331404e-10)
('mse', 0.019468987)
('mean', -2.8748135e-09)
('mse', 0.01628879)
('mean', -1.2212502e-09)
('mse', 0.017310455)
('mean', -1.694425e-09)
('mse', 0.018104034)
('mean', 1.3205136e-09)
('mse', 0.015037427)
('mean', 1.1079101e-09)
('mse', 0.019546295)
('mean', -1.710655e-09)
('mse', 0.016616955)
('mean', 3.5390257e-10)
('mse', 0.01334216)
('mean', -6.0880895e-10)
('mse', 0.019838518)
('mean', -1.2340023e-10)
('mse', 0.015999472)
('mean', -8.9465174e-10)
('mse', 0.01487229)
('mean', 7.3275025e-10)
('mse', 0.019977879)
('mean', 1.2274549e-09)
('mse', 0.014599838)
('mea

('mean', -2.7168403e-09)
('mse', 0.014260662)
('mean', 1.1007117e-09)
('mse', 0.019078713)
('mean', 5.838228e-10)
('mse', 0.01580362)
('mean', 3.8882716e-10)
('mse', 0.014199441)
('mean', 6.522055e-11)
('mse', 0.014789233)
('mean', 9.555879e-10)
('mse', 0.019494949)
('mean', 4.938922e-10)
('mse', 0.0154491365)
('mean', -1.1897646e-09)
('mse', 0.013145447)
('mean', 6.542541e-10)
('mse', 0.015460462)
('mean', -4.305184e-10)
('mse', 0.018133437)
('mean', -8.3491e-12)
('mse', 0.019433238)
('mean', 8.0560075e-10)
('mse', 0.014782578)
('mean', 9.924406e-10)
('mse', 0.014511555)
('mean', -1.1865632e-09)
('mse', 0.01598584)
('mean', 4.5227366e-10)
('mse', 0.014717934)
('mean', 3.370651e-10)
('mse', 0.01488046)
('mean', -2.724109e-10)
('mse', 0.018050363)
('mean', 1.7565253e-09)
('mse', 0.01807525)
('mean', -1.1061638e-09)
('mse', 0.015015567)
('mean', 2.444727e-11)
('mse', 0.014185095)
('mean', 7.5531714e-10)
('mse', 0.0162731)
('mean', -4.636238e-10)
('mse', 0.014139687)
('mean', -2.1885654e-

('mean', 8.7632546e-10)
('mse', 0.018878248)
('mean', -1.5832484e-10)
('mse', 0.015282666)
('mean', 8.9406965e-10)
('mse', 0.016821979)
('mean', 8.335337e-10)
('mse', 0.014595184)
('mean', -1.2923556e-09)
('mse', 0.017694548)
('mean', -5.7742e-10)
('mse', 0.014922806)
('mean', 2.242451e-09)
('mse', 0.019957483)
('mean', 5.104812e-10)
('mse', 0.019649614)
('mean', 5.089532e-10)
('mse', 0.018453822)
('mean', -4.9345544e-10)
('mse', 0.018885856)
('mean', -2.4592738e-11)
('mse', 0.013299609)
('mean', 3.4167896e-10)
('mse', 0.01854313)
('mean', -9.2415575e-10)
('mse', 0.016013639)
('mean', -4.775823e-10)
('mse', 0.012859813)
('mean', 4.782487e-10)
('mse', 0.017674392)
('mean', 1.6349077e-09)
('mse', 0.016255554)
('mean', -7.543713e-10)
('mse', 0.013185504)
('mean', 1.6542617e-09)
('mse', 0.017698996)
('mean', 4.2302417e-10)
('mse', 0.014972254)
('mean', 3.7172868e-10)
('mse', 0.016826587)
('mean', 6.9849196e-11)
('mse', 0.01676505)
('mean', -5.715992e-10)
('mse', 0.018655274)
('mean', -1.04

('mean', -1.4202665e-10)
('mse', 0.0019386705)
('mean', -8.242202e-10)
('mse', 0.005825473)
('mean', 7.1089745e-10)
('mse', 2.4134748e-05)
('mean', -5.353286e-11)
('mse', 1.4658898e-05)
('mean', -2.1752296e-09)
('mse', 0.019463478)
('mean', 3.3898687e-10)
('mse', 0.0039571477)
('mean', -7.0082024e-10)
('mse', 0.0151104955)
('mean', -7.9569873e-10)
('mse', 0.0020803513)
('mean', -4.8875337e-10)
('mse', 0.003184601)
('mean', 1.8533319e-09)
('mse', 0.013292551)
('mean', -4.4892635e-11)
('mse', 0.019926121)
('mean', 8.741699e-10)
('mse', 0.005440458)
('mean', 6.984919e-12)
('mse', 0.014812023)
('mean', -2.4259499e-09)
('mse', 0.015548107)
('mean', 3.8870895e-10)
('mse', 0.0025888188)
('mean', -1.1117937e-09)
('mse', 0.01396475)
('mean', -9.685754e-10)
('mse', 0.00030706407)
('mean', -8.1956386e-10)
('mse', 0.014969592)
('mean', -1.5832483e-10)
('mse', 0.004943136)
('mean', -8.7195073e-10)
('mse', 0.018125493)
('mean', 6.625487e-10)
('mse', 0.016269965)
('mean', 7.9980966e-10)
('mse', 1.464

('mean', -2.5043845e-10)
('mse', 0.015436001)
('mean', -7.439425e-10)
('mse', 0.01872127)
('mean', -1.023692e-09)
('mse', 0.01587198)
('mean', 6.9303495e-12)
('mse', 0.018857036)
('mean', -1.0904663e-09)
('mse', 0.019933248)
('mean', 2.916204e-10)
('mse', 0.019387359)
('mean', -1.1631982e-09)
('mse', 0.019612798)
('mean', -3.6132405e-10)
('mse', 0.018145833)
('mean', 3.8999137e-11)
('mse', 0.015118853)
('mean', 1.5774276e-10)
('mse', 0.016130624)
('mean', -5.5821153e-10)
('mse', 0.017130122)
('mean', 5.835318e-11)
('mse', 0.013619485)
('mean', 4.6071363e-10)
('mse', 0.01739086)
('mean', 6.5996825e-10)
('mse', 0.012754088)
('mean', -7.513108e-11)
('mse', 0.01807154)
('mean', 1.2279315e-10)
('mse', 0.01778803)
('mean', -8.3076884e-10)
('mse', 0.014578946)
('mean', -7.8693835e-10)
('mse', 0.013522703)
('mean', -7.2792317e-10)
('mse', 0.019690804)
('mean', -3.4835707e-10)
('mse', 0.011574198)
('mean', -3.2440767e-10)
('mse', 0.016367191)
('mean', -8.2945917e-10)
('mse', 0.014597897)
('mean

('mean', -1.3899989e-09)
('mse', 0.017947603)
('mean', 4.1254683e-10)
('mse', 0.014427848)
('mean', -1.3621829e-09)
('mse', 0.0154391145)
('mean', 2.5122426e-09)
('mse', 0.019971376)
('mean', -1.1526572e-09)
('mse', 0.017757267)
('mean', 1.466833e-10)
('mse', 0.017516425)
('mean', -1.013582e-09)
('mse', 0.018058289)
('mean', 2.1471351e-10)
('mse', 0.014535653)
('mean', 7.4506075e-11)
('mse', 0.015390949)
('mean', 1.1155686e-09)
('mse', 0.013909609)
('mean', 3.592868e-10)
('mse', 0.013485337)
('mean', -1.2665987e-09)
('mse', 0.013189272)
('mean', 6.530916e-10)
('mse', 0.014149462)
('mean', 2.3655602e-09)
('mse', 0.018543977)
('mean', 1.340559e-09)
('mse', 0.018926183)
('mean', 1.9904018e-09)
('mse', 0.016672907)
('mean', 4.147297e-12)
('mse', 0.0128596695)
('mean', 4.8501186e-11)
('mse', 0.019584112)
('mean', -3.0326183e-10)
('mse', 0.015141875)
('mean', 8.8510205e-10)
('mse', 0.014733903)
('mean', -2.2086169e-10)
('mse', 0.0149822915)
('mean', 3.1257513e-10)
('mse', 0.016489051)
('mean

('mean', -1.185108e-09)
('mse', 0.017016413)
('mean', -2.3646857e-11)
('mse', 0.014085828)
('mean', 5.118534e-10)
('mse', 0.019177558)
('mean', 7.722747e-10)
('mse', 0.014242602)
('mean', 7.2979554e-16)
('mse', 0.014142368)
('mean', 4.1050957e-10)
('mse', 0.016132733)
('mean', 1.3387764e-11)
('mse', 0.017569054)
('mean', 1.8755864e-09)
('mse', 0.015196465)
('mean', 1.5791738e-09)
('mse', 0.014906399)
('mean', -4.858839e-10)
('mse', 0.017777449)
('mean', 6.396342e-10)
('mse', 0.014882567)
('mean', 5.107313e-10)
('mse', 0.016296677)
('mean', 7.026404e-10)
('mse', 0.017995423)
('mean', -1.9864275e-09)
('mse', 0.015973438)
('mean', 2.5152075e-10)
('mse', 0.01926252)
('mean', -8.981169e-10)
('mse', 0.018738167)
('mean', -6.7375366e-10)
('mse', 0.013402258)
('mean', 3.4429878e-10)
('mse', 0.014768911)
('mean', 9.51841e-10)
('mse', 0.017092172)
('mean', 5.0989957e-10)
('mse', 0.016625708)
('mean', 4.6885534e-10)
('mse', 0.01875921)
('mean', 1.2215092e-09)
('mse', 0.012894421)
('mean', 3.86283

In [150]:
ff, ll = files[2525], labels[2525]
img = cv2.imread(ff)
img = cv2.resize(img, (150, 80)) / 255.0
img = np.expand_dims(img, 0)
img.shape

(1, 80, 150, 3)

In [151]:
np.mean(base_model.predict(img) - ll)

-93.99

In [152]:
np.mean(np.square(base_model.predict(img) - ll), axis=-1)

array([8834.126], dtype=float32)

In [163]:
ll, np.argmax(base_model.predict(img), axis=1)

(94, array([27]))

In [154]:
index = np.random.choice(len(labels), batch_size, replace=False)
label_list = labels[index]
file_list = files[index]
label_list, batch_size

(array([86, 51,  5,  7, 28, 51, 21, 15, 77, 48, 47, 30, 94, 50, 51, 53, 86,
        67, 15,  4, 43, 12, 94, 17, 14, 58, 40, 24,  7, 99, 71, 29, 32, 47,
        16, 61, 83, 96, 13, 81, 56, 31, 50, 80, 71, 10, 44,  9, 45, 16, 32,
        26, 58, 25, 43, 63,  8, 75,  5, 68, 42, 44, 81, 36]), 64)

In [157]:
lbl = to_categorical(np.array(labels), num_classes=100)

In [162]:
base_model.predict(img)

array([[9.48159468e-11, 3.82158105e-05, 9.07421963e-06, 3.96392039e-13,
        3.58873876e-05, 4.00542532e-09, 1.06933681e-11, 4.73708004e-15,
        2.17511328e-07, 4.54494439e-05, 3.79323873e-12, 1.14529866e-14,
        5.20818867e-04, 4.96327957e-08, 6.74605047e-08, 1.48367016e-02,
        2.47269677e-10, 2.38397566e-04, 4.09140138e-10, 1.03598531e-07,
        1.56483479e-10, 2.67315636e-06, 5.91897697e-06, 1.33341438e-09,
        1.35822082e-03, 2.35824762e-11, 1.73823489e-03, 7.36553252e-01,
        2.11590964e-06, 1.32346045e-07, 2.21699872e-03, 3.77253957e-11,
        4.56198445e-03, 2.09035843e-05, 6.48990861e-10, 8.90399213e-11,
        7.61320844e-05, 3.96524230e-10, 8.84075746e-08, 4.75062663e-03,
        2.62610658e-12, 1.58525324e-07, 1.39562950e-09, 1.50020223e-05,
        6.74512776e-05, 3.01604478e-14, 7.28327564e-07, 2.85222984e-10,
        4.29088622e-02, 8.20120447e-12, 6.02269247e-05, 2.00682243e-05,
        2.37111263e-02, 5.10784835e-02, 8.41704160e-02, 1.095841

In [8]:
model = load_model('brand_classify_vgg16.h5')



In [22]:
x, y = next(data_gen(batch_size=200, gene=1))

In [23]:
np.mean(np.argmax(model.predict(x), axis=-1) == np.argmax(y, axis=-1))

0.955