In [2]:
import numpy as np

def dequantize_4bit(B_quant, scales, zero_points=None, block_size=32):
    N, num_blocks, bytes_per_block = B_quant.shape
    K = num_blocks * block_size  

    if scales.ndim == 1:
        scales = scales.reshape(N, num_blocks)

    low_4bit = B_quant & 0x0F  
    high_4bit = (B_quant >> 4) & 0x0F  
    
    quant_values = np.stack([low_4bit, high_4bit], axis=-1)  
    quant_values = quant_values.reshape(N, num_blocks, bytes_per_block * 2)  

    if zero_points is None:
        zero_point = 8
        B_dequant = (quant_values * scales[:, :, np.newaxis]) - (scales[:, :, np.newaxis] * zero_point)
    else:
        zero_points = zero_points.astype(np.float32)
        B_dequant = (quant_values * scales[:, :, np.newaxis]) - (scales[:, :, np.newaxis] * zero_points)

    B_dequant = B_dequant.reshape(N, K)
    B_dequant_T = B_dequant.T  
    return B_dequant_T

def matmul_nbits(A, B_quant, scales, zero_points=None, block_size=32):
    B_dequant_T = dequantize_4bit(B_quant, scales, zero_points, block_size)
    Y = np.matmul(A, B_dequant_T)  
    return Y


Output shape: (2, 128, 9216)
Output matches expected: True


In [None]:
# Test 0
X = np.load("model.layers.0.input_layernorm.output_0.npy")  
W = np.load("model.layers.0.attn.qkv_proj.MatMul.weight_Q4.npy")  
scales = np.load("model.layers.0.attn.qkv_proj.MatMul.weight_scales.npy")  

if scales.ndim == 1:
    scales = scales.reshape(W.shape[0], W.shape[1])

zero_points = None
Y = matmul_nbits(X, W, scales, zero_points, block_size=32)
print("Output shape:", Y.shape)

expected_output = np.load("model.layers.0.attn.qkv_proj.MatMul.output_0.npy")
print("Output matches expected:", np.allclose(Y, expected_output, atol=1e-3))

In [4]:
# Test 1
X = np.load("model.layers.32.final_norm_layernorm.output_0.npy")  # Shape: [2, 128, 3072]
W = np.load("lm_head.MatMul.weight_Q4.npy")  # Shape: [32064, 96, 16]
scales = np.load("lm_head.MatMul.weight_scales.npy")  # Shape: [32064, 96]

if scales.ndim == 1:
    scales = scales.reshape(W.shape[0], W.shape[1])

zero_points = None
Y = matmul_nbits(X, W, scales, zero_points, block_size=32)
print("Output shape:", Y.shape)

expected_output = np.load("logits.npy")
print("Output matches expected:", np.allclose(Y, expected_output, atol=1e-3))

Output shape: (2, 128, 32064)
Output matches expected: True
