In [1]:
import bitsandbytes as bnb
import bitsandbytes.functional as bnb_func
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
def _build_input_weight(embed_dim: int, device: torch.device, dtype: torch.dtype):
    torch.manual_seed(0)
    input_weight = torch.empty(embed_dim, embed_dim, device=device, dtype=dtype)
    input_weight.normal_(0, 1)
    return input_weight


def _build_bnb_linear(input_weight, device):
    param = bnb.nn.Params4bit(input_weight, requires_grad=False, quant_type="nf4").cuda(
        device
    )
    bnb_linear = bnb.nn.LinearNF4(
        input_weight.size(0), input_weight.size(1), bias=False
    )
    bnb_linear.weight = param
    bnb_linear.to(device)
    return bnb_linear

In [3]:
torch.manual_seed(0)
dim = 512
device = "cuda"
dtype = torch.bfloat16
input_weight = _build_input_weight(dim, device, dtype)
nf4_weight = to_nf4(input_weight)
param = bnb.nn.Params4bit(input_weight, requires_grad=False, quant_type="nf4").cuda(
    device
)

In [26]:
quant_state = param.quant_state
with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA]) as prof:
    absmax = bnb_func.dequantize_blockwise(quant_state.absmax, quant_state.state2)
absmax += quant_state.offset
absmax = absmax.float()

In [28]:
print(prof.key_averages().table())

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                       cudaLaunchKernel        65.74%      42.913us        65.74%      42.913us      42.913us       0.000us         0.00%       0.000us       0.000us             1  
void kDequantizeBlockwise<float, 512, 64, 8, 0>(floa...         0.00%       0.000us         0.00%       0.000us       0.000us       1.792us       100.00%       1.792us       1.792us             1  
         

In [29]:
blockwise_code = bnb_func.create_dynamic_map().to(device)

In [33]:
quantized_scalers = quant_state.absmax
dquantized_scalers = blockwise_code.index_select(0, quantized_scalers.to(torch.long))

In [34]:
dq_scalers = dquantized_scalers.reshape(-1, 256) * quant_state.state2.absmax[:, None]

In [36]:
dq_scalers += quant_state.offset
dq_scalers = dq_scalers.float()
dq_scalers.view(-1)[:5]

tensor([2.6390, 3.1359, 2.2418, 3.3183, 2.0594], device='cuda:0')

In [37]:
absmax.view(-1)[:5]

tensor([2.6390, 3.1359, 2.2418, 3.3183, 2.0594], device='cuda:0')

In [39]:
blockwise_code[:5]

tensor([-0.9930, -0.9789, -0.9648, -0.9508, -0.9367], device='cuda:0')