In [2]:
import torch
from torch import Tensor

from paretoq_qat import replace_linear_with_quantized_linear


def quantize_lsq_binary_ternary_extension(
    input: Tensor,
    alpha: Tensor,
    num_bits: int,
) -> tuple[Tensor, Tensor]:
    if num_bits >= 16:
        raise NotImplementedError

    if num_bits == 1 or num_bits == 0:
        Qn = -1
        Qp = 1
    else:
        Qn = -(2 ** (num_bits - 1))
        Qp = 2 ** (num_bits - 1) - 1

    eps = torch.tensor(0.00001, device=alpha.device).float()

    alpha = torch.where(alpha > eps, alpha, eps)

    if num_bits == 1:
        q_w = input.sign()
    else:
        q_w = (input / alpha).round().clamp(Qn, Qp)

    return q_w, alpha


def quantize_stretched_elastic_quant(
    input: Tensor,
    alpha: Tensor,
    num_bits: int,
) -> tuple[Tensor, Tensor]:
    if num_bits >= 16:
        raise NotImplementedError

    eps = torch.tensor(0.00001, device=alpha.device).float()
    alpha = torch.where(alpha > eps, alpha, eps)

    clip_val = 1 - 1e-2
    if num_bits == 0:
        n_levels = 1.5
        shift = 0
    else:
        n_levels = 2 ** (num_bits - 1)
        shift = 0.5  # type: ignore

    if num_bits == 1:
        q_w = input.sign()
    else:
        q_w = (
            torch.round(torch.clamp(input / alpha, -clip_val, clip_val) * n_levels - shift) + shift
        ) / n_levels

    return q_w, alpha


def quantize(weight: Tensor, weight_clip_val: Tensor, w_bits: int) -> tuple[Tensor, Tensor]:
    if w_bits == 2 or w_bits == 0:
        return quantize_stretched_elastic_quant(
            weight,
            weight_clip_val,
            w_bits,
        )
    elif w_bits <= 4:
        return quantize_lsq_binary_ternary_extension(
            weight,
            weight_clip_val,
            w_bits,
        )
    else:
        raise NotImplementedError

In [None]:
from hqq.core.quantize import BaseQuantizeConfig, HQQLinear, Quantizer
from torch import Tensor, nn


def update_quantized_linear(
    linear: nn.Linear,
    weight: Tensor,
    weight_clip_val: Tensor,
    w_bits: int,
) -> HQQLinear:
    data, scale = quantize(weight, weight_clip_val, w_bits)
    hqq_linear = HQQLinear(
        linear,
        BaseQuantizeConfig(nbits=4, group_size=None, axis=1),  # type: ignore
        compute_dtype=torch.bfloat16,
        device=linear.weight.device,  # type: ignore
        initialize=True,
    )

    hqq_linear.meta["scale"] = scale  # type: ignore
    hqq_linear.meta["zero"] = torch.full_like(scale, -data.min())  # type: ignore

    with torch.no_grad():
        hqq_linear.W_q.set_(Quantizer.pack[hqq_linear.meta["packing"]]((data - data.min()).float()))  # type: ignore

    return hqq_linear


def quantize_model_after_qat(
    model: nn.Module,
    qat_state_dict: dict[str, Tensor],
    w_bits: int,
) -> None:
    for name, module in model.named_modules():
        if isinstance(module, nn.Linear):
            parent = model.get_submodule(".".join(name.split(".")[:-1]))
            weight = qat_state_dict[name + ".weight"]
            weight_clip_val = qat_state_dict[name + ".weight_clip_val"]
            setattr(
                parent,
                name.split(".")[-1],
                update_quantized_linear(module, weight, weight_clip_val, w_bits),
            )

In [66]:
from copy import deepcopy


class DummyModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = torch.nn.Linear(512, 1024)

    def forward(self, x):
        return self.linear(x)


model1 = DummyModule().to("cpu", torch.bfloat16)
model2 = deepcopy(model1)

replace_linear_with_quantized_linear(model2, w_bits=4)
quantize_model_after_qat(model1, model2.state_dict(), w_bits=4)

x = torch.randn(1, 512, device="cpu", dtype=torch.bfloat16)

y1 = model1(x)
y2 = model2(x)

torch.allclose(y1, y2)

True

In [68]:
from transformers import AutoModelForCausalLM  # type: ignore

model1 = AutoModelForCausalLM.from_pretrained(
    "Qwen/Qwen3-0.6B",
    torch_dtype=torch.bfloat16,
    device_map="cpu",
)

model2 = AutoModelForCausalLM.from_pretrained(
    "Qwen/Qwen3-0.6B",
    torch_dtype=torch.bfloat16,
    device_map="cpu",
)

replace_linear_with_quantized_linear(model2.model, w_bits=4)
quantize_model_after_qat(model1.model, model2.model.state_dict(), w_bits=4)

input_ids = torch.arange(32)[None]

with torch.no_grad():
    out1 = model1(input_ids)
    out2 = model2(input_ids)

torch.allclose(out1.logits, out2.logits)

True

In [69]:
model1

Qwen3ForCausalLM(
  (model): Qwen3Model(
    (embed_tokens): Embedding(151936, 1024)
    (layers): ModuleList(
      (0-27): 28 x Qwen3DecoderLayer(
        (self_attn): Qwen3Attention(
          (q_proj): HQQLinear(in_features=1024, out_features=2048, bias=False)
          (k_proj): HQQLinear(in_features=1024, out_features=1024, bias=False)
          (v_proj): HQQLinear(in_features=1024, out_features=1024, bias=False)
          (o_proj): HQQLinear(in_features=2048, out_features=1024, bias=False)
          (q_norm): Qwen3RMSNorm((128,), eps=1e-06)
          (k_norm): Qwen3RMSNorm((128,), eps=1e-06)
        )
        (mlp): Qwen3MLP(
          (gate_proj): HQQLinear(in_features=1024, out_features=3072, bias=False)
          (up_proj): HQQLinear(in_features=1024, out_features=3072, bias=False)
          (down_proj): HQQLinear(in_features=3072, out_features=1024, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): Qwen3RMSNorm((1024,), eps=1e-06)
        (post_atte

In [90]:
from hqq.models.hf.base import AutoHQQHFModel

model = AutoModelForCausalLM.from_pretrained(
    "Qwen/Qwen3-0.6B",
    torch_dtype=torch.bfloat16,
    device_map="cpu",
)
with torch.no_grad():
    model.lm_head.weight = torch.nn.Parameter(model.lm_head.weight.data.clone())

quant_config = {
    "weight_quant_params": {
        "nbits": 4,
        "channel_wise": True,
        "group_size": None,
        "optimize": False,
        "round_zero": True,
        "axis": 1,
        "view_as_float": False,
    },
    "scale_quant_params": None,
    "zero_quant_params": None,
    "offload_meta": False,
}
AutoHQQHFModel.quantize_model(
    model,
    quant_config=quant_config,
    compute_dtype=torch.bfloat16,
    device="cpu",
)

AutoHQQHFModel.save_to_safetensors(model, "test")

100%|██████████| 143/143 [00:00<00:00, 101280.90it/s]
100%|██████████| 197/197 [00:00<00:00, 376.11it/s]
mkdir: test/: File exists


saving 1 : 685 / 3839
saving 2 : 685 / 3839
saving 3 : 685 / 3839
saving 4 : 685 / 3839
saving 5 : 1099 / 3839


In [91]:
from transformers import AutoModelForCausalLM  # type: ignore

model1 = AutoModelForCausalLM.from_pretrained(
    "test",
    torch_dtype=torch.bfloat16,
    device_map="cpu",
)

model2 = AutoModelForCausalLM.from_pretrained(
    "Qwen/Qwen3-0.6B",
    torch_dtype=torch.bfloat16,
    device_map="cpu",
)

replace_linear_with_quantized_linear(model2.model, w_bits=4)

input_ids = torch.arange(32)[None]

with torch.no_grad():
    out1 = model1(input_ids)
    out2 = model2(input_ids)

torch.allclose(out1.logits, out2.logits)

Loading checkpoint shards:   0%|          | 0/5 [00:00<?, ?it/s]

Some weights of the model checkpoint at test were not used when initializing Qwen3ForCausalLM: ['model.layers.0.mlp.down_proj.W_q', 'model.layers.0.mlp.down_proj.axis', 'model.layers.0.mlp.down_proj.channel_wise', 'model.layers.0.mlp.down_proj.compute_dtype', 'model.layers.0.mlp.down_proj.encoded_state_dict', 'model.layers.0.mlp.down_proj.group_size', 'model.layers.0.mlp.down_proj.nbits', 'model.layers.0.mlp.down_proj.offload_meta', 'model.layers.0.mlp.down_proj.optimize', 'model.layers.0.mlp.down_proj.packing', 'model.layers.0.mlp.down_proj.quant_scale', 'model.layers.0.mlp.down_proj.quant_zero', 'model.layers.0.mlp.down_proj.round_zero', 'model.layers.0.mlp.down_proj.scale', 'model.layers.0.mlp.down_proj.shape', 'model.layers.0.mlp.down_proj.stores_quant_config', 'model.layers.0.mlp.down_proj.unpack_view_dtype', 'model.layers.0.mlp.down_proj.view_as_float', 'model.layers.0.mlp.down_proj.zero', 'model.layers.0.mlp.gate_proj.W_q', 'model.layers.0.mlp.gate_proj.axis', 'model.layers.0.ml

False