In [None]:
from dataset import load_mnist, load_cifar10
import tensorflow as tf
from resnet import resnet18, resnet34

In [None]:
import tensorflow as tf
import numpy as np



class WarmUpAndCosineDecay(tf.keras.optimizers.schedules.LearningRateSchedule):
    def __init__(self, max_learning_rate, warmup_steps, total_steps):
        super(WarmUpAndCosineDecay, self).__init__()
        self.max_learning_rate = max_learning_rate
        self.warmup_steps = warmup_steps
        self.total_steps = total_steps

    def __call__(self, step):
        linear_warmup = self.max_learning_rate * step / self.warmup_steps
        cosine_decay = 0.5 * self.max_learning_rate * (1 + tf.cos(
            tf.constant(np.pi) * (step - self.warmup_steps) / (self.total_steps - self.warmup_steps)))
        return tf.where(step < self.warmup_steps, linear_warmup, cosine_decay)

    def get_config(self):
        return {
            "max_learning_rate": self.max_learning_rate,
            "warmup_steps": self.warmup_steps,
            "total_steps": self.total_steps
        }
    


In [None]:
_input = tf.keras.layers.Input(shape=(32,32,3))
_output = resnet18(_input, num_classes=10)

model = tf.keras.models.Model(inputs=_input, outputs=_output)
optimizer = tf.keras.optimizers.Adam(learning_rate=WarmUpAndCosineDecay(0.001, 5,  50))
model.compile(optimizer, loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True, label_smoothing=0.8), metrics=['accuracy'])


model.summary()
model.save("resnet18.h5")

In [None]:
x_train, y_train, x_test, y_test = load_cifar10()

In [None]:
model.fit(x_train, y_train, batch_size=256, epochs=50, validation_data=(x_test, y_test), validation_batch_size=16)

