In [8]:
!pip install -q flax einops optax

In [9]:
import jax
import jax.numpy as jnp
import flax.linen as nn
from dataclasses import dataclass
from einops import rearrange

In [10]:
@dataclass
class CLIPConfig:
    embed_dim: int = 512

    image_size: int = 224
    patch_size: int = 32
    vision_width: int = 768
    vision_layers: int = 12
    vision_heads: int = 12

    vocab_size: int = 49408
    context_length: int = 77
    text_width: int = 512
    text_layers: int = 12
    text_heads: int = 8

In [11]:
class VisionTransformer(nn.Module):
    cfg: CLIPConfig

    @nn.compact
    def __call__(self, x):
        B = x.shape[0]

        x = nn.Conv(
            features=self.cfg.vision_width,
            kernel_size=(self.cfg.patch_size, self.cfg.patch_size),
            strides=(self.cfg.patch_size, self.cfg.patch_size),
            padding="VALID"
        )(x)

        x = rearrange(x, "b h w c -> b (h w) c")

        cls = self.param(
            "cls_token",
            nn.initializers.zeros,
            (1, 1, self.cfg.vision_width)
        )
        cls = jnp.tile(cls, (B, 1, 1))
        x = jnp.concatenate([cls, x], axis=1)

        pos = self.param(
            "pos_embed",
            nn.initializers.normal(stddev=0.01),
            (1, x.shape[1], self.cfg.vision_width)
        )
        x = x + pos

        for _ in range(self.cfg.vision_layers):
            h = nn.LayerNorm()(x)
            h = nn.SelfAttention(
                num_heads=self.cfg.vision_heads,
                qkv_features=self.cfg.vision_width
            )(h)
            x = x + h
        x = nn.LayerNorm()(x[:, 0])
        return x

In [None]:
class TextTransformer(nn.Module):
    cfg: CLIPConfig

    @nn.compact
    def __call__(self, tokens):
        x = nn.Embed(self.cfg.vocab_size, self.cfg.text_width)(tokens)

        pos = self.param(
            "pos_embed",
            nn.initializers.normal(stddev=0.01),
            (1, self.cfg.context_length, self.cfg.text_width)
        )
        x = x + pos

        causal_mask = nn.attention.make_causal_mask(tokens)

        for _ in range(self.cfg.text_layers):
            h = nn.LayerNorm()(x)
            h = nn.SelfAttention(
                num_heads=self.cfg.text_heads,
                qkv_features=self.cfg.text_width,
            )(h, mask=causal_mask)
            x = x + h

        x = nn.LayerNorm()(x[:, -1])
        return x

In [13]:
class CLIP(nn.Module):
    cfg: CLIPConfig

    @nn.compact
    def __call__(self, images, texts):

        image_feat = VisionTransformer(self.cfg)(images)
        text_feat = TextTransformer(self.cfg)(texts)

        image_emb = nn.Dense(self.cfg.embed_dim, name="image_proj")(image_feat)
        text_emb = nn.Dense(self.cfg.embed_dim, name="text_proj")(text_feat)

        image_emb = image_emb / jnp.linalg.norm(image_emb, axis=-1, keepdims=True)
        text_emb = text_emb / jnp.linalg.norm(text_emb, axis=-1, keepdims=True)

        logit_scale = self.param("logit_scale", nn.initializers.zeros, ())
        logit_scale = jnp.exp(logit_scale)

        logits = logit_scale * image_emb @ text_emb.T
        return logits

In [14]:
cfg = CLIPConfig()
model = CLIP(cfg)

key = jax.random.PRNGKey(0)

dummy_images = jnp.ones((2, 224, 224, 3))
dummy_texts = jnp.ones((2, cfg.context_length), dtype=jnp.int32)

params = model.init(key, dummy_images, dummy_texts)

In [15]:
logits = model.apply(params, dummy_images, dummy_texts)
logits

Array([[0.01149538, 0.01149538],
       [0.01149538, 0.01149538]], dtype=float32)

In [16]:
def count_params(p):
    return sum(x.size for x in jax.tree_util.tree_leaves(p))

count_params(params)

69381121