# CLIP-JAX demo

In [None]:
from dataclasses import dataclass, field, asdict
from typing import Any, Mapping
from clip_jax import CLIPModel
from clip_jax.partitions import logical_axis_rules
from jax.experimental.mesh_utils import create_device_mesh
from jax.sharding import Mesh, PartitionSpec, NamedSharding
import jax
import jax.numpy as jnp
from functools import partial
from jax.sharding import PartitionSpec
from jax.experimental.pjit import pjit, with_sharding_constraint
import json
from flax.traverse_util import flatten_dict
from clip_jax.utils import load_config, count_params
import flax.linen as nn

## Create a model

In [None]:
@dataclass
class CLIPConfig:
    projection_dim: int
    text_config: Mapping[str, Any] = field(default_factory=dict)
    vision_config: Mapping[str, Any] = field(default_factory=dict)

In [None]:
config = CLIPConfig(
    projection_dim=128,
    text_config={
        "vocab_size": 50000,
        "hidden_size": 256,
        "max_position_embeddings": 80,
        "num_layers": 2,
        "use_rmsnorm": True,
        "ln_type": "normformer",
        "num_heads": 8,
        "position_embedding_type": "rotary",
        "use_causal_mask": True,
        "mlp_dim": 512,
        "activations": ("relu", "linear"),
    },
    vision_config={
        "image_size": 256,
        "hidden_size": 256,
        "patch_size": 16,
        "num_layers": 2,
        "use_rmsnorm": True,
        "ln_type": "normformer",
        "num_heads": 8,
        "use_causal_mask": False,
        "mlp_dim": 512,
        "activations": ("relu", "linear"),
    },
)

In [None]:
model = CLIPModel(**asdict(config))

In [None]:
# inputs
rng = jax.random.PRNGKey(0)
model_inputs = model.init_inputs(rng)

In [None]:
# model summary
tabulation = model.tabulate(
    **model_inputs, console_kwargs={"width": 400, "force_terminal": False, "force_jupyter": False}
)
# write to a file (too long to be displayed in the notebook)
with open("summary.md", "w") as f:
    f.write(tabulation)

In [None]:
# extract full config
config = {k: v for k, v in asdict(model).items() if k not in ["parent", "name"]}
config

In [None]:
# save config
with open("config.json", "w") as f:
    f.write(json.dumps(config, indent=4))

## Instantiate a model

In [None]:
# load a config
model_path = "../configs/huge-patch16.json"
config = load_config(model_path)

In [None]:
# instantiate model
model = CLIPModel(**config)

In [None]:
# create inputs
rng = jax.random.PRNGKey(0)
model_inputs = model.init_inputs(rng)

In [None]:
# get logical params
logical_params = jax.eval_shape(lambda inputs: model.init(**inputs), model_inputs)["params"]

In [None]:
# Number of parameters
print(f"Number of parameters: {count_params(logical_params):,}")
for k, v in logical_params.items():
    print(f"{k}: {count_params(v):,}")

In [None]:
# get logical spec
logical_spec = nn.get_partition_spec(logical_params)

In [None]:
# view all logical axes
logical_axes = {i for s in flatten_dict(logical_spec).values() for i in s}
logical_axes

In [None]:
# we can manually check params shape, type and axes
for (shape_k, shape_v), (_, axis_v) in zip(flatten_dict(logical_params).items(), flatten_dict(logical_spec).items()):
    # shape, axis, parameter count
    print(shape_k, shape_v.value.dtype, shape_v.value.shape, axis_v, f"{count_params({shape_k: shape_v}):,}")

In [None]:
# get partitioning rules
rules = logical_axis_rules(activation_partitioning_dims=1, parameter_partitioning_dims=2)

In [None]:
# get params spec
params_spec = nn.logical_to_mesh(logical_spec, rules)

In [None]:
# data spec
data_spec = PartitionSpec("data")

In [None]:
# create mesh
mp_devices = 1
dp_devices = jax.local_device_count() // 1
dev_mesh = create_device_mesh((dp_devices, 1))
mesh = Mesh(dev_mesh, ("data", "model"))

In [None]:
@partial(pjit, in_shardings=None, out_shardings=params_spec)
def init_params():
    params = model.init(**model_inputs)["params"]
    return params

In [None]:
with mesh:
    params = init_params()