In [1]:
from pathlib import Path
from typing import Dict
from pprint import pprint

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

from torchvision.datasets import MNIST
from torchvision import transforms
from torchvision.transforms import ToTensor

import brevitas.nn as qnn
from brevitas.quant import Int8Bias as BiasQuant

In [17]:
project_path = Path.cwd().parent
mnist_path = project_path/'data/mnist'
weight_path = project_path/'quant_he_code/weights'

## Load float trained weights to do inference

In [3]:
test_dataset = MNIST(root=mnist_path, train=False, transform=transforms.Compose([
    ToTensor(),
    lambda x: (x*4).int(),
    lambda x: x.float()/4,
]))

test_loader = DataLoader(test_dataset, batch_size=64, pin_memory=True)

In [4]:
## same code from mnist.ipynb

def accuracy(outputs, labels):
    _, preds = torch.max(outputs, dim=1)
    return torch.tensor(torch.sum(preds == labels).item() / len(preds))


class ImageClassificationBase(nn.Module):
    """
    PytorchLightining style
    """
    def training_step(self, batch):
        images, labels = batch
        out = self(images)  # Generate predictions
        loss = F.cross_entropy(out, labels)  # Calculate loss
        return loss

    def validation_step(self, batch) -> Dict:
        images, labels = batch
        out = self(images)  # Generate predictions
        loss = F.cross_entropy(out, labels)  # Calculate loss
        acc = accuracy(out, labels)  # Calculate accuracy
        return {'val_loss': loss.detach(), 'val_acc': acc}

    def validation_epoch_end(self, outputs) -> Dict:
        batch_losses = [x['val_loss'] for x in outputs]
        epoch_loss = torch.stack(batch_losses).mean()  # Combine losses
        batch_accs = [x['val_acc'] for x in outputs]
        epoch_acc = torch.stack(batch_accs).mean()  # Combine accuracies
        return {'val_loss': epoch_loss.item(), 'val_acc': epoch_acc.item()}

    def epoch_end(self, epoch, result) -> None:
        print("Epoch [{}], val_loss: {:.4f}, val_acc: {:.4f}".format(epoch + 1, result['val_loss'], result['val_acc']))


class MNISTConvModel(ImageClassificationBase):
    """
    2 conv layers + 1 linear layer
    """
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 5, 5, stride=(2, 2),
                               padding=0, bias=True)

        self.conv2 = nn.Conv2d(5, 50, 5, stride=(2, 2),
                               padding=0, bias=True)
        self.fc1 = nn.Linear(800, 10, bias=True)

    def forward(self, xb):
        out = self.conv1(xb)
        out = out * out  # first square
        out = self.conv2(out)
        out = out.reshape(out.shape[0], -1)
        out = out * out  # second square
        out = self.fc1(out)

        return out

In [5]:
def evaluate(model, val_loader) -> Dict:
    outputs = [model.validation_step(batch) for batch in val_loader]
    return model.validation_epoch_end(outputs)

Evaluation before loading trained weights

In [8]:
model = MNISTConvModel()
acc = evaluate(model, test_loader)['val_acc']
print(f"test accuracy with random weights = {acc}")

test accuracy with random weights = 0.10161226242780685


Load the trained weights and evaluate

In [11]:
model.load_state_dict(torch.load(
    './hcnn_mnist_plain.pth', map_location=torch.device('cpu'))
)
dict(model.named_modules())

{'': MNISTConvModel(
   (conv1): Conv2d(1, 5, kernel_size=(5, 5), stride=(2, 2))
   (conv2): Conv2d(5, 50, kernel_size=(5, 5), stride=(2, 2))
   (fc1): Linear(in_features=800, out_features=10, bias=True)
 ),
 'conv1': Conv2d(1, 5, kernel_size=(5, 5), stride=(2, 2)),
 'conv2': Conv2d(5, 50, kernel_size=(5, 5), stride=(2, 2)),
 'fc1': Linear(in_features=800, out_features=10, bias=True)}

In [12]:
acc = evaluate(model, test_loader)['val_acc']
print(f"test accuracy with trained weights = {acc}")

test accuracy with trained weights = 0.9882563948631287


## Quantize the neural network

In [13]:
class QuantMNISTConvModel(nn.Module):
    def __init__(self, weight_bit_width: int = 16):
        super().__init__()
        self.conv1 = qnn.QuantConv2d(1, 5, 5, stride=(2, 2),
                                     padding=0, bias=True, 
                                     weight_bit_width=weight_bit_width, 
                                     return_quant_tensor=True)
        self.conv2 = qnn.QuantConv2d(5, 50, 5, stride=(2, 2),
                                     padding=0, bias=True, 
                                     weight_bit_width=weight_bit_width)

        self.fc1 = qnn.QuantLinear(800, 10, bias=True, 
                                   weight_bit_width=weight_bit_width, 
                                   return_quant_tensor=True)

    def forward(self, xb):
        out = self.conv1(xb)
        out = out * out
        out = self.conv2(out)
        out = out.reshape(out.shape[0], -1)
        out = out * out
        out = self.fc1(out)

        return out

weight_bit_width = 16
quant_model = QuantMNISTConvModel(weight_bit_width=weight_bit_width)

In [14]:
quant_model.load_state_dict(torch.load(
    './hcnn_mnist_plain.pth', map_location=torch.device('cpu'))
)

<All keys matched successfully>

In [15]:
pprint(dict(quant_model.named_modules()).keys())

dict_keys(['', 'conv1', 'conv1.input_quant', 'conv1.input_quant._zero_hw_sentinel', 'conv1.output_quant', 'conv1.output_quant._zero_hw_sentinel', 'conv1.weight_quant', 'conv1.weight_quant._zero_hw_sentinel', 'conv1.weight_quant.tensor_quant', 'conv1.weight_quant.tensor_quant.int_quant', 'conv1.weight_quant.tensor_quant.int_quant.float_to_int_impl', 'conv1.weight_quant.tensor_quant.int_quant.tensor_clamp_impl', 'conv1.weight_quant.tensor_quant.int_quant.delay_wrapper', 'conv1.weight_quant.tensor_quant.int_quant.delay_wrapper.delay_impl', 'conv1.weight_quant.tensor_quant.scaling_impl', 'conv1.weight_quant.tensor_quant.scaling_impl.parameter_list_stats', 'conv1.weight_quant.tensor_quant.scaling_impl.parameter_list_stats.first_tracked_param', 'conv1.weight_quant.tensor_quant.scaling_impl.parameter_list_stats.first_tracked_param.view_shape_impl', 'conv1.weight_quant.tensor_quant.scaling_impl.parameter_list_stats.stats', 'conv1.weight_quant.tensor_quant.scaling_impl.parameter_list_stats.stat

## Export weights to `.py` file

Utility functions

In [16]:
def generate_size(name, layer):
    string = ""

    if "Conv2d" in str(layer.type):
        string += name + "_kernel_size  = " + str(layer.kernel_size) + "\n"
        string += name + "_in_channels  = " + str(layer.in_channels) + "\n"
        string += name + "_out_channels = " + str(layer.out_channels) + "\n"
        string += name + "_stride       = " + str(layer.stride) + "\n"
        string += name + "_dilation     = " + str(layer.dilation) + "\n"
    elif "Linear" in str(layer.type):
        string += name + "_input  = " + str(layer.in_features) + "\n"
        string += name + "_output = " + str(layer.out_features) + "\n"

    string += "\n"
    return string


def generate_string(name, array):
    splitted = name.split("/")
    flat_array = array.flatten()
    variable_name = splitted[0]

    # In case of bias: add to name
    if "bias" in name:
        variable_name = variable_name + "_bias"

    string = "inline double " + str(variable_name) + " [" + str(flat_array.shape[0]) + "] = {"

    # In case of short array, start direct writing. Else add line escape
    if len(flat_array) > 15:
        string += "\n"

    for i in range(len(flat_array)):
        string += str(flat_array[i])
        if (i != len(flat_array) - 1):
            string += ','
        if i % 7 == 0 and i != 0:
            string += '\n'

    # Remove the line escape if it is there
    if string[-1] == '\n':
        string = string[:-1]
    string += "};\n\n"

    return string

Export floating point weights

In [None]:
file_to_write = "float_weights.py"  # args.export_filepath
trained_weight_path = "hcnn_mnist_plain.pth"  # args.data

if ".py" not in file_to_write:
    f = open(file_to_write + ".py", "w")
else:
    f = open(file_to_write, "w")

f.write("import numpy as np \n")
f.write("# " + str(trained_weight_path) + "\n\n")
f.write(generate_size("conv2d", quant_model.conv1))
f.write(generate_size("conv2d_1", quant_model.conv2))
f.write(generate_size("dense", quant_model.fc1))

f.write("conv2d = np.array(" + str(quant_model.conv1.weight.data.tolist()) + ')\n\n')
f.write("conv2d_1 = np.array(" + str(quant_model.conv2.weight.data.tolist()) + ')\n\n')
f.write("dense = np.array(" + str(quant_model.fc1.weight.data.tolist()) + ')\n\n')
f.write("conv2d_bias = np.array(" + str(model.conv1.bias.data.tolist()) + ')\n\n')
f.write("conv2d_1_bias = np.array(" + str(model.conv2.bias.data.tolist()) + ')\n\n')
f.write("dense_bias = np.array(" + str(model.fc1.bias.data.tolist()) + ')\n\n')

Quantized integer weights

In [20]:
file_to_write = f"quant_weights_{weight_bit_width}bits.py"  # args.export_filepath
trained_weight_path = "hcnn_mnist_plain.pth"  # args.data

if ".py" not in file_to_write:
    f = open(file_to_write + ".py", "w")
else:
    f = open(file_to_write, "w")

f.write("import numpy as np \n")
f.write("# " + str(trained_weight_path) + "\n\n")
f.write(generate_size("conv2d", quant_model.conv1))
f.write(generate_size("conv2d_1", quant_model.conv2))
f.write(generate_size("dense", quant_model.fc1))

f.write("conv2d_scale = " + str(quant_model.conv1.quant_weight().scale.data.tolist()) + '\n\n')
f.write("conv2d = np.array(" + str(quant_model.conv1.quant_weight().int().tolist()) + ')\n\n')
f.write("conv2d_1_scale = " + str(quant_model.conv2.quant_weight().scale.data.tolist()) + '\n\n')
f.write("conv2d_1 = np.array(" + str(quant_model.conv2.quant_weight().int().tolist()) + ')\n\n')

f.write("dense_scale = " + str(quant_model.fc1.quant_weight().scale.data.tolist()) + '\n\n')
f.write("dense = np.array(" + str(quant_model.fc1.quant_weight().int().tolist()) + ')\n\n')

f.write("conv2d_bias = np.array(" + str(model.conv1.bias.data.tolist()) + ')\n\n')
f.write("conv2d_1_bias = np.array(" + str(model.conv2.bias.data.tolist()) + ')\n\n')
f.write("dense_bias = np.array(" + str(model.fc1.bias.data.tolist()) + ')\n\n')


236