In [None]:
#| default_exp quantize.quantizer

In [None]:
#| export
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.ao.quantization import get_default_qconfig_mapping
import torch.ao.quantization.quantize_fx as quantize_fx
from torch.ao.quantization.quantize_fx import convert_fx, prepare_fx

In [None]:
#| include: false
from nbdev.showdoc import *
from fastai.vision.all import *
import warnings
warnings.filterwarnings('ignore')

In [None]:
#| export
class Quantizer():
    def __init__(self, backend="x86"):
        self.qconfig = get_default_qconfig_mapping(backend)
    
    def quantize(self, model, calibration_dl):
        x, _ = calibration_dl.valid.one_batch()
        model_prepared = prepare_fx(model.eval(), self.qconfig, x)
        _ = [model_prepared(xb.to('cpu')) for xb, _ in calibration_dl.valid]
            
        return convert_fx(model_prepared)

In [None]:
show_doc(Quantizer)

---

### Quantizer

>      Quantizer (backend='x86')

Initialize self.  See help(type(self)) for accurate signature.

In [None]:
path = untar_data(URLs.PETS)
files = get_image_files(path/"images")

def label_func(f): return f[0].isupper()

dls = ImageDataLoaders.from_name_func(path, files, label_func, item_tfms=Resize(64))

In [None]:
import timm
pretrained_resnet_34 = timm.create_model('resnet34', pretrained=True)
qt = Quantizer()

q_model = qt.quantize(pretrained_resnet_34, dls); q_model

GraphModule(
  (conv1): QuantizedConvReLU2d(3, 64, kernel_size=(7, 7), stride=(2, 2), scale=0.057870179414749146, zero_point=0, padding=(3, 3))
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Module(
    (0): Module(
      (conv1): QuantizedConv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), scale=0.06179201602935791, zero_point=77, padding=(1, 1))
      (drop_block): Identity()
      (act1): ReLU(inplace=True)
      (aa): Identity()
      (conv2): QuantizedConv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), scale=0.18166492879390717, zero_point=39, padding=(1, 1))
    )
    (1): Module(
      (conv1): QuantizedConv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), scale=0.12881271541118622, zero_point=68, padding=(1, 1))
      (drop_block): Identity()
      (act1): ReLU(inplace=True)
      (aa): Identity()
      (conv2): QuantizedConv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), scale=0.21964868903160095, zero_point=87, padding=(1, 1))
    )


In [None]:
#model = resnet18()
#model.fc = nn.Linear(512, 2)

qt = Quantizer()

q_model = qt.quantize(model, dls); q_model

KeyboardInterrupt: 