In [41]:
#|export
import torch
import struct
import numpy as np
from sympy import divisors
import copy
from fasterai.misc.bn_folding import *

In [42]:
seed = 1337
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

In [43]:
#|export
def serialize_fp32(file, tensor):
    ''' Write one fp32 tensor to file that is open in wb mode '''
    d = tensor.detach().cpu().view(-1).to(torch.float32).numpy()
    b = struct.pack(f'{len(d)}f', *d)
    file.write(b)

def serialize_int8(file, tensor):
    ''' Write one int8 tensor to file that is open in wb mode '''
    d = tensor.detach().cpu().view(-1).numpy().astype(np.int8)
    b = struct.pack(f'{len(d)}b', *d)
    file.write(b)

In [44]:
# Unit test for serialize_fp32 & serialize_int8
fname = '.test.bin'

# Unit test for serialize_fp32
def test_serialize_fp32(fname):
    a = torch.randn(2,3)
    with open(fname, 'wb') as f:
        serialize_fp32(f, a)
    with open(fname, 'rb') as f:
        x = f.read()
    b = torch.tensor(struct.unpack(f'6f', x)).reshape(2,3)
    assert (a==b).all()
    
test_serialize_fp32(fname)


# Unit test for serialize_int8
def test_serialize_int8(fname):
    a = torch.tensor([[0,1,2],[3,4,5]], dtype=torch.int8)
    with open(fname, 'wb') as f:
        serialize_int8(f, a)
    with open(fname, 'rb') as f:
        x = f.read()
    b = torch.tensor(struct.unpack(f'6b', x)).reshape(2,3)
    assert (a==b).all()
    
test_serialize_int8(fname)

In [45]:
#|export
def quantize_q80(w, group_size):
    '''
    Take a tensor and returns the Q8_0 quantized version
    i.e. symmetric quantization into int8, range [-127,127]
    '''
    assert w.numel() % group_size == 0
    ori_shape = w.shape
    w = w.float() # convert to float32
    w = w.reshape(-1, group_size)
    # find the max in each group
    wmax = torch.abs(w).max(dim=1).values
    # calculate the scaling factor such that float = quant * scale
    scale = wmax / 127.0
    # scale into range [-127, 127]
    quant = w / scale[:,None]
    # round to nearest integer
    int8val = torch.round(quant).to(torch.int8)
    # dequantize by rescaling
    fp32val = (int8val.float() * scale[:,None]).view(-1)
    fp32valr = fp32val.reshape(-1, group_size)
    # calculate the max error in each group
    err = torch.abs(fp32valr - w).max(dim=1).values
    # find the max error across all groups
    maxerr = err.max().item()
    return int8val, scale, maxerr

In [46]:
# Unit test for quantize_q80
def test_quantize_q80():
    w0 = torch.tensor([0,1,2,3], dtype=torch.float32)
    w1, scale, err = quantize_q80(w0, 2)
    assert (w1==torch.tensor([0,127,85,127], dtype=torch.int8).reshape(2,2)).all()
    #FIXME: check scale & err too?
    
test_quantize_q80()

In [65]:
#|export
def export_model(model, file_path="model.bin"):
    '''
    Export the quantized model to a file
    The data inside the file follows this order:
    1. The number of: classes, each type of layers and parameters
    2. CNN, FC and BN layers' configuration
    3. CNN, FC and BN layers' parameters
    '''
    # batchnorm folding
    model = BN_Folder().fold(model)
    f = open(file_path, "wb")
    # write model config
    conv_layers = [layer for layer in model.modules() if isinstance(layer, torch.nn.Conv2d)]
    nconv = len(conv_layers)
    # read batchnorm1d layers to which batchnorm folding cannot be applied
    bn_layers = [layer for layer in model.modules() if isinstance(layer, torch.nn.BatchNorm1d)]
    nbn = len(bn_layers)
    linear_layers = [layer for layer in model.modules() if isinstance(layer, torch.nn.Linear)]
    nlinear = len(linear_layers)
    nclasses = 10
    header = struct.pack("4i", nclasses, nconv, nlinear, nbn)
    f.write(header)

    # write layers' config
    offset = 0 # the number of bytes in float32 (i.e. offset 1 = 4 bytes)
    for l in conv_layers:
        bias = 1 if l.bias is not None else 0
        f.write(struct.pack("7i", l.kernel_size[0], l.stride[0], l.padding[0],
                l.in_channels, l.out_channels, offset, bias))
        # set offset to the start of next layer
        t_offset =  l.out_channels * l.in_channels * l.kernel_size[0]**2
        # check if the layer has a bias term and adjust the offset accordingly
        offset += t_offset + bias * l.out_channels  # include biases in the offset if bias is not None

    for l in linear_layers:
        bias = 1 if l.bias is not None else 0
        f.write(struct.pack("4i", l.in_features, l.out_features, offset, bias))
        offset += l.in_features * l.out_features + bias * l.out_features

    for l in bn_layers:
        f.write(struct.pack("2i", l.num_features, offset))
        # set offset to the start of next layer
        offset += 4 * l.num_features # weight, bias, running_mean, running_var

    # write the weights and biases of the model
    for l in [*conv_layers, *linear_layers]:
        for p in l.parameters():
            serialize_fp32(f, p)

    for l in bn_layers:
        for p in [l.weight, l.bias, l.running_mean, l.running_var]:
            serialize_fp32(f, p)

    f.close()
    print(f"wrote {file_path}")

In [66]:
import os
import torch.nn as nn

class TestModel(nn.Module):
    def __init__(self):
        super(TestModel, self).__init__()
        self.conv1 = nn.Conv2d(1, 1, kernel_size=3, padding=1, bias=True)
        self.conv2 = nn.Conv2d(1, 1, kernel_size=3, padding=1, bias=False)
        self.bn2d = nn.BatchNorm2d(1)
        self.fc1 = nn.Linear(3, 2)
        self.bn1d = nn.BatchNorm1d(2)
        self.fc2 = nn.Linear(2, 1, bias=False)
    def forward(self, x):
        return self.fc2(self.bn1d(self.fc1(self.bn2d(self.conv2(self.conv1(x))))))

m = TestModel().eval()

# header, conv_config, linear_config, bn_config, conv_param, linear_param, bn_param
# 4*4 + 2*7*4 + 2*4*4 + 1*2*4 + 2*(3*3+1)*4 + (3*2+2+2*1)*4 + 2*4*4 = 264
def test_export_model():
    fname = '/tmp/.test.bin'
    export_model(m, fname)
    assert os.path.getsize(fname)==264 # manually calc from export_model()

test_export_model()

wrote /tmp/.test.bin


In [67]:
def test_bias_not_none():
    conv_layer = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
    bn_layer = nn.BatchNorm2d(16)
    fold_batchnorm(conv_layer, bn_layer)
    assert conv_layer.bias is not None, "Bias is None after folding."

In [68]:
#|export
def calculate_groupsize(dim, gs):
    '''
    Change the group size if dimension is smaller, and adjust if dim is not a
    multiple of group size. Otherewise it remains the same.
    '''
    if dim < gs:
        return dim
    elif  dim % gs == 0:
        return gs
    else:
        factors = list(divisors(dim)) # give the factors of number "dim"
        return min(factors, key=lambda x: abs(x - gs)) # find the closest number to group size

def export_modelq8(model, file_path="modelq8.bin", gs=64):
    '''
    Export the quantized model to a file
    The data inside the file follows this order:
    1. The number of: classes, each type of layers and parameters
    2. CNN, FC and BN layers' configuration
    3. CNN and FC layers' quantized parameters
    4. CNN and FC layers' scaling factors
    5. BN layers' parameters
    '''
    f = open(file_path, "wb")
    # write model config
    conv_layers = [layer for layer in model.modules() if isinstance(layer, torch.nn.Conv2d)]
    nconv = len(conv_layers)
    linear_layers = [layer for layer in model.modules() if isinstance(layer, torch.nn.Linear)]
    nlinear = len(linear_layers)
    bn_layers = [layer for layer in model.modules()
                 if isinstance(layer, torch.nn.BatchNorm1d) or isinstance(layer, torch.nn.BatchNorm2d)]
    nbn = len(bn_layers)
    nclasses = 10
    nparameters = sum(p.numel() for layer in [*conv_layers, *linear_layers] for p in layer.parameters())
    header = struct.pack("5i", nclasses, nconv, nlinear, nbn, nparameters)
    f.write(header)
    # write layers' config
    qoffset = 0 # offset for quantized parameters
    soffset = 0 # offset for scaling factors
    group_sizes = [] # save group sizes of each layer
    for l in conv_layers:
        # calculates group sizes for weights and biases
        gs_weight = calculate_groupsize(l.in_channels * l.kernel_size[0]**2, gs)
        gs_bias = calculate_groupsize(l.out_channels, gs) if l.bias is not None else 0
        group_sizes.append(gs_weight)
        if l.bias is not None:
            group_sizes.append(gs_bias)

        f.write(struct.pack("9i", l.kernel_size[0], l.stride[0], l.padding[0], l.in_channels,
                            l.out_channels, qoffset, soffset, gs_weight, gs_bias))
        # set offsets to the start of next layer
        nweights = l.out_channels * l.in_channels * l.kernel_size[0]**2
        if l.bias is not None:
            qoffset += nweights + l.out_channels
            soffset += nweights // gs_weight + l.out_channels // gs_bias
        else:
            qoffset += nweights
            soffset += nweights // gs_weight

    for l in linear_layers:
        gs_weight = calculate_groupsize(l.in_features, gs)
        gs_bias = calculate_groupsize(l.out_features, gs) if l.bias is not None else 0
        group_sizes.append(gs_weight)
        if l.bias is not None:
            group_sizes.append(gs_bias)

        f.write(struct.pack("6i", l.in_features, l.out_features, qoffset, soffset, gs_weight, gs_bias))

        nweights = l.in_features * l.out_features
        if l.bias is not None:
            qoffset += nweights + l.out_features
            soffset += nweights // gs_weight + l.out_features // gs_bias
        else:
            qoffset += nweights
            soffset += nweights // gs_weight   

    for l in bn_layers:
        f.write(struct.pack("2i", l.num_features, soffset))
        # weight, bias, running_mean, running_var
        soffset += 4 * l.num_features

    # write layers' parameters
    ew = []
    scaling_factors = []
    i = 0
    for l in [*conv_layers, *linear_layers]:
        for p in l.parameters():
            q, s, err = quantize_q80(p, group_sizes[i])
            serialize_int8(f, q) # save the tensor in int8
            scaling_factors.append(s)
            ew.append((err, p.shape))
            i += 1
            print(f"Quantized {tuple(p.shape)} to Q8_0 with max error {err}")

    for s in scaling_factors:
        serialize_fp32(f, s) # save scale factors

    for l in bn_layers:
        for p in [l.weight, l.bias, l.running_mean, l.running_var]:
            serialize_fp32(f, p)

    # print the highest error across all parameters, should be very small, e.g. O(~0.001)
    ew.sort(reverse=True)
    print(f"max quantization group error across all weights: {ew[0][0]}")
    f.close()
    print(f"wrote {file_path}")

In [69]:
# Unit test for calculate_groupsize()
def test_calculate_groupsize():
    assert calculate_groupsize(4, 8)==4
    assert calculate_groupsize(8, 4)==4
    assert calculate_groupsize(9, 4)==3
    assert calculate_groupsize(10, 5)==5
    assert calculate_groupsize(5, 10)==5
    assert calculate_groupsize(20, 3)==2
    
test_calculate_groupsize()

def test_edge_cases():
    assert calculate_groupsize(1, 10) == 1, "Should handle dimension of 1 correctly."
    assert calculate_groupsize(100, 101) == 100, "Should adjust group size to dimension if it's larger."
    assert calculate_groupsize(13, 2) == 1, "Should return 1 for prime dimensions."

test_edge_cases()


# Unit test for export_modelq8()
def test_export_modelq8():
    fname = '/tmp/.test.bin'
    mpath = "test.pt"
    export_modelq8(m, fname, gs=10)
    assert os.path.getsize(fname)==261

test_export_modelq8()

Quantized (1, 1, 3, 3) to Q8_0 with max error 0.001264810562133789
Quantized (1,) to Q8_0 with max error 0.0
Quantized (1, 1, 3, 3) to Q8_0 with max error 0.0011606067419052124
Quantized (2, 3) to Q8_0 with max error 0.001620173454284668
Quantized (2,) to Q8_0 with max error 0.0009410381317138672
Quantized (1, 2) to Q8_0 with max error 0.0001624971628189087
max quantization group error across all weights: 0.001620173454284668
wrote /tmp/.test.bin


In [70]:
from train import load

model = load("resnet18").model
export_model(model, "resnet18.bin")
export_modelq8(model, "resnet18_q8.bin", gs=64)

Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

wrote resnet18.bin
Quantized (64, 3, 7, 7) to Q8_0 with max error 0.00399431586265564
Quantized (64, 64, 3, 3) to Q8_0 with max error 0.0030971625819802284
Quantized (64, 64, 3, 3) to Q8_0 with max error 0.0018653776496648788
Quantized (64, 64, 3, 3) to Q8_0 with max error 0.0025304292794317007
Quantized (64, 64, 3, 3) to Q8_0 with max error 0.0014979839324951172
Quantized (128, 64, 3, 3) to Q8_0 with max error 0.0012937188148498535
Quantized (128, 128, 3, 3) to Q8_0 with max error 0.0016725650057196617
Quantized (128, 64, 1, 1) to Q8_0 with max error 0.0030133600812405348
Quantized (128, 128, 3, 3) to Q8_0 with max error 0.001718759536743164
Quantized (128, 128, 3, 3) to Q8_0 with max error 0.0013971449807286263
Quantized (256, 128, 3, 3) to Q8_0 with max error 0.0015289076836779714
Quantized (256, 256, 3, 3) to Q8_0 with max error 0.0013069347478449345
Quantized (256, 128, 1, 1) to Q8_0 with max error 0.0010194750502705574
Quantized (256, 256, 3, 3) to Q8_0 with max error 0.001156233