# Quick Tutorial of Brevitas for Hardware-Oriented QNN Traning 

##### *Author: Yuhao Liu, Chair of Processor Design, TU Dresden, Email: yuhao.liu1@tu-dresden.de*

## Section I: Library Import for Tutorial

#### 1.1 Improt PyTorch Library

In [1]:
import torch 
import torch.nn as nn

#### 1.2 Import Brevitas 0.11 Library

In [2]:
import brevitas.nn as qnn # Import QNN layers in Brevitas
from brevitas.quant_tensor.int_quant_tensor import IntQuantTensor # Import Integer Quantization types for QNN

## Section II: Quantization Type in Brevitas

#### 2.1 How to define Integer Quantization Type tensors in Brevitas

IntQuantType is used for training and inferring the QNN models in Brevitas, which consists of six attitudes: 

<ol>
  <li>value: the non-quantized raw value</li>
  <li>scale: the scale rate to quantize the raw value</li>
  <li>zero_point: the zero point to quantize the raw value</li>
  <li>bit_width: the bitwidth of quantized value</li>
  <li>signed: If the quantized value is signed</li>
  <li>training: If this value is used in training</li>
</ol>

Therefore, we can try to create the IntQuantType objects.

First, we create three raw value tensors for our IntQuantType object. 

In [3]:
value_1d = torch.arange(0, 10, dtype=torch.float32)
value_2d = torch.arange(0, 16, dtype=torch.float32).reshape(4, 4)
value_3d = torch.arange(0, 32, dtype=torch.float32).reshape(2, 4, 4)

print("Raw Value for 1-d integer quantized tensor: \n", value_1d, "\n")
print("Raw Value for 2-d integer quantized tensor: \n", value_2d, "\n")
print("Raw Value for 3-d integer quantized tensor: \n", value_3d, "\n")

Raw Value for 1-d integer quantized tensor: 
 tensor([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.]) 

Raw Value for 2-d integer quantized tensor: 
 tensor([[ 0.,  1.,  2.,  3.],
        [ 4.,  5.,  6.,  7.],
        [ 8.,  9., 10., 11.],
        [12., 13., 14., 15.]]) 

Raw Value for 3-d integer quantized tensor: 
 tensor([[[ 0.,  1.,  2.,  3.],
         [ 4.,  5.,  6.,  7.],
         [ 8.,  9., 10., 11.],
         [12., 13., 14., 15.]],

        [[16., 17., 18., 19.],
         [20., 21., 22., 23.],
         [24., 25., 26., 27.],
         [28., 29., 30., 31.]]]) 



Then we can define the scale rate, zero point for three IntQuantType tensors. Their are fixed as 8 bit and unsigned data for now.

In [4]:
scale_1d, scale_2d, scale_3d = 0.1, 0.2, 0.5
zero_point_1d, zero_point_2d, zero_point_3d = 1.0, 2.0, 3.0
bit_width = 8
signed = False

Therefore, we can create three IntQuantType tensor as:

In [5]:
int_tensor_1d = IntQuantTensor(value = value_1d, 
                               scale=scale_1d, 
                               zero_point=zero_point_1d, 
                               bit_width=bit_width, 
                               signed=signed, 
                               training=False)
int_tensor_2d = IntQuantTensor(value = value_2d, 
                               scale=scale_2d, 
                               zero_point=zero_point_2d, 
                               bit_width=bit_width, 
                               signed=signed, 
                               training=False)
int_tensor_3d = IntQuantTensor(value = value_3d, 
                               scale=scale_3d, 
                               zero_point=zero_point_3d, 
                               bit_width=bit_width, 
                               signed=signed, 
                               training=False)

print("1-d integer quantized tensor: \n", int_tensor_1d, "\n")
print("2-d integer quantized tensor: \n", int_tensor_2d, "\n")
print("3-d integer quantized tensor: \n", int_tensor_3d, "\n")

1-d integer quantized tensor: 
 IntQuantTensor(value=tensor([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.]), scale=0.10000000149011612, zero_point=1.0, bit_width=8.0, signed_t=False, training_t=False) 

2-d integer quantized tensor: 
 IntQuantTensor(value=tensor([[ 0.,  1.,  2.,  3.],
        [ 4.,  5.,  6.,  7.],
        [ 8.,  9., 10., 11.],
        [12., 13., 14., 15.]]), scale=0.20000000298023224, zero_point=2.0, bit_width=8.0, signed_t=False, training_t=False) 

3-d integer quantized tensor: 
 IntQuantTensor(value=tensor([[[ 0.,  1.,  2.,  3.],
         [ 4.,  5.,  6.,  7.],
         [ 8.,  9., 10., 11.],
         [12., 13., 14., 15.]],

        [[16., 17., 18., 19.],
         [20., 21., 22., 23.],
         [24., 25., 26., 27.],
         [28., 29., 30., 31.]]]), scale=0.5, zero_point=3.0, bit_width=8.0, signed_t=False, training_t=False) 



Therefore, we can find that the data in IntQuantType is actually saved as floating-point, not the quantized integers. If we want to see the real integer value of these IntQuantType tensors, we can try the following two methods:

1. Using the int() function of IntQuantType, which can output the quantized integer tensor based on the value, scale rate, and zero point.
This function follows $Q_{out} = \frac{Raw_{in}}{scale\_rate} + zero\_point$
2. Manually extract the raw value, scale rate, and zero point to round the integers.

In [11]:
real_int_tensor_1d = int_tensor_1d.int(float_datatype=True)
real_int_tensor_2d = int_tensor_2d.int(float_datatype=True)
real_int_tensor_3d = int_tensor_3d.int(float_datatype=True)

print("Real quantized 1-d tensor: \n", real_int_tensor_1d, "\n")
print("Real quantized 2-d tensor: \n", real_int_tensor_2d, "\n")
print("Real quantized 3-d tensor: \n", real_int_tensor_3d, "\n")

Real quantized 1-d tensor: 
 tensor([ 1., 11., 21., 31., 41., 51., 61., 71., 81., 91.]) 

Real quantized 2-d tensor: 
 tensor([[ 2.,  7., 12., 17.],
        [22., 27., 32., 37.],
        [42., 47., 52., 57.],
        [62., 67., 72., 77.]]) 

Real quantized 3-d tensor: 
 tensor([[[ 3.,  5.,  7.,  9.],
         [11., 13., 15., 17.],
         [19., 21., 23., 25.],
         [27., 29., 31., 33.]],

        [[35., 37., 39., 41.],
         [43., 45., 47., 49.],
         [51., 53., 55., 57.],
         [59., 61., 63., 65.]]]) 



We can notice that the int() function needs a "float_datatype=True" to convert the quantized data to a floating-point format. Otherwise, it cannot be used in network training and inference. 

Another method is manually extracting the raw value, scale rate, and zero point, computing the result following the equation above, and rounding the output as integers.

In [12]:
extract_scale_1d, extract_scale_2d, extract_scale_3d = int_tensor_1d.scale, int_tensor_2d.scale, int_tensor_3d.scale
extract_zero_point_1d, extract_zero_point_2d, extract_zero_point_3d = int_tensor_1d.zero_point, int_tensor_2d.zero_point, int_tensor_3d.zero_point

manual_int_tensor_1d = torch.round((int_tensor_1d.value / int_tensor_1d.scale) + int_tensor_1d.zero_point)
manual_int_tensor_2d = torch.round((int_tensor_2d.value / int_tensor_2d.scale) + int_tensor_2d.zero_point)
manual_int_tensor_3d = torch.round((int_tensor_3d.value / int_tensor_3d.scale) + int_tensor_3d.zero_point)

print("Manually quantized 1-d tensor: \n", real_int_tensor_1d, "\n")
print("Manually quantized 2-d tensor: \n", real_int_tensor_2d, "\n")
print("Manually quantized 3-d tensor: \n", real_int_tensor_3d, "\n")

Manually quantized 1-d tensor: 
 tensor([ 1., 11., 21., 31., 41., 51., 61., 71., 81., 91.]) 

Manually quantized 2-d tensor: 
 tensor([[ 2.,  7., 12., 17.],
        [22., 27., 32., 37.],
        [42., 47., 52., 57.],
        [62., 67., 72., 77.]]) 

Manually quantized 3-d tensor: 
 tensor([[[ 3.,  5.,  7.,  9.],
         [11., 13., 15., 17.],
         [19., 21., 23., 25.],
         [27., 29., 31., 33.]],

        [[35., 37., 39., 41.],
         [43., 45., 47., 49.],
         [51., 53., 55., 57.],
         [59., 61., 63., 65.]]]) 



In [6]:
in_features = 10
in_batch = 3
out_features = 4

In [7]:
fc_torch = nn.Linear(in_features=in_features, out_features=out_features, bias=False)

In [8]:
input = torch.arange(0, (in_features*in_batch), dtype=torch.float32).reshape(in_batch, in_features)
print("Input: \n", input, "\n")
print("Input Shape: \n", input.shape)

Input: 
 tensor([[ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9.],
        [10., 11., 12., 13., 14., 15., 16., 17., 18., 19.],
        [20., 21., 22., 23., 24., 25., 26., 27., 28., 29.]]) 

Input Shape: 
 torch.Size([3, 10])


In [9]:
output = fc_torch(input)
print("Output: \n", output, "\n")
print("Output Shape: \n", output.shape)

Output: 
 tensor([[  3.3895,  -5.7051,  -0.0827,   1.8867],
        [  8.4728, -19.1678,  -4.5732,   9.1544],
        [ 13.5561, -32.6305,  -9.0636,  16.4221]], grad_fn=<MmBackward0>) 

Output Shape: 
 torch.Size([3, 4])


In [24]:
weight_bit_width = 8

In [25]:
fc_brevtias = qnn.QuantLinear(
                         in_features, 
                         out_features, 
                         weight_bit_width=weight_bit_width,
                         weight_quant_type=QuantType.INT,
                         bias=False
                     )

In [26]:
input = torch.arange(0, (in_features*in_batch), dtype=torch.float32).reshape(in_batch, in_features)
print("Input: \n", input, "\n")
print("Input Shape: \n", input.shape)

Input: 
 tensor([[ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9.],
        [10., 11., 12., 13., 14., 15., 16., 17., 18., 19.],
        [20., 21., 22., 23., 24., 25., 26., 27., 28., 29.]]) 

Input Shape: 
 torch.Size([3, 10])


In [27]:
output = fc_brevtias(input)
print("Output: \n", output, "\n")
print("Output Shape: \n", output.shape)

Output: 
 tensor([[ 1.3777e+00,  6.3361e+00, -1.0303e+00, -2.4129e-03],
        [ 5.8656e+00,  1.7797e+01, -3.3949e+00, -2.9195e-01],
        [ 1.0353e+01,  2.9258e+01, -5.7594e+00, -5.8149e-01]],
       grad_fn=<MmBackward0>) 

Output Shape: 
 torch.Size([3, 4])


In [28]:
fc_brevtias.weight

Parameter containing:
tensor([[ 0.1873,  0.0713,  0.1197, -0.2085,  0.2534, -0.2766,  0.1688,  0.0643,
         -0.0056,  0.0702],
        [ 0.1553, -0.1768,  0.1419,  0.2165,  0.2388, -0.1434,  0.1745,  0.2389,
          0.0745,  0.2245],
        [ 0.1911,  0.1431, -0.1837, -0.0698, -0.2489, -0.1370, -0.3064,  0.1281,
          0.2084,  0.0405],
        [-0.0802,  0.1470, -0.0577, -0.0128, -0.0954,  0.1145, -0.1578,  0.2424,
         -0.2232,  0.0974]], requires_grad=True)

In [29]:
bn = nn.BatchNorm1d(out_features)

In [30]:
output = bn(output)
print("Output: \n", output, "\n")
print("Output Shape: \n", output.shape)

Output: 
 tensor([[-1.2247e+00, -1.2247e+00,  1.2247e+00,  1.2246e+00],
        [ 9.3746e-08,  1.6303e-08,  1.0984e-07,  6.7152e-07],
        [ 1.2247e+00,  1.2247e+00, -1.2247e+00, -1.2246e+00]],
       grad_fn=<NativeBatchNormBackward0>) 

Output Shape: 
 torch.Size([3, 4])


In [31]:
quan_bit_width = 8

In [32]:
gelu = nn.GELU()
output = gelu(output)
print("Output: \n", output, "\n")
print("Output Shape: \n", output.shape)

Output: 
 tensor([[-1.3513e-01, -1.3513e-01,  1.0896e+00,  1.0895e+00],
        [ 4.6873e-08,  8.1513e-09,  5.4920e-08,  3.3576e-07],
        [ 1.0896e+00,  1.0896e+00, -1.3513e-01, -1.3515e-01]],
       grad_fn=<GeluBackward0>) 

Output Shape: 
 torch.Size([3, 4])


In [33]:
quant_gelu = qnn.QuantIdentity(
                         quant_type='int',
                         scaling_impl_type='const',
                         bit_width=quan_bit_width,
                         min_val=-128.0,
                         max_val=127.0, 
                         return_quant_tensor=True
                     )
output = quant_gelu(output)
print("Output: \n", output, "\n")
print("Output Shape: \n", output.shape)

Output: 
 QuantTensor(value=tensor([[-0., -0., 1., 1.],
        [0., 0., 0., 0.],
        [1., 1., -0., -0.]], grad_fn=<MulBackward0>), scale=tensor(1.), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(True)) 

Output Shape: 
 torch.Size([3, 4])


In [15]:
output = quant_relu(output)
print("Output: \n", output, "\n")
print("Output Shape: \n", output.shape)

Output: 
 QuantTensor(value=tensor([[-0.8410, -0.8410,  0.8345,  0.8345],
        [ 0.0000,  0.0000,  0.0000, -0.0000],
        [ 0.8345,  0.8345, -0.8410, -0.8410]], grad_fn=<MulBackward0>), scale=tensor(0.0066, grad_fn=<DivBackward0>), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(True)) 

Output Shape: 
 torch.Size([3, 4])


In [16]:
output = output._pre_round_int_value
print("Output: \n", output, "\n")
print("Output Shape: \n", output.shape)

Output: 
 tensor([[-128., -128.,  127.,  127.],
        [   0.,    0.,    0.,    0.],
        [ 127.,  127., -128., -128.]], grad_fn=<AddBackward0>) 

Output Shape: 
 torch.Size([3, 4])


In [None]:
kernel_size=(3,3)   

in_channels1=1
out_channels1=64 

in_channels2=64
out_channels2=64

input_size = 7*7*64 

weight_bit_width = 1
act_bit_width = 1

hidden1 = 64   
num_classes = 10  

In [None]:
class TCV_W8A8(Module):
    def __init__(self):
        super(TCV_W8A8, self).__init__()
        
        self.input = qnn.QuantIdentity(
                         quant_type='int',
                         scaling_impl_type='const',
                         bit_width=quan_bit_width,
                         min_val=-128.0,
                         max_val=127.0, 
                         return_quant_tensor=True
                     )
        
        self.conv1 = qnn.QuantConv2d( 
                         in_channels=in_channels1,
                         out_channels=out_channels1,
                         kernel_size=kernel_size, 
                         stride=1, 
                         padding=1,
                         weight_bit_width=weight_bit_width,
                         weight_quant_type=QuantType.BINARY,
                         bias=False
                     )
        
        self.bn1   = nn.BatchNorm2d(out_channels1)
        self.relu1 = qnn.QuantReLU(
                         bit_width=act_bit_width, 
                         return_quant_tensor=True
                     )
        
        self.pool1 = qnn.QuantMaxPool2d(2, return_quant_tensor=True)
        
        self.conv2 = qnn.QuantConv2d( 
                         in_channels=in_channels2,
                         out_channels=out_channels2,
                         kernel_size=kernel_size, 
                         stride=1, 
                         padding=1,
                         weight_bit_width=weight_bit_width,
                         weight_quant_type=QuantType.BINARY,
                         bias=False
                     )
        
        self.bn2   = nn.BatchNorm2d(out_channels2)
        self.relu2 = qnn.QuantReLU(
                         bit_width=act_bit_width, 
                         return_quant_tensor=True
                     )
        
        self.pool2 = qnn.QuantMaxPool2d(2, return_quant_tensor=True)
        
        self.fc1   = qnn.QuantLinear(
                         input_size, 
                         hidden1, 
                         weight_bit_width=weight_bit_width,
                         weight_quant_type=QuantType.BINARY,
                         bias=False
                     )
        
        self.bn3   = nn.BatchNorm1d(hidden1)
        self.relu3 = qnn.QuantReLU(
                         bit_width=act_bit_width, 
                         return_quant_tensor=True
                     )
        
        self.out   = qnn.QuantLinear(
                         hidden1, 
                         num_classes, 
                         weight_bit_width=weight_bit_width,
                         weight_quant_type=QuantType.BINARY,
                         bias=False
                     )

    def forward(self, x):
        out = self.input(x) # MNIST INPUT 28x28 channel 1 => 28x28x1
        out = self.pool1(self.relu1(self.bn1(self.conv1(out))))
                            # Conv OUTPUT 28X28 channel 64 => 28x28x64
                            # MaxPool OUTPUT 14x14 channel 64 => 14x14x64
        out = self.pool2(self.relu2(self.bn2(self.conv2(out))))
                            # Conv OUTPUT 14X14 channel 64 => 14x14x64
                            # MaxPool OUTPUT 7x7 channel 64 => 7x7x64
        out = out.reshape(out.shape[0], -1) # FC INPUT 7X7X64
        out = self.relu3(self.bn3(self.fc1(out))) # FC OUTPUT 64
        out = self.out(out) # OUTPUT 10
        return out
   
model = TCV_W1A1()