In [None]:
import torch

from PIL import Image
import open_clip

In [None]:
model, train_transform, eval_transform = open_clip.create_model_and_transforms('ViT-B-16-plus-240', pretrained='laion400m_e32')

In [None]:
tokenizer = open_clip.get_tokenizer('ViT-B-16-plus-240')

image = eval_transform(Image.open("../docs/CLIP.png")).unsqueeze(0)
text = tokenizer(["a diagram", "a dog", "a cat"])

### Prepare data for optimization

In [None]:
import os
import random
from io import BytesIO
import requests
import numpy as np

def get_pil_from_url(url):
    response = requests.get(url)
    image = Image.open(BytesIO(response.content))
    return image.convert("RGB")

BACKUP_PAIR = (
    get_pil_from_url(
        "https://thumbs.dreamstime.com/t/altai-mountains-mountain-lake-russia-siberia-chuya-ridge-49130812.jpg"
    ),
    "Altai mountains Stock Photography",
)
AVAILABLE_EXAMPLES = []

def check_text_data(data):
    if isinstance(data, str):
        return True
    if isinstance(data, list):
        return all(isinstance(x, str) for x in data)
    return False    

def laion2B_preprocess_train(examples, train_transforms, tokenize_captions, image_column="url", text_column="caption"):
    url = examples[image_column]
    try:
        image = get_pil_from_url(url)
        if not check_text_data(examples[text_column]):
            raise ValueError("Text data is not valid")
        AVAILABLE_EXAMPLES.append((url, examples[text_column]))
    except Exception:
        print(f"Can't load image from url: {url}, using cache with size: {len(AVAILABLE_EXAMPLES)}")
        if len(AVAILABLE_EXAMPLES) > 0:
            backup_id = random.randint(0, len(AVAILABLE_EXAMPLES) - 1)
            backup_example = AVAILABLE_EXAMPLES[backup_id]
            try:
                image = get_pil_from_url(backup_example[0])
                examples[text_column] = backup_example[1]
            except Exception:
                print(f"Can't load image from cached url: {backup_example[0]}, using backup")
                image = BACKUP_PAIR[0].copy()
                examples[text_column] = BACKUP_PAIR[1]
        else:
            print(f"Can't load image from url: {url}, using backup")
            image = BACKUP_PAIR[0].copy()
            examples[text_column] = BACKUP_PAIR[1]

    examples["pixel_values"] = train_transforms(image)
    examples["text"] = tokenize_captions(examples)
    return examples

def tokenize_captions(examples, is_train=True):
    caption_column = "caption"
    captions = []
    caption = examples[caption_column]
    if isinstance(caption, str):
        captions.append(caption)
    elif isinstance(caption, (list, np.ndarray)):
        # take a random caption if there are multiple
        captions.append(random.choice(caption) if is_train else caption[0])
    else:
        raise ValueError(f"Caption column `{caption_column}` should contain either strings or lists of strings.")
    #inputs = tokenizer(captions[0], max_length=tokenizer.model_max_length, padding="do_not_pad", truncation=True)
    #input_ids = inputs.input_ids
    input_ids = tokenizer(captions[0])[0]
    return input_ids

In [None]:
from datasets import load_dataset

max_train_samples = 10000
dataset = load_dataset("laion/laion400m", streaming=True)
train_dataset = dataset["train"].shuffle(seed=42, buffer_size=max_train_samples)

In [None]:
cast_dtype = model.transformer.get_cast_dtype()

def preprocess_text(text):
    with torch.no_grad():
        x = model.token_embedding(text).to(cast_dtype)  # [batch_size, n_ctx, d_model]
        x = x + model.positional_embedding.to(cast_dtype)
    return x

def collate_fn_image(examples):
    examples = [laion2B_preprocess_train(example, train_transform, tokenize_captions) for example in examples]
    pixel_values = torch.stack([example["pixel_values"] for example in examples])
    pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
    
    input_ids = torch.stack([preprocess_text(example["text"]) for example in examples]).permute(1, 0, 2) # NLD -> LND
    return {
        "pixel_values": pixel_values,
        "input_ids": input_ids,
        "attention_masks": model.attn_mask,
    }

In [None]:
import itertools
from tqdm.notebook import tqdm

def prepare_nncf_init_data(dataloader, init_steps):
    nncf_init_data = []

    print(f"Fetching {init_steps} for the initialization...")
    for _, batch in tqdm(zip(range(init_steps), itertools.islice(dataloader, 0, init_steps))):
        with torch.no_grad():
            # Convert images to latent space
            
            nncf_init_data.append(
                (
                    batch["pixel_values"].to("cpu"),
                    batch["input_ids"].to("cpu"),
                    batch["attention_masks"].to("cpu")
                )
            )
    return nncf_init_data

In [None]:
train_batch_size = 1
dataloader_num_workers = 4
train_dataloader = torch.utils.data.DataLoader(
        train_dataset, collate_fn=collate_fn_image, batch_size=train_batch_size, num_workers=dataloader_num_workers
    )

In [None]:
opt_init_steps = 10
init_data = prepare_nncf_init_data(train_dataloader, opt_init_steps)

In [None]:
class InitDataset(torch.utils.data.Dataset):
    def __init__(self, data):
        super().__init__()
        self.init_data = data

    def __len__(self):
        return len(self.init_data)

    def __getitem__(self, index):
        return self.init_data[index]

### Quantize Image Encoder

In [None]:
import nncf

def quantize_image_encoder(model, data_loader):
    quantization_dataset = nncf.Dataset(data_loader, lambda x: x[0])

    quantized_model = nncf.quantize(
                            model,
                            quantization_dataset,
                            model_type=nncf.ModelType.TRANSFORMER,
                            preset=nncf.QuantizationPreset.MIXED,
                            
                            )
    return quantized_model

In [None]:
import openvino.runtime as ov
from pathlib import Path

ov_model_path = Path("image_encoder.xml")

core = ov.Core()
image_encoder = core.read_model(ov_model_path)

q_image_encoder = quantize_image_encoder(image_encoder, InitDataset(init_data))

In [None]:
ov.serialize(q_image_encoder, "q_image_encoder.xml")

### Quantize Text Encoder

In [None]:
def quantize_text_encoder(model, data_loader):
    quantization_dataset = nncf.Dataset(data_loader, lambda x: (x[1], x[2]))

    quantized_model = nncf.quantize(
                            model,
                            quantization_dataset,
                            model_type=nncf.ModelType.TRANSFORMER,
                            preset=nncf.QuantizationPreset.MIXED,
                            
                            )
    return quantized_model

In [None]:
ov_model_path = Path("text_encoder.xml")

core = ov.Core()
text_encoder = core.read_model(ov_model_path)

q_text_encoder = quantize_text_encoder(text_encoder, InitDataset(init_data))

In [None]:
ov.serialize(q_text_encoder, "q_text_encoder.xml")