In [None]:
import pandas as pd
from transformers import (
    Blip2Processor,
    Blip2ForConditionalGeneration,
    AutoProcessor,
    Blip2ForImageTextRetrieval,
)
from operator import attrgetter

import torch.nn as nn
import os
import re

from transformers import LlavaForConditionalGeneration
import torch

In [None]:
from collections import OrderedDict


def get_leaf_modules(model: nn.Module) -> OrderedDict[str, nn.Module]:
    """
    Returns an ordered dictionary containing only the leaf modules of a PyTorch model.
    Leaf modules are those that do not have any children.
    """
    leaf_modules = OrderedDict()
    for name, module in model.named_modules():
        if not list(module.children()):  # Check if the module has no children
            leaf_modules[name] = module
    return leaf_modules

In [None]:
def compute_bpw(
    leaves,
    quantized_mods,
    total_params,
    vision_bits=None,
    qformer_bits=None,
    llm_bits=None,
    fp_size=16,
):
    total_bits = 0
    vision_params = 0
    qformer_params = 0
    llm_params = 0

    for key, module in leaves.items():
        fp_mod_flag = True

        # check if parameters in module should be quantized
        for q_mod in quantized_mods:
            # add quantized linear bit sizes
            if q_mod in key and isinstance(module, nn.Linear):
                num_el = module.weight.numel()

                if "vision" in q_mod:
                    total_bits += vision_bits * num_el
                    vision_params += num_el
                elif "qformer" in q_mod:
                    total_bits += qformer_bits * num_el
                    qformer_params += num_el
                elif "language" in q_mod:
                    total_bits += llm_bits * num_el
                    llm_params += num_el
                else:
                    raise Exception()

                fp_mod_flag = False

        # full_precision module
        if fp_mod_flag:
            # print(key)
            for param in module.parameters():
                total_bits += fp_size * param.numel()

    print(f"vision q params: {vision_params}")
    print(f"qformer q params: {qformer_params}")
    print(f"llm_params: {llm_params}")

    return total_bits / total_params

In [None]:
# model_name = "Salesforce/blip2-itm-vit-g-coco"
# model = Blip2ForImageTextRetrieval.from_pretrained(model_name)

# leaves = get_leaf_modules(model)
# total_params = sum(p.numel() for p in model.parameters())
# print(total_params)
# quantized_mods = [
#     "vision_model.encoder.layers",
#     "qformer.encoder.layer",
# ]


# model_name = "Salesforce/blip2-opt-2.7b"
# model = Blip2ForConditionalGeneration.from_pretrained(model_name)
# model.to('cpu')

# leaves = get_leaf_modules(model)
# total_params = sum(p.numel() for p in model.parameters())
# quantized_mods = [
#     "vision_model.encoder.layers",
#     "qformer.encoder.layer",
#     "language_model.model.decoder.layers"
# ]


# Load the model
# model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf", torch_dtype=torch.float16)
# # offload model to cpu for now
# model.to('cpu')


# quantized_mods = [
#     "vision_tower.vision_model.encoder.layers",
#     "language_model.model.layers",
# ]

# leaves = get_leaf_modules(model)
# total_params = sum(p.numel() for p in model.parameters())
# print(total_params)

# compute_bpw(leaves, quantized_mods, total_params,
#                                   vision_bits=4,
#                                   qformer_bits=4,
#                                   llm_bits=4)


In [None]:
def compute_bpw_llava(vision_bits, llm_bits, fp_bits=16):
    total_params = 7063427072

    vision_q_params = 301989888
    llm_q_params = 6476005376

    non_q_params = total_params - vision_q_params - llm_q_params

    bpw = (vision_bits * vision_q_params + llm_bits * llm_q_params + fp_bits * non_q_params) / total_params

    return bpw

In [None]:
def compute_bpw_blip_full(vision_bits, qformer_bits, llm_bits, fp_bits=16):
    total_params = 3744761856

    vision_q_params = 984023040
    qformer_q_params = 104988672
    llm_q_params = 2516582400

    non_q_params = total_params - vision_q_params - qformer_q_params - llm_q_params

    bpw = (
        vision_bits * vision_q_params
        + qformer_bits * qformer_q_params
        + llm_bits * llm_q_params
        + fp_bits * non_q_params
    ) / total_params

    return bpw

In [None]:
def compute_bpw_blip_retrieval(vision_bits, qformer_bits, fp_bits=16):
    total_params = 1172623618

    vision_q_params = 984023040
    qformer_q_params = 161611776

    non_q_params = total_params - vision_q_params - qformer_q_params

    bpw = (vision_bits * vision_q_params + qformer_bits * qformer_q_params + fp_bits * non_q_params) / total_params

    return bpw

In [None]:
path = "./blip2/awq/image_captioning/awq_image_captioning.csv"
df_awq_coco = pd.read_csv(path)
df_awq_coco = df_awq_coco.drop(["model_size"], axis=1)
df_awq_coco

In [None]:
# # compute bpw
# model_name = "Salesforce/blip2-opt-2.7b"
# model = Blip2ForConditionalGeneration.from_pretrained(model_name)
# model.to('cpu')

# leaves = get_leaf_modules(model)
# total_params = sum(p.numel() for p in model.parameters())
# quantized_mods = [
#     "vision_model.encoder.layers",
#     "qformer.encoder.layer",
#     "language_model.model.decoder.layers"
# ]

# df_awq_coco['bpw'] = [compute_bpw(leaves, quantized_mods, total_params,
#                                   vision_bits=x['vit_bits'],
#                                   qformer_bits=x['qformer_bits'],
#                                   llm_bits=x['llm_bits']) for x in df_awq_coco.to_dict(orient='records')]

df_awq_coco["bpw"] = [
    compute_bpw_blip_full(
        vision_bits=x["vit_bits"],
        qformer_bits=x["qformer_bits"],
        llm_bits=x["llm_bits"],
    )
    for x in df_awq_coco.to_dict(orient="records")
]

df_awq_coco["quant_method"] = "awq"

In [None]:
df_awq_coco

In [None]:
df_awq_coco.to_csv(os.path.join("./final_results/all_results", "blip2_awq_coco.csv"), index=False)

In [None]:
path = "./blip2/awq/image_text_retrieval/awq_image_text_retrieval.csv"
df_awq_flickr = pd.read_csv(path)
df_awq_flickr

In [None]:
df_awq_flickr["bpw"] = [
    compute_bpw_blip_retrieval(
        vision_bits=x["vit_bits"],
        qformer_bits=x["qformer_bits"],
    )
    for x in df_awq_flickr.to_dict(orient="records")
]

df_awq_flickr["quant_method"] = "awq"
df_awq_flickr = df_awq_flickr.drop(["model_size"], axis=1)

In [None]:
df_awq_flickr

In [None]:
df_awq_flickr.to_csv(os.path.join("./final_results/all_results", "blip2_awq_flickr.csv"), index=False)

In [None]:
# GQA
df_gptq_gqa = pd.read_csv("./final_results/llava/llava_gptq_gqa_results.csv")
df_gptq_gqa.head(5)

In [None]:
df_gptq_gqa["bpw"] = [
    compute_bpw_llava(
        vision_bits=x["vision_bits"],
        llm_bits=x["language_bits"],
    )
    for x in df_gptq_gqa.to_dict(orient="records")
]

df_gptq_gqa["quant_method"] = "gptq"

df_gptq_gqa.head(5)

In [None]:
df_gptq_gqa.to_csv(os.path.join("./final_results/all_results", "llava_gptq_gqa.csv"), index=False)

In [None]:
df_gptq_vqav2 = pd.read_csv("./final_results/llava/llava_gptq_vqav2.csv")

In [None]:
df_gptq_vqav2["bpw"] = [
    compute_bpw_llava(
        vision_bits=x["vision_bits"],
        llm_bits=x["language_bits"],
    )
    for x in df_gptq_vqav2.to_dict(orient="records")
]

df_gptq_vqav2["quant_method"] = "gptq"
df_gptq_vqav2 = df_gptq_vqav2.rename({"agg_metrics": "acc"}, axis=1)

df_gptq_vqav2.head(2)

In [None]:
df_gptq_vqav2.to_csv("./final_results/all_results/llava_gptq_vqav2.csv", index=None)

In [None]:
df_awq_gqa = pd.read_csv("./final_results/llava/llava_awq_gqa.csv")
df_awq_gqa.head(5)

In [None]:
df_awq_gqa["bpw"] = [
    compute_bpw_llava(
        vision_bits=x["vision_bits"],
        llm_bits=x["language_bits"],
    )
    for x in df_awq_gqa.to_dict(orient="records")
]

df_awq_gqa["quant_method"] = "awq"

df_awq_gqa.head(2)

In [None]:
df_awq_gqa.to_csv("./final_results/all_results/llava_awq_gqa.csv", index=None)

In [None]:
df_awq_vqav2 = pd.read_csv("./final_results/llava/llava_awq_vqav2.csv")
df_awq_vqav2.head(5)

In [None]:
df_awq_vqav2["bpw"] = [
    compute_bpw_llava(
        vision_bits=x["vision_bits"],
        llm_bits=x["language_bits"],
    )
    for x in df_awq_vqav2.to_dict(orient="records")
]

df_awq_vqav2["quant_method"] = "awq"
df_awq_vqav2 = df_awq_vqav2.rename({"agg_metrics": "acc"}, axis=1)

df_awq_vqav2.head(2)

In [None]:
df_awq_vqav2.to_csv("./final_results/all_results/llava_awq_vqav2.csv", index=None)

In [None]:
# uniform flickr
df_uniform_flickr = pd.read_csv("./final_results/blip2/uniform/blip2_flickr_results.csv")
df_uniform_flickr.head(5)

In [None]:
len(df_uniform_flickr)

In [None]:
model_name = "Salesforce/blip2-itm-vit-g-coco"
model = Blip2ForImageTextRetrieval.from_pretrained(model_name)

In [None]:
def compute_bpw_uniform(leaves, quantized_mods, total_params, row_dict, fp_size=16):
    total_bits = 0

    for key, module in leaves.items():
        fp_mod_flag = True

        # check if parameters in module should be quantized
        for q_mod in quantized_mods:
            # add quantized linear bit sizes
            if q_mod in key and isinstance(module, nn.Linear):
                num_el = module.weight.numel()

                # parse out layer index and module name
                layer_idx = int(re.findall(r"layer[s]*.(\d*)", key)[-1])
                mod_name = key.split(".")[-1]

                if mod_name == "projection":
                    mod_name = "proj"

                # quantized vision module and layer idx included and mod_name included
                if "vision" in q_mod:
                    # sanity check for nan values
                    if (
                        row_dict["visual_encoder_block_indices"] == row_dict["visual_encoder_block_indices"]
                        and layer_idx in eval(row_dict["visual_encoder_block_indices"])
                        and mod_name in eval(row_dict["visual_encoder_block_modules"])
                    ):
                        # print(layer_idx)
                        # print(mod_name)

                        total_bits += int(row_dict["visual_encoder_block_weight_bits"]) * num_el
                        fp_mod_flag = False

                    # total_bits += vision_bits*num_el

                elif "qformer" in q_mod:  # and \
                    # sanity check for nan values
                    if row_dict["qformer_layer_indices"] == row_dict["qformer_layer_indices"] and layer_idx in eval(
                        row_dict["qformer_layer_indices"]
                    ):
                        qformer_weight_bits = int(row_dict["qformer_weight_bits"])

                        # NOTE: same quantized mods for self/cross-attn
                        if "attention" in key:
                            if row_dict["qformer_self_attention_modules"] == row_dict[
                                "qformer_self_attention_modules"
                            ] and mod_name in eval(row_dict["qformer_self_attention_modules"]):
                                total_bits += qformer_weight_bits * num_el
                                fp_mod_flag = False
                        # img_ff
                        elif "query" in key:
                            if row_dict["qformer_img_ff_modules"] == row_dict["qformer_img_ff_modules"] and any(
                                x in key for x in eval(row_dict["qformer_img_ff_modules"])
                            ):
                                total_bits += qformer_weight_bits * num_el
                                fp_mod_flag = False

                        # text_ff
                        else:
                            if row_dict["qformer_text_ff_modules"] == row_dict["qformer_text_ff_modules"] and any(
                                x in key for x in eval(row_dict["qformer_text_ff_modules"])
                            ):
                                total_bits += qformer_weight_bits * num_el
                                fp_mod_flag = False

        # full_precision module
        if fp_mod_flag:
            # print(key)
            for param in module.parameters():
                total_bits += fp_size * param.numel()

    return total_bits / total_params

In [None]:
df_uniform_flickr["visual_encoder_block_modules"].value_counts()

In [None]:
row_dict = df_uniform_flickr.to_dict(orient="records")[202]

In [None]:
row_dict.keys()

In [None]:
row_dict["qformer_layer_indices"] == row_dict["qformer_layer_indices"]

In [None]:
row_dict["qformer_self_attention_modules"]

In [None]:
row_dict["qformer_cross_attention_modules"]

In [None]:
row_dict["qformer_img_ff_modules"]

In [None]:
row_dict["qformer_text_ff_modules"]

In [None]:
row_dict["visual_encoder_block_indices"]

In [None]:
row_dict["qformer_weight_bits"]

In [None]:
row_dict["visual_encoder_block_weight_bits"]

In [None]:
leaves = get_leaf_modules(model)
total_params = sum(p.numel() for p in model.parameters())

quantized_mods = [
    "vision_model.encoder.layers",
    "qformer.encoder.layer",
]


df_uniform_flickr["bpw"] = [
    compute_bpw_uniform(leaves, quantized_mods, total_params, row_dict)
    for row_dict in df_uniform_flickr.to_dict(orient="records")
]


df_uniform_flickr["quant_method"] = "uniform"

In [None]:
df_uniform_flickr.bpw.agg(["min", "max"])

In [None]:
df_uniform_flickr.columns

In [None]:
df_export = df_uniform_flickr[
    [
        "txt_r1",
        "txt_r5",
        "txt_r10",
        "txt_r_mean",
        "img_r1",
        "img_r5",
        "img_r10",
        "img_r_mean",
        "r_mean",
        "vit_attn",
        "vit_ff",
        "vit_front_blocks",
        "vit_middle_blocks",
        "vit_end_blocks",
        "vit_weight_bits",
        "qformer_front_blocks",
        "qformer_middle_blocks",
        "qformer_end_blocks",
        "qformer_self_attn",
        "qformer_cross_attn",
        "qformer_text_ff",
        "qformer_img_ff",
        "qformer_weight_bits",
        "Quantized Portion",
        "weight_bits",
        "bpw",
        "quant_method",
    ]
]

df_export

In [None]:
df_export.to_csv(os.path.join("./final_results/all_results", "blip2_uniform_flickr.csv"), index=None)