In [171]:
#|export
import torch
import struct
import numpy as np
from sympy import divisors

In [172]:
import random
seed = 1337
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

In [173]:
#|export
def serialize_fp32(file, tensor):
    ''' Write one fp32 tensor to file that is open in wb mode '''
    d = tensor.detach().cpu().view(-1).to(torch.float32).numpy()
    b = struct.pack(f'{len(d)}f', *d)
    file.write(b)

def serialize_int8(file, tensor):
    ''' Write one int8 tensor to file that is open in wb mode '''
    d = tensor.detach().cpu().view(-1).numpy().astype(np.int8)
    b = struct.pack(f'{len(d)}b', *d)
    file.write(b)

In [174]:
# Unit test for serialize_fp32 & serialize_int8
fname = '.test.bin'

# Unit test for serialize_fp32
def test_serialize_fp32(fname):
    a = torch.randn(2,3)
    with open(fname, 'wb') as f:
        serialize_fp32(f, a)
    with open(fname, 'rb') as f:
        x = f.read()
    b = torch.tensor(struct.unpack(f'6f', x)).reshape(2,3)
    assert (a==b).all()
    
test_serialize_fp32(fname)


# Unit test for serialize_int8
def test_serialize_int8(fname):
    a = torch.tensor([[0,1,2],[3,4,5]], dtype=torch.int8)
    with open(fname, 'wb') as f:
        serialize_int8(f, a)
    with open(fname, 'rb') as f:
        x = f.read()
    b = torch.tensor(struct.unpack(f'6b', x)).reshape(2,3)
    assert (a==b).all()
    
test_serialize_int8(fname)

In [175]:
#|export
def quantize_q80(w, group_size):
    '''
    Take a tensor and returns the Q8_0 quantized version
    i.e. symmetric quantization into int8, range [-127,127]
    '''
    assert w.numel() % group_size == 0
    ori_shape = w.shape
    w = w.float() # convert to float32
    w = w.reshape(-1, group_size)
    # find the max in each group
    wmax = torch.abs(w).max(dim=1).values
    # calculate the scaling factor such that float = quant * scale
    scale = wmax / 127.0
    # scale into range [-127, 127]
    quant = w / scale[:,None]
    # round to nearest integer
    int8val = torch.round(quant).to(torch.int8)
    # dequantize by rescaling
    fp32val = (int8val.float() * scale[:,None]).view(-1)
    fp32valr = fp32val.reshape(-1, group_size)
    # calculate the max error in each group
    err = torch.abs(fp32valr - w).max(dim=1).values
    # find the max error across all groups
    maxerr = err.max().item()
    return int8val, scale, maxerr

In [176]:
# Unit test for quantize_q80
def test_quantize_q80():
    w0 = torch.tensor([0,1,2,3], dtype=torch.float32)
    w1, scale, err = quantize_q80(w0, 2)
    assert (w1==torch.tensor([0,127,85,127], dtype=torch.int8).reshape(2,2)).all()
    #FIXME: check scale & err too?
    
test_quantize_q80()

In [177]:
#|export
def export_model(model, file_path="model.bin"):
    '''
    Export the model to a file
    The data inside the file follows this order:
    1. The number of classes, CNN layers, and FC layers.
    2. CNN and FC layers' configuration.
    3. CNN and FC layers' parameters.
    '''
    f = open(file_path, "wb")
    # write model config
    conv_layers = [model.conv1, model.conv2]
    nconv = len(conv_layers)
    linear_layers = [model.fc1, model.fc2]
    nlinear = len(linear_layers)
    header = struct.pack("iii", model.nclasses, nconv, nlinear)
    f.write(header)
    # write layers' config
    offset = 0 # the number of bytes in float32 (i.e. offset 1 = 4 bytes)
    for layer in conv_layers:
        f.write(struct.pack("6i", layer.kernel_size[0], layer.stride[0], layer.padding[0],
                layer.in_channels, layer.out_channels, offset))
        # set offset to the start of next layer
        offset += layer.out_channels * layer.in_channels * layer.kernel_size[0]**2 + layer.out_channels

    for layer in linear_layers:
        f.write(struct.pack("3i", layer.in_features, layer.out_features, offset))
        offset += layer.in_features * layer.out_features + layer.out_features

    # write the weights and biases of the model
    for layer in conv_layers:
        for p in layer.parameters():
            serialize_fp32(f, p)

    for layer in linear_layers:
        for p in layer.parameters():
            serialize_fp32(f, p)

    f.close()
    print(f"wrote {file_path}")
    
    torch.save(model, "model.pt") # for loading in python

In [18]:
# Unit test for export_model()
import os
import torch.nn as nn

class TestModel(nn.Module):
    def __init__(self):
        super(TestModel, self).__init__()
        self.nclasses = 4
        self.conv1 = nn.Conv2d(1, 1, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(1, 1, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(3, 2)
        self.fc2 = nn.Linear(2, 1)

m = TestModel()

def test_export_model():
    fname = '/tmp/.test.bin'
    export_model(m, fname)
    assert os.path.getsize(fname)==208 # manucally calc from export_model()

test_export_model()

wrote /tmp/.test.bin


In [193]:
#|export
def calculate_groupsize(dim, gs):
    '''
    Change the group size if dimension is smaller, and adjust if dim is not a
    multiple of group size. Otherewise it remains the same.
    '''
    if dim < gs:
        return dim
    elif  dim % gs == 0:
        return gs
    else:
        factors = list(divisors(dim)) # give the factors of number "dim"
        return min(factors, key=lambda x: abs(x - gs)) # find the closest number to group size

def export_modelq8(model_path="model.pt", file_path="modelq8.bin", gs=64):
    '''
    Export the quantized model to a file
    The data inside the file follows this order:
    1. The number of classes, CNN layers, and FC layers, parameters.
    2. CNN and FC layers' configuration.
    3. CNN and FC layers' quantized parameters.
    4. CNN and FC layers' scaling factors.
    '''
    model = torch.load(model_path)
    f = open(file_path, "wb")
    # write model config
    conv_layers = [model.conv1, model.conv2]
    nconv = len(conv_layers)
    linear_layers = [model.fc1, model.fc2]
    nlinear = len(linear_layers)
    nparameters = sum(p.numel() for p in model.parameters())
    header = struct.pack("4i", model.nclasses, nconv, nlinear, nparameters)
    f.write(header)
    # write layers' config
    qoffset = 0 # offset for quantized parameters
    soffset = 0 # offset for scaling factors
    group_sizes = [] # save group sizes of each layer
    for layer in conv_layers:
        # calculates group sizes for weights and biases
        gs_weight = calculate_groupsize(layer.in_channels * layer.kernel_size[0]**2, gs)
        gs_bias = calculate_groupsize(layer.out_channels, gs)
        group_sizes.append(gs_weight)
        group_sizes.append(gs_bias)

        f.write(struct.pack("9i", layer.kernel_size[0], layer.stride[0], layer.padding[0],
                layer.in_channels, layer.out_channels, qoffset, soffset, gs_weight, gs_bias))
        # set offsets to the start of next layer
        nweights = layer.out_channels * layer.in_channels * layer.kernel_size[0]**2
        qoffset += nweights + layer.out_channels
        soffset += nweights // gs_weight + layer.out_channels // gs_bias

    for layer in linear_layers:
        gs_weight = calculate_groupsize(layer.in_features, gs)
        gs_bias = calculate_groupsize(layer.out_features, gs)
        group_sizes.append(gs_weight)
        group_sizes.append(gs_bias)

        f.write(struct.pack("6i", layer.in_features, layer.out_features, qoffset, soffset,
                            gs_weight, gs_bias))

        nweights = layer.in_features * layer.out_features
        qoffset += nweights + layer.out_features
        soffset += nweights // gs_weight + layer.out_features // gs_bias

    ew = []
    scaling_factors = []
    i = 0
    for layer in [*conv_layers, *linear_layers]:
        for p in layer.parameters():
            q, s, err = quantize_q80(p, group_sizes[i])
            serialize_int8(f, q) # save the tensor in int8
            scaling_factors.append(s)
            ew.append((err, p.shape))
            i += 1
            print(f"Quantized {tuple(p.shape)} to Q8_0 with max error {err}")

    for s in scaling_factors:
        serialize_fp32(f, s) # save scale factors

    # print the highest error across all parameters, should be very small, e.g. O(~0.001)
    ew.sort(reverse=True)
    print(f"max quantization group error across all weights: {ew[0][0]}")
    f.close()
    print(f"wrote {file_path}")

In [23]:
# Unit test for calculate_groupsize()
def test_calculate_groupsize():
    assert calculate_groupsize(4, 8)==4
    assert calculate_groupsize(8, 4)==4
    assert calculate_groupsize(9, 4)==3
    
test_calculate_groupsize()

# Unit test for export_modelq8()
def test_export_modelq8():
    fname = '/tmp/.test.bin'
    test_export_model()
    export_modelq8(file_path=fname, gs=1)
    assert os.path.getsize(fname)==291

test_export_modelq8()

wrote /tmp/.test.bin
Quantized (1, 1, 3, 3) to Q8_0 with max error 3.725290298461914e-09
Quantized (1,) to Q8_0 with max error 0.0
Quantized (1, 1, 3, 3) to Q8_0 with max error 1.4901161193847656e-08
Quantized (1,) to Q8_0 with max error 0.0
Quantized (2, 3) to Q8_0 with max error 0.0
Quantized (2,) to Q8_0 with max error 0.0
Quantized (1, 2) to Q8_0 with max error 0.0
Quantized (1,) to Q8_0 with max error 0.0
max quantization group error across all weights: 1.4901161193847656e-08
wrote /tmp/.test.bin
