In [1]:
# import tensorflow as tf
import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow import keras

In [4]:
def prepare_tf_data(Train_size = 32, Test_size = 16, img_size = 224, random_ratio = 0.1):
    (train_ds, test_ds), ds_info = tfds.load(
        'cifar10',
        split=['train', 'test'],
        as_supervised=True,
        with_info=True
    )
    
    def preprocess(image, label):
        image = tf.cast(image, tf.float32) / 255.0
        image = tf.image.resize(image, (img_size, img_size))
        return image, label

    def augment(image, label):
        image = tf.image.random_flip_left_right(image)
        image = tf.image.random_brightness(image, random_ratio)
        image = tf.image.random_contrast(image, 1 - random_ratio, 1 + random_ratio)
        return image, label
    
    AUTOTUNE = tf.data.AUTOTUNE
    
    train_ds = (train_ds
                .map(preprocess, num_parallel_calls=AUTOTUNE)
                .map(augment, num_parallel_calls=AUTOTUNE)
                .shuffle(10000)
                .batch(Train_size)
                .prefetch(AUTOTUNE))
    
    test_ds = (test_ds
               .map(preprocess, num_parallel_calls=AUTOTUNE)
               .batch(Test_size)
               .prefetch(AUTOTUNE))
    
    return train_ds, test_ds

In [5]:
train_ds, test_ds = prepare_tf_data(img_size=32)

In [13]:
from models.densenet import DenseNet
import tensorflow_addons as tfa

In [14]:
# den = DenseNet(num_classes=10)
den = tf.keras.models.load_model('./den_cifar10_best')

initial_learning_rate = 5e-5
decay_steps = 5000
# lr_schedule = tf.keras.optimizers.schedules.CosineDecay(
#     initial_learning_rate, decay_steps
# )

lr_schedule_den = tf.keras.optimizers.schedules.ExponentialDecay(
    initial_learning_rate=initial_learning_rate,
    decay_steps=decay_steps,
    decay_rate=0.9,  
    staircase=True   
)

den.compile(
    optimizer = tfa.optimizers.AdamW(
    learning_rate=lr_schedule_den,
    weight_decay=0.0001
),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=['accuracy']
)

callbacks_den = [
    tf.keras.callbacks.ModelCheckpoint(
        'den_cifar10_best',  
        save_best_only=True,
        monitor='val_accuracy'
    ),
    tf.keras.callbacks.EarlyStopping(
        monitor='val_accuracy',
        patience=5,
        restore_best_weights=True
    ),
    tf.keras.callbacks.TensorBoard(log_dir='./logs/den_cifar10')
]

history_den = den.fit(
    train_ds,
    validation_data=test_ds,
    epochs=10,
    callbacks=callbacks_den
)

Epoch 1/10



INFO:tensorflow:Assets written to: den_cifar10_best\assets


INFO:tensorflow:Assets written to: den_cifar10_best\assets


Epoch 2/10
Epoch 3/10
Epoch 4/10



INFO:tensorflow:Assets written to: den_cifar10_best\assets


INFO:tensorflow:Assets written to: den_cifar10_best\assets


Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
