In [42]:
import os
os.environ["TFF_CPP_MIN_LOG_LEVEL"]="2"

In [43]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow_datasets as tfds

In [44]:
#load and split the data
(ds_train,ds_test), ds_info=tfds.load(
    "mnist",
    split=["train","test"],
    shuffle_files=True,
    as_supervised=True,
    with_info=True,

)

In [45]:
# normalize the data
def normalize_image(image,label):
    return tf.cast(image,tf.float32)/255.0,label

In [46]:
AUTOTUNE=tf.data.experimental.AUTOTUNE
BATCH_SIZE=128

In [47]:
#setup train dataset
ds_train=ds_train.map(normalize_image,num_parallel_calls=AUTOTUNE).cache()
ds_train=ds_train.shuffle(ds_info.splits["train"].num_examples).batch(BATCH_SIZE)
ds_train=ds_train.prefetch(AUTOTUNE)

In [48]:
#setup test dataset
ds_test=ds_test.map(normalize_image,num_parallel_calls=AUTOTUNE).cache()
ds_test=ds_test.batch(BATCH_SIZE).prefetch(AUTOTUNE)

In [49]:
model = keras.Sequential(
    [
        keras.Input((28, 28, 1)),
        layers.Conv2D(32, 3, activation="relu"),
        layers.Flatten(),
        tf.keras.layers.Dense(10, activation="softmax"),
    ]
)

In [50]:
save_callback=keras.callbacks.ModelCheckpoint(
    "checkpoints/",save_weights_only=True, monitor="train_acc",save_best_only=False,
)

In [51]:
lr_scheduler=keras.callbacks.ReduceLROnPlateau(
    monitor="loss", factor=0.1, patience=3, mode="max",verbose=1
)

In [52]:
class my_callback(keras.callbacks.Callback):
    def on_epoch_end(self,epoch,logs=None):
        if logs.get("accuracy")>.95:
            print("accuracy over 95 %, quitting trainning")
            self.model.stop_training=True
            

In [53]:
model.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=keras.optimizers.Adam(learning_rate=.001),
    metrics=["accuracy"]
)

In [54]:
model.fit(
    ds_train,
    epochs=10,
    callbacks=[save_callback,lr_scheduler,my_callback()],
    verbose=2,
)

Epoch 1/10


469/469 - 13s - loss: 0.2520 - accuracy: 0.9286 - lr: 0.0010 - 13s/epoch - 28ms/step
Epoch 2/10
accuracy over 95 %, quitting trainning
469/469 - 11s - loss: 0.0905 - accuracy: 0.9748 - lr: 0.0010 - 11s/epoch - 23ms/step


<keras.src.callbacks.History at 0x1fbbff8dad0>