In [1]:
%config Completer.use_jedi = False

In [2]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, MaxPooling2D, Conv2D, Flatten, InputLayer
from tensorflow.keras.datasets import cifar10
from sklearn.metrics import confusion_matrix

In [3]:
(x_train, y_train), (x_test, y_test) = cifar10.load_data()

### CIFAR 10 classes     airplane : 0 ,    automobile : 1  ,   bird : 2  ,   cat : 3  ,   deer : 4  ,   dog : 5  ,  frog : 6   ,  horse : 7  ,   ship : 8   ,  truck : 9

In [4]:
print(x_train.shape)
print(y_train.shape)
print(x_test.shape)
print(y_test.shape)

(50000, 32, 32, 3)
(50000, 1)
(10000, 32, 32, 3)
(10000, 1)


In [5]:
x_train = x_train.astype("float32") / 255.0
x_test = x_test.astype("float32") / 255.0

## As we are using CNN, we don't need to flatten the input images before giving into the model

In [6]:
print(x_train.shape)
print(x_test.shape)

(50000, 32, 32, 3)
(10000, 32, 32, 3)


In [9]:
model = Sequential()
model.add(InputLayer(input_shape=(32,32,3)))
model.add(Conv2D(32, kernel_size=3, padding='valid', activation='relu'))
model.add(MaxPooling2D(pool_size=(2,2)))
model.add(Conv2D(64, kernel_size=3, activation='relu'))
model.add(MaxPooling2D())
model.add(Conv2D(128, kernel_size=3, activation='relu'))
model.add(Flatten())
model.add(Dense(64, activation='relu'))
model.add(Dense(10))

In [10]:
model.summary()

Model: "sequential_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d_1 (Conv2D)            (None, 30, 30, 32)        896       
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 15, 15, 32)        0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 13, 13, 64)        18496     
_________________________________________________________________
max_pooling2d_2 (MaxPooling2 (None, 6, 6, 64)          0         
_________________________________________________________________
conv2d_3 (Conv2D)            (None, 4, 4, 128)         73856     
_________________________________________________________________
flatten (Flatten)            (None, 2048)              0         
_________________________________________________________________
dense (Dense)                (None, 64)               

### Keras Model Traiing APIs - https://keras.io/api/models/model_training_apis/

In [12]:
model.compile(
    loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer = keras.optimizers.Adam(),
    metrics= ['accuracy']
)

In [13]:
model.fit(x_train, y_train, batch_size=64, epochs=10, verbose=2)

Epoch 1/10
782/782 - 17s - loss: 1.6078 - accuracy: 0.4102
Epoch 2/10
782/782 - 8s - loss: 1.2410 - accuracy: 0.5572
Epoch 3/10
782/782 - 8s - loss: 1.0805 - accuracy: 0.6193
Epoch 4/10
782/782 - 9s - loss: 0.9687 - accuracy: 0.6582
Epoch 5/10
782/782 - 8s - loss: 0.8788 - accuracy: 0.6929
Epoch 6/10
782/782 - 8s - loss: 0.8171 - accuracy: 0.7175
Epoch 7/10
782/782 - 8s - loss: 0.7581 - accuracy: 0.7374
Epoch 8/10
782/782 - 7s - loss: 0.7049 - accuracy: 0.7555
Epoch 9/10
782/782 - 10s - loss: 0.6638 - accuracy: 0.7699
Epoch 10/10
782/782 - 9s - loss: 0.6162 - accuracy: 0.7874


<tensorflow.python.keras.callbacks.History at 0x2048486ab88>

In [17]:
model.evaluate(x_test, y_test, batch_size=64, verbose=2)

157/157 - 1s - loss: 0.8156 - accuracy: 0.7254


[0.8156260848045349, 0.7253999710083008]

In [18]:
y_pred = model.predict(x_test)

In [19]:
y_pred

array([[ -0.75745624,  -3.634709  ,  -2.3077433 , ...,  -3.1414852 ,
         -0.67020977,  -2.8515463 ],
       [  1.2681766 ,   3.064496  ,  -6.172635  , ..., -10.04764   ,
          8.146093  ,   1.1396742 ],
       [  0.1684213 ,   0.9035431 ,  -4.0576334 , ...,  -3.1569579 ,
          1.0653434 ,   1.6354463 ],
       ...,
       [ -7.06685   ,  -9.600877  ,   2.7736156 , ...,   1.9572964 ,
         -5.0661    ,  -4.9726586 ],
       [  0.565251  ,  -0.3054345 ,  -3.7934618 , ...,  -1.1840146 ,
         -5.5216546 ,  -4.631081  ],
       [ -4.1383905 ,  -5.221046  ,   0.5248078 , ...,   9.768485  ,
         -8.706483  ,  -7.1981635 ]], dtype=float32)