In [1]:
import torch
import torch.nn as nn
from transformers import Blip2Processor, Blip2ForConditionalGeneration
from datasets import COCODataset
from tqdm import tqdm
from PIL import Image
from torch.utils.data import DataLoader
# from utils import print_model_structure

from collections import defaultdict
from functools import partial

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model_name = "Salesforce/blip2-opt-2.7b"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
processor = Blip2Processor.from_pretrained(model_name)

device



device(type='cuda')

In [3]:
model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b")
model = model.to(device)


Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  5.48it/s]


In [4]:
# TODO: determine appropriate size for this calibration set
# AutoAWQ defaults to a size of 512

coco_dataset = COCODataset(ann_file='/nfshomes/vla/project_dirs/low-bit-vision/datasets/cocow/annotations/captions_val2017.json',
                           img_dir='/nfshomes/vla/project_dirs/low-bit-vision/datasets/cocow/images/val2017')

# calibration_set = [coco_dataset[0], coco_dataset[1]]

loading annotations into memory...
Done (t=0.18s)
creating index...
index created!


In [5]:
# base class for AWQ quantizer
class BaseAWQQuantizer():
    
    def __init__(self, model):
        self.model = model
        self.inputs_processor = None
        self.dataset = None
        self.run_model = None

    
 
    def quantize(self):
        layers = self.get_model_layers()
        calibration_set = self.get_calibration_set()
        inputs = self.gather_inputs(layers, calibration_set)
        
        return inputs

    def gather_inputs(self, layers, calibration_set):

        def input_hook(module, input, output, layer_index, module_name, inputs):
            x = input[0]
            x = x.detach().cpu()

            out = output[0]
            out = out.detach().cpu()

            inputs[layer_index][module_name].append(x)
            # inputs_dict[module_name].append(x)
            # activations_dict[module_name].append(out)

        # list of dicts holding inputs for each layer
        inputs = [defaultdict(list)] * len(layers)

        # list of hooks so we can remove them after
        hooks = []
        
        for i,layer in enumerate(layers):

            named_linears = self.get_named_linears(layer)
            for name, mod in named_linears.items():
                hooks.append(
                    mod.register_forward_hook(partial(input_hook,
                                                      layer_index = i, 
                                                      module_name=name, 
                                                      inputs = inputs))
                )

        
        # TODO: setup proper dataloader for this
        for batch in calibration_set:
            X = self.inputs_processor(images=batch[0], return_tensors="pt").to(device)
            self.run_model(**X)

        # remove hooks from model
        for hook in hooks:
            hook.remove()

        return inputs
        
    # returns all nn.linear within module (a layer)
    def get_named_linears(self, module):
        return {name: mod for name, mod in module.named_modules() if isinstance(mod, nn.Linear)}

    # return layers of model to consider for quantization (modify with config file)
    def get_model_layers(self):
        raise NotImplementedError('get_model_layers')
    
    def get_calibration_set(self):
        raise NotImplementedError('sample_calibration_set')
    

class Blip2AWQQuantizer(BaseAWQQuantizer):

    def __init__(self, model, inputs_processor, dataset):
        assert isinstance(model, Blip2ForConditionalGeneration)
        super().__init__(model)
        
        self.inputs_processor = inputs_processor
        self.dataset = dataset
        self.run_model = model.generate
        
    def get_model_layers(self):
        # NOTE: returning all layers for now
        return [*[layer for layer in self.model.vision_model.encoder.layers],
                *[layer for layer in self.model.qformer.encoder.layer],
                *[layer for layer in self.model.language_model.model.decoder.layers]]

    def get_calibration_set(self):
        return [self.dataset[0], self.dataset[1]]


In [6]:
b = Blip2AWQQuantizer(model, processor, coco_dataset)
inputs = b.quantize()

In [8]:
inputs[0]

defaultdict(list,
            {'self_attn.qkv': [tensor([[[ 0.0006, -0.0002, -0.0024,  ...,  0.0005, -0.0017,  0.0013],
                       [ 0.0005,  0.0015, -0.0020,  ...,  0.0007, -0.0017,  0.0003],
                       [ 0.0005,  0.0040, -0.0032,  ...,  0.0008, -0.0017,  0.0020],
                       ...,
                       [ 0.0005, -0.0015,  0.0014,  ...,  0.0009, -0.0018,  0.0072],
                       [ 0.0005, -0.0010, -0.0010,  ...,  0.0010, -0.0017,  0.0079],
                       [ 0.0005, -0.0017,  0.0004,  ...,  0.0006, -0.0017,  0.0053]]]),
              tensor([[[ 7.7160e-05,  1.1226e-01, -5.9010e-02,  ...,  7.0399e-03,
                        -9.1602e-02, -9.3570e-02],
                       [ 6.3412e-05,  1.5977e-01,  2.3643e-01,  ..., -4.8116e-03,
                         1.4551e-02, -3.2326e-01],
                       [ 6.8773e-05,  4.3250e-01,  2.7647e-02,  ..., -2.3672e-02,
                        -2.1486e-01, -1.1653e-01],
                       ..

In [48]:
for k,v in inputs_dict.items():
    try:
        torch.cat(v, dim=0)
    except:
        print(k)

language_model.model.decoder.layers.0.self_attn.q_proj
language_model.model.decoder.layers.0.self_attn.k_proj
language_model.model.decoder.layers.0.self_attn.v_proj
language_model.model.decoder.layers.0.self_attn.out_proj
language_model.model.decoder.layers.1.self_attn.q_proj
language_model.model.decoder.layers.1.self_attn.k_proj
language_model.model.decoder.layers.1.self_attn.v_proj
language_model.model.decoder.layers.1.self_attn.out_proj
language_model.model.decoder.layers.2.self_attn.q_proj
language_model.model.decoder.layers.2.self_attn.k_proj
language_model.model.decoder.layers.2.self_attn.v_proj
language_model.model.decoder.layers.2.self_attn.out_proj
language_model.model.decoder.layers.3.self_attn.q_proj
language_model.model.decoder.layers.3.self_attn.k_proj
language_model.model.decoder.layers.3.self_attn.v_proj
language_model.model.decoder.layers.3.self_attn.out_proj
language_model.model.decoder.layers.4.self_attn.q_proj
language_model.model.decoder.layers.4.self_attn.k_proj
la

In [52]:
for x in inputs_dict['language_model.model.decoder.layers.0.self_attn.q_proj']:
    print(x[0])

tensor([[-1.5179,  0.3456, -0.4672,  ..., -0.9774,  1.5519, -0.0911],
        [ 1.5346, -1.3748, -0.3652,  ..., -0.3387, -0.8434, -0.2678],
        [ 0.5751, -0.0796, -0.3691,  ...,  0.5701,  0.0799, -0.4667],
        ...,
        [ 0.4681,  0.9722,  0.7046,  ...,  3.3068, -0.1594,  0.6633],
        [ 2.2080, -0.9600, -0.1465,  ..., -0.0391, -1.2819, -0.7929],
        [-1.0473,  0.1923, -0.0220,  ..., -0.7799,  0.4223,  0.2008]])
tensor([[-1.6023, -0.7651,  0.6915,  ...,  0.3768, -1.0898,  0.6077]])
tensor([[-0.3831, -0.1472, -0.6872,  ...,  0.2345,  0.9957, -1.0965]])
tensor([[-0.2948, -1.1158, -1.7285,  ..., -0.9329,  0.3891, -0.2080]])
tensor([[ 1.4899, -0.4618, -1.2205,  ..., -0.8670, -0.2333, -0.1425]])
tensor([[ 1.0793,  0.9101, -0.1845,  ...,  2.1612,  0.3138,  0.5595]])
tensor([[ 0.9503, -0.0749, -0.9298,  ...,  0.2445, -1.3204,  0.3572]])
tensor([[ 1.0121, -0.9353, -0.3473,  ...,  0.5288, -0.3467, -0.0470]])
tensor([[-0.3844,  0.4571, -1.1839,  ..., -1.0338,  2.2866, -0.4311]]

In [44]:
torch.cat(inputs_dict['vision_model.encoder.layers.0.self_attn.qkv'], dim = 0).shape

torch.Size([2, 257, 1408])

In [45]:
inputs_dict['vision_model.encoder.layers.0.self_attn.qkv']

[tensor([[[ 0.0006, -0.0002, -0.0024,  ...,  0.0005, -0.0017,  0.0013],
          [ 0.0005,  0.0015, -0.0020,  ...,  0.0007, -0.0017,  0.0003],
          [ 0.0005,  0.0040, -0.0032,  ...,  0.0008, -0.0017,  0.0020],
          ...,
          [ 0.0005, -0.0015,  0.0014,  ...,  0.0009, -0.0018,  0.0072],
          [ 0.0005, -0.0010, -0.0010,  ...,  0.0010, -0.0017,  0.0079],
          [ 0.0005, -0.0017,  0.0004,  ...,  0.0006, -0.0017,  0.0053]]]),
 tensor([[[ 5.9693e-04, -2.4308e-04, -2.4253e-03,  ...,  5.1895e-04,
           -1.6880e-03,  1.2974e-03],
          [ 5.8581e-04,  1.8597e-03, -2.6897e-03,  ...,  7.2060e-04,
           -1.7048e-03,  8.9084e-04],
          [ 5.7178e-04,  4.4605e-03, -3.7989e-03,  ...,  9.3806e-04,
           -1.7107e-03,  2.1687e-03],
          ...,
          [ 4.7314e-04, -1.6773e-03, -3.2294e-04,  ...,  8.6584e-04,
           -1.7345e-03,  5.2242e-03],
          [ 5.0922e-04, -1.2801e-03, -2.3081e-03,  ...,  1.0080e-03,
           -1.7200e-03,  6.9340e-03],


In [31]:
activations_dict['vision_model.encoder.layers.0.self_attn.qkv'][0]

tensor([[-0.4629, -0.0478,  0.0863,  ..., -0.4253, -0.0268, -0.1028],
        [-0.7950,  0.5636,  0.0464,  ...,  0.1780,  0.2326,  0.4707],
        [-2.1751, -2.1898,  1.5828,  ...,  0.2761, -0.0725,  0.2002],
        ...,
        [-1.0926, -0.3367,  0.6469,  ..., -0.0923,  0.0876,  0.0290],
        [-1.2412, -0.3969,  0.7164,  ...,  0.0388, -0.0405,  0.0099],
        [-0.8863, -0.8924,  0.1963,  ...,  0.2184,  0.1151, -0.1381]])

In [24]:
activations_dict['vision_model.encoder.layers.0.self_attn.qkv']

[tensor([[-0.4629, -0.0478,  0.0863,  ..., -0.4253, -0.0268, -0.1028],
         [-0.7950,  0.5636,  0.0464,  ...,  0.1780,  0.2326,  0.4707],
         [-2.1751, -2.1898,  1.5828,  ...,  0.2761, -0.0725,  0.2002],
         ...,
         [-1.0926, -0.3367,  0.6469,  ..., -0.0923,  0.0876,  0.0290],
         [-1.2412, -0.3969,  0.7164,  ...,  0.0388, -0.0405,  0.0099],
         [-0.8863, -0.8924,  0.1963,  ...,  0.2184,  0.1151, -0.1381]]),
 tensor([[-0.4629, -0.0478,  0.0863,  ..., -0.4253, -0.0268, -0.1028],
         [-0.2788,  0.3219, -0.0484,  ..., -0.0533, -0.0405, -0.0077],
         [-0.2399,  0.3956, -0.1201,  ..., -0.0804, -0.0429, -0.0314],
         ...,
         [-0.7399,  1.0482,  0.2760,  ..., -0.5047, -0.0524, -0.3422],
         [-0.7270,  0.3344,  0.1099,  ...,  0.0149,  0.2532,  0.0598],
         [-0.3091,  0.3351,  0.1163,  ..., -0.0599,  0.0194, -0.0341]])]

In [None]:
# TODO: exlude certain linear layers, reading from quant config

In [None]:
# TODO: solve for optimal (per input channel) scaling factor
# TODO: grid search for \alpha which balances protection of salient / non-salient weights