# Network Compression (Weight Quantization)


# Read state_dict


In [None]:
!gdown --id '12wtIa0WVRcpboQzhgRUJOpcXe23tgWUL' --output student_custom_small.bin

import os
import torch

print(f"\noriginal cost: {os.stat('student_custom_small.bin').st_size} bytes.")
params = torch.load('student_custom_small.bin')

Downloading...
From: https://drive.google.com/uc?id=12wtIa0WVRcpboQzhgRUJOpcXe23tgWUL
To: /content/student_custom_small.bin
  0% 0.00/1.05M [00:00<?, ?B/s] 50% 524k/1.05M [00:00<00:00, 3.60MB/s]100% 1.05M/1.05M [00:00<00:00, 4.90MB/s]

original cost: 1047430 bytes.


# 32-bit Tensor -> 16-bit 

In [None]:
import numpy as np
import pickle

def encode16(params, fname):
    '''將params壓縮成16-bit後輸出到fname。

    Args:
      params: model的state_dict。
      fname: 壓縮後輸出的檔名。
    '''

    custom_dict = {}
    for (name, param) in params.items():
        param = np.float64(param.cpu().numpy())
        # 有些東西不屬於ndarray，只是單一個數字（類型為numpy.float64），這個時候我們就不用壓縮。
        # 例：cnn.3.1.num_batches_tracked
        #    tensor(53148, device='cuda:0')
        if type(param) == np.ndarray:
            custom_dict[name] = np.float16(param)
        else:
            custom_dict[name] = param

    pickle.dump(custom_dict, open(fname, 'wb'))


def decode16(fname):
    '''從fname讀取各個params，將其從16-bit還原回torch.tensor後存進state_dict內。

    Args:
      fname: 壓縮後的檔名。
    '''

    params = pickle.load(open(fname, 'rb'))
    custom_dict = {}
    for (name, param) in params.items():
        param = torch.tensor(param)
        custom_dict[name] = param

    return custom_dict


encode16(params, '16_bit_model.pkl')
print(f"16-bit cost: {os.stat('16_bit_model.pkl').st_size} bytes.")

16-bit cost: 522958 bytes.


# 32-bit Tensor -> 8-bit 

In [None]:
def encode8(params, fname):
    custom_dict = {}
    for (name, param) in params.items():
        param = np.float64(param.cpu().numpy())
        if type(param) == np.ndarray:
            min_val = np.min(param)
            max_val = np.max(param)
            param = np.round((param - min_val) / (max_val - min_val) * 255)
            param = np.uint8(param)
            custom_dict[name] = (min_val, max_val, param)
        else:
            custom_dict[name] = param

    pickle.dump(custom_dict, open(fname, 'wb'))


def decode8(fname):
    params = pickle.load(open(fname, 'rb'))
    custom_dict = {}
    for (name, param) in params.items():
        if type(param) == tuple:
            min_val, max_val, param = param
            param = np.float64(param)
            param = (param / 255 * (max_val - min_val)) + min_val
            param = torch.tensor(param)
        else:
            param = torch.tensor(param)

        custom_dict[name] = param

    return custom_dict

encode8(params, '8_bit_model.pkl')
print(f"8-bit cost: {os.stat('8_bit_model.pkl').st_size} bytes.")

8-bit cost: 268471 bytes.
