In [26]:
import torch
import torch.nn as nn
from transformers import Blip2Processor, Blip2ForConditionalGeneration
from datasets import COCODataset
from tqdm import tqdm
from PIL import Image
from torch.utils.data import DataLoader
# from utils import print_model_structure

from collections import defaultdict
from functools import partial


model_name = "Salesforce/blip2-opt-2.7b"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
processor = Blip2Processor.from_pretrained(model_name)


Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  5.44it/s]


In [None]:

model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b")
model = model.to(device)


In [17]:
# TODO: determine appropriate size for this

coco_dataset = COCODataset(ann_file='/nfshomes/vla/project_dirs/low-bit-vision/datasets/cocow/annotations/captions_val2017.json',
                           img_dir='/nfshomes/vla/project_dirs/low-bit-vision/datasets/cocow/images/val2017')

calibration_set = [coco_dataset[0]]

loading annotations into memory...
Done (t=0.02s)
creating index...
index created!


In [18]:
inputs = processor(images=calibration_set[0][0], return_tensors="pt").to(device)
inputs

{'pixel_values': tensor([[[[-0.1572, -0.2740, -0.5368,  ..., -1.3251, -1.6171, -1.6171],
          [-0.1572, -0.4492, -0.6098,  ..., -1.5149, -1.6025, -1.3397],
          [-0.2156, -0.5076, -0.6682,  ..., -1.5441, -1.5879, -1.4857],
          ...,
          [ 0.2077,  0.2661,  0.2953,  ..., -1.0331, -0.9893, -1.0331],
          [ 0.2369,  0.2223,  0.2661,  ..., -1.0769, -1.0331, -1.0915],
          [ 0.2807,  0.2807,  0.2661,  ..., -1.0623, -1.0477, -1.0915]],

         [[ 0.1089, -0.0112, -0.3414,  ..., -1.4369, -1.5720, -1.5570],
          [ 0.1089, -0.1913, -0.4614,  ..., -1.5120, -1.5720, -1.3019],
          [ 0.0789, -0.3864, -0.5215,  ..., -1.5270, -1.5570, -1.4069],
          ...,
          [-0.1613, -0.1613, -0.1163,  ..., -1.1968, -1.1818, -1.1818],
          [-0.1613, -0.1463, -0.1313,  ..., -1.2118, -1.2268, -1.2568],
          [-0.1313, -0.1163, -0.1313,  ..., -1.2118, -1.2268, -1.2869]],

         [[-0.5559, -0.5417, -0.4706,  ..., -1.1816, -1.3238, -1.2811],
          [-0

In [19]:
inputs['pixel_values'].shape

torch.Size([1, 3, 224, 224])

In [20]:
calibration_set

[(<PIL.Image.Image image mode=RGB size=640x427>,
  ['A man is in a kitchen making pizzas.',
   'Man in apron standing on front of oven with pans and bakeware',
   'A baker is working in the kitchen rolling dough.',
   'A person standing by a stove in a kitchen.',
   'A table with pies being made and a person standing near a wall with pots and pans hanging on the wall.'])]

In [21]:
model

Blip2ForConditionalGeneration(
  (vision_model): Blip2VisionModel(
    (embeddings): Blip2VisionEmbeddings(
      (patch_embedding): Conv2d(3, 1408, kernel_size=(14, 14), stride=(14, 14))
    )
    (encoder): Blip2Encoder(
      (layers): ModuleList(
        (0-38): 39 x Blip2EncoderLayer(
          (self_attn): Blip2Attention(
            (dropout): Dropout(p=0.0, inplace=False)
            (qkv): Linear(in_features=1408, out_features=4224, bias=True)
            (projection): Linear(in_features=1408, out_features=1408, bias=True)
          )
          (layer_norm1): LayerNorm((1408,), eps=1e-06, elementwise_affine=True)
          (mlp): Blip2MLP(
            (activation_fn): GELUActivation()
            (fc1): Linear(in_features=1408, out_features=6144, bias=True)
            (fc2): Linear(in_features=6144, out_features=1408, bias=True)
          )
          (layer_norm2): LayerNorm((1408,), eps=1e-06, elementwise_affine=True)
        )
      )
    )
    (post_layernorm): LayerNorm((

In [22]:
def get_named_linears(module):
    return {name: mod for name, mod in module.named_modules() if isinstance(mod, nn.Linear)}

In [23]:
named_linears = get_named_linears(model)

In [38]:
def input_hook(module, input, output, module_name, inputs_dict, activations_dict):
    x = input[0]
    x = x.detach().cpu()

    out = output[0]
    out = out.detach().cpu()

    inputs_dict[module_name].append(x)
    activations_dict[module_name].append(out)

# dict containing list of different inputs for each named linear
inputs_dict = defaultdict(list)
activations_dict = defaultdict(list)

# list of hooks so we can remove them after
hooks = []

In [39]:
for module_name in named_linears:
    hooks.append(
        named_linears[module_name].register_forward_hook(partial(input_hook, 
                                                                 module_name=module_name, 
                                                                 inputs_dict=inputs_dict, 
                                                                 activations_dict=activations_dict))
    )

In [37]:
# TODO: setup proper dataloader for this
for batch in calibration_set:
    X = processor(images=batch[0], return_tensors="pt").to(device)
    print(X)
    model.generate(**X)

{'pixel_values': tensor([[[[-0.1572, -0.2740, -0.5368,  ..., -1.3251, -1.6171, -1.6171],
          [-0.1572, -0.4492, -0.6098,  ..., -1.5149, -1.6025, -1.3397],
          [-0.2156, -0.5076, -0.6682,  ..., -1.5441, -1.5879, -1.4857],
          ...,
          [ 0.2077,  0.2661,  0.2953,  ..., -1.0331, -0.9893, -1.0331],
          [ 0.2369,  0.2223,  0.2661,  ..., -1.0769, -1.0331, -1.0915],
          [ 0.2807,  0.2807,  0.2661,  ..., -1.0623, -1.0477, -1.0915]],

         [[ 0.1089, -0.0112, -0.3414,  ..., -1.4369, -1.5720, -1.5570],
          [ 0.1089, -0.1913, -0.4614,  ..., -1.5120, -1.5720, -1.3019],
          [ 0.0789, -0.3864, -0.5215,  ..., -1.5270, -1.5570, -1.4069],
          ...,
          [-0.1613, -0.1613, -0.1163,  ..., -1.1968, -1.1818, -1.1818],
          [-0.1613, -0.1463, -0.1313,  ..., -1.2118, -1.2268, -1.2568],
          [-0.1313, -0.1163, -0.1313,  ..., -1.2118, -1.2268, -1.2869]],

         [[-0.5559, -0.5417, -0.4706,  ..., -1.1816, -1.3238, -1.2811],
          [-0

KeyboardInterrupt: 

In [None]:
# TODO: exlude certain linear layers, reading from quant config

In [None]:
# TODO: solve for optimal (per input channel) scaling factor
# TODO: grid search for \alpha which balances protection of salient / non-salient weights