In [1]:
from keras import layers
from keras import models

Using TensorFlow backend.


In [2]:
def build_model():
    model = models.Sequential()
    model.add(layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)))
    model.add(layers.MaxPool2D((2,2)))
    model.add(layers.Conv2D(64, (3, 3), activation='relu'))
    model.add(layers.MaxPool2D((2, 2)))
    model.add(layers.Conv2D(64, (3, 3), activation='relu'))
    model.add(layers.Flatten())
    model.add(layers.Dense(64, activation='relu'))
    model.add(layers.Dense(10, activation='softmax'))
    return model 

In [3]:
model = build_model()
print(model.summary())

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d_1 (Conv2D)            (None, 26, 26, 32)        320       
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 13, 13, 32)        0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 11, 11, 64)        18496     
_________________________________________________________________
max_pooling2d_2 (MaxPooling2 (None, 5, 5, 64)          0         
_________________________________________________________________
conv2d_3 (Conv2D)            (None, 3, 3, 64)          36928     
_________________________________________________________________
flatten_1 (Flatten)          (None, 576)               0         
_________________________________________________________________
dense_1 (Dense)              (None, 64)                36928     
__________

In [4]:
from keras.datasets import mnist
from keras.utils import to_categorical
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
print(train_images.shape, train_labels.shape, test_images.shape, test_labels.shape)

import numpy as np
def add_dim_1 (shape):
    return np.concatenate( [np.array(shape), np.array((1,))] )
tis = add_dim_1(train_images.shape)
print(tis)
train_images = train_images.reshape(add_dim_1(train_images.shape))
test_images = test_images.reshape(add_dim_1(test_images.shape))
print(train_images.shape)
print(test_images.shape)
train_labels = to_categorical(train_labels)
test_labels = to_categorical(test_labels)
print(train_labels.shape, test_labels.shape)

(60000, 28, 28) (60000,) (10000, 28, 28) (10000,)
[60000    28    28     1]
(60000, 28, 28, 1)
(10000, 28, 28, 1)
(60000, 10) (10000, 10)


In [5]:
model.compile(optimizer='rmsprop', loss='categorical_crossentropy', metrics=['accuracy'])
model.fit(train_images, train_labels, epochs=5, batch_size=64)

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<keras.callbacks.History at 0x7f50aa5e3978>

In [6]:
test_loss, test_acc = model.evaluate(test_images, test_labels)
print(test_loss, test_acc)

14.461155032348632 0.1028


In [7]:
import matplotlib.pyplot as plt
idx = 12
print(test_labels[idx])
print(test_images[idx].shape)
print(add_dim_1(test_images[idx].shape))
plt.imshow(test_images[idx].reshape(test_images[idx].shape[0:2]), cmap='gray')
pred_ar = test_images[idx:idx+1]
print(pred_ar.shape)
print(pred_ar[0].shape)
plt.figure()
plt.imshow(pred_ar[0].reshape(pred_ar[0].shape[0:2]), cmap='gray')
pred_probs = model.predict(pred_ar)
print(pred_probs)
print(np.argmax(pred_probs))

[0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]
(28, 28, 1)
[28 28  1  1]
(1, 28, 28, 1)
(28, 28, 1)
[[0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]]
7


In [21]:
# kaggle competitions download -c dogs-vs-cats
import os, shutil
original_dataset_dir = '/home/ktdiedrich/.kaggle/competitions/dogs-vs-cats'
base_dir = '/home/ktdiedrich/Documents/localData/machineLearning/dogs-vs-cats_small'
if not os.path.exists(base_dir):
    os.makedirs(base_dir, exist_ok=True)
train_dir = os.path.join(base_dir, 'train')
test_dir = os.path.join(base_dir, 'test')
train_cats_dir = os.path.join
validation_dir = os.path.join(base_dir, 'validation')
train_cats_dir = os.path.join(train_dir, 'cats')
train_dogs_dir = os.path.join(train_dir, 'dogs')
validation_cats_dir = os.path.join(validation_dir, 'cats')
validation_dogs_dir = os.path.join(validation_dir, 'dogs')
test_cats_dir = os.path.join(test_dir, 'cats')
test_dogs_dir = os.path.join(test_dir, 'dogs')

dir_paths = [train_cats_dir, train_dogs_dir, 
        validation_cats_dir, validation_dogs_dir, test_cats_dir, test_dogs_dir]

for dir_path in dir_paths: 
    if not os.path.exists(dir_path):
        os.makedirs(dir_path, exist_ok=True)
import cv2
def copy_files(base_name, destination, begin, end, source=original_dataset_dir):
    fnames = ['{}.{}.jpg'.format(base_name, i) for i in range(begin, end)]
    for fname in fnames:
        src = os.path.join(source, fname)
        img = cv2.imread(src)
        # print(img.shape)
        img = cv2.resize(img, (150, 150))
        dst = os.path.join(destination, fname)
        cv2.imwrite(dst, img)
        #shutil.copyfile(src, dst)
        
data_set_dir = os.path.join(original_dataset_dir, 'train')
copy_files(base_name='cat', destination=train_cats_dir, begin=0, end=1000, source=data_set_dir)
copy_files(base_name='dog', destination=train_dogs_dir, begin=0, end=1000, source=data_set_dir)

copy_files(base_name='cat', destination=validation_cats_dir, begin=1000, end=1500, source=data_set_dir)
copy_files(base_name='dog', destination=validation_dogs_dir, begin=1000, end=1500, source=data_set_dir)

copy_files(base_name='cat', destination=test_cats_dir, begin=1500, end=2000, source=data_set_dir)
copy_files(base_name='dog', destination=test_dogs_dir, begin=1500, end=2000, source=data_set_dir)

for dir_path in dir_paths:
    print("{}: {}".format(dir_path, len(os.listdir(dir_path))))

(374, 500, 3)
(280, 300, 3)
(396, 312, 3)
(414, 500, 3)
(375, 499, 3)
(144, 175, 3)
(303, 400, 3)
(499, 495, 3)
(345, 461, 3)
(425, 320, 3)
(499, 489, 3)
(410, 431, 3)
(224, 300, 3)
(315, 499, 3)
(267, 320, 3)
(353, 405, 3)
(258, 448, 3)
(375, 499, 3)
(374, 500, 3)
(223, 320, 3)
(374, 500, 3)
(499, 431, 3)
(345, 500, 3)
(256, 334, 3)
(374, 500, 3)
(500, 345, 3)
(374, 500, 3)
(479, 370, 3)
(270, 286, 3)
(375, 499, 3)
(262, 349, 3)
(374, 500, 3)
(374, 500, 3)
(375, 499, 3)
(499, 375, 3)
(426, 499, 3)
(311, 500, 3)
(337, 499, 3)
(337, 350, 3)
(500, 374, 3)
(383, 499, 3)
(499, 333, 3)
(173, 237, 3)
(500, 356, 3)
(102, 107, 3)
(370, 500, 3)
(433, 400, 3)
(214, 258, 3)
(93, 139, 3)
(129, 180, 3)
(196, 299, 3)
(473, 256, 3)
(300, 399, 3)
(457, 492, 3)
(499, 500, 3)
(500, 344, 3)
(332, 499, 3)
(291, 335, 3)
(375, 499, 3)
(331, 464, 3)
(176, 180, 3)
(226, 328, 3)
(377, 500, 3)
(374, 500, 3)
(374, 500, 3)
(375, 499, 3)
(382, 499, 3)
(374, 500, 3)
(181, 249, 3)
(374, 500, 3)
(375, 499, 3)
(417, 4

(497, 500, 3)
(374, 500, 3)
(346, 258, 3)
(245, 367, 3)
(374, 500, 3)
(359, 480, 3)
(199, 75, 3)
(325, 265, 3)
(333, 499, 3)
(374, 500, 3)
(374, 500, 3)
(375, 499, 3)
(332, 499, 3)
(374, 500, 3)
(359, 480, 3)
(180, 111, 3)
(466, 349, 3)
(375, 499, 3)
(375, 499, 3)
(375, 499, 3)
(476, 499, 3)
(373, 298, 3)
(297, 350, 3)
(302, 350, 3)
(500, 369, 3)
(396, 500, 3)
(499, 483, 3)
(140, 145, 3)
(190, 249, 3)
(375, 499, 3)
(373, 499, 3)
(299, 187, 3)
(248, 288, 3)
(198, 239, 3)
(453, 500, 3)
(378, 500, 3)
(177, 200, 3)
(499, 375, 3)
(222, 350, 3)
(332, 499, 3)
(256, 287, 3)
(499, 446, 3)
(278, 300, 3)
(375, 499, 3)
(494, 499, 3)
(500, 407, 3)
(375, 499, 3)
(315, 399, 3)
(500, 383, 3)
(383, 499, 3)
(375, 499, 3)
(240, 319, 3)
(273, 350, 3)
(374, 500, 3)
(500, 201, 3)
(382, 499, 3)
(375, 499, 3)
(333, 499, 3)
(228, 300, 3)
(375, 499, 3)
(366, 283, 3)
(149, 200, 3)
(175, 348, 3)
(374, 500, 3)
(500, 349, 3)
(295, 500, 3)
(455, 368, 3)
(475, 425, 3)
(375, 499, 3)
(269, 300, 3)
(180, 499, 3)
(330, 4

(359, 480, 3)
(240, 226, 3)
(356, 478, 3)
(299, 400, 3)
(288, 359, 3)
(439, 440, 3)
(338, 300, 3)
(300, 399, 3)
(499, 432, 3)
(471, 382, 3)
(431, 362, 3)
(89, 120, 3)
(499, 315, 3)
(391, 500, 3)
(199, 178, 3)
(357, 370, 3)
(466, 499, 3)
(398, 499, 3)
(408, 365, 3)
(374, 500, 3)
(375, 499, 3)
(231, 404, 3)
(250, 350, 3)
(499, 500, 3)
(253, 200, 3)
(375, 499, 3)
(375, 499, 3)
(359, 336, 3)
(500, 471, 3)
(179, 240, 3)
(270, 246, 3)
(399, 320, 3)
(224, 300, 3)
(466, 499, 3)
(500, 486, 3)
(247, 429, 3)
(499, 443, 3)
(374, 500, 3)
(400, 326, 3)
(275, 199, 3)
(500, 489, 3)
(240, 319, 3)
(186, 287, 3)
(247, 299, 3)
(380, 499, 3)
(191, 199, 3)
(375, 499, 3)
(500, 396, 3)
(397, 399, 3)
(215, 286, 3)
(481, 400, 3)
(375, 499, 3)
(374, 500, 3)
(427, 432, 3)
(270, 275, 3)
(374, 500, 3)
(374, 500, 3)
(375, 499, 3)
(500, 272, 3)
(499, 477, 3)
(375, 499, 3)
(374, 500, 3)
(472, 489, 3)
(500, 392, 3)
(371, 499, 3)
(452, 397, 3)
(382, 421, 3)
(377, 499, 3)
(232, 350, 3)
(240, 319, 3)
(367, 500, 3)
(375, 4

(337, 499, 3)
(374, 500, 3)
(375, 499, 3)
(99, 90, 3)
(411, 499, 3)
(374, 500, 3)
(400, 322, 3)
(426, 400, 3)
(495, 500, 3)
(402, 500, 3)
(286, 430, 3)
(499, 302, 3)
(500, 374, 3)
(319, 240, 3)
(374, 500, 3)
(378, 499, 3)
(187, 200, 3)
(375, 499, 3)
(375, 499, 3)
(246, 230, 3)
(375, 499, 3)
(316, 450, 3)
(496, 499, 3)
(375, 499, 3)
(374, 500, 3)
(375, 499, 3)
(365, 350, 3)
(499, 375, 3)
(373, 500, 3)
(375, 499, 3)
(499, 267, 3)
(475, 474, 3)
(255, 499, 3)
(374, 500, 3)
(499, 500, 3)
(384, 332, 3)
(431, 500, 3)
(459, 500, 3)
(419, 500, 3)
(434, 407, 3)
(221, 214, 3)
(389, 499, 3)
(395, 247, 3)
(200, 146, 3)
(343, 499, 3)
(374, 500, 3)
(331, 500, 3)
(500, 310, 3)
(499, 375, 3)
(375, 499, 3)
(74, 100, 3)
(372, 400, 3)
(258, 329, 3)
(374, 500, 3)
(331, 500, 3)
(375, 499, 3)
(262, 349, 3)
(181, 249, 3)
(499, 494, 3)
(500, 483, 3)
(149, 150, 3)
(499, 336, 3)
(407, 379, 3)
(269, 259, 3)
(375, 499, 3)
(333, 500, 3)
(375, 499, 3)
(374, 500, 3)
(368, 328, 3)
(400, 353, 3)
(374, 500, 3)
(273, 500

(460, 429, 3)
(360, 479, 3)
(329, 375, 3)
(324, 480, 3)
(333, 499, 3)
(303, 499, 3)
(307, 500, 3)
(280, 308, 3)
(483, 349, 3)
(500, 464, 3)
(492, 425, 3)
(373, 500, 3)
(374, 500, 3)
(374, 500, 3)
(462, 399, 3)
(499, 367, 3)
(289, 360, 3)
(375, 499, 3)
(413, 500, 3)
(422, 499, 3)
(374, 500, 3)
(374, 500, 3)
(224, 300, 3)
(299, 300, 3)
(421, 421, 3)
(375, 499, 3)
(329, 499, 3)
(499, 363, 3)
(375, 499, 3)
(212, 288, 3)
(500, 414, 3)
(500, 423, 3)
(374, 500, 3)
(499, 432, 3)
(179, 180, 3)
(225, 299, 3)
(499, 415, 3)
(299, 293, 3)
(426, 467, 3)
(374, 500, 3)
(418, 500, 3)
(399, 387, 3)
(373, 500, 3)
(216, 286, 3)
(459, 500, 3)
(200, 174, 3)
(368, 198, 3)
(375, 499, 3)
(374, 500, 3)
(380, 499, 3)
(500, 499, 3)
(500, 360, 3)
(393, 311, 3)
(499, 360, 3)
(338, 449, 3)
(374, 500, 3)
(225, 299, 3)
(242, 322, 3)
(499, 492, 3)
(375, 499, 3)
(200, 174, 3)
(221, 318, 3)
(268, 288, 3)
(499, 392, 3)
(374, 500, 3)
(375, 499, 3)
(375, 499, 3)
(499, 375, 3)
(375, 499, 3)
(199, 319, 3)
(491, 363, 3)
(470, 

(374, 500, 3)
(394, 499, 3)
(499, 412, 3)
(205, 229, 3)
(499, 492, 3)
(375, 499, 3)
(298, 236, 3)
(228, 300, 3)
(409, 499, 3)
(499, 247, 3)
(187, 215, 3)
(499, 375, 3)
(402, 369, 3)
(407, 500, 3)
(337, 327, 3)
(374, 499, 3)
(374, 500, 3)
(374, 500, 3)
(470, 499, 3)
(375, 499, 3)
(220, 208, 3)
(375, 499, 3)
(298, 500, 3)
(230, 306, 3)
(474, 350, 3)
(225, 299, 3)
(197, 265, 3)
(375, 499, 3)
(288, 194, 3)
(333, 324, 3)
(500, 373, 3)
(375, 499, 3)
(364, 500, 3)
(239, 206, 3)
(477, 367, 3)
(299, 400, 3)
(479, 381, 3)
(448, 321, 3)
(374, 500, 3)
(319, 278, 3)
(314, 329, 3)
(240, 319, 3)
(374, 500, 3)
(199, 165, 3)
(375, 499, 3)
(499, 375, 3)
(261, 350, 3)
(270, 266, 3)
(334, 499, 3)
(499, 473, 3)
(301, 250, 3)
(322, 312, 3)
(500, 215, 3)
(392, 380, 3)
(400, 399, 3)
(110, 99, 3)
(239, 240, 3)
(375, 499, 3)
(359, 480, 3)
(500, 374, 3)
(258, 240, 3)
(375, 499, 3)
(500, 385, 3)
(311, 332, 3)
(200, 239, 3)
(448, 229, 3)
(234, 249, 3)
(375, 299, 3)
(226, 342, 3)
(374, 500, 3)
(240, 319, 3)
(262, 3

(499, 424, 3)
(500, 450, 3)
(351, 480, 3)
(420, 289, 3)
(375, 499, 3)
(326, 267, 3)
(200, 167, 3)
(249, 250, 3)
(437, 500, 3)
(375, 499, 3)
(499, 399, 3)
(375, 499, 3)
(337, 449, 3)
(370, 275, 3)
(285, 249, 3)
(203, 250, 3)
(500, 356, 3)
(386, 350, 3)
(374, 500, 3)
(217, 173, 3)
(374, 500, 3)
(399, 400, 3)
(499, 276, 3)
(287, 311, 3)
(499, 494, 3)
(225, 311, 3)
(299, 283, 3)
(373, 500, 3)
(242, 500, 3)
(239, 319, 3)
(352, 500, 3)
(499, 375, 3)
(299, 224, 3)
(499, 487, 3)
(374, 500, 3)
(401, 500, 3)
(499, 429, 3)
(500, 344, 3)
(346, 303, 3)
(499, 367, 3)
(212, 320, 3)
(288, 386, 3)
(225, 299, 3)
(499, 457, 3)
(374, 500, 3)
(215, 288, 3)
(333, 499, 3)
(267, 394, 3)
(191, 191, 3)
(427, 419, 3)
(374, 500, 3)
(232, 231, 3)
(122, 160, 3)
(373, 499, 3)
(330, 500, 3)
(374, 500, 3)
(499, 479, 3)
(374, 500, 3)
(499, 366, 3)
(211, 400, 3)
(375, 499, 3)
(375, 499, 3)
(250, 249, 3)
(457, 499, 3)
(334, 500, 3)
(499, 417, 3)
(300, 270, 3)
(299, 400, 3)
(375, 499, 3)
(329, 250, 3)
(374, 500, 3)
(265, 