In [None]:
import os
import random
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import datasets, transforms
from cifar_model import MobileNet, ConvBNReLU, Bottleneck, ConvPool
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import ipdb
import time
import copy
import numpy as np
from utils import *

In [None]:
# Static quantization of our CIFAR model
# adapted from https://github.com/leimao/PyTorch-Static-Quantization/blob/main/cifar.py
# Lei has great tutorials on quantization aware training too!

def set_random_seeds(random_seed=0):

    torch.manual_seed(random_seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(random_seed)
    random.seed(random_seed)


def prepare_dataloader(num_workers=8,
                       train_batch_size=128,
                       eval_batch_size=256):

    train_transform = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
    ])

    test_transform = transforms.Compose([
        transforms.ToTensor(),
    ])

    train_set = torchvision.datasets.CIFAR10(root="./data",
                                             train=True,
                                             download=True,
                                             transform=train_transform)
    # We will use test set for validation and test in this project.
    # Do not use test set for validation in practice!
    test_set = torchvision.datasets.CIFAR10(root="./data",
                                            train=False,
                                            download=True,
                                            transform=test_transform)

    train_sampler = torch.utils.data.RandomSampler(train_set)
    test_sampler = torch.utils.data.SequentialSampler(test_set)

    train_loader = torch.utils.data.DataLoader(dataset=train_set,
                                               batch_size=train_batch_size,
                                               sampler=train_sampler,
                                               num_workers=num_workers)

    test_loader = torch.utils.data.DataLoader(dataset=test_set,
                                              batch_size=eval_batch_size,
                                              sampler=test_sampler,
                                              num_workers=num_workers)

    return train_loader, test_loader


def evaluate_model(model, test_loader, device, criterion=None):

    model.eval()
    model.to(device)

    running_loss = 0
    running_corrects = 0

    for inputs, labels in test_loader:

        inputs = inputs.to(device)
        labels = labels.to(device)

        outputs = model(inputs)
        _, preds = torch.max(outputs, 1)

        if criterion is not None:
            loss = criterion(outputs, labels).item()
        else:
            loss = 0

        # statistics
        running_loss += loss * inputs.size(0)
        running_corrects += torch.sum(preds == labels.data)

    eval_loss = running_loss / len(test_loader.dataset)
    eval_accuracy = running_corrects / len(test_loader.dataset)

    return eval_loss, eval_accuracy


def train_model(model,
                train_loader,
                test_loader,
                device,
                learning_rate=1e-1,
                num_epochs=200):

    # The training configurations were not carefully selected.

    criterion = nn.CrossEntropyLoss()

    model.to(device)

    # It seems that SGD optimizer is better than Adam optimizer for ResNet18 training on CIFAR10.
    optimizer = optim.SGD(model.parameters(),
                          lr=learning_rate,
                          momentum=0.9,
                          weight_decay=1e-4)
    # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=500)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                     milestones=[100, 150],
                                                     gamma=0.1,
                                                     last_epoch=-1)
    # optimizer = optim.Adam(model.parameters(), lr=learning_rate, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)

    # Evaluation
    model.eval()
    eval_loss, eval_accuracy = evaluate_model(model=model,
                                              test_loader=test_loader,
                                              device=device,
                                              criterion=criterion)
    print("Epoch: {:02d} Eval Loss: {:.3f} Eval Acc: {:.3f}".format(
        -1, eval_loss, eval_accuracy))

    for epoch in range(num_epochs):

        # Training
        model.train()

        running_loss = 0
        running_corrects = 0

        for inputs, labels in train_loader:

            inputs = inputs.to(device)
            labels = labels.to(device)

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            # statistics
            running_loss += loss.item() * inputs.size(0)
            running_corrects += torch.sum(preds == labels.data)

        train_loss = running_loss / len(train_loader.dataset)
        train_accuracy = running_corrects / len(train_loader.dataset)

        # Evaluation
        model.eval()
        eval_loss, eval_accuracy = evaluate_model(model=model,
                                                  test_loader=test_loader,
                                                  device=device,
                                                  criterion=criterion)

        # Set learning rate scheduler
        scheduler.step()

        print(
            "Epoch: {:03d} Train Loss: {:.3f} Train Acc: {:.3f} Eval Loss: {:.3f} Eval Acc: {:.3f}"
            .format(epoch, train_loss, train_accuracy, eval_loss,
                    eval_accuracy))

    return model


def calibrate_model(model, loader, device=torch.device("cpu:0")):

    model.to(device)
    model.eval()

    for inputs, labels in loader:
        inputs = inputs.to(device)
        labels = labels.to(device)
        _ = model(inputs)


def measure_inference_latency(model,
                              device,
                              input_size=(1, 3, 32, 32),
                              num_samples=100,
                              num_warmups=10):

    model.to(device)
    model.eval()

    x = torch.rand(size=input_size).to(device)

    with torch.no_grad():
        for _ in range(num_warmups):
            _ = model(x)
    torch.cuda.synchronize()

    with torch.no_grad():
        start_time = time.time()
        for _ in range(num_samples):
            _ = model(x)
            torch.cuda.synchronize()
        end_time = time.time()
    elapsed_time = end_time - start_time
    elapsed_time_ave = elapsed_time / num_samples

    return elapsed_time_ave


def save_model(model, model_dir, model_filename):

    if not os.path.exists(model_dir):
        os.makedirs(model_dir)
    model_filepath = os.path.join(model_dir, model_filename)
    torch.save(model.state_dict(), model_filepath)


def load_model(model, model_filepath, device):

    model.load_state_dict(torch.load(model_filepath, map_location=device))

    return model


def save_torchscript_model(model, model_dir, model_filename):

    if not os.path.exists(model_dir):
        os.makedirs(model_dir)
    model_filepath = os.path.join(model_dir, model_filename)
    torch.jit.save(torch.jit.script(model), model_filepath)


def load_torchscript_model(model_filepath, device):

    model = torch.jit.load(model_filepath, map_location=device)

    return model


class QuantizedMobileNet(nn.Module):
    def __init__(self, model_fp32):
        super(QuantizedMobileNet, self).__init__()
        self.quant = torch.quantization.QuantStub()
        self.dequant = torch.quantization.DeQuantStub()
        self.model_fp32 = model_fp32
    def forward(self, x):
        x = self.quant(x)
        x = self.model_fp32(x)
        x = self.dequant(x)
        return x

def model_equivalence(model_1,
                      model_2,
                      device,
                      rtol=1e-05,
                      atol=1e-08,
                      num_tests=100,
                      input_size=(1, 3, 32, 32)):

    model_1.to(device)
    model_2.to(device)

    for _ in range(num_tests):
        x = torch.rand(size=input_size).to(device)
        y1 = model_1(x).detach().cpu().numpy()
        y2 = model_2(x).detach().cpu().numpy()
        if np.allclose(a=y1, b=y2, rtol=rtol, atol=atol,
                       equal_nan=False) == False:
            print("Model equivalence test sample failed: ")
            print(y1)
            print(y2)
            return False

    return True

In [None]:
random_seed = 0
num_classes = 10
cuda_device = torch.device("cuda:0")
cpu_device = torch.device("cpu:0")


model_dir = "models"
model_filename = "cifar-convnet.pth"
quantized_model_filename = "cifar-convnet-quantized.pt"
model_filepath = os.path.join(model_dir, model_filename)
quantized_model_filepath = os.path.join(model_dir,
                                        quantized_model_filename)

model = MobileNet(10)
model.load_state_dict(torch.load("models/cifar-convnet.pth", map_location=torch.device('cpu')))

set_random_seeds(random_seed=random_seed)
train_loader, test_loader = prepare_dataloader(num_workers=8,
                                               train_batch_size=128,
                                               eval_batch_size=256)



# Move the model to CPU since static quantization does not support CUDA currently.
model.to(cpu_device)
# Make a copy of the model for layer fusion
fused_model = copy.deepcopy(model)

model.eval()
# The model has to be switched to evaluation mode before any layer fusion.
# Otherwise the quantization will not work correctly.
capture = fused_model.eval()

In [None]:
for m in fused_model.modules():
    if type(m) == ConvBNReLU:
        torch.quantization.fuse_modules(m, ['0', '1', '2'], inplace=True)
    if type(m) == Bottleneck:
        for name, block in m.named_children():
            if name == 'bottleneck':
                torch.quantization.fuse_modules(block, ['0', '1', '2'], inplace=True)
                torch.quantization.fuse_modules(block, ['4', '5',], inplace=True)
    if type(m) == ConvPool:
        torch.quantization.fuse_modules(m, ['0', '1', '2'], inplace=True)

In [None]:
# Print fused model.
# print(fused_model)

In [None]:
# Model and fused model should be equivalent.
assert model_equivalence(
    model_1=model,
    model_2=fused_model,
    device=cpu_device,
    rtol=1e-03,
    atol=1e-06,
    num_tests=100,
    input_size=(
        1, 3, 32,
        32)), "Fused model is not equivalent to the original model!"

In [None]:
# Prepare the model for static quantization. This inserts observers in
# the model that will observe activation tensors during calibration.
quantized_model = QuantizedMobileNet(model_fp32=fused_model)
# Using un-fused model will fail.
# Because there is no quantized layer implementation for a single batch normalization layer.
# quantized_model = QuantizedResNet18(model_fp32=model)
# Select quantization schemes from
# https://pytorch.org/docs/stable/quantization-support.html
quantization_config = torch.quantization.get_default_qconfig("fbgemm")
# Custom quantization configurations
# quantization_config = torch.quantization.default_qconfig
# quantization_config = torch.quantization.QConfig(activation=torch.quantization.MinMaxObserver.with_args(dtype=torch.quint8), weight=torch.quantization.MinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_tensor_symmetric))

quantized_model.qconfig = quantization_config

# Print quantization configurations
print(quantized_model.qconfig)

# https://pytorch.org/docs/master/torch.quantization.html#torch.quantization.prepare
torch.quantization.prepare(quantized_model, inplace=True)

# Use training data for calibration.
calibrate_model(model=quantized_model,
                loader=train_loader,
                device=cpu_device)

quantized_model = torch.quantization.convert(quantized_model, inplace=True)

In [None]:
# Using high-level static quantization wrapper
# The above steps, including torch.quantization.prepare, calibrate_model, and torch.quantization.convert, are also equivalent to
# quantized_model = torch.quantization.quantize(model=quantized_model, run_fn=calibrate_model, run_args=[train_loader], mapping=None, inplace=False)

compressed_model.eval()

# Print quantized model.
# print(quantized_model)

# Save quantized model.
save_torchscript_model(model=compressed_model,
                       model_dir=model_dir,
                       model_filename=quantized_model_filename)

# Load quantized model.
quantized_jit_model = load_torchscript_model(
    model_filepath=quantized_model_filepath, device=cpu_device)

In [None]:
_, fp32_eval_accuracy = evaluate_model(model=model,
                                       test_loader=test_loader,
                                       device=cpu_device,
                                       criterion=None)
_, int8_eval_accuracy = evaluate_model(model=quantized_jit_model,
                                       test_loader=test_loader,
                                       device=cpu_device,
                                       criterion=None)

# Skip this assertion since the values might deviate a lot.
# assert model_equivalence(model_1=model, model_2=quantized_jit_model, device=cpu_device, rtol=1e-01, atol=1e-02, num_tests=100, input_size=(1,3,32,32)), "Quantized model deviates from the original model too much!"

print("FP32 evaluation accuracy: {:.3f}".format(fp32_eval_accuracy))
print("INT8 evaluation accuracy: {:.3f}".format(int8_eval_accuracy))

fp32_cpu_inference_latency = measure_inference_latency(model=model,
                                                       device=cpu_device,
                                                       input_size=(1, 3,
                                                                   32, 32),
                                                       num_samples=100)
int8_cpu_inference_latency = measure_inference_latency(
    model=quantized_model,
    device=cpu_device,
    input_size=(1, 3, 32, 32),
    num_samples=100)
int8_jit_cpu_inference_latency = measure_inference_latency(
    model=quantized_jit_model,
    device=cpu_device,
    input_size=(1, 3, 32, 32),
    num_samples=100)
fp32_gpu_inference_latency = measure_inference_latency(model=model,
                                                       device=cuda_device,
                                                       input_size=(1, 3,
                                                                   32, 32),
                                                       num_samples=100)

print("FP32 CPU Inference Latency: {:.2f} ms / sample".format(
    fp32_cpu_inference_latency * 1000))
print("FP32 CUDA Inference Latency: {:.2f} ms / sample".format(
    fp32_gpu_inference_latency * 1000))
print("INT8 CPU Inference Latency: {:.2f} ms / sample".format(
    int8_cpu_inference_latency * 1000))
print("INT8 JIT CPU Inference Latency: {:.2f} ms / sample".format(
    int8_jit_cpu_inference_latency * 1000))

In [None]:
quantized_model.eval()
# type(quantized_model.model_fp32.features[1].bottleneck[4])

In [None]:
new_layers = []
for module in quantized_model.modules():
    if isinstance(module, (torch.nn.intrinsic.quantized.modules.conv_relu.ConvReLU2d, torch.nn.quantized.modules.conv.Conv2d, torch.nn.quantized.modules.linear.Linear, nn.ReLU6, nn.MaxPool2d, nn.ReLU, nn.Flatten)):
        new_layers.append(module)
compressed_model = nn.Sequential(*new_layers)
compressed_model

In [None]:
new_layers = []
for module in quantized_model.modules():
    if isinstance(module, (torch.nn.intrinsic.quantized.modules.conv_relu.ConvReLU2d, torch.nn.quantized.modules.conv.Conv2d)):
        new_module = nn.Conv2d(module.in_channels, module.out_channels, module.kernel_size, stride=module.stride, padding=module.padding, groups=module.groups)
        new_module.weight = torch.nn.Parameter(module.weight().int_repr().float()/128, requires_grad=False)
#         ipdb.set_trace()
        new_module.bias = torch.nn.Parameter(module.bias(), requires_grad=False)
        new_layers.append(new_module)
        if isinstance(module, torch.nn.intrinsic.quantized.modules.conv_relu.ConvReLU2d):
            new_layers.append(nn.ReLU())
    elif isinstance(module, torch.nn.quantized.modules.linear.Linear):
        new_module = nn.Linear(module.in_features, module.out_features)
        new_module.weight = torch.nn.Parameter(module.weight().int_repr().float()/128, requires_grad=False)
        new_module.bias = torch.nn.Parameter(module.bias(), requires_grad=False)
        new_layers.append(new_module)
    elif isinstance(module, (nn.ReLU6, nn.MaxPool2d, nn.ReLU, nn.Flatten)):
        new_layers.append(module)

In [None]:
folded_model = nn.Sequential(*new_layers)
# folded_model

In [None]:
# torch.save(folded_model, "./cifar-convnet-quantized2.pth") # don't forget to set model.eval() after loading

In [None]:
folded_model = torch.load("cifar-convnet-quantized.pth")
folded_model

In [None]:
device = 'cuda'
folded_model.to(device)

In [None]:
transform = transforms.Compose([transforms.ToTensor()])
valid_dataset = datasets.CIFAR10(root='./data', train=False, transform=transform)
valid_loader = DataLoader(dataset=valid_dataset, batch_size=1, shuffle=False, num_workers=10, pin_memory=True)

In [None]:
folded_model.to(device)
get_folded_accuracy(folded_model, valid_loader, device=device)

In [None]:
inputs, target = next(iter(valid_loader))
inputs

In [None]:
# new_inputs = torch.tensor((inputs*255).clone().detach().requires_grad_(False), dtype=torch.int8)
# new_inputs.to(device)

In [None]:
folded_model(inputs.to(device))

In [None]:
_, int8_eval_accuracy = evaluate_model(model=quantized_jit_model,
                                       test_loader=test_loader,
                                       device=cpu_device,
                                       criterion=None)

In [None]:
int8_eval_accuracy

In [None]:
quantized_jit_model.model_fp32.features