In [None]:
import tensorflow as tf
import keras_cv
from tensorflow import keras

In [None]:
# load the pipeline, then get the diffusion/denoise mode

model = keras_cv.models.StableDiffusion(img_width=512, img_height=512)
diffusion_model = model.diffusion_model

In [None]:
# find the op/layer that we can use to split the model into two roughly equal chunks

def find_split_layer(model):
    total_size = 0

    for layer in model.layers:
        if layer.weights:
            # print(layer.name)
            if (isinstance(layer.weights, list)):
                  for w in layer.weights:
                    # print(w.shape, w.dtype)
                    total_size = total_size + w.numpy().size
    # print("total size:", total_size)
    half_size = total_size / 2

    first_layers = []
    accumulator = 0 
    for layer in model.layers:
        first_layers.append(layer.name)
        # print(first_layers)
        if layer.weights:
            if (isinstance(layer.weights, list)):
                for w in layer.weights:
                    accumulator = accumulator + w.numpy().size
                if accumulator > half_size:
                    return first_layers, layer.name

In [None]:
# find the edges crossing both chunks
# use them as the output tensors of the first chunk and the input tensors of the second chunk

def find_boundary_tensors(model, first_layers, end_of_first_chunk):
    
    boundary_tensors = []
    boundary_input_layers = []
    in_second_chunk = False
    
    for l in model.layers:
        if in_second_chunk:
            #print(l.name)
            if (isinstance(l.input, list)):
                for i in l.input:
                    #print("  ", i.node.layer.name)
                    if (i.node.layer.name in first_layers):
                        #print("  ", i.node.layer.name)
                        #print(boundary_input_layers)
                        if (i.node.layer.name not in boundary_input_layers):
                            # print(boundary_tensors)
                            boundary_tensors.append(i)
                            boundary_input_layers.append(i.node.layer.name)
            else:
                # print("  whatever", l.input.node.layer.name)
                if (l.input.node.layer.name in first_layers):
                    # print("  yes:", l.input.layer.name)
                    boundary_tensors.append(l.input)
                    boundary_input_layers.append(i.input.name)
                    
        elif (l.name == end_of_first_chunk):
            in_second_chunk = True
            
    return boundary_tensors

In [None]:
first_layers, end_of_first_chunk = find_split_layer(diffusion_model)
boundary_tensors = find_boundary_tensors(diffusion_model, first_layers, end_of_first_chunk)

# construct the two chunks
first_part = keras.Model(diffusion_model.inputs, boundary_tensors)
second_part = keras.Model(boundary_tensors, diffusion_model.outputs)

In [None]:
# convert the two chunks to tflite

converter1 = tf.lite.TFLiteConverter.from_keras_model(first_part)
chunk1 = converter1.convert()
with open('/tmp/sd_diffusion_model_first.tflite', 'wb') as f:
    f.write(chunk1)
    
converter2 = tf.lite.TFLiteConverter.from_keras_model(second_part)
chunk2 = converter2.convert()
with open('/tmp/sd_diffusion_model_second.tflite', 'wb') as f:
    f.write(chunk2)

In [None]:
first_part.save('/tmp/sd/diffusion_model_first')
second_part.save('/tmp/sd/diffusion_model_second')

In [None]:
first_model = tf.saved_model.load('/tmp/sd/diffusion_model_first/')

concrete_func = first_model.signatures[tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
concrete_func.inputs[0].set_shape([1, 77, 768])
concrete_func.inputs[1].set_shape([1, 320])
concrete_func.inputs[2].set_shape([1, 64, 64, 4])
converter1 = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
first_model_fixed_size = converter1.convert()

with open('/tmp/sd_diffusion_model_first_fixed_batch.tflite', 'wb') as f:
    f.write(first_model_fixed_size)

In [None]:
second_model = tf.saved_model.load('/tmp/sd/diffusion_model_second/')

concrete_func_2 = second_model.signatures[tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY]

concrete_func_2.inputs[0].set_shape([1, 8, 8, 1280])
concrete_func_2.inputs[1].set_shape([1, 8, 8, 1280])
concrete_func_2.inputs[2].set_shape([1, 64, 64, 320])
concrete_func_2.inputs[3].set_shape([1, 64, 64, 320])
concrete_func_2.inputs[4].set_shape([1, 64, 64, 320])
concrete_func_2.inputs[5].set_shape([1, 1280])
concrete_func_2.inputs[6].set_shape([1, 16, 16, 1280])
concrete_func_2.inputs[7].set_shape([1, 77, 768])
concrete_func_2.inputs[8].set_shape([1, 16, 16, 1280])
concrete_func_2.inputs[9].set_shape([1, 16, 16, 640])
concrete_func_2.inputs[10].set_shape([1, 32, 32, 640])
concrete_func_2.inputs[11].set_shape([1, 32, 32, 640])
concrete_func_2.inputs[12].set_shape([1, 32, 32, 320])

converter2 = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func_2])
second_model_fixed_size = converter2.convert()

with open('/tmp/sd_diffusion_model_second_fixed_batch.tflite', 'wb') as f:
    f.write(second_model_fixed_size)