In [57]:
import torch
import torch.nn as nn
import numpy as np

In [58]:
#compressed = torch.load("../layer30_v_proj_weight_packed.pt",map_location=torch.device('cpu'))
#quantized = torch.load("../layer30_v_proj_weight.pt",map_location=torch.device('cpu'))

In [59]:
class Packer:
    def __init__(self):
        self.s = torch.from_numpy(np.array([1, 2, 4, 8, 16, 32, 64, 128])).view(
            [-1, 1])
        if torch.cuda.is_available():
            self.s = self.s.cuda()
        self.w_pool = {}

    def __get_weight(self, shape, dtype):
        key = np.prod(shape)
        if key not in self.w_pool.keys():
            self.w_pool[key] = torch.zeros(shape, dtype=dtype)
            if torch.cuda.is_available():
                self.w_pool[key] = self.w_pool[key].cuda()
        return self.w_pool[key].reshape(shape)

    def pack(self, b):
        shape = b.shape
        p_b = b
        print("!!!!\n",b)
        if torch.cuda.is_available():
            p_b = p_b.cuda()
        p_b = (p_b + 1) / 2  # (-1., +1.) -> (0, 1)
        p_b = torch.reshape(p_b, [8, -1]).type(torch.uint8)
        print("shape 8 ",p_b.shape)
        p_b = p_b * self.s
        p_b = p_b.sum(0)
        p_b = p_b.type(torch.uint8)
        return p_b, shape

    def unpack(self, pb, shape, dtype=torch.float16):
        b = self.__get_weight(shape, dtype).view([8, -1])
        for i in range(8):
            b[i] = (pb & 1)  # (pB%2)
            pb = pb >> 1  # //2
        b = b * 2 - 1
        b = b.reshape(shape)
        return b


PACKER = Packer()

class CompressionParameter(nn.Parameter):
    def compress(self, **kwargs):
        raise NotImplemented

    def decompress(self, **kwargs):
        raise NotImplemented


In [60]:
def pseudo_quantize_tensor(
    w, n_bit=4, zero_point=True, q_group_size=128, inplace=False, get_scale_zp=True
):
    org_w_shape = w.shape
    if q_group_size > 0:
        assert org_w_shape[-1] % q_group_size == 0
        w = w.reshape(-1, q_group_size)
    assert w.dim() == 2
    if zero_point:
        max_val = w.amax(dim=1, keepdim=True)
        min_val = w.amin(dim=1, keepdim=True)
        max_int = 2**n_bit - 1
        min_int = 0
        scales = (max_val - min_val).clamp(min=1e-5) / max_int
        zeros = (-torch.round(min_val / scales)).clamp_(min_int, max_int)
    else:  # we actually never used this
        assert min_val is None
        max_val = w.abs().amax(dim=1, keepdim=True)
        max_val = max_val.clamp(min=1e-5)
        max_int = 2 ** (n_bit - 1) - 1
        min_int = -(2 ** (n_bit - 1))
        scales = max_val / max_int
        zeros = 0

    assert torch.isnan(scales).sum() == 0
    assert torch.isnan(w).sum() == 0

    if inplace:
        (
            (w.div_(scales).round_().add_(zeros)).clamp_(min_int, max_int).sub_(zeros)
        ).mul_(scales)
    else:
        w = (
            torch.clamp(torch.round(w / scales) + zeros, min_int, max_int) - zeros
        ) * scales
    assert torch.isnan(w).sum() == 0

    w = w.reshape(org_w_shape)
    print(scales.shape)
    scales = scales.view(w.shape[0], -1)
    zeros = zeros.view(w.shape[0], -1)
    intweight = []
    scale_zeros = zeros * scales
    for idx in range(w.shape[1]):
        intweight.append(
            torch.round(
                (w[:, idx] + scale_zeros[:, idx // q_group_size])
                / scales[:, idx // q_group_size]
            ).to(torch.int)[:, None]
        )
    intweight = torch.cat(intweight, dim=1)
    # intweight = intweight.t().contiguous()
    intweight = intweight.to(dtype=torch.int32).reshape([intweight.shape[0], -1, q_group_size])
    scales = scales.view(w.shape[0], -1,1)
    zeros = zeros.view(w.shape[0], -1,1)
    if get_scale_zp:
        return w, scales, zeros, intweight
    else:
        return w

In [61]:
import torch
import torch.nn as nn
import numpy as np

#from clet.functions.rtn import Quantizer
from utils import CompressionParameter, PACKER, Quantizer
from bcq_parameter import BCQParameter


class RTNParameter(CompressionParameter):
    def compress(self, in_ch_wise=False, **kwargs):
        data_shape = self.data.shape
        group_size = -1
        if 'group_size' in kwargs:
            group_size = kwargs.pop('group_size')
        out_ch = data_shape[0]
        in_ch = data_shape[1]

        quant = Quantizer()
        quant.configure(**kwargs)
        if in_ch_wise == False:
            data = self.data
            if group_size > 0:
                data = data.reshape([-1, group_size])
            quant.find_params(data, weight=True)
            quant_data  = torch.clamp(torch.round(data / quant.scale) + quant.zero, 0, quant.maxq)
            quant_data  = quant_data.reshape([out_ch, -1]).to(torch.int)
            quant.scale = quant.scale.reshape([out_ch, -1, 1])
            quant.zero  = quant.zero.reshape([out_ch, -1, 1])
        else:
            data = self.data.T
            if group_size > 0:
                data = data.reshape([-1, group_size])
            quant.find_params(data, weight=True)
            quant_data = torch.clamp(torch.round(data / quant.scale) + quant.zero, 0, quant.maxq)
            quant_data = quant_data.reshape([in_ch, -1, group_size]).to(torch.int)
            quant.scale = quant.scale.reshape([in_ch, -1, 1])
            quant.zero  = quant.zero.reshape([in_ch, -1, 1])

        return quant.scale, quant.zero, quant_data, quant_data.shape
    
    def decompress(self, scale, zero, quant_data, quant_data_shape, in_ch_wise=False):
        # w.shape = [out_ch, in_ch]
        # in_ch_wise == True
        #   -> quant_data.shape = [in_ch, out_ch//group_size, group_size]
        #   -> scale.shape      = [in_ch, out_ch//group_size, 1]
        #   -> zero.shape       = [in_ch, out_ch//group_size, 1]
        # in_ch_wise == False
        #   -> quant_data.shape = [out_ch, in_ch//group_size, group_size]
        #   -> scale.shape      = [out_ch, in_ch//group_size, 1]
        #   -> zero.shape       = [out_ch, in_ch//group_size, 1]

        if in_ch_wise == True:
            out_ch = quant_data_shape[1] * quant_data_shape[2]
            decomp_w = scale * (quant_data - zero)
            decomp_w = decomp_w.reshape([-1, out_ch]).T
        else:
            out_ch = quant_data_shape[0]
            decomp_w = scale * (quant_data - zero)
            decomp_w = decomp_w.reshape([out_ch, -1])
        self.data = decomp_w

    def convert_bcq_format(self, scale, zero, quant_data, qbits, do_packing=False, in_ch_wise=False):
        global PACKER

        zero   = scale * zero
        upack  = torch.Tensor([[2**i for i in range(qbits)]])
        scale  = scale / 2.0
        print("bf matmul",scale.shape)
        scale  = torch.matmul(scale, upack)
        print("af matmul",scale.shape,upack.shape)
        offset = scale.sum(-1).unsqueeze(-1) - zero

        binary = torch.zeros(list(quant_data.shape) + [qbits])
        binary_shape = binary.shape
        print(binary_shape)
        for i in range(qbits):
            binary[:, :, i] = ((quant_data >> i) & 1) * 2.0 - 1.0

        scale = scale.permute(1,2,0)
        binary = binary.permute(1,2,0)
        offset = offset.permute(1,0,2)
        K = binary.shape[0] #input
        
        binary = binary.reshape(K,qbits,-1)

        N = binary.shape[2] #output

        bW = torch.zeros([K // 32, qbits, N], dtype=torch.int32)
        binary_shape = binary.shape
        if do_packing == True:
            for n in range(N):
                for b in range(qbits):
                    for k in range(0, K, 32):
                        s = 0
                        for t in range(32):
                            if binary[k + t][b][n] == 1:
                                s |= 1 << t  # 비트를 설정
                        bW[k // 32][b][n] = (s & 0xFFFFFFFF)

        return scale, bW, binary_shape, offset


In [62]:
tensor = torch.arange(3*2*3*4).reshape(3,2,3,4)  # shape: (2, 3, 4, 5)

# 차원 순서 변경 (0, 3, 1, 2)
permuted_tensor = tensor.permute(0, 3, 1, 2)

print(tensor)

print(permuted_tensor)


tensor([[[[ 0,  1,  2,  3],
          [ 4,  5,  6,  7],
          [ 8,  9, 10, 11]],

         [[12, 13, 14, 15],
          [16, 17, 18, 19],
          [20, 21, 22, 23]]],


        [[[24, 25, 26, 27],
          [28, 29, 30, 31],
          [32, 33, 34, 35]],

         [[36, 37, 38, 39],
          [40, 41, 42, 43],
          [44, 45, 46, 47]]],


        [[[48, 49, 50, 51],
          [52, 53, 54, 55],
          [56, 57, 58, 59]],

         [[60, 61, 62, 63],
          [64, 65, 66, 67],
          [68, 69, 70, 71]]]])
tensor([[[[ 0,  4,  8],
          [12, 16, 20]],

         [[ 1,  5,  9],
          [13, 17, 21]],

         [[ 2,  6, 10],
          [14, 18, 22]],

         [[ 3,  7, 11],
          [15, 19, 23]]],


        [[[24, 28, 32],
          [36, 40, 44]],

         [[25, 29, 33],
          [37, 41, 45]],

         [[26, 30, 34],
          [38, 42, 46]],

         [[27, 31, 35],
          [39, 43, 47]]],


        [[[48, 52, 56],
          [60, 64, 68]],

         [[49, 53, 57],
 

In [73]:
import torch
import torch.nn as nn
import numpy as np

#from clet.functions.rtn import Quantizer
from utils import CompressionParameter, PACKER, Quantizer
from bcq_parameter import BCQParameter



if __name__ == '__main__':
    w_org = torch.randn(1024,4096)
    #w =torch.tensor([511]).to(torch.int16)
    #print(w.to(torch.uint8))
    #w_org = torch.arange(4 * 32).reshape(4, 32).to(torch.float16)
    #print(w_org)
    # INT4 Quantization -> RTN
    w_rtn = RTNParameter(w_org)
    scale, zero, w_quant, w_quant_shape = w_rtn.compress(in_ch_wise=False, qbits=4, group_size=128, perchannel=True, sym=False)
    #print(w_quant)
    #w, s, z ,i= pseudo_quantize_tensor(w_org)
    #a = scale*(w_quant-zero)
    print("quant",scale.shape,zero.shape,w_quant.shape)
    #print(scale.dtype,zero.dtype,w_quant.dtype)
    #print("!!!",w_quant.shape[1])
    #print(zero)
    #print(scale)
    #print(w_quant)
    #print(a)
    #print(z)

    #w_rtn.decompress(scale, zero, w_quant, w_quant_shape, in_ch_wise=False)
    #print(abs(w_org-w_rtn.data).mean())

    # Convert INT4 -> BCQ4
    alpha, binary, binary_shape, offset = w_rtn.convert_bcq_format(scale, zero, w_quant, qbits=4, do_packing=False, in_ch_wise=False)
    
    print(binary.size(),alpha.size(),offset.size())
    #print(alpha)
    #print(binary)
    #print(offset)
    # BCQ Decompress Check
    #w_bcq = BCQParameter(w_org)
    #w_bcq.decompress(alpha, binary, binary_shape, offset=offset, do_packing=True, in_ch_wise=False)
    #print(abs(w_bcq.data - w_rtn.data).mean())


tensor([255], dtype=torch.uint8)
quant torch.Size([1024, 32, 1]) torch.Size([1024, 32, 1]) torch.Size([1024, 4096])
bf matmul torch.Size([1024, 32, 1])
af matmul torch.Size([1024, 32, 4]) torch.Size([1, 4])
torch.Size([1024, 4096, 4])
torch.Size([128, 4, 1024]) torch.Size([32, 4, 1024]) torch.Size([32, 1024, 1])
