https://id2thomas.medium.com/ml-bitsandbytes-nf4-quantize-dequantize-analysis-1ad91d9912c9

- Normalization: The weights of the model are normalized so that we expect the weights to fall within a certain range. This allows for more efficient representation of more common values.

- Quantization: The weights are quantized to 4-bit. In NF4, the quantization levels are evenly spaced with respect to the normalized weights, thereby efficiently representing the original 32-bit weights.

- Dequantization: Although the weights are stored in 4-bit, they are dequantized during **computation** which gives a performance boost during inference.

In [2]:
import os
os.environ['http_proxy'] = 'http://127.0.0.1:7890'
os.environ['https_proxy'] = 'http://127.0.0.1:7890'

In [1]:
import torch
from transformers import AutoModelForCausalLM, BitsAndBytesConfig

bnb_config = BitsAndBytesConfig(
		load_in_4bit=True,
		bnb_4bit_use_double_quant=False,
		bnb_4bit_quant_type="nf4",
		bnb_4bit_compute_dtype=torch.bfloat16
	)

In [3]:
gpu_num = 0
model = AutoModelForCausalLM.from_pretrained(
	"meta-llama/Meta-Llama-3-8B-Instruct", 
	device_map = {"": "cuda:" + str(gpu_num)},
	quantization_config=bnb_config
)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

You are calling `save_pretrained` to a 4-bit converted model, but your `bitsandbytes` version doesn't support it. If you want to save 4-bit models, make sure to have `bitsandbytes>=0.41.3` installed.


In [4]:
model

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear4bit(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear4bit(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear4bit(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): Ll

In [15]:
W = model.model.layers[0].self_attn.q_proj.weight
W

Parameter containing:
Parameter(Params4bit([[ 97],
            [101],
            [110],
            ...,
            [119],
            [ 88],
            [119]], device='cuda:0', dtype=torch.uint8))

In [17]:
W.dtype, W.shape

(torch.uint8, torch.Size([8388608, 1]))

In [14]:
torch.min(W), torch.max(W)

(tensor(0, device='cuda:0', dtype=torch.uint8),
 tensor(255, device='cuda:0', dtype=torch.uint8))

In [19]:
print([v[0] for v in W[:2].cpu().numpy().tolist()])
# [97, 101]
print(format(W[0].item(), '08b'))
# 01100001
print(format(W[1].item(), '08b'))
# 01100101

[97, 101]
01100001
01100101


In [21]:
absmax, shape, dtype, blocksize, compressed_stats, quant_type, data_type = W.quant_state

In [23]:
print("absmax", absmax, absmax.shape)

print("shape", shape) # (4096,4096) -> original shape
print("dtype", dtype) # dtype torch.float16
print("blocksize", blocksize) # 64
print("quant_type", quant_type) # quant_type nf4
print("data_type", data_type, len(data_type))

absmax tensor([0.0405, 0.0386, 0.0376,  ..., 0.0581, 0.0957, 0.1079], device='cuda:0') torch.Size([262144])
shape torch.Size([4096, 4096])
dtype torch.float16
blocksize 64
quant_type nf4
data_type tensor([-1.0000, -0.6962, -0.5251, -0.3949, -0.2844, -0.1848, -0.0911,  0.0000,
         0.0796,  0.1609,  0.2461,  0.3379,  0.4407,  0.5626,  0.7230,  1.0000],
       device='cuda:0') 16


In [25]:
import bitsandbytes.functional as F
dequantized = F.dequantize_4bit(W, W.quant_state).to(torch.bfloat16)
print(dequantized.shape) # (4096, 4096) - (3*hidden_size, hidden_size)
print(dequantized[0])


torch.Size([4096, 4096])
tensor([-0.0037, -0.0282, -0.0037,  ...,  0.0078, -0.0483, -0.0190],
       device='cuda:0', dtype=torch.bfloat16)


In [26]:
def get_absmax(x):
	return max(abs(x))

weight_block = dequantized[0][:blocksize].clone().cpu()
absmax = get_absmax(weight_block)
print("absmax of block:", absmax)
# absmax of block: tensor(0.0491, dtype=torch.bfloat16)

absmax of block: tensor(0.0405, dtype=torch.bfloat16)
