In [1]:
import torch
import torch.nn as nn
from transformers import Blip2Processor, Blip2ForConditionalGeneration, AutoProcessor, Blip2ForImageTextRetrieval
from datasets import COCODataset
from tqdm import tqdm
from PIL import Image
from torch.utils.data import DataLoader
# from utils import print_model_structure

from collections import defaultdict
from functools import partial

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [3]:
model_name = "Salesforce/blip2-opt-2.7b"

model = Blip2ForConditionalGeneration.from_pretrained(model_name)
model = model.to(device)

processor = Blip2Processor.from_pretrained(model_name)


Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  5.82it/s]


In [4]:
# TODO: determine appropriate size for this calibration set
# AutoAWQ defaults to a size of 512

coco_dataset = COCODataset(ann_file='/nfshomes/vla/project_dirs/low-bit-vision/datasets/cocow/annotations/captions_val2017.json',
                           img_dir='/nfshomes/vla/project_dirs/low-bit-vision/datasets/cocow/images/val2017')

# calibration_set = [coco_dataset[0], coco_dataset[1]]

loading annotations into memory...
Done (t=0.04s)
creating index...
index created!


In [14]:
# base class for AWQ quantizer
class BaseAWQQuantizer():
    
    def __init__(self, model, device, inputs_processor, dataset, **kwargs):
        self.model = model
        self.device = device
        self.inputs_processor = inputs_processor
        self.dataset = dataset

        self.group_size = 128

        # TODO: change to something appropriate
        self.n_samples = 2
        
        self.run_model = None

    @torch.no_grad
    def quantize(self):
        layer_groups = self._get_model_layer_groups()
        calibration_set = self._get_calibration_set()
        first_inputs, layer_args, layer_kwargs = self._gather_first_inputs(layer_groups, calibration_set)
        

        for layer_group, modules in layer_groups.items():

            inps = first_inputs[layer_group]
            for i in tqdm(range(len(modules)), desc= f"Quantizing {layer_group}"):
                
                inps = modules[i](inps, *layer_args[layer_group], **layer_kwargs[layer_group])
                inps = inps[0]
                

        return layer_groups, first_inputs, layer_args, layer_kwargs
    
    def _compute_scales(self, layer, inp):

        for mod_name, xs in inp.items():
            x_flat = torch.cat([x.cpu().abs().view(-1, x.shape[-1]) for x in xs], dim = 0)

            # average of absolute value of all channels in linear module
            x_mean = x_flat.mean(0)
            print(x_mean)

            

    def _gather_first_inputs(self, layer_groups, calibration_set):

        first_inputs = {}
        layer_args = {}
        layer_kwargs = {}

        # get input and kwargs to layer 0
        # use this Catcher hack cause forward hooks cannot capture kwargs
        class Catcher(nn.Module):
            def __init__(self, module, layer_group):
                super().__init__()
                self.module = module
                self.layer_group = layer_group

            def forward(self, *args, **kwargs):
                # assume first input to forward is hidden states
                if len(args) > 0:
                    hidden_states = args[0]
                    # del args
                else:
                    first_key = list(kwargs.keys())[0]
                    hidden_states = kwargs.pop(first_key)

                first_inputs[self.layer_group] = hidden_states
                layer_args[self.layer_group] = args[1:]
                layer_kwargs[self.layer_group] = kwargs

                # print(f'args: {type(args)}')
                # print(args)
                # print(f'kwargs: {type(kwargs)}')
                # print(kwargs)

                return self.module.forward(*args, **kwargs)

        # raise ValueError  # early exit to break later inference
       
        for layer_group, modules in layer_groups.items():
            # replace first module in group of layers with a Catcher
            modules[0] = Catcher(modules[0], layer_group)


        self.run_model(calibration_set)
        
        for layer_group, modules in layer_groups.items():
            # restore proper module at beginning of layer group
            modules[0] = modules[0].module
        
        return first_inputs, layer_args, layer_kwargs
       


    def _gather_layer_input(self, layers, calibration_set):

        def input_hook(module, input, output, layer_index, module_name, inputs):
            x = input[0]
            x = x.detach().cpu()

            out = output[0]
            out = out.detach().cpu()

            inputs[layer_index][module_name].append(x)
        

        # list of dicts holding inputs for each layer
        inputs = [defaultdict(list)] * len(layers)
        # list of hooks so we can remove them after
        hooks = []
        
        for i,layer in enumerate(layers):
            named_linears = self._get_named_linears(layer)
            for name, mod in named_linears.items():
                hooks.append(
                    mod.register_forward_hook(partial(input_hook,
                                                      layer_index = i, 
                                                      module_name=name, 
                                                      inputs = inputs))
                )

        
        # TODO: setup proper dataloader for this
        for batch in calibration_set:
            X = self._prepare_input(batch)
            self.run_model(**X)

        # remove hooks from model
        for hook in hooks:
            hook.remove()

        return inputs
        

    # returns all nn.linear within module (a layer)
    def _get_named_linears(self, module):
        return {name: mod for name, mod in module.named_modules() if isinstance(mod, nn.Linear)}

    # return layers of model to consider for quantization (modify with config file)
    def _get_model_layer_groups(self):
        raise NotImplementedError('_get_model_layers')
    
    def _get_calibration_set(self):
        raise NotImplementedError('_get_calibration_set')

    def _prepare_input(self):
        raise NotImplementedError('_prepare_input')
    

class Blip2ForConditionalGenerationAWQQuantizer(BaseAWQQuantizer):

    def __init__(self, model, inputs_processor, dataset):
        assert isinstance(model, Blip2ForConditionalGeneration)

        super().__init__(model, device, inputs_processor, dataset)
        self.run_model = model.generate
        
    def _get_model_layer_groups(self):
        # NOTE: returning all layers for now
        return {'vit_layers': self.model.vision_model.encoder.layers,
                'qformer_layers': self.model.qformer.encoder.layer,
                'llm_layers': self.model.language_model.model.decoder.layers
              }

    def _get_calibration_set(self):
        # NOTE: small set for testing

        samples = []
        n = 0
        for data in self.dataset:
            
            sample = self._prepare_input(data[0])
            samples.append(sample)
            
            n += 1
            if n == self.n_samples:
                break
        
        samples = torch.cat(samples, dim = 0)
        return samples

    def _prepare_input(self, inp):
        X = self.inputs_processor(images=inp, return_tensors="pt").to(device)
        return X['pixel_values']


class Blip2ForImageTextRetrievalAWQQuantizer(BaseAWQQuantizer):

    def __init__(self, model, device, inputs_processor, dataset):
        assert isinstance(model, Blip2ForImageTextRetrieval)
        super().__init__(model, device, inputs_processor, dataset)
        self.run_model = model.forward
        
    def _get_model_layer_groups(self):
        # NOTE: returning all layers for now
        return [*[layer for layer in self.model.vision_model.encoder.layers],
                *[layer for layer in self.model.qformer.encoder.layer]]

    def _get_calibration_set(self):
        return [self.dataset[0], self.dataset[1]]

    def _prepare_input(self, batch):
        X = self.processor(images=batch[0], text=batch[1][0], return_tensors="pt").to(device, torch.float16)
        return X


In [15]:
# model = Blip2ForImageTextRetrieval.from_pretrained("Salesforce/blip2-itm-vit-g", torch_dtype=torch.float16)
# processor = AutoProcessor.from_pretrained("Salesforce/blip2-itm-vit-g")
# model.to(device)

# device

In [16]:
model

Blip2ForConditionalGeneration(
  (vision_model): Blip2VisionModel(
    (embeddings): Blip2VisionEmbeddings(
      (patch_embedding): Conv2d(3, 1408, kernel_size=(14, 14), stride=(14, 14))
    )
    (encoder): Blip2Encoder(
      (layers): ModuleList(
        (0-38): 39 x Blip2EncoderLayer(
          (self_attn): Blip2Attention(
            (dropout): Dropout(p=0.0, inplace=False)
            (qkv): Linear(in_features=1408, out_features=4224, bias=True)
            (projection): Linear(in_features=1408, out_features=1408, bias=True)
          )
          (layer_norm1): LayerNorm((1408,), eps=1e-06, elementwise_affine=True)
          (mlp): Blip2MLP(
            (activation_fn): GELUActivation()
            (fc1): Linear(in_features=1408, out_features=6144, bias=True)
            (fc2): Linear(in_features=6144, out_features=1408, bias=True)
          )
          (layer_norm2): LayerNorm((1408,), eps=1e-06, elementwise_affine=True)
        )
      )
    )
    (post_layernorm): LayerNorm((

In [17]:
b = Blip2ForConditionalGenerationAWQQuantizer(model, processor, coco_dataset)


In [18]:
layers, first_inputs, layer_args, layer_kwargs = b.quantize()

Quantizing vit_layers: 100%|██████████| 39/39 [00:00<00:00, 2491.67it/s]
Quantizing qformer_layers: 100%|██████████| 12/12 [00:00<00:00, 1307.93it/s]
Quantizing llm_layers: 100%|██████████| 32/32 [00:00<00:00, 328.86it/s]


In [21]:
first_inputs['vit_layers']

tensor([[[ 0.7319,  0.2039, -0.1177,  ...,  0.3483, -0.1315, -0.3481],
         [-0.2441,  0.8111, -0.0983,  ...,  0.0234, -0.1451, -0.7400],
         [-0.0418,  1.7882, -0.3203,  ..., -0.0219, -0.0488, -0.1417],
         ...,
         [-0.9605, -0.2865,  0.5448,  ..., -0.0762,  0.4271,  1.2226],
         [-0.6000, -0.1437,  0.1221,  ..., -0.1582,  0.1567,  1.4119],
         [-0.2725, -0.3828,  0.3872,  ...,  0.2089,  0.1482,  0.7765]],

        [[ 0.7319,  0.2039, -0.1177,  ...,  0.3483, -0.1315, -0.3481],
         [ 0.4066,  0.8089, -0.1578,  ...,  0.0691,  0.0124, -0.3917],
         [ 0.2490,  1.6557, -0.3386,  ..., -0.1113,  0.0690, -0.0610],
         ...,
         [-1.0813, -0.3526,  0.3304,  ..., -0.0324,  0.3346,  0.8685],
         [-0.5170, -0.1903, -0.0778,  ..., -0.1721,  0.1628,  1.2161],
         [-0.0057, -0.5496,  0.3215,  ...,  0.1134, -0.0510,  0.9223]]],
       device='cuda:0')

In [25]:
layer_args['vit_layers']

(None,)

In [23]:
layer_kwargs['vit_layers']

{'output_attentions': False}

In [28]:
out = layers['vit_layers'][0](first_inputs['vit_layers'], *layer_args['vit_layers'], **layer_kwargs['vit_layers'])

In [30]:
out[0]

tensor([[[ 0.3378, -0.1946, -0.0252,  ...,  0.6891, -0.1665, -0.2123],
         [-0.5796, -0.0850,  0.8516,  ...,  0.2505,  0.2570, -2.0609],
         [-0.1985,  1.0866,  0.2261,  ..., -0.4743, -0.8173, -0.4517],
         ...,
         [-1.9534, -0.5337,  0.5927,  ...,  0.0530,  0.0967,  1.0787],
         [-1.4451, -0.4682,  0.1352,  ...,  0.0619,  0.0195,  1.4128],
         [-0.4682, -1.1753,  0.9844,  ...,  0.6326,  1.3233,  0.8188]],

        [[ 0.0635, -0.3202,  0.1072,  ...,  0.6998, -0.0949, -0.1115],
         [-0.6654,  0.2025,  0.1473,  ...,  0.6025,  0.0049,  0.0392],
         [-0.9092,  1.0444, -0.1438,  ...,  0.4590,  0.1668,  0.3372],
         ...,
         [-1.1537, -1.1512,  1.0716,  ...,  0.6038, -0.2326,  0.4638],
         [-0.8844, -0.3487,  0.2901,  ...,  0.2318, -0.4852,  1.4794],
         [-0.9237, -0.8065,  0.7050,  ...,  0.6329,  0.0026,  1.1944]]],
       device='cuda:0', grad_fn=<AddBackward0>)

In [31]:
layers['vit_layers'][1](out[0], *layer_args['vit_layers'],  **layer_kwargs['vit_layers'])

(tensor([[[ 0.0026, -0.0644, -0.1888,  ...,  1.0750,  0.1186,  0.5769],
          [-1.0819,  0.1315,  0.9809,  ...,  0.7254,  0.2461, -1.5226],
          [-0.2447,  0.9629,  0.4360,  ..., -0.0313, -1.1122, -0.2950],
          ...,
          [-2.0357,  0.1081,  0.3105,  ...,  0.5053,  0.0971,  1.0850],
          [-1.2883,  0.0584, -0.2668,  ...,  0.5149,  0.1054,  1.6861],
          [-0.4440, -1.2503,  0.4363,  ...,  1.1043,  0.8403,  0.6618]],
 
         [[-0.1361,  0.0898, -0.0243,  ...,  0.9252,  0.0271,  0.6004],
          [-0.6104,  0.3142, -0.1762,  ...,  0.7918, -0.2915,  0.5835],
          [-0.9518,  0.8436, -0.6670,  ...,  0.6580,  0.1580,  0.7574],
          ...,
          [-0.8174, -0.8920,  1.0149,  ...,  0.7384,  0.2734,  0.8379],
          [-1.1361, -0.0739, -0.0643,  ...,  0.7002,  0.1879,  2.2353],
          [-0.8765, -0.5891,  0.4028,  ...,  1.1088,  0.3058,  1.6114]]],
        device='cuda:0', grad_fn=<AddBackward0>),)

In [42]:
first_inputs['qformer_layers']

(tensor([[[-0.7876, -0.3205, -0.0842,  ..., -0.6614, -0.0151, -0.5240],
          [-0.0952, -0.0247,  0.3759,  ..., -0.2078,  0.4916, -0.4537],
          [-0.5563,  0.5105, -0.6659,  ..., -0.2041,  0.5277,  0.8380],
          ...,
          [-0.1493,  1.2919,  1.5551,  ...,  0.2978, -1.4789,  0.2294],
          [ 0.2532,  0.0649, -0.7901,  ..., -0.4740, -1.6942, -0.6370],
          [ 0.3278, -0.4323,  0.2681,  ..., -0.4160,  0.3958, -0.1349]],
 
         [[-0.7876, -0.3205, -0.0842,  ..., -0.6614, -0.0151, -0.5240],
          [-0.0952, -0.0247,  0.3759,  ..., -0.2078,  0.4916, -0.4537],
          [-0.5563,  0.5105, -0.6659,  ..., -0.2041,  0.5277,  0.8380],
          ...,
          [-0.1493,  1.2919,  1.5551,  ...,  0.2978, -1.4789,  0.2294],
          [ 0.2532,  0.0649, -0.7901,  ..., -0.4740, -1.6942, -0.6370],
          [ 0.3278, -0.4323,  0.2681,  ..., -0.4160,  0.3958, -0.1349]]],
        device='cuda:0'),
 tensor([[[[-0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0.

In [44]:
layers['qformer_layers'][0](*first_inputs['qformer_layers'], **layer_kwargs['qformer_layers'])

(tensor([[[-0.5611, -0.4313,  0.7348,  ..., -0.5339,  0.2271, -0.6419],
          [-0.0081,  2.1908,  0.5378,  ...,  0.0602,  0.0924, -0.4073],
          [-0.3419,  0.2889, -0.3378,  ..., -0.3128,  0.8518, -0.0093],
          ...,
          [-0.4123,  1.7034, -0.6390,  ...,  0.3584, -0.3071, -0.3509],
          [-0.3916,  0.2184,  0.2309,  ..., -0.4097,  0.2242, -0.1063],
          [ 0.1907,  1.5663,  0.5063,  ..., -0.2080,  0.2138, -0.3192]],
 
         [[-0.6695, -0.3799,  0.7295,  ..., -0.4814,  0.0236, -0.6724],
          [-0.1965, -0.3873, -0.5339,  ..., -0.6227,  0.0911,  0.0212],
          [-0.8092,  0.3860, -0.2228,  ..., -0.5321, -0.2148, -0.6384],
          ...,
          [-0.1741,  0.9254, -0.0700,  ...,  0.0874, -1.2012, -0.1439],
          [-0.3797,  0.2424, -0.4833,  ..., -0.7007, -0.8832, -0.7238],
          [ 0.4638, -0.7695, -0.3554,  ..., -0.5841, -0.0452, -0.0482]]],
        device='cuda:0', grad_fn=<NativeLayerNormBackward0>),
 (tensor([[[[-1.7426,  0.2527,  1.7284,

In [18]:
layer_kwargs['qformer_layers']

{}

In [15]:
first_inputs['llm_layers'].shape

torch.Size([2, 1, 2560])

In [19]:
layer_kwargs['llm_layers']

{'attention_mask': tensor([[[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
            0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]],
 
 
         [[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
            0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]]],
        device='cuda:0'),
 'layer_head_mask': None,
 'past_key_value': (tensor([[[[ 3.4119e+00, -1.4275e+00,  1.6319e+00,  ...,  5.2653e-01,
              2.4766e+00, -1.3366e+00],
            [ 1.7859e+00, -2.7154e-01,  8.6161e-01,  ..., -8.4367e-02,
              1.5484e+00,  4.6518e-01],
            [ 1.8167e+00, -5.6734e-01,  1.4872e+00,  ..., -7.3273e-01,
              6.8322e-01, -1.6523e-01],
            ...,
            [ 2.3114e+00, -3.1942e-01, -6.0153e-01,  ..., -1.4472e+00,
              1.0140e+00, -6.6369e-01],
            [ 6.2211e-01,  4.0976e-01, -

In [160]:
modules = model.vision_model.encoder.layers
                # *[layer for layer in self.model.qformer.encoder.layer],
                # *[layer for layer in self.model.language_model.model.decoder.layers]]

modules[0] = modules[0].cuda()

In [161]:
calib_set = b._get_calibration_set()
calib_set.shape

torch.Size([2, 3, 224, 224])

In [164]:

inps = []
layer_kwargs = {}
# get input and kwargs to layer 0
# with_kwargs is only supported in PyTorch 2.0
# use this Catcher hack for now
class Catcher(nn.Module):
    def __init__(self, module):
        super().__init__()
        self.module = module

    def forward(self, *args, **kwargs):
        # assume first input to forward is hidden states
        if len(args) > 0:
            hidden_states = args[0]
            del args
        else:
            first_key = list(kwargs.keys())[0]
            hidden_states = kwargs.pop(first_key)

        inps.append(hidden_states)
        layer_kwargs.update(kwargs)
        raise ValueError  # early exit to break later inference


modules[0] = Catcher(modules[0])

try:
    model.generate(calib_set.to(next(model.parameters()).device))
except ValueError:
    pass


In [165]:
modules

ModuleList(
  (0): Catcher(
    (module): Catcher(
      (module): Catcher(
        (module): Blip2EncoderLayer(
          (self_attn): Blip2Attention(
            (dropout): Dropout(p=0.0, inplace=False)
            (qkv): Linear(in_features=1408, out_features=4224, bias=True)
            (projection): Linear(in_features=1408, out_features=1408, bias=True)
          )
          (layer_norm1): LayerNorm((1408,), eps=1e-06, elementwise_affine=True)
          (mlp): Blip2MLP(
            (activation_fn): GELUActivation()
            (fc1): Linear(in_features=1408, out_features=6144, bias=True)
            (fc2): Linear(in_features=6144, out_features=1408, bias=True)
          )
          (layer_norm2): LayerNorm((1408,), eps=1e-06, elementwise_affine=True)
        )
      )
    )
  )
  (1-38): 38 x Blip2EncoderLayer(
    (self_attn): Blip2Attention(
      (dropout): Dropout(p=0.0, inplace=False)
      (qkv): Linear(in_features=1408, out_features=4224, bias=True)
      (projection): Linea

In [166]:
inps

[tensor([[[ 0.7319,  0.2039, -0.1177,  ...,  0.3483, -0.1315, -0.3481],
          [-0.2441,  0.8111, -0.0983,  ...,  0.0234, -0.1451, -0.7400],
          [-0.0418,  1.7882, -0.3203,  ..., -0.0219, -0.0488, -0.1417],
          ...,
          [-0.9605, -0.2865,  0.5448,  ..., -0.0762,  0.4271,  1.2226],
          [-0.6000, -0.1437,  0.1221,  ..., -0.1582,  0.1567,  1.4119],
          [-0.2725, -0.3828,  0.3872,  ...,  0.2089,  0.1482,  0.7765]],
 
         [[ 0.7319,  0.2039, -0.1177,  ...,  0.3483, -0.1315, -0.3481],
          [ 0.4066,  0.8089, -0.1578,  ...,  0.0691,  0.0124, -0.3917],
          [ 0.2490,  1.6557, -0.3386,  ..., -0.1113,  0.0690, -0.0610],
          ...,
          [-1.0813, -0.3526,  0.3304,  ..., -0.0324,  0.3346,  0.8685],
          [-0.5170, -0.1903, -0.0778,  ..., -0.1721,  0.1628,  1.2161],
          [-0.0057, -0.5496,  0.3215,  ...,  0.1134, -0.0510,  0.9223]]],
        device='cuda:0')]

In [None]:
# b = Blip2ForImageTextRetrievalAWQQuantizer(model, processor, coco_dataset)
# inputs = b.quantize()


In [None]:
# TODO: exlude certain linear layers, reading from quant config

In [None]:
# TODO: solve for optimal (per input channel) scaling factor
# TODO: grid search for \alpha which balances protection of salient / non-salient weights