In [1]:
import warnings
warnings.filterwarnings("ignore")

import torch
import torch.nn.functional as F
from torchvision import datasets, transforms
from nni.algorithms.compression.pytorch.quantization import LsqQuantizer, QAT_Quantizer

import torch.nn as nn
import onnx
import onnx.numpy_helper

### markus
import numpy as np
import PIL

import matplotlib.pyplot as plt

%matplotlib notebook

[2022-08-13 03:48:46] [32mPyTorch Lightning is not installed.[0m
[2022-08-13 03:48:46] [32mTensorflow is not installed.[0m


# Load the Data

In [2]:
torch.manual_seed(0)
# choose the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# Dataloader for MNIST Dataset

## convert images from 1-color channel to 3-color channel images
trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,)), transforms.Lambda(lambda x: x.repeat(3, 1, 1) )])

root='data'
# if not exist, download mnist dataset
train_set = datasets.MNIST(root=root, train=True, transform=trans, download=True)
test_set = datasets.MNIST(root=root, train=False, transform=trans, download=True)

batch_size = 50

train_loader = torch.utils.data.DataLoader(
                 dataset=train_set,
                 batch_size=batch_size,
                 shuffle=True)
test_loader = torch.utils.data.DataLoader(
                dataset=test_set,
                batch_size=batch_size,
                shuffle=False,)


### Export and view image to test

In [3]:
numbers = set()
for i in range(0,100):
    image, label = train_set[i]

    # export until all numbers are present
    if len(numbers)==10:
        break
    if label not in numbers:
        numbers.add(label)
        # denormalize
        imag=(np.array(image.tolist())+0.5) * 255
        # shape image from CHW [RRRRR[..],GGG[...],BBB[...]] -> HWC (3x28x28  -> 28x28x3 [RGB,RGB,RGB,RGB,...])
        imag_tp = np.ascontiguousarray( imag.transpose((1,2,0)), dtype=np.uint8)
        print(f"{label} : {imag.shape} -> {imag_tp.shape}")
       
        #print(imag.shape)
        #print(imag.astype(np.uint8))
        pil_image = PIL.Image.frombytes('RGB',(28,28), imag_tp)
        pil_image.save("example_images_rgb\\"+str(label)+".bmp")
        
        
# visualize the last image as example
plt.imshow(imag_tp)


5 : (3, 28, 28) -> (28, 28, 3)
0 : (3, 28, 28) -> (28, 28, 3)
4 : (3, 28, 28) -> (28, 28, 3)
1 : (3, 28, 28) -> (28, 28, 3)
9 : (3, 28, 28) -> (28, 28, 3)
2 : (3, 28, 28) -> (28, 28, 3)
3 : (3, 28, 28) -> (28, 28, 3)
6 : (3, 28, 28) -> (28, 28, 3)
7 : (3, 28, 28) -> (28, 28, 3)
8 : (3, 28, 28) -> (28, 28, 3)


<IPython.core.display.Javascript object>

<matplotlib.image.AxesImage at 0x7fee082790d0>

In [4]:
print(image[:,:,0])
print( (imag_tp[:,:, 0].astype(float)-127) / 255)

tensor([[-0.5000, -0.5000, -0.5000, -0.5000, -0.5000, -0.5000, -0.5000, -0.5000,
         -0.5000, -0.5000, -0.5000, -0.5000, -0.5000, -0.5000, -0.5000, -0.5000,
         -0.5000, -0.5000, -0.5000, -0.5000, -0.5000, -0.5000, -0.5000, -0.5000,
         -0.5000, -0.5000, -0.5000, -0.5000],
        [-0.5000, -0.5000, -0.5000, -0.5000, -0.5000, -0.5000, -0.5000, -0.5000,
         -0.5000, -0.5000, -0.5000, -0.5000, -0.5000, -0.5000, -0.5000, -0.5000,
         -0.5000, -0.5000, -0.5000, -0.5000, -0.5000, -0.5000, -0.5000, -0.5000,
         -0.5000, -0.5000, -0.5000, -0.5000],
        [-0.5000, -0.5000, -0.5000, -0.5000, -0.5000, -0.5000, -0.5000, -0.5000,
         -0.5000, -0.5000, -0.5000, -0.5000, -0.5000, -0.5000, -0.5000, -0.5000,
         -0.5000, -0.5000, -0.5000, -0.5000, -0.5000, -0.5000, -0.5000, -0.5000,
         -0.5000, -0.5000, -0.5000, -0.5000]])
[[-0.49803922 -0.49803922 -0.49803922 -0.49803922 -0.49803922 -0.49803922
  -0.49803922 -0.49803922 -0.49803922 -0.49803922 -0.49803

# Build the normal model

In [5]:
# from: https://karanbirchahal.medium.com/how-to-quantise-an-mnist-network-to-8-bits-in-pytorch-no-retraining-required-from-scratch-39f634ac8459
## we want true rgb data to be trained
mnist = False
if mnist:
  num_channels = 1
else:
  num_channels = 3

class Mnist(nn.Module):
    def __init__(self):
        super(Mnist, self).__init__()
        self.conv1 = nn.Conv2d(num_channels, 20, 5, 1)
        self.conv2 = nn.Conv2d(20, 20, 5, 1)

        self.conv3 = nn.Conv2d(20, 50, 1, 1)
        self.conv4 = nn.Conv2d(50, 50, 1, 1)
        self.conv5 = nn.Conv2d(50, 50, 1, 1)

        self.fc1 = nn.Linear(4*4*50, 500)
        self.fc2 = nn.Linear(500, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))

        x = F.relu(self.conv3(x))
        x = F.relu(self.conv4(x))
        x = F.relu(self.conv5(x))

        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 4*4*50)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

In [6]:
model = Mnist()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)


# Train the model

In [7]:
def train(model,  device, train_loader, optimizer):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % 100 == 0:
            print('{:2.0f}%  Loss {}'.format(100 * batch_idx / len(train_loader), loss.item()))


def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
    test_loss /= len(test_loader.dataset)

    print('Loss: {}  Accuracy: {}%)\n'.format(
        test_loss, 100 * correct / len(test_loader.dataset)))

In [8]:
model.to(device)
# training only 30 epoches
for epoch in range(10):
    print('# Epoch {} #'.format(epoch))
    train(model,  device, train_loader, optimizer)
    test(model, device, test_loader)

# Epoch 0 #
 0%  Loss 2.3046133518218994
 8%  Loss 2.29872989654541
17%  Loss 2.2896604537963867
25%  Loss 2.284914255142212
33%  Loss 2.287930727005005
42%  Loss 2.2509796619415283
50%  Loss 2.183379888534546
58%  Loss 0.7409600019454956
67%  Loss 0.40377846360206604
75%  Loss 0.4745100438594818
83%  Loss 0.07942009717226028
92%  Loss 0.2577817440032959
Loss: 0.18880681320130824  Accuracy: 93.82%)

# Epoch 1 #
 0%  Loss 0.16079120337963104
 8%  Loss 0.159143328666687
17%  Loss 0.2190445512533188
25%  Loss 0.12469472736120224
33%  Loss 0.3984004259109497
42%  Loss 0.0426090769469738
50%  Loss 0.06527391076087952
58%  Loss 0.23984478414058685
67%  Loss 0.14549529552459717
75%  Loss 0.07671437412500381
83%  Loss 0.275117963552475
92%  Loss 0.06703899800777435
Loss: 0.07920336971916259  Accuracy: 97.48%)

# Epoch 2 #
 0%  Loss 0.023926405236124992
 8%  Loss 0.018319547176361084
17%  Loss 0.07069689780473709
25%  Loss 0.07843434810638428
33%  Loss 0.02024276740849018
42%  Loss 0.0418521314

# Model Save and Test

In [9]:
"""
The main function of this page is to convert pytorch model to onnx model.
Convertion from pytorch model to onnx model is primary so that a critical
problem is caused that Layer name of pytorch model fail to convert to onnx
layer name directly. To solve it, we wrap pytorch model in new wrapper which
multiply bits number and input before computation of each op. Only in this
way can onnx model get bits number of corresponded layer.
"""

class LayernameModuleWrapper(torch.nn.Module):
    def __init__(self, module, module_bits) -> None:
        """
        Parameters
        ----------
        module : torch.nn.Module
            Layer module of pytorch model
        module_bits : int
            Bits width setting for module
        """
        super().__init__()
        self.module = module
        self.module_bits = module_bits

    def forward(self, inputs):
        inputs = inputs*self.module_bits
        inputs = self.module(inputs)
        return inputs

def _setattr(model, name, module):
    """
    Parameters
    ----------
    model : pytorch model
        The model to speedup by quantization
    name : str
        name of pytorch module
    module : torch.nn.Module
        Layer module of pytorch model
    """
    name_list = name.split(".")
    for name in name_list[:-1]:
        model = getattr(model, name)
    setattr(model, name_list[-1], module)

def unwrapper(model_onnx, index2name, config):
    """
    Fill onnx config and remove wrapper node in onnx
    Parameters
    ----------
    model_onnx : onnx model
        Onnx model which is converted from pytorch model
    index2name : dict
        Dictionary of layer index and name
    config : dict
        Config recording name of layers and calibration parameters
    Returns
    -------
    onnx model
        Onnx model which is converted from pytorch model
    dict
        The configuration of onnx model layers and calibration parameters
    """
    # Support Gemm, Conv, Relu, Clip(Relu6) and Maxpool
    support_op = ['Gemm', 'Conv', 'Relu', 'Clip', 'MaxP']
    idx = 0
    onnx_config = {}
    while idx < len(model_onnx.graph.node):
        nd = model_onnx.graph.node[idx]
        if nd.name[0:4] in support_op and  idx > 1:
            # Grad constant node and multiply node
            const_nd = model_onnx.graph.node[idx-2]
            mul_nd = model_onnx.graph.node[idx-1]
            # Get index number which is transferred by constant node
            index = int(onnx.numpy_helper.to_array(const_nd.attribute[0].t))
            if index != -1:
                name = index2name[index]
                onnx_config[nd.name] = config[name]
            nd.input[0] = mul_nd.input[0]
            # Remove constant node and multiply node
            model_onnx.graph.node.remove(const_nd)
            model_onnx.graph.node.remove(mul_nd)
            idx = idx-2
        idx = idx+1
    return model_onnx, onnx_config

def torch_to_onnx(model, input_shape, model_path, input_names, output_names):
    """
    Convert torch model to onnx model and get layer bits config of onnx model.
    Parameters
    ----------
    model : pytorch model
        The model to speedup by quantization
    config : dict
        Config recording bits number and name of layers
    input_shape : tuple
        The input shape of model, shall pass it to torch.onnx.export
    model_path : str
        The path user want to store onnx model which is converted from pytorch model
    input_names : list
        Input name of onnx model providing for torch.onnx.export to generate onnx model
    output_name : list
        Output name of onnx model providing for torch.onnx.export to generate onnx model
    Returns
    -------
    onnx model
        Onnx model which is converted from pytorch model
    dict
        The configuration of onnx model layers and calibration parameters
    """
    # Convert torch model to onnx model and save it in model_path
    dummy_input = torch.randn(input_shape)
    model.to('cpu')
    torch.onnx.export(model, dummy_input, model_path, verbose=False, input_names=input_names, output_names=output_names, export_params=True)

    # Load onnx model
    model_onnx = onnx.load(model_path)
    model_onnx, onnx_config = unwrapper(model_onnx, index2name, config)
    onnx.save(model_onnx, model_path)

    onnx.checker.check_model(model_onnx)
    return model_onnx, onnx_config



def export_model_to_onnx(model, input_shape=(1,3,28,28), path="mnist2.onnx"):

    dummy_input = torch.randn(input_shape)
    model.to('cpu')
        
    # very important or must leave out - not sure need to test again...
    #traced = torch.jit.trace(model, input_dimension)
    print("------------- Exporting to onnx")
    torch.onnx.export(
                      model, 
                      dummy_input, 
                      path,
                      opset_version=7,
                      verbose=True,
                      export_params=True, 
                      input_names=['input'],
                      output_names=['output'],
                      dynamic_axes=None
    )
    
    print("------------- Checking exported model")
    
    # Load the ONNX model
    onnx_model = onnx.load(path)

    # Check that the IR is well formed
    onnx.checker.check_model(onnx_model)

    # Print a Human readable representation of the graph
    print( onnx.helper.printable_graph(onnx_model.graph) )



In [10]:
export_model_to_onnx(model)

------------- Exporting to onnx
graph(%input : Float(1, 3, 28, 28, strides=[2352, 784, 28, 1], requires_grad=0, device=cpu),
      %conv1.weight : Float(200, 3, 5, 5, strides=[75, 25, 5, 1], requires_grad=1, device=cpu),
      %conv1.bias : Float(200, strides=[1], requires_grad=1, device=cpu),
      %conv2.weight : Float(200, 200, 5, 5, strides=[5000, 25, 5, 1], requires_grad=1, device=cpu),
      %conv2.bias : Float(200, strides=[1], requires_grad=1, device=cpu),
      %conv3.weight : Float(500, 200, 1, 1, strides=[200, 1, 1, 1], requires_grad=1, device=cpu),
      %conv3.bias : Float(500, strides=[1], requires_grad=1, device=cpu),
      %conv4.weight : Float(500, 500, 1, 1, strides=[500, 1, 1, 1], requires_grad=1, device=cpu),
      %conv4.bias : Float(500, strides=[1], requires_grad=1, device=cpu),
      %conv5.weight : Float(500, 500, 1, 1, strides=[500, 1, 1, 1], requires_grad=1, device=cpu),
      %conv5.bias : Float(500, strides=[1], requires_grad=1, device=cpu),
      %fc1.weig

In [11]:
print(mnist)

False
