In [2]:
import pandas as pd
from transformers import Blip2Processor, Blip2ForConditionalGeneration, AutoProcessor, Blip2ForImageTextRetrieval
from operator import attrgetter

import torch.nn as nn
import os
import re

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from collections import OrderedDict

def get_leaf_modules(model: nn.Module) -> OrderedDict[str, nn.Module]:
    """
    Returns an ordered dictionary containing only the leaf modules of a PyTorch model.
    Leaf modules are those that do not have any children.
    """
    leaf_modules = OrderedDict()
    for name, module in model.named_modules():
        if not list(module.children()):  # Check if the module has no children
            leaf_modules[name] = module
    return leaf_modules

In [None]:
def compute_bpw(leaves, quantized_mods, total_params, vision_bits = None, qformer_bits =None, llm_bits=None, fp_size = 16):

    total_bits = 0

    for key, module in leaves.items():

        fp_mod_flag = True

        # check if parameters in module should be quantized
        for q_mod in quantized_mods:
            
            # add quantized linear bit sizes
            if q_mod in key and isinstance(module, nn.Linear):
                num_el = module.weight.numel()

                if "vision" in q_mod:
                    total_bits += vision_bits*num_el
                elif "qformer" in q_mod:
                    total_bits += qformer_bits*num_el
                elif "language" in q_mod:
                    total_bits += llm_bits*num_el
                else:
                    raise Exception()
                
                fp_mod_flag = False
        
        # full_precision module
        if fp_mod_flag:
            # print(key)
            for param in module.parameters():
                total_bits += fp_size*param.numel()

    return total_bits / total_params

In [4]:
path = '/fs/cfar-projects/low-bit-vision/final_results/blip2/awq/image_captioning/awq_image_captioning.csv'
df_awq_coco = pd.read_csv(path)
df_awq_coco = df_awq_coco.drop(['model_size'], axis = 1)
df_awq_coco

Unnamed: 0,vit_bits,qformer_bits,llm_bits,METEOR,CIDEr
0,2,2,2,0.029884,0.000790
1,2,2,3,0.149857,0.389274
2,2,2,4,0.183735,0.544352
3,2,2,5,0.188660,0.577806
4,2,2,6,0.192159,0.594062
...,...,...,...,...,...
338,16,16,4,0.266413,1.163837
339,16,16,5,0.270866,1.195060
340,16,16,6,0.278989,1.245283
341,16,16,8,0.280147,1.249383


In [5]:
# compute bpw
model_name = "Salesforce/blip2-opt-2.7b"
model = Blip2ForConditionalGeneration.from_pretrained(model_name)
model.to('cpu')

leaves = get_leaf_modules(model)
total_params = sum(p.numel() for p in model.parameters())
quantized_mods = [
    "vision_model.encoder.layers",
    "qformer.encoder.layer",
    "language_model.model.decoder.layers"
]

df_awq_coco['bpw'] = [compute_bpw(leaves, quantized_mods, total_params,
                                  vision_bits=x['vit_bits'],
                                  qformer_bits=x['qformer_bits'],
                                  llm_bits=x['llm_bits']) for x in df_awq_coco.to_dict(orient='records')]

df_awq_coco['quant_method'] = 'awq'

Loading checkpoint shards: 100%|██████████| 2/2 [00:02<00:00,  1.05s/it]


In [6]:
df_awq_coco

Unnamed: 0,vit_bits,qformer_bits,llm_bits,METEOR,CIDEr,bpw,quant_method
0,2,2,2,0.029884,0.000790,3.063071,awq
1,2,2,3,0.149857,0.389274,3.735099,awq
2,2,2,4,0.183735,0.544352,4.407126,awq
3,2,2,5,0.188660,0.577806,5.079154,awq
4,2,2,6,0.192159,0.594062,5.751181,awq
...,...,...,...,...,...,...,...
338,16,16,4,0.266413,1.163837,8.478457,awq
339,16,16,5,0.270866,1.195060,9.150484,awq
340,16,16,6,0.278989,1.245283,9.822512,awq
341,16,16,8,0.280147,1.249383,11.166566,awq


In [7]:
df_awq_coco.to_csv(os.path.join('/fs/cfar-projects/low-bit-vision/final_results/all_results','blip2_awq_coco.csv'), index=False)

In [8]:
path = '/fs/cfar-projects/low-bit-vision/final_results/blip2/awq/image_text_retrieval/awq_image_text_retrieval.csv'
df_awq_flickr = pd.read_csv(path)
df_awq_flickr

Unnamed: 0,vit_bits,qformer_bits,txt_r1,txt_r5,txt_r10,txt_r_mean,img_r1,img_r5,img_r10,img_r_mean,r_mean,agg_metrics,model_size
0,2,2,67.5,83.0,88.1,79.533333,61.32,81.88,86.72,76.64,78.086667,79.533333,3103760704
1,2,3,83.8,95.7,97.6,92.366667,70.5,89.62,93.62,84.58,88.473333,92.366667,3265519936
2,2,4,84.5,95.4,97.4,92.433333,71.22,89.9,93.62,84.913333,88.673333,92.433333,3427279168
3,2,5,83.9,95.6,97.5,92.333333,71.42,89.74,93.86,85.006667,88.67,92.333333,3589038400
4,2,6,83.7,95.3,97.4,92.133333,71.1,89.82,93.7,84.873333,88.503333,92.133333,3750797632
5,2,8,84.0,95.1,97.3,92.133333,71.2,89.94,93.66,84.933333,88.533333,92.133333,4074316096
6,2,16,84.1,95.1,97.4,92.2,71.24,89.98,93.68,84.966667,88.583333,92.2,5368389952
7,3,2,87.8,94.2,95.5,92.5,82.1,94.94,96.64,91.226667,91.863333,92.5,4088297920
8,3,3,97.2,100.0,100.0,99.066667,88.54,98.18,99.02,95.246667,97.156667,99.066667,4250057152
9,3,4,97.5,100.0,100.0,99.166667,88.52,97.88,99.06,95.153333,97.16,99.166667,4411816384


In [9]:
model_name = "Salesforce/blip2-itm-vit-g-coco"
model = Blip2ForImageTextRetrieval.from_pretrained(model_name)

leaves = get_leaf_modules(model)
total_params = sum(p.numel() for p in model.parameters())
quantized_mods = [
    "vision_model.encoder.layers",
    "qformer.encoder.layer",
]

df_awq_flickr['bpw'] = [compute_bpw(leaves, quantized_mods, total_params,
                                  vision_bits=x['vit_bits'],
                                  qformer_bits=x['qformer_bits'],
                                  llm_bits=None) for x in df_awq_flickr.to_dict(orient='records')]


df_awq_flickr['quant_method'] = 'awq'

df_awq_flickr = df_awq_flickr.drop(['model_size'], axis=1)

In [10]:
df_awq_flickr

Unnamed: 0,vit_bits,qformer_bits,txt_r1,txt_r5,txt_r10,txt_r_mean,img_r1,img_r5,img_r10,img_r_mean,r_mean,agg_metrics,bpw,quant_method
0,2,2,67.5,83.0,88.1,79.533333,61.32,81.88,86.72,76.64,78.086667,79.533333,2.299832,awq
1,2,3,83.8,95.7,97.6,92.366667,70.5,89.62,93.62,84.58,88.473333,92.366667,2.437653,awq
2,2,4,84.5,95.4,97.4,92.433333,71.22,89.9,93.62,84.913333,88.673333,92.433333,2.575473,awq
3,2,5,83.9,95.6,97.5,92.333333,71.42,89.74,93.86,85.006667,88.67,92.333333,2.713294,awq
4,2,6,83.7,95.3,97.4,92.133333,71.1,89.82,93.7,84.873333,88.503333,92.133333,2.851115,awq
5,2,8,84.0,95.1,97.3,92.133333,71.2,89.94,93.66,84.933333,88.533333,92.133333,3.126756,awq
6,2,16,84.1,95.1,97.4,92.2,71.24,89.98,93.68,84.966667,88.583333,92.2,4.229321,awq
7,3,2,87.8,94.2,95.5,92.5,82.1,94.94,96.64,91.226667,91.863333,92.5,3.138995,awq
8,3,3,97.2,100.0,100.0,99.066667,88.54,98.18,99.02,95.246667,97.156667,99.066667,3.276816,awq
9,3,4,97.5,100.0,100.0,99.166667,88.52,97.88,99.06,95.153333,97.16,99.166667,3.414637,awq


In [11]:
df_awq_flickr.to_csv(os.path.join('/fs/cfar-projects/low-bit-vision/final_results/all_results','blip2_awq_flickr.csv'), index=False)

In [3]:
df_awq_flickr = pd.read_csv(os.path.join('/fs/cfar-projects/low-bit-vision/final_results/all_results','blip2_awq_flickr.csv'))
df_awq_flickr.head(2)

Unnamed: 0,vit_bits,qformer_bits,txt_r1,txt_r5,txt_r10,txt_r_mean,img_r1,img_r5,img_r10,img_r_mean,r_mean,agg_metrics,bpw,quant_method
0,2,2,67.5,83.0,88.1,79.533333,61.32,81.88,86.72,76.64,78.086667,79.533333,2.299832,awq
1,2,3,83.8,95.7,97.6,92.366667,70.5,89.62,93.62,84.58,88.473333,92.366667,2.437653,awq


In [5]:
df_awq_flickr.bpw.agg(['min', 'max'])

min     2.299832
max    15.977611
Name: bpw, dtype: float64

In [None]:
# GQA
df_gptq_gqa = pd.read_csv('/fs/cfar-projects/low-bit-vision/final_results/llava/llava_gptq_gqa_results.csv')
df_gptq_gqa.head(5)

Unnamed: 0,vision_bits,language_bits,acc
0,6,5,61.34
1,8,4,60.79
2,3,5,59.59
3,5,6,61.19
4,6,6,61.33


In [14]:
from transformers import LlavaForConditionalGeneration
import torch

# Load the model
model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf", torch_dtype=torch.float16)
# offload model to cpu for now
model.to('cpu')

Loading checkpoint shards: 100%|██████████| 3/3 [00:01<00:00,  2.73it/s]


LlavaForConditionalGeneration(
  (vision_tower): CLIPVisionModel(
    (vision_model): CLIPVisionTransformer(
      (embeddings): CLIPVisionEmbeddings(
        (patch_embedding): Conv2d(3, 1024, kernel_size=(14, 14), stride=(14, 14), bias=False)
        (position_embedding): Embedding(577, 1024)
      )
      (pre_layrnorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
      (encoder): CLIPEncoder(
        (layers): ModuleList(
          (0-23): 24 x CLIPEncoderLayer(
            (self_attn): CLIPSdpaAttention(
              (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
              (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
              (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
              (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
            )
            (layer_norm1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            (mlp): CLIPMLP(
              (activation_fn): Q

In [24]:
quantized_mods = [
    "vision_tower.vision_model.encoder.layers",
    "language_model.model.layers",
   
]

leaves = get_leaf_modules(model)
total_params = sum(p.numel() for p in model.parameters())


df_gptq_gqa['bpw'] = [compute_bpw(leaves, quantized_mods, total_params,
                                  vision_bits=x['vision_bits'],
                                  llm_bits=x['language_bits']) for x in df_gptq_gqa.to_dict(orient='records')]

df_gptq_gqa['quant_method'] = 'gptq'

In [25]:
df_gptq_gqa

Unnamed: 0,vision_bits,language_bits,acc,bpw,quant_method
0,6,5,61.34,5.486759,gptq
1,8,4,60.79,4.655431,gptq
2,3,5,59.59,5.358497,gptq
3,5,6,61.19,6.360841,gptq
4,6,6,61.33,6.403595,gptq
5,2,4,35.78,4.398906,gptq
6,6,16,61.27,15.571956,gptq
7,16,3,55.82,4.080627,gptq
8,3,2,0.0,2.607988,gptq
9,4,5,60.81,5.401251,gptq


In [27]:
df_gptq_gqa.to_csv(os.path.join('/fs/cfar-projects/low-bit-vision/final_results/all_results','llava_gptq_gqa.csv'), index=False)

In [3]:
# uniform flickr
df_uniform_flickr = pd.read_csv('/fs/cfar-projects/low-bit-vision/final_results/blip2/uniform/blip2_flickr_results.csv')
df_uniform_flickr.head(5)

Unnamed: 0,txt_r1,txt_r5,txt_r10,txt_r_mean,img_r1,img_r5,img_r10,img_r_mean,r_mean,agg_metrics,...,qformer_front_blocks,qformer_middle_blocks,qformer_end_blocks,qformer_self_attn,qformer_cross_attn,qformer_text_ff,qformer_img_ff,qformer_weight_bits,Quantized Portion,weight_bits
0,0.0,0.0,0.4,0.133333,0.1,0.34,0.72,0.386667,0.26,0.133333,...,True,False,False,False,False,True,False,2.0,ViT + Q-Former,2.0
1,0.0,0.1,0.3,0.133333,0.06,0.22,0.56,0.28,0.206667,0.133333,...,True,False,False,False,False,True,False,2.0,Q-Former,2.0
2,0.0,0.2,0.2,0.133333,0.14,0.3,0.72,0.386667,0.26,0.133333,...,True,False,True,False,False,True,False,2.0,Q-Former,2.0
3,0.0,0.3,0.4,0.233333,0.14,0.42,0.62,0.393333,0.313333,0.233333,...,True,False,False,True,True,True,False,2.0,ViT + Q-Former,2.0
4,0.0,0.3,0.7,0.333333,0.16,0.6,1.08,0.613333,0.473333,0.333333,...,False,True,False,True,True,True,False,2.0,ViT + Q-Former,2.0


In [4]:
len(df_uniform_flickr)

952

In [5]:
model_name = "Salesforce/blip2-itm-vit-g-coco"
model = Blip2ForImageTextRetrieval.from_pretrained(model_name)

In [7]:
def compute_bpw_uniform(leaves, quantized_mods, total_params, row_dict, fp_size = 16):

    total_bits = 0

    for key, module in leaves.items():

        fp_mod_flag = True

        # check if parameters in module should be quantized
        for q_mod in quantized_mods:
            
            # add quantized linear bit sizes
            if q_mod in key and isinstance(module, nn.Linear):
                num_el = module.weight.numel()

                # parse out layer index and module name
                layer_idx = int(re.findall(r'layer[s]*.(\d*)', key)[-1])
                mod_name = key.split('.')[-1]

                if mod_name == 'projection':
                    mod_name = 'proj'

                # quantized vision module and layer idx included and mod_name included
                if "vision" in q_mod: 
                   
                   # sanity check for nan values    
                    if row_dict['visual_encoder_block_indices'] == row_dict['visual_encoder_block_indices'] and \
                        layer_idx in eval(row_dict['visual_encoder_block_indices']) and \
                        mod_name in eval(row_dict['visual_encoder_block_modules']):
                    
                        # print(layer_idx)
                    # print(mod_name)

                        total_bits += int(row_dict['visual_encoder_block_weight_bits']) * num_el
                        fp_mod_flag = False


                    # total_bits += vision_bits*num_el

                elif "qformer" in q_mod: #and \
                    
                    # sanity check for nan values   
                    if row_dict['qformer_layer_indices'] == row_dict['qformer_layer_indices'] and \
                       layer_idx in eval(row_dict['qformer_layer_indices']):
                        
                        qformer_weight_bits = int(row_dict['qformer_weight_bits'])
                        
                        # NOTE: same quantized mods for self/cross-attn
                        if 'attention' in key:
                            if row_dict['qformer_self_attention_modules'] == row_dict['qformer_self_attention_modules'] and \
                               mod_name in eval(row_dict['qformer_self_attention_modules']):
                                total_bits += qformer_weight_bits * num_el
                                fp_mod_flag = False
                        # img_ff
                        elif 'query' in key:
                            
                            if row_dict['qformer_img_ff_modules'] == row_dict['qformer_img_ff_modules'] and \
                               any(x in key for x in eval(row_dict['qformer_img_ff_modules'])):
                                total_bits += qformer_weight_bits * num_el
                                fp_mod_flag = False

                                
                        # text_ff
                        else:
                            if row_dict['qformer_text_ff_modules'] == row_dict['qformer_text_ff_modules'] and \
                               any(x in key for x in eval(row_dict['qformer_text_ff_modules'])): 
                                total_bits += qformer_weight_bits * num_el
                                fp_mod_flag = False

                            
        # full_precision module
        if fp_mod_flag:
            # print(key)
            for param in module.parameters():
                total_bits += fp_size*param.numel()

    return total_bits / total_params

In [476]:
df_uniform_flickr['visual_encoder_block_modules'].value_counts()

visual_encoder_block_modules
['qkv', 'proj']                  252
['fc1', 'fc2']                   252
['qkv', 'proj', 'fc1', 'fc2']    252
Name: count, dtype: int64

In [477]:
row_dict = df_uniform_flickr.to_dict(orient='records')[202]

In [478]:
row_dict.keys()

dict_keys(['txt_r1', 'txt_r5', 'txt_r10', 'txt_r_mean', 'img_r1', 'img_r5', 'img_r10', 'img_r_mean', 'r_mean', 'agg_metrics', 'model_size', 'visual_encoder_block_modules', 'visual_encoder_block_indices', 'visual_encoder_block_weight_bits', 'qformer_layer_indices', 'qformer_self_attention_modules', 'qformer_self_attention_weight_bits', 'qformer_cross_attention_modules', 'qformer_cross_attention_weight_bits', 'qformer_text_ff_modules', 'qformer_text_ff_weight_bits', 'qformer_img_ff_modules', 'qformer_img_ff_weight_bits', 'job_batch', 'vit_attn', 'vit_ff', 'vit_front_blocks', 'vit_middle_blocks', 'vit_end_blocks', 'vit_weight_bits', 'qformer_front_blocks', 'qformer_middle_blocks', 'qformer_end_blocks', 'qformer_self_attn', 'qformer_cross_attn', 'qformer_text_ff', 'qformer_img_ff', 'qformer_weight_bits', 'Quantized Portion', 'weight_bits'])

In [479]:
row_dict['qformer_layer_indices'] == row_dict['qformer_layer_indices']

True

In [493]:
row_dict['qformer_self_attention_modules']

"['query', 'key', 'value', 'dense']"

In [492]:
row_dict['qformer_cross_attention_modules']

"['query', 'key', 'value', 'dense']"

In [489]:
row_dict['qformer_img_ff_modules']

nan

In [488]:
row_dict['qformer_text_ff_modules']

"['intermediate', 'output']"

In [484]:
row_dict['visual_encoder_block_indices']

'[13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38]'

In [490]:
row_dict['qformer_weight_bits']

4.0

In [491]:
row_dict['visual_encoder_block_weight_bits']

4.0

In [8]:

leaves = get_leaf_modules(model)
total_params = sum(p.numel() for p in model.parameters())

quantized_mods = [
    "vision_model.encoder.layers",
    "qformer.encoder.layer",
]


df_uniform_flickr['bpw'] = [compute_bpw_uniform(leaves, quantized_mods, total_params, row_dict)
                            for row_dict in df_uniform_flickr.to_dict(orient='records')]


df_uniform_flickr['quant_method'] = 'uniform'

In [11]:
df_uniform_flickr.bpw.agg(['min', 
                           'max'])

min     2.299832
max    15.876399
Name: bpw, dtype: float64

In [12]:
df_uniform_flickr.columns

Index(['txt_r1', 'txt_r5', 'txt_r10', 'txt_r_mean', 'img_r1', 'img_r5',
       'img_r10', 'img_r_mean', 'r_mean', 'agg_metrics', 'model_size',
       'visual_encoder_block_modules', 'visual_encoder_block_indices',
       'visual_encoder_block_weight_bits', 'qformer_layer_indices',
       'qformer_self_attention_modules', 'qformer_self_attention_weight_bits',
       'qformer_cross_attention_modules',
       'qformer_cross_attention_weight_bits', 'qformer_text_ff_modules',
       'qformer_text_ff_weight_bits', 'qformer_img_ff_modules',
       'qformer_img_ff_weight_bits', 'job_batch', 'vit_attn', 'vit_ff',
       'vit_front_blocks', 'vit_middle_blocks', 'vit_end_blocks',
       'vit_weight_bits', 'qformer_front_blocks', 'qformer_middle_blocks',
       'qformer_end_blocks', 'qformer_self_attn', 'qformer_cross_attn',
       'qformer_text_ff', 'qformer_img_ff', 'qformer_weight_bits',
       'Quantized Portion', 'weight_bits', 'bpw', 'quant_method'],
      dtype='object')

In [14]:
df_export = df_uniform_flickr[['txt_r1', 'txt_r5', 'txt_r10', 'txt_r_mean', 'img_r1', 'img_r5',
       'img_r10', 'img_r_mean', 'r_mean', 'vit_attn', 'vit_ff',
       'vit_front_blocks', 'vit_middle_blocks', 'vit_end_blocks',
       'vit_weight_bits', 'qformer_front_blocks', 'qformer_middle_blocks',
       'qformer_end_blocks', 'qformer_self_attn', 'qformer_cross_attn',
       'qformer_text_ff', 'qformer_img_ff', 'qformer_weight_bits',
       'Quantized Portion', 'weight_bits', 'bpw', 'quant_method']]

df_export

Unnamed: 0,txt_r1,txt_r5,txt_r10,txt_r_mean,img_r1,img_r5,img_r10,img_r_mean,r_mean,vit_attn,...,qformer_end_blocks,qformer_self_attn,qformer_cross_attn,qformer_text_ff,qformer_img_ff,qformer_weight_bits,Quantized Portion,weight_bits,bpw,quant_method
0,0.0,0.0,0.4,0.133333,0.10,0.34,0.72,0.386667,0.260000,True,...,False,False,False,True,False,2.0,ViT + Q-Former,2.0,14.529316,uniform
1,0.0,0.1,0.3,0.133333,0.06,0.22,0.56,0.280000,0.206667,False,...,False,False,False,True,False,2.0,Q-Former,2.0,15.761088,uniform
2,0.0,0.2,0.2,0.133333,0.14,0.30,0.72,0.386667,0.260000,False,...,True,False,False,True,False,2.0,Q-Former,2.0,15.535536,uniform
3,0.0,0.3,0.4,0.233333,0.14,0.42,0.62,0.393333,0.313333,True,...,False,True,True,True,False,2.0,ViT + Q-Former,2.0,14.336585,uniform
4,0.0,0.3,0.7,0.333333,0.16,0.60,1.08,0.613333,0.473333,False,...,False,True,True,True,False,2.0,ViT + Q-Former,2.0,12.881694,uniform
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
947,98.0,100.0,100.0,99.333333,88.12,97.88,98.82,94.940000,97.136667,True,...,False,False,False,False,False,2.0,ViT + Q-Former,2.0,14.754867,uniform
948,98.0,100.0,100.0,99.333333,88.12,97.88,98.82,94.940000,97.136667,True,...,False,False,False,False,False,,ViT,2.0,14.754867,uniform
949,98.0,100.0,100.0,99.333333,89.60,98.10,98.96,95.553333,97.443333,False,...,True,True,True,False,True,4.0,Q-Former,4.0,14.910858,uniform
950,98.0,100.0,100.0,99.333333,89.66,98.10,98.92,95.560000,97.446667,False,...,True,True,True,False,True,4.0,Q-Former,4.0,15.269452,uniform


In [16]:
df_export.to_csv(os.path.join('/fs/cfar-projects/low-bit-vision/final_results/all_results', 'blip2_uniform_flickr.csv'), index=None)