In [1]:
## Implementing bitsandbytes nf4 with mlx
import mlx
import mlx.nn as nn
import mlx.core as mx

In [64]:
## Functions from https://medium.com/@id2thomas/ml-bitsandbytes-nf4-quantize-dequantize-analysis-1ad91d9912c9
def get_absmax(x):
	return mx.abs(x).max()

def get_quantile(x, data_type):
	dt_diffs = [mx.expand_dims(mx.abs(dt-x), 0) for dt in data_type]
	dt_diffs = mx.concatenate(dt_diffs, axis = 0)
	return mx.argmin(dt_diffs, axis = 0)

def simple_quant(x, absmax, data_type):
	c = 1/absmax
	scaled = x*c
	q = get_quantile(scaled, data_type)
	return q


def simple_dequant(x_q, absmax, data_type):
	dq = data_type[x_q]
	c = 1/absmax
	return dq/c

In [61]:
sample = mx.array([[1,1,2], [2,0,3]])
mx.argmin(sample, axis = 0)

array([0, 1, 0], dtype=uint32)

In [38]:
data_type = mx.array([-1.0000, -0.6962, -0.5251, -0.3949, -0.2844, -0.1848, -0.0911,  0.0000, 0.0796,  0.1609,  0.2461,  0.3379,  0.4407,  0.5626,  0.7230,  1.0000])
data_type.astype(mx.float32)

array([-1, -0.6962, -0.5251, ..., 0.5626, 0.723, 1], dtype=float32)

In [45]:
weight = mx.random.normal((32,8), mx.float32)
print(weight.shape)

[32, 8]


In [69]:
blocksize = 8
weight_block = weight[:blocksize, 0]
print("WEIGHT BLOCK", weight_block)
block_absmax = get_absmax(weight_block)
print("ABSMAX", block_absmax)
quantized_weight_block = simple_quant(weight_block, block_absmax, data_type)
print("QUANTIZED BLOPCK", quantized_weight_block)

WEIGHT BLOCK array([-1.176, -0.305002, 0.60144, ..., -1.27792, -0.760275, 1.37701], dtype=float32)
ABSMAX array(1.73652, dtype=float32)
QUANTIZED BLOPCK array([1, 5, 11, ..., 1, 3, 14], dtype=uint32)


In [71]:
## Dequantize
dequantized_weight_block = simple_dequant(quantized_weight_block, block_absmax, data_type)
print("DEQUANTIZED BLOPCK", dequantized_weight_block)

DEQUANTIZED BLOPCK array([-1.20896, -0.320908, 0.586769, ..., -1.20896, -0.685751, 1.2555], dtype=float32)


In [72]:
mx.sum(dequantized_weight_block-weight_block)

array(-0.104853, dtype=float32)