<a href="https://colab.research.google.com/github/karankulshrestha/ai-notebooks/blob/main/fast_dequantize.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install bitsandbytes

Collecting bitsandbytes
  Downloading bitsandbytes-0.45.2-py3-none-manylinux_2_24_x86_64.whl.metadata (5.8 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch<3,>=2.0->bitsandbytes)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch<3,>=2.0->bitsandbytes)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch<3,>=2.0->bitsandbytes)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch<3,>=2.0->bitsandbytes)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch<3,>=2.0->bitsandbytes)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-

In [35]:
import bitsandbytes as bnb

global CUDA_STREAM
CUDA_STREAM = None
get_ptr = bnb.functional.get_ptr
import ctypes
ctypes_c_int   = ctypes.c_int
ctypes_c_int32 = ctypes.c_int32
cdequantize_blockwise_fp32      = bnb.functional.lib.cdequantize_blockwise_fp32
cdequantize_blockwise_fp16_nf4  = bnb.functional.lib.cdequantize_blockwise_fp16_nf4
cdequantize_blockwise_bf16_nf4  = bnb.functional.lib.cdequantize_blockwise_bf16_nf4
cgemm_4bit_inference_naive_fp16 = bnb.functional.lib.cgemm_4bit_inference_naive_fp16
cgemm_4bit_inference_naive_bf16 = bnb.functional.lib.cgemm_4bit_inference_naive_bf16

In [36]:
import torch

In [37]:
global WEIGHT_BUFFER
WEIGHT_BUFFER = None
global ABSMAX_BUFFER
ABSMAX_BUFFER = None

In [38]:
@torch.inference_mode
def fast_dequantize(W, quant_state=None, out=None, use_global_buffer=None):
  if quant_state is None: return W

  if isinstance(quant_state, list):
    absmax, shape, dtype, blocksize, compressed_stats, _, _ = quant_state
    offset, state2 = compressed_stats
    absmax2, code2, blocksize2, _, _, _, _ = state2
  else:
    absmax = quant_state.absmax
    shape = quant_state.shape
    dtype = quant_state.dtype
    blocksize = quant_state.blocksize
    offset = quant_state.offset
    state2 = quant_state.state2
    code2 = state2.code
    absmax2 = state2.absmax
    blocksize2 = state2.blocksize

  global CUDA_STREAM
  if CUDA_STREAM is None:
    CUDA_STREAM = torch.cuda.current_stream("cuda:0")

  n_elements_absmax = absmax.numel()

  if use_global_buffer:

    size = shape[0] * shape[1]
    global WEIGHT_BUFFER, ABSMAX_BUFFER

    if WEIGHT_BUFFER is None:
      WEIGHT_BUFFER = torch.empty(size, dtype=dtype, device="cuda:0", requires_grad=False)
      ABSMAX_BUFFER = torch.empty(n_elements_absmax, dtype=torch.float32, device="cuda:0", requires_grad=False)

    if size > WEIGHT_BUFFER.numel():
      WEIGHT_BUFFER.resize_(size)

    if n_elements_absmax > ABSMAX_BUFFER.numel():
      ABSMAX_BUFFER.resize_(n_elements_absmax)

    out = WEIGHT_BUFFER[:size].view(shape)
    out_absmax = ABSMAX_BUFFER[:n_elements_absmax]

  else:

    if out is None:
      out = torch.empty(shape, dtype=dtype, device="cuda:0", requires_grad=False)

    else:
      assert out.shape == shape and out.dtype == dtype

    out_absmax = torch.empty(n_elements_absmax, dtype=torch.float32, device="cuda:0", requires_grad=False)

  ptr_out_absmax = get_ptr(out_absmax)
  cdequantize_blockwise_fp32(
      get_ptr(code2), get_ptr(absmax), get_ptr(absmax2), ptr_out_absmax,
      ctypes_c_int(blocksize2), ctypes_c_int(n_elements_absmax), CUDA_STREAM,
  )

  out_absmax += offset

  dequantize_fn = cdequantize_blockwise_fp16_nf4 if dtype == torch.float16 else cdequantize_blockwise_bf16_nf4

  dequantize_fn(
      get_ptr(None), get_ptr(W), ptr_out_absmax, get_ptr(out),
      ctypes_c_int(blocksize), ctypes_c_int(out.numel()), CUDA_STREAM,
  )

  is_transposed = (True if W.shape[0] == 1 else False)

  return out.t() if is_transposed else out
