In [None]:
import math

import tensorflow as tf
import keras_cv
from tensorflow import keras

In [None]:
# load the pipeline, get text_encoder and decoder

model = keras_cv.models.StableDiffusion(img_width=512, img_height=512)

tokenizer = model.tokenizer
text_encoder_model = model.text_encoder
decoder_model = model.decoder

In [None]:
MAX_PROMPT_LENGTH = 77

def get_pos_ids():
    return tf.convert_to_tensor([list(range(MAX_PROMPT_LENGTH))], dtype=tf.int32)

def representative_data_gen_text_encoder():
    for i in range(1):
        inputs = tokenizer.encode('This is a test')
        phrase = inputs + [49407] * (MAX_PROMPT_LENGTH - len(inputs))
        phrase = tf.convert_to_tensor([phrase], dtype=tf.int32)

        yield [phrase, get_pos_ids()]
        
def representative_data_gen_decoder():
    for i in range(1):
        noise = tf.random.normal((1, 64, 64, 4))
        yield [noise]

In [None]:
# convert the two models to tflite

converter1 = tf.lite.TFLiteConverter.from_keras_model(text_encoder_model)
converter1.optimizations = [tf.lite.Optimize.DEFAULT]
converter1.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
# new quantizer cannot handle text_encoder (yet)
converter1.experimental_new_quantizer = False
converter1.representative_dataset = representative_data_gen_text_encoder
tflite_text_encoder_qint8 = converter1.convert()

with open('/tmp/sd_text_encoder_qint8.tflite', 'wb') as f:
    f.write(tflite_text_encoder_qint8)
    
converter2 = tf.lite.TFLiteConverter.from_keras_model(decoder_model)
converter2.optimizations = [tf.lite.Optimize.DEFAULT]
converter2.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter2.representative_dataset = representative_data_gen_decoder
tflite_decoder_qint8 = converter2.convert()
    
with open('/tmp/sd_decoder_qint8.tflite', 'wb') as f:
    f.write(tflite_decoder_qint8)