In [1]:
import torch
import torch.nn as nn
from ultralytics import YOLO
from ultralytics.yolo.utils import yaml_load, LOGGER, RANK
import onnx
from thop import profile
from torch.ao.quantization.qconfig import QConfig
from ultralytics.nn.modules_quantized import Q_Conv
import torchvision
from torchvision import models, datasets
import torchvision.transforms as transforms
import os
import torch.quantization as quantization
import torch.quantization._numeric_suite as ns
from torch.ao.quantization import (
    default_eval_fn,
    default_qconfig,
    quantize,
)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
float_model = torchvision.models.quantization.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1, quantize=False)
float_model.to('cpu')
float_model.eval()
float_model.fuse_model()
float_model.qconfig = torch.quantization.default_qconfig
img_data = [(torch.rand(2, 3, 10, 10, dtype=torch.float), torch.randint(0, 1, (2,), dtype=torch.long)) for _ in range(2)]
qmodel = quantize(float_model, default_eval_fn, [img_data], inplace=False)
print(qmodel)

QuantizableResNet(
  (conv1): QuantizedConvReLU2d(3, 64, kernel_size=(7, 7), stride=(2, 2), scale=0.00818648561835289, zero_point=0, padding=(3, 3))
  (bn1): Identity()
  (relu): Identity()
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): QuantizableBasicBlock(
      (conv1): QuantizedConvReLU2d(64, 64, kernel_size=(3, 3), stride=(1, 1), scale=0.006201634649187326, zero_point=0, padding=(1, 1))
      (bn1): Identity()
      (relu): Identity()
      (conv2): QuantizedConv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), scale=0.03158609941601753, zero_point=51, padding=(1, 1))
      (bn2): Identity()
      (add_relu): QFunctional(
        scale=0.023167304694652557, zero_point=0
        (activation_post_process): Identity()
      )
    )
    (1): QuantizableBasicBlock(
      (conv1): QuantizedConvReLU2d(64, 64, kernel_size=(3, 3), stride=(1, 1), scale=0.013521037064492702, zero_point=0, padding=(1, 1))
      (bn1): Iden

In [3]:
wt_compare_dict = ns.compare_weights(float_model.state_dict(), qmodel.state_dict())

print('keys of wt_compare_dict:')
print(wt_compare_dict.keys())

print("\nkeys of wt_compare_dict entry for conv1's weight:")
print(wt_compare_dict['conv1.weight'].keys())
print(wt_compare_dict['conv1.weight']['float'].shape)
print(wt_compare_dict['conv1.weight']['quantized'].shape)

keys of wt_compare_dict:
dict_keys(['conv1.weight', 'layer1.0.conv1.weight', 'layer1.0.conv2.weight', 'layer1.1.conv1.weight', 'layer1.1.conv2.weight', 'layer2.0.conv1.weight', 'layer2.0.conv2.weight', 'layer2.0.downsample.0.weight', 'layer2.1.conv1.weight', 'layer2.1.conv2.weight', 'layer3.0.conv1.weight', 'layer3.0.conv2.weight', 'layer3.0.downsample.0.weight', 'layer3.1.conv1.weight', 'layer3.1.conv2.weight', 'layer4.0.conv1.weight', 'layer4.0.conv2.weight', 'layer4.0.downsample.0.weight', 'layer4.1.conv1.weight', 'layer4.1.conv2.weight', 'fc._packed_params._packed_params'])

keys of wt_compare_dict entry for conv1's weight:
dict_keys(['float', 'quantized'])
torch.Size([64, 3, 7, 7])
torch.Size([64, 3, 7, 7])


In [4]:
def transfer_weights_qconv(conv, qconv):
    state_dict_conv = conv.state_dict()
    state_dict_qconv = qconv.state_dict()
    state_dict_qconv['conv.weight'] = state_dict_conv['conv.weight']
    for bn_key in ['weight', 'bias', 'running_mean', 'running_var']:
        state_dict_qconv[f'bn.{bn_key}'] = state_dict_conv[f'bn.{bn_key}']
    for attr_name in dir(conv):
        attr_value = getattr(conv, attr_name)
        if not callable(attr_value) and '_' not in attr_name:
            setattr(qconv, attr_name, attr_value)
    qconv.load_state_dict(state_dict_qconv)
    
def replace_conv_with_qconv_v2_ptq(module):
    for name, child_module in module.named_children():
        if isinstance(child_module, Detect):
            continue
        elif isinstance(child_module, Conv):
            # Replace C2f with C2f_v2 while preserving its parameters
            conv2d = child_module.conv
            (c1, c2, k, s, p, g, d, act) = (conv2d.in_channels, conv2d.out_channels, conv2d.kernel_size, 
                                conv2d.stride, conv2d.padding, conv2d.groups, conv2d.dilation[0], child_module.act)
            qconv = Q_Conv(c1, c2, k, s, p=p, g=g, d=d, act=act)
            qconfig = QConfig(activation=torch.quantization.QuantStubConfig(dtype="float"),
                                weight=torch.quantization.QuantStubConfig(dtype="float"),
                                qscheme="float16",
                            )

            # qconfig = quantization.get_default_qat_qconfig()
            qconv.qconfig = qconfig
            setattr(module, name, qconv)
            transfer_weights_qconv(child_module, qconv) 
            qconv.eval()
            if not isinstance(act, nn.ReLU):
                torch.quantization.fuse_modules(qconv, [['conv', 'bn']], inplace=True)
                qconv.forward = forward.__get__(qconv)
            else:
                torch.quantization.fuse_modules(qconv, [['conv', 'bn', 'act']], inplace=True)
        else:
            replace_conv_with_qconv_v2_ptq(child_module)

In [83]:
script_model = torch.jit.load("asset/trained_model/UA-DETRAC/v8s_relu_DETRAC.torchscript")

# Convert the script model to a regular PyTorch model
if isinstance(script_model, torch.jit.ScriptModule):
    # If the loaded model is a ScriptModule, invoke it to get the underlying nn.Module
    model = script_model.eval()
else:
    # If the loaded model is already an nn.Module, use it directly
    model = script_model

model = torch.jit.trace(model, torch.randn(1, 3, 640, 640))
 
# model2 = YOLO('yolov8s_relu.pt')
# print(model2.model.model)

# replace_conv_with_qconv_v2_ptq(model2.model)
# torch.quantization.prepare(model2.model, inplace=True)
# torch.quantization.convert(model2.model, inplace=True)

# example_input = torch.randn(1, 3, 640, 640)
# model2.model(example_input)
# # scripted_model = torch.jit.script(model2.model)
# traced_model = torch.jit.trace(model2.model, example_input, strict=False)

# traced_model.save('tmp.torchscript')
# float_model = torch.jit.load('tmp.torchscript')

with open("yolo_model.txt", "w") as file:
    for key in model.state_dict().keys():
        file.write(key + '\n')

# wt_compare_dict = ns.compare_weights(model.state_dict(), float_model.state_dict())

# print('keys of wt_compare_dict:')
# print(wt_compare_dict.keys())



