In [None]:
import tensorflow as tf
from keras import Sequential
from tqdm.notebook import tqdm
from vit_keras import vit

from ViT.utils import split_functional_model

In [None]:
def save_tflite(model, name):
    # Convert the model.
    converter = tf.lite.TFLiteConverter.from_keras_model(model)
    tflite_model = converter.convert()

    # Save the model.
    with open("models/head/" + str(name) + '.tflite', 'wb') as f:
        f.write(tflite_model)

In [None]:
def igelu(x):
    import math
    a = -0.2888
    b = -1.769
    return 0.5 * x * (1 + (tf.math.tanh(1000 * (x / math.sqrt(2))) * (
                a * (tf.math.minimum((x / math.sqrt(2)) * tf.math.tanh(1000 * (x / math.sqrt(2))), -b) + b) ** 2 + 1)))

In [None]:
model = vit.vit_b16()

In [None]:
def replace_gelu(model):
    new_model = Sequential()
    for layer in model.layers:
        if isinstance(layer, tf.keras.layers.Lambda):
            new_layer = tf.keras.layers.Lambda(igelu)
        else:
            new_layer = layer

        new_model.add(new_layer)

    return new_model

In [None]:
# replace gelu with polynomial approximation of gelu
for layer in model.layers:
    if "encoderblock_" in layer.name:
        layer.mlpblock = replace_gelu(layer.mlpblock)

In [None]:
# forward pass is needed for weight initialization (also sets batch size to 1)
dummy_input = tf.zeros((1,) + model.input_shape[1:])
_ = model(dummy_input)

In [None]:
print("Save full models")
save_tflite(model, "19")
# model.save("models/tail/0")

In [None]:
print("Save partial models")
# skip full model with first and last index
for i in tqdm(range(1, 19)):
    head, tail = split_functional_model(model, i)
    save_tflite(head, i)
    # tail.save("models/tail/" + str(i))