In [17]:
import torch

In [16]:
compressed = torch.load("../layer30_v_proj_weight_packed.pt",map_location=torch.device('cpu'))
quantized = torch.load("../layer30_v_proj_weight.pt",map_location=torch.device('cpu'))

In [56]:
def pseudo_quantize_tensor(
    w, n_bit=4, zero_point=True, q_group_size=128, inplace=False, get_scale_zp=True
):
    org_w_shape = w.shape
    if q_group_size > 0:
        assert org_w_shape[-1] % q_group_size == 0
        w = w.reshape(-1, q_group_size)
    assert w.dim() == 2
    if zero_point:
        max_val = w.amax(dim=1, keepdim=True)
        min_val = w.amin(dim=1, keepdim=True)
        max_int = 2**n_bit - 1
        min_int = 0
        scales = (max_val - min_val).clamp(min=1e-5) / max_int
        zeros = (-torch.round(min_val / scales)).clamp_(min_int, max_int)
    else:  # we actually never used this
        assert min_val is None
        max_val = w.abs().amax(dim=1, keepdim=True)
        max_val = max_val.clamp(min=1e-5)
        max_int = 2 ** (n_bit - 1) - 1
        min_int = -(2 ** (n_bit - 1))
        scales = max_val / max_int
        zeros = 0

    assert torch.isnan(scales).sum() == 0
    assert torch.isnan(w).sum() == 0

    if inplace:
        (
            (w.div_(scales).round_().add_(zeros)).clamp_(min_int, max_int).sub_(zeros)
        ).mul_(scales)
    else:
        w = (
            torch.clamp(torch.round(w / scales) + zeros, min_int, max_int) - zeros
        ) * scales
    assert torch.isnan(w).sum() == 0

    w = w.reshape(org_w_shape)
    print(scales.shape)
    scales = scales.view(w.shape[0], -1)
    zeros = zeros.view(w.shape[0], -1)
    intweight = []
    scale_zeros = zeros * scales
    for idx in range(w.shape[1]):
        intweight.append(
            torch.round(
                (w[:, idx] + scale_zeros[:, idx // q_group_size])
                / scales[:, idx // q_group_size]
            ).to(torch.int)[:, None]
        )
    intweight = torch.cat(intweight, dim=1)
    # intweight = intweight.t().contiguous()
    intweight = intweight.to(dtype=torch.int32).reshape([intweight.shape[0], -1, q_group_size])
    scales = scales.view(w.shape[0], -1,1)
    zeros = zeros.view(w.shape[0], -1,1)
    if get_scale_zp:
        return w, scales, zeros, intweight
    else:
        return w

In [49]:
def compress(self, in_ch_wise=False, **kwargs):
    data_shape = self.data.shape
    group_size = -1
    if 'group_size' in kwargs:
        group_size = kwargs.pop('group_size')
    out_ch = data_shape[0]
    in_ch = data_shape[1]

    quant = Quantizer()
    quant.configure(**kwargs)
    if in_ch_wise == False:
        data = self.data
        if group_size > 0:
            data = data.reshape([-1, group_size])
        quant.find_params(data, weight=True) #scale
        quant_data  = torch.clamp(torch.round(data / quant.scale) + quant.zero, 0, quant.maxq)
        quant_data  = quant_data.reshape([out_ch, -1, group_size]).to(torch.int)
        quant.scale = quant.scale.reshape([out_ch, -1, 1])
        quant.zero  = quant.zero.reshape([out_ch, -1, 1])
    else:
        data = self.data.T
        if group_size > 0:
            data = data.reshape([-1, group_size])
        quant.find_params(data, weight=True)
        quant_data = torch.clamp(torch.round(data / quant.scale) + quant.zero, 0, quant.maxq)
        quant_data = quant_data.reshape([in_ch, -1, group_size]).to(torch.int)
        quant.scale = quant.scale.reshape([in_ch, -1, 1])
        quant.zero  = quant.zero.reshape([in_ch, -1, 1])

    return quant.scale, quant.zero, quant_data, quant_data.shape

In [50]:
import torch
import torch.nn as nn
import numpy as np

#from clet.functions.rtn import Quantizer
from utils import CompressionParameter, PACKER, Quantizer
from bcq_parameter import BCQParameter

def pack(self, b):
    shape = b.shape
    p_b = b
    if torch.cuda.is_available():
        p_b = p_b.cuda()
    p_b = (p_b + 1) / 2  # (-1., +1.) -> (0, 1)
    p_b = torch.reshape(p_b, [8, -1]).type(torch.uint8)
    p_b = p_b * self.s
    p_b = p_b.sum(0)
    p_b = p_b.type(torch.uint8)
    return p_b, shape

def convert_bcq_format(self, scale, zero, quant_data, qbits, do_packing=False, in_ch_wise=False):
    global PACKER

    zero   = scale * zero
    upack  = torch.Tensor([[2**i for i in range(qbits)]])
    scale  = scale / 2.0
    scale  = torch.matmul(scale, upack)

    offset = scale.sum(-1).unsqueeze(-1) - zero

    binary = torch.zeros(list(quant_data.shape) + [qbits])
    binary_shape = binary.shape
    for i in range(qbits):
        binary[:, :, :, i] = ((quant_data >> i) & 1) * 2.0 - 1.0

    if do_packing == True:
        binary, binary_shape = PACKER.pack(binary)
        binary = binary.to(self.data.device)

    return scale, binary, binary_shape, offset

class RTNParameter(CompressionParameter):
    def compress(self, in_ch_wise=False, **kwargs):
        data_shape = self.data.shape
        group_size = -1
        if 'group_size' in kwargs:
            group_size = kwargs.pop('group_size')
        out_ch = data_shape[0]
        in_ch = data_shape[1]

        quant = Quantizer()
        quant.configure(**kwargs)
        if in_ch_wise == False:
            data = self.data
            if group_size > 0:
                data = data.reshape([-1, group_size])
            quant.find_params(data, weight=True) #scale
            quant_data  = torch.clamp(torch.round(data / quant.scale) + quant.zero, 0, quant.maxq)
            quant_data  = quant_data.reshape([out_ch, -1, group_size]).to(torch.int)
            quant.scale = quant.scale.reshape([out_ch, -1, 1])
            quant.zero  = quant.zero.reshape([out_ch, -1, 1])
        else:
            data = self.data.T
            if group_size > 0:
                data = data.reshape([-1, group_size])
            quant.find_params(data, weight=True)
            quant_data = torch.clamp(torch.round(data / quant.scale) + quant.zero, 0, quant.maxq)
            quant_data = quant_data.reshape([in_ch, -1, group_size]).to(torch.int)
            quant.scale = quant.scale.reshape([in_ch, -1, 1])
            quant.zero  = quant.zero.reshape([in_ch, -1, 1])

        return quant.scale, quant.zero, quant_data, quant_data.shape

    def decompress(self, scale, zero, quant_data, quant_data_shape, in_ch_wise=False):
        # w.shape = [out_ch, in_ch]
        # in_ch_wise == True
        #   -> quant_data.shape = [in_ch, out_ch//group_size, group_size]
        #   -> scale.shape      = [in_ch, out_ch//group_size, 1]
        #   -> zero.shape       = [in_ch, out_ch//group_size, 1]
        # in_ch_wise == False
        #   -> quant_data.shape = [out_ch, in_ch//group_size, group_size]
        #   -> scale.shape      = [out_ch, in_ch//group_size, 1]
        #   -> zero.shape       = [out_ch, in_ch//group_size, 1]

        if in_ch_wise == True:
            out_ch = quant_data_shape[1] * quant_data_shape[2]
            decomp_w = scale * (quant_data - zero)
            decomp_w = decomp_w.reshape([-1, out_ch]).T
        else:
            out_ch = quant_data_shape[0]
            decomp_w = scale * (quant_data - zero)
            decomp_w = decomp_w.reshape([out_ch, -1])
        self.data = decomp_w

    def convert_bcq_format(self, scale, zero, quant_data, qbits, do_packing=False, in_ch_wise=False):
        global PACKER

        zero   = scale * zero
        upack  = torch.Tensor([[2**i for i in range(qbits)]])
        scale  = scale / 2.0
        scale  = torch.matmul(scale, upack)

        offset = scale.sum(-1).unsqueeze(-1) - zero

        binary = torch.zeros(list(quant_data.shape) + [qbits])
        binary_shape = binary.shape
        for i in range(qbits):
            binary[:, :, :, i] = ((quant_data >> i) & 1) * 2.0 - 1.0

        if do_packing == True:
            binary, binary_shape = PACKER.pack(binary)
            binary = binary.to(self.data.device)

        return scale, binary, binary_shape, offset


In [60]:
if __name__ == '__main__':
    w_org = torch.randn(5120,17920)

    # INT4 Quantization -> RTN
    w_rtn = RTNParameter(w_org)
    scale, zero, w_quant, w_quant_shape = w_rtn.compress(in_ch_wise=False, qbits=4, group_size=128, perchannel=True, sym=False)
    w, s, z ,i= pseudo_quantize_tensor(w_org)

    print(scale.shape,zero.shape,w_quant.shape)
    print(scale.dtype,zero.dtype,w_quant.dtype)

    print(zero-z)
    print(scale-s)
    print(i-w_quant)

    #print(z)

    #w_rtn.decompress(scale, zero, w_quant, w_quant_shape, in_ch_wise=False)
    #print(abs(w_org-w_rtn.data).mean())

    # Convert INT4 -> BCQ4
    #alpha, binary, binary_shape, offset = w_rtn.convert_bcq_format(scale, zero, w_quant, qbits=3, do_packing=True, in_ch_wise=False)
    
    #print(binary.size(),alpha.size(),offset.size())
    
    #print(zero)

    #print(offset)
    # BCQ Decompress Check
    #w_bcq = BCQParameter(w_org)
    #w_bcq.decompress(alpha, binary, binary_shape, offset=offset, do_packing=True, in_ch_wise=False)
    #print(abs(w_bcq.data - w_rtn.data).mean())

torch.Size([716800, 1])
torch.Size([5120, 140, 1]) torch.Size([5120, 140, 1]) torch.Size([5120, 140, 128])
torch.float32 torch.float32 torch.int32
tensor([[[0.],
         [0.],
         [0.],
         ...,
         [0.],
         [0.],
         [0.]],

        [[0.],
         [0.],
         [0.],
         ...,
         [0.],
         [0.],
         [0.]],

        [[0.],
         [0.],
         [0.],
         ...,
         [0.],
         [0.],
         [0.]],

        ...,

        [[0.],
         [0.],
         [0.],
         ...,
         [0.],
         [0.],
         [0.]],

        [[0.],
         [0.],
         [0.],
         ...,
         [0.],
         [0.],
         [0.]],

        [[0.],
         [0.],
         [0.],
         ...,
         [0.],
         [0.],
         [0.]]])
tensor([[[0.],
         [0.],
         [0.],
         ...,
         [0.],
         [0.],
         [0.]],

        [[0.],
         [0.],
         [0.],
         ...,
         [0.],
         [0.],
        

TypeError: unsupported operand type(s) for -: 'torch.Size' and 'torch.Size'