In [None]:
import tensorflow as tf
from tqdm import tqdm_notebook as tqdm
from tensorflow.keras import Sequential
from tensorflow.keras.losses import SparseCategoricalCrossentropy

In [None]:
from fbnet.blocks import get_super_net, Block
from fbnet.model import FBNet, Trainer
from fbnet.lookup_table import read as read_lookup_table, get_lookup_table

In [None]:
cifar10 = tf.keras.datasets.cifar10
(train_images, train_labels), (test_images, test_labels) = cifar10.load_data()

In [None]:
x_train = train_images.astype('float32') / 255.0
x_test = test_images.astype('float32') / 255.0

split_at = int(x_train.shape[0] * 0.8)
split_at

In [None]:
train_weights_dataset = tf.data.Dataset.from_tensor_slices((x_train[:split_at], train_labels[:split_at]))
train_weights_dataset = train_weights_dataset.shuffle(buffer_size=1024).batch(128)
train_weights_dataset

In [None]:
train_thetas_dataset = tf.data.Dataset.from_tensor_slices((x_train[split_at:], train_labels[split_at:]))
train_thetas_dataset = train_thetas_dataset.shuffle(buffer_size=1024).batch(128)
train_thetas_dataset

In [None]:
test_dataset = tf.data.Dataset.from_tensor_slices((x_test, test_labels)).batch(128)
test_dataset

In [None]:
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, train_labels))
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(128)
train_dataset

In [None]:
super_net = get_super_net(
    num_classes=10,
    bn=True,
    config={'ss': [1,1,2,2,1,1,1,1,1]}
)

In [None]:
lookup_table = read_lookup_table('lookup_table_pi.json') 

In [None]:
# lookup_table = get_lookup_table(super_net)

In [None]:
fbnet = FBNet(super_net, lookup_table)

In [None]:
trainer = Trainer(
    fbnet,
    input_shape=(None, 32, 32, 3),
    initial_temperature=5,
    temperature_decay_rate=0.956,
    temperature_decay_steps=1,
    latency_alpha=0.2,
    latency_beta=0.6,
    weight_lr=0.01,
    weight_momentum=0.9,
    weight_decay=1e-4,
    theta_lr=1e-3,
    theta_beta1 = 0.9,
    theta_beta2 = 0.999,
    theta_decay=5e-4
)

In [None]:
for epoch in tqdm(range(trainer.epoch, 90)):
    print('Start of epoch %d' % (epoch,))

    for step, (x_batch, y_batch) in tqdm(enumerate(train_weights_dataset, start=1)):
        trainer.train_weights(x_batch, y_batch)
        if step % 100 == 0:
            print(
                'training weights step {}: accuracy = {}, mean loss = {}'
                .format(step, trainer.training_accuracy, trainer.training_loss)
            )
    print(
        'training weights step {}: accuracy = {}, mean loss = {}'
        .format(step, trainer.training_accuracy, trainer.training_loss)
    )
    trainer.reset_metrics()

    test_accuracy = trainer.evaluate(tqdm(test_dataset))
    print('test accuracy: {}'.format(test_accuracy))

    if trainer.epoch >= 10:
        for step, (x_batch, y_batch) in tqdm(enumerate(train_thetas_dataset, start=1)):
            trainer.train_thetas(x_batch, y_batch)
            if step % 100 == 0:
                print(
                    'training thetas step {}: accuracy = {}, mean loss = {}'
                    .format(step, trainer.training_accuracy, trainer.training_loss)
                )
        print(
            'training thetas step {}: accuracy = {}, mean loss = {}'
            .format(step, trainer.training_accuracy, trainer.training_loss)
        )
        trainer.reset_metrics()

        test_accuracy = trainer.evaluate(tqdm(test_dataset))
        print('test accuracy: {}'.format(test_accuracy))

    trainer.epoch += 1
    if trainer.epoch % 10 == 0:
        trainer.save_weights(
            'drive/My Drive/fbnet/checkpoints/checkpoints_epoch_{}_accuracy_{:.4f}'
            .format(trainer.epoch, test_accuracy)
        )

In [None]:
# save checkpoints
# trainer.save_weights('PATH')

In [None]:
# print the current temperature
trainer.temperature

In [None]:
# inspect thetas
for weight in trainer.fbnet.weights:
    if 'theta' in weight.name:
        print(weight)

In [None]:
# sample a fbnet
seq_config = trainer.sample_sequential_config()
sampled_fbnet = Sequential.from_config(seq_config, custom_objects={'Block': Block})

In [None]:
# inspect the layers in the sampled fbnet
for layer in sampled_fbnet.layers:
    print(layer.name)

In [None]:
sampled_fbnet.compile(
    optimizer='adam',
    loss=SparseCategoricalCrossentropy(from_logits=True),
    metrics=['accuracy']
)

In [None]:
history = sampled_fbnet.fit(train_dataset, epochs=30, validation_data=test_dataset)

In [None]:
converter = tf.lite.TFLiteConverter.from_keras_model(sampled_fbnet)

In [None]:
tflite_model = converter.convert()

In [None]:
open('data/fbnet.tflite', 'wb').write(tflite_model)

In [None]:
interpreter = tf.lite.Interpreter(model_content=tflite_model)
# interpreter = tf.lite.Interpreter(model_path='data/fbnet.tflite')

In [None]:
interpreter.allocate_tensors()

In [None]:
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

In [None]:
interpreter.set_tensor(input_details[0]['index'], x_train[0:1])
interpreter.invoke()
output_data = interpreter.get_tensor(output_details[0]['index'])

In [None]:
output_data

In [None]:
sampled_fbnet.predict(x_train[0:1])

In [None]:
train_labels[0:1]