# ML model generation

This notebook creates and trains a simple fully connected model using MNIST dataset. At the end, the model is exported as `model.bin`.

In [184]:
import numpy as np
import torchvision
from torchvision import transforms
import torch
from torch import nn
import matplotlib.pyplot as plt
import struct
import os

In [2]:
def generate_dataloader(batch_size=32):
    ''' 
    Generates dataloaders
    
    Args:
        batch_size (int): the number of training instances in the batch
    '''
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
    
    trainset = torchvision.datasets.MNIST("./data", train=True, download=True, transform=transform)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)
    
    testset = torchvision.datasets.MNIST("./data", train=False, download=True, transform=transform)
    testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False)
    
    return trainloader, testloader

In [180]:
# sanity check function
def write_activation(tensor):
    f = open("actp.txt", "w")
    copy = torch.clone(tensor).detach().cpu().view(-1).numpy()
    for x in copy:
        f.write("{:.6f}\n".format(x))
    f.close()

In [181]:
class Model(nn.Module):
    def __init__(self, dim=128):
        super(Model, self).__init__()
        self.dim = dim
        self.nclass = 10
        self.flatten = nn.Flatten()
        self.layers = nn.ModuleList([nn.Linear(28*28, dim), 
                                    nn.Linear(dim, dim//2)])
        self.activation = nn.ReLU()
        self.out = nn.Linear(dim//2, self.nclass)
        
    def forward(self, x):
        x = self.flatten(x)
        for layer in self.layers:
            x = self.activation(layer(x))
        x = self.out(x)
        return x

In [261]:
def test_model(model, testloader):
    loss_fn = nn.CrossEntropyLoss()
    model.eval()
    with torch.no_grad():
        vloss = 0.
        correct = 0.
        for X,y in testloader:
            out = model(X)
            vloss += loss_fn(out, y).item()
            correct += (torch.argmax(out, 1)==y).float().sum()
    
    return vloss/len(testloader),  correct/len(testloader.dataset)

def train_model(model):  
    # training
    loss_fn = nn.CrossEntropyLoss()
    opt = torch.optim.AdamW(model.parameters(), lr=0.001)
    trainloader, testloader = generate_dataloader()

    for epoch in range(3):

        model.train()
        tloss = 0
        for X,y in trainloader:
            opt.zero_grad()
            out = model(X)
            loss = loss_fn(out, y)
            loss.backward()
            tloss += loss.item()
            opt.step()

        tloss = tloss/len(trainloader)
        vloss, correct = test_model(model, testloader)

        print('LOSS train {} valid {} accuracy {:.5f}'.format(tloss, vloss, correct))

In [243]:
# export the model
def serialize_fp32(file, tensor):
    """ write one fp32 tensor to file that is open in wb mode """
    d = tensor.detach().cpu().view(-1).to(torch.float32).numpy()
    b = struct.pack(f'{len(d)}f', *d)
    file.write(b)
    
def export_model(model, filepath = "model.bin"):
    ''' export the model to filepath '''
    f = open(filepath, "wb")
    # write the model structure 
    header = struct.pack("ii", model.dim, model.nclass)
    f.write(header) 
    # write the model weights and biases
    weights = [*[layer.weight for layer in model.layers], model.out.weight]
    bias = [*[layer.bias for layer in model.layers], model.out.bias]
    
    for w in weights:
        serialize_fp32(f, w)
    
    for b in bias:
        serialize_fp32(f, b)

    f.close()
    print(f"wrote {filepath}")

In [244]:
def read_model():
    ''' read model.bin and assign parameters to the model '''
    f = open("model.bin", "rb")
    inp_size = 28*28
    dim = struct.unpack('i', f.read(4))[0]
    nclass = struct.unpack('i', f.read(4))[0]
    dim2 = dim//2
    wi = torch.tensor(struct.unpack('f'*inp_size*dim, f.read(4*inp_size*dim))).view(dim,inp_size)
    wh = torch.tensor(struct.unpack('f'*dim*dim2, f.read(4*dim*dim2))).view(dim2,dim)
    wo = torch.tensor(struct.unpack('f'*dim2*nclass, f.read(4*dim2*nclass))).view(nclass, dim2)
    bi = torch.tensor(struct.unpack('f'*dim, f.read(4*dim)))
    bh = torch.tensor(struct.unpack('f'*dim2, f.read(4*dim2)))
    bo = torch.tensor(struct.unpack('f'*nclass, f.read(4*nclass)))
    f.close()
    
    model = Model(dim)
    params = [wi, bi, wh, bh, wo, bo]
    for i,p in enumerate(model.parameters()):
        p.data = params[i]
    
    return model

In [247]:
def serialize_int8(file, tensor):
    """ write one int8 tensor to file that is open in wb mode """
    d = tensor.detach().cpu().view(-1).numpy().astype(np.int8)
    b = struct.pack(f'{len(d)}b', *d)
    file.write(b)

def quantize_q80(w, group_size):
    """
    takes a tensor and returns the Q8_0 quantized version
    i.e. symmetric quantization into int8, range [-127,127]
    """
    assert w.numel() % group_size == 0
    ori_shape = w.shape
    w = w.float() # convert to float32
    w = w.reshape(-1, group_size)
    # find the max in each group
    wmax = torch.abs(w).max(dim=1).values
    # calculate the scaling factor such that float = quant * scale
    scale = wmax / 127.0
    # scale into range [-127, 127]
    quant = w / scale[:,None]
    # round to nearest integer
    int8val = torch.round(quant).to(torch.int8)
    # dequantize by rescaling
    fp32val = (int8val.float() * scale[:,None]).view(-1)
    fp32valr = fp32val.reshape(-1, group_size)
    # calculate the max error in each group
    err = torch.abs(fp32valr - w).max(dim=1).values
    # find the max error across all groups
    maxerr = err.max().item()
    return int8val, scale, maxerr

def dequantize_q80(w, scale, group_size):
    """
    takes a q80 tensor and returns the float tensor
    """
    assert w.numel() % group_size == 0
    w = w.reshape(-1, group_size)
    w = (w.float() * scale[:,None]).view(-1)
    return w

def export_modelq8(model=None, filepath="modelq8.bin", gs=64):
    ''' read a model from model.bin if not given and export a quatized (int8) model to filepath '''
    if model==None:
        model = read_model()
        
    f = open(filepath, "wb")
    # write the model structure 
    header = struct.pack("iii", model.dim, model.nclass, gs)
    f.write(header) 
    # quantize and write the model weights and biases
    weights = [*[layer.weight for layer in model.layers], model.out.weight]
    biases = [*[layer.bias for layer in model.layers], model.out.bias]
    params = [*weights, *biases]

    ew = []
    for i, p in enumerate(params):
        if i==len(params)-1 and gs>model.nclass:
            gs = model.nclass
        # quantize this weight
        q, s, err = quantize_q80(p, gs)
        # save the int8 weights to file
        serialize_int8(f, q) # save the tensor in int8
        serialize_fp32(f, s) # save scale factors
        # logging
        ew.append((err, p.shape))
        print(f"{i+1}/{len(params)} quantized {tuple(p.shape)} to Q8_0 with max error {err}")

    # print the highest error across all parameters, should be very small, e.g. O(~0.001)
    ew.sort(reverse=True)
    print(f"max quantization group error across all weights: {ew[0][0]}")
    f.close()
    print(f"wrote {filepath}")

In [245]:
def read_modelq8():
    ''' read modelq8.bin and assign parameters to the model '''
    f = open("modelq8.bin", "rb")
    inp_size = 28*28
    dim, nclass, gs = [struct.unpack('i', f.read(4))[0] for _ in range(3)]
    dim2 = dim//2
    wi = torch.tensor(struct.unpack('b'*inp_size*dim, f.read(inp_size*dim)))
    si = torch.tensor(struct.unpack('f'*(inp_size*dim//gs), f.read(4*(inp_size*dim//gs))))
    wi = dequantize_q80(wi, si, gs).view(dim,inp_size)
    wh = torch.tensor(struct.unpack('b'*dim*dim2, f.read(dim*dim2)))
    sh = torch.tensor(struct.unpack('f'*(dim*dim2//gs), f.read(4*(dim*dim2//gs))))
    wh = dequantize_q80(wh, sh, gs).view(dim2,dim)
    wo = torch.tensor(struct.unpack('b'*dim2*nclass, f.read(dim2*nclass)))
    so = torch.tensor(struct.unpack('f'*(dim2*nclass//gs), f.read(4*(dim2*nclass//gs))))
    wo = dequantize_q80(wo, so, gs).view(nclass, dim2)
    bi = torch.tensor(struct.unpack('b'*dim, f.read(dim)))
    si = torch.tensor(struct.unpack('f'*(dim//gs), f.read(4*(dim//gs))))
    bi = dequantize_q80(bi, si, gs)
    bh = torch.tensor(struct.unpack('b'*dim2, f.read(dim2)))
    sh = torch.tensor(struct.unpack('f'*(dim2//gs), f.read(4*(dim2//gs))))
    bh = dequantize_q80(bh, sh, gs)
    bo = torch.tensor(struct.unpack('b'*nclass, f.read(nclass)))
    so = torch.tensor(struct.unpack('f', f.read(4)))
    bo = dequantize_q80(bo, so, nclass) 
    f.close()

    model = Model(dim)
    params = [wi, bi, wh, bh, wo, bo]
    for i,p in enumerate(model.parameters()):
        p.data = params[i]
    
    return model

In [205]:
# sanity check function
def write_output():
    ''' write model outputs to compare these to model outputs in C '''
    
    model = read_model()
    _, testloader = generate_dataloader()

    f = open("outputp.txt", "w")
    for X,y in testloader:
        outputs = torch.argmax(model(X), 1)
        outputs = outputs.detach().cpu()
        for out in outputs:
            f.write("{}\n".format(int(out)))
    f.close()

In [269]:
# example 
_, testloader = generate_dataloader()
loss, acc = test_model(read_model(), testloader)
print("Accuracy: {:.2f} %".format(100*acc))

Accuracy: 96.36 %
