In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoProcessor, AutoModelForVision2Seq, AutoConfig
from datasets import load_dataset
from PIL import Image
import random
from tqdm import tqdm
import numpy as np
import gc
import math
import time
import os
import pandas as pd
import shutil

In [None]:
# quantized_model_dir = '/kaggle/working/phi3_vision_8bit_quantizers_wf_inspired.pt'
# if os.path.exists(quantized_model_dir):
#     shutil.rmtree(quantized_model_dir)
#     print(f"Removed directory: {quantized_model_dir}")

In [None]:
base_path = '/kaggle/input/flickr30k'
image_folder_name = 'flickr30k_images'
captions_file_name = 'captions.txt'
calibration_output_dir = '/kaggle/working/flickr30k_calibration'

image_folder = os.path.join(base_path, image_folder_name)
captions_file = os.path.join(base_path, captions_file_name)

df_captions = pd.read_csv(
    captions_file,
    delimiter=',',
    header=None,
    names=['image_name', 'caption_index', 'caption']
)

df_captions['caption'] = df_captions['caption'].astype(str).str.strip()
captions_dict = df_captions.groupby('image_name')['caption'].apply(list).to_dict()
image_paths = {}
for img_name in captions_dict.keys():
    full_path = os.path.join(image_folder, img_name)
    if os.path.exists(full_path):
        image_paths[img_name] = full_path

available_image_names = list(image_paths.keys())
sample_size = min(500, len(available_image_names))
sample_image_names = random.sample(available_image_names, sample_size)

os.makedirs(calibration_output_dir, exist_ok=True)
for img_name in sample_image_names:
    src_path = image_paths[img_name]
    dst_path = os.path.join(calibration_output_dir, img_name)
    shutil.copy(src_path, dst_path)

In [None]:
%%writefile setup.py
from setuptools import setup, Extension
from torch.utils import cpp_extension

setup(
    name='quant_cuda', # The name of the module to import later
    ext_modules=[cpp_extension.CUDAExtension(
        'quant_cuda', # Must match the name argument above
        ['quant_cuda.cpp', 'quant_cuda_kernel.cu'] # Source files
    )],
    cmdclass={'build_ext': cpp_extension.BuildExtension} # Command class for building
)

In [None]:
%%writefile quant_cuda.cpp
#include <torch/all.h>
#include <torch/python.h>
#include <c10/cuda/CUDAGuard.h>


// Declaration for the 8-bit CUDA function (defined in .cu file)
void vecquant8matmul_cuda(
  torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
  torch::Tensor scales, torch::Tensor zeros // expects uint8 zeros
);


// C++ Wrapper for the 8-bit CUDA function
void vecquant8matmul(
  torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
  torch::Tensor scales, torch::Tensor zeros
) {
  // Ensure execution on the correct CUDA device
  const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));

  // Input validation checks
  TORCH_CHECK(vec.is_cuda(), "Input vector 'vec' must be a CUDA tensor");
  TORCH_CHECK(mat.is_cuda(), "Input matrix 'mat' must be a CUDA tensor");
  TORCH_CHECK(mul.is_cuda(), "Output vector 'mul' must be a CUDA tensor");
  TORCH_CHECK(scales.is_cuda(), "Input scales 'scales' must be a CUDA tensor");
  TORCH_CHECK(zeros.is_cuda(), "Input zeros 'zeros' must be a CUDA tensor");

  TORCH_CHECK(vec.dim() >= 1, "Input vector 'vec' must have at least 1 dimension");
  TORCH_CHECK(mat.dim() == 2, "Input matrix 'mat' must have 2 dimensions");
  TORCH_CHECK(mul.dim() == 1, "Output vector 'mul' must have 1 dimension");
  TORCH_CHECK(scales.dim() == 1 || scales.size(0) == mul.size(0), "Scales must be 1D and match output size");
  TORCH_CHECK(zeros.dim() == 1 || zeros.size(0) == mul.size(0), "Zeros must be 1D and match output size");

  TORCH_CHECK(mat.dtype() == torch::kInt32 || mat.dtype() == torch::kUInt32, "Matrix 'mat' must be int32 or uint32");
  TORCH_CHECK(mat.size(1) == mul.size(0), "Matrix columns must match output size");
  TORCH_CHECK(mat.size(1) == scales.size(0), "Matrix columns must match scales size");
  TORCH_CHECK(mat.size(1) == zeros.size(0), "Matrix columns must match zeros size");
  TORCH_CHECK(mat.size(0) * 4 == vec.size(-1), "Packed matrix rows * 4 must match input vector size");

  // Ensure zeros tensor is uint8 as expected by the kernel
  auto zeros_uint8 = zeros.to(torch::kUInt8);

  // Call the CUDA kernel launcher
  vecquant8matmul_cuda(vec, mat, mul, scales, zeros_uint8);
}


// Pybind11 Module Definition
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.def("vecquant8matmul", &vecquant8matmul, "Vector 8-bit Quantized Matrix Multiplication (CUDA)");
}

In [None]:
%%writefile quant_cuda_kernel.cu
// quant_cuda_kernel.cu

#include <torch/all.h>
#include <torch/python.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>


template <typename scalar_t>
__global__ void VecQuant3MatMulKernel(
    const  scalar_t* __restrict__ vec,
    const       int* __restrict__ mat,
           scalar_t* __restrict__ mul,
    const  scalar_t* __restrict__ scales,
    const  scalar_t* __restrict__ zeros,
    int height,
    int width
);

__global__ void VecQuant3MatMulKernelFaster(
    const  half2* __restrict__ vec,
    const    int* __restrict__ mat,
           float* __restrict__ mul,
    const  float* __restrict__ scales,
    const  float* __restrict__ zeros,
    int height,
    int width
);

template <typename scalar_t>
__global__ void VecQuant8MatMulKernel(
    const  scalar_t* __restrict__ vec,
    const  uint32_t* __restrict__ mat,
           scalar_t* __restrict__ mul,
    const  scalar_t* __restrict__ scales,
    const   uint8_t* __restrict__ zeros,
    int height,
    int width
);


const int BLOCKWIDTH  = 256;
const int BLOCKHEIGHT =  24;
const int BLOCKWIDTH_8BIT = 128; // Example block width for 8bit kernel


void vecquant3matmul_cuda(
  torch::Tensor vec,
  torch::Tensor mat,
  torch::Tensor mul,
  torch::Tensor scales,
  torch::Tensor zeros
) {
  int height = mat.size(0);
  int width = mat.size(1);

  dim3 blocks(
    (height + BLOCKHEIGHT - 1) / BLOCKHEIGHT,
    (width + BLOCKWIDTH - 1) / BLOCKWIDTH
  );
  dim3 threads(BLOCKWIDTH);

  AT_DISPATCH_FLOATING_TYPES(
    vec.type(), "vecquant3matmul_cuda", ([&] {
      VecQuant3MatMulKernel<<<blocks, threads>>>(
        vec.data_ptr<scalar_t>(), mat.data_ptr<int>(), mul.data_ptr<scalar_t>(),
        scales.data_ptr<scalar_t>(), zeros.data_ptr<scalar_t>(),
        height, width
      );
    })
  );
}

void vecquant3matmul_faster_cuda(
  torch::Tensor vec,
  torch::Tensor mat,
  torch::Tensor mul,
  torch::Tensor scales,
  torch::Tensor zeros
) {
  int height = mat.size(0);
  int width = mat.size(1);

  dim3 blocks(
    (height + BLOCKHEIGHT - 1) / BLOCKHEIGHT,
    (width + BLOCKWIDTH - 1) / BLOCKWIDTH
  );
  dim3 threads(BLOCKWIDTH);

  VecQuant3MatMulKernelFaster<<<blocks, threads>>>(
    (half2*) vec.data_ptr(),
    mat.data_ptr<int>(),
    mul.data_ptr<float>(),
    scales.data_ptr<float>(),
    zeros.data_ptr<float>(),
    height, width
  );
}


void vecquant8matmul_cuda(
  torch::Tensor vec,
  torch::Tensor mat,
  torch::Tensor mul,
  torch::Tensor scales,
  torch::Tensor zeros
) {
  // mat shape: [infeatures / 4, outfeatures] -> height = infeatures / 4, width = outfeatures
  // vec shape: [infeatures]
  // mul shape: [outfeatures]
  // scales shape: [outfeatures]
  // zeros shape: [outfeatures] (uint8)
  int packed_rows = mat.size(0);
  int width = mat.size(1); // outfeatures
  int height = vec.size(0); // infeatures

  // Launch configuration maps threads to output columns (width)
  dim3 blocks((width + BLOCKWIDTH_8BIT - 1) / BLOCKWIDTH_8BIT);
  dim3 threads(BLOCKWIDTH_8BIT);

  AT_DISPATCH_FLOATING_TYPES_AND_HALF(
    vec.scalar_type(), "vecquant8matmul_cuda", ([&] {
      VecQuant8MatMulKernel<<<blocks, threads>>>(
        vec.data_ptr<scalar_t>(),
        mat.data_ptr<uint32_t>(),
        mul.data_ptr<scalar_t>(),
        scales.data_ptr<scalar_t>(),
        zeros.data_ptr<uint8_t>(),
        height, // infeatures
        width   // outfeatures
      );
    })
  );
}


__device__ inline unsigned int as_unsigned(int i) {
  return *reinterpret_cast<unsigned int*>(&i);
}


template <typename scalar_t>
__global__ void VecQuant3MatMulKernel(
    const  scalar_t* __restrict__ vec,
    const       int* __restrict__ mat,
           scalar_t* __restrict__ mul,
    const  scalar_t* __restrict__ scales,
    const  scalar_t* __restrict__ zeros,
    int height,
    int width
) {
  int row = BLOCKHEIGHT * blockIdx.x;
  int col =  BLOCKWIDTH * blockIdx.y + threadIdx.x;

  __shared__ scalar_t blockvec[BLOCKWIDTH];
  if (threadIdx.x < BLOCKWIDTH) {
      int vec_idx = (row / BLOCKHEIGHT) * BLOCKWIDTH + threadIdx.x;
       if (vec_idx < height) { // Ensure reading within bounds of vec (height=infeatures)
          blockvec[threadIdx.x] = vec[vec_idx];
       } else {
           blockvec[threadIdx.x] = 0; // Pad with zero if out of bounds
       }
  }
  __syncthreads();

  if (col >= width) return; // Check column bounds

  scalar_t scale = scales[col];
  scalar_t zero = zeros[col]; // Float zero point for 3bit

  scalar_t res = 0;
  int i = width * row + col; // Index into mat [height=rows(packed), width=cols(outfeatures)]
  int k = 0;

  unsigned int tmp1;
  unsigned int tmp2;
  unsigned int tmp;

  // Assuming height is multiple of BLOCKHEIGHT for simplicity in loop structure
  // A more robust kernel handles arbitrary height/infeatures
  while (k < BLOCKWIDTH) {
    tmp1 = as_unsigned(mat[i]);
    res += (scale * scalar_t((tmp1 >>  0) & 0x7) - zero) * blockvec[k + 0];
    res += (scale * scalar_t((tmp1 >>  3) & 0x7) - zero) * blockvec[k + 1];
    res += (scale * scalar_t((tmp1 >>  6) & 0x7) - zero) * blockvec[k + 2];
    res += (scale * scalar_t((tmp1 >>  9) & 0x7) - zero) * blockvec[k + 3];
    res += (scale * scalar_t((tmp1 >> 12) & 0x7) - zero) * blockvec[k + 4];
    res += (scale * scalar_t((tmp1 >> 15) & 0x7) - zero) * blockvec[k + 5];
    res += (scale * scalar_t((tmp1 >> 18) & 0x7) - zero) * blockvec[k + 6];
    res += (scale * scalar_t((tmp1 >> 21) & 0x7) - zero) * blockvec[k + 7];
    res += (scale * scalar_t((tmp1 >> 24) & 0x7) - zero) * blockvec[k + 8];
    res += (scale * scalar_t((tmp1 >> 27) & 0x7) - zero) * blockvec[k + 9];
    i += width; // Move to the next packed row for this column
    tmp2 = as_unsigned(mat[i]);
    tmp = (tmp1 >> 30) | ((tmp2 << 2) & 0x4);
    tmp2 >>= 1;
    res += (scale * scalar_t(tmp) - zero) * blockvec[k + 10];
    k += 11;
    res += (scale * scalar_t((tmp2 >>  0) & 0x7) - zero) * blockvec[k + 0];
    res += (scale * scalar_t((tmp2 >>  3) & 0x7) - zero) * blockvec[k + 1];
    res += (scale * scalar_t((tmp2 >>  6) & 0x7) - zero) * blockvec[k + 2];
    res += (scale * scalar_t((tmp2 >>  9) & 0x7) - zero) * blockvec[k + 3];
    res += (scale * scalar_t((tmp2 >> 12) & 0x7) - zero) * blockvec[k + 4];
    res += (scale * scalar_t((tmp2 >> 15) & 0x7) - zero) * blockvec[k + 5];
    res += (scale * scalar_t((tmp2 >> 18) & 0x7) - zero) * blockvec[k + 6];
    res += (scale * scalar_t((tmp2 >> 21) & 0x7) - zero) * blockvec[k + 7];
    res += (scale * scalar_t((tmp2 >> 24) & 0x7) - zero) * blockvec[k + 8];
    res += (scale * scalar_t((tmp2 >> 27) & 0x7) - zero) * blockvec[k + 9];
    i += width;
    tmp1 = as_unsigned(mat[i]);
    tmp = (tmp2 >> 30) | ((tmp1 << 1) & 0x6);
    tmp1 >>= 2;
    res += (scale * scalar_t(tmp) - zero) * blockvec[k + 10];
    k += 11;
    res += (scale * scalar_t((tmp1 >>  0) & 0x7) - zero) * blockvec[k + 0];
    res += (scale * scalar_t((tmp1 >>  3) & 0x7) - zero) * blockvec[k + 1];
    res += (scale * scalar_t((tmp1 >>  6) & 0x7) - zero) * blockvec[k + 2];
    res += (scale * scalar_t((tmp1 >>  9) & 0x7) - zero) * blockvec[k + 3];
    res += (scale * scalar_t((tmp1 >> 12) & 0x7) - zero) * blockvec[k + 4];
    res += (scale * scalar_t((tmp1 >> 15) & 0x7) - zero) * blockvec[k + 5];
    res += (scale * scalar_t((tmp1 >> 18) & 0x7) - zero) * blockvec[k + 6];
    res += (scale * scalar_t((tmp1 >> 21) & 0x7) - zero) * blockvec[k + 7];
    res += (scale * scalar_t((tmp1 >> 24) & 0x7) - zero) * blockvec[k + 8];
    res += (scale * scalar_t((tmp1 >> 27) & 0x7) - zero) * blockvec[k + 9];
    i += width;
    k += 10;
  }

  atomicAdd(&mul[col], res);
}

__global__ void VecQuant3MatMulKernelFaster(
    const  half2* __restrict__ vec,
    const    int* __restrict__ mat,
           float* __restrict__ mul,
    const  float* __restrict__ scales,
    const  float* __restrict__ zeros,
    int height,
    int width
) {
  const int blockwidth2 = BLOCKWIDTH / 2;

  int row = BLOCKHEIGHT * blockIdx.x;
  int col = BLOCKWIDTH * blockIdx.y + threadIdx.x;

   __shared__ half2 blockvec[blockwidth2];
   if (threadIdx.x < blockwidth2) {
       int vec_idx = (row / BLOCKHEIGHT) * blockwidth2 + threadIdx.x;
        // Assuming vec has length height = infeatures
       if (vec_idx * 2 + 1 < height) { // Check bounds for half2
           blockvec[threadIdx.x] = vec[vec_idx];
       } else if (vec_idx * 2 < height) {
           // Handle case where only the first half is valid
           // This requires careful handling or assuming vec length is even
            blockvec[threadIdx.x] = __halves2half2(((half*)vec)[vec_idx * 2], __float2half(0.0f));
       }
       else {
            blockvec[threadIdx.x] = __float2half2_rn(0.0f); // Pad with zero
       }
   }

  __shared__ half2 deq2[64][32];
  int val = threadIdx.x / 32;
  int off = threadIdx.x % 32;
  for (; val < 64; val += BLOCKWIDTH / 32) {
    deq2[val][off] = __halves2half2(
       __int2half_rn(val & 0x7), __int2half_rn(val >> 3)
    );
  }

  __syncthreads();

  if (col >= width) return; // Check column bounds

  half2 scale = __float2half2_rn(scales[col]);
  half2 zero = __float2half2_rn(-zeros[col]); // Note: using -zeros[col]

  int i = width * row + col;
  int k = 0;

  float res = 0;
  half2 res2;

  unsigned int tmp1;
  unsigned int tmp2;
  unsigned int tmp;


  while (k < blockwidth2) { // Iterate through input features (paired in half2)
    res2 = {};
    tmp1 = as_unsigned(mat[i]);
    res2 = __hfma2(__hfma2(deq2[(tmp1 >>  0) & 0x3f][off], scale, zero), blockvec[k + 0], res2);
    res2 = __hfma2(__hfma2(deq2[(tmp1 >>  6) & 0x3f][off], scale, zero), blockvec[k + 1], res2);
    res2 = __hfma2(__hfma2(deq2[(tmp1 >> 12) & 0x3f][off], scale, zero), blockvec[k + 2], res2);
    res2 = __hfma2(__hfma2(deq2[(tmp1 >> 18) & 0x3f][off], scale, zero), blockvec[k + 3], res2);
    res2 = __hfma2(__hfma2(deq2[(tmp1 >> 24) & 0x3f][off], scale, zero), blockvec[k + 4], res2);
    i += width;
    tmp2 = as_unsigned(mat[i]);
    tmp = (tmp1 >> 30) | ((tmp2 << 2) & 0x3c);
    res2 = __hfma2(__hfma2(deq2[tmp][off], scale, zero), blockvec[k + 5], res2);
    tmp2 >>= 4;
    k += 6;
    res2 = __hfma2(__hfma2(deq2[(tmp2 >>  0) & 0x3f][off], scale, zero), blockvec[k + 0], res2);
    res2 = __hfma2(__hfma2(deq2[(tmp2 >>  6) & 0x3f][off], scale, zero), blockvec[k + 1], res2);
    res2 = __hfma2(__hfma2(deq2[(tmp2 >> 12) & 0x3f][off], scale, zero), blockvec[k + 2], res2);
    res2 = __hfma2(__hfma2(deq2[(tmp2 >> 18) & 0x3f][off], scale, zero), blockvec[k + 3], res2);
    i += width;
    tmp1 = as_unsigned(mat[i]);
    tmp = (tmp2 >> 24) | ((tmp1 << 4) & 0x30);
    res2 = __hfma2(__hfma2(deq2[tmp][off], scale, zero), blockvec[k + 4], res2);
    tmp1 >>= 2;
    k += 5;
    res2 = __hfma2(__hfma2(deq2[(tmp1 >>  0) & 0x3f][off], scale, zero), blockvec[k + 0], res2);
    res2 = __hfma2(__hfma2(deq2[(tmp1 >>  6) & 0x3f][off], scale, zero), blockvec[k + 1], res2);
    res2 = __hfma2(__hfma2(deq2[(tmp1 >> 12) & 0x3f][off], scale, zero), blockvec[k + 2], res2);
    res2 = __hfma2(__hfma2(deq2[(tmp1 >> 18) & 0x3f][off], scale, zero), blockvec[k + 3], res2);
    res2 = __hfma2(__hfma2(deq2[(tmp1 >> 24) & 0x3f][off], scale, zero), blockvec[k + 4], res2);
    i += width;
    k += 5;
    res += __half2float(res2.x) + __half2float(res2.y);
  }

  atomicAdd(&mul[col], res);
}


template <typename scalar_t>
__global__ void VecQuant8MatMulKernel(
    const  scalar_t* __restrict__ vec,     // Input vector [infeatures]
    const  uint32_t* __restrict__ mat,     // Packed weights [infeatures / 4, outfeatures]
           scalar_t* __restrict__ mul,     // Output vector [outfeatures]
    const  scalar_t* __restrict__ scales,  // Scales [outfeatures]
    const   uint8_t* __restrict__ zeros,   // Integer zero points [outfeatures]
    int height,                           // Infeatures dimension
    int width                             // Outfeatures dimension
) {
    int out_col = blockIdx.x * blockDim.x + threadIdx.x; // Maps threads to output columns

    if (out_col >= width) return; // Ensure thread is within output bounds

    scalar_t scale = scales[out_col];
    uint8_t zero_point = zeros[out_col];
    scalar_t zero_point_f = static_cast<scalar_t>(zero_point);

    scalar_t accum = 0;

    // Iterate over the input features (height = infeatures)
    for (int k = 0; k < height; ++k) {
        // Calculate index into packed matrix
        int packed_row = k / 4;
        int sub_idx = k % 4; // Which 8-bit weight within the uint32 (0, 1, 2, or 3)

        // Read the packed uint32 value
        // Index: packed_row * width + out_col (Column-major access)
        uint32_t packed_val = mat[packed_row * width + out_col];

        // Unpack the specific 8-bit weight
        uint8_t q_val = (packed_val >> (sub_idx * 8)) & 0xFF;

        // Dequantize: W = scale * (Q - zero_point)
        scalar_t w_val = scale * (static_cast<scalar_t>(q_val) - zero_point_f);

        // Multiply and accumulate
        accum += w_val * vec[k];
    }

    // Atomically add the result for this output column
    // Assumes mul is initialized to zero before kernel launch
    atomicAdd(&mul[out_col], accum);
}

In [None]:
# This command compiles the C++/CUDA code and installs the module
# The "! " prefix runs this as a shell command
# Make sure the GPU is enabled in your Kaggle notebook settings
!pip install . --verbose
# Using --verbose can help diagnose build issues if they occur

In [None]:
try:
    import quant_cuda
    print("quant_cuda module imported successfully!")
    print("Available functions:", dir(quant_cuda))
except ImportError as e:
    print(f"Failed to import quant_cuda: {e}")
    print("Please check the build output in the cell above for errors.")
except Exception as e:
    print(f"An unexpected error occurred during import: {e}")

In [None]:
import sys
import os

module_dir = '/kaggle/input/gptmodified'

if module_dir not in sys.path:
    sys.path.insert(0, module_dir)

import gptq
import quant
import modelutils

In [None]:
%%writefile /kaggle/working/demo_single_layer.py
import sys
import os
import torch
import torch.nn as nn
import numpy as np
import time
import traceback

module_dir_working = '/kaggle/working'
module_dir_input = '/kaggle/input/gptqmodified' # Adjust if needed

if module_dir_working not in sys.path:
    sys.path.insert(0, module_dir_working)
if module_dir_input not in sys.path:
    sys.path.insert(0, module_dir_input)

try:
    from gptq import GPTQ
    from modelutils import find_layers
    from quant import Quantizer, Quant3Linear, make_quant3
    import quant_cuda
    print("Custom modules and CUDA extension imported successfully.")
except ImportError as e:
    print(f"Error importing modules: {e}")
    sys.exit(1)
except Exception as e:
    print(f"An unexpected error occurred during import: {e}")
    sys.exit(1)

DEV = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch_dtype = torch.float16 if DEV != torch.device("cpu") else torch.float32
print(f"Using device: {DEV}, dtype: {torch_dtype}")

in_features = 1024
out_features = 512
assert in_features % 4 == 0, "in_features must be divisible by 4"

wbits = 8
groupsize = -1
sym = False
percdamp = 0.01
nsamples = 32
seqlen = 16

class DemoModel(nn.Module):
    def __init__(self, in_f, out_f):
        super().__init__()
        self.linear_layer = nn.Linear(in_f, out_f, bias=False)
    def forward(self, x):
        return self.linear_layer(x)

model = DemoModel(in_features, out_features).to(DEV, dtype=torch_dtype)
original_layer = model.linear_layer
print(f"Created original model with layer: {original_layer}")

print(f"\nGenerating {nsamples} dummy calibration samples...")
calibration_data = [torch.randn(seqlen, in_features, device=DEV, dtype=torch_dtype) for _ in range(nsamples)]
print("Dummy data generated.")

print("\nSetting up GPTQ and Quantizer...")
layer_name = 'linear_layer'
gptq_instance = GPTQ(original_layer)
quantizer = Quantizer()
quantizer.configure(wbits, perchannel=True, sym=sym, mse=False)
gptq_instance.quantizer = quantizer
print("GPTQ setup complete.")

print("\nRunning calibration (computing Hessian)...")
with torch.no_grad():
    for inp_batch in calibration_data:
        out_batch = original_layer(inp_batch)
        inp_reshaped = inp_batch.view(-1, in_features)
        out_reshaped = out_batch.view(-1, out_features)
        gptq_instance.add_batch(inp_reshaped, out_reshaped)
print("Calibration complete.")

print("\nPerforming quantization (fasterquant)...")
quant_tick = time.time()
try:
    gptq_instance.fasterquant(percdamp=percdamp, groupsize=groupsize)
except Exception as e:
    print(f"Error during fasterquant: {e}"); traceback.print_exc(); sys.exit(1)
quant_time = time.time() - quant_tick
print(f"Quantization finished in {quant_time:.2f}s.")

# --- Get FP Quantized Output (Reference) ---
print("\nGetting reference output from FP-quantized layer (before packing)...")
test_input = torch.randn(1, in_features, device=DEV, dtype=torch_dtype)
with torch.no_grad():
    # original_layer weights are now the FP representation of quantized weights
    fp_quantized_output = original_layer(test_input).detach().clone()
print(f"Reference FP Quantized Output (example): {fp_quantized_output.flatten()[:5]}")

# --- Packing ---
print("\nPacking the quantized layer...")
pack_tick = time.time()
try:
    final_quantizer = gptq_instance.quantizer.to('cpu')
    scale = final_quantizer.scale
    zero = final_quantizer.zero
    layer_map = {layer_name: original_layer}
    make_quant3(model, layer_map) # Modifies model in-place
    packed_layer = model.linear_layer
    if not isinstance(packed_layer, Quant3Linear):
         raise TypeError(f"make_quant3 failed. Found: {type(packed_layer)}")
    print(f"Layer replaced with: {packed_layer}")
    packed_layer.pack(original_layer.to('cpu'), scale, zero)
except Exception as e:
    print(f"Error during packing: {e}"); traceback.print_exc(); sys.exit(1)
pack_time = time.time() - pack_tick
print(f"Packing finished in {pack_time:.2f}s.")

# --- Verification Checks ---
print("\n--- Verification Checks ---")
verification_passed = True
try:
    packed_layer_cpu = packed_layer.to('cpu') # Move to CPU for checks

    # 1. Check qweight
    qweight = packed_layer_cpu.qweight
    print(f"qweight dtype: {qweight.dtype}")
    print(f"qweight shape: {qweight.shape}")
    expected_qweight_shape = (in_features // 4, out_features)
    if qweight.dtype != torch.uint32:
        print("!!! Verification FAILED: qweight dtype is not torch.uint32 !!!")
        verification_passed = False
    if qweight.shape != expected_qweight_shape:
        print(f"!!! Verification FAILED: qweight shape mismatch. Expected {expected_qweight_shape}, Got {qweight.shape} !!!")
        verification_passed = False
    print(f"qweight example values: {qweight.flatten()[:5].tolist()} ... {qweight.flatten()[-5:].tolist()}")

    # 2. Check scales
    scales = packed_layer_cpu.scales
    print(f"\nscales dtype: {scales.dtype}")
    print(f"scales shape: {scales.shape}")
    expected_scales_shape = (out_features,)
    if scales.shape != expected_scales_shape:
         print(f"!!! Verification FAILED: scales shape mismatch. Expected {expected_scales_shape}, Got {scales.shape} !!!")
         verification_passed = False
    if not torch.all(scales > 0):
         print(f"!!! Verification WARNING: Some scales are not positive !!!")
    print(f"scales stats: min={scales.min().item():.4f}, max={scales.max().item():.4f}, mean={scales.mean().item():.4f}")

    # 3. Check zeros
    zeros = packed_layer_cpu.zeros
    print(f"\nzeros dtype: {zeros.dtype}")
    print(f"zeros shape: {zeros.shape}")
    expected_zeros_shape = (out_features,)
    if zeros.dtype != torch.uint8:
        print("!!! Verification FAILED: zeros dtype is not torch.uint8 !!!")
        verification_passed = False
    if zeros.shape != expected_zeros_shape:
         print(f"!!! Verification FAILED: zeros shape mismatch. Expected {expected_zeros_shape}, Got {zeros.shape} !!!")
         verification_passed = False
    if not (torch.all(zeros >= 0) and torch.all(zeros <= 255)):
         print(f"!!! Verification WARNING: zeros values out of range [0, 255] !!!")
    print(f"zeros stats: min={zeros.min().item()}, max={zeros.max().item()}, mean={zeros.float().mean().item():.2f}")

except Exception as e:
    print(f"!!! Verification FAILED during checks: {e} !!!")
    verification_passed = False
    traceback.print_exc()

if not verification_passed:
    print("\n--- Verification FAILED ---")
    sys.exit(1)
else:
    print("\n--- Basic Buffer Verification Passed ---")


# --- Inference ---
print("\nRunning inference on packed layer...")
model = model.to(DEV) # Ensure model is back on GPU
packed_layer = model.linear_layer
try:
    inf_tick = time.time()
    with torch.no_grad():
        quantized_output = packed_layer(test_input) # Use the same test_input
    inf_time = time.time() - inf_tick
    print(f"Inference finished in {inf_time:.4f}s.")
    print(f"Output shape: {quantized_output.shape}")
    print(f"Quantized Output (example): {quantized_output.flatten()[:5]}")

    # --- 4. Compare Outputs ---
    if 'fp_quantized_output' in locals():
        mae = torch.mean(torch.abs(quantized_output - fp_quantized_output)).item()
        max_diff = torch.max(torch.abs(quantized_output - fp_quantized_output)).item()
        print(f"\nOutput Comparison vs FP Quantized:")
        print(f"  Mean Absolute Error (MAE): {mae:.6f}")
        print(f"  Maximum Absolute Difference: {max_diff:.6f}")
        if mae == 0:
             print("  !!! WARNING: MAE is zero. Quantization might not have had an effect or comparison is flawed. !!!")
        elif mae > 0.1: # Arbitrary threshold, adjust based on expectation
             print(f"  !!! WARNING: MAE ({mae:.6f}) seems high. Check quantization/packing. !!!")
        else:
             print("  Difference seems reasonable (non-zero, not excessively large).")
    else:
        print("\nSkipping output comparison (reference output not generated).")


except Exception as e:
    print(f"Error during inference or comparison: {e}")
    traceback.print_exc(); sys.exit(1)

print("\n--- Demo Complete ---")

In [None]:
!python /kaggle/working/demo_single_layer.py

In [None]:
%%writefile quantize_phi3_vision.py
import sys
import os

module_dir = '/kaggle/input/gptqmodified' # Assuming this is the correct path now
if module_dir not in sys.path:
    sys.path.insert(0, module_dir)

# IMPORTANT: Ensure quant.py used here has the dtype=torch.uint32 fix
from gptq import GPTQ
from modelutils import find_layers
from quant import Quantizer, Quant3Linear, make_quant3

import argparse
import time
import random
import numpy as np
import pandas as pd
import traceback

import torch
import torch.nn as nn
from transformers import AutoModelForCausalLM, AutoProcessor
from PIL import Image

import warnings
warnings.filterwarnings('ignore')

# --- Argument Parsing ---
# (parse_args function remains the same)
def parse_args():
    parser = argparse.ArgumentParser(description="Quantize Phi-3.5 Vision model using GPTQ")
    parser.add_argument('--model_id', type=str, default="microsoft/Phi-3.5-vision-instruct", help='Model ID from Hugging Face Hub.')
    parser.add_argument('--dataset_path', type=str, default='/kaggle/input/flickr30k', help='Path to the Flickr30k dataset.')
    parser.add_argument('--captions_file', type=str, default='captions.txt', help='Name of the captions file within dataset_path.')
    parser.add_argument('--image_subdir', type=str, default='flickr30k_images', help='Subdirectory containing images within dataset_path.')
    parser.add_argument('--nsamples', type=int, default=128, help='Number of calibration data samples.')
    parser.add_argument('--seed', type=int, default=42, help='Random seed for sampling calibration data.')
    parser.add_argument('--wbits', type=int, default=8, choices=[8], help='Number of bits for quantization.')
    parser.add_argument('--groupsize', type=int, default=-1, help='Group size for quantization (-1 for per-channel).')
    parser.add_argument('--sym', action='store_true', help='Whether to perform symmetric quantization.')
    parser.add_argument('--percdamp', type=float, default=0.01, help='Damping factor for Hessian.')
    parser.add_argument('--save_dir', type=str, default='/kaggle/working/phi3_vision_quantized_8bit_hf', help='Directory to save the quantized model using save_pretrained.')

    args = parser.parse_args()
    if args.wbits != 8:
        raise ValueError("This script only supports 8-bit quantization (--wbits 8)")
    return args


# --- Input Preparation ---
# (prepare_calibration_input function remains the same)
def prepare_calibration_input(img_path, caption_text, processor, device):
    try:
        if not os.path.exists(img_path):
            return None
        image = Image.open(img_path).convert("RGB")
        placeholder = "<|image_1|>\n"
        user_content = placeholder + str(caption_text).strip()
        messages = [{"role": "user", "content": user_content}]
        prompt = processor.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        inputs = processor(prompt, [image], return_tensors="pt", padding=False)
        return {k: v.to(device) for k, v in inputs.items()}
    except Exception as e:
        print(f"Error processing {img_path}: {e}")
        return None

# --- Quantization Function ---
# (phi3_vision_sequential function remains the same)
@torch.no_grad()
def phi3_vision_sequential(model, processor, calibration_data, dev, args):
    print('Starting GPTQ quantization...')
    try:
        layers = model.model.layers
        print(f"Found {len(layers)} transformer layers.")
    except AttributeError:
        print("Could not find layers at model.model.layers.")
        return None, None

    model.model.vision_embed_tokens = model.model.vision_embed_tokens.to(dev)
    model.model.embed_tokens = model.model.embed_tokens.to(dev)
    model.model.norm = model.model.norm.to(dev)
    model.lm_head = model.lm_head.to(dev)

    for i in range(len(layers)):
        layers[i] = layers[i].to('cpu')
    torch.cuda.empty_cache()

    inputs_cache = []
    class Catcher(nn.Module):
        def __init__(self, module):
            super().__init__()
            self.module = module
        def forward(self, hidden_states, attention_mask=None, position_ids=None, **kwargs):
            if isinstance(hidden_states, torch.Tensor):
                inputs_cache.append({
                    'hidden_states': hidden_states.cpu(),
                    'attention_mask': attention_mask.cpu() if attention_mask is not None else None,
                    'position_ids': position_ids.cpu() if position_ids is not None else None,
                })
            raise ValueError

    print("Capturing inputs to the first transformer layer...")
    original_layer_0 = layers[0]
    layers[0] = Catcher(layers[0]).to(dev)

    for i, (img_path, caption) in enumerate(calibration_data):
        if i >= args.nsamples: break
        calib_input = prepare_calibration_input(img_path, caption, processor, dev)
        if calib_input is None: continue
        try: model(**calib_input)
        except ValueError: pass
        except Exception as e: print(f"Error during capture sample {i}: {e}")
        finally: del calib_input; torch.cuda.empty_cache()

    layers[0] = original_layer_0.to('cpu')
    torch.cuda.empty_cache()

    if not inputs_cache:
        print("Error: No calibration inputs captured.")
        return model, None
    print(f"Captured inputs for {len(inputs_cache)} samples.")

    quantizers = {}
    current_layer_inputs = inputs_cache

    for i in range(len(layers)):
        print(f"\n--- Quantizing layer {i} ---")
        layer = layers[i].to(dev)
        layer.train(False)
        subset = find_layers(layer)

        if not subset:
            print(f"  No linear layers found. Propagating inputs.")
            next_layer_inputs = []
            for sample_idx in range(len(current_layer_inputs)):
                layer_input_args = {k: v.to(dev) if isinstance(v, torch.Tensor) else v for k, v in current_layer_inputs[sample_idx].items() if v is not None}
                if 'hidden_states' not in layer_input_args: continue
                try:
                    layer_outputs = layer(**layer_input_args, use_cache=False)
                    next_layer_inputs.append({
                        'hidden_states': layer_outputs[0].cpu(),
                        'attention_mask': current_layer_inputs[sample_idx]['attention_mask'],
                        'position_ids': current_layer_inputs[sample_idx].get('position_ids', None)
                    })
                except Exception as e: print(f"Error during fwd (no quant) L{i} S{sample_idx}: {e}"); traceback.print_exc(); next_layer_inputs.append(current_layer_inputs[sample_idx])
                finally: del layer_input_args; torch.cuda.empty_cache()
            layers[i] = layer.cpu()
            current_layer_inputs = next_layer_inputs
            torch.cuda.empty_cache()
            continue

        gptq_handlers = {name: GPTQ(subset[name]) for name in subset}
        for name in subset:
            gptq_handlers[name].quantizer = Quantizer()
            gptq_handlers[name].quantizer.configure(args.wbits, perchannel=True, sym=args.sym, mse=False)

        def add_batch(name):
            def tmp(_, inp, out):
                try:
                    inp_tensor = inp[0].data
                    out_tensor = out.data if isinstance(out, torch.Tensor) else out[0].data
                    gptq_handlers[name].add_batch(inp_tensor, out_tensor)
                except Exception as e: print(f"!!! Error in add_batch hook for {name}: {e} !!!")
            return tmp

        handles = [subset[name].register_forward_hook(add_batch(name)) for name in subset]

        for sample_idx in range(len(current_layer_inputs)):
            layer_input_args = {k: v.to(dev) if isinstance(v, torch.Tensor) else v for k, v in current_layer_inputs[sample_idx].items() if v is not None}
            if 'hidden_states' not in layer_input_args: continue
            try: _ = layer(**layer_input_args, use_cache=False)
            except Exception as e: print(f"Error during Hessian L{i} S{sample_idx}: {e}"); traceback.print_exc()
            finally: del layer_input_args; torch.cuda.empty_cache()

        for h in handles: h.remove()

        for name in subset:
            try:
                print(f"  Quantizing {name}...")
                gptq_handlers[name].fasterquant(percdamp=args.percdamp, groupsize=args.groupsize)
                quantizers[f'model.layers.{i}.{name}'] = gptq_handlers[name].quantizer.cpu()
                gptq_handlers[name].free()
            except Exception as e: print(f"Error during fasterquant for {name}: {e}"); traceback.print_exc()

        next_layer_inputs = []
        for sample_idx in range(len(current_layer_inputs)):
            layer_input_args = {k: v.to(dev) if isinstance(v, torch.Tensor) else v for k, v in current_layer_inputs[sample_idx].items() if v is not None}
            if 'hidden_states' not in layer_input_args: continue
            try:
                layer_outputs = layer(**layer_input_args, use_cache=False)
                next_layer_inputs.append({
                    'hidden_states': layer_outputs[0].cpu(),
                    'attention_mask': current_layer_inputs[sample_idx]['attention_mask'],
                    'position_ids': current_layer_inputs[sample_idx].get('position_ids', None)
                })
            except Exception as e: print(f"Error during propagation L{i} S{sample_idx}: {e}"); traceback.print_exc(); next_layer_inputs.append(current_layer_inputs[sample_idx])
            finally: del layer_input_args; torch.cuda.empty_cache()

        layers[i] = layer.cpu()
        del layer, gptq_handlers, subset
        current_layer_inputs = next_layer_inputs
        torch.cuda.empty_cache()

    print("\nGPTQ quantization finished.")
    model.model.vision_embed_tokens = model.model.vision_embed_tokens.to('cpu')
    model.model.embed_tokens = model.model.embed_tokens.to('cpu')
    model.model.norm = model.model.norm.to('cpu')
    model.lm_head = model.lm_head.to('cpu')
    torch.cuda.empty_cache()
    return model, quantizers


# --- Packing Function ---
# (phi3_vision_pack function remains the same)
def phi3_vision_pack(model, quantizers):
    print("\nPacking model weights...")
    layer_names_to_replace = list(quantizers.keys())
    modules_to_modify = list({name.rsplit('.', 1)[0] for name in layer_names_to_replace})
    total_modules = len(modules_to_modify)
    print(f"  Replacing layers in modules: {modules_to_modify}")
    relative_names = {
        mod_name: [name.split('.')[-1] for name in layer_names_to_replace if name.startswith(mod_name + '.')]
        for mod_name in modules_to_modify
    }

    # Process each module and print progress after each module is processed.
    for mod_index, mod_name in enumerate(modules_to_modify, start=1):
        print(f"\nProcessing module {mod_index}/{total_modules}: '{mod_name}'")
        parent_module = model.get_submodule(mod_name)
        names_in_module = relative_names[mod_name]
        original_linears = {name: layer for name, layer in find_layers(parent_module).items() if name in names_in_module}
        layers_to_replace_dict = {k: v for k, v in find_layers(parent_module).items() if k in names_in_module}

        if not layers_to_replace_dict:
            print(f"  No layers found in module '{mod_name}', skipping.")
            continue

        make_quant3(parent_module, layers_to_replace_dict)
        qlayers = find_layers(parent_module, [Quant3Linear])
        total_layers = len(qlayers)
        if total_layers == 0:
            print(f"  No Quant3Linear layers found in module '{mod_name}', skipping.")
            continue

        # Process each linear layer in the current module.
        layers_interval = max(1, total_layers // 5)  # Change interval as needed.
        for layer_index, (name, qlayer) in enumerate(qlayers.items(), start=1):
            if name not in names_in_module:
                continue
            full_name = f"{mod_name}.{name}"
            if full_name not in quantizers:
                continue
            if name not in original_linears:
                continue

            quantizer_data = quantizers[full_name].to('cpu')
            original_layer_ref = original_linears[name].to('cpu')
            try:
                qlayer.pack(original_layer_ref, quantizer_data.scale, quantizer_data.zero)
            except Exception as e:
                print(f"Error packing {full_name}: {e}")
                traceback.print_exc()

            # Print progress every few layers.
            if layer_index % layers_interval == 0 or layer_index == total_layers:
                print(f"  Processed {layer_index}/{total_layers} layers in module '{mod_name}'")

    print("\nPacking complete.")
    return model



# --- Main Execution Block ---
if __name__ == '__main__':
    args = parse_args()
    print("Parsed arguments:", args)

    DEV = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
    print(f"Using device: {DEV}, dtype: {torch_dtype}")

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if torch.cuda.is_available(): torch.cuda.manual_seed_all(args.seed)

    print(f"Loading model: {args.model_id}")
    model = AutoModelForCausalLM.from_pretrained(
        args.model_id,
        torch_dtype=torch_dtype,
        trust_remote_code=True,
        _attn_implementation='eager' # Keep eager as requested
    )
    processor = AutoProcessor.from_pretrained(args.model_id, trust_remote_code=True)
    model.eval()

    print("Loading calibration data...")
    captions_path = os.path.join(args.dataset_path, args.captions_file)
    image_dir = os.path.join(args.dataset_path, args.image_subdir)
    if not os.path.exists(captions_path): raise FileNotFoundError(f"Captions file not found: {captions_path}")
    if not os.path.isdir(image_dir): raise FileNotFoundError(f"Image directory not found: {image_dir}")

    try:
        df_captions = pd.read_csv(captions_path)
        if len(df_captions.columns) < 3: df_captions = pd.read_csv(captions_path, sep='\t', header=None)
        df_captions.columns = ['image_name', 'caption_index', 'caption']
    except Exception as e: raise ValueError(f"Error reading captions file {captions_path}: {e}")

    df_captions['caption'] = df_captions['caption'].astype(str).str.strip()
    captions_dict = df_captions.groupby('image_name')['caption'].apply(list).to_dict()
    all_image_names = list(captions_dict.keys())
    if not all_image_names: raise ValueError(f"No image names found in captions file: {captions_path}")

    num_available = len(all_image_names)
    num_to_sample = min(args.nsamples, num_available)
    print(f"Sampling {num_to_sample} calibration images.")
    sampled_image_names = random.sample(all_image_names, num_to_sample)

    calibration_data = []
    missing_files = 0
    for img_name in sampled_image_names:
        img_path = os.path.join(image_dir, img_name)
        if os.path.exists(img_path): calibration_data.append((img_path, captions_dict[img_name][0]))
        else: missing_files += 1
    if missing_files > 0: print(f"Warning: Skipped {missing_files} samples due to missing image files.")
    if not calibration_data: raise ValueError("No valid calibration data prepared.")
    print(f"Prepared {len(calibration_data)} calibration data pairs.")

    tick = time.time()
    model.to('cpu'); torch.cuda.empty_cache()
    model, quantizers = phi3_vision_sequential(model, processor, calibration_data, DEV, args)
    quant_time = time.time() - tick
    print(f"\nQuantization completed in {quant_time:.2f} seconds.")

    if args.save_dir and quantizers:
        print("\nProceeding with packing and saving...")
        model.to('cpu'); torch.cuda.empty_cache()

        pack_tick = time.time()
        model = phi3_vision_pack(model, quantizers)
        pack_time = time.time() - pack_tick
        print(f"Packing finished in {pack_time:.2f} seconds.")

        print(f"\nSaving packed model and processor to {args.save_dir}...")
        save_tick = time.time()
        os.makedirs(args.save_dir, exist_ok=True)

        # --- Temporary Cast Before Saving ---
        original_dtypes = {}
        try:
            print("Temporarily casting qweight buffers to int32 for saving...")
            for name, module in model.named_modules():
                if isinstance(module, Quant3Linear): # Use the actual class name used
                    if hasattr(module, 'qweight') and module.qweight.dtype == torch.uint32:
                        original_dtypes[name] = module.qweight.dtype
                        module.qweight = module.qweight.view(torch.int32) # Reinterpret bits as int32

            # Use safe_serialization=False to handle tied weights AND the lack of uint32 support
            model.save_pretrained(args.save_dir, safe_serialization=False)
            processor.save_pretrained(args.save_dir)
            save_time = time.time() - save_tick
            print(f"Model and processor saved successfully in {save_time:.2f} seconds.")

        except Exception as e:
            print(f"Error during saving: {e}")
            traceback.print_exc()
        finally:
            # --- Cast Back After Saving (Optional but good practice) ---
            print("Casting qweight buffers back to original dtype...")
            for name, module in model.named_modules():
                 if name in original_dtypes:
                     if hasattr(module, 'qweight') and module.qweight.dtype == torch.int32:
                         module.qweight = module.qweight.view(torch.uint32) # Reinterpret bits back to uint32
            print("Casting back complete.")


        if os.path.exists(os.path.join(args.save_dir, "pytorch_model.bin")):
             print(f"Saved artifacts found in {args.save_dir}")
        else:
             print(f"Warning: Expected files not found in save directory {args.save_dir}")

    elif not quantizers: print("\nQuantization failed. Model not packed or saved.")
    else: print("\nNo save directory specified (--save_dir). Skipping saving.")

    print("\n--- Quantization Script Finished ---")

In [None]:
!python quantize_phi3_vision.py \
    --dataset_path /kaggle/input/flickr30k \
    --captions_file captions.txt \
    --image_subdir flickr30k_images \
    --nsamples 16 \
    --save_dir /kaggle/working/phi3_vision_quantized_8bit_hf_sym \
    --sym  

In [None]:
# === INFERENCE CELL (Corrected Path + Memory Footprint) ===

import torch
from transformers import AutoModelForCausalLM, AutoProcessor
from PIL import Image
import os
import random
import pandas as pd
import traceback
import time
import sys

# --- Setup Paths ---
# Add path to custom modules NEEDED FOR LOADING the custom layer
module_dir_working = '/kaggle/working'
# Assuming gptq.py and modelutils.py are here now too based on previous error fixes
# module_dir_input = '/kaggle/input/gptqmain4' # Input for others (If needed)

if module_dir_working not in sys.path:
    sys.path.insert(0, module_dir_working)
# if module_dir_input not in sys.path: # Uncomment if gptq/modelutils are still in input
#     sys.path.insert(0, module_dir_input)

# --- Imports (Needed for loading custom code) ---
try:
    from gptq import GPTQ
    from modelutils import find_layers
    from quant import Quantizer, Quant3Linear, make_quant3
    import quant_cuda
    print("Custom modules for loading found.")
except ImportError as e:
    print(f"Error importing custom modules needed for loading: {e}")
except Exception as e:
    print(f"An unexpected error occurred during custom module import: {e}")

import warnings
warnings.filterwarnings('ignore')

# --- Configuration ---
# *** CORRECTED: POINT TO THE DIRECTORY SAVED BY save_pretrained ***
quantized_model_dir = '/kaggle/working/phi3_vision_quantized_8bit_hf_sym' # Default save_dir

original_model_id = "microsoft/Phi-3.5-vision-instruct"
dataset_path = '/kaggle/input/flickr30k'
captions_file_name = 'captions.txt'
image_subdir = 'flickr30k_images'

# --- Check if model saved ---
saved_weight_file = os.path.join(quantized_model_dir, "pytorch_model.bin") # Check for the correct file
if not os.path.exists(saved_weight_file):
    print(f"Error: Quantized model weights file ({saved_weight_file}) not found.")
    print(f"Please ensure the quantization script ran successfully and saved to: {quantized_model_dir}")
else:
    print(f"Found quantized model artifacts in {quantized_model_dir}. Proceeding with inference.")

    # --- Setup ---
    device = "cuda" if torch.cuda.is_available() else "cpu"
    torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
    print(f"Using device: {device}, dtype: {torch_dtype}")

    print("Loading quantized model and original processor...")
    load_tick = time.time()
    try:
        model = AutoModelForCausalLM.from_pretrained(
            quantized_model_dir,      # Load from the directory
            torch_dtype=torch_dtype,
            device_map="auto",        # Handles device placement
            trust_remote_code=True    # *** CRITICAL ***
        )
        processor = AutoProcessor.from_pretrained(original_model_id, trust_remote_code=True)
        load_time = time.time() - load_tick
        print(f"Model and processor loaded in {load_time:.2f} seconds.")
        model.eval() # Ensure model is in eval mode

        # --- Calculate and Print Memory Footprint ---
        try:
            mem_footprint_bytes = model.get_memory_footprint()
            mem_footprint_gb = mem_footprint_bytes / (1024**3)
            print(f"Model memory footprint (calculated by HF): {mem_footprint_gb:.2f} GB")
            if device == "cuda":
                 # Wait for model loading to finish if using device_map='auto'
                 torch.cuda.synchronize()
                 allocated_mem_gb = torch.cuda.memory_allocated() / (1024**3)
                 reserved_mem_gb = torch.cuda.memory_reserved() / (1024**3)
                 print(f"GPU Memory Allocated: {allocated_mem_gb:.2f} GB")
                 print(f"GPU Memory Reserved: {reserved_mem_gb:.2f} GB")
        except Exception as mem_e:
            print(f"Could not calculate memory footprint: {mem_e}")
        # --------------------------------------------

    except Exception as e:
        print(f"\nError loading the model or processor: {e}")
        traceback.print_exc()
        model = None

    # --- Prepare Sample Input ---
    if model:
        print("\nPreparing a sample input...")
        captions_path = os.path.join(dataset_path, captions_file_name)
        image_dir = os.path.join(dataset_path, image_subdir)
        test_img_path = None
        test_caption = "Describe the image."
        try:
            df_captions = pd.read_csv(captions_path)
            if len(df_captions.columns) < 3: df_captions = pd.read_csv(captions_path, sep='\t', header=None)
            df_captions.columns = ['image_name', 'caption_index', 'caption']
            df_captions['caption'] = df_captions['caption'].astype(str).str.strip()
            captions_dict = df_captions.groupby('image_name')['caption'].apply(list).to_dict()
            all_image_names = list(captions_dict.keys())
            if all_image_names:
                test_img_name = random.choice(all_image_names)
                test_img_path = os.path.join(image_dir, test_img_name)
                test_caption = captions_dict.get(test_img_name, [test_caption])[0]
                print(f"Selected sample: Image='{test_img_name}', Caption='{test_caption}'")
                if not os.path.exists(test_img_path): print(f"Error: Image not found: {test_img_path}"); test_img_path = None
            else: print("No images found in caption data.")
        except Exception as e: print(f"Error loading caption data: {e}")

        # --- Run Inference ---
        if test_img_path:
            try:
                image = Image.open(test_img_path).convert("RGB")
                placeholder = "<|image_1|>\n"; user_content = placeholder + str(test_caption).strip()
                messages = [{"role": "user", "content": user_content}]
                prompt = processor.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
                inputs = processor(prompt, [image], return_tensors="pt")
                inputs = {k: v.to(model.device) for k, v in inputs.items()} # Move to model's device
                generation_args = {"max_new_tokens": 500, "temperature": 0.0, "do_sample": False}

                print("\nRunning model.generate (using quantized layers)...")
                inf_tick = time.time()
                with torch.no_grad():
                    generate_ids = model.generate(**inputs, eos_token_id=processor.tokenizer.eos_token_id, **generation_args)
                inf_time = time.time() - inf_tick

                input_token_len = inputs['input_ids'].shape[1]
                generate_ids = generate_ids[:, input_token_len:]
                response = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]

                print(f"\nInference completed in {inf_time:.2f} seconds.")
                print("\n--- Inference Result ---")
                print(response)
                print("------------------------")
            except Exception as e: print(f"\nError during inference: {e}"); traceback.print_exc()
        else: print("\nSkipping inference due to issues preparing test sample.")

    print("\n--- Inference Cell Complete ---")