In [1]:
# import tensorflow as tf
import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow import keras

In [2]:
print(tf.test.is_built_with_cuda())

True


In [4]:
def prepare_tf_data(Train_size = 32, Test_size = 16, img_size = 224, random_ratio = 0.1):
    (train_ds, test_ds), ds_info = tfds.load(
        'cifar10',
        split=['train', 'test'],
        as_supervised=True,
        with_info=True
    )
    
    def preprocess(image, label):
        image = tf.cast(image, tf.float32) / 255.0
        image = tf.image.resize(image, (img_size, img_size))
        return image, label

    def augment(image, label):
        image = tf.image.random_flip_left_right(image)
        image = tf.image.random_brightness(image, random_ratio)
        image = tf.image.random_contrast(image, 1 - random_ratio, 1 + random_ratio)
        return image, label
    
    AUTOTUNE = tf.data.AUTOTUNE
    
    train_ds = (train_ds
                .map(preprocess, num_parallel_calls=AUTOTUNE)
                .map(augment, num_parallel_calls=AUTOTUNE)
                .shuffle(10000)
                .batch(Train_size)
                .prefetch(AUTOTUNE))
    
    test_ds = (test_ds
               .map(preprocess, num_parallel_calls=AUTOTUNE)
               .batch(Test_size)
               .prefetch(AUTOTUNE))
    
    return train_ds, test_ds

In [5]:
train_ds, test_ds = prepare_tf_data(Train_size = 64, Test_size = 32,img_size=32)

In [7]:
from models.resnet_cifar import ResNetCIFAR
import tensorflow_addons as tfa

In [10]:
res_ci = ResNetCIFAR(num_classes=10)
# den = tf.keras.models.load_model('./den_cifar10_best')

initial_learning_rate = 1e-2
decay_steps = 5000
# lr_schedule = tf.keras.optimizers.schedules.CosineDecay(
#     initial_learning_rate, decay_steps
# )

lr_schedule_res_ci = tf.keras.optimizers.schedules.ExponentialDecay(
    initial_learning_rate=initial_learning_rate,
    decay_steps=decay_steps,
    decay_rate=0.75, 
    staircase=True   
)

res_ci.compile(
    optimizer = tfa.optimizers.AdamW(
    learning_rate=lr_schedule_res_ci,
    weight_decay=0.0001
),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=['accuracy']
)

callbacks_res_ci = [
    tf.keras.callbacks.ModelCheckpoint(
        'res_ci_cifar10_best',  
        save_best_only=True,
        monitor='val_accuracy'
    ),
    tf.keras.callbacks.EarlyStopping(
        monitor='val_accuracy',
        patience=5,
        restore_best_weights=True
    ),
    tf.keras.callbacks.TensorBoard(log_dir='./logs/res_ci_cifar10')
]

# tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir="./logs")
# model.fit(..., callbacks=[tensorboard_callback])

history_res_ci = res_ci.fit(
    train_ds,
    validation_data=test_ds,
    epochs=20,
    callbacks=callbacks_res_ci
)

Epoch 1/20



INFO:tensorflow:Assets written to: res_ci_cifar10_best\assets


INFO:tensorflow:Assets written to: res_ci_cifar10_best\assets


Epoch 2/20



INFO:tensorflow:Assets written to: res_ci_cifar10_best\assets


INFO:tensorflow:Assets written to: res_ci_cifar10_best\assets


Epoch 3/20



INFO:tensorflow:Assets written to: res_ci_cifar10_best\assets


INFO:tensorflow:Assets written to: res_ci_cifar10_best\assets


Epoch 4/20



INFO:tensorflow:Assets written to: res_ci_cifar10_best\assets


INFO:tensorflow:Assets written to: res_ci_cifar10_best\assets


Epoch 5/20
Epoch 6/20



INFO:tensorflow:Assets written to: res_ci_cifar10_best\assets


INFO:tensorflow:Assets written to: res_ci_cifar10_best\assets


Epoch 7/20



INFO:tensorflow:Assets written to: res_ci_cifar10_best\assets


INFO:tensorflow:Assets written to: res_ci_cifar10_best\assets


Epoch 8/20
Epoch 9/20



INFO:tensorflow:Assets written to: res_ci_cifar10_best\assets


INFO:tensorflow:Assets written to: res_ci_cifar10_best\assets


Epoch 10/20
Epoch 11/20
Epoch 12/20



INFO:tensorflow:Assets written to: res_ci_cifar10_best\assets


INFO:tensorflow:Assets written to: res_ci_cifar10_best\assets


Epoch 13/20



INFO:tensorflow:Assets written to: res_ci_cifar10_best\assets


INFO:tensorflow:Assets written to: res_ci_cifar10_best\assets


Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20



INFO:tensorflow:Assets written to: res_ci_cifar10_best\assets


INFO:tensorflow:Assets written to: res_ci_cifar10_best\assets


Epoch 19/20
Epoch 20/20



INFO:tensorflow:Assets written to: res_ci_cifar10_best\assets


INFO:tensorflow:Assets written to: res_ci_cifar10_best\assets




In [11]:
test_loss, test_accuracy = res_ci.evaluate(test_ds)
print(f"Test accuracy: {test_accuracy:.4f}")

Test accuracy: 0.8671
