# CLIP-JAX demo

In [1]:
import json
from dataclasses import asdict
from functools import partial
from io import BytesIO

import flax.linen as nn
import jax
import jax.numpy as jnp
import numpy as np
import requests
from flax.training import checkpoints
from flax.traverse_util import flatten_dict
from jax.experimental.mesh_utils import create_device_mesh
from jax.experimental.pjit import pjit
from jax.sharding import Mesh, PartitionSpec
from PIL import Image
from transformers import AutoTokenizer

from clip_jax import CLIPModel
from clip_jax.data import image_to_logits
from clip_jax.partitions import logical_axis_rules
from clip_jax.utils import count_params, load_config

## Explore a model

In [2]:
# load a config
model_path = "../configs/small-patch16.json"
config = load_config(model_path)

In [3]:
# instantiate model
model = CLIPModel(**config)

In [4]:
# save loaded config (adds potential missing defaults)
config = {k: v for k, v in asdict(model).items() if k not in ["parent", "name"]}

with open("config.json", "w") as f:
    f.write(json.dumps(config, indent=4))

In [5]:
# create inputs
rng = jax.random.PRNGKey(0)
model_inputs = model.init_inputs(rng)

In [6]:
# display summary
tabulation = model.tabulate(
    **model_inputs, console_kwargs={"width": 400, "force_terminal": False, "force_jupyter": False}
)
# write to a file (too long to be displayed in the notebook)
with open("summary.md", "w") as f:
    f.write(tabulation)

In [7]:
# get logical params
logical_params = jax.eval_shape(lambda inputs: model.init(**inputs), model_inputs)["params"]

In [None]:
# Number of parameters
print(f"Number of parameters: {count_params(logical_params):,}")
for k, v in logical_params.items():
    print(f"{k}: {count_params(v):,}")

In [None]:
# get logical spec
logical_spec = nn.get_partition_spec(logical_params)

In [None]:
# view all logical axes
logical_axes = {i for s in flatten_dict(logical_spec).values() for i in s}
logical_axes

In [None]:
# get partition spec
rules = logical_axis_rules(activation_partitioning_dims=1, parameter_partitioning_dims=1)
params_spec = nn.logical_to_mesh(logical_spec, rules)
data_spec = PartitionSpec("data")

In [None]:
# create mesh
mp_devices = 1
dp_devices = jax.local_device_count() // 1
dev_mesh = create_device_mesh((dp_devices, 1))
mesh = Mesh(dev_mesh, ("data", "model"))

In [None]:
# init params


@partial(pjit, in_shardings=None, out_shardings=params_spec)
def init_params():
    return model.init(**model_inputs)["params"]


with mesh:
    params = init_params()

## Inference

In [None]:
# load tokenizer
tokenizer_name = "openai/clip-vit-base-patch32"
tokenizer_name = "../training/craiyon_tokenizer"
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)

In [None]:
# load model
model_path = "gs://craiyon_models_us_central2/clip/20230430112259"
config = load_config(f"{model_path}/config.json")
model = CLIPModel(**config)

In [None]:
# initialize model
rng = jax.random.PRNGKey(0)
model_inputs = model.init_inputs(rng)
logical_shape = jax.eval_shape(model.init, **model_inputs)["params"]
params = jax.tree_map(lambda x: jnp.zeros(x.shape, dtype=x.dtype), logical_shape)

In [None]:
# restore checkpoint
params = checkpoints.restore_checkpoint(model_path, target=params, prefix="model_")

In [None]:
# inference functions


@jax.jit
def get_text_features(input_ids, attention_mask, params):
    return model.apply(
        {"params": params}, input_ids=input_ids, attention_mask=attention_mask, method=model.get_text_features
    )["text_embeds"]


@jax.jit
def get_image_features(pixel_values, params):
    return model.apply({"params": params}, pixel_values=pixel_values, method=model.get_image_features)["image_embeds"]

In [None]:
# image data
img_url = "https://hips.hearstapps.com/hmg-prod/images/dog-puppy-on-garden-royalty-free-image-1586966191.jpg?crop=0.752xw:1.00xh;0.175xw,0&resize=1200:*"
response = requests.get(img_url)
img = Image.open(BytesIO(response.content))
img = img.resize((256, 256))

In [None]:
# image inference
pixel_values = image_to_logits(img)
pixel_values = pixel_values[np.newaxis, ...]
img_embeds = get_image_features(pixel_values, params)

In [None]:
# text inference
text = "a dog"
text_inputs = tokenizer(
    text, padding="max_length", truncation=True, max_length=config["text_config"]["max_length"], return_tensors="np"
)
text_embeds = get_text_features(
    input_ids=text_inputs["input_ids"], attention_mask=text_inputs["attention_mask"], params=params
)

In [None]:
# calculate similarity
similarity = jnp.matmul(img_embeds, text_embeds.T)
similarity