In [9]:
import warnings
warnings.filterwarnings("ignore")

import torch
import torch.nn.functional as F
from torchvision import datasets, transforms
from nni.algorithms.compression.pytorch.quantization import LsqQuantizer, QAT_Quantizer

import torch.nn as nn
import onnx
import onnx.numpy_helper

### markus
import numpy as np
import PIL

import matplotlib.pyplot as plt

%matplotlib notebook

# Load the Data

In [10]:
torch.manual_seed(0)
# choose the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# Dataloader for MNIST Dataset

## convert images from 1-color channel to 3-color channel images
trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,)), transforms.Lambda(lambda x: x.repeat(3, 1, 1) )])

root='data'
# if not exist, download mnist dataset
train_set = datasets.MNIST(root=root, train=True, transform=trans, download=True)
test_set = datasets.MNIST(root=root, train=False, transform=trans, download=True)

batch_size = 50

train_loader = torch.utils.data.DataLoader(
                 dataset=train_set,
                 batch_size=batch_size,
                 shuffle=True)
test_loader = torch.utils.data.DataLoader(
                dataset=test_set,
                batch_size=batch_size,
                shuffle=False,)


### Export and view image to test

In [11]:
numbers = set()
for i in range(0,100):
    image, label = train_set[i]

    # export until all numbers are present
    if len(numbers)==10:
        break
    if label not in numbers:
        numbers.add(label)
        # denormalize
        imag=(np.array(image.tolist())+0.5) * 255
        # shape image from CHW [RRRRR[..],GGG[...],BBB[...]] -> HWC (3x28x28  -> 28x28x3 [RGB,RGB,RGB,RGB,...])
        imag_tp = np.ascontiguousarray( imag.transpose((1,2,0)), dtype=np.uint8)
        print(f"{label} : {imag.shape} -> {imag_tp.shape}")
       
        #print(imag.shape)
        #print(imag.astype(np.uint8))
        pil_image = PIL.Image.frombytes('RGB',(28,28), imag_tp)
        pil_image.save("example_images_rgb\\"+str(label)+".bmp")
        
        
# visualize the last image as example
plt.imshow(imag_tp)


5 : (3, 28, 28) -> (28, 28, 3)
0 : (3, 28, 28) -> (28, 28, 3)
4 : (3, 28, 28) -> (28, 28, 3)
1 : (3, 28, 28) -> (28, 28, 3)
9 : (3, 28, 28) -> (28, 28, 3)
2 : (3, 28, 28) -> (28, 28, 3)
3 : (3, 28, 28) -> (28, 28, 3)
6 : (3, 28, 28) -> (28, 28, 3)
7 : (3, 28, 28) -> (28, 28, 3)
8 : (3, 28, 28) -> (28, 28, 3)


<IPython.core.display.Javascript object>

<matplotlib.image.AxesImage at 0x7feb15b4a910>

In [12]:
print(image[:,:,0])
print( (imag_tp[:,:, 0].astype(float)-127) / 255)

tensor([[-0.5000, -0.5000, -0.5000, -0.5000, -0.5000, -0.5000, -0.5000, -0.5000,
         -0.5000, -0.5000, -0.5000, -0.5000, -0.5000, -0.5000, -0.5000, -0.5000,
         -0.5000, -0.5000, -0.5000, -0.5000, -0.5000, -0.5000, -0.5000, -0.5000,
         -0.5000, -0.5000, -0.5000, -0.5000],
        [-0.5000, -0.5000, -0.5000, -0.5000, -0.5000, -0.5000, -0.5000, -0.5000,
         -0.5000, -0.5000, -0.5000, -0.5000, -0.5000, -0.5000, -0.5000, -0.5000,
         -0.5000, -0.5000, -0.5000, -0.5000, -0.5000, -0.5000, -0.5000, -0.5000,
         -0.5000, -0.5000, -0.5000, -0.5000],
        [-0.5000, -0.5000, -0.5000, -0.5000, -0.5000, -0.5000, -0.5000, -0.5000,
         -0.5000, -0.5000, -0.5000, -0.5000, -0.5000, -0.5000, -0.5000, -0.5000,
         -0.5000, -0.5000, -0.5000, -0.5000, -0.5000, -0.5000, -0.5000, -0.5000,
         -0.5000, -0.5000, -0.5000, -0.5000]])
[[-0.49803922 -0.49803922 -0.49803922 -0.49803922 -0.49803922 -0.49803922
  -0.49803922 -0.49803922 -0.49803922 -0.49803922 -0.49803

# Build the normal model

In [13]:
# from: https://karanbirchahal.medium.com/how-to-quantise-an-mnist-network-to-8-bits-in-pytorch-no-retraining-required-from-scratch-39f634ac8459
## we want true rgb data to be trained
mnist = False
if mnist:
  num_channels = 1
else:
  num_channels = 3

class Mnist(nn.Module):
    def __init__(self):
        super(Mnist, self).__init__()
        self.conv1 = nn.Conv2d(num_channels, 20, 5, 1)
        self.conv2 = nn.Conv2d(20, 20, 5, 1)

        self.conv3 = nn.Conv2d(20, 50, 1, 1)
        ########################
        self.conv4 = nn.Conv2d(50, 50, 1, 1)
        self.conv5 = nn.Conv2d(50, 50, 1, 1)
        self.conv6 = nn.Conv2d(50, 50, 1, 1)
        self.conv7 = nn.Conv2d(50, 50, 1, 1)
        self.conv8 = nn.Conv2d(50, 50, 1, 1)
        self.conv9 = nn.Conv2d(50, 50, 1, 1)
        self.conv10 = nn.Conv2d(50, 50, 1, 1)
        self.conv11 = nn.Conv2d(50, 50, 1, 1)
        self.conv12 = nn.Conv2d(50, 50, 1, 1)
        self.conv13 = nn.Conv2d(50, 50, 1, 1)
        self.conv14 = nn.Conv2d(50, 50, 1, 1)
        self.conv15 = nn.Conv2d(50, 50, 1, 1)
        self.conv16 = nn.Conv2d(50, 50, 1, 1)
        self.conv17 = nn.Conv2d(50, 50, 1, 1)
        self.conv18 = nn.Conv2d(50, 50, 1, 1)
        self.conv19 = nn.Conv2d(50, 50, 1, 1)
        self.conv20 = nn.Conv2d(50, 50, 1, 1)
        self.conv21 = nn.Conv2d(50, 50, 1, 1)
        self.conv22 = nn.Conv2d(50, 50, 1, 1)
        self.conv23 = nn.Conv2d(50, 50, 1, 1)
        self.conv24 = nn.Conv2d(50, 50, 1, 1)
        self.conv25 = nn.Conv2d(50, 50, 1, 1)
        self.conv26 = nn.Conv2d(50, 50, 1, 1)
        self.conv27 = nn.Conv2d(50, 50, 1, 1)
        self.conv28 = nn.Conv2d(50, 50, 1, 1)
        self.conv29 = nn.Conv2d(50, 50, 1, 1)
        self.conv30 = nn.Conv2d(50, 50, 1, 1)
        self.conv31 = nn.Conv2d(50, 50, 1, 1)
        self.conv32 = nn.Conv2d(50, 50, 1, 1)
        self.conv33 = nn.Conv2d(50, 50, 1, 1)
        self.conv34 = nn.Conv2d(50, 50, 1, 1)
        self.conv35 = nn.Conv2d(50, 50, 1, 1)
        self.conv36 = nn.Conv2d(50, 50, 1, 1)
        self.conv37 = nn.Conv2d(50, 50, 1, 1)
        self.conv38 = nn.Conv2d(50, 50, 1, 1)
        self.conv39 = nn.Conv2d(50, 50, 1, 1)
        self.conv40 = nn.Conv2d(50, 50, 1, 1)
        self.conv41 = nn.Conv2d(50, 50, 1, 1)
        self.conv42 = nn.Conv2d(50, 50, 1, 1)
        self.conv43 = nn.Conv2d(50, 50, 1, 1)
        self.conv44 = nn.Conv2d(50, 50, 1, 1)
        self.conv45 = nn.Conv2d(50, 50, 1, 1)
        self.conv46 = nn.Conv2d(50, 50, 1, 1)
        self.conv47 = nn.Conv2d(50, 50, 1, 1)
        self.conv48 = nn.Conv2d(50, 50, 1, 1)
        self.conv49 = nn.Conv2d(50, 50, 1, 1)
        self.conv50 = nn.Conv2d(50, 50, 1, 1)
        self.conv51 = nn.Conv2d(50, 50, 1, 1)
        self.conv52 = nn.Conv2d(50, 50, 1, 1)
        self.conv53 = nn.Conv2d(50, 50, 1, 1)


        self.fc1 = nn.Linear(4*4*50, 500)
        self.fc2 = nn.Linear(500, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))

        x = F.relu(self.conv3(x))
        ########################
        x = F.relu(self.conv4(x))
        x = F.relu(self.conv5(x))
        x = F.relu(self.conv6(x))
        x = F.relu(self.conv7(x))
        x = F.relu(self.conv8(x))
        x = F.relu(self.conv9(x))
        x = F.relu(self.conv10(x))
        x = F.relu(self.conv11(x))
        x = F.relu(self.conv12(x))
        x = F.relu(self.conv13(x))
        x = F.relu(self.conv14(x))
        x = F.relu(self.conv15(x))
        x = F.relu(self.conv16(x))
        x = F.relu(self.conv17(x))
        x = F.relu(self.conv18(x))
        x = F.relu(self.conv19(x))
        x = F.relu(self.conv20(x))
        x = F.relu(self.conv21(x))
        x = F.relu(self.conv22(x))
        x = F.relu(self.conv23(x))
        x = F.relu(self.conv24(x))
        x = F.relu(self.conv25(x))
        x = F.relu(self.conv26(x))
        x = F.relu(self.conv27(x))
        x = F.relu(self.conv28(x))
        x = F.relu(self.conv29(x))
        x = F.relu(self.conv30(x))
        x = F.relu(self.conv31(x))
        x = F.relu(self.conv32(x))
        x = F.relu(self.conv33(x))
        x = F.relu(self.conv34(x))
        x = F.relu(self.conv35(x))
        x = F.relu(self.conv36(x))
        x = F.relu(self.conv37(x))
        x = F.relu(self.conv38(x))
        x = F.relu(self.conv39(x))
        x = F.relu(self.conv40(x))
        x = F.relu(self.conv41(x))
        x = F.relu(self.conv42(x))
        x = F.relu(self.conv43(x))
        x = F.relu(self.conv44(x))
        x = F.relu(self.conv45(x))
        x = F.relu(self.conv46(x))
        x = F.relu(self.conv47(x))
        x = F.relu(self.conv48(x))
        x = F.relu(self.conv49(x))
        x = F.relu(self.conv50(x))
        x = F.relu(self.conv51(x))
        x = F.relu(self.conv52(x))
        x = F.relu(self.conv53(x))


        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 4*4*50)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

In [14]:
model = Mnist()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)


# Train the model

In [15]:
def train(model,  device, train_loader, optimizer):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % 100 == 0:
            print('{:2.0f}%  Loss {}'.format(100 * batch_idx / len(train_loader), loss.item()))


def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
    test_loss /= len(test_loader.dataset)

    print('Loss: {}  Accuracy: {}%)\n'.format(
        test_loss, 100 * correct / len(test_loader.dataset)))

In [16]:
model.to(device)
# training only 30 epoches
for epoch in range(10):
    print('# Epoch {} #'.format(epoch))
    train(model,  device, train_loader, optimizer)
    test(model, device, test_loader)

# Epoch 0 #
 0%  Loss 2.297818422317505
 8%  Loss 2.3035624027252197
17%  Loss 2.2982606887817383
25%  Loss 2.306532144546509
33%  Loss 2.3108413219451904
42%  Loss 2.3004724979400635
50%  Loss 2.308532953262329
58%  Loss 2.2933411598205566
67%  Loss 2.3072328567504883
75%  Loss 2.2946364879608154
83%  Loss 2.2918643951416016
92%  Loss 2.2930808067321777
Loss: 2.301144715881348  Accuracy: 11.35%)

# Epoch 1 #
 0%  Loss 2.2986860275268555
 8%  Loss 2.289917230606079
17%  Loss 2.299856662750244
25%  Loss 2.297154664993286
33%  Loss 2.3028573989868164
42%  Loss 2.3093507289886475
50%  Loss 2.296455144882202
58%  Loss 2.3064589500427246
67%  Loss 2.303717613220215
75%  Loss 2.3310399055480957
83%  Loss 2.2961723804473877
92%  Loss 2.290550470352173
Loss: 2.3012802696228025  Accuracy: 11.35%)

# Epoch 2 #
 0%  Loss 2.2978012561798096
 8%  Loss 2.301926612854004
17%  Loss 2.3033406734466553
25%  Loss 2.300546646118164
33%  Loss 2.2976627349853516
42%  Loss 2.3046207427978516
50%  Loss 2.2925

# Model Save and Test

In [17]:
"""
The main function of this page is to convert pytorch model to onnx model.
Convertion from pytorch model to onnx model is primary so that a critical
problem is caused that Layer name of pytorch model fail to convert to onnx
layer name directly. To solve it, we wrap pytorch model in new wrapper which
multiply bits number and input before computation of each op. Only in this
way can onnx model get bits number of corresponded layer.
"""

class LayernameModuleWrapper(torch.nn.Module):
    def __init__(self, module, module_bits) -> None:
        """
        Parameters
        ----------
        module : torch.nn.Module
            Layer module of pytorch model
        module_bits : int
            Bits width setting for module
        """
        super().__init__()
        self.module = module
        self.module_bits = module_bits

    def forward(self, inputs):
        inputs = inputs*self.module_bits
        inputs = self.module(inputs)
        return inputs

def _setattr(model, name, module):
    """
    Parameters
    ----------
    model : pytorch model
        The model to speedup by quantization
    name : str
        name of pytorch module
    module : torch.nn.Module
        Layer module of pytorch model
    """
    name_list = name.split(".")
    for name in name_list[:-1]:
        model = getattr(model, name)
    setattr(model, name_list[-1], module)

def unwrapper(model_onnx, index2name, config):
    """
    Fill onnx config and remove wrapper node in onnx
    Parameters
    ----------
    model_onnx : onnx model
        Onnx model which is converted from pytorch model
    index2name : dict
        Dictionary of layer index and name
    config : dict
        Config recording name of layers and calibration parameters
    Returns
    -------
    onnx model
        Onnx model which is converted from pytorch model
    dict
        The configuration of onnx model layers and calibration parameters
    """
    # Support Gemm, Conv, Relu, Clip(Relu6) and Maxpool
    support_op = ['Gemm', 'Conv', 'Relu', 'Clip', 'MaxP']
    idx = 0
    onnx_config = {}
    while idx < len(model_onnx.graph.node):
        nd = model_onnx.graph.node[idx]
        if nd.name[0:4] in support_op and  idx > 1:
            # Grad constant node and multiply node
            const_nd = model_onnx.graph.node[idx-2]
            mul_nd = model_onnx.graph.node[idx-1]
            # Get index number which is transferred by constant node
            index = int(onnx.numpy_helper.to_array(const_nd.attribute[0].t))
            if index != -1:
                name = index2name[index]
                onnx_config[nd.name] = config[name]
            nd.input[0] = mul_nd.input[0]
            # Remove constant node and multiply node
            model_onnx.graph.node.remove(const_nd)
            model_onnx.graph.node.remove(mul_nd)
            idx = idx-2
        idx = idx+1
    return model_onnx, onnx_config

def torch_to_onnx(model, input_shape, model_path, input_names, output_names):
    """
    Convert torch model to onnx model and get layer bits config of onnx model.
    Parameters
    ----------
    model : pytorch model
        The model to speedup by quantization
    config : dict
        Config recording bits number and name of layers
    input_shape : tuple
        The input shape of model, shall pass it to torch.onnx.export
    model_path : str
        The path user want to store onnx model which is converted from pytorch model
    input_names : list
        Input name of onnx model providing for torch.onnx.export to generate onnx model
    output_name : list
        Output name of onnx model providing for torch.onnx.export to generate onnx model
    Returns
    -------
    onnx model
        Onnx model which is converted from pytorch model
    dict
        The configuration of onnx model layers and calibration parameters
    """
    # Convert torch model to onnx model and save it in model_path
    dummy_input = torch.randn(input_shape)
    model.to('cpu')
    torch.onnx.export(model, dummy_input, model_path, verbose=False, input_names=input_names, output_names=output_names, export_params=True)

    # Load onnx model
    model_onnx = onnx.load(model_path)
    model_onnx, onnx_config = unwrapper(model_onnx, index2name, config)
    onnx.save(model_onnx, model_path)

    onnx.checker.check_model(model_onnx)
    return model_onnx, onnx_config



def export_model_to_onnx(model, input_shape=(1,3,28,28), path="mnistrB.onnx"):

    dummy_input = torch.randn(input_shape)
    model.to('cpu')
        
    # very important or must leave out - not sure need to test again...
    #traced = torch.jit.trace(model, input_dimension)
    print("------------- Exporting to onnx")
    torch.onnx.export(
                      model, 
                      dummy_input, 
                      path,
                      opset_version=7,
                      verbose=True,
                      export_params=True, 
                      input_names=['input'],
                      output_names=['output'],
                      dynamic_axes=None
    )
    
    print("------------- Checking exported model")
    
    # Load the ONNX model
    onnx_model = onnx.load(path)

    # Check that the IR is well formed
    onnx.checker.check_model(onnx_model)

    # Print a Human readable representation of the graph
    print( onnx.helper.printable_graph(onnx_model.graph) )



In [18]:
export_model_to_onnx(model)

------------- Exporting to onnx
graph(%input : Float(1, 3, 28, 28, strides=[2352, 784, 28, 1], requires_grad=0, device=cpu),
      %conv1.weight : Float(20, 3, 5, 5, strides=[75, 25, 5, 1], requires_grad=1, device=cpu),
      %conv1.bias : Float(20, strides=[1], requires_grad=1, device=cpu),
      %conv2.weight : Float(20, 20, 5, 5, strides=[500, 25, 5, 1], requires_grad=1, device=cpu),
      %conv2.bias : Float(20, strides=[1], requires_grad=1, device=cpu),
      %conv3.weight : Float(50, 20, 1, 1, strides=[20, 1, 1, 1], requires_grad=1, device=cpu),
      %conv3.bias : Float(50, strides=[1], requires_grad=1, device=cpu),
      %conv4.weight : Float(50, 50, 1, 1, strides=[50, 1, 1, 1], requires_grad=1, device=cpu),
      %conv4.bias : Float(50, strides=[1], requires_grad=1, device=cpu),
      %conv5.weight : Float(50, 50, 1, 1, strides=[50, 1, 1, 1], requires_grad=1, device=cpu),
      %conv5.bias : Float(50, strides=[1], requires_grad=1, device=cpu),
      %conv6.weight : Float(50, 5

In [19]:
print(mnist)

False
