In [1]:
import torch
import torch.nn as nn
from transformers import Blip2Processor, Blip2ForConditionalGeneration, AutoProcessor, Blip2ForImageTextRetrieval
from datasets import COCODataset
from tqdm import tqdm
from PIL import Image
from torch.utils.data import DataLoader
# from utils import print_model_structure

from collections import defaultdict
from functools import partial

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [3]:
model_name = "Salesforce/blip2-opt-2.7b"

model = Blip2ForConditionalGeneration.from_pretrained(model_name)
model = model.to(device)

processor = Blip2Processor.from_pretrained(model_name)


Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  5.18it/s]


In [4]:
# TODO: determine appropriate size for this calibration set
# AutoAWQ defaults to a size of 512

coco_dataset = COCODataset(ann_file='/nfshomes/vla/project_dirs/low-bit-vision/datasets/cocow/annotations/captions_val2017.json',
                           img_dir='/nfshomes/vla/project_dirs/low-bit-vision/datasets/cocow/images/val2017')

# calibration_set = [coco_dataset[0], coco_dataset[1]]

loading annotations into memory...
Done (t=0.03s)
creating index...
index created!


In [28]:
# base class for AWQ quantizer
class BaseAWQQuantizer():
    
    def __init__(self, model, device, inputs_processor, dataset, **kwargs):
        self.model = model
        self.device = device
        self.inputs_processor = inputs_processor
        self.dataset = dataset

        self.group_size = 128
        
        self.run_model = None


    def quantize(self):
        layers = self._get_model_layers()
        calibration_set = self._get_calibration_set()
        inputs = self._gather_inputs(layers, calibration_set)

        # for i,layer in enumerate(layers):
        #     inp = inputs[i]

        #     # # TODO: some inputs have different shape? cannot just concat them all channelwise
        #     # inp = {k: torch.cat(v, dim=0) for k, v in inp.items()}

        #     self._compute_scales(layer, inp)
        
        return layers, inputs
    
    def _compute_scales(self, layer, inp):

        for mod_name, xs in inp.items():
            x_flat = torch.cat([x.cpu().abs().view(-1, x.shape[-1]) for x in xs], dim = 0)

            # average of absolute value of all channels in linear module
            x_mean = x_flat.mean(0)
            print(x_mean)

            

    def _gather_inputs(self, layers, calibration_set):

        def input_hook(module, input, output, layer_index, module_name, inputs):
            x = input[0]
            x = x.detach().cpu()

            out = output[0]
            out = out.detach().cpu()

            inputs[layer_index][module_name].append(x)
        

        # list of dicts holding inputs for each layer
        inputs = [defaultdict(list)] * len(layers)
        # list of hooks so we can remove them after
        hooks = []
        
        for i,layer in enumerate(layers):
            named_linears = self._get_named_linears(layer)
            for name, mod in named_linears.items():
                hooks.append(
                    mod.register_forward_hook(partial(input_hook,
                                                      layer_index = i, 
                                                      module_name=name, 
                                                      inputs = inputs))
                )

        
        # TODO: setup proper dataloader for this
        for batch in calibration_set:
            X = self._prepare_input(batch)
            self.run_model(**X)

        # remove hooks from model
        for hook in hooks:
            hook.remove()

        return inputs
        

    # returns all nn.linear within module (a layer)
    def _get_named_linears(self, module):
        return {name: mod for name, mod in module.named_modules() if isinstance(mod, nn.Linear)}

    # return layers of model to consider for quantization (modify with config file)
    def _get_model_layers(self):
        raise NotImplementedError('_get_model_layers')
    
    def _get_calibration_set(self):
        raise NotImplementedError('_get_calibration_set')

    def _prepare_input(self):
        raise NotImplementedError('_prepare_input')
    

class Blip2ForConditionalGenerationAWQQuantizer(BaseAWQQuantizer):

    def __init__(self, model, inputs_processor, dataset):
        assert isinstance(model, Blip2ForConditionalGeneration)

        super().__init__(model, device, inputs_processor, dataset)
        self.run_model = model.generate
        
    def _get_model_layers(self):
        # NOTE: returning all layers for now
        return [*[layer for layer in self.model.vision_model.encoder.layers],
                *[layer for layer in self.model.qformer.encoder.layer],
                *[layer for layer in self.model.language_model.model.decoder.layers]]

    def _get_calibration_set(self):
        # NOTE: small set for testing
        return [self.dataset[0], self.dataset[1]]

    def _prepare_input(self, batch):
        X = self.inputs_processor(images=batch[0], return_tensors="pt").to(device)
        return X


class Blip2ForImageTextRetrievalAWQQuantizer(BaseAWQQuantizer):

    def __init__(self, model, device, inputs_processor, dataset):
        assert isinstance(model, Blip2ForImageTextRetrieval)
        super().__init__(model, device, inputs_processor, dataset)
        self.run_model = model.forward
        
    def _get_model_layers(self):
        # NOTE: returning all layers for now
        return [*[layer for layer in self.model.vision_model.encoder.layers],
                *[layer for layer in self.model.qformer.encoder.layer]]

    def _get_calibration_set(self):
        return [self.dataset[0], self.dataset[1]]

    def _prepare_input(self, batch):
        X = self.processor(images=batch[0], text=batch[1][0], return_tensors="pt").to(device, torch.float16)
        return X


In [29]:
# model = Blip2ForImageTextRetrieval.from_pretrained("Salesforce/blip2-itm-vit-g", torch_dtype=torch.float16)
# processor = AutoProcessor.from_pretrained("Salesforce/blip2-itm-vit-g")
# model.to(device)

# device

In [31]:
b = Blip2ForConditionalGenerationAWQQuantizer(model, processor, coco_dataset)
layers, inputs = b.quantize()

In [36]:
layer = layers[0]
layer


Blip2EncoderLayer(
  (self_attn): Blip2Attention(
    (dropout): Dropout(p=0.0, inplace=False)
    (qkv): Linear(in_features=1408, out_features=4224, bias=True)
    (projection): Linear(in_features=1408, out_features=1408, bias=True)
  )
  (layer_norm1): LayerNorm((1408,), eps=1e-06, elementwise_affine=True)
  (mlp): Blip2MLP(
    (activation_fn): GELUActivation()
    (fc1): Linear(in_features=1408, out_features=6144, bias=True)
    (fc2): Linear(in_features=6144, out_features=1408, bias=True)
  )
  (layer_norm2): LayerNorm((1408,), eps=1e-06, elementwise_affine=True)
)

In [37]:
b._get_named_linears(layer)

{'self_attn.qkv': Linear(in_features=1408, out_features=4224, bias=True),
 'self_attn.projection': Linear(in_features=1408, out_features=1408, bias=True),
 'mlp.fc1': Linear(in_features=1408, out_features=6144, bias=True),
 'mlp.fc2': Linear(in_features=6144, out_features=1408, bias=True)}

In [None]:
weight = torch.cat([_m.weight for _m in layers], dim=0)

In [33]:
layers[0]

Blip2EncoderLayer(
  (self_attn): Blip2Attention(
    (dropout): Dropout(p=0.0, inplace=False)
    (qkv): Linear(in_features=1408, out_features=4224, bias=True)
    (projection): Linear(in_features=1408, out_features=1408, bias=True)
  )
  (layer_norm1): LayerNorm((1408,), eps=1e-06, elementwise_affine=True)
  (mlp): Blip2MLP(
    (activation_fn): GELUActivation()
    (fc1): Linear(in_features=1408, out_features=6144, bias=True)
    (fc2): Linear(in_features=6144, out_features=1408, bias=True)
  )
  (layer_norm2): LayerNorm((1408,), eps=1e-06, elementwise_affine=True)
)

In [23]:
for x in inputs[0]['self_attn.q_proj']:
    print(x.shape)

torch.Size([1, 33, 2560])
torch.Size([1, 33, 2560])
torch.Size([1, 33, 2560])
torch.Size([1, 33, 2560])
torch.Size([1, 33, 2560])
torch.Size([1, 33, 2560])
torch.Size([1, 33, 2560])
torch.Size([1, 33, 2560])
torch.Size([1, 33, 2560])
torch.Size([1, 33, 2560])
torch.Size([1, 33, 2560])
torch.Size([1, 33, 2560])
torch.Size([1, 33, 2560])
torch.Size([1, 33, 2560])
torch.Size([1, 33, 2560])
torch.Size([1, 33, 2560])
torch.Size([1, 33, 2560])
torch.Size([1, 33, 2560])
torch.Size([1, 33, 2560])
torch.Size([1, 33, 2560])
torch.Size([1, 33, 2560])
torch.Size([1, 33, 2560])
torch.Size([1, 33, 2560])
torch.Size([1, 33, 2560])
torch.Size([1, 33, 2560])
torch.Size([1, 33, 2560])
torch.Size([1, 33, 2560])
torch.Size([1, 33, 2560])
torch.Size([1, 33, 2560])
torch.Size([1, 33, 2560])
torch.Size([1, 33, 2560])
torch.Size([1, 33, 2560])
torch.Size([1, 1, 2560])
torch.Size([1, 1, 2560])
torch.Size([1, 1, 2560])
torch.Size([1, 1, 2560])
torch.Size([1, 1, 2560])
torch.Size([1, 1, 2560])
torch.Size([1, 1, 

In [26]:
xs = [x.cpu().abs().view(-1, x.shape[-1]) for x in inputs[0]['self_attn.q_proj']]
torch.cat(xs, dim = 0).shape

torch.Size([2816, 2560])

In [27]:
torch.cat(xs, dim = 0).mean(0)

tensor([0.6801, 0.4476, 0.5009,  ..., 0.7988, 0.7426, 0.6116])

In [None]:
# b = Blip2ForImageTextRetrievalAWQQuantizer(model, processor, coco_dataset)
# inputs = b.quantize()


In [None]:
# TODO: exlude certain linear layers, reading from quant config

In [None]:
# TODO: solve for optimal (per input channel) scaling factor
# TODO: grid search for \alpha which balances protection of salient / non-salient weights