In [None]:
import torch
from torch import Tensor

from paretoq_qat import replace_linear_with_quantized_linear


def quantize_lsq_binary_ternary_extension(
    input: Tensor,
    alpha: Tensor,
    num_bits: int,
) -> tuple[Tensor, Tensor]:
    if num_bits >= 16:
        raise NotImplementedError

    if num_bits == 1 or num_bits == 0:
        Qn = -1
        Qp = 1
    else:
        Qn = -(2 ** (num_bits - 1))
        Qp = 2 ** (num_bits - 1) - 1

    eps = torch.tensor(0.00001, device=alpha.device).float()

    alpha = torch.where(alpha > eps, alpha, eps)

    if num_bits == 1:
        q_w = input.sign()
    else:
        q_w = (input / alpha).round().clamp(Qn, Qp)

    return q_w, alpha


def quantize_stretched_elastic_quant(
    input: Tensor,
    alpha: Tensor,
    num_bits: int,
) -> tuple[Tensor, Tensor]:
    if num_bits >= 16:
        raise NotImplementedError

    eps = torch.tensor(0.00001, device=alpha.device).float()
    alpha = torch.where(alpha > eps, alpha, eps)

    clip_val = 1 - 1e-2
    if num_bits == 0:
        n_levels = 1.5
        shift = 0
    else:
        n_levels = 2 ** (num_bits - 1)
        shift = 0.5  # type: ignore

    if num_bits == 1:
        q_w = input.sign()
    else:
        q_w = (
            torch.round(torch.clamp(input / alpha, -clip_val, clip_val) * n_levels - shift) + shift
        ) / n_levels

    return q_w, alpha


def quantize(weight: Tensor, weight_clip_val: Tensor, w_bits: int) -> tuple[Tensor, Tensor]:
    if w_bits == 2 or w_bits == 0:
        return quantize_stretched_elastic_quant(
            weight,
            weight_clip_val,
            w_bits,
        )
    elif w_bits <= 4:
        return quantize_lsq_binary_ternary_extension(
            weight,
            weight_clip_val,
            w_bits,
        )
    else:
        raise NotImplementedError

In [None]:
from copy import deepcopy

from transformers import AutoModelForCausalLM  # type: ignore

model = AutoModelForCausalLM.from_pretrained(
    "Qwen/Qwen3-0.6B",
    torch_dtype=torch.bfloat16,
    device_map="cpu",
)

replace_linear_with_quantized_linear(model, w_bits=4)

model_clone = deepcopy(model)

In [None]:
def pack_int4_to_int32(weight: Tensor) -> Tensor:
    int4 = weight.int() & 0xF
    *rest, last_dim = int4.shape
    assert last_dim % 8 == 0, "Last dim size must be divisible by 8."
    int4 = int4.view(*rest, last_dim // 8, 8)

    packed = torch.zeros(*rest, last_dim // 8, dtype=torch.int32, device=weight.device)
    for i in range(8):
        packed |= int4[..., i] << (4 * i)

    return packed


for module in model.modules():
    if isinstance(module, torch.nn.Linear):
        if hasattr(module, "weight_clip_val"):
            quantized_weight, weight_scale = quantize(
                module.weight,
                module.weight_clip_val,  # type: ignore
                w_bits=4,
            )
            del module.weight
            del module.weight_clip_val

            module.register_buffer("weight_q_packed", pack_int4_to_int32(quantized_weight))
            module.register_buffer("weight_scale", weight_scale)
            module.register_buffer("weight_bits", torch.tensor(4, dtype=torch.int8))

save_dir = "Qwen3-0.6B-int4"
model.save_pretrained(save_dir, safe_serialization=True)