In [1]:
from functools import partial

import tensorflow as tf
import tensorflow.keras as keras
import tensorflow.keras.layers as layers
from tensorflow.keras.applications import imagenet_utils
import tensorflow_datasets as tfds
import tensorflow_addons as tfa

In [2]:
IMAGE_SIZE = 256
PATCH_SIZE = 4
EXPANSION_FACTOR = 2
BATCH_SIZE = 32
EPOCHS = 50
LABEL_SMOOTHING_FACTOR = 0.1

In [3]:
train_dataset, val_dataset = tfds.load('tf_flowers', split=['train[:90%]', 'train[90%:]'], as_supervised=True)

print('training: ', train_dataset)
print('validation: ', val_dataset)

training:  <PrefetchDataset shapes: ((None, None, 3), ()), types: (tf.uint8, tf.int64)>
validation:  <PrefetchDataset shapes: ((None, None, 3), ()), types: (tf.uint8, tf.int64)>


In [4]:
train_dataset.cardinality()

<tf.Tensor: shape=(), dtype=int64, numpy=3303>

In [5]:
val_dataset.cardinality()

<tf.Tensor: shape=(), dtype=int64, numpy=367>

In [6]:
LARGER_IMAGE_SIZE = 280

def random_jitter(image, label):
    image = tf.image.resize(image, [LARGER_IMAGE_SIZE, LARGER_IMAGE_SIZE])
    image = tf.image.random_crop(image, [IMAGE_SIZE, IMAGE_SIZE, 3])
    image = tf.image.random_flip_left_right(image)
    # image = tf.image.random_jpeg_quality(image, 80, 100)
    label = tf.one_hot(label, 5)
    return image, label


train_ds = train_dataset.map(random_jitter).batch(BATCH_SIZE)
print(train_ds)

<BatchDataset shapes: ((None, 256, 256, 3), (None, 5)), types: (tf.float32, tf.float32)>


In [7]:
def preprocess_image(image, label):
    image = tf.image.resize(image, [IMAGE_SIZE, IMAGE_SIZE])
    label = tf.one_hot(label, 5)
    return image, label


val_ds = val_dataset.map(preprocess_image).batch(BATCH_SIZE)
print(val_ds)

<BatchDataset shapes: ((None, 256, 256, 3), (None, 5)), types: (tf.float32, tf.float32)>


In [8]:
def inverted_residual(x, expanded_channels, output_channels, strides=1):
    m = layers.Conv2D(expanded_channels, 1, padding="same", use_bias=False)(x)
    m = layers.BatchNormalization()(m)
    m = tf.nn.swish(m)

    padding = 'same'
    if strides == 2:
        m = layers.ZeroPadding2D(padding=[1,1])(m)
        padding = 'valid'
    m = layers.DepthwiseConv2D(3, strides=strides, padding=padding, use_bias=False)(m)
    m = layers.BatchNormalization()(m)
    m = tf.nn.swish(m)

    m = layers.Conv2D(output_channels, 1, padding="same", use_bias=False)(m)
    m = layers.BatchNormalization()(m)

    if tf.math.equal(x.shape[-1], output_channels) and strides == 1:
        return layers.Add()([m, x])
    return m


def mlp(x, hidden_units, dropout_rate):
    for unit in hidden_units:
        x = layers.Dense(unit, activation=tf.nn.swish)(x)
        x = layers.Dropout(dropout_rate)(x)
    return x


def transformer(x, num_layers, projection, num_heads=2):
    norm = partial(layers.LayerNormalization, epsilon=1e-6)
    for _ in range(num_layers):
        x1 = norm()(x)
        attention = layers.MultiHeadAttention(num_heads=num_heads, key_dim=projection, dropout=0.1)(x1, x1)
        x2 = layers.Add()([attention, x])
        x3 = norm()(x)
        x3 = mlp(x3, [x.shape[-1]*2, x.shape[-1]], 0.1)
        x = layers.Add()([x3, x2])
    return x


def vit(x, num_blocks, projection, strides=1):
    local = layers.Conv2D(projection, 3, padding='same', strides=strides, activation=tf.nn.swish)(x)
    local = layers.Conv2D(projection, 1, padding='same', strides=strides, activation=tf.nn.swish)(local)
    num_patches = int((local.shape[1] * local.shape[2]) / PATCH_SIZE)
    patches = layers.Reshape((PATCH_SIZE, num_patches, projection))(local)
    global_features = transformer(patches, num_blocks, projection)
    folded = layers.Reshape((*local.shape[1:-1], projection))(global_features)
    folded = layers.Conv2D(x.shape[-1], 1, padding='same', strides=strides, activation=tf.nn.swish)(folded)
    features = layers.Concatenate(axis=-1)([x, folded])
    features = layers.Conv2D(projection, 3, padding='same', strides=strides, activation=tf.nn.swish)(features)
    return features


def create_mobilevit(output_size):
    inputs = layers.Input(shape=[IMAGE_SIZE, IMAGE_SIZE, 3])
    
    x = layers.Conv2D(16, 3, strides=2, padding='same', activation=tf.nn.swish)(inputs)
    x = inverted_residual(x, 16 * EXPANSION_FACTOR, 16)
    
    x = inverted_residual(x, 16 * EXPANSION_FACTOR, 24, strides=2)
    x = inverted_residual(x, 24 * EXPANSION_FACTOR, 24)
    x = inverted_residual(x, 24 * EXPANSION_FACTOR, 24)
    
    x = inverted_residual(x, 24 * EXPANSION_FACTOR, 48, strides=2)
    x = vit(x, 2, 64)
    
    x = inverted_residual(x, 64 * EXPANSION_FACTOR, 64, strides=2)
    x = vit(x, 4, 80)

    x = inverted_residual(x, 80 * EXPANSION_FACTOR, 80, strides=2)
    x = vit(x, 3, 96)
    
    x = layers.Conv2D(320, 1, padding='same', strides=1, activation=tf.nn.swish)(x)
    x = layers.GlobalAvgPool2D()(x)
    outputs = layers.Dense(output_size, activation='softmax')(x)
    return keras.Model(inputs, outputs)
    

keras.backend.clear_session()
model = create_mobilevit(5)
model.summary()

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 256, 256, 3) 0                                            
__________________________________________________________________________________________________
conv2d (Conv2D)                 (None, 128, 128, 16) 448         input_1[0][0]                    
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, 128, 128, 32) 512         conv2d[0][0]                     
__________________________________________________________________________________________________
batch_normalization (BatchNorma (None, 128, 128, 32) 128         conv2d_1[0][0]                   
______________________________________________________________________________________________

In [9]:
class LogLearningRate(keras.callbacks.Callback):
    def on_epoch_begin(self, epoch, logs=None):
        lr = self.model.optimizer.lr(self.model.optimizer.iterations)
        print(f'learning rate: {lr.numpy()}')


def run_experiment(model, epochs):
    checkpoint_path = '/tmp/checkpoint'
    checkpoint_callback = keras.callbacks.ModelCheckpoint(
        checkpoint_path, monitor='val_accuracy', save_best_only=True)
    log_lr_callback = LogLearningRate()
    
    model.fit(
        train_ds,
        validation_data=val_ds,
        epochs=epochs,
        callbacks=[checkpoint_callback, log_lr_callback])
    
    model.load_weights(checkpoint_path)
    _, accuracy = model.evaluate(val_ds)
    print(f'validation accuracy: {round(accuracy * 100, 2)}%')
    return model


keras.backend.clear_session()

loss_fn = keras.losses.CategoricalCrossentropy(label_smoothing=LABEL_SMOOTHING_FACTOR)

decay_steps = int(len(train_dataset) / BATCH_SIZE * EPOCHS)
lr = tf.keras.optimizers.schedules.CosineDecay(2e-3, decay_steps)
optimizer = keras.optimizers.Adam(lr)

model = create_mobilevit(5)
model.compile(optimizer=optimizer, loss=loss_fn, metrics=['accuracy'])
model = run_experiment(model, EPOCHS)

Epoch 1/50
learning rate: 0.0020000000949949026








INFO:tensorflow:Assets written to: /tmp/checkpoint/assets


INFO:tensorflow:Assets written to: /tmp/checkpoint/assets


Epoch 2/50
learning rate: 0.0019979961216449738
Epoch 3/50
learning rate: 0.0019919921178370714
Epoch 4/50
learning rate: 0.0019820125307887793




INFO:tensorflow:Assets written to: /tmp/checkpoint/assets


INFO:tensorflow:Assets written to: /tmp/checkpoint/assets


Epoch 5/50
learning rate: 0.0019680969417095184




INFO:tensorflow:Assets written to: /tmp/checkpoint/assets


INFO:tensorflow:Assets written to: /tmp/checkpoint/assets


Epoch 6/50
learning rate: 0.0019503012299537659




INFO:tensorflow:Assets written to: /tmp/checkpoint/assets


INFO:tensorflow:Assets written to: /tmp/checkpoint/assets


Epoch 7/50
learning rate: 0.0019286967581138015




INFO:tensorflow:Assets written to: /tmp/checkpoint/assets


INFO:tensorflow:Assets written to: /tmp/checkpoint/assets


Epoch 8/50
learning rate: 0.0019033702556043863
Epoch 9/50
learning rate: 0.001874422887340188




INFO:tensorflow:Assets written to: /tmp/checkpoint/assets


INFO:tensorflow:Assets written to: /tmp/checkpoint/assets


Epoch 10/50
learning rate: 0.0018419709522277117
Epoch 11/50
learning rate: 0.0018061446025967598
Epoch 12/50
learning rate: 0.0017670870292931795
Epoch 13/50
learning rate: 0.0017249551601707935




INFO:tensorflow:Assets written to: /tmp/checkpoint/assets


INFO:tensorflow:Assets written to: /tmp/checkpoint/assets


Epoch 14/50
learning rate: 0.001679917797446251
Epoch 15/50
learning rate: 0.0016321552684530616
Epoch 16/50
learning rate: 0.001581858959980309
Epoch 17/50
learning rate: 0.0015292306197807193




INFO:tensorflow:Assets written to: /tmp/checkpoint/assets


INFO:tensorflow:Assets written to: /tmp/checkpoint/assets


Epoch 18/50
learning rate: 0.0014744813088327646




INFO:tensorflow:Assets written to: /tmp/checkpoint/assets


INFO:tensorflow:Assets written to: /tmp/checkpoint/assets


Epoch 19/50
learning rate: 0.001417830353602767
Epoch 20/50
learning rate: 0.0013595045311376452
Epoch 21/50
learning rate: 0.0012997378362342715
Epoch 22/50
learning rate: 0.0012387699680402875




INFO:tensorflow:Assets written to: /tmp/checkpoint/assets


INFO:tensorflow:Assets written to: /tmp/checkpoint/assets


Epoch 23/50
learning rate: 0.0011768450494855642
Epoch 24/50
learning rate: 0.0011142113944515586
Epoch 25/50
learning rate: 0.0010511199943721294




INFO:tensorflow:Assets written to: /tmp/checkpoint/assets


INFO:tensorflow:Assets written to: /tmp/checkpoint/assets


Epoch 26/50
learning rate: 0.000987823586910963
Epoch 27/50
learning rate: 0.000924576073884964




INFO:tensorflow:Assets written to: /tmp/checkpoint/assets


INFO:tensorflow:Assets written to: /tmp/checkpoint/assets


Epoch 28/50
learning rate: 0.0008616308332420886
Epoch 29/50
learning rate: 0.0007992401951923966




INFO:tensorflow:Assets written to: /tmp/checkpoint/assets


INFO:tensorflow:Assets written to: /tmp/checkpoint/assets


Epoch 30/50
learning rate: 0.0007376540452241898
Epoch 31/50
learning rate: 0.0006771195912733674
Epoch 32/50
learning rate: 0.0006178790354169905
Epoch 33/50
learning rate: 0.0005601702141575515
Epoch 34/50
learning rate: 0.0005042242119088769




INFO:tensorflow:Assets written to: /tmp/checkpoint/assets


INFO:tensorflow:Assets written to: /tmp/checkpoint/assets


Epoch 35/50
learning rate: 0.0004502648371271789




INFO:tensorflow:Assets written to: /tmp/checkpoint/assets


INFO:tensorflow:Assets written to: /tmp/checkpoint/assets


Epoch 36/50
learning rate: 0.00039850923349149525
Epoch 37/50
learning rate: 0.00034916415461339056




INFO:tensorflow:Assets written to: /tmp/checkpoint/assets


INFO:tensorflow:Assets written to: /tmp/checkpoint/assets


Epoch 38/50
learning rate: 0.000302427593851462
Epoch 39/50
learning rate: 0.00025848689256235957
Epoch 40/50
learning rate: 0.00021751809981651604
Epoch 41/50
learning rate: 0.00017968547763302922
Epoch 42/50
learning rate: 0.00014514077338390052




INFO:tensorflow:Assets written to: /tmp/checkpoint/assets


INFO:tensorflow:Assets written to: /tmp/checkpoint/assets


Epoch 43/50
learning rate: 0.00011402214295230806




INFO:tensorflow:Assets written to: /tmp/checkpoint/assets


INFO:tensorflow:Assets written to: /tmp/checkpoint/assets


Epoch 44/50
learning rate: 8.645451453048736e-05
Epoch 45/50
learning rate: 6.254828622331843e-05
Epoch 46/50
learning rate: 4.239934787619859e-05
Epoch 47/50
learning rate: 2.608847717056051e-05
Epoch 48/50
learning rate: 1.3680875781574287e-05




INFO:tensorflow:Assets written to: /tmp/checkpoint/assets


INFO:tensorflow:Assets written to: /tmp/checkpoint/assets


Epoch 49/50
learning rate: 5.226493158261292e-06
Epoch 50/50
learning rate: 7.590651875943877e-07
validation accuracy: 88.01%
