### **Quantization steps**

In [43]:
import torch
import torch.ao.quantization.quantize_fx as quantize_fx
from tinynn.graph.quantization.quantizer import QATQuantizer
from torchvision.transforms import transforms
from src.models.unet_module import UNETLitModule
from src.models.components.depth_net_efficient_ffn import DepthNet
from src.data.components.nyu_dataset import NYUDataset
from torch.utils.data import DataLoader, Dataset, random_split
from src.data.components.custom_transforms import BilinearInterpolation, NormalizeData

In [61]:
model_ckpt = "logs/train/runs/2024-05-31_03-13-02/checkpoints/last.ckpt"

In [62]:
model = UNETLitModule.load_from_checkpoint(model_ckpt)

/net/tscratch/people/plgkzaleska/envs/ennca-project/lib/python3.10/site-packages/lightning/pytorch/utilities/parsing.py:199: Attribute 'net' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['net'])`.


In [63]:
model.net

DepthNet(
  (encoder): EfficientNet(
    (layer1): Sequential(
      (0): Conv2dNormActivation(
        (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): SiLU(inplace=True)
      )
      (1): Sequential(
        (0): MBConv(
          (block): Sequential(
            (0): Conv2dNormActivation(
              (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
              (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): SiLU(inplace=True)
            )
            (1): SqueezeExcitation(
              (avgpool): AdaptiveAvgPool2d(output_size=1)
              (fc1): Conv2d(32, 8, kernel_size=(1, 1), stride=(1, 1))
              (fc2): Conv2d(8, 32, kernel_size=(1, 1), stride=(1, 1))
              (activation): SiLU(inplace=True)
              (scale_activa

### **Fuse BatchNorm**

In [56]:
model_fuse = quantize_fx.fuse_fx(model.eval())

In [57]:
model_fuse

GraphModule(
  (net): Module(
    (encoder): Module(
      (layer1): Module(
        (0): Module(
          (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
          (2): SiLU(inplace=True)
        )
        (1): Module(
          (0): Module(
            (block): Module(
              (0): Module(
                (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32)
                (2): SiLU(inplace=True)
              )
              (1): Module(
                (avgpool): AdaptiveAvgPool2d(output_size=1)
                (fc1): Conv2d(32, 8, kernel_size=(1, 1), stride=(1, 1))
                (activation): SiLU(inplace=True)
                (fc2): Conv2d(8, 32, kernel_size=(1, 1), stride=(1, 1))
                (scale_activation): Sigmoid()
              )
              (2): Module(
                (0): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1))
              )
            )
          )
        )
        (2): Module(
          (

### **PTQ**

In [64]:
model.eval()

UNETLitModule(
  (net): DepthNet(
    (encoder): EfficientNet(
      (layer1): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
          (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): SiLU(inplace=True)
        )
        (1): Sequential(
          (0): MBConv(
            (block): Sequential(
              (0): Conv2dNormActivation(
                (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
                (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
                (2): SiLU(inplace=True)
              )
              (1): SqueezeExcitation(
                (avgpool): AdaptiveAvgPool2d(output_size=1)
                (fc1): Conv2d(32, 8, kernel_size=(1, 1), stride=(1, 1))
                (fc2): Conv2d(8, 32, kernel_size=(1, 1), stride=(1, 1))
            

In [41]:
transforms_img = transforms.Compose(
    [transforms.PILToTensor(), transforms.Resize((224, 224))]
)

transforms_mask_train = transforms.Compose(
    [transforms.ToTensor(), BilinearInterpolation((56, 56))]
)

In [44]:
trainset = NYUDataset(
    "nyu2_train.csv",
    "data/",
    transform=transforms_img,
    target_transform=transforms_mask_train
)

data_train, data_val = random_split(
    dataset=trainset,
    lengths=[0.8, 0.2],
    generator=torch.Generator().manual_seed(42),
)

val_dataloader = DataLoader(
    dataset=data_val,
    batch_size=32,
    num_workers=2
)

In [49]:
def calibration(model, num_iterations, val_dataloader):
    count = 0
    for data in val_dataloader:
        img, mask = data

        if torch.cuda.is_available():
            img = img.cuda()
            mask = mask.cuda()
        model(img)

        count += 1

        if count >= num_iterations:
            break

    return model

In [65]:
quantizer_per_tensor = QATQuantizer(
    model,
    torch.randn(1,3,52,52),
    work_dir='quant_output',
    config={
        'asymmetric': True,
        'backend': 'qnnpack',
        "disable_requantization_for_cat": True,
        'per_tensor': True,
})

quantizer_per_channel = QATQuantizer(
    model,
    torch.randn(1,3,52,52),
    work_dir='quant_output',
    config={
        'asymmetric': True,
        'backend': 'qnnpack',
        "disable_requantization_for_cat": True,
        'per_tensor': False,
})

In [66]:
ptq_model_with_quantizer_tensor = quantizer_per_tensor.quantize()
ptq_model_with_quantizer_channel = quantizer_per_channel.quantize()

In [67]:
ptq_model_with_quantizer_tensor.to("cuda")
ptq_model_with_quantizer_channel.to("cuda")

QUNETLitModule(
  (fake_quant_0): QuantStub(
    (activation_post_process): FakeQuantize(
      fake_quant_enabled=tensor([1], device='cuda:0', dtype=torch.uint8), observer_enabled=tensor([1], device='cuda:0', dtype=torch.uint8), quant_min=0, quant_max=255, dtype=torch.quint8, qscheme=torch.per_tensor_affine, ch_axis=-1, scale=tensor([1.], device='cuda:0'), zero_point=tensor([0], device='cuda:0', dtype=torch.int32)
      (activation_post_process): MovingAverageMinMaxObserver(min_val=inf, max_val=-inf)
    )
  )
  (net_encoder_layer1_0_0): ConvBn2d(
    3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False
    (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (weight_fake_quant): FakeQuantize(
      fake_quant_enabled=tensor([1], device='cuda:0', dtype=torch.uint8), observer_enabled=tensor([1], device='cuda:0', dtype=torch.uint8), quant_min=-127, quant_max=127, dtype=torch.qint8, qscheme=torch.per_channel_symmetric, ch_axis=0, scale

In [68]:
# post quantization calibration
ptq_model_with_quantizer_tensor.apply(torch.quantization.disable_fake_quant)
ptq_model_with_quantizer_tensor.apply(torch.quantization.enable_observer)
ptq_model_with_quantizer_tensor = calibration(ptq_model_with_quantizer_tensor, 50, val_dataloader)

ptq_model_with_quantizer_channel.apply(torch.quantization.disable_fake_quant)
ptq_model_with_quantizer_channel.apply(torch.quantization.enable_observer)
ptq_model_with_quantizer_channel = calibration(ptq_model_with_quantizer_tensor, 50, val_dataloader)

In [69]:
# disable observer and enable fake quantization to validate model with quantization error
ptq_model_with_quantizer_tensor.apply(torch.quantization.disable_observer)
ptq_model_with_quantizer_tensor.apply(torch.quantization.enable_fake_quant)
# ptq_model_with_quantizer_tensor(next(iter(val_dataloader))[0].to("cuda"))

ptq_model_with_quantizer_channel.apply(torch.quantization.disable_observer)
ptq_model_with_quantizer_channel.apply(torch.quantization.enable_fake_quant)
# ptq_model_with_quantizer_channel(next(iter(val_dataloader))[0].to("cuda"))

QUNETLitModule(
  (fake_quant_0): QuantStub(
    (activation_post_process): FakeQuantize(
      fake_quant_enabled=tensor([1], device='cuda:0', dtype=torch.uint8), observer_enabled=tensor([0], device='cuda:0', dtype=torch.uint8), quant_min=0, quant_max=255, dtype=torch.quint8, qscheme=torch.per_tensor_affine, ch_axis=-1, scale=tensor([0.0039], device='cuda:0'), zero_point=tensor([0], device='cuda:0', dtype=torch.int32)
      (activation_post_process): MovingAverageMinMaxObserver(min_val=0.0, max_val=1.0)
    )
  )
  (net_encoder_layer1_0_0): ConvBn2d(
    3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False
    (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (weight_fake_quant): FakeQuantize(
      fake_quant_enabled=tensor([1], device='cuda:0', dtype=torch.uint8), observer_enabled=tensor([0], device='cuda:0', dtype=torch.uint8), quant_min=-128, quant_max=127, dtype=torch.qint8, qscheme=torch.per_tensor_symmetric, ch_axis=-1, sc

### **QAT**

In [73]:
quantizer_per_tensor = QATQuantizer(
    model,
    torch.randn(1,3,52,52),
    work_dir='quant_output',
    config={
        'asymmetric': True,
        'backend': 'qnnpack',
        "disable_requantization_for_cat": True,
        'per_tensor': True,
})

In [74]:
qat_model = quantizer_per_tensor.quantize()

In [75]:
qat_model = calibration(qat_model, 50, val_dataloader)

In [76]:
qat_model.to("cuda")

QUNETLitModule(
  (fake_quant_0): QuantStub(
    (activation_post_process): FakeQuantize(
      fake_quant_enabled=tensor([1], device='cuda:0', dtype=torch.uint8), observer_enabled=tensor([1], device='cuda:0', dtype=torch.uint8), quant_min=0, quant_max=255, dtype=torch.quint8, qscheme=torch.per_tensor_affine, ch_axis=-1, scale=tensor([0.0039], device='cuda:0'), zero_point=tensor([0], device='cuda:0', dtype=torch.int32)
      (activation_post_process): MovingAverageMinMaxObserver(min_val=0.0, max_val=1.0)
    )
  )
  (net_encoder_layer1_0_0): ConvBn2d(
    3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False
    (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (weight_fake_quant): FakeQuantize(
      fake_quant_enabled=tensor([1], device='cuda:0', dtype=torch.uint8), observer_enabled=tensor([1], device='cuda:0', dtype=torch.uint8), quant_min=-128, quant_max=127, dtype=torch.qint8, qscheme=torch.per_tensor_symmetric, ch_axis=-1, sc

In [77]:
qat_model.train()

QUNETLitModule(
  (fake_quant_0): QuantStub(
    (activation_post_process): FakeQuantize(
      fake_quant_enabled=tensor([1], device='cuda:0', dtype=torch.uint8), observer_enabled=tensor([1], device='cuda:0', dtype=torch.uint8), quant_min=0, quant_max=255, dtype=torch.quint8, qscheme=torch.per_tensor_affine, ch_axis=-1, scale=tensor([0.0039], device='cuda:0'), zero_point=tensor([0], device='cuda:0', dtype=torch.int32)
      (activation_post_process): MovingAverageMinMaxObserver(min_val=0.0, max_val=1.0)
    )
  )
  (net_encoder_layer1_0_0): ConvBn2d(
    3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False
    (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (weight_fake_quant): FakeQuantize(
      fake_quant_enabled=tensor([1], device='cuda:0', dtype=torch.uint8), observer_enabled=tensor([1], device='cuda:0', dtype=torch.uint8), quant_min=-128, quant_max=127, dtype=torch.qint8, qscheme=torch.per_tensor_symmetric, ch_axis=-1, sc

In [78]:
qat_model.apply(torch.quantization.enable_fake_quant)
qat_model.apply(torch.quantization.enable_observer)

QUNETLitModule(
  (fake_quant_0): QuantStub(
    (activation_post_process): FakeQuantize(
      fake_quant_enabled=tensor([1], device='cuda:0', dtype=torch.uint8), observer_enabled=tensor([1], device='cuda:0', dtype=torch.uint8), quant_min=0, quant_max=255, dtype=torch.quint8, qscheme=torch.per_tensor_affine, ch_axis=-1, scale=tensor([0.0039], device='cuda:0'), zero_point=tensor([0], device='cuda:0', dtype=torch.int32)
      (activation_post_process): MovingAverageMinMaxObserver(min_val=0.0, max_val=1.0)
    )
  )
  (net_encoder_layer1_0_0): ConvBn2d(
    3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False
    (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (weight_fake_quant): FakeQuantize(
      fake_quant_enabled=tensor([1], device='cuda:0', dtype=torch.uint8), observer_enabled=tensor([1], device='cuda:0', dtype=torch.uint8), quant_min=-128, quant_max=127, dtype=torch.qint8, qscheme=torch.per_tensor_symmetric, ch_axis=-1, sc

In [79]:
# train model here

In [80]:
# validate the model with quantization error via fake quantization
qat_model.apply(torch.quantization.disable_observer)
# validate here

QUNETLitModule(
  (fake_quant_0): QuantStub(
    (activation_post_process): FakeQuantize(
      fake_quant_enabled=tensor([1], device='cuda:0', dtype=torch.uint8), observer_enabled=tensor([0], device='cuda:0', dtype=torch.uint8), quant_min=0, quant_max=255, dtype=torch.quint8, qscheme=torch.per_tensor_affine, ch_axis=-1, scale=tensor([0.0039], device='cuda:0'), zero_point=tensor([0], device='cuda:0', dtype=torch.int32)
      (activation_post_process): MovingAverageMinMaxObserver(min_val=0.0, max_val=1.0)
    )
  )
  (net_encoder_layer1_0_0): ConvBn2d(
    3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False
    (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (weight_fake_quant): FakeQuantize(
      fake_quant_enabled=tensor([1], device='cuda:0', dtype=torch.uint8), observer_enabled=tensor([0], device='cuda:0', dtype=torch.uint8), quant_min=-128, quant_max=127, dtype=torch.qint8, qscheme=torch.per_tensor_symmetric, ch_axis=-1, sc