In [None]:
import tensorflow as tf
import qresnet
import dataset
from tensorflow import keras
from keras import layers
from qkeras import *

In [None]:
x_train, y_train, x_test, y_test = dataset.load_cifar10()

In [None]:
x = x_in = layers.Input(x_train.shape[1:], name="input")
x = QActivation(quantized_relu_po2(4,1,use_stochastic_rounding=True))(x)
x = qresnet.resnet32(x, num_classes=10)
qmodel = keras.Model(inputs=[x_in], outputs=[x])
qmodel.summary()


In [None]:
class ResNetPaperLR(tf.keras.optimizers.schedules.LearningRateSchedule):
    def __init__(self, initial_lr = 0.1, learning_rate_decay=10, steps=[32000, 48000]):
        super(ResNetPaperLR, self).__init__()
        self.initial_lr = initial_lr
        self.learning_rate_decay = learning_rate_decay
        self.steps = steps
        self.current_lr = 0
    def __call__(self, step):        
        step12 = tf.where(step < self.steps[1], self.initial_lr/self.learning_rate_decay, self.initial_lr/(self.learning_rate_decay**2))    
        step01 = tf.where(step < self.steps[0], self.initial_lr, step12)
        return step01
                         

    def get_config(self):
        return {
            "initial_lr": self.initial_lr,
            "learning_rate_decay": self.learning_rate_decay,
            "steps": self.steps
        }


NB_EPOCH = 164
BATCH_SIZE = 128
VERBOSE = 1

OPTIMIZER = keras.optimizers.Adam(learning_rate=ResNetPaperLR(0.001), decay=0.000025)

qmodel.compile(optimizer=OPTIMIZER, loss=keras.losses.CategoricalCrossentropy(from_logits=True, label_smoothing=0.02), metrics=["accuracy"])


In [None]:
history = qmodel.fit(x_train, y_train, batch_size=BATCH_SIZE, epochs=NB_EPOCH, initial_epoch=0, verbose=VERBOSE, validation_data=(x_test, y_test), validation_freq=2)


In [None]:
import matplotlib.pyplot as plt


#plt.plot(history.history["loss"])
plt.plot(history.history["val_loss"])
plt.show()    
#plt.plot(history.history["accuracy"])
plt.plot(history.history["val_accuracy"])
plt.show()    

In [None]:
from qkeras.utils import model_save_quantized_weights, load_qmodel


dic = model_save_quantized_weights(qmodel, "qmodels/resnet32/qmodel_weights")
dic = model_save_quantized_weights(qmodel, "qmodels/resnet32/qmodel_weights.h5")
qmodel.save("qmodels/resnet32/model.h5")
qmodel_load_test = load_qmodel("qmodels/resnet32/model.h5", custom_objects={"ResNetPaperLR":ResNetPaperLR})
qmodel_load_test.evaluate(x_test, y_test)


In [None]:
print_qstats(qmodel)