In [None]:
import time
import math
import random
import torch
import torch.nn as nn
import transformers
import numpy as np
from transformers import Blip2ForConditionalGeneration, Blip2Processor
from PIL import Image
import os

# ====================================================
# Constants and Configuration
# ====================================================

# Model and dataset configuration
MODEL_NAME = "Salesforce/blip2-opt-2.7b"  # BLIP-2 model to load
DATASET = "coco_captions"  # Dataset for calibration and evaluation
DATASET_PATH = "../data/coco/val2017"       # Path to the COCO dataset directory
ANNOTATIONS_FILE = "../data/coco/annotations/captions_val2017.json"  # Path to the COCO annotations file

# Quantization parameters
SEED = 0  # Random seed for reproducibility
NUM_SAMPLES = 16  # Number of calibration samples
PERCENT_DAMPENING = 0.01  # Percentage for dampening during quantization
BITS = 4  # Number of bits for quantization
GROUP_SIZE = -1  # Group size for quantization
USE_SYMMETRIC = True  # Use symmetric quantization
USE_ACT_ORDER = False  # Use activation order during quantization
USE_STATIC_GROUPS = False  # Use static groups during quantization

# Device configuration
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# Debugging flag
DEBUG = False  # Set to True for debugging output

# Disable TensorFloat32 for matmul and cuDNN to ensure deterministic results
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False

# ====================================================
# Quantization Functions and Classes
# ====================================================

def quantize(x, scale, zero, maxq):
    """
    Quantize the input tensor x using the provided scale and zero point.
    If maxq < 0, use a special case quantization.
    """
    if maxq < 0:
        # Special case for ternary quantization
        return (x > scale / 2).float() * scale + (x < zero / 2).float() * zero
    q = torch.clamp(torch.round(x / scale) + zero, 0, maxq)
    return scale * (q - zero)

class Quantizer(nn.Module):
    """
    Quantizer class to handle quantization parameters and operations.
    """
    def __init__(self, shape=1):
        super(Quantizer, self).__init__()
        self.register_buffer('maxq', torch.tensor(0))  # Maximum quantization level
        self.register_buffer('scale', torch.zeros(shape))  # Scale for quantization
        self.register_buffer('zero', torch.zeros(shape))  # Zero point for quantization

    def configure(self, bits, perchannel=False, sym=True, mse=False, norm=2.4,
                  grid=100, maxshrink=0.8, trits=False):
        """
        Configure the quantizer with the specified parameters.
        """
        self.maxq = torch.tensor(2 ** bits - 1)
        self.perchannel = perchannel  # Whether to quantize per channel
        self.sym = sym  # Symmetric quantization
        self.mse = mse  # Use MSE for scale and zero point calculation
        self.norm = norm  # Norm for error calculation
        self.grid = grid  # Grid size for scale search
        self.maxshrink = maxshrink  # Maximum shrinkage for scale search
        if trits:
            self.maxq = torch.tensor(-1)  # Special value for ternary quantization

    def find_params(self, x, weight=False):
        """
        Find the scale and zero point parameters for quantization based on input tensor x.
        """
        device = x.device
        self.maxq = self.maxq.to(device)

        shape = x.shape
        if self.perchannel:
            # Per-channel quantization
            if weight:
                x = x.flatten(1)
            else:
                if len(shape) == 4:
                    x = x.permute(1, 0, 2, 3)
                    x = x.flatten(1)
                elif len(shape) == 3:
                    x = x.reshape(-1, shape[-1]).t()
                elif len(shape) == 2:
                    x = x.t()
        else:
            # Global quantization
            x = x.flatten().unsqueeze(0)

        tmp = torch.zeros(x.shape[0], device=device)
        xmin = torch.minimum(x.min(dim=1)[0], tmp)
        xmax = torch.maximum(x.max(dim=1)[0], tmp)

        if self.sym:
            xmax = torch.maximum(torch.abs(xmin), xmax)
            negative_mask = xmin < 0
            if torch.any(negative_mask):
                xmin[negative_mask] = -xmax[negative_mask]
        zero_mask = (xmin == 0) & (xmax == 0)
        xmin[zero_mask] = -1
        xmax[zero_mask] = 1

        if self.maxq < 0:
            # Special case for ternary quantization
            self.scale = xmax
            self.zero = xmin
        else:
            self.scale = (xmax - xmin) / self.maxq
            if self.sym:
                self.zero = torch.full_like(self.scale, (self.maxq + 1) / 2)
            else:
                self.zero = torch.round(-xmin / self.scale)

        if self.mse:
            # Use Mean Squared Error to find optimal scale and zero point
            best_error = torch.full([x.shape[0]], float('inf'), device=device)
            for i in range(int(self.maxshrink * self.grid)):
                p = 1 - i / self.grid
                xmin1 = p * xmin
                xmax1 = p * xmax
                scale1 = (xmax1 - xmin1) / self.maxq
                zero1 = torch.round(-xmin1 / scale1) if not self.sym else self.zero
                q = quantize(x, scale1.unsqueeze(1), zero1.unsqueeze(1), self.maxq)
                error = ((q - x).abs().pow(self.norm)).sum(dim=1)
                better_error_mask = error < best_error
                if torch.any(better_error_mask):
                    best_error[better_error_mask] = error[better_error_mask]
                    self.scale[better_error_mask] = scale1[better_error_mask]
                    self.zero[better_error_mask] = zero1[better_error_mask]

        if not self.perchannel:
            repeat_times = shape[0] if weight else shape[1] if len(shape) != 3 else shape[2]
            self.scale = self.scale.repeat(repeat_times)
            self.zero = self.zero.repeat(repeat_times)

        if weight:
            # Reshape for weight tensors
            new_shape = [-1] + [1] * (len(shape) - 1)
            self.scale = self.scale.reshape(new_shape)
            self.zero = self.zero.reshape(new_shape)
            return

        # Reshape for activation tensors
        if len(shape) == 4:
            self.scale = self.scale.reshape(1, -1, 1, 1)
            self.zero = self.zero.reshape(1, -1, 1, 1)
        elif len(shape) == 3:
            self.scale = self.scale.reshape(1, 1, -1)
            self.zero = self.zero.reshape(1, 1, -1)
        elif len(shape) == 2:
            self.scale = self.scale.unsqueeze(0)
            self.zero = self.zero.unsqueeze(0)

    def quantize(self, x):
        """
        Quantize the input tensor x using the stored scale and zero point.
        """
        if self.ready():
            return quantize(x, self.scale, self.zero, self.maxq)
        return x

    def enabled(self):
        """
        Check if quantization is enabled (maxq > 0).
        """
        return self.maxq > 0

    def ready(self):
        """
        Check if the quantizer is ready (scale is non-zero).
        """
        return torch.all(self.scale != 0)

# ====================================================
# GPTQ Quantization Class
# ====================================================

class GPTQ:
    """
    GPTQ class to perform quantization of a given model layer.
    """
    def __init__(self, layer):
        self.layer = layer
        self.device = self.layer.weight.device
        W = layer.weight.data.clone()
        if isinstance(self.layer, nn.Conv2d):
            W = W.flatten(1)
        if isinstance(self.layer, transformers.Conv1D):
            W = W.t()
        self.rows, self.columns = W.shape
        self.H = torch.zeros((self.columns, self.columns), device=self.device)
        self.nsamples = 0  # Number of samples collected
        self.quantizer = Quantizer()

    def add_batch(self, inp, out):
        """
        Add a batch of input and output data to compute Hessian approximation.
        """
        if DEBUG:
            self.inp1 = inp
            self.out1 = out
        if len(inp.shape) == 2:
            inp = inp.unsqueeze(0)
        batch_size = inp.shape[0]
        if isinstance(self.layer, (nn.Linear, transformers.Conv1D)):
            if len(inp.shape) == 3:
                inp = inp.reshape(-1, inp.shape[-1])
            inp = inp.t()
        elif isinstance(self.layer, nn.Conv2d):
            unfold = nn.Unfold(
                self.layer.kernel_size,
                dilation=self.layer.dilation,
                padding=self.layer.padding,
                stride=self.layer.stride
            )
            inp = unfold(inp)
            inp = inp.permute(1, 0, 2)
            inp = inp.flatten(1)
        # Update Hessian approximation
        self.H *= self.nsamples / (self.nsamples + batch_size)
        self.nsamples += batch_size
        inp = math.sqrt(2 / self.nsamples) * inp.float()
        self.H += inp.matmul(inp.t())

    def fasterquant(self, blocksize=128, percdamp=0.01, groupsize=-1,
                    actorder=False, static_groups=False):
        """
        Perform the quantization using the collected data and Hessian approximation.
        """
        W = self.layer.weight.data.clone()
        if isinstance(self.layer, nn.Conv2d):
            W = W.flatten(1)
        if isinstance(self.layer, transformers.Conv1D):
            W = W.t()
        W = W.float()

        start_time = time.time()

        if not self.quantizer.ready():
            self.quantizer.find_params(W, weight=True)

        H = self.H
        del self.H
        dead_mask = torch.diag(H) == 0
        H[dead_mask, dead_mask] = 1
        W[:, dead_mask] = 0

        if static_groups:
            import copy
            groups = []
            for i in range(0, self.columns, groupsize):
                quantizer = copy.deepcopy(self.quantizer)
                quantizer.find_params(W[:, i:i+groupsize], weight=True)
                groups.append(quantizer)

        if actorder:
            perm = torch.argsort(torch.diag(H), descending=True)
            W = W[:, perm]
            H = H[perm][:, perm]
            inv_perm = torch.argsort(perm)

        losses = torch.zeros_like(W)
        Q = torch.zeros_like(W)

        damp = percdamp * torch.mean(torch.diag(H))
        diag_indices = torch.arange(self.columns, device=self.device)
        H[diag_indices, diag_indices] += damp
        H = torch.linalg.cholesky(H)
        H = torch.cholesky_inverse(H)
        H = torch.linalg.cholesky(H, upper=True)
        H_inv = H

        for i1 in range(0, self.columns, blocksize):
            i2 = min(i1 + blocksize, self.columns)
            count = i2 - i1

            W1 = W[:, i1:i2].clone()
            Q1 = torch.zeros_like(W1)
            Err1 = torch.zeros_like(W1)
            Losses1 = torch.zeros_like(W1)
            H_inv1 = H_inv[i1:i2, i1:i2]

            for i in range(count):
                w = W1[:, i]
                d = H_inv1[i, i]

                if groupsize != -1:
                    if not static_groups:
                        if (i1 + i) % groupsize == 0:
                            self.quantizer.find_params(W[:, (i1 + i):(i1 + i + groupsize)], weight=True)
                    else:
                        idx = i1 + i
                        if actorder:
                            idx = perm[idx]
                        self.quantizer = groups[idx // groupsize]

                q = quantize(w.unsqueeze(1), self.quantizer.scale, self.quantizer.zero, self.quantizer.maxq).flatten()
                Q1[:, i] = q
                Losses1[:, i] = (w - q).pow(2) / d.pow(2) / 2

                err1 = (w - q) / d
                W1[:, i:] -= err1.unsqueeze(1).matmul(H_inv1[i, i:].unsqueeze(0))
                Err1[:, i] = err1

            Q[:, i1:i2] = Q1
            losses[:, i1:i2] = Losses1

            W[:, i2:] -= Err1.matmul(H_inv[i1:i2, i2:])

            if DEBUG:
                self.layer.weight.data[:, :i2] = Q[:, :i2]
                self.layer.weight.data[:, i2:] = W[:, i2:]
                print(torch.sum((self.layer(self.inp1) - self.out1).pow(2)))
                print(torch.sum(losses))

        torch.cuda.synchronize()
        print(f"Time for quantization: {time.time() - start_time:.2f} seconds")
        print(f"Total quantization error: {torch.sum(losses).item()}")

        if actorder:
            Q = Q[:, inv_perm]

        if isinstance(self.layer, transformers.Conv1D):
            Q = Q.t()
        self.layer.weight.data = Q.reshape(self.layer.weight.shape).to(self.layer.weight.data.dtype)
        if DEBUG:
            print(torch.sum((self.layer(self.inp1) - self.out1).pow(2)))

    def free(self):
        """
        Free up memory by deleting large variables.
        """
        if DEBUG:
            self.inp1 = None
            self.out1 = None
        self.H = None
        torch.cuda.empty_cache()

# ====================================================
# Data Loader Functions
# ====================================================

def set_seed(seed):
    """
    Set the random seed for reproducibility.
    """
    np.random.seed(seed)
    torch.manual_seed(seed)
    random.seed(seed)

def get_coco_dataset(nsamples, seed, dataset_path, annotations_file):
    """
    Load the COCO dataset and prepare calibration and test data loaders.
    """
    from pycocotools.coco import COCO

    # Initialize COCO API for caption annotations
    coco = COCO(annotations_file)

    # Get all image IDs
    img_ids = list(coco.imgs.keys())

    # Set random seed
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

    # Shuffle and select samples
    random.shuffle(img_ids)
    selected_img_ids = img_ids[:nsamples]
    test_img_ids = img_ids[nsamples:nsamples + 100]  # Use next 100 images for testing

    processor = Blip2Processor.from_pretrained(MODEL_NAME)
    trainloader = []

    # Prepare the calibration data loader
    for img_id in selected_img_ids:
        img_info = coco.loadImgs(img_id)[0]
        img_path = os.path.join(dataset_path, img_info['file_name'])
        image = Image.open(img_path).convert('RGB')
        # Preprocess the image
        inputs = processor(images=image, return_tensors="pt")
        trainloader.append(inputs)

    # Prepare the test data loader
    testloader = []
    for img_id in test_img_ids:
        img_info = coco.loadImgs(img_id)[0]
        img_path = os.path.join(dataset_path, img_info['file_name'])
        image = Image.open(img_path).convert('RGB')
        # Get all captions for the image
        ann_ids = coco.getAnnIds(imgIds=img_id)
        anns = coco.loadAnns(ann_ids)
        captions = [ann['caption'] for ann in anns]
        # Use the first caption as reference
        reference_caption = captions[0] if captions else ""
        inputs = processor(images=image, return_tensors="pt")
        testloader.append((inputs, reference_caption))

    return trainloader, testloader

def get_loaders(nsamples, seed, dataset_path, annotations_file):
    """
    Get the calibration and test data loaders for the COCO dataset.
    """
    return get_coco_dataset(nsamples, seed, dataset_path, annotations_file)

# ====================================================
# Model Utility Functions
# ====================================================

def find_layers(module, layers=[nn.Conv2d, nn.Linear, transformers.Conv1D], name=''):
    """
    Recursively find all layers of specified types in a model.
    Returns a dictionary mapping layer names to layers.
    """
    if type(module) in layers:
        return {name: module}
    res = {}
    for child_name, child_module in module.named_children():
        child_res = find_layers(
            child_module, layers=layers, name=name + '.' + child_name if name else child_name
        )
        res.update(child_res)
    return res

def get_blip2(model_name):
    """
    Load and prepare the BLIP-2 model for quantization.
    """
    # Disable weight initialization to speed up model loading
    def skip_init(*args, **kwargs):
        pass

    torch.nn.init.kaiming_uniform_ = skip_init
    torch.nn.init.uniform_ = skip_init
    torch.nn.init.normal_ = skip_init

    # Load the model
    model = Blip2ForConditionalGeneration.from_pretrained(model_name, torch_dtype='auto')
    return model


@torch.no_grad()
def blip2_sequential(model, dataloader):
    print("Starting quantization...")

    use_cache = getattr(model.config, 'use_cache', False)
    if hasattr(model.config, 'use_cache'):
        model.config.use_cache = False

    model.eval()
    dtype = next(iter(model.parameters())).dtype

    print("Collecting calibration data...")

    layers = find_layers(model)
    layers_to_quantize = {name: layer for name, layer in layers.items() if isinstance(layer, nn.Linear)}

    gptq_layers = {}
    for name, layer in layers_to_quantize.items():
        gptq = GPTQ(layer)
        gptq.quantizer.configure(
            bits=BITS, perchannel=True, sym=USE_SYMMETRIC, mse=False
        )
        gptq_layers[name] = gptq

    handles = []

    def get_activation(name):
        def hook(module, input, output):
            gptq_layers[name].add_batch(input[0].data, output.data)
        return hook

    for name, layer in layers_to_quantize.items():
        handles.append(layer.register_forward_hook(get_activation(name)))

    for inputs in dataloader:
        inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
        with torch.no_grad():
            model(**inputs)

    for handle in handles:
        handle.remove()

    print("Quantizing layers...")
    for name, layer in layers_to_quantize.items():
        print(f"Quantizing layer {name}")
        gptq = gptq_layers[name]
        gptq.fasterquant(
            percdamp=PERCENT_DAMPENING,
            groupsize=GROUP_SIZE,
            actorder=USE_ACT_ORDER,
            static_groups=USE_STATIC_GROUPS
        )
        quantized_weight = gptq.quantizer.quantize(layer.weight.data)
        layer.weight.data = quantized_weight.to(dtype)
        gptq.free()

    print("Quantization complete.")
    return model

set_seed(SEED)

# Load the model
model = get_blip2(MODEL_NAME)
model.to(DEVICE)
# model.eval()

# Prepare data loaders
dataloader, testloader = get_loaders(
    nsamples=NUM_SAMPLES,
    seed=SEED,
    dataset_path=DATASET_PATH,
    annotations_file=ANNOTATIONS_FILE
)

# Perform quantization if required
start_time = time.time()
quantized_model = blip2_sequential(model, dataloader)
print(f"Quantization time: {time.time() - start_time:.2f} seconds")