### **Quantization steps**

In [None]:
import torch
import torch.ao.quantization.quantize_fx as quantize_fx
from tinynn.graph.quantization.quantizer import QATQuantizer
from torch.utils.data import DataLoader, random_split
from torchvision.transforms import transforms

from src.data.components.custom_transforms import BilinearInterpolation
from src.data.components.nyu_dataset import NYUDataset
from src.models.unet_module import UNETLitModule

In [None]:
model_ckpt = "logs/train/runs/2024-06-02_01-46-33/checkpoints/epoch_000.ckpt"

In [None]:
checkpoint = torch.load(model_ckpt)

In [None]:
model = UNETLitModule.load_from_checkpoint(model_ckpt)

In [None]:
model.net

### **Fuse BatchNorm**

In [None]:
model_fuse = quantize_fx.fuse_fx(model.eval())

In [None]:
model_fuse

### **PTQ**

In [None]:
model.eval()

In [None]:
transforms_img = transforms.Compose([transforms.PILToTensor(), transforms.Resize((224, 224))])

transforms_mask_train = transforms.Compose(
    [transforms.ToTensor(), BilinearInterpolation((56, 56))]
)

In [None]:
trainset = NYUDataset(
    "nyu2_train.csv", "data/", transform=transforms_img, target_transform=transforms_mask_train
)

data_train, data_val = random_split(
    dataset=trainset,
    lengths=[0.8, 0.2],
    generator=torch.Generator().manual_seed(42),
)

val_dataloader = DataLoader(dataset=data_val, batch_size=32, num_workers=2)

In [None]:
def calibration(model, num_iterations, val_dataloader):
    count = 0
    for data in val_dataloader:
        img, mask = data

        if torch.cuda.is_available():
            img = img.cuda()
            mask = mask.cuda()
        model(img)

        count += 1

        if count >= num_iterations:
            break

    return model

In [None]:
quantizer_per_tensor = QATQuantizer(
    model,
    torch.randn(1, 3, 52, 52),
    work_dir="quant_output",
    config={
        "asymmetric": True,
        "backend": "qnnpack",
        "disable_requantization_for_cat": True,
        "per_tensor": True,
    },
)

quantizer_per_channel = QATQuantizer(
    model,
    torch.randn(1, 3, 52, 52),
    work_dir="quant_output",
    config={
        "asymmetric": True,
        "backend": "qnnpack",
        "disable_requantization_for_cat": True,
        "per_tensor": False,
    },
)

In [None]:
ptq_model_with_quantizer_tensor = quantizer_per_tensor.quantize()
ptq_model_with_quantizer_channel = quantizer_per_channel.quantize()

In [None]:
ptq_model_with_quantizer_tensor.to("cuda")
ptq_model_with_quantizer_channel.to("cuda")

In [None]:
# post quantization calibration
ptq_model_with_quantizer_tensor.apply(torch.quantization.disable_fake_quant)
ptq_model_with_quantizer_tensor.apply(torch.quantization.enable_observer)
ptq_model_with_quantizer_tensor = calibration(ptq_model_with_quantizer_tensor, 50, val_dataloader)

ptq_model_with_quantizer_channel.apply(torch.quantization.disable_fake_quant)
ptq_model_with_quantizer_channel.apply(torch.quantization.enable_observer)
ptq_model_with_quantizer_channel = calibration(ptq_model_with_quantizer_tensor, 50, val_dataloader)

In [None]:
# disable observer and enable fake quantization to validate model with quantization error
ptq_model_with_quantizer_tensor.apply(torch.quantization.disable_observer)
ptq_model_with_quantizer_tensor.apply(torch.quantization.enable_fake_quant)
# ptq_model_with_quantizer_tensor(next(iter(val_dataloader))[0].to("cuda"))

ptq_model_with_quantizer_channel.apply(torch.quantization.disable_observer)
ptq_model_with_quantizer_channel.apply(torch.quantization.enable_fake_quant)
# ptq_model_with_quantizer_channel(next(iter(val_dataloader))[0].to("cuda"))

### **QAT**

In [None]:
quantizer_per_tensor = QATQuantizer(
    model,
    torch.randn(1, 3, 52, 52),
    work_dir="quant_output",
    config={
        "asymmetric": True,
        "backend": "qnnpack",
        "disable_requantization_for_cat": True,
        "per_tensor": True,
    },
)

In [None]:
qat_model = quantizer_per_tensor.quantize()

In [None]:
qat_model = calibration(qat_model, 50, val_dataloader)

In [None]:
qat_model.to("cuda")

In [None]:
qat_model.train()

In [None]:
qat_model.apply(torch.quantization.enable_fake_quant)
qat_model.apply(torch.quantization.enable_observer)

In [None]:
# train model here

In [None]:
# validate the model with quantization error via fake quantization
qat_model.apply(torch.quantization.disable_observer)
# validate here

In [None]:
import os


def check_saved_pytorch_model_size(filepath):
    if os.path.isfile(filepath):
        size_bytes = os.path.getsize(filepath)
        size_mb = size_bytes / (1024 * 1024)
        return size_mb
    else:
        return None

In [None]:
check_saved_pytorch_model_size("ptq_tensor.pty")

In [None]:
check_saved_pytorch_model_size("ptq.pty")

In [None]:
check_saved_pytorch_model_size("qat.pty")