In [None]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.datasets import cifar10

In [None]:
(x_train, y_train), (x_test, y_test) = cifar10.load_data()

In [None]:
x_train = x_train.astype("float32") / 255.0
x_test = x_test.astype("float32") / 255.0

In [None]:
model = keras.Sequential(
    [
     keras.Input(shape = (32,32,3)),
     layers.Conv2D(32, 3, padding="same", activation="relu"),
     layers.MaxPooling2D(pool_size=(2,2)),
     layers.Conv2D(64, 3, padding="same", activation="relu"),
     layers.MaxPooling2D(),
     layers.Conv2D(128, 3, padding="same", activation="relu"),
     layers.Flatten(),
     layers.Dense(64, activation="relu"),
     layers.Dense(10),
    ]
)



In [None]:
model.compile(
    loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer = keras.optimizers.Adam(lr=3e-4),
    metrics = ["accuracy"],
)

  "The `lr` argument is deprecated, use `learning_rate` instead.")


In [None]:
model.fit(x_train, y_train, batch_size=64, epochs=10, verbose=2)
model.evaluate(x_test, y_test, batch_size=64, verbose=2)

Epoch 1/10
782/782 - 103s - loss: 0.1918 - accuracy: 0.9364
Epoch 2/10
782/782 - 103s - loss: 0.1656 - accuracy: 0.9469
Epoch 3/10
782/782 - 103s - loss: 0.1471 - accuracy: 0.9525
Epoch 4/10
782/782 - 103s - loss: 0.1240 - accuracy: 0.9607
Epoch 5/10
782/782 - 103s - loss: 0.1088 - accuracy: 0.9654
Epoch 6/10
782/782 - 103s - loss: 0.0952 - accuracy: 0.9705
Epoch 7/10
782/782 - 103s - loss: 0.0860 - accuracy: 0.9738
Epoch 8/10
782/782 - 103s - loss: 0.0732 - accuracy: 0.9778
Epoch 9/10
782/782 - 103s - loss: 0.0688 - accuracy: 0.9783
Epoch 10/10
782/782 - 103s - loss: 0.0543 - accuracy: 0.9843
157/157 - 5s - loss: 1.5118 - accuracy: 0.7293


[1.5117930173873901, 0.7293000221252441]

In [None]:
model.save('model.h5')

In [None]:
saved_model = keras.models.load_model('./model.h5')
saved_model.predict(x_test)

array([[  9.108884 ,  14.680792 , -10.833942 ,  -4.6781697,  -8.805697 ,
         -3.8885143, -12.696902 , -21.26303  ,  30.844404 ,   3.6461673]],
      dtype=float32)