In [1]:
import torch
import warnings

from src.models.resnet20 import ResNet20
from torch.utils.data import DataLoader
from data.data import get_train_data
from main import evaluate
from src.utils import Timer

warnings.filterwarnings("ignore")

# ResNet pretraining

In [2]:
!python3 main.py

Files already downloaded and verified
Files already downloaded and verified
[34m[1mwandb[0m: Currently logged in as: [33mjohan_ddc[0m ([33mjohan_ddc_team[0m). Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: wandb version 0.15.5 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade
[34m[1mwandb[0m: Tracking run with wandb version 0.15.0
[34m[1mwandb[0m: Run data is saved locally in [35m[1m/home/johan/PycharmProjects/quantization/wandb/run-20230708_201400-18g01tsv[0m
[34m[1mwandb[0m: Run [1m`wandb offline`[0m to turn off syncing.
[34m[1mwandb[0m: Syncing run [33mresnet20_train[0m
[34m[1mwandb[0m: ⭐️ View project at [34m[4mhttps://wandb.ai/johan_ddc_team/quatization_simple[0m
[34m[1mwandb[0m: 🚀 View run at [34m[4mhttps://wandb.ai/johan_ddc_team/quatization_simple/runs/18g01tsv[0m
100%|███████████████████████████████████████████| 50/50 [16:40<00:00, 20.02s/it]
[34m[1mwandb[0m: 

In [2]:
def memory_consumption(model, bits):
    num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    memory = num_params * bits / (8 * 1024 ** 2)
    return memory

In [3]:
timer = Timer("Unquantized model on cpu")
model = ResNet20(configuration=(3, 2, 2), num_classes=10, quantize=True)
model.load_state_dict(torch.load("checkpoints/resnet20_final.pth"))

criterion = torch.nn.CrossEntropyLoss()
cifar10_train, cifar10_test = get_train_data(root_dir="data")
train_loader = DataLoader(cifar10_train, batch_size=128, shuffle=True, pin_memory=True, num_workers=1, drop_last=True)
test_loader = DataLoader(cifar10_test, batch_size=128, shuffle=False, pin_memory=True, num_workers=1)

with timer:
    val_loss, val_accuracy = evaluate(model, criterion, test_loader)
print()
print(f"Final test loss: {val_loss.item()}")
print(f"Final test accuracy: {val_accuracy.item()}")
print(f"(Theoretical) memory consumption of model: {round(memory_consumption(model, 32), 3)}Mb")

Files already downloaded and verified
Files already downloaded and verified
Unquantized model on cpu took 48.01813s.

Final test loss: 0.38893190026283264
Final test accuracy: 0.9089201092720032
(Theoretical) memory consumption of model: 1.117Mb


# PyTorch Post Training Quantization

Model quantized to 16 bits.

In [7]:
timer = Timer("16-bit quantized model")

model = ResNet20(configuration=(3, 2, 2), num_classes=10, quantize=True)
model.load_state_dict(torch.load("checkpoints/resnet20_final.pth"))
model.eval()
model.half()
model.cuda()

with timer:
    val_loss, val_accuracy = evaluate(model, criterion, test_loader, device="cuda",
                                      batch_preprocessor=lambda batch: batch.half())
print(f"Quantized model test loss: {val_loss.item()}")
print(f"Quantized model test accuracy: {val_accuracy.item()}")
print(f"(Theoretical) memory consumption of model: {round(memory_consumption(model, 16), 3)}Mb")

16-bit quantized model took 3.99165s.
Quantized model test loss: 0.38875532150268555
Quantized model test accuracy: 0.9090189933776855
(Theoretical) memory consumption of model: 0.559Mb


Model quantized to 8 bits with per tensor quantization scheme.

In [8]:
timer = Timer("8-bit quantized model (per tensor scheme)")
num_calibration_batches = 20

model = ResNet20(configuration=(3, 2, 2), num_classes=10, quantize=True)
model.load_state_dict(torch.load("checkpoints/resnet20_final.pth"))
model.eval()
model.qconfig = torch.ao.quantization.qconfig.QConfig(
    activation=torch.ao.quantization.observer.MinMaxObserver.with_args(dtype=torch.quint8),
    weight=torch.ao.quantization.observer.MinMaxObserver.with_args(dtype=torch.qint8,
                                                                   qscheme=torch.per_tensor_symmetric)
)
model.fuse_model()
model = torch.ao.quantization.prepare(model, inplace=True)
_, _ = evaluate(model, criterion, train_loader, num_batches=num_calibration_batches)
model_q = torch.ao.quantization.convert(model)

with timer:
    val_loss, val_accuracy = evaluate(model_q, criterion, test_loader)
print(f"Quantized model test loss: {val_loss.item()}")
print(f"Quantized model test accuracy: {val_accuracy.item()}")
print(f"(Theoretical) memory consumption of model: {round(memory_consumption(model, 8), 3)}Mb")

8-bit quantized model (per tensor scheme) took 33.56244s.
Quantized model test loss: 0.3915562331676483
Quantized model test accuracy: 0.9067444801330566
(Theoretical) memory consumption of model: 0.279Mb


# Custom Post Training Quantization
## Simple implementation using torch Quantization API

Model quantized to 8 bits (per tensor quantization scheme).

In [10]:
from src.ptq.activation_observer import SimpleObserver

timer = Timer("8-bit quantized model (per tensor scheme)")
num_calibration_batches = 20

model = ResNet20(configuration=(3, 2, 2), num_classes=10, quantize=True)
model.load_state_dict(torch.load("checkpoints/resnet20_final.pth"))
model.eval()
model.qconfig = torch.ao.quantization.qconfig.QConfig(SimpleObserver.with_args(dtype=torch.quint8),
                                                      SimpleObserver.with_args(dtype=torch.qint8,
                                                                               qscheme=torch.per_tensor_symmetric))
model.fuse_model()
model = torch.ao.quantization.prepare(model, inplace=True)
_, _ = evaluate(model, criterion, train_loader, num_batches=num_calibration_batches)
model_q = torch.ao.quantization.convert(model)

with timer:
    val_loss, val_accuracy = evaluate(model_q, criterion, test_loader)

print(f"Quantized model test loss: {val_loss.item()}")
print(f"Quantized model test accuracy: {val_accuracy.item()}")
print(f"(Theoretical) memory consumption of model: {round(memory_consumption(model, 8), 3)}Mb")

8-bit quantized model (per tensor scheme) took 34.17681s.
Quantized model test loss: 0.3917810916900635
Quantized model test accuracy: 0.905656635761261
(Theoretical) memory consumption of model: 0.279Mb


Model quantized to 4 bits (per tensor quantization scheme).

In [6]:
timer = Timer("4-bit quantized model (per tensor scheme)")
num_calibration_batches = 20

model = ResNet20(configuration=(3, 2, 2), num_classes=10, quantize=True)
model.load_state_dict(torch.load("checkpoints/resnet20_final.pth"))
model.eval()
model.qconfig = torch.ao.quantization.qconfig.QConfig(
    SimpleObserver.with_args(dtype=torch.quint8, quant_min=0, quant_max=2 ** 4 - 1),
    SimpleObserver.with_args(dtype=torch.qint8,
                             qscheme=torch.per_tensor_symmetric, quant_min=-2 ** 3, quant_max=2 ** 3 - 1))
model.fuse_model()
model = torch.ao.quantization.prepare(model, inplace=True)
_, _ = evaluate(model, criterion, train_loader, num_batches=num_calibration_batches)
model_q = torch.ao.quantization.convert(model)

with timer:
    val_loss, val_accuracy = evaluate(model_q, criterion, test_loader)

print(f"Quantized model test loss: {val_loss.item()}")
print(f"Quantized model test accuracy: {val_accuracy.item()}")
print(f"(Theoretical) memory consumption of model: {round(memory_consumption(model, 4), 3)}Mb")

4-bit quantized model (per tensor scheme) took 32.71681s.
Quantized model test loss: 1.8605022430419922
Quantized model test accuracy: 0.3407832384109497
(Theoretical) memory consumption of model: 0.139Mb


In [7]:
timer = Timer("2-bit quantized model (per tensor scheme)")
num_calibration_batches = 20

model = ResNet20(configuration=(3, 2, 2), num_classes=10, quantize=True)
model.load_state_dict(torch.load("checkpoints/resnet20_final.pth"))
model.eval()
model.qconfig = torch.ao.quantization.qconfig.QConfig(
    SimpleObserver.with_args(dtype=torch.quint8, quant_min=0, quant_max=2 ** 2 - 1),
    SimpleObserver.with_args(dtype=torch.qint8,
                             qscheme=torch.per_tensor_symmetric, quant_min=-2, quant_max=1))
model.fuse_model()
model = torch.ao.quantization.prepare(model, inplace=True)
_, _ = evaluate(model, criterion, train_loader, num_batches=num_calibration_batches)
model_q = torch.ao.quantization.convert(model)

with timer:
    val_loss, val_accuracy = evaluate(model_q, criterion, test_loader)

print(f"Quantized model test loss: {val_loss.item()}")
print(f"Quantized model test accuracy: {val_accuracy.item()}")
print(f"(Theoretical) memory consumption of model: {round(memory_consumption(model, 2), 3)}Mb")

2-bit quantized model (per tensor scheme) took 32.92772s.
Quantized model test loss: 2.3025832176208496
Quantized model test accuracy: 0.1002769023180008
(Theoretical) memory consumption of model: 0.07Mb


## Custom quantization engine

In [5]:
from src.ptq.model_quantizer import ModelQuantizer
from src.ptq.activation_observer import SimpleObserver

timer = Timer("16-bit quantized model (per tensor scheme)")
num_calibration_batches = 20

model = ResNet20(configuration=(3, 2, 2), num_classes=10, quantize=True)
model.load_state_dict(torch.load("checkpoints/resnet20_final.pth"))

model.eval()
model.fuse_model()
mq = ModelQuantizer(model, SimpleObserver, num_bits=16, dtype=torch.int16)
mq.calibrate()
_, _ = evaluate(mq, criterion, train_loader, num_batches=num_calibration_batches)
mq.quantize()

with timer:
    val_loss, val_accuracy = evaluate(mq, criterion, test_loader)

print(f"Quantized model test loss: {val_loss.item()}")
print(f"Quantized model test accuracy: {val_accuracy.item()}")
print(f"(Theoretical) memory consumption of model: {round(memory_consumption(model, 16), 3)}Mb")

[W NNPACK.cpp:53] Could not initialize NNPACK! Reason: Unsupported hardware.


16-bit quantized model (per tensor scheme) took 361.7651s.
Quantized model test loss: 4.250782012939453
Quantized model test accuracy: 0.09889240562915802
(Theoretical) memory consumption of model: 0.557Mb


In [6]:
timer = Timer("8-bit quantized model (per tensor scheme)")
num_calibration_batches = 20

model = ResNet20(configuration=(3, 2, 2), num_classes=10, quantize=True)
model.load_state_dict(torch.load("checkpoints/resnet20_final.pth"))

model.eval()
model.fuse_model()
mq = ModelQuantizer(model, SimpleObserver, num_bits=16, dtype=torch.int16, quant_min=0, quant_max=2 ** 8 - 1)
mq.calibrate()
_, _ = evaluate(mq, criterion, train_loader, num_batches=num_calibration_batches)
mq.quantize()

with timer:
    val_loss, val_accuracy = evaluate(mq, criterion, test_loader)

print(f"Quantized model test loss: {val_loss.item()}")
print(f"Quantized model test accuracy: {val_accuracy.item()}")
print(f"(Theoretical) memory consumption of model: {round(memory_consumption(model, 8), 3)}Mb")

8-bit quantized model (per tensor scheme) took 362.56452s.
Quantized model test loss: 877.6320190429688
Quantized model test accuracy: 0.09889240562915802
(Theoretical) memory consumption of model: 0.279Mb


In [7]:
timer = Timer("4-bit quantized model (per tensor scheme)")
num_calibration_batches = 20

model = ResNet20(configuration=(3, 2, 2), num_classes=10, quantize=True)
model.load_state_dict(torch.load("checkpoints/resnet20_final.pth"))

model.eval()
model.fuse_model()
mq = ModelQuantizer(model, SimpleObserver, num_bits=16, dtype=torch.int16, quant_min=0, quant_max=2 ** 4 - 1)
mq.calibrate()
_, _ = evaluate(mq, criterion, train_loader, num_batches=num_calibration_batches)
mq.quantize()

with timer:
    val_loss, val_accuracy = evaluate(mq, criterion, test_loader)

print(f"Quantized model test loss: {val_loss.item()}")
print(f"Quantized model test accuracy: {val_accuracy.item()}")
print(f"(Theoretical) memory consumption of model: {round(memory_consumption(model, 4), 3)}Mb")

4-bit quantized model (per tensor scheme) took 353.24939s.
Quantized model test loss: 15700.380859375
Quantized model test accuracy: 0.09889240562915802
(Theoretical) memory consumption of model: 0.139Mb


In [8]:
timer = Timer("2-bit quantized model (per tensor scheme)")
num_calibration_batches = 20

model = ResNet20(configuration=(3, 2, 2), num_classes=10, quantize=True)
model.load_state_dict(torch.load("checkpoints/resnet20_final.pth"))

model.eval()
model.fuse_model()
mq = ModelQuantizer(model, SimpleObserver, num_bits=16, dtype=torch.int16, quant_min=0, quant_max=2 ** 2 - 1)
mq.calibrate()
_, _ = evaluate(mq, criterion, train_loader, num_batches=num_calibration_batches)
mq.quantize()

with timer:
    val_loss, val_accuracy = evaluate(mq, criterion, test_loader)

print(f"Quantized model test loss: {val_loss.item()}")
print(f"Quantized model test accuracy: {val_accuracy.item()}")
print(f"(Theoretical) memory consumption of model: {round(memory_consumption(model, 2), 3)}Mb")

2-bit quantized model (per tensor scheme) took 364.39447s.
Quantized model test loss: 76492.9140625
Quantized model test accuracy: 0.09889240562915802
(Theoretical) memory consumption of model: 0.07Mb


# Additional experiments
## Per channel quantization

Model quantized to 8-bit with per channel quantization scheme (native torch implementation).

In [8]:
timer = Timer("8-bit quantized model (per channel scheme)")
num_calibration_batches = 20

model = ResNet20(configuration=(3, 2, 2), num_classes=10, quantize=True)
model.load_state_dict(torch.load("checkpoints/resnet20_final.pth"))
model.eval()
model.qconfig = torch.ao.quantization.get_default_qconfig("fbgemm")
model.fuse_model()
model = torch.ao.quantization.prepare(model, inplace=True)
_, _ = evaluate(model, criterion, train_loader, num_batches=num_calibration_batches)
model_q = torch.ao.quantization.convert(model)

with timer:
    val_loss, val_accuracy = evaluate(model_q, criterion, test_loader)

print(f"Quantized model test loss: {val_loss.item()}")
print(f"Quantized model test accuracy: {val_accuracy.item()}")
print(f"(Theoretical) memory consumption of model: {round(memory_consumption(model, 8), 3)}Mb")

8-bit quantized model (per channel scheme) took 34.25016s.
Quantized model test loss: 0.383556067943573
Quantized model test accuracy: 0.9060522317886353
(Theoretical) memory consumption of model: 0.279Mb


Model quantized to 8-bit with per channel quantization scheme (custom Observer implementation).

In [11]:
from src.ptq.activation_observer import PerChannelObserver, SimpleObserver

timer = Timer("8-bit quantized model (per channel scheme)")
num_calibration_batches = 20

model = ResNet20(configuration=(3, 2, 2), num_classes=10, quantize=True)
model.load_state_dict(torch.load("checkpoints/resnet20_final.pth"))
model.eval()
model.qconfig = torch.ao.quantization.qconfig.QConfig(SimpleObserver.with_args(dtype=torch.quint8,
                                                                               qscheme=torch.per_tensor_affine),
                                                      PerChannelObserver.with_args(dtype=torch.qint8,
                                                                                   qscheme=torch.per_channel_symmetric))
model.fuse_model()
model = torch.ao.quantization.prepare(model, inplace=True)
_, _ = evaluate(model, criterion, train_loader, num_batches=num_calibration_batches)
model_q = torch.ao.quantization.convert(model)

with timer:
    val_loss, val_accuracy = evaluate(model_q, criterion, test_loader)

print(f"Quantized model test loss: {val_loss.item()}")
print(f"Quantized model test accuracy: {val_accuracy.item()}")
print(f"(Theoretical) memory consumption of model: {round(memory_consumption(model, 8), 3)}Mb")

8-bit quantized model (per tensor scheme) took 32.75303s.
Quantized model test loss: 0.4082055389881134
Quantized model test accuracy: 0.9034810066223145
(Theoretical) memory consumption of model: 0.279Mb


Model quantized to 4-bit with per channel quantization scheme (custom Observer implementation).

In [14]:
from src.ptq.activation_observer import PerChannelObserver, SimpleObserver

timer = Timer("4-bit quantized model (per channel scheme)")
num_calibration_batches = 20

model = ResNet20(configuration=(3, 2, 2), num_classes=10, quantize=True)
model.load_state_dict(torch.load("checkpoints/resnet20_final.pth"))
model.eval()
model.qconfig = torch.ao.quantization.qconfig.QConfig(SimpleObserver.with_args(dtype=torch.quint8,
                                                                               qscheme=torch.per_tensor_affine, quant_min=0, quant_max=2**4-1),
                                                      PerChannelObserver.with_args(dtype=torch.qint8,
                                                                                   qscheme=torch.per_channel_symmetric, quant_min=-2**3, quant_max=2**3 - 1))
model.fuse_model()
model = torch.ao.quantization.prepare(model, inplace=True)
_, _ = evaluate(model, criterion, train_loader, num_batches=num_calibration_batches)
model_q = torch.ao.quantization.convert(model)

with timer:
    val_loss, val_accuracy = evaluate(model_q, criterion, test_loader)

print(f"Quantized model test loss: {val_loss.item()}")
print(f"Quantized model test accuracy: {val_accuracy.item()}")
print(f"(Theoretical) memory consumption of model: {round(memory_consumption(model, 4), 3)}Mb")

8-bit quantized model (per tensor scheme) took 32.97311s.
Quantized model test loss: 1.8132280111312866
Quantized model test accuracy: 0.3720332384109497
(Theoretical) memory consumption of model: 0.139Mb


Model quantized to 2-bit with per channel quantization scheme (custom Observer implementation).

In [13]:
from src.ptq.activation_observer import PerChannelObserver, SimpleObserver

timer = Timer("2-bit quantized model (per channel scheme)")
num_calibration_batches = 20

model = ResNet20(configuration=(3, 2, 2), num_classes=10, quantize=True)
model.load_state_dict(torch.load("checkpoints/resnet20_final.pth"))
model.eval()
model.qconfig = torch.ao.quantization.qconfig.QConfig(SimpleObserver.with_args(dtype=torch.quint8,
                                                                               qscheme=torch.per_tensor_affine, quant_min=0, quant_max=2**2-1),
                                                      PerChannelObserver.with_args(dtype=torch.qint8,
                                                                                   qscheme=torch.per_channel_symmetric, quant_min=-2, quant_max=1))
model.fuse_model()
model = torch.ao.quantization.prepare(model, inplace=True)
_, _ = evaluate(model, criterion, train_loader, num_batches=num_calibration_batches)
model_q = torch.ao.quantization.convert(model)

with timer:
    val_loss, val_accuracy = evaluate(model_q, criterion, test_loader)

print(f"Quantized model test loss: {val_loss.item()}")
print(f"Quantized model test accuracy: {val_accuracy.item()}")
print(f"(Theoretical) memory consumption of model: {round(memory_consumption(model, 2), 3)}Mb")

8-bit quantized model (per tensor scheme) took 33.21948s.
Quantized model test loss: 2.3041255474090576
Quantized model test accuracy: 0.1002769023180008
(Theoretical) memory consumption of model: 0.07Mb


## QAT

In [46]:
qat_model = ResNet20(configuration=(3, 2, 2), num_classes=10, quantize=True)
qat_model.load_state_dict(torch.load("checkpoints/resnet20_final.pth"))
qat_model.eval()
qat_model.fuse_model()
qat_model.train()

optimizer = torch.optim.SGD(qat_model.parameters(), lr=1e-5)
qat_model.qconfig = torch.ao.quantization.get_default_qat_qconfig("fbgemm")

torch.ao.quantization.prepare_qat(qat_model, inplace=True)
qat_model.cuda()
print()




In [47]:
def qat_train_epoch(model, criterion, optimizer, loader, device="cpu"):
    model.to(device)
    model.train()
    losses = torch.zeros((1,), device=device)
    for batch_id, (input, target) in enumerate(loader):
        input = input.to(device, non_blocking=True)
        target = target.to(device, non_blocking=True)

        output = model(input)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad(set_to_none=True)

        losses += loss.detach()

    return losses / len(loader)

In [48]:
timer = Timer("QAT model")
num_eval_batches = 32
qat_train_epochs = 10

for epoch in range(qat_train_epochs):
    qat_train_epoch(qat_model, criterion, optimizer, train_loader, device="cuda")
    if epoch > 3:
        qat_model.apply(torch.ao.quantization.disable_observer)
    if epoch > 2:
        qat_model.apply(torch.nn.intrinsic.qat.freeze_bn_stats)

    model_q = torch.ao.quantization.convert(qat_model.cpu(), inplace=False)
    val_loss, val_accuracy = evaluate(model_q, criterion, test_loader, num_batches=num_eval_batches)
    print(f"test loss: {round(val_loss.item(),  4)},\ttest accuracy: {round(val_accuracy.item(),  4)}")

model_q = torch.ao.quantization.convert(qat_model.cpu(), inplace=False)
with timer:
    val_loss, val_accuracy = evaluate(model_q, criterion, test_loader)

print()
print(f"QAT model test loss: {val_loss.item()}")
print(f"QAT model test accuracy: {val_accuracy.item()}")

test loss: 0.1653,	test accuracy: 0.3895
test loss: 0.1604,	test accuracy: 0.3898
test loss: 0.1558,	test accuracy: 0.3896
test loss: 0.153,	test accuracy: 0.3901
test loss: 0.1495,	test accuracy: 0.39
test loss: 0.1469,	test accuracy: 0.3902
test loss: 0.1459,	test accuracy: 0.3899
test loss: 0.1434,	test accuracy: 0.3903
test loss: 0.142,	test accuracy: 0.3903
test loss: 0.1408,	test accuracy: 0.3901
QAT model took 32.70717s.

QAT model test loss: 0.31536567211151123
QAT model test accuracy: 0.909414529800415


Now we additionally quantize QAT model using custom PTQ (to 4 and 2 bits).