# Selective Quantization

In [1]:
import torch
from torch import nn

class LeNet(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.l1 = nn.Linear(28 * 28, 10)
        self.relu1 = nn.ReLU(inplace=True)

    def forward(self, x):
        return self.relu1(self.l1(x.view(x.size(0), -1)))

## Selective qconfig assignment and top level transform

In [2]:
model = LeNet()
model.l1.qconfig = torch.ao.quantization.get_default_qat_qconfig()
torch.ao.quantization.prepare_qat(model, inplace=True)
print(model)

LeNet(
  (l1): Linear(
    in_features=784, out_features=10, bias=True
    (weight_fake_quant): FusedMovingAvgObsFakeQuantize(
      fake_quant_enabled=tensor([1]), observer_enabled=tensor([1]), scale=tensor([1.]), zero_point=tensor([0], dtype=torch.int32), dtype=torch.qint8, quant_min=-128, quant_max=127, qscheme=torch.per_channel_symmetric, reduce_range=False
      (activation_post_process): MovingAveragePerChannelMinMaxObserver(min_val=tensor([]), max_val=tensor([]))
    )
    (activation_post_process): FusedMovingAvgObsFakeQuantize(
      fake_quant_enabled=tensor([1]), observer_enabled=tensor([1]), scale=tensor([1.]), zero_point=tensor([0], dtype=torch.int32), dtype=torch.quint8, quant_min=0, quant_max=127, qscheme=torch.per_tensor_affine, reduce_range=True
      (activation_post_process): MovingAverageMinMaxObserver(min_val=inf, max_val=-inf)
    )
  )
  (relu1): ReLU(inplace=True)
)




In [7]:
print(type(model.l1))

<class 'torch.ao.nn.qat.modules.linear.Linear'>


## Selective qconfig assignment and selective transform

In [3]:
model2 = LeNet()
model2.l1.qconfig = torch.ao.quantization.get_default_qat_qconfig()
torch.ao.quantization.prepare_qat(model2.l1, inplace=True)
print(model2)

LeNet(
  (l1): Linear(
    in_features=784, out_features=10, bias=True
    (activation_post_process): FusedMovingAvgObsFakeQuantize(
      fake_quant_enabled=tensor([1]), observer_enabled=tensor([1]), scale=tensor([1.]), zero_point=tensor([0], dtype=torch.int32), dtype=torch.quint8, quant_min=0, quant_max=127, qscheme=torch.per_tensor_affine, reduce_range=True
      (activation_post_process): MovingAverageMinMaxObserver(min_val=inf, max_val=-inf)
    )
  )
  (relu1): ReLU(inplace=True)
)




In [6]:
print(type(model2.l1))

<class 'torch.nn.modules.linear.Linear'>


`prepare_qat()` calls `convert()`, which doesn’t convert the root module, so if you print the type of l1, in case 2 model2.l1 is root module thus not converted and still has type <class ‘torch.nn.modules.linear.Linear’> which doesn’t have weight_fake_quant attribute, while model.l1 is type <class ‘torch.ao.nn.qat.modules.linear.Linear’> which has weight_fake_quant attribute.

## Exclude layers from quantization

We can use the `model.layer.qconfig = None` syntax to turn off quantization for a layer and all of its children.