# Demo of Blip2 Quantization, Inference, and Scoring

## 1. Load Model and Quantize

In [None]:
from blip_quantizer import BlipQuantizer, QuantConfig, ModelPart, LayerGroup, LayerType
from quant_functions import uniform_quantization
import torch
from transformers import Blip2ForImageTextRetrieval
from dataset import Flickr30kEvalDataset
from tqdm import tqdm
from PIL import Image
from torch.utils.data import DataLoader
from utils import print_model_structure

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Blip2ForImageTextRetrieval.from_pretrained("Salesforce/blip2-itm-vit-g-coco", torch_dtype=torch.float16)
model = model.to(device)

quantizer = BlipQuantizer(model)
configs = [
    QuantConfig(
        ModelPart.VIT,
        LayerGroup.FIRST,
        LayerType.BOTH,
        uniform_quantization,
        num_bits=8,
    ),
    QuantConfig(
        ModelPart.VIT,
        LayerGroup.MIDDLE,
        LayerType.MLP,
        uniform_quantization,
        num_bits=8,
    ),
    QuantConfig(
        ModelPart.QFORMER,
        LayerGroup.MIDDLE,
        LayerType.MLP,
        uniform_quantization,
        num_bits=4,
    ),
]


print("Quantizing model...")
quantizer.apply_quantization(configs)

# print_model_structure(model)

## 2. Run Inference on Model and Generate a .json File

In [None]:
from datasets import Flickr30kEvalDataset
import numpy as np
import re
import torch
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
from inference_pipeline import InferencePipeline

img_transform = transforms.Compose(
    [
        transforms.Resize((224, 224), interpolation=InterpolationMode.BICUBIC),
        transforms.ToTensor(),
        transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
    ]
)

flickr30k = Flickr30kEvalDataset(
    "./data/flickr30k/annotations/flickr30k_test.json",
    "./data/flickr30k/images_flickr_1k_test",
    img_transform=img_transform,
)

inferencer = InferencePipeline(model, device)

results = inferencer.run_inference(flickr30k, task="image_text_retrieval")

## 3. Score Results from .json File

In [None]:
from scoring_pipeline import ScoringPipeline

scorer = ScoringPipeline()
retrieval_results = scorer.compute_scores(results, "image_text_retrieval")

print(retrieval_results)

## Sample Results

This is not a necessary step but just helps qualitatively understand how the results relate to the captions.

In [None]:
import json

f = open(
    "./results/coco_quantized_inference.json",
)

data = json.load(f)
f.close()

for i in range(0, 5):
    img_id, caption = data["predictions"][i].values()
    references = data["references"][i]
    print(f"Image Id: {img_id}\nPredicted Caption:{caption}")
    print(f"Reference Captions: {' '.join(references)}\n")

### Here's what the first predicted image caption from above looks like:

In [None]:
coco_dataset[0][0]

## Cleanup

In [None]:
import gc

model.to("cpu")
del model, evaluator
gc.collect()
torch.cuda.empty_cache()