In [13]:
import torch


In [14]:
input_data = torch.tensor(
        [
            [-1.0, 0.5, 0.5],
            [0.5, 0.5, -0.5],
            [-0.5, 0.0, 0.5],
        ],
        dtype=torch.float32,
    )

weight_data = torch.tensor(
        [
            [0.5, -0.5, 0.3],
            [0.2, -0.2, 0.1],
            [0.4, -0.4, 0.2],
            [0.6, -0.6, 0.3],
            [0.7, -0.7, 0.4],
            [0.8, -0.8, 0.5],
            [0.9, -0.9, 0.6],
            [1.0, -1.0, 0.7],
            [1.1, -1.1, 0.8],
            [1.2, -1.2, 0.9],
        ],
        dtype=torch.float32,
    )

bias_data = torch.tensor(
    [0.5, -0.5, 0.3, 0.2, -0.2, 0.1, 0.4, -0.4, 0.2, 0.6], dtype=torch.float32
)

In [15]:
output_data = input_data.mm(weight_data.t()) + bias_data    
print(output_data)

tensor([[-0.1000, -0.7500, -0.2000, -0.5500, -1.0500, -0.8500, -0.6500, -1.5500,
         -1.0500, -0.7500],
        [ 0.3500, -0.5500,  0.2000,  0.0500, -0.4000, -0.1500,  0.1000, -0.7500,
         -0.2000,  0.1500],
        [ 0.4000, -0.5500,  0.2000,  0.0500, -0.3500, -0.0500,  0.2500, -0.5500,
          0.0500,  0.4500]])


In [16]:
def calculate_asymmetric_quant_params(
    min_float: torch.FloatTensor,
    max_float: torch.FloatTensor,
    min_quant: torch.IntTensor,
    max_quant: torch.IntTensor,
    eps: torch.FloatTensor,
):
    scale_factor = (max_float - min_float) / (max_quant.float() - min_quant.float())
    scale_factor = torch.max(scale_factor, eps)

    zero_point = max_quant - (max_float / scale_factor)
    zero_point = zero_point.round_().clamp(min_quant, max_quant)

    return scale_factor, zero_point, min_float, max_float

In [17]:
def calculate_symmetric_quant_params(
    min_float: torch.FloatTensor,
    max_float: torch.FloatTensor,
    min_quant: torch.IntTensor,
    max_quant: torch.IntTensor,
    eps: torch.FloatTensor,
):
    max_extent = torch.max(torch.abs(min_float), torch.abs(max_float))
    max_float = max_extent
    min_float = -max_extent

    scale_factor = (max_float - min_float) / (max_quant.float() - min_quant.float())
    scale_factor = torch.max(scale_factor, eps)

    zero_point = torch.zeros(scale_factor.size())

    return scale_factor, zero_point, min_float, max_float


In [18]:
scale_factor_input, zero_point, min_float, max_float = calculate_asymmetric_quant_params(
    input_data.min(), 
    input_data.max(), 
    min_quant=torch.tensor([-128], dtype=torch.int32), 
    max_quant=torch.tensor([127], dtype=torch.int32), 
    eps=torch.tensor([1e-6], dtype=torch.float32)
)
print(f"scale_factor: {scale_factor_input.item()}")
print(f"zero_point: {zero_point}")
print(f"min_float: {min_float.item()}")
print(f"max_float: {max_float.item()}")

scale_factor: 0.0058823530562222
zero_point: tensor([42.])
min_float: -1.0
max_float: 0.5


In [19]:
def quantize( x: torch.FloatTensor,
        scale_factor: torch.FloatTensor,
        zero_point: torch.IntTensor,
        min_quant: torch.IntTensor,
        max_quant: torch.IntTensor) -> torch.IntTensor:
        x_q = x / scale_factor + zero_point
        x_q = x_q.round_().clamp(min=min_quant.item(), max=max_quant.item())
        x_q = x_q.to(torch.int32)
        return x_q


In [20]:
scale_factor_weight, zero_point, min_float, max_float = calculate_asymmetric_quant_params(
    weight_data.min(), 
    weight_data.max(),
    min_quant=torch.tensor([-128], dtype=torch.int32), 
    max_quant=torch.tensor([127], dtype=torch.int32), 
    eps=torch.tensor([1e-6], dtype=torch.float32)
)
print(f"scale_factor: {scale_factor_weight.item()}")
print(f"zero_point: {zero_point}")
print(f"min_float: {min_float.item()}")
print(f"max_float: {max_float.item()}")

scale_factor: 0.00941176526248455
zero_point: tensor([-0.])
min_float: -1.2000000476837158
max_float: 1.2000000476837158


In [21]:
q_weight = quantize(weight_data, 
                    scale_factor_weight, 
                    zero_point, 
                    min_quant=torch.tensor([-128], dtype=torch.int32), 
                    max_quant=torch.tensor([127], dtype=torch.int32))
print(q_weight)
q_weight = q_weight - zero_point
print(q_weight.int())

tensor([[  53,  -53,   32],
        [  21,  -21,   11],
        [  42,  -42,   21],
        [  64,  -64,   32],
        [  74,  -74,   42],
        [  85,  -85,   53],
        [  96,  -96,   64],
        [ 106, -106,   74],
        [ 117, -117,   85],
        [ 127, -128,   96]], dtype=torch.int32)
tensor([[  53,  -53,   32],
        [  21,  -21,   11],
        [  42,  -42,   21],
        [  64,  -64,   32],
        [  74,  -74,   42],
        [  85,  -85,   53],
        [  96,  -96,   64],
        [ 106, -106,   74],
        [ 117, -117,   85],
        [ 127, -128,   96]], dtype=torch.int32)


In [22]:
scale_factor_bias, zero_point, min_float, max_float = calculate_symmetric_quant_params(
    min_float=bias_data.min(),
    max_float=bias_data.max(), 
    min_quant=torch.tensor([-127], dtype=torch.int32), 
    max_quant=torch.tensor([127], dtype=torch.int32), 
    eps=torch.tensor([1e-6], dtype=torch.float32)

)
print(f"scale_factor: {scale_factor_bias.item()}")
print(f"zero_point: {zero_point}")
print(f"min_float: {min_float.item()}")
print(f"max_float: {max_float.item()}")

scale_factor: 0.004724409431219101
zero_point: tensor([0.])
min_float: -0.6000000238418579
max_float: 0.6000000238418579


In [23]:
scale_factor_output, zero_point, min_float, max_float = calculate_asymmetric_quant_params(
    min_float=output_data.min(),
    max_float=output_data.max(), 
    min_quant=torch.tensor([-128], dtype=torch.int32), 
    max_quant=torch.tensor([127], dtype=torch.int32), 
    eps=torch.tensor([1e-6], dtype=torch.float32)

)
print(f"scale_factor: {scale_factor_output.item()}")
print(f"zero_point: {zero_point}")
print(f"min_float: {min_float.item()}")
print(f"max_float: {max_float.item()}")

scale_factor: 0.007843137718737125
zero_point: tensor([70.])
min_float: -1.5499999523162842
max_float: 0.44999998807907104


In [24]:
quant_bits = 8+1+8+1
min_quant = torch.tensor([-(1 << (quant_bits - 1))], dtype=torch.int32)
max_quant = torch.tensor([(1 << (quant_bits - 1)) - 1], dtype=torch.int32)

scale_factor_bias = scale_factor_input * scale_factor_weight
zero_point = torch.tensor([0], dtype=torch.int32)

q_bias = quantize(bias_data,
                    scale_factor_bias, 
                    zero_point, 
                    min_quant=min_quant, 
                    max_quant=max_quant)
print(q_bias)
q_bias = q_bias - zero_point
print(q_bias.int())

tensor([ 9031, -9031,  5419,  3612, -3612,  1806,  7225, -7225,  3612, 10838],
       dtype=torch.int32)
tensor([ 9031, -9031,  5419,  3612, -3612,  1806,  7225, -7225,  3612, 10838],
       dtype=torch.int32)


In [None]:
a = torch.tensor([[  53,  -53,   32],
        [  21,  -21,   11],
        [  42,  -42,   21],
        [  64,  -64,   32],
        [  74,  -74,   42],
        [  85,  -85,   53],
        [  96,  -96,   64],
        [ 106, -106,   74],
        [ 117, -117,   85],
        [ 127, -128,   96]], dtype=torch.int32)

In [None]:
b = torch.tensor([ 9031, -9031,  5419,  3612, -3612,  1806,  7225, -7225,  3612, 10838],
       dtype=torch.int32)