In [1]:
import torch
import clip
import coremltools as ct
import numpy as np
from PIL import Image

scikit-learn version 1.2.2 is not supported. Minimum required version: 0.17. Maximum required version: 1.1.2. Disabling scikit-learn conversion API.


# 1. Load ViT-B/32 CLIP model

In [4]:
device="cpu"
model, preprocess = clip.load("ViT-B/32", device=device)
text = clip.tokenize("a diagram").to(device)
i = Image.open("IMG_7466.jpg")
image = preprocess(i).unsqueeze(0).to(device)

with torch.no_grad():
    image_features = model.encode_image(image)
    text_features = model.encode_text(text)
    logits_per_image, logits_per_text = model(image, text)
    probs = logits_per_image.softmax(dim=-1).cpu().numpy()

traced = torch.jit.trace(model, (image, text))

# 2. Export TextEncoder

In [5]:
import torch.nn as nn
from collections import OrderedDict

class ResidualAttentionBlock(nn.Module):
    def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
        super().__init__()

        self.attn = nn.MultiheadAttention(d_model, n_head)
        self.ln_1 = LayerNorm(d_model)
        self.mlp = nn.Sequential(OrderedDict([
            ("c_fc", nn.Linear(d_model, d_model * 4)),
            ("gelu", QuickGELU()),
            ("c_proj", nn.Linear(d_model * 4, d_model))
        ]))
        self.ln_2 = LayerNorm(d_model)
        self.attn_mask = attn_mask

    def attention(self, x: torch.Tensor):
        self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
        return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]

    def forward(self, x: torch.Tensor):
        x = x + self.attention(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x
    
class Transformer(nn.Module):
    def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
        super().__init__()
        self.width = width
        self.layers = layers
        self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])

    def forward(self, x: torch.Tensor):
        return self.resblocks(x)

class LayerNorm(nn.LayerNorm):
    """Subclass torch's LayerNorm to handle fp16."""

    def forward(self, x: torch.Tensor):
        orig_type = x.dtype
        ret = super().forward(x.type(torch.float32))
        return ret.type(orig_type)

class QuickGELU(nn.Module):
    def forward(self, x: torch.Tensor):
        return x * torch.sigmoid(1.702 * x)

In [6]:
import torch.nn as nn

class TextEncoder(nn.Module):
    def __init__(self,
                 embed_dim: int,
                 # text
                 context_length: int,
                 vocab_size: int,
                 transformer_width: int,
                 transformer_heads: int,
                 transformer_layers: int
                 ):
        super().__init__()

        self.context_length = context_length

        self.transformer = Transformer(
                width=transformer_width,
                layers=transformer_layers,
                heads=transformer_heads,
                attn_mask=self.build_attention_mask()
        )

        self.vocab_size = vocab_size
        self.token_embedding = nn.Embedding(vocab_size, transformer_width)
        self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
        self.ln_final = LayerNorm(transformer_width)

        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
        self.temperature = nn.Parameter(torch.tensor(0.07))

        self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))

        print(f"text_projection shape: {self.text_projection.shape}")
        self.dtype = torch.float32

        self.initialize_parameters()
    
    def initialize_parameters(self):
        nn.init.normal_(self.token_embedding.weight, std=0.02)
        nn.init.normal_(self.positional_embedding, std=0.01)

        proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
        attn_std = self.transformer.width ** -0.5
        fc_std = (2 * self.transformer.width) ** -0.5
        for block in self.transformer.resblocks:
            nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
            nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
            nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
            nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)

        if self.text_projection is not None:
            nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
        else:
            nn.init.normal_(self.text_projection, std=self.custom_text_config['text_rep_size'] ** -0.5)

    def build_attention_mask(self):
        # lazily create causal attention mask, with full attention between the vision tokens
        # pytorch uses additive attention mask; fill with -inf
        mask = torch.empty(self.context_length, self.context_length)
        mask.fill_(float("-inf"))
        mask.triu_(1)  # zero out the lower diagonal
        return mask

    def forward(self, text):
        # print(f'text: {text}')
        x = self.token_embedding(text).type(self.dtype)  # [batch_size, n_ctx, d_model]

        x = x + self.positional_embedding.type(self.dtype)
        x = x.permute(1, 0, 2)  # NLD -> LND
        x = self.transformer(x)
        x = x.permute(1, 0, 2)  # LND -> NLD
        x = self.ln_final(x).type(self.dtype)
        # x.shape = [batch_size, n_ctx, transformer.width]
        # take features from the eot embedding (eot_token is the highest number in each sequence)
        x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection

        return x

In [7]:
text_encoder = TextEncoder(embed_dim=512, context_length=77, vocab_size=49408, 
                           transformer_width=512, transformer_heads=8, transformer_layers=12)

text_projection shape: torch.Size([512, 512])


In [8]:
text_encoder.load_state_dict(model.state_dict(), strict=False)

_IncompatibleKeys(missing_keys=['temperature'], unexpected_keys=['visual.class_embedding', 'visual.positional_embedding', 'visual.proj', 'visual.conv1.weight', 'visual.ln_pre.weight', 'visual.ln_pre.bias', 'visual.transformer.resblocks.0.attn.in_proj_weight', 'visual.transformer.resblocks.0.attn.in_proj_bias', 'visual.transformer.resblocks.0.attn.out_proj.weight', 'visual.transformer.resblocks.0.attn.out_proj.bias', 'visual.transformer.resblocks.0.ln_1.weight', 'visual.transformer.resblocks.0.ln_1.bias', 'visual.transformer.resblocks.0.mlp.c_fc.weight', 'visual.transformer.resblocks.0.mlp.c_fc.bias', 'visual.transformer.resblocks.0.mlp.c_proj.weight', 'visual.transformer.resblocks.0.mlp.c_proj.bias', 'visual.transformer.resblocks.0.ln_2.weight', 'visual.transformer.resblocks.0.ln_2.bias', 'visual.transformer.resblocks.1.attn.in_proj_weight', 'visual.transformer.resblocks.1.attn.in_proj_bias', 'visual.transformer.resblocks.1.attn.out_proj.weight', 'visual.transformer.resblocks.1.attn.ou

In [9]:
import coremltools as ct

text_encoder.eval()

example_input = clip.tokenize("a diagram").to(device)
traced_model = torch.jit.trace(text_encoder, example_input)
out = traced_model(example_input)

In [11]:
max_seq_length = 77

text_encoder_model = ct.convert(
            traced_model,
            convert_to="mlprogram",
            minimum_deployment_target=ct.target.iOS16,
            inputs=[ct.TensorType(name="prompt",
                                 shape=[1,max_seq_length],
                                 dtype=np.int32)],
            outputs=[ct.TensorType(name="embOutput", dtype=np.float32)],
#             compute_units=ct.ComputeUnit[args.compute_unit],
            # skip_model_load=True,
        )

Converting PyTorch Frontend ==> MIL Ops:  92%|█████████████████████████████████████████████████████▌    | 897/972 [00:00<00:00, 8961.48 ops/s]Saving value type of int64 into a builtin type of int32, might lose precision!
Converting PyTorch Frontend ==> MIL Ops: 100%|█████████████████████████████████████████████████████████▉| 971/972 [00:00<00:00, 8673.74 ops/s]
Running MIL frontend_pytorch pipeline: 100%|██████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 206.05 passes/s]
Running MIL default pipeline: 100%|██████████████████████████████████████████████████████████████████████| 66/66 [00:03<00:00, 20.86 passes/s]
Running MIL backend_mlprogram pipeline: 100%|███████████████████████████████████████████████████████████| 12/12 [00:00<00:00, 842.75 passes/s]


In [12]:
text_encoder_model.save("TextEncoder_float32.mlpackage")

## Validate export  precision

In [29]:
import coremltools as ct

# Load the model
model = ct.models.MLModel('TextEncoder_float32.mlpackage')
text = clip.tokenize("a diagram").to(device)
predictions = model.predict({'prompt': text})

In [46]:
print("PyTorch TextEncoder ckpt out for \"a diagram\":\n>>>", out[0, :10])
print("\nCoreML TextEncoder ckpt out for \"a diagram\":\n>>>", predictions['embOutput'][0, :10])

PyTorch TextEncoder ckpt out for "a diagram":
>>> tensor([ 0.0547, -0.0061,  0.0495,  0.0106,  0.1107, -0.2575, -0.2108, -1.3542,
         0.4390, -0.1328], grad_fn=<SliceBackward0>)

CoreML TextEncoder ckpt out for "a diagram":
>>> [ 0.05474854 -0.00689697  0.04943848  0.01080322  0.11053467 -0.2578125
 -0.21118164 -1.3535156   0.43920898 -0.13305664]


You can see that there is some loss in precision, but it is still acceptable.

# 3. Export ImageEncoder

In [90]:
import torch
import clip
import coremltools as ct
import numpy as np
from PIL import Image

In [91]:
device="cpu"
model, preprocess = clip.load("ViT-B/32", device=device)
i = Image.open("IMG_3628.jpg")
image_orig = preprocess(i).unsqueeze(0).to(device)

In [92]:
traced_image_only = torch.jit.trace(model.visual, image_orig)
out = traced_image_only(image_orig)

In [100]:
import coremltools as ct
# Set the image scale and bias for input image preprocessing
scale = 1/(0.2685697*255.0)
bias = [- 0.48145466/(0.26862954) , - 0.4578275/(0.26130258), - 0.40821073/(0.27577711)]

# imgPIL = Image.open("4111670639918_.pic.png")

image_input_scale = ct.ImageType(name="colorImage",
                           color_layout=ct.colorlayout.RGB,
                           shape=image_orig.shape,
                           scale=scale, bias=bias)


image_encoder_model = ct.convert(
            traced_image_only,
            convert_to="mlprogram",
            minimum_deployment_target=ct.target.iOS16,
            inputs=[image_input_scale],
            outputs=[ct.TensorType(name="embOutput", dtype=np.float32)],
        )


Converting PyTorch Frontend ==> MIL Ops: 100%|█████████████████████████████████████████████████████████████▉| 970/971 [00:00<00:00, 8897.94 ops/s]
Running MIL frontend_pytorch pipeline: 100%|██████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 212.42 passes/s]
Running MIL default pipeline: 100%|██████████████████████████████████████████████████████████████████████████| 64/64 [00:04<00:00, 15.84 passes/s]
Running MIL backend_mlprogram pipeline: 100%|███████████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 678.72 passes/s]


In [101]:
image_encoder_model.save("ImageEncoder_float32.mlpackage")

## Validate export

In [102]:
import coremltools as ct

# Load the model
image_encoder = ct.models.MLModel('ImageEncoder_float32.mlpackage')

from torchvision import transforms
imgPIL = Image.open("IMG_3628.jpg")
imgPIL = imgPIL.resize((224, 224), Image.BICUBIC)
predictions = image_encoder.predict({'colorImage': imgPIL})

In [103]:
print("PyTorch ImageEncoder ckpt out for IMG_3628.jpg:\n>>>", out[0, :10])
print("\nCoreML ImageEncoder ckpt out for IMG_3628.jpg:\n>>>", predictions['embOutput'][0, :10])

PyTorch ImageEncoder ckpt out for IMG_3628.jpg:
>>> tensor([-0.0282,  0.6441, -0.2774, -0.0922,  0.3574,  0.3617, -0.6459,  0.3053,
         0.3879,  0.2529], grad_fn=<SliceBackward0>)

CoreML ImageEncoder ckpt out for IMG_3628.jpg:
>>> [ 0.09521484  0.87402344 -0.2861328  -0.09381104  0.34057617  0.07556152
 -0.22106934  0.41137695  0.01852417  0.16931152]


This time <span style='color:red'> the precision error is larger.</span> This may be caused by the wrong norm. 

## What if no norm?

In [104]:
image_input_scale = ct.ImageType(name="colorImage",
                           color_layout=ct.colorlayout.RGB,
                           shape=image_orig.shape)


image_encoder_model = ct.convert(
            traced_image_only,
            convert_to="mlprogram",
            minimum_deployment_target=ct.target.iOS16,
            inputs=[image_input_scale],
            outputs=[ct.TensorType(name="embOutput", dtype=np.float32)],
        )

image_encoder_model.save("ImageEncoder_float32.mlpackage")

image_encoder = ct.models.MLModel('ImageEncoder_float32.mlpackage')

from torchvision import transforms
imgPIL = Image.open("IMG_3628.jpg")
imgPIL = imgPIL.resize((224, 224), Image.BICUBIC)
predictions = image_encoder.predict({'colorImage': imgPIL})

print("PyTorch ImageEncoder ckpt out for IMG_3628.jpg:\n>>>", out[0, :10])
print("\nCoreML ImageEncoder ckpt out for IMG_3628.jpg:\n>>>", predictions['embOutput'][0, :10])

Converting PyTorch Frontend ==> MIL Ops: 100%|█████████████████████████████████████████████████████████████▉| 970/971 [00:00<00:00, 9653.01 ops/s]
Running MIL frontend_pytorch pipeline: 100%|██████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 214.42 passes/s]
Running MIL default pipeline: 100%|██████████████████████████████████████████████████████████████████████████| 64/64 [00:03<00:00, 18.28 passes/s]
Running MIL backend_mlprogram pipeline: 100%|███████████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 836.44 passes/s]


PyTorch ImageEncoder ckpt out for IMG_3628.jpg:
>>> tensor([-0.0282,  0.6441, -0.2774, -0.0922,  0.3574,  0.3617, -0.6459,  0.3053,
         0.3879,  0.2529], grad_fn=<SliceBackward0>)

CoreML ImageEncoder ckpt out for IMG_3628.jpg:
>>> [ 0.02630615  0.090271    0.04776001  0.04675293  0.41870117 -0.20947266
  0.3623047   0.5439453   0.83984375  0.10150146]


**The error is even worse.**