# CLIP model to TFLite

Convert a CLIP model to TFLite with a resizing layer, so the CLIP model can process input image of size 512x512 instead of 224x224.

You may need to use a high RAM instance when running this notebook.

Based originally on this notebook:
https://github.com/freedomtan/clip_score_on_android/blob/main/test_clip_model.ipynb

## Load CLIP model and processor

In [14]:
SAVED_MODEL_DIR = './clip_model'
TFLITE_MODEL_PATH = './clip_model.tflite'
MODEL_NAME = "openai/clip-vit-large-patch14"

In [15]:
from PIL import Image
import requests
import tensorflow as tf

from transformers import TFCLIPModel, CLIPProcessor

# Load the pre-trained CLIP model and processor
model = TFCLIPModel.from_pretrained(MODEL_NAME)
processor = CLIPProcessor.from_pretrained(MODEL_NAME)

# Load the image from the URL
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)

All model checkpoint layers were used when initializing TFCLIPModel.

All the layers of TFCLIPModel were initialized from the model checkpoint at openai/clip-vit-base-patch32.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFCLIPModel for predictions without further training.


In [16]:
# Process the inputs
inputs = processor(
    text=["a photo of a cat"],
    images=image,
    size={"shortest_edge": 512},
    crop_size=512,
    return_tensors="tf",
    padding="max_length",
    truncation=True,
)
for i in inputs:
    print(i, ":", inputs[i].shape)

input_ids : (1, 77)
attention_mask : (1, 77)
pixel_values : (1, 3, 512, 512)


## Convert CLIP model to TF SavedModel

In [17]:
# Create a new model that includes the resize operation
class ResizedModel(tf.keras.Model):
    def __init__(self, original_model):
        super(ResizedModel, self).__init__()
        self.original_model = original_model

    def call(self, attention_mask, input_ids, pixel_values):
        # Resize the pixel values to 224x224. pixel_values is expected to have NHWC layout.
        resized_images = tf.image.resize(pixel_values, [224, 224])
        # convert image from NHWC to NCHW
        resized_images = tf.transpose(resized_images, [0, 3, 1, 2])
        return self.original_model(
            attention_mask=attention_mask,
            input_ids=input_ids,
            pixel_values=resized_images
        )

# Wrap the original model with the resize operation
resized_model = ResizedModel(model)

# Run the model
outputs = resized_model(
    inputs['attention_mask'],
    inputs['input_ids'],
    tf.transpose(inputs['pixel_values'], perm=[0, 2, 3, 1])
)

print('logits_per_image:', outputs['logits_per_image'])
print('logits_per_text:', outputs['logits_per_text'])

# Define a function that will be used as the signature to have named inputs when inspecting the model
@tf.function(input_signature=[
    tf.TensorSpec(shape=[None, 77], dtype=tf.int32, name='attention_mask'),
    tf.TensorSpec(shape=[None, 77], dtype=tf.int32, name='input_ids'),
    tf.TensorSpec(shape=[None, 512, 512, 3], dtype=tf.float32, name='pixel_values')
])
def serving_fn(attention_mask, input_ids, pixel_values):
    output = resized_model(attention_mask, input_ids, pixel_values)
    output_dict = {key: value for key, value in output.items() if isinstance(value, tf.Tensor)}
    print(output_dict)
    return output_dict

# Save the model with the signature
tf.saved_model.save(
    resized_model,
    SAVED_MODEL_DIR,
    signatures={tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY: serving_fn}
)

logits_per_image: tf.Tensor([[24.341438]], shape=(1, 1), dtype=float32)
logits_per_text: tf.Tensor([[24.341438]], shape=(1, 1), dtype=float32)
{'logits_per_image': <tf.Tensor 'resized_model_6/tfclip_model_2/clip/transpose:0' shape=(None, None) dtype=float32>, 'logits_per_text': <tf.Tensor 'resized_model_6/tfclip_model_2/clip/mul:0' shape=(None, None) dtype=float32>, 'text_embeds': <tf.Tensor 'resized_model_6/tfclip_model_2/clip/truediv_1:0' shape=(None, 512) dtype=float32>, 'image_embeds': <tf.Tensor 'resized_model_6/tfclip_model_2/clip/truediv:0' shape=(None, 512) dtype=float32>}




## Convert TF SavedModel to TFLite model

In [18]:
# Load the saved model
loaded = tf.saved_model.load(SAVED_MODEL_DIR)
concrete_func = loaded.signatures[tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY]

# Inspect the concrete function
print(concrete_func.structured_input_signature)
print(concrete_func.structured_outputs)

# Convert the model to TensorFlow Lite
converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
converter.experimental_new_converter = True
tflite_model = converter.convert()

# Save the TensorFlow Lite model
with open(TFLITE_MODEL_PATH, 'wb') as f:
    f.write(tflite_model)



((), {'attention_mask': TensorSpec(shape=(None, 77), dtype=tf.int32, name='attention_mask'), 'pixel_values': TensorSpec(shape=(None, 512, 512, 3), dtype=tf.float32, name='pixel_values'), 'input_ids': TensorSpec(shape=(None, 77), dtype=tf.int32, name='input_ids')})
{'logits_per_text': TensorSpec(shape=(None, None), dtype=tf.float32, name='logits_per_text'), 'image_embeds': TensorSpec(shape=(None, 512), dtype=tf.float32, name='image_embeds'), 'text_embeds': TensorSpec(shape=(None, 512), dtype=tf.float32, name='text_embeds'), 'logits_per_image': TensorSpec(shape=(None, None), dtype=tf.float32, name='logits_per_image')}


## Test the converted TFLite model

In [23]:
# Load the TensorFlow Lite model
i = tf.lite.Interpreter(TFLITE_MODEL_PATH)
i.allocate_tensors()

# Set the input tensors
# convert image from NCHW to NHWC
pixel_values = tf.transpose(inputs['pixel_values'], [0, 2, 3, 1])
assert(pixel_values.shape == (1, 512, 512, 3))
i.set_tensor(0, inputs['attention_mask'])
i.set_tensor(1, inputs['input_ids'])
i.set_tensor(2, pixel_values)

# Run inference
i.invoke()

# Print the outputs of TFLite model
print('TFLite model:')
print(f'logits_per_image', i.get_tensor(i.get_output_details()[1]['index']))
print(f'logits_per_text', i.get_tensor(i.get_output_details()[2]['index']))

# Print the outputs of the original model for comparision
print('Original model:')
print('logits_per_image:', outputs['logits_per_image'].numpy())
print('logits_per_text:', outputs['logits_per_text'].numpy())

TFLite model:
logits_per_image [[24.341446]]
logits_per_text [[24.341446]]
Original model:
logits_per_image: [[24.341438]]
logits_per_text: [[24.341438]]
