In [None]:
from gluoncv.model_zoo import cifar_resnet56_v1
import numpy as np
import os, shutil
import struct
import ctypes

In [None]:
import sys
sys.path.append("..")

In [None]:
from compress import utils as cutils, huffman

In [None]:
params_dir = './params/'
if os.path.exists(params_dir): shutil.rmtree(params_dir)
os.mkdir(params_dir)

In [None]:
net = cifar_resnet56_v1()
net.load_parameters("./checkpoints/cifa10_resnet_56_v1_CBQuantize_mergebn_wprune_3bits-000500.params", ignore_extra=True)

In [None]:
def save(data, path, dtype="f", compress_bits=8):
    float_packer = struct.Struct(dtype)
    buffer = ctypes.create_string_buffer(float_packer.size)
    if dtype == "B" and compress_bits < 8:
        with open(path + str(compress_bits), 'wb') as f:
            cnt = 0
            buf = 0
            mask = (1 << compress_bits) - 1
            for d in data.reshape(-1):
                buf = (buf << compress_bits) | (d & mask)
                cnt += compress_bits
                if cnt >= 8:
                    rest = cnt - 8
                    mask_ = 255 << rest
                    dbyte = (buf & mask_) >> rest
                    float_packer.pack_into(buffer, 0, dbyte)
                    f.write(buffer)
                    buf &= (1 << rest) - 1
                    cnt = rest
            if cnt > 0:
                float_packer.pack_into(buffer, 0, buf << (8 - cnt))
                f.write(buffer)
    else:
        with open(path, 'wb') as f:
            for d in data.reshape(-1):
                float_packer.pack_into(buffer, 0, d)
                f.write(buffer)

In [None]:
def load(path):
    res = []
    float_packer = struct.Struct("f")
    with open(path, 'rb') as f:
        while True:
            d = f.read(4)
            if not d:
                break
            res.append(float_packer.unpack_from(d)[0])
    return np.array(res)

In [None]:
def sparsity(arr, idx_bits):
    res_data = []
    res_idx = []

    codebook = np.unique(arr)
    max_jump = 2 ** idx_bits
    nnz_cnt, ph_cnt, ptr_cnt = 0, 0, 0
    for data in arr.reshape(-1):
        if data == 0:
            ptr_cnt += 1
            if ptr_cnt == max_jump:
                res_data.append(0)
                res_idx.append(max_jump - 1)
                ph_cnt += 1
                ptr_cnt = 0
        else:
            res_data.append(np.where(codebook == data)[0])
            res_idx.append(ptr_cnt+1)
            nnz_cnt += 1
            ptr_cnt = 0
    
    return np.array(res_idx, dtype='uint8'), (codebook, np.array(res_data, dtype='uint8')), (nnz_cnt, ph_cnt)

In [None]:
blocks = cutils.collect_conv_and_fc(net, exclude=[net.features[0]])

In [None]:
weight_bits = 3
index_bits = 5
for blk in blocks:
    weight = blk.weight.data().asnumpy()
    sparse_indices, (codebook, sparse_data), _ = sparsity(weight, index_bits)
    
    save(codebook, f'{params_dir}{blk.name}.weight.codebook.dat')
    bytes_codebook1, bytes_data1 = huffman.huffman_encode(sparse_data, f'{params_dir}{blk.name}.weight.data')
    bytes_codebook2, bytes_data2 = huffman.huffman_encode(sparse_indices, f'{params_dir}{blk.name}.weight.index')
    
    print(f'{blk.name}: {8*(bytes_data1 + bytes_codebook1 + bytes_data2 + bytes_codebook2)/(weight_bits * weight.size)}')
    
    if blk.bias is not None:
        bias = blk.bias.data().asnumpy()
        save(bias, f'{params_dir}{blk.name}.bias.dat')