In [None]:
import math

import tensorflow as tf
import keras_cv
from tensorflow import keras

In [None]:
# load the pipeline, then get the diffusion/denoise mode

model = keras_cv.models.StableDiffusion(img_width=512, img_height=512)
diffusion_model = model.diffusion_model

In [None]:
# find the op/layer that we can use to split the model into two roughly equal chunks

def find_split_layer(model):
    total_size = 0

    for layer in model.layers:
        if layer.weights:
            # print(layer.name)
            if (isinstance(layer.weights, list)):
                  for w in layer.weights:
                    # print(w.shape, w.dtype)
                    total_size = total_size + w.numpy().size
    # print("total size:", total_size)
    half_size = total_size / 2

    first_layers = []
    accumulator = 0 
    for layer in model.layers:
        first_layers.append(layer.name)
        # print(first_layers)
        if layer.weights:
            if (isinstance(layer.weights, list)):
                for w in layer.weights:
                    accumulator = accumulator + w.numpy().size
                if accumulator > half_size:
                    return first_layers, layer.name

In [None]:
# find the edges crossing both chunks
# use them as the output tensors of the first chunk and the input tensors of the second chunk

def find_boundary_tensors(model, first_layers, end_of_first_chunk):
    
    boundary_tensors = []
    boundary_input_layers = []
    in_second_chunk = False
    
    for l in model.layers:
        if in_second_chunk:
            #print(l.name)
            if (isinstance(l.input, list)):
                for i in l.input:
                    #print("  ", i.node.layer.name)
                    if (i.node.layer.name in first_layers):
                        #print("  ", i.node.layer.name)
                        #print(boundary_input_layers)
                        if (i.node.layer.name not in boundary_input_layers):
                            # print(boundary_tensors)
                            boundary_tensors.append(i)
                            boundary_input_layers.append(i.node.layer.name)
            else:
                # print("  whatever", l.input.node.layer.name)
                if (l.input.node.layer.name in first_layers):
                    # print("  yes:", l.input.layer.name)
                    boundary_tensors.append(l.input)
                    boundary_input_layers.append(i.input.name)
                    
        elif (l.name == end_of_first_chunk):
            in_second_chunk = True
            
    return boundary_tensors

In [None]:
first_layers, end_of_first_chunk = find_split_layer(diffusion_model)
boundary_tensors = find_boundary_tensors(diffusion_model, first_layers, end_of_first_chunk)

# construct the two chunks
first_part = keras.Model(diffusion_model.inputs, boundary_tensors)
second_part = keras.Model(boundary_tensors, diffusion_model.outputs)

In [None]:
prompt_1 = "A watercolor painting of a Golden Retriever at the beach"
encoding_1 = model.encode_text(prompt_1)

def get_timestep_embedding(timestep, batch_size, dim=320, max_period=10000):
    half = dim // 2
    freqs = tf.math.exp(
        -math.log(max_period) * tf.range(0, half, dtype=tf.float32) / half
    )

    args = tf.convert_to_tensor([timestep], dtype=tf.float32) * freqs
    embedding = tf.concat([tf.math.cos(args), tf.math.sin(args)], 0)
    embedding = tf.reshape(embedding, [1, -1])
    return tf.repeat(embedding, batch_size, axis=0)

In [None]:
def representative_data_gen_first():
    for i in range(1):
        em = get_timestep_embedding(i+1, 1) 
        noise = tf.random.normal((1, 64, 64, 4))
    yield ({'input_1': encoding_1, 'input_2': em, 'input_3': noise})

In [None]:
# when converting a Keras model to a tflite model, it's saved to a saved_model first
# in a saved_model, the 13 inputs are named args_0, args_0_1, args_0_2,..., args_0_12
def representative_data_gen_second():
    for i in range(1):
        em = get_timestep_embedding(i+1, 1) 
        noise = tf.random.normal((1, 64, 64, 4))
        a = first_part((noise, em, encoding_1))
        yield ({
            'args_0': a[0],
            'args_0_1': a[1],
            'args_0_2': a[2],
            'args_0_3': a[3],
            'args_0_4': a[4],
            'args_0_5': a[5],
            'args_0_6': a[6],
            'args_0_7': a[7],
            'args_0_8': a[8],
            'args_0_9': a[9],
            'args_0_10': a[10],
            'args_0_11': a[11],
            'args_0_12': a[12]
            })

In [None]:
converter1 = tf.lite.TFLiteConverter.from_keras_model(first_part)
converter1.optimizations = [tf.lite.Optimize.DEFAULT]
converter1.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter1.inference_input_type = tf.int8
converter1.inference_output_type = tf.int8

converter1.representative_dataset = representative_data_gen_first
first_chunk_qint8_tflite = converter1.convert()

converter2 = tf.lite.TFLiteConverter.from_keras_model(second_part)
converter2.optimizations = [tf.lite.Optimize.DEFAULT]
converter2.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter2.inference_input_type = tf.int8
converter2.inference_output_type = tf.int8

converter2.representative_dataset = representative_data_gen_second
second_chunk_qint8_tflite = converter2.convert()

with open('/tmp/diffusion_model_first_qint8.tflite', 'wb') as f:
        f.write(first_chunk_qint8_tflite)

with open('/tmp/diffusion_model_second_qint8.tflite', 'wb') as f:
        f.write(second_chunk_qint8_tflite)