# Quantization

In [25]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
from torchvision import models

## Load CIFAR100

In [26]:
def load_data(batch_size=128):
    """Load CIFAR-100 dataset."""
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
    ])
    
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
    ])
    
    trainset = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transform_train)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)
    
    testset = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=transform_test)
    testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)
    
    return trainloader, testloader

In [27]:

def finetune_vgg11(model, trainloader, testloader, device, epochs=5, freeze_layers=False):
    if freeze_layers:
        for param in model.features.parameters():
            param.requires_grad = False

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=0.01, momentum=0.9, weight_decay=5e-4)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)

    best_acc = 0
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        for i, (inputs, labels) in enumerate(trainloader):
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()

            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

            if i % 50 == 49:
                print(f'Finetuning: Epoch {epoch + 1}, Batch {i + 1}, '
                      f'Loss: {running_loss / 50:.3f}, Acc: {100. * correct / total:.2f}%')
                running_loss = 0.0

        # Evaluate
        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for inputs, labels in testloader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                _, predicted = outputs.max(1)
                total += labels.size(0)
                correct += predicted.eq(labels).sum().item()

        acc = 100. * correct / total
        print(f'Finetuning Epoch {epoch + 1}: Validation Accuracy: {acc:.2f}%')

        if acc > best_acc:
            best_acc = acc
            torch.save(model.state_dict(), '/kaggle/working/best_vgg11_finetuned.pth')

        scheduler.step()

    model.load_state_dict(torch.load('/kaggle/working/best_vgg11_finetuned.pth'))
    return model

In [28]:
def evaluate_accuracy(model, testloader, device):
    """Evaluate the accuracy of the model on the test set."""
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in testloader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

    accuracy = 100. * correct / total
    print(f'Final Test Accuracy of the Finetuned Model: {accuracy:.2f}%')
    return accuracy

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = models.vgg11(pretrained=True)
model.classifier[-1] = nn.Linear(4096, 100) 
model = model.to(device)

trainloader, testloader = load_data(batch_size=128)

finetuned_model = finetune_vgg11(model, trainloader, testloader, device, epochs=5, freeze_layers=False)

evaluate_accuracy(finetuned_model, testloader, device)

Files already downloaded and verified
Files already downloaded and verified
Finetuning: Epoch 1, Batch 50, Loss: 4.120, Acc: 7.39%
Finetuning: Epoch 1, Batch 100, Loss: 3.244, Acc: 13.35%
Finetuning: Epoch 1, Batch 150, Loss: 2.838, Acc: 18.08%
Finetuning: Epoch 1, Batch 200, Loss: 2.700, Acc: 21.16%
Finetuning: Epoch 1, Batch 250, Loss: 2.584, Acc: 23.33%
Finetuning: Epoch 1, Batch 300, Loss: 2.447, Acc: 25.45%
Finetuning: Epoch 1, Batch 350, Loss: 2.351, Acc: 27.32%
Finetuning Epoch 1: Validation Accuracy: 43.05%
Finetuning: Epoch 2, Batch 50, Loss: 2.152, Acc: 41.42%
Finetuning: Epoch 2, Batch 100, Loss: 2.036, Acc: 43.15%
Finetuning: Epoch 2, Batch 150, Loss: 2.055, Acc: 43.80%
Finetuning: Epoch 2, Batch 200, Loss: 1.998, Acc: 44.14%
Finetuning: Epoch 2, Batch 250, Loss: 1.946, Acc: 44.72%
Finetuning: Epoch 2, Batch 300, Loss: 1.972, Acc: 44.91%
Finetuning: Epoch 2, Batch 350, Loss: 1.959, Acc: 45.13%
Finetuning Epoch 2: Validation Accuracy: 49.83%
Finetuning: Epoch 3, Batch 50, Lo

  model.load_state_dict(torch.load('/kaggle/working/best_vgg11_finetuned.pth'))


Final Test Accuracy of the Finetuned Model: 61.84%


61.84

In [29]:
import copy

copy1 = copy.deepcopy(finetuned_model)
copy2 = copy.deepcopy(finetuned_model)
copy3 = copy.deepcopy(finetuned_model)
copy4 = copy.deepcopy(finetuned_model)

In [30]:
def ptq_quantization(model, testloader, device, dtype=torch.float16):
    model = model.to(dtype=dtype)  
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in testloader:
            inputs, labels = inputs.to(device, dtype=dtype), labels.to(device)
            outputs = model(inputs)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
    
    accuracy = 100. * correct / total
    print(f'PTQ Test Accuracy with dtype={dtype}: {accuracy:.2f}%')
    return accuracy

ptq_quantization(copy1, testloader, device, dtype=torch.float16)

ptq_quantization(copy2, testloader, device, dtype=torch.bfloat16)

PTQ Test Accuracy with dtype=torch.float16: 61.83%
PTQ Test Accuracy with dtype=torch.bfloat16: 61.89%


61.89

In [31]:
from torch.amp.autocast_mode import autocast

def qat_finetuning(model, trainloader, testloader, device, epochs=5, dtype=torch.float16):
    model = model.to(device)  
    if dtype == torch.float16:
        model = model.half()
    elif dtype == torch.bfloat16:
        model = model.to(torch.bfloat16)  
    
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=0.01, momentum=0.9, weight_decay=5e-4)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
    
    best_acc = 0

    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        for i, (inputs, labels) in enumerate(trainloader):
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()

            with autocast(device_type='cuda'):
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

            if i % 50 == 49:
                print(f'QAT: Epoch {epoch + 1}, Batch {i + 1}, '
                      f'Loss: {running_loss / 50:.3f}, Acc: {100. * correct / total:.2f}%')
                running_loss = 0.0

        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for inputs, labels in testloader:
                inputs, labels = inputs.to(device), labels.to(device)
                with autocast(device_type='cuda'):
                    outputs = model(inputs)
                _, predicted = outputs.max(1)
                total += labels.size(0)
                correct += predicted.eq(labels).sum().item()

        acc = 100. * correct / total
        print(f'QAT Epoch {epoch + 1}: Validation Accuracy: {acc:.2f}%')

        if acc > best_acc:
            best_acc = acc
            torch.save(model.state_dict(), '/kaggle/working/best_vgg11_qat.pth')

        scheduler.step()

    model.load_state_dict(torch.load('/kaggle/working/best_vgg11_qat.pth'))
    return model

qat_finetuned_model_float16 = qat_finetuning(copy1, trainloader, testloader, device, epochs=5, dtype=torch.float16)

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x786554598c10>
Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1477, in __del__
    Exception ignored in: self._shutdown_workers()
<function _MultiProcessingDataLoaderIter.__del__ at 0x786554598c10>
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1460, in _shutdown_workers
Traceback (most recent call last):
    if w.is_alive():
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1477, in __del__
  File "/opt/conda/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
        self._shutdown_workers()assert self._parent_pid == os.getpid(), 'can only test a child process'

  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1460, in _shutdown_workers
AssertionError    : if w.is_alive():can only test a child process

  File "/op

QAT: Epoch 1, Batch 50, Loss: 1.378, Acc: 61.39%
QAT: Epoch 1, Batch 100, Loss: 1.576, Acc: 59.00%
QAT: Epoch 1, Batch 150, Loss: 1.591, Acc: 57.76%
QAT: Epoch 1, Batch 200, Loss: 1.539, Acc: 57.57%
QAT: Epoch 1, Batch 250, Loss: 1.609, Acc: 57.30%
QAT: Epoch 1, Batch 300, Loss: 1.591, Acc: 57.12%
QAT: Epoch 1, Batch 350, Loss: 1.578, Acc: 57.03%


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x786554598c10>
Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1477, in __del__
    self._shutdown_workers()
Exception ignored in:   File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1460, in _shutdown_workers
<function _MultiProcessingDataLoaderIter.__del__ at 0x786554598c10>
Traceback (most recent call last):
      File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1477, in __del__
if w.is_alive():
      File "/opt/conda/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
    self._shutdown_workers()assert self._parent_pid == os.getpid(), 'can only test a child process'

  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1460, in _shutdown_workers
    if w.is_alive():AssertionError
  File "/opt/conda/lib/python3.10/multiproc

QAT Epoch 1: Validation Accuracy: 56.62%


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x786554598c10>
Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1477, in __del__
    self._shutdown_workers()
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1460, in _shutdown_workers
    if w.is_alive():
  File "/opt/conda/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child processException ignored in: 
<function _MultiProcessingDataLoaderIter.__del__ at 0x786554598c10>
Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1477, in __del__
    self._shutdown_workers()
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1460, in _shutdown_workers
    if w.is_alive():
  File "/op

QAT: Epoch 2, Batch 50, Loss: 1.401, Acc: 60.52%
QAT: Epoch 2, Batch 100, Loss: 1.394, Acc: 60.02%
QAT: Epoch 2, Batch 150, Loss: 1.413, Acc: 59.78%
QAT: Epoch 2, Batch 200, Loss: 1.418, Acc: 59.83%
QAT: Epoch 2, Batch 250, Loss: 1.409, Acc: 60.00%
QAT: Epoch 2, Batch 300, Loss: 1.398, Acc: 60.22%
QAT: Epoch 2, Batch 350, Loss: 1.460, Acc: 60.04%


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x786554598c10>
Exception ignored in: Traceback (most recent call last):
<function _MultiProcessingDataLoaderIter.__del__ at 0x786554598c10>  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1477, in __del__

Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1477, in __del__
        self._shutdown_workers()self._shutdown_workers()

  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1460, in _shutdown_workers
      File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1460, in _shutdown_workers
if w.is_alive():
  File "/opt/conda/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'    
if w.is_alive():AssertionError
: can only test a child process
  File "/op

QAT Epoch 2: Validation Accuracy: 57.58%


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x786554598c10>
Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1477, in __del__
    self._shutdown_workers()
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1460, in _shutdown_workers
    if w.is_alive():
  File "/opt/conda/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x786554598c10>
Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1477, in __del__
    self._shutdown_workers()
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1460, in _shutdown_workers
    if w.is_alive():
  File "/op

QAT: Epoch 3, Batch 50, Loss: 1.204, Acc: 65.47%
QAT: Epoch 3, Batch 100, Loss: 1.203, Acc: 65.36%
QAT: Epoch 3, Batch 150, Loss: 1.211, Acc: 65.35%
QAT: Epoch 3, Batch 200, Loss: 1.166, Acc: 65.58%
QAT: Epoch 3, Batch 250, Loss: 1.172, Acc: 65.65%
QAT: Epoch 3, Batch 300, Loss: 1.205, Acc: 65.64%
QAT: Epoch 3, Batch 350, Loss: 1.184, Acc: 65.72%


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x786554598c10>
Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1477, in __del__
    self._shutdown_workers()
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1460, in _shutdown_workers
    if w.is_alive():
  File "/opt/conda/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: Exception ignored in: can only test a child process<function _MultiProcessingDataLoaderIter.__del__ at 0x786554598c10>

Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1477, in __del__
    self._shutdown_workers()
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1460, in _shutdown_workers
    if w.is_alive():
  File "/op

QAT Epoch 3: Validation Accuracy: 60.87%


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x786554598c10>
Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1477, in __del__
    self._shutdown_workers()
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1460, in _shutdown_workers
    if w.is_alive():
  File "/opt/conda/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x786554598c10>
Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1477, in __del__
    self._shutdown_workers()
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1460, in _shutdown_workers
    if w.is_alive():
  File "/op

QAT: Epoch 4, Batch 50, Loss: 0.964, Acc: 71.80%
QAT: Epoch 4, Batch 100, Loss: 0.965, Acc: 71.65%
QAT: Epoch 4, Batch 150, Loss: 0.988, Acc: 71.35%
QAT: Epoch 4, Batch 200, Loss: 0.953, Acc: 71.45%
QAT: Epoch 4, Batch 250, Loss: 0.997, Acc: 71.42%
QAT: Epoch 4, Batch 300, Loss: 0.939, Acc: 71.53%
QAT: Epoch 4, Batch 350, Loss: 0.914, Acc: 71.70%


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x786554598c10>
Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1477, in __del__
    self._shutdown_workers()
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1460, in _shutdown_workers
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x786554598c10>    if w.is_alive():

Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1477, in __del__
  File "/opt/conda/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
        assert self._parent_pid == os.getpid(), 'can only test a child process'self._shutdown_workers()
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1460, in _shutdown_workers

AssertionError:     can only test a child processif w.is_alive():

  File "/op

QAT Epoch 4: Validation Accuracy: 63.62%


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x786554598c10>
Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1477, in __del__
    self._shutdown_workers()
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1460, in _shutdown_workers
    if w.is_alive():
  File "/opt/conda/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: Exception ignored in: can only test a child process
<function _MultiProcessingDataLoaderIter.__del__ at 0x786554598c10>
Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1477, in __del__
    self._shutdown_workers()
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1460, in _shutdown_workers
    if w.is_alive():
  File "/op

QAT: Epoch 5, Batch 50, Loss: 0.821, Acc: 76.11%
QAT: Epoch 5, Batch 100, Loss: 0.789, Acc: 76.31%
QAT: Epoch 5, Batch 150, Loss: 0.783, Acc: 76.22%
QAT: Epoch 5, Batch 200, Loss: 0.774, Acc: 76.27%
QAT: Epoch 5, Batch 250, Loss: 0.798, Acc: 76.18%
QAT: Epoch 5, Batch 300, Loss: 0.797, Acc: 76.07%
QAT: Epoch 5, Batch 350, Loss: 0.768, Acc: 76.21%


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x786554598c10>
Traceback (most recent call last):
Exception ignored in:   File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1477, in __del__
<function _MultiProcessingDataLoaderIter.__del__ at 0x786554598c10>    
self._shutdown_workers()
Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1460, in _shutdown_workers
    if w.is_alive():
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1477, in __del__
  File "/opt/conda/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'    
self._shutdown_workers()AssertionError
:   File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1460, in _shutdown_workers
can only test a child process    
if w.is_alive():
  File "/op

QAT Epoch 5: Validation Accuracy: 65.41%


  model.load_state_dict(torch.load('/kaggle/working/best_vgg11_qat.pth'))


In [32]:
    qat_finetuned_model_bfloat16 = qat_finetuning(copy2, trainloader, testloader, device, epochs=5, dtype=torch.bfloat16)

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x786554598c10>
Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1477, in __del__
    self._shutdown_workers()
Exception ignored in:   File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1460, in _shutdown_workers
<function _MultiProcessingDataLoaderIter.__del__ at 0x786554598c10>
    Traceback (most recent call last):
if w.is_alive():
  File "/opt/conda/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1477, in __del__
    self._shutdown_workers()    
assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1460, in _shutdown_workers
: can only test a child process    
if w.is_alive():
  File "/op

QAT: Epoch 1, Batch 50, Loss: 1.342, Acc: 62.14%
QAT: Epoch 1, Batch 100, Loss: 1.482, Acc: 60.20%
QAT: Epoch 1, Batch 150, Loss: 1.493, Acc: 59.38%
QAT: Epoch 1, Batch 200, Loss: 1.523, Acc: 58.90%
QAT: Epoch 1, Batch 250, Loss: 1.501, Acc: 58.68%
QAT: Epoch 1, Batch 300, Loss: 1.529, Acc: 58.45%
QAT: Epoch 1, Batch 350, Loss: 1.508, Acc: 58.38%


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x786554598c10>
Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1477, in __del__
    self._shutdown_workers()
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1460, in _shutdown_workers
    if w.is_alive():
Exception ignored in:   File "/opt/conda/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
<function _MultiProcessingDataLoaderIter.__del__ at 0x786554598c10>    
assert self._parent_pid == os.getpid(), 'can only test a child process'Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1477, in __del__
    
self._shutdown_workers()
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1460, in _shutdown_workers
AssertionError: can only test a child process
    if w.is_alive():
  File "/op

QAT Epoch 1: Validation Accuracy: 56.39%


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x786554598c10>
Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1477, in __del__
    self._shutdown_workers()
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1460, in _shutdown_workers
    if w.is_alive():
  File "/opt/conda/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
Exception ignored in: AssertionError<function _MultiProcessingDataLoaderIter.__del__ at 0x786554598c10>: 
can only test a child processTraceback (most recent call last):

  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1477, in __del__
    self._shutdown_workers()
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1460, in _shutdown_workers
    if w.is_alive():
  File "/op

QAT: Epoch 2, Batch 50, Loss: 1.370, Acc: 60.83%
QAT: Epoch 2, Batch 100, Loss: 1.415, Acc: 60.61%
QAT: Epoch 2, Batch 150, Loss: 1.347, Acc: 60.78%
QAT: Epoch 2, Batch 200, Loss: 1.381, Acc: 60.59%
QAT: Epoch 2, Batch 250, Loss: 1.319, Acc: 60.83%
QAT: Epoch 2, Batch 300, Loss: 1.333, Acc: 61.20%
QAT: Epoch 2, Batch 350, Loss: 1.372, Acc: 61.29%


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x786554598c10>
Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1477, in __del__
    self._shutdown_workers()Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x786554598c10>
Traceback (most recent call last):

  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1460, in _shutdown_workers
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1477, in __del__
        self._shutdown_workers()
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1460, in _shutdown_workers
if w.is_alive():
      File "/opt/conda/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
    if w.is_alive():assert self._parent_pid == os.getpid(), 'can only test a child process'
  File "/opt/conda/lib/python3.10/multiprocessing/process.

QAT Epoch 2: Validation Accuracy: 58.02%


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x786554598c10>
Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1477, in __del__
    self._shutdown_workers()
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1460, in _shutdown_workers
    if w.is_alive():
  File "/opt/conda/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x786554598c10>
Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1477, in __del__
    self._shutdown_workers()
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1460, in _shutdown_workers
    if w.is_alive():
  File "/op

QAT: Epoch 3, Batch 50, Loss: 1.176, Acc: 66.00%
QAT: Epoch 3, Batch 100, Loss: 1.142, Acc: 66.48%
QAT: Epoch 3, Batch 150, Loss: 1.150, Acc: 66.54%
QAT: Epoch 3, Batch 200, Loss: 1.124, Acc: 66.69%
QAT: Epoch 3, Batch 250, Loss: 1.120, Acc: 66.75%
QAT: Epoch 3, Batch 300, Loss: 1.124, Acc: 66.79%
QAT: Epoch 3, Batch 350, Loss: 1.157, Acc: 66.79%


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x786554598c10>
Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1477, in __del__
    self._shutdown_workers()
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1460, in _shutdown_workers
    if w.is_alive():Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x786554598c10>
Traceback (most recent call last):

  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1477, in __del__
  File "/opt/conda/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
        self._shutdown_workers()
assert self._parent_pid == os.getpid(), 'can only test a child process'  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1460, in _shutdown_workers
    if w.is_alive():

AssertionError  File "/opt/conda/lib/python3.10/multiproc

QAT Epoch 3: Validation Accuracy: 61.91%


Exception ignored in: Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x786554598c10><function _MultiProcessingDataLoaderIter.__del__ at 0x786554598c10>

Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1477, in __del__
Traceback (most recent call last):
      File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1477, in __del__
self._shutdown_workers()    self._shutdown_workers()

  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1460, in _shutdown_workers
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1460, in _shutdown_workers
    if w.is_alive():    
  File "/opt/conda/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
if w.is_alive():    
assert self._parent_pid == os.getpid(), 'can only test a child process'  File "/opt/conda/lib/python3.10/multiprocessing/process.

QAT: Epoch 4, Batch 50, Loss: 0.958, Acc: 71.67%
QAT: Epoch 4, Batch 100, Loss: 0.925, Acc: 72.05%
QAT: Epoch 4, Batch 150, Loss: 0.940, Acc: 72.20%
QAT: Epoch 4, Batch 200, Loss: 0.888, Acc: 72.49%
QAT: Epoch 4, Batch 250, Loss: 0.925, Acc: 72.48%
QAT: Epoch 4, Batch 300, Loss: 0.929, Acc: 72.45%
QAT: Epoch 4, Batch 350, Loss: 0.944, Acc: 72.41%


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x786554598c10>
Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1477, in __del__
Exception ignored in:     <function _MultiProcessingDataLoaderIter.__del__ at 0x786554598c10>self._shutdown_workers()

Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1477, in __del__
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1460, in _shutdown_workers
        if w.is_alive():self._shutdown_workers()
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1460, in _shutdown_workers

      File "/opt/conda/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
if w.is_alive():
  File "/opt/conda/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
        assert self._parent_pid == os.getpid(), 

QAT Epoch 4: Validation Accuracy: 64.01%


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x786554598c10>
Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1477, in __del__
    self._shutdown_workers()
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1460, in _shutdown_workers
    if w.is_alive():
  File "/opt/conda/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x786554598c10>
Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1477, in __del__
    self._shutdown_workers()
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1460, in _shutdown_workers
    if w.is_alive():
  File "/op

QAT: Epoch 5, Batch 50, Loss: 0.879, Acc: 73.66%
QAT: Epoch 5, Batch 100, Loss: 0.863, Acc: 73.97%
QAT: Epoch 5, Batch 150, Loss: 0.852, Acc: 74.20%
QAT: Epoch 5, Batch 200, Loss: 0.822, Acc: 74.54%
QAT: Epoch 5, Batch 250, Loss: 0.815, Acc: 74.75%
QAT: Epoch 5, Batch 300, Loss: 0.788, Acc: 74.98%
QAT: Epoch 5, Batch 350, Loss: 0.824, Acc: 74.99%


Exception ignored in: Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x786554598c10><function _MultiProcessingDataLoaderIter.__del__ at 0x786554598c10>

Traceback (most recent call last):
Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1477, in __del__
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1477, in __del__
    self._shutdown_workers()    self._shutdown_workers()
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1460, in _shutdown_workers
    if w.is_alive():
  File "/opt/conda/lib/python3.10/multiprocessing/process.py", line 160, in is_alive

  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1460, in _shutdown_workers
    assert self._parent_pid == os.getpid(), 'can only test a child process'    if w.is_alive():
  File "/opt/conda/lib/python3.10/multiprocessing/process.

QAT Epoch 5: Validation Accuracy: 64.93%


  model.load_state_dict(torch.load('/kaggle/working/best_vgg11_qat.pth'))


In [33]:
!pip install torchao



In [34]:
from torchao.quantization.quant_api import (
    quantize_,
    int8_dynamic_activation_int4_weight,
    int8_dynamic_activation_int8_weight,
    int4_weight_only,
    int8_weight_only
)

In [35]:
def qat_finetuning_int8(model, trainloader, testloader, device, epochs=5):
    model = model.to(device)  

    quantize_(model, int8_dynamic_activation_int8_weight())
    
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=0.01, momentum=0.9, weight_decay=5e-4)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
    
    best_acc = 0

    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        for i, (inputs, labels) in enumerate(trainloader):
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()

            with autocast(device_type='cuda'):
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

            if i % 50 == 49:
                print(f'QAT (int8): Epoch {epoch + 1}, Batch {i + 1}, '
                      f'Loss: {running_loss / 50:.3f}, Acc: {100. * correct / total:.2f}%')
                running_loss = 0.0

        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for inputs, labels in testloader:
                inputs, labels = inputs.to(device), labels.to(device)
                with autocast(device_type='cuda'):
                    outputs = model(inputs)
                _, predicted = outputs.max(1)
                total += labels.size(0)
                correct += predicted.eq(labels).sum().item()

        acc = 100. * correct / total
        print(f'QAT (int8) Epoch {epoch + 1}: Validation Accuracy: {acc:.2f}%')

        if acc > best_acc:
            best_acc = acc
            torch.save(model.state_dict(), '/kaggle/working/best_vgg11_qat_int8.pth')

        scheduler.step()

    # model.load_state_dict(torch.load('/kaggle/working/best_vgg11_qat_int8.pth'))
    # return model


qat_finetuning_int8(copy3, trainloader, testloader, device, epochs=5)

QAT (int8): Epoch 1, Batch 50, Loss: 1.125, Acc: 66.20%
QAT (int8): Epoch 1, Batch 100, Loss: 1.118, Acc: 66.52%
QAT (int8): Epoch 1, Batch 150, Loss: 1.085, Acc: 67.00%
QAT (int8): Epoch 1, Batch 200, Loss: 1.106, Acc: 67.19%
QAT (int8): Epoch 1, Batch 250, Loss: 1.083, Acc: 67.35%
QAT (int8): Epoch 1, Batch 300, Loss: 1.103, Acc: 67.42%
QAT (int8): Epoch 1, Batch 350, Loss: 1.070, Acc: 67.54%
QAT (int8) Epoch 1: Validation Accuracy: 61.83%
QAT (int8): Epoch 2, Batch 50, Loss: 1.111, Acc: 67.44%
QAT (int8): Epoch 2, Batch 100, Loss: 1.086, Acc: 67.96%
QAT (int8): Epoch 2, Batch 150, Loss: 1.126, Acc: 67.51%
QAT (int8): Epoch 2, Batch 200, Loss: 1.108, Acc: 67.53%
QAT (int8): Epoch 2, Batch 250, Loss: 1.118, Acc: 67.45%
QAT (int8): Epoch 2, Batch 300, Loss: 1.102, Acc: 67.52%
QAT (int8): Epoch 2, Batch 350, Loss: 1.099, Acc: 67.48%
QAT (int8) Epoch 2: Validation Accuracy: 61.85%
QAT (int8): Epoch 3, Batch 50, Loss: 1.090, Acc: 67.66%
QAT (int8): Epoch 3, Batch 100, Loss: 1.093, Acc: 67

In [36]:
def qat_finetuning_int4(model, trainloader, testloader, device, epochs=5):
    model = model.to(device)  

    quantize_(model, int8_dynamic_activation_int4_weight())
    
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=0.01, momentum=0.9, weight_decay=5e-4)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
    
    best_acc = 0

    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        for i, (inputs, labels) in enumerate(trainloader):
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()

            with autocast(device_type='cuda'):
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

            if i % 50 == 49:
                print(f'QAT (int4): Epoch {epoch + 1}, Batch {i + 1}, '
                      f'Loss: {running_loss / 50:.3f}, Acc: {100. * correct / total:.2f}%')
                running_loss = 0.0

        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for inputs, labels in testloader:
                inputs, labels = inputs.to(device), labels.to(device)
                with autocast(device_type='cuda'):
                    outputs = model(inputs)
                _, predicted = outputs.max(1)
                total += labels.size(0)
                correct += predicted.eq(labels).sum().item()

        acc = 100. * correct / total
        print(f'QAT (int4) Epoch {epoch + 1}: Validation Accuracy: {acc:.2f}%')

        if acc > best_acc:
            best_acc = acc
            torch.save(model.state_dict(), '/kaggle/working/best_vgg11_qat_int8.pth')

        scheduler.step()

    # model.load_state_dict(torch.load('/kaggle/working/best_vgg11_qat_int8.pth'))
    # return model


qat_finetuning_int4(copy4, trainloader, testloader, device, epochs=5)

QAT (int4): Epoch 1, Batch 50, Loss: 1.123, Acc: 67.38%
QAT (int4): Epoch 1, Batch 100, Loss: 1.138, Acc: 67.16%
QAT (int4): Epoch 1, Batch 150, Loss: 1.106, Acc: 67.26%
QAT (int4): Epoch 1, Batch 200, Loss: 1.071, Acc: 67.62%
QAT (int4): Epoch 1, Batch 250, Loss: 1.115, Acc: 67.67%
QAT (int4): Epoch 1, Batch 300, Loss: 1.158, Acc: 67.47%
QAT (int4): Epoch 1, Batch 350, Loss: 1.124, Acc: 67.44%
QAT (int8) Epoch 1: Validation Accuracy: 62.01%
QAT (int4): Epoch 2, Batch 50, Loss: 1.141, Acc: 67.33%
QAT (int4): Epoch 2, Batch 100, Loss: 1.105, Acc: 67.68%
QAT (int4): Epoch 2, Batch 150, Loss: 1.096, Acc: 67.65%
QAT (int4): Epoch 2, Batch 200, Loss: 1.097, Acc: 67.73%
QAT (int4): Epoch 2, Batch 250, Loss: 1.109, Acc: 67.75%
QAT (int4): Epoch 2, Batch 300, Loss: 1.104, Acc: 67.82%
QAT (int4): Epoch 2, Batch 350, Loss: 1.124, Acc: 67.78%
QAT (int8) Epoch 2: Validation Accuracy: 62.04%
QAT (int4): Epoch 3, Batch 50, Loss: 1.113, Acc: 67.86%
QAT (int4): Epoch 3, Batch 100, Loss: 1.105, Acc: 67

In [37]:
def ptq_quantization_int8_int4(model, testloader, device, quant_type='int8'):
    if quant_type == 'int8':
        quantize_(model, int8_dynamic_activation_int8_weight())
    elif quant_type == 'int4':
        quantize_(model, int8_dynamic_activation_int4_weight())
    
    model = model.to(device)
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in testloader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
    
    accuracy = 100. * correct / total
    print(f'PTQ Test Accuracy with {quant_type} quantization: {accuracy:.2f}%')
    return accuracy

ptq_quantization_int8_int4(copy3, testloader, device, quant_type='int8')

ptq_quantization_int8_int4(copy4, testloader, device, quant_type='int4')


PTQ Test Accuracy with int8 quantization: 61.85%
PTQ Test Accuracy with int4 quantization: 62.02%


62.02