In [1]:
import torch, time
import torch.nn as nn
import torch.optim as optim

from src.utils import *
from src.override_resnet import *


class Args:
    arch = 50
    dataset = "ImageNet"
    # dataset = "CIFAR100"
    lr = 0.001
    momentum = 0.9
    batch = 16
    epochs = 10
    save_every = 1
    quan = "static"
    only_eval = True
    verbose = True


args = Args()

In [2]:
def run_benchmark(model, img_loader):
    elapsed = 0
    model.eval()
    num_batches = 1
    # 이미지 배치들 이용하여 스크립트된 모델 실행
    for i, (images, target) in enumerate(img_loader):
        if i < num_batches:
            start = time.time()
            output = model(images)
            end = time.time()
            elapsed = elapsed + (end - start)
        else:
            break
    num_images = images.size()[0] * num_batches

    print("Elapsed time: %3.0f ms" % (elapsed / num_images * 1000))
    return elapsed

In [3]:
def fuse_model(model) -> nn.Module:
    flag = False
    for m in model.modules():
        if m.__class__.__name__ == ResNet_quan.__name__:
            if flag == True:
                raise ValueError("ResNet_quan is already fused")
            flag = True
            torch.quantization.fuse_modules(
                m,
                ["conv1", "bn1", "relu"],
                inplace=True,
            )

        if type(m) == BottleNeck_quan:
            torch.quantization.fuse_modules(
                m,
                [
                    ["conv1", "bn1", "relu1"],
                    ["conv2", "bn2", "relu2"],
                    ["conv3", "bn3"],
                ],
                inplace=True,
            )
            if m.downsample is not None:
                torch.quantization.fuse_modules(
                    m.downsample,
                    ["0", "1"],
                    inplace=True,
                )
    return model

In [4]:
# %% my code

args = Args()
# %% Load the ResNet-50 model
if args.quan == "fp32":
    # case 0 : no quantization case
    print("----------No quantization enabled")
    device = str(torch.device("cuda:0" if torch.cuda.is_available() else "cpu"))
    model = layers_mapping[args.arch](
        weights=pretrained_weights_mapping[args.arch]
    ).to(device)

elif args.quan == "dynamic":
    # case 1 : Dynamic Quantization
    print("----------Dynamic Quantization enabled")
    device = "cuda"
    model = resnet50_quan(weights=pretrained_weights_mapping[args.arch]).to(device)
    quantized_model = torch.quantization.quantize_dynamic(
        model, {torch.nn.Linear}, dtype=torch.qint8
    )
    model = quantized_model

elif args.quan == "static":
    # case 2 : Static Quantization
    print("----------Static Quantization enabled")
    device = "cpu"
    model = resnet50_quan(weights=pretrained_weights_mapping[args.arch]).to(device)

elif args.quan == "qat":
    # case 3 : Quantization Aware Training
    print("----------Quantization Aware Training enabled")
else:
    raise ValueError("Invalid quantization method")

_folder_path = f"resnet{args.arch}_{args.dataset}" + "_" + args.quan
_file_name = (
    f"resnet{args.arch}_{args.dataset}_epoch"  # resnet18_cifar10_epoch{epoch}.pth
)

# %%Set up training and evaluation processes
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)
device = "cpu"
print(device)

----------Static Quantization enabled
cpu


In [5]:
model.to(device)
model.eval()

train_loader, test_loader = GetDataset(
    dataset_name=args.dataset,
    device=device,
    root="data",
    batch_size=64,
    num_workers=8,
)

# 0. REF (acc@1 : 80.35%, 시간 및 크기 측정만)

In [6]:
_ = run_benchmark(model, test_loader)
print_size_of_model(model)
print(model.layer1[0])

Elapsed time:  43 ms
Size (MB): 102.52663
BottleNeck_quan(
  (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
  (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (downsample): Sequential(
    (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (relu1): ReLU()
  (relu2): ReLU()
  (relu3): ReLU()
  (add): FloatFunctional(
    (activation_post_process): Identity()
  )
)


fuse 확인

In [7]:
model = fuse_model(model)
print(print_size_of_model(model))
print(model.layer1[0])

Size (MB): 102.158986
None
BottleNeck_quan(
  (conv1): ConvReLU2d(
    (0): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
    (1): ReLU()
  )
  (bn1): Identity()
  (conv2): ConvReLU2d(
    (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
  )
  (bn2): Identity()
  (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1))
  (bn3): Identity()
  (relu): ReLU(inplace=True)
  (downsample): Sequential(
    (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1))
    (1): Identity()
  )
  (relu1): Identity()
  (relu2): Identity()
  (relu3): ReLU()
  (add): FloatFunctional(
    (activation_post_process): Identity()
  )
)


# calibration (training set 한 바퀴 돌림)

In [8]:
model.qconfig = torch.quantization.get_default_qconfig("x86")
print(model.qconfig)
torch.quantization.prepare(model, inplace=True)
print("Post Training Quantization Prepare: Inserting Observers")

QConfig(activation=functools.partial(<class 'torch.ao.quantization.observer.HistogramObserver'>, reduce_range=True){}, weight=functools.partial(<class 'torch.ao.quantization.observer.PerChannelMinMaxObserver'>, dtype=torch.qint8, qscheme=torch.per_channel_symmetric){})
Post Training Quantization Prepare: Inserting Observers




In [9]:
# _, _ = SingleEpochEval(model, train_loader, criterion, device)
_, _ = SingleEpochEval(model, test_loader, criterion, device)
print("Post Training Quantization: Calibration done")

100%|██████████| 782/782 [37:28<00:00,  2.88s/it]

Post Training Quantization: Calibration done





convert

In [10]:
torch.quantization.convert(model, inplace=True)
print("Post Training Quantization: Convert done")

Post Training Quantization: Convert done


# Static quantization 완료

In [11]:
_ = run_benchmark(model, test_loader)

eval_loss, eval_acc = SingleEpochEval(model, test_loader, criterion, device)
print_size_of_model(model)
print(f"Eval Loss: {eval_loss:.4f}, Eval Acc: {eval_acc:.2f}%")
print("Post Training Quantization: Eval done")

Elapsed time:  20 ms


100%|██████████| 782/782 [09:36<00:00,  1.36it/s]


Size (MB): 26.151272
Eval Loss: 1.4316, Eval Acc: 79.81%
Post Training Quantization: Eval done
