In [2]:
#|export
import torch
import struct
import numpy as np
from sympy import divisors
from train import load

In [3]:
seed = 1337
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

In [4]:
#|export
def serialize_fp32(file, tensor):
    ''' Write one fp32 tensor to file that is open in wb mode '''
    d = tensor.detach().cpu().view(-1).to(torch.float32).numpy()
    b = struct.pack(f'{len(d)}f', *d)
    file.write(b)

def serialize_int8(file, tensor):
    ''' Write one int8 tensor to file that is open in wb mode '''
    d = tensor.detach().cpu().view(-1).numpy().astype(np.int8)
    b = struct.pack(f'{len(d)}b', *d)
    file.write(b)

In [5]:
# Unit test for serialize_fp32 & serialize_int8
fname = '.test.bin'

# Unit test for serialize_fp32
def test_serialize_fp32(fname):
    a = torch.randn(2,3)
    with open(fname, 'wb') as f:
        serialize_fp32(f, a)
    with open(fname, 'rb') as f:
        x = f.read()
    b = torch.tensor(struct.unpack(f'6f', x)).reshape(2,3)
    assert (a==b).all()
    
test_serialize_fp32(fname)


# Unit test for serialize_int8
def test_serialize_int8(fname):
    a = torch.tensor([[0,1,2],[3,4,5]], dtype=torch.int8)
    with open(fname, 'wb') as f:
        serialize_int8(f, a)
    with open(fname, 'rb') as f:
        x = f.read()
    b = torch.tensor(struct.unpack(f'6b', x)).reshape(2,3)
    assert (a==b).all()
    
test_serialize_int8(fname)

In [6]:
#|export
def quantize_q80(w, group_size):
    '''
    Take a tensor and returns the Q8_0 quantized version
    i.e. symmetric quantization into int8, range [-127,127]
    '''
    assert w.numel() % group_size == 0
    ori_shape = w.shape
    w = w.float() # convert to float32
    w = w.reshape(-1, group_size)
    # find the max in each group
    wmax = torch.abs(w).max(dim=1).values
    # calculate the scaling factor such that float = quant * scale
    scale = wmax / 127.0
    # scale into range [-127, 127]
    quant = w / scale[:,None]
    # round to nearest integer
    int8val = torch.round(quant).to(torch.int8)
    # dequantize by rescaling
    fp32val = (int8val.float() * scale[:,None]).view(-1)
    fp32valr = fp32val.reshape(-1, group_size)
    # calculate the max error in each group
    err = torch.abs(fp32valr - w).max(dim=1).values
    # find the max error across all groups
    maxerr = err.max().item()
    return int8val, scale, maxerr

In [7]:
# Unit test for quantize_q80
def test_quantize_q80():
    w0 = torch.tensor([0,1,2,3], dtype=torch.float32)
    w1, scale, err = quantize_q80(w0, 2)
    assert (w1==torch.tensor([0,127,85,127], dtype=torch.int8).reshape(2,2)).all()
    #FIXME: check scale & err too?
    
test_quantize_q80()

In [8]:
#|export
def fold_batchnorm(cnn, bn) :
    ''' Fold batchnorm layer into convolutional layer '''

    # apply batchnorm folding
    var = 1 / torch.sqrt(bn.running_var + bn.eps)
    w_fold = cnn.weight.data * (bn.weight.data * var).view(-1,1,1,1)
    cnn.weight.data = w_fold

    if cnn.bias is None:
        cnn.bias = torch.nn.Parameter(torch.zeros(cnn.weight.shape[0])) # initialize bias as 0 if doesn't exist

    b_fold = bn.weight.data * (cnn.bias.data - bn.running_mean) * var + bn.bias.data
    cnn.bias.data = b_fold
    

def export_model(model_name="resnet18", file_path="model.bin"):
    '''
    Export the quantized model to a file
    The data inside the file follows this order:
    1. The number of: classes, each type of layers and parameters
    2. CNN, FC and BN layers' configuration
    3. CNN, FC and BN layers' parameters
    '''
    model = load(model_name)
    f = open(file_path, "wb")
    # write model config
    conv_layers = [layer for layer in model.modules() if isinstance(layer, torch.nn.Conv2d)]
    bn_layers = [layer for layer in model.modules() if isinstance(layer, torch.nn.BatchNorm2d)]
    bn1_layers = [layer for layer in model.modules() if isinstance(layer, torch.nn.BatchNorm1d)]
    for conv_layer, bn_layer in zip(conv_layers, bn_layers):
        fold_batchnorm(conv_layer, bn_layer) # fold batchnorm layers into convolutional layers
    nconv = len(conv_layers)
    nbn1 = len(bn1_layers)
    linear_layers = [layer for layer in model.modules() if isinstance(layer, torch.nn.Linear)]
    nlinear = len(linear_layers)
    nclasses = 10
    header = struct.pack("4i", nclasses, nconv, nlinear, nbn1)
    f.write(header)

    # write layers' config
    offset = 0 # the number of bytes in float32 (i.e. offset 1 = 4 bytes)
    for layer in conv_layers:
        bias = 1 if layer.bias is not None else 0
        f.write(struct.pack("7i", layer.kernel_size[0], layer.stride[0], layer.padding[0],
                layer.in_channels, layer.out_channels, offset, bias))
        # set offset to the start of next layer
        t_offset =  layer.out_channels*layer.in_channels*layer.kernel_size[0]**2
        # Check if the layer has a bias term and adjust the offset accordingly
        offset += t_offset + bias * layer.out_channels  # Include biases in the offset if bias is not None

    for l in linear_layers:
        bias = 1 if l.bias is not None else 0
        f.write(struct.pack("4i", l.in_features, l.out_features, offset, bias))
        offset += l.in_features * l.out_features + bias * l.out_features    

    for layer in bn1_layers:
        f.write(struct.pack("2i", layer.num_features, offset))
        # set offset to the start of next layer
        # weight, bias, running_mean, running_var
        offset += 4 * layer.num_features
          
    # write the weights and biases of the model
    for l in [*conv_layers, *linear_layers]:
        for p in l.parameters():
            serialize_fp32(f, p)
    
    for layer in bn1_layers:
        for p in layer.parameters():
            serialize_fp32(f, p)
        serialize_fp32(f, layer.running_mean)
        serialize_fp32(f, layer.running_var)

    f.close()
    print(f"wrote {file_path}")

In [73]:
import os
import torch.nn as nn

class TestModel(nn.Module):
    def __init__(self):
        super(TestModel, self).__init__()
        self.nclasses = 4
        self.conv1 = nn.Conv2d(1, 1, kernel_size=3, padding=1, bias=False)
        self.conv2 = nn.Conv2d(1, 1, kernel_size=3, padding=1, bias=False)
        self.bn = nn.BatchNorm2d(1)
        self.fc1 = nn.Linear(3, 2)
        self.fc2 = nn.Linear(2, 1, bias=False)

m = TestModel()
torch.save(m, "models/test.pt")

def test_export_model():
    fname = '/tmp/.test.bin'
    mname = "test"
    export_model(mname, fname)
    assert os.path.getsize(fname)==220 # manucally calc from export_model()

test_export_model()

wrote /tmp/.test.bin


In [74]:
def test_bias_not_none():
    conv_layer = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
    bn_layer = nn.BatchNorm2d(16)
    fold_batchnorm(conv_layer, bn_layer)
    assert conv_layer.bias is not None, "Bias is None after folding."

# Unit test for batchnorm folding
def test_folding():
    input_tensor = torch.randn(1, 3, 32, 32)
    conv_layer_original = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=True)
    bn_layer = nn.BatchNorm2d(16)
    conv_layer_folded = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=True)
    conv_layer_folded.load_state_dict(conv_layer_original.state_dict())
    with torch.no_grad():
            bn_layer(conv_layer_original(input_tensor)) # the running mean and var are updated, "forward pass"

            bn_layer.eval() # use current running mean and var
            fold_batchnorm(conv_layer_folded, bn_layer)
            output_original = bn_layer(conv_layer_original(input_tensor))
            output_folded = conv_layer_folded(input_tensor)
            assert torch.allclose(output_original, output_folded, atol=1e-5), "Output mismatch after folding."

test_bias_not_none()
test_folding()


In [75]:
#|export
def calculate_groupsize(dim, gs):
    '''
    Change the group size if dimension is smaller, and adjust if dim is not a
    multiple of group size. Otherewise it remains the same.
    '''
    if dim < gs:
        return dim
    elif  dim % gs == 0:
        return gs
    else:
        factors = list(divisors(dim)) # give the factors of number "dim"
        return min(factors, key=lambda x: abs(x - gs)) # find the closest number to group size

def export_modelq8(model_path="model.pt", file_path="modelq8.bin", gs=64):
    '''
    Export the quantized model to a file
    The data inside the file follows this order:
    1. The number of: classes, each type of layers and parameters
    2. CNN, FC and BN layers' configuration
    3. CNN and FC layers' quantized parameters
    4. CNN and FC layers' scaling factors
    5. BN layers' parameters
    '''
    model = torch.load(model_path)
    f = open(file_path, "wb")
    # write model config
    conv_layers = [layer for layer in model.modules() if isinstance(layer, torch.nn.Conv2d)]
    nconv = len(conv_layers)
    linear_layers = [layer for layer in model.modules() if isinstance(layer, torch.nn.Linear)]
    nlinear = len(linear_layers)
    bn_layers = [layer for layer in model.modules() if isinstance(layer, torch.nn.BatchNorm2d)]
    nbn = len(bn_layers)
    nclasses = 10
    nparameters = sum(p.numel() for layer in [*conv_layers, *linear_layers] for p in layer.parameters())
    header = struct.pack("5i", nclasses, nconv, nlinear, nbn, nparameters)
    f.write(header)
    # write layers' config
    qoffset = 0 # offset for quantized parameters
    soffset = 0 # offset for scaling factors
    group_sizes = [] # save group sizes of each layer
    for layer in conv_layers:
        # calculates group sizes for weights and biases
        gs_weight = calculate_groupsize(layer.in_channels * layer.kernel_size[0]**2, gs)
        gs_bias = calculate_groupsize(layer.out_features, gs) if layer.bias is not None else 0
        group_sizes.append(gs_weight)

        f.write(struct.pack("9i", layer.kernel_size[0], layer.stride[0], layer.padding[0],
                layer.in_channels, layer.out_channels, qoffset, soffset, gs_weight, gs_bias))
        # set offsets to the start of next layer
        nweights = layer.out_channels * layer.in_channels * layer.kernel_size[0]**2
        qoffset += nweights
        soffset += nweights // gs_weight

    for layer in linear_layers:
       
        gs_weight = calculate_groupsize(layer.in_features, gs)
        gs_bias = calculate_groupsize(layer.out_features, gs) if layer.bias is not None else 0
        group_sizes.append(gs_weight)
        if layer.bias is not None:
            group_sizes.append(gs_bias)

        f.write(struct.pack("6i", layer.in_features, layer.out_features, qoffset, soffset,
                            gs_weight, gs_bias))

        nweights = layer.in_features * layer.out_features
        if layer.bias is not None:
            qoffset += nweights + layer.out_features
            soffset += nweights // gs_weight + layer.out_features // gs_bias
        else:
            qoffset += nweights
            soffset += nweights // gs_weight   

    for l in bn_layers:
        f.write(struct.pack("2i", l.num_features, soffset))
        # weight, bias, running_mean, running_var
        soffset += 4 * l.num_features

    # write layers' parameters
    ew = []
    scaling_factors = []
    i = 0
    for l in [*conv_layers, *linear_layers]:
        for p in l.parameters():
            q, s, err = quantize_q80(p, group_sizes[i])
            serialize_int8(f, q) # save the tensor in int8
            scaling_factors.append(s)
            ew.append((err, p.shape))
            i += 1
            print(f"Quantized {tuple(p.shape)} to Q8_0 with max error {err}")

    for s in scaling_factors:
        serialize_fp32(f, s) # save scale factors

    for l in bn_layers:
        for p in [l.weight, l.bias, l.running_mean, l.running_var]:
            serialize_fp32(f, p)

    # print the highest error across all parameters, should be very small, e.g. O(~0.001)
    ew.sort(reverse=True)
    print(f"max quantization group error across all weights: {ew[0][0]}")
    f.close()
    print(f"wrote {file_path}")

In [155]:
# Unit test for calculate_groupsize()
def test_calculate_groupsize():
    assert calculate_groupsize(4, 8)==4
    assert calculate_groupsize(8, 4)==4
    assert calculate_groupsize(9, 4)==3
    assert calculate_groupsize(10, 5)==5
    assert calculate_groupsize(5, 10)==5
    assert calculate_groupsize(20, 3)==2
    
test_calculate_groupsize()

def test_edge_cases():
    assert calculate_groupsize(1, 10) == 1, "Should handle dimension of 1 correctly."
    assert calculate_groupsize(100, 101) == 100, "Should adjust group size to dimension if it's larger."
    assert calculate_groupsize(13, 2) == 1, "Should return 1 for prime dimensions."

test_edge_cases()


# Unit test for export_modelq8()
def test_export_modelq8():
    fname = '/tmp/.test.bin'
    mpath = "test.pt"
    export_model(mpath, fname)
    export_modelq8(mpath, fname, gs=10)
    assert os.path.getsize(fname)==216

test_export_modelq8()

wrote /tmp/.test.bin
Quantized (1, 1, 3, 3) to Q8_0 with max error 0.0011526569724082947
Quantized (1, 1, 3, 3) to Q8_0 with max error 0.001158013939857483
Quantized (2, 3) to Q8_0 with max error 0.0010108351707458496
Quantized (2,) to Q8_0 with max error 0.00028876960277557373
Quantized (1, 2) to Q8_0 with max error 0.002393275499343872
max quantization group error across all weights: 0.002393275499343872
wrote /tmp/.test.bin


In [14]:
export_model("resnet50")

wrote model.bin
