# CLIP-JAX demo

In [1]:
from dataclasses import dataclass, field, asdict
from typing import Any, Mapping
from clip_jax import CLIPModel
import jax
import jax.numpy as jnp
from jax.sharding import PartitionSpec
from flax.traverse_util import flatten_dict

## Create a model

In [2]:
@dataclass
class CLIPConfig:
    projection_dim: int
    text_config: Mapping[str, Any] = field(default_factory=dict)
    vision_config: Mapping[str, Any] = field(default_factory=dict)

In [4]:
config = CLIPConfig(
    projection_dim=128,
    text_config={
        "vocab_size": 50000,
        "hidden_size": 256,
        "max_position_embeddings": 80,
        "num_layers": 2,
        "use_rmsnorm": True,
        "ln_type": "normformer",
        "num_heads": 8,
        "position_embedding_type": "rotary",
        "use_causal_mask": False,
        "mlp_dim": 512,
        "activations": ("relu", "linear"),
    },
    vision_config={
        "image_size": 256,
        "hidden_size": 256,
        "patch_size": 16,
        "num_layers": 2,
        "use_rmsnorm": True,
        "ln_type": "normformer",
        "num_heads": 8,
        "use_causal_mask": False,
        "mlp_dim": 512,
        "activations": ("relu", "linear"),
    },
)

In [5]:
model = CLIPModel(**asdict(config))

In [6]:
# inputs
rng = jax.random.PRNGKey(0)
model_inputs = model.init_inputs(rng, model)

In [26]:
# model summary
tabulation = model.tabulate(
    **model_inputs, console_kwargs={"width": 400, "force_terminal": False, "force_jupyter": False}
)
# write to a file (too long to be displayed in the notebook)
with open("summary.md", "w") as f:
    f.write(tabulation)