In [1]:
from gorillatracker.datasets.cxl import CXLDataset
from gorillatracker.model import BaseModule
from gorillatracker.quantization.utils import get_model_input
from gorillatracker.utils.embedding_generator import get_model_for_run_url

from torch.ao.quantization.quantize_pt2e import prepare_pt2e, convert_pt2e
from torch._export import capture_pre_autograd_graph
import torch.ao.quantization

from ai_edge_torch.quantize.pt2e_quantizer import get_symmetric_quantization_config
from ai_edge_torch.quantize.pt2e_quantizer import PT2EQuantizer
from ai_edge_torch.quantize.quant_config import QuantConfig

from ai_edge_torch.debug import find_culprits

save_quantized_model = False
load_quantized_model = False
save_model_architecture = False
number_of_calibration_images = 100
dataset_path = "/workspaces/gorillatracker/data/splits/ground_truth-cxl-face_images-openset-reid-val-0-test-0-mintraincount-3-seed-42-train-50-val-25-test-25"
model_wandb_url = (
    "https://wandb.ai/gorillas/Embedding-EfficientNet-CXL-OpenSet/runs/famq71r6/workspace?nw=nwuserkajohpi"
)


# 1. Quantization

calibration_input_embeddings, _ = get_model_input(
    CXLDataset, dataset_path=dataset_path, partion="train", amount_of_tensors=number_of_calibration_images
)
model: BaseModule = get_model_for_run_url(model_wandb_url)

model = model.eval()

pt2e_quantizer = PT2EQuantizer().set_global(get_symmetric_quantization_config(is_per_channel=True, is_dynamic=True))

pt2e_torch_model = capture_pre_autograd_graph(model, (calibration_input_embeddings,))
pt2e_torch_model = prepare_pt2e(pt2e_torch_model, pt2e_quantizer)

print("Prepared")

# Run the prepared model with sample input data to ensure that internal observers are populated with correct values
pt2e_torch_model(*(calibration_input_embeddings,))

# Convert the prepared model to a quantized model
pt2e_torch_model = convert_pt2e(pt2e_torch_model, fold_quantize=False)
torch.ao.quantization.allow_exported_model_train_eval(pt2e_torch_model)
print("Converted")

I0000 00:00:1717060154.188171  752003 service.cc:145] XLA service 0x55749cb8bcc0 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1717060154.188219  752003 service.cc:153]   StreamExecutor device (0): NVIDIA H100 80GB HBM3, Compute Capability 9.0
I0000 00:00:1717060154.188875  752003 se_gpu_pjrt_client.cc:853] Using BFC allocator.
I0000 00:00:1717060154.188915  752003 gpu_helpers.cc:114] XLA backend allocating 63707234304 bytes on device 0 for BFCAllocator.
I0000 00:00:1717060154.188944  752003 gpu_helpers.cc:154] XLA backend will use up to 21235744768 bytes on device 0 for CollectiveBFCAllocator.


ParseResult(scheme='https', netloc='wandb.ai', path='/gorillas/Embedding-EfficientNet-CXL-OpenSet/runs/famq71r6/workspace', params='', query='nw=nwuserkajohpi', fragment='') ['', 'gorillas', 'Embedding-EfficientNet-CXL-OpenSet', 'runs', 'famq71r6', 'workspace'] /gorillas/Embedding-EfficientNet-CXL-OpenSet/runs/famq71r6/workspace
Using model from run: 261-add-the-ability-to-do-quantization-using-pytorch-2-export-quantization-2024-05-21-07-13-45
Config: {'s': 64, 'seed': 42, 'beta1': 0.9, 'beta2': 0.999, 'debug': False, 'kfold': False, 'end_lr': 1e-06, 'margin': 1, 'resume': False, 'compile': False, 'delta_t': 50, 'epsilon': 1e-07, 'l2_beta': 0.01, 'offline': False, 'plugins': None, 'use_ssl': False, 'workers': 16, 'data_dir': '/workspaces/gorillatracker/data/splits/ground_truth-cxl-face_images-openset-reid-val-0-test-0-mintraincount-3-seed-42-train-50-val-25-test-25', 'l2_alpha': 0.1, 'only_val': False, 'profiler': None, 'run_name': '261-add-the-ability-to-do-quantization-using-pytorch-

[34m[1mwandb[0m: Downloading large artifact model-famq71r6:v0, 1348.64MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:2.3
W0530 09:09:21.580000 140652183078720 torch/_export/__init__.py:97] capture_pre_autograd_graph() is deprecated and doesn't provide any function guarantee moving forward.
W0530 09:09:21.580000 140652183078720 torch/_export/__init__.py:98] Please switch to use torch.export instead.


Prepared
Converted


In [10]:
# print(model(calibration_input_embeddings))
# print(pt2e_torch_model(calibration_input_embeddings))

calibration_input_embeddings[0].unsqueeze(0).shape
# culprits = find_culprits(pt2e_torch_model, (calibration_input_embeddings,))
# culprit = next(culprits)
# culprit.print_code()

torch.Size([1, 3, 224, 224])

In [5]:
# import torch.ao.quantization.pt2e.export_utils
# torch.ao.quantization.pt2e.export_utils.model_is_exported(pt2e_torch_model)

# quantized_ep = torch.export.export(pt2e_torch_model, (calibration_input_embeddings,))



In [11]:
import ai_edge_torch
pt2e_drq_model = ai_edge_torch.convert(
    pt2e_torch_model, (calibration_input_embeddings[0].unsqueeze(0),), quant_config=QuantConfig(pt2e_quantizer=pt2e_quantizer)
)
pt2e_drq_model.export("quantized_model.tflite")

W0000 00:00:1717061088.589067  752003 tf_tfl_flatbuffer_helpers.cc:392] Ignored output_format.
W0000 00:00:1717061088.589090  752003 tf_tfl_flatbuffer_helpers.cc:395] Ignored drop_control_dependency.
