In [1]:
import torch
from typing import Tuple

In [2]:
def _calculate_scale_and_zeropoint(
    min_val: float, max_val: float, num_bits: int) -> Tuple[float, int]:
    qmin = 0.
    qmax = 2.**num_bits - 1.

    scale = (max_val - min_val) / (qmax - qmin)

    initial_zero_point = qmin - min_val / scale

    zero_point = 0
    if initial_zero_point < qmin:
        zero_point = int(qmin)
    elif initial_zero_point > qmax:
        zero_point = int(qmax)
    else:
        zero_point = int(initial_zero_point)
    
    return scale, zero_point

In [3]:
def quantize(x: torch.Tensor, scale: float, zero_point: int, dtype=torch.uint8):
    q_x = zero_point + x / scale
    q_x.clamp_(0, 255).round_()
    q_x = q_x.to(dtype)
    return q_x

def dequantize(x: torch.Tensor, scale: float, zero_point: int):
    return scale * (x.float() - zero_point)

In [4]:
from copy import deepcopy

def test_case_0():
  torch.manual_seed(999)
  test_input = torch.randn((4,4))

  min_val, max_val = torch.min(test_input), torch.max(test_input)
  scale, zero_point = _calculate_scale_and_zeropoint(min_val, max_val, 8)

  your_quant = quantize(test_input, scale, zero_point)
  your_dequant = dequantize(your_quant, scale, zero_point)

  test_case_0 = torch.Tensor([
      [-0.2623,  1.3991,  0.2842,  1.0275],
      [-0.9838, -3.4104,  1.4866,  0.2405],
      [ 1.4866, -0.3716,  0.0874,  2.1424],
      [ 0.6340, -1.1587, -0.7870,  0.0656]])

  assert torch.allclose(your_dequant, test_case_0, atol=1e-4)
  assert torch.allclose(your_dequant, test_input, atol=5e-2)

  return test_input, your_dequant, your_quant



### Test Case 1
def test_case_1():
  torch.manual_seed(999)
  test_input = torch.randn((8,8))

  min_val, max_val = torch.min(test_input), torch.max(test_input)
  scale, zero_point = _calculate_scale_and_zeropoint(min_val, max_val, 8)

  your_quant = quantize(test_input, scale, zero_point)
  your_dequant = dequantize(your_quant, scale, zero_point)

  test_case_1 = torch.Tensor(
      [[-0.2623,  1.3991,  0.2842,  1.0275, -0.9838, -3.4104,  1.4866,  0.2405],
      [ 1.4866, -0.3716,  0.0874,  2.1424,  0.6340, -1.1587, -0.7870,  0.0656],
      [ 0.0000, -0.6558, -1.0056,  0.3061,  0.6340, -1.0931, -1.6178,  1.5740],
      [-1.7927,  0.6121, -0.7214,  0.6121,  0.3279, -1.5959, -0.5247,  0.3498],
      [-1.3773,  1.1149, -0.7870,  0.2842,  0.9182, -1.1805, -0.7433, -1.5522],
      [ 1.0056, -0.1093,  1.3991, -0.9182, -1.1805, -0.6777, -0.3061,  0.9838],
      [ 0.2186,  1.6396,  1.0712,  1.7489,  0.0874,  0.3498,  0.9838,  1.2024],
      [-0.3935, -0.6340,  1.9238,  1.2898,  0.0219,  0.3935,  1.4866, -0.9401]])

  assert torch.allclose(your_dequant, test_case_1, atol=1e-4)
  assert torch.allclose(your_dequant, test_input, atol=5e-2)

  return test_input, your_dequant, your_quant

In [6]:
# Empirically, report the average and maximum quantization error for the test cases
def test():
  test_input, your_dequant, your_quant = test_case_0()
  test_input, your_dequant, your_quant = test_case_1()

  avg_error = torch.mean(torch.abs(test_input - your_dequant))
  max_error = torch.max(torch.abs(test_input - your_dequant))

  return avg_error, max_error

test()

(tensor(0.0059), tensor(0.0115))

In [8]:
# Save the original fp32 tensor and quantized tensor to disk with torch.save. Report the difference in disk utilization
output_folder = "data/lab3"

def save_to_disk(test_input, your_quant, output_folder):
    torch.save(test_input, f"{output_folder}/test_input.pt")
    torch.save(your_quant, f"{output_folder}/your_quant.pt")
    
    test_input_size = test_input.element_size() * test_input.nelement()
    your_quant_size = your_quant.element_size() * your_quant.nelement()
    
    return test_input_size, your_quant_size

test_input, your_dequant, your_quant = test_case_1()
save_to_disk(test_input, your_quant, output_folder)

(256, 64)

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
import time
import matplotlib.pyplot as plt
from torchvision import transforms
from itertools import product

In [None]:
class MNISTDataset(Dataset):
    def __init__(self, csv_file, transform=None):
        data = pd.read_csv(csv_file)
        self.labels = data.iloc[:, 0].values
        self.pixels = data.iloc[:, 1:].values.astype('float32')
        self.pixels = self.pixels.reshape(-1, 28, 28)  # Reshape to 28x28 images

        # Normalize the pixel values
        self.pixels_mean = self.pixels.mean()
        self.pixels_std = self.pixels.std()
        self.pixels = (self.pixels - self.pixels_mean) / self.pixels_std

        self.transform = transform

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        image = self.pixels[idx]
        label = self.labels[idx]

        if self.transform:
            image = self.transform(torch.tensor(image).unsqueeze(0))

        return image.squeeze(0), torch.tensor(label)