# __Quick Tutorial of Brevitas for Hardware-Oriented QNN Traning__

##### *Author: Yuhao Liu, TU Dresden & ScaDS.AI, Email: yuhao.liu1@tu-dresden.de*

## __Section I: Library Import for Tutorial__

#### __1.1 Improt PyTorch Library__

In [1]:
import torch 
import torch.nn as nn

#### __1.2 Import Basic Brevitas 0.11 Library__

In [2]:
import brevitas.nn as qnn # Import QNN layers in Brevitas
from brevitas.quant_tensor.int_quant_tensor import IntQuantTensor # Import Integer Quantization types for QNN

## __Section II: Introduction of Basic Quantization Type in Brevitas__

#### __2.1 How to define Integer Quantization Type tensors in Brevitas__

To train a low-precision QNN (1~8bit) for hardware accelerator design, *IntQuantTensor* is used for training and inferring the QNN models in *Brevitas*, which consists of six attitudes: 

<div style="width:fit-content; float:left; margin-right:20px;">

|Attitude      | Definiation                                |
|--------------|--------------------------------------------|
| *value*      | The non-quantized raw value                |
| *scale*      | The scale rate to quantize the raw value   |
| *zero_point* | The zero point to quantize the raw value   |
| *bit_width*  | The bitwidth of the quantized value        |
| *signed*     | Defining if the quantized value is signed  |
| *training*   | Defining if this value is used in training |

</div>
<div style="clear:both;"></div>

Therefore, we can try to create the *IntQuantTensor* objects.

First, we create three raw value tensors for our *IntQuantTensor* object. 

In [3]:
value_1d = torch.arange(0, 10, dtype=torch.float32)
value_2d = torch.arange(0, 16, dtype=torch.float32).reshape(4, 4)
value_3d = torch.arange(0, 32, dtype=torch.float32).reshape(2, 4, 4)

print("Raw Value for 1-d integer quantized tensor: \n", value_1d, "\n")
print("Raw Value for 2-d integer quantized tensor: \n", value_2d, "\n")
print("Raw Value for 3-d integer quantized tensor: \n", value_3d, "\n")

Raw Value for 1-d integer quantized tensor: 
 tensor([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.]) 

Raw Value for 2-d integer quantized tensor: 
 tensor([[ 0.,  1.,  2.,  3.],
        [ 4.,  5.,  6.,  7.],
        [ 8.,  9., 10., 11.],
        [12., 13., 14., 15.]]) 

Raw Value for 3-d integer quantized tensor: 
 tensor([[[ 0.,  1.,  2.,  3.],
         [ 4.,  5.,  6.,  7.],
         [ 8.,  9., 10., 11.],
         [12., 13., 14., 15.]],

        [[16., 17., 18., 19.],
         [20., 21., 22., 23.],
         [24., 25., 26., 27.],
         [28., 29., 30., 31.]]]) 



Then we can define the scale rate, zero point for three *IntQuantTensor* tensors. Their are fixed as 8 bit and unsigned data for now.

In [4]:
scale_1d, scale_2d, scale_3d = 0.1, 0.2, 0.5
zero_point_1d, zero_point_2d, zero_point_3d = 1.0, 2.0, 3.0
bit_width = 8
signed = False
training = False

Then, we can create three *IntQuantTensor* tensor as:

In [5]:
int_tensor_1d = IntQuantTensor(value = value_1d, 
                               scale=scale_1d, 
                               zero_point=zero_point_1d, 
                               bit_width=bit_width, 
                               signed=signed, 
                               training=training)
int_tensor_2d = IntQuantTensor(value = value_2d, 
                               scale=scale_2d, 
                               zero_point=zero_point_2d, 
                               bit_width=bit_width, 
                               signed=signed, 
                               training=training)
int_tensor_3d = IntQuantTensor(value = value_3d, 
                               scale=scale_3d, 
                               zero_point=zero_point_3d, 
                               bit_width=bit_width, 
                               signed=signed, 
                               training=training)

print("1-d integer quantized tensor: \n", int_tensor_1d, "\n")
print("2-d integer quantized tensor: \n", int_tensor_2d, "\n")
print("3-d integer quantized tensor: \n", int_tensor_3d, "\n")

1-d integer quantized tensor: 
 IntQuantTensor(value=tensor([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.]), scale=0.10000000149011612, zero_point=1.0, bit_width=8.0, signed_t=False, training_t=False) 

2-d integer quantized tensor: 
 IntQuantTensor(value=tensor([[ 0.,  1.,  2.,  3.],
        [ 4.,  5.,  6.,  7.],
        [ 8.,  9., 10., 11.],
        [12., 13., 14., 15.]]), scale=0.20000000298023224, zero_point=2.0, bit_width=8.0, signed_t=False, training_t=False) 

3-d integer quantized tensor: 
 IntQuantTensor(value=tensor([[[ 0.,  1.,  2.,  3.],
         [ 4.,  5.,  6.,  7.],
         [ 8.,  9., 10., 11.],
         [12., 13., 14., 15.]],

        [[16., 17., 18., 19.],
         [20., 21., 22., 23.],
         [24., 25., 26., 27.],
         [28., 29., 30., 31.]]]), scale=0.5, zero_point=3.0, bit_width=8.0, signed_t=False, training_t=False) 



Tips: There is one interesting things in this *IntQuantTensor*. When we create this tensor, what we defined is *signed* and *training*. However, when we print this tensor, we can see the attitudes about sign and training is *signed_t* and *training_t*. Following is their different:

In [6]:
print(int_tensor_3d.signed)
print(int_tensor_3d.signed_t)

False
tensor(False)


Therefore, we can find that the data in *IntQuantTensor* is actually saved as floating-point, not the quantized integers. If we want to see the real integer value of these *IntQuantType* tensors, we can try the following two methods:

1. Using the *int()* function of *IntQuantTensor*, which can output the quantized integer tensor based on the value, scale rate, and zero point.
This function follows $Q_{out} = \frac{Raw_{in}}{scale\_rate} + zero\_point$
2. Manually extract the raw value, scale rate, and zero point to round the integers.

In [7]:
real_int_tensor_1d = int_tensor_1d.int()
real_int_tensor_2d = int_tensor_2d.int()
real_int_tensor_3d = int_tensor_3d.int()

print("Real quantized 1-d tensor: \n", real_int_tensor_1d, "\n")
print("Real quantized 2-d tensor: \n", real_int_tensor_2d, "\n")
print("Real quantized 3-d tensor: \n", real_int_tensor_3d, "\n")

Real quantized 1-d tensor: 
 tensor([ 1, 11, 21, 31, 41, 51, 61, 71, 81, 91], dtype=torch.uint8) 

Real quantized 2-d tensor: 
 tensor([[ 2,  7, 12, 17],
        [22, 27, 32, 37],
        [42, 47, 52, 57],
        [62, 67, 72, 77]], dtype=torch.uint8) 

Real quantized 3-d tensor: 
 tensor([[[ 3,  5,  7,  9],
         [11, 13, 15, 17],
         [19, 21, 23, 25],
         [27, 29, 31, 33]],

        [[35, 37, 39, 41],
         [43, 45, 47, 49],
         [51, 53, 55, 57],
         [59, 61, 63, 65]]], dtype=torch.uint8) 



In [8]:
real_int_tensor_1d = int_tensor_1d.int(float_datatype=True)
real_int_tensor_2d = int_tensor_2d.int(float_datatype=True)
real_int_tensor_3d = int_tensor_3d.int(float_datatype=True)

print("Real quantized 1-d tensor: \n", real_int_tensor_1d, "\n")
print("Real quantized 2-d tensor: \n", real_int_tensor_2d, "\n")
print("Real quantized 3-d tensor: \n", real_int_tensor_3d, "\n")

Real quantized 1-d tensor: 
 tensor([ 1., 11., 21., 31., 41., 51., 61., 71., 81., 91.]) 

Real quantized 2-d tensor: 
 tensor([[ 2.,  7., 12., 17.],
        [22., 27., 32., 37.],
        [42., 47., 52., 57.],
        [62., 67., 72., 77.]]) 

Real quantized 3-d tensor: 
 tensor([[[ 3.,  5.,  7.,  9.],
         [11., 13., 15., 17.],
         [19., 21., 23., 25.],
         [27., 29., 31., 33.]],

        [[35., 37., 39., 41.],
         [43., 45., 47., 49.],
         [51., 53., 55., 57.],
         [59., 61., 63., 65.]]]) 



We can notice that the *int()* function needs a *float_datatype=True* to convert the quantized data to a floating-point format. Otherwise, it cannot be used in network training and inference. 

Another method is manually extracting the raw value, scale rate, and zero point, computing the result following the equation above, and rounding the output as integers.

In [9]:
extract_scale_1d, extract_scale_2d, extract_scale_3d = int_tensor_1d.scale, int_tensor_2d.scale, int_tensor_3d.scale
extract_zero_point_1d, extract_zero_point_2d, extract_zero_point_3d = int_tensor_1d.zero_point, int_tensor_2d.zero_point, int_tensor_3d.zero_point

manual_int_tensor_1d = torch.round((int_tensor_1d.value / int_tensor_1d.scale) + int_tensor_1d.zero_point)
manual_int_tensor_2d = torch.round((int_tensor_2d.value / int_tensor_2d.scale) + int_tensor_2d.zero_point)
manual_int_tensor_3d = torch.round((int_tensor_3d.value / int_tensor_3d.scale) + int_tensor_3d.zero_point)

print("Manually quantized 1-d tensor: \n", real_int_tensor_1d, "\n")
print("Manually quantized 2-d tensor: \n", real_int_tensor_2d, "\n")
print("Manually quantized 3-d tensor: \n", real_int_tensor_3d, "\n")

Manually quantized 1-d tensor: 
 tensor([ 1., 11., 21., 31., 41., 51., 61., 71., 81., 91.]) 

Manually quantized 2-d tensor: 
 tensor([[ 2.,  7., 12., 17.],
        [22., 27., 32., 37.],
        [42., 47., 52., 57.],
        [62., 67., 72., 77.]]) 

Manually quantized 3-d tensor: 
 tensor([[[ 3.,  5.,  7.,  9.],
         [11., 13., 15., 17.],
         [19., 21., 23., 25.],
         [27., 29., 31., 33.]],

        [[35., 37., 39., 41.],
         [43., 45., 47., 49.],
         [51., 53., 55., 57.],
         [59., 61., 63., 65.]]]) 



Actually, we don't need to extract the scale rate and zero point, *IntQuantTensor* offers one function *_pre_round_int_value* to do it

In [10]:
manual_int_tensor_1d = torch.round(int_tensor_1d._pre_round_int_value)
manual_int_tensor_2d = torch.round(int_tensor_2d._pre_round_int_value)
manual_int_tensor_3d = torch.round(int_tensor_3d._pre_round_int_value)

print("Manually quantized 1-d tensor: \n", real_int_tensor_1d, "\n")
print("Manually quantized 2-d tensor: \n", real_int_tensor_2d, "\n")
print("Manually quantized 3-d tensor: \n", real_int_tensor_3d, "\n")

Manually quantized 1-d tensor: 
 tensor([ 1., 11., 21., 31., 41., 51., 61., 71., 81., 91.]) 

Manually quantized 2-d tensor: 
 tensor([[ 2.,  7., 12., 17.],
        [22., 27., 32., 37.],
        [42., 47., 52., 57.],
        [62., 67., 72., 77.]]) 

Manually quantized 3-d tensor: 
 tensor([[[ 3.,  5.,  7.,  9.],
         [11., 13., 15., 17.],
         [19., 21., 23., 25.],
         [27., 29., 31., 33.]],

        [[35., 37., 39., 41.],
         [43., 45., 47., 49.],
         [51., 53., 55., 57.],
         [59., 61., 63., 65.]]]) 



Here, we are using the regular *torch.round* to round the results with floating-point format to integers. However, Brevitas also offered one *round_ste* function specifically designed for the training of QNN. If you are using manual rounding in your training code, you can consider using *round_ste*.

In [11]:
from brevitas.function.ops_ste import round_ste

round_ste_int_tensor_1d = round_ste(int_tensor_1d._pre_round_int_value)
round_ste_int_tensor_2d = round_ste(int_tensor_2d._pre_round_int_value)
round_ste_int_tensor_3d = round_ste(int_tensor_3d._pre_round_int_value)

print("round_ste quantized 1-d tensor: \n", real_int_tensor_1d, "\n")
print("round_ste quantized 2-d tensor: \n", real_int_tensor_2d, "\n")
print("round_ste quantized 3-d tensor: \n", real_int_tensor_3d, "\n")

round_ste quantized 1-d tensor: 
 tensor([ 1., 11., 21., 31., 41., 51., 61., 71., 81., 91.]) 

round_ste quantized 2-d tensor: 
 tensor([[ 2.,  7., 12., 17.],
        [22., 27., 32., 37.],
        [42., 47., 52., 57.],
        [62., 67., 72., 77.]]) 

round_ste quantized 3-d tensor: 
 tensor([[[ 3.,  5.,  7.,  9.],
         [11., 13., 15., 17.],
         [19., 21., 23., 25.],
         [27., 29., 31., 33.]],

        [[35., 37., 39., 41.],
         [43., 45., 47., 49.],
         [51., 53., 55., 57.],
         [59., 61., 63., 65.]]]) 



#### __2.2 What will Happen When the Quantized Data is Out of the Range of the Bit Width__

We define a 1-d raw tensor with a small scale rate and small bit width to test what will happen when the quantized output is out of the range of bit width

In [12]:
value_1d = torch.arange(0, 10, dtype=torch.float32)
scale_1d = 0.1
zero_point_1d = 1
bit_width = 4
signed = False

new_quant_tensor = IntQuantTensor(value = value_1d, 
                                  scale=scale_1d, 
                                  zero_point=zero_point_1d, 
                                  bit_width=bit_width, 
                                  signed=signed, 
                                  training=False)

print("Oringinal quantized 1-d tensor: \n", new_quant_tensor, "\n")

Oringinal quantized 1-d tensor: 
 IntQuantTensor(value=tensor([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.]), scale=0.10000000149011612, zero_point=1.0, bit_width=4.0, signed_t=False, training_t=False) 



If we use the *int()* function of *IntQuantTensor*, there will be an error to show that *IntQuantTensor not valid.*

In [13]:
try:
    print("Rounded quantized 1-d tensor with int(): \n", new_quant_tensor.int(float_datatype=True), "\n")
except Exception as e:
    print("There is an ERROR:", e)

There is an ERROR: IntQuantTensor not valid.


But for the manually quantization, the error will not happen, but the output will also out of the range of 4-bit unsigned quantization (0~15)

In [14]:
manual_round_int_tensor_1d = torch.round(new_quant_tensor._pre_round_int_value)
manual_round_ste_int_tensor_1d = round_ste(new_quant_tensor._pre_round_int_value)

print("Manully rounded quantized 1-d tensor with torch.int: \n", manual_round_int_tensor_1d, "\n")
print("Manully rounded quantized 1-d tensor with round_ste: \n", manual_round_ste_int_tensor_1d, "\n")

Manully rounded quantized 1-d tensor with torch.int: 
 tensor([ 1., 11., 21., 31., 41., 51., 61., 71., 81., 91.]) 

Manully rounded quantized 1-d tensor with round_ste: 
 tensor([ 1., 11., 21., 31., 41., 51., 61., 71., 81., 91.]) 



Therefore, we need to use the manully quantization applying a limitation for these quantized tensor if we want to show the quantized data anyway:

In [15]:
limited_round_int_tensor_1d = torch.where(manual_round_int_tensor_1d >= ((2**new_quant_tensor.bit_width) - 1), 
                                          torch.full_like(manual_round_int_tensor_1d, ((2**new_quant_tensor.bit_width) - 1)), 
                                          manual_round_int_tensor_1d)
limited_round_ste_int_tensor_1d = torch.where(manual_round_ste_int_tensor_1d >= ((2**new_quant_tensor.bit_width) - 1), 
                                          torch.full_like(manual_round_ste_int_tensor_1d, ((2**new_quant_tensor.bit_width) - 1)), 
                                          manual_round_ste_int_tensor_1d)

print("Limited rounded quantized 1-d tensor with torch.int: \n", limited_round_int_tensor_1d, "\n")
print("Limited rounded quantized 1-d tensor with round_ste: \n", limited_round_ste_int_tensor_1d, "\n")

Limited rounded quantized 1-d tensor with torch.int: 
 tensor([ 1., 11., 15., 15., 15., 15., 15., 15., 15., 15.]) 

Limited rounded quantized 1-d tensor with round_ste: 
 tensor([ 1., 11., 15., 15., 15., 15., 15., 15., 15., 15.]) 



For the signed quantzation, the scripts are similar:

In [16]:
value_1d = -1.0 * torch.arange(0, 10, dtype=torch.float32)
scale_1d = 0.1
zero_point_1d = 15
bit_width = 4
signed = True

new_quant_tensor = IntQuantTensor(value = value_1d, 
                                  scale=scale_1d, 
                                  zero_point=zero_point_1d, 
                                  bit_width=bit_width, 
                                  signed=signed, 
                                  training=False)

print("Oringinal signed quantized 1-d tensor: \n", new_quant_tensor, "\n")

manual_round_int_tensor_1d = torch.round(new_quant_tensor._pre_round_int_value)
manual_round_ste_int_tensor_1d = round_ste(new_quant_tensor._pre_round_int_value)

print("Manully rounded quantized 1-d tensor with torch.int: \n", manual_round_int_tensor_1d, "\n")
print("Manully rounded quantized 1-d tensor with round_ste: \n", manual_round_ste_int_tensor_1d, "\n")

limited_round_int_tensor_1d = torch.where(manual_round_int_tensor_1d >= ((2**(new_quant_tensor.bit_width-1)) - 1), 
                                          torch.full_like(manual_round_int_tensor_1d, ((2**(new_quant_tensor.bit_width-1)) - 1)), 
                                          torch.where(manual_round_int_tensor_1d <= (-1.0*(2**(new_quant_tensor.bit_width-1))), 
                                                torch.full_like(manual_round_int_tensor_1d, (-1.0*(2**(new_quant_tensor.bit_width-1)))), 
                                                manual_round_int_tensor_1d))
limited_round_ste_int_tensor_1d = torch.where(manual_round_ste_int_tensor_1d >= ((2**(new_quant_tensor.bit_width-1)) - 1), 
                                          torch.full_like(manual_round_ste_int_tensor_1d, ((2**(new_quant_tensor.bit_width-1)) - 1)), 
                                          torch.where(manual_round_ste_int_tensor_1d <= (-1.0*(2**(new_quant_tensor.bit_width-1))), 
                                                torch.full_like(manual_round_ste_int_tensor_1d, (-1.0*(2**(new_quant_tensor.bit_width-1)))), 
                                                manual_round_ste_int_tensor_1d))

print("Limited rounded quantized 1-d tensor with torch.int: \n", limited_round_int_tensor_1d, "\n")
print("Limited rounded quantized 1-d tensor with round_ste: \n", limited_round_ste_int_tensor_1d, "\n")

Oringinal signed quantized 1-d tensor: 
 IntQuantTensor(value=tensor([-0., -1., -2., -3., -4., -5., -6., -7., -8., -9.]), scale=0.10000000149011612, zero_point=15.0, bit_width=4.0, signed_t=True, training_t=False) 

Manully rounded quantized 1-d tensor with torch.int: 
 tensor([ 15.,   5.,  -5., -15., -25., -35., -45., -55., -65., -75.]) 

Manully rounded quantized 1-d tensor with round_ste: 
 tensor([ 15.,   5.,  -5., -15., -25., -35., -45., -55., -65., -75.]) 

Limited rounded quantized 1-d tensor with torch.int: 
 tensor([ 7.,  5., -5., -8., -8., -8., -8., -8., -8., -8.]) 

Limited rounded quantized 1-d tensor with round_ste: 
 tensor([ 7.,  5., -5., -8., -8., -8., -8., -8., -8., -8.]) 



Therefore, we can create one function to do these:

In [17]:
def quant_tensor(raw_in, quant_type = False):
    if (raw_in.training):
        manual_round_int_tensor = round_ste(raw_in._pre_round_int_value)
        if (raw_in.signed):
            limited_round_int_tensor = torch.where(manual_round_int_tensor >= ((2**(raw_in.bit_width-1)) - 1), 
                                                   torch.full_like(manual_round_int_tensor, ((2**(raw_in.bit_width-1)) - 1)), 
                                                   torch.where(manual_round_int_tensor <= (-1.0*(2**(raw_in.bit_width-1))), 
                                                               torch.full_like(manual_round_int_tensor, (-1.0*(2**(raw_in.bit_width-1)))), 
                                                               manual_round_int_tensor))
        else:
            limited_round_int_tensor = torch.where(manual_round_int_tensor >= ((2**raw_in.bit_width) - 1), 
                                                   torch.full_like(manual_round_int_tensor, ((2**raw_in.bit_width) - 1)), 
                                                   manual_round_int_tensor)
    else:
        manual_round_int_tensor = torch.round(raw_in._pre_round_int_value)
        if (raw_in.signed):
            limited_round_int_tensor = torch.where(manual_round_int_tensor >= ((2**(raw_in.bit_width-1)) - 1), 
                                                   torch.full_like(manual_round_int_tensor, ((2**(raw_in.bit_width-1)) - 1)), 
                                                   torch.where(manual_round_int_tensor <= (-1.0*(2**(raw_in.bit_width-1))), 
                                                               torch.full_like(manual_round_int_tensor, (-1.0*(2**(raw_in.bit_width-1)))), 
                                                               manual_round_int_tensor))
        else:
            limited_round_int_tensor = torch.where(manual_round_int_tensor >= ((2**raw_in.bit_width) - 1), 
                                                   torch.full_like(manual_round_int_tensor, ((2**raw_in.bit_width) - 1)), 
                                                   manual_round_int_tensor)

    if (quant_type):
        return IntQuantTensor(value = limited_round_int_tensor, 
                              scale=1.0, 
                              zero_point=0.0, 
                              bit_width=raw_in.bit_width, 
                              signed=raw_in.signed, 
                              training=raw_in.training)
    else:
        return limited_round_int_tensor

Following is the test of this function:

In [18]:
value_1d = torch.arange(0, 10, dtype=torch.float32)
value_2d = -1.0 * torch.arange(0, 16, dtype=torch.float32).reshape(4, 4)
value_3d = torch.arange(0, 32, dtype=torch.float32).reshape(2, 4, 4)

scale_1d, scale_2d, scale_3d = 0.5, 0.2, 0.1
zero_point_1d, zero_point_2d, zero_point_3d = 1.5, 2.5, 3.5
bit_width = 4
signed_1d, signed_2d, signed_3d = False, True, False
training = False

test_int_tensor_1d = IntQuantTensor(value = value_1d, 
                                    scale=scale_1d, 
                                    zero_point=zero_point_1d, 
                                    bit_width=bit_width, 
                                    signed=signed_1d, 
                                    training=training)
test_int_tensor_2d = IntQuantTensor(value = value_2d, 
                                    scale=scale_2d, 
                                    zero_point=zero_point_2d, 
                                    bit_width=bit_width, 
                                    signed=signed_2d, 
                                    training=training)
test_int_tensor_3d = IntQuantTensor(value = value_3d, 
                                    scale=scale_3d, 
                                    zero_point=zero_point_3d, 
                                    bit_width=bit_width, 
                                    signed=signed_3d, 
                                    training=training)

print("Raw 1-d integer quantized tensor for test: \n", test_int_tensor_1d, "\n")
print("Raw 2-d integer quantized tensor for test: \n", test_int_tensor_2d, "\n")
print("Raw 3-d integer quantized tensor for test: \n", test_int_tensor_3d, "\n \n")

quantized_int_tensor_1d = quant_tensor(test_int_tensor_1d, quant_type=True)
quantized_int_tensor_2d = quant_tensor(test_int_tensor_2d)
quantized_int_tensor_3d = quant_tensor(test_int_tensor_3d)

print("Rounded 1-d integer quantized tensor result: \n", quantized_int_tensor_1d, "\n")
print("Rounded 2-d integer quantized tensor result: \n", quantized_int_tensor_2d, "\n")
print("Rounded 3-d integer quantized tensor result: \n", quantized_int_tensor_3d, "\n")

Raw 1-d integer quantized tensor for test: 
 IntQuantTensor(value=tensor([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.]), scale=0.5, zero_point=1.5, bit_width=4.0, signed_t=False, training_t=False) 

Raw 2-d integer quantized tensor for test: 
 IntQuantTensor(value=tensor([[ -0.,  -1.,  -2.,  -3.],
        [ -4.,  -5.,  -6.,  -7.],
        [ -8.,  -9., -10., -11.],
        [-12., -13., -14., -15.]]), scale=0.20000000298023224, zero_point=2.5, bit_width=4.0, signed_t=True, training_t=False) 

Raw 3-d integer quantized tensor for test: 
 IntQuantTensor(value=tensor([[[ 0.,  1.,  2.,  3.],
         [ 4.,  5.,  6.,  7.],
         [ 8.,  9., 10., 11.],
         [12., 13., 14., 15.]],

        [[16., 17., 18., 19.],
         [20., 21., 22., 23.],
         [24., 25., 26., 27.],
         [28., 29., 30., 31.]]]), scale=0.10000000149011612, zero_point=3.5, bit_width=4.0, signed_t=False, training_t=False) 
 

Rounded 1-d integer quantized tensor result: 
 IntQuantTensor(value=tensor([ 2.,  4.,  6.,  8.

## __Section III: Computation of Quantized Tensor in Brevitas__
#### __3.1 Addition of Quantized Tensor in Brevitas__

Considering the quantized tensors in Brevitas have additional attitudes of scale rate, zero point, and bit width. Their computation is more complex than regular tensors. Here we define two *IntQuantTensor* with completely different scale rate, zero point, and bit width to execute the addition. There will be an ERROR of *"Scaling factors are different"*

In [19]:
value_1 = torch.arange(1, 5, dtype=torch.float32)
value_2 = torch.arange(2, 6, dtype=torch.float32)

scale_1, scale_2 = 0.2, 0.5
zero_point_1, zero_point_2 = 2.5, 3.5
bit_width_1, bit_width_2 = 6, 8
signed = False
training = False

int_tensor_1 = IntQuantTensor(value=value_1, 
                              scale=scale_1, 
                              zero_point=zero_point_1, 
                              bit_width=bit_width_1, 
                              signed=signed, 
                              training=training)
int_tensor_2 = IntQuantTensor(value=value_2, 
                              scale=scale_2, 
                              zero_point=zero_point_2, 
                              bit_width=bit_width_2, 
                              signed=signed, 
                              training=training)
try:
    int_tensor_result = int_tensor_1 + int_tensor_2
except RuntimeError as e:
    print(f"Caught a RuntimeError: {e}")

Caught a RuntimeError: Scaling factors are different


Therefore, we can define a function to change the scale rate when it's necessary

In [None]:
def quant_transform(tensor_in, new_scale):
    

If we unify the scale rate with different zero points and bit widths of these two tensors, the ERROR has gone.

However, the interesting thing is that if we directly add two quantized tensors together and convert them to an integer format, comparing with converting two input tensors as integers first and computing their sum second, these two results are different.

In [20]:
value_1 = torch.arange(1, 5, dtype=torch.float32)
value_2 = torch.arange(2, 6, dtype=torch.float32)

scale = 0.2
zero_point_1, zero_point_2 = 2.5, 3.5
bit_width_1, bit_width_2 = 6, 8
signed = False
training = False

int_tensor_1 = IntQuantTensor(value=value_1, 
                              scale=scale, 
                              zero_point=zero_point_1, 
                              bit_width=bit_width_1, 
                              signed=signed, 
                              training=training)
int_tensor_2 = IntQuantTensor(value=value_2, 
                              scale=scale, 
                              zero_point=zero_point_2, 
                              bit_width=bit_width_2, 
                              signed=signed, 
                              training=training)

int_tensor_result = int_tensor_1 + int_tensor_2
    
print("Raw Input Tensor 1: \n", int_tensor_1, "\n")
print("Raw Input Tensor 2: \n", int_tensor_2, "\n")
print("Raw Addition Result: \n", int_tensor_result, "\n")

rounded_int_tensor_1 = quant_tensor(int_tensor_1)
rounded_int_tensor_2 = quant_tensor(int_tensor_2)
rounded_int_tensor_result = rounded_int_tensor_1+rounded_int_tensor_2
rounded_old_int_tensor_result = quant_tensor(int_tensor_result)

print("Quantized Input Tensor 1: \n", rounded_int_tensor_1, "\n")
print("Quantized Input Tensor 2: \n", rounded_int_tensor_2, "\n")
print("New Addition Result: \n", rounded_int_tensor_result, "\n")
print("Quantized Previous Addition Result: \n", rounded_old_int_tensor_result, "\n")

Raw Input Tensor 1: 
 IntQuantTensor(value=tensor([1., 2., 3., 4.]), scale=0.20000000298023224, zero_point=2.5, bit_width=6.0, signed_t=False, training_t=False) 

Raw Input Tensor 2: 
 IntQuantTensor(value=tensor([2., 3., 4., 5.]), scale=0.20000000298023224, zero_point=3.5, bit_width=8.0, signed_t=False, training_t=False) 

Raw Addition Result: 
 IntQuantTensor(value=tensor([3., 5., 7., 9.]), scale=0.20000000298023224, zero_point=6.0, bit_width=9.0, signed_t=False, training_t=False) 

Quantized Input Tensor 1: 
 tensor([ 8., 12., 18., 22.]) 

Quantized Input Tensor 2: 
 tensor([14., 18., 24., 28.]) 

New Addition Result: 
 tensor([22., 30., 42., 50.]) 

Quantized Previous Addition Result: 
 tensor([21., 31., 41., 51.]) 



However, if we compute the integer format of two input *IntQuantTensor*s without the rounding, then add them together and round the result, we can find that the result is the same as the rounded integer value for the result of directly adding the input *IntQuantTensor*s. It means that the addition of *IntQuantTensor* is not 

In [21]:
unround_int_tensor_1 = int_tensor_1.value / int_tensor_1.scale + int_tensor_1.zero_point
unround_int_tensor_2 = int_tensor_2.value / int_tensor_2.scale + int_tensor_2.zero_point
unround_int_tensor_sum = unround_int_tensor_1 + unround_int_tensor_2
simulated_int_tensor_sum = torch.round(unround_int_tensor_sum)

print("Quantized Input Tensor 1 without Rounding: \n", unround_int_tensor_1, "\n")
print("Quantized Input Tensor 2 without Rounding: \n", unround_int_tensor_2, "\n")
print("Addition Result without Rounding: \n", unround_int_tensor_sum, "\n")
print("Rounded Quantized Addition Result: \n", simulated_int_tensor_sum, "\n")

Quantized Input Tensor 1 without Rounding: 
 tensor([ 7.5000, 12.5000, 17.5000, 22.5000]) 

Quantized Input Tensor 2 without Rounding: 
 tensor([13.5000, 18.5000, 23.5000, 28.5000]) 

Addition Result without Rounding: 
 tensor([21., 31., 41., 51.]) 

Rounded Quantized Addition Result: 
 tensor([21., 31., 41., 51.]) 



This means that Brevitas has not really compute the *IntQuantTensor* as integers. To accurately simulate the integer addition of *IntQuantTensor*, we can create a function as:

In [26]:
def quant_add(tensor_in_0, tensor_in_1, quant_type = False):
    unround_tensor_in_0 = quant_tensor(tensor_in_0)
    unround_tensor_in_1 = quant_tensor(tensor_in_1)
    
    new_bit_width = max(tensor_in_0.bit_width, tensor_in_1.bit_width) + 1
    new_signed = tensor_in_0.signed | tensor_in_1.signed
    new_training = tensor_in_0.training | tensor_in_1.training
    
    int_tensor_sum = unround_tensor_in_0+unround_tensor_in_1

    if (new_signed):
        limited_int_tensor_sum = torch.where(int_tensor_sum >= ((2**(new_bit_width-1)) - 1), 
                                             torch.full_like(int_tensor_sum, ((2**(new_bit_width-1)) - 1)), 
                                             torch.where(int_tensor_sum <= (-1.0*(2**(new_bit_width-1))), 
                                                         torch.full_like(int_tensor_sum, (-1.0*(2**(new_bit_width-1)))), 
                                                         int_tensor_sum))
    else:
        limited_int_tensor_sum = torch.where(int_tensor_sum >= ((2**new_bit_width) - 1), 
                                             torch.full_like(int_tensor_sum, ((2**new_bit_width) - 1)), 
                                                             int_tensor_sum)
    

    if (quant_type):
        
        return IntQuantTensor(value = limited_int_tensor_sum, 
                              scale=1.0, 
                              zero_point=0.0, 
                              bit_width=new_bit_width, 
                              signed=new_signed, 
                              training=new_training)
    else:
        return limited_int_tensor_sum

In [28]:
print(quant_add(int_tensor_1, int_tensor_2, quant_type=True))

IntQuantTensor(value=tensor([22., 30., 42., 50.]), scale=1.0, zero_point=0.0, bit_width=9.0, signed_t=False, training_t=False)
