# Homework 3 - Neural Network Post-Training Static Quantization

# Setup
> **TASKS:**
> 1. Run these cells to grab the PyPI packages and import the dependencies for the notebook. You can click into the "Files" explorer on the sidebar to confirm that `./ClassyClassifierParams.pt` was appropriately downloaded.

In [None]:
!pip install torchinfo
!pip install gdown
!gdown https://drive.google.com/uc?id=1xrTg4FfhV_znq6g1KTThK44vbIAqEvmd -O ./ClassyClassifierParams.pt

In [None]:
import copy
import torch.nn.functional as F
import torch.nn as nn
import numpy as np
import os
import matplotlib.pyplot as plt
import statistics
import torch
import torchinfo
import torchvision
import torchvision.transforms as transforms

from torch.utils.data import DataLoader
from torch.quantization import get_default_qconfig
from torch.quantization.quantize_fx import prepare_fx, convert_fx, QuantizedGraphModule
from typing import Union, Tuple

# Data
In this assignment, you'll be working with the [CIFAR-10 dataset](https://www.cs.toronto.edu/~kriz/cifar.html). CIFAR-10 contains 60,000 tiny images (RGB, 32x32), divided among 10 object classes (there's a "frog" in the dataset - that's pretty terrific). The testing subset of CIFAR-10 contains 10,000 images.

> **TASKS:**
> 1. Read the code in the cells and run them, so that you understand how the CIFAR-10 code works.


In [None]:
CIFAR10_SAVE_LOCATION = "./cifar10_dataset/"
CIFAR10_CLASS_NAMES = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
BATCH_SIZE = 5 # images get passed in 5 at a time

In [None]:
preprocessor = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ]
)

test_dataset = torchvision.datasets.CIFAR10(root=CIFAR10_SAVE_LOCATION,
                                            train=False,
                                            download=True,
                                            transform=preprocessor)
test_dataloader = DataLoader(test_dataset,
                             batch_size=BATCH_SIZE,
                             shuffle=False)

def display_cifar10_imgs(imgs: torch.Tensor):
    imgs = imgs / 2 + 0.5 # CIFAR-10 images are normalized; need to de-normalize first
    plt.imshow(np.transpose(imgs.numpy(), (1, 2, 0))) # Reorder channels
    plt.show()

In [None]:
dataiter = iter(test_dataloader)
first_batch_of_imgs, first_batch_of_groundtruth_labels = dataiter.next()

print("SAMPLE IMAGES from CIFAR-10:")
display_cifar10_imgs(torchvision.utils.make_grid(first_batch_of_imgs))
print("CORRESPONDING GROUND TRUTH LABELS:")
print("    " + "     ".join([CIFAR10_CLASS_NAMES[idx] for idx in first_batch_of_groundtruth_labels]))

# Model (5pts)

We've whipped up a basic image classifier model, similar to the one from this [official PyTorch tutorial](https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html). It's a relatively small network, yet it achieves reasonable accuracy (~68.7% correct predictions on CIFAR-10's 10,000 test images). We pre-trained the model for you (the weights are in the `./ClassyClassifierParams.pt` file you grabbed during the "Setup" step).

> **TASKS:**
> 1. Read the code in the cells and run them, so that you understand how the `ClassyClassifier` works.
> 2. Answer the concept question. (5pts)

**Concept Question 1: In `ClassyClassifier`, which layers are suitable for post-training static quantization? Why?**

_[your answer here]_

In [None]:
CLASSY_CLASSIFIER_PARAMETERS_FILENAME = "./ClassyClassifierParams.pt"

class ClassyClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        # INPUT SHAPE: Bx3x32x32 (B is for "batch size," in this case: 5)
        self.layer1_conv = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=5, stride=1) # output Bx16x28x28
        self.layer2_pool = nn.MaxPool2d(kernel_size=2, stride=2) # output Bx16x14x14
        self.layer3_conv = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5, stride=1) # output Bx32x10x10
        self.layer4_pool = nn.MaxPool2d(kernel_size=2, stride=2) # output Bx32x5x5
        self.layer5_flat = nn.Flatten() # output Bx(32x5x5=800)
        self.layer6_fc = nn.Linear(in_features=800, out_features=128) # output Bx128
        self.layer7_fc = nn.Linear(in_features=128, out_features=84) # output Bx84
        self.layer8_fc = nn.Linear(in_features=84, out_features=10) # output Bx10

    def forward(self, x: torch.Tensor):
        x = F.relu(self.layer1_conv(x))
        x = self.layer2_pool(x)
        x = F.relu(self.layer3_conv(x))
        x = self.layer4_pool(x)
        x = self.layer5_flat(x)
        x = F.relu(self.layer6_fc(x))
        x = F.relu(self.layer7_fc(x))
        x = self.layer8_fc(x)
        return x

In [None]:
fp32_classifier = ClassyClassifier()
fp32_classifier.load_state_dict(torch.load(CLASSY_CLASSIFIER_PARAMETERS_FILENAME, map_location=torch.device("cpu")))

print("SUMMARY OF ClassyClassifier")
print("    input size 1x3x32x32, the size of an image from the CIFAR-10 dataset")
torchinfo.summary(fp32_classifier, input_size=(1, 3, 32, 32), device="cpu") # CIFAR-10 is 3x32x32

# Quantizing by hand (85pts)

Quantization follows the following equation:

`FP32 VALUE = scalefactor * (QUANT INT8 VALUE - zeropoint)`

Therefore,

`QUANT INT8 VALUE = FP32 VALUE/scalefactor + zeropoint` will give you the quantized value.

The scalefactor and zeropoint can be solved by using algebra. Let's say you're trying to quantize a floating-point weight tensor with a min and max of A and B to an int-8 range of 0 and 255. You can plug those into the equations to solve for the scalefactor and then use substitution to calculate the zeropoint.

Now, you can just save the weight tensor as a much smaller INT 8 tensor and a scalefactor and zeropoint! Refer to the lecture and assignment slides for further details.

Note that this you will **NOT** be doing true quantization - ByteTensors will **NOT** be returned.
    
Instead, you will do "pseudo-quantization." You'll do all the math with FloatTensors, but the values of the FloatTensor are integers in range [0, 255].

You will calibrate the neural network by doing forward pass on images from the calibration dataset (we just reuse the test dataset for this) and then recording the layer's minimum/maximum values as the image passes through each layer. Those values will then be averaged and used to generate scalefactors and zero points.

> **TASKS:**
> 1. Make sure you understand the Quantization lecture slides. If you do, this assignment should be relatively simple!
> 2. Read the docstrings for each function so you understand what everything is supposed to do.
> 3. Fill out the `#TODO`s in the code below. The comments should guide you through the process; you can work top to bottom.
> 4. Run the code, you should see that you get the exact same predictions as the original, unquantized neural network.

*Grading details*
Each properly completed #TODO (18 of them) is worth 4 points except the "complete the forward pass" `#TODO` which is worth 13 points. 4x18+13=85 total

## Helper functions to calculate scalefactor/zeropoint, quantize/dequantize tensors

In [None]:
def calculate_scalefactor_and_zeropoint(min_val: float, max_val: float) -> Tuple[float, int]:
    """Calculates the scale factor and zero point to quantize from a [`min_val`, `max_val`] range in 32-bit float to [0, 255].
    
    Follows the quantization formula: FP32 VALUE = scalefactor * (QUANT INT8 VALUE - zeropoint).

    Using the formula, float value `min_val` quantizes to `INT8_MIN` of 0.

    Args:
        min_val (float): Minimum value in range of interest.
        max_val (float): Maximum value in range of interest.

    Returns:
        tuple[float, int]: Scale factor, zero point
    """
    INT8_MIN = 0
    INT8_MAX = 255

    scalefactor =  # TODO: fill this in such that you can scale from [min_val, max_val] to [0, 255]
    zeropoint =  # TODO: fill this in
    
    # This clamps the zero point appropriately into the range [0, 255]
    if zeropoint < INT8_MIN:
        zeropoint = INT8_MIN
    elif zeropoint > INT8_MAX:
        zeropoint = INT8_MAX

    return scalefactor, int(zeropoint)


def quantize_tensor(fp32_tensor: torch.Tensor, min_val: float, max_val: float) -> Tuple[torch.Tensor, float, int]:
    """Pseudo-quantizes a 32-bit float tensor with minimum of `min_val` and maximum of `max_val` to integer range [0, 255].

    This is done using the quantization formula `quantized_tensor = zeropoint + fp32_tensor / scalefactor`.
    
    Note that this function does **NOT** truly quantize - it does **NOT** return a ByteTensor.
    
    It still returns a FloatTensor, but the values of the FloatTensor are integers in range [0, 255].

    Args:
        fp32_tensor (torch.Tensor): The tensor to quantize.

    Returns:
        torch.Tensor: A copy of `fp32_tensor`, with all values in integer range [0, 255].
    """
    INT8_MIN = 0
    INT8_MAX = 255
    scalefactor, zeropoint = calculate_scalefactor_and_zeropoint(min_val, max_val)
    quantized_tensor =  # TODO: calculate this in terms of zeropoint, scalefactor, and fp32_tensor
    quantized_tensor = quantized_tensor.clamp(INT8_MIN, INT8_MAX).round() # Clamp to [0, 255] and round to int
    return quantized_tensor, scalefactor, int(zeropoint)


def dequantize_tensor(int8_tensor: torch.Tensor, scalefactor: float, zeropoint: int) -> torch.Tensor:
    """Dequantizes a tensor represented in integer range [0, 255] to 32-bit float.

    This is done by rearranging the quantization formula of `int8_tensor = zeropoint + fp32_tensor / scalefactor`.

    Args:
        int8_tensor (torch.Tensor): The quantized tensor.
        scalefactor (float): The scale factor of the quantized tensor.
        zeropoint (int): The zero point of the quantized tensor.

    Returns:
        torch.Tensor: The dequantized, 32-bit float tensor.
    """
    return  # TODO: return the quantized tensor in terms of int8_tensor, scalefactor, and zeropoint

## Calibration and quantization functionality

In [None]:
class QuantizedLayer():
    """An additional construct to manage pseudo-quantized weights and biases for Conv2d and Linear layers
    """
    def __init__(self, fp32_layer: Union[nn.Conv2d, nn.Linear]) -> None:
        """Constructor for a QuantizedLayer for either torch.nn.Conv2d or a torch.nn.Linear with 32-bit float weights/biases.

        Args:
            fp32_layer (nn.Conv2d | nn.Linear): The layer to quantize.
        """
        self._layer_to_run = copy.deepcopy(fp32_layer) # Don't accidentally mess with the original. Make a copy
        self.int8_weight = copy.deepcopy(fp32_layer.weight.data)
        self.int8_bias = copy.deepcopy(fp32_layer.bias.data)

        # Pseudo-quantizes the weights and biases, and stores scale factors and zero points
        # Pass in self.int8_weight to quantize_tensor() as the input tensor (first argument)
        self.int8_weight, self.weight_scalefactor, self.weight_zeropoint =  #TODO call quantize_tensor() for the weight
        # Pass in self.int8_bias to quantize_tensor() as the input tensor (first argument)
        self.int8_bias, self.bias_scalefactor, self.bias_zeropoint =  #TODO call quantize_tensor() for the bias

    def run_quantized_layer(self, x_int8:torch.Tensor, x_scalefactor: float, x_zeropoint: int, output_scalefactor: float, output_zeropoint: int) -> torch.Tensor:
        """Runs the layer with a given input, quantized to integer-range [0, 255].

        The function first applies the quantization formula to x_int8, int8_weight, int8_bias.  

        Using `output_scalefactor` and `output_zeropoint`, the output tensor is quantized using the quantization formula:

        `quantized_tensor = zeropoint + fp32_tensor / scalefactor`

        Args:
            x_int8 (torch.Tensor): Input tensor, with values in range [0, 255].
            x_scalefactor (float): Scale factor for the input.
            x_zeropoint (int): Zero point for the input.
            output_scalefactor (float): Scale factor for the output.
            output_zeropoint (int): Zero point for the output.

        Returns:
            torch.Tensor: Output tensor, in integer range [0, 255].
        """

        # Apply the above quantization formula to input, weight, and bias to convert back to fp32
        x = x_scalefactor * (x_int8 - x_zeropoint) # Apply the quantization formula to the input
        # Use self.int8_weight, self.weight_scalefactor, self.weight_zeropoint
        # You do not need to use the variable 'x' for this line
        weight =  # TODO: Apply the quantization formula to the weight
        # Use self.int8_weight, self.weight_scalefactor, self.weight_zeropoint
        # You do not need to use the variable 'x' for this line
        bias =  # TODO: Apply the quantization formula to the bias

        # Load up the layer with the re-scaled weights and biases (real hardware implements this differently)
        self._layer_to_run.weight.data = weight
        self._layer_to_run.bias.data = bias

        # Apply the quantization formula to the output
        int8_output = # TODO: Call self._layer_to_run(x) to retrieve fp32_output, and then quantize the output using 'output_scalefactor' and 'output_zeropoint'
        return int8_output

class QuantizerForClassyClassifier():
    """Applies pseudo-quantization to the ClassyClassifier for a given calibration dataset.
    """
    def __init__(self, fp32_model: ClassyClassifier, calibration_dataloader: DataLoader):
        """Constructor for a QuantizerForClassyClassifier for a given calibration dataset.

        Args:
            fp32_model (ClassyClassifier): Model to quantize.
            calibration_dataloader (DataLoader): Calibration dataset to use.
        """
        self.fp32_model = copy.deepcopy(fp32_model) # Don't accidentally mess with the original. Make a copy
        
        # Quantize weights and biases
        self.int8_layer1_conv = QuantizedLayer(fp32_model.layer1_conv)
        self.int8_layer3_conv = QuantizedLayer(fp32_model.layer3_conv)
        self.int8_layer6_fc = QuantizedLayer(fp32_model.layer6_fc)
        self.int8_layer7_fc = QuantizedLayer(fp32_model.layer7_fc)
        self.int8_layer8_fc = QuantizedLayer(fp32_model.layer8_fc)
        
        # Calibration
        self.calibration_dataloader = calibration_dataloader

        # Set up calibration stat-tracking
        self.calibration_input_stats = {
            "layer1_conv": {
                "mins": [],
                "maxes": [],
                "avg_min": 0,
                "avg_max": 0,
            },
            "layer3_conv": {
                "mins": [],
                "maxes": [],
                "avg_min": 0,
                "avg_max": 0,
                "input_scalefactor": 0.,
                "input_zeropoint": 0
            },
            "layer6_fc": {
                "mins": [],
                "maxes": [],
                "avg_min": 0,
                "avg_max": 0,
                "input_scalefactor": 0.,
                "input_zeropoint": 0
            },
            "layer7_fc": {
                "mins": [],
                "maxes": [],
                "avg_min": 0,
                "avg_max": 0,
                "input_scalefactor": 0.,
                "input_zeropoint": 0
            },
            "layer8_fc": {
                "mins": [],
                "maxes": [],
                "avg_min": 0,
                "avg_max": 0,
                "input_scalefactor": 0.,
                "input_zeropoint": 0
            },

        }

        # Calibrate
        self.calibrate_with_dataloader()

    def calibrate_with_dataloader(self): 
        """Inserts observers into the 32-bit float forward pass of the ClassyClassifier, and then runs the calibration dataset.

        Records average input mins/maxes for each Conv2d and Linear layer, and uses them to calculate input scale factors and zero points.

        Stores into `self.calibration_input_stats`.
        """
        with torch.no_grad():
            for x, _ in self.calibration_dataloader:
                batch_size = x.shape[0]
                
                ### Refer to the forward() function of ClassyClassifier to understand the order of layers and relu() activation functions used below
                ### Note that 'self.calibration_input_stats' is a dictionary, and is defined above in the constructor for the class QuantizerForClassyClassifier
                ### The string literals "layer1_conv" and "layer3_conv" used to index 'self.calibration_input_stats' correspond with the 'layer1_conv' and 'layer3_conv' in ClassyClassifier
                ### Use this information to fill out the statements for the remaining three layers

                # Get calibration data going into layer 1, then use ReLU activations and pooling, store to self.calibration_input_stats
                # collect the min and maxes stats for the input of layer 1 (aka the input x)
                self.calibration_input_stats["layer1_conv"]["mins"].append(x.view(batch_size, -1).min(dim=1)[0].sum().item())
                self.calibration_input_stats["layer1_conv"]["maxes"].append(x.view(batch_size, -1).max(dim=1)[0].sum().item())
                # use ReLU
                x = F.relu(self.fp32_model.layer1_conv(x))
                # do layer-2 pooling
                x = self.fp32_model.layer2_pool(x)

                # Get calibration data going into layer 3, then use ReLU activations and pooling, like above, store to self.calibration_input_stats               
                # Collect the min and maxes stats for the input of layer 3 (aka the output of layer 2)
                # Use ReLU, do layer-4 pooling
                self.calibration_input_stats["layer3_conv"]["mins"].append(x.view(batch_size, -1).min(dim=1)[0].sum().item())
                self.calibration_input_stats["layer3_conv"]["maxes"].append(x.view(batch_size, -1).max(dim=1)[0].sum().item())
                x = F.relu(self.fp32_model.layer3_conv(x))
                x = self.fp32_model.layer4_pool(x)
                x = self.fp32_model.layer5_flat(x)

                # Get calibration data going into layer 6, then use ReLU activations, like above, store to self.calibration_input_stats
                # TODO: collect the min and maxes stats for the input of layer 6 (aka the output of layer 5)
                # TODO: use ReLU
                

                # Get calibration data going into layer 7, then use ReLU activations, like above, store to self.calibration_input_stats
                # TODO: collect the min and maxes stats for the input of layer 7 (aka the output of layer 6)
                # TODO: use ReLU
                
                
                # Get calibration data going into layer 8, like above, store to self.calibration_input_stats
                # TODO: collect the min and maxes stats for the input of layer 7 (aka the output of layer 6)
                # No need to do ReLU and finish the forward pass because there are no further stats to collect
                


        # Convert the lists of mins and maxes into averages, layer-input scalefactors, and layer-input zeropoints
        for layer_name in self.calibration_input_stats:
            avg_min = statistics.mean(self.calibration_input_stats[layer_name]["mins"])
            avg_max = statistics.mean(self.calibration_input_stats[layer_name]["maxes"])
            # Use the two variables 'avg_min' and 'avg_max' to compute the scalefactor and zeropoint
            input_scalefactor, input_zeropoint = # TODO: call calculate_scalefactor_and_zeropoint()
            self.calibration_input_stats[layer_name]["avg_min"] = avg_min
            self.calibration_input_stats[layer_name]["avg_max"] = avg_max
            self.calibration_input_stats[layer_name]["input_scalefactor"] = input_scalefactor 
            self.calibration_input_stats[layer_name]["input_zeropoint"] = input_zeropoint
    
    def run_calibrated_quantized(self, x: torch.Tensor) -> torch.Tensor:
        """Runs the pseudo-quantized ClassyClassifier on a 32-bit float tensor.
        
        Quantizes the tensor to [0, 255] prior to forward pass, and then dequantizes back to float before return.

        Args:
            x (torch.Tensor): Input 32-bit float tensor.

        Returns:
            torch.Tensor: Output 32-bit float tensor.
        """
        x = copy.deepcopy(x)

        # Use the calibration input stats for Layer 1 to quantize the tensor and run the layer
        x, x_scalefactor, x_zeropoint = quantize_tensor(x, self.calibration_input_stats["layer1_conv"]["avg_min"], self.calibration_input_stats["layer1_conv"]["avg_max"])
        # The \ in Python continues the statement on the next line - it is often used to make long, single line statements more readable
        x = self.int8_layer1_conv.run_quantized_layer(x, x_scalefactor, x_zeropoint, \
                                                      self.calibration_input_stats["layer3_conv"]["input_scalefactor"], self.calibration_input_stats["layer3_conv"]["input_zeropoint"])
        x = F.relu(x)
        x = self.fp32_model.layer2_pool(x)

        # TODO (worth 13 points): As above, use the calibration input stats for the appropriate layers to complete the rest of the forward pass: layer 4-7 (DON'T DO layer 8 yet)
        # Remember to look at the order of layers and activations in calibrate_with_dataloader() 
        # Don't forget to do your ReLU activations and pooling as necessary!
        # Notice the pattern between the first two and last two parameters in calls to run_quantized_layer()
        x = self.int8_layer3_conv.run_quantized_layer(x, self.calibration_input_stats["layer3_conv"]["input_scalefactor"], self.calibration_input_stats["layer3_conv"]["input_zeropoint"], \
                                                      self.calibration_input_stats["layer6_fc"]["input_scalefactor"], self.calibration_input_stats["layer6_fc"]["input_zeropoint"])
        x = F.relu(x)
      
        x = # TODO: dequantize the tensor 'x' using dequantize_tensor(), the Layer 8 scalefactor, and zeropoint
        
        # Run the last layer in FP32, dequantized
        x = self.fp32_model.layer8_fc(x)
        return x

In [None]:
# Just one line to set everything up using the test dataloader as the calibration dataset
quantizer = QuantizerForClassyClassifier(fp32_classifier.eval(), test_dataloader)

In [None]:
dataiter = iter(test_dataloader)
first_batch_of_imgs, first_batch_of_groundtruth_labels = dataiter.next()

print("SAMPLE IMAGES from CIFAR-10:")
display_cifar10_imgs(torchvision.utils.make_grid(first_batch_of_imgs))
print("CORRESPONDING GROUND TRUTH LABELS:")
print("    " + "     ".join([CIFAR10_CLASS_NAMES[idx] for idx in first_batch_of_groundtruth_labels]))

outputs = quantizer.run_calibrated_quantized(first_batch_of_imgs)

_, predicted = torch.max(outputs, 1)
print("PREDICTED CLASSES from PSEUDO-QUANTIZED NETWORK should match original neural network outputs (i.e., cat, car, plane, plane, frog):")
print("    " + "     ".join([CIFAR10_CLASS_NAMES[idx] for idx in predicted]))

# Quantizing with Torch.FX (10 pts)

Torch.FX is a library in PyTorch that represents the computational graph of a neural network. It also has a built-in API that makes quantization a snap!

> **TASKS:**
> 1. Read and run the cells below. You should see that the Torch.FX-quantized model is much smaller (160kB) than the original model (540kB). You should also see that accuracy remains similar.
> 2. Answer the questions.

**(2 pts) Concept Question 1: Explain what calibration does.**

_[your answer here]_

**(2 pts) Concept Question 2: Why might the accuracy of a quantized network be _better_ than that of the original? Doesn't quantization cause the network to lose precision?**

_[your answer here]_

**(6 pts) Ensure the accuracy of the Torch.FX model and your Pseudo-Quantized Model is within +- 0.08 of the original, non-quantized model**. If your accuracy is not within that range, you most likely have an incorrect implementation.



In [None]:
import os
from torch.quantization import get_default_qconfig
from torch.quantization.quantize_fx import prepare_fx, convert_fx, QuantizedGraphModule

fp32_classifier.eval()
qconfig = get_default_qconfig("fbgemm")
qconfig_dict = {"": qconfig}

# Fuse modules and insert observers to automatically collect min/max during calibration
prepared_model = prepare_fx(fp32_classifier, qconfig_dict)  

# Calibrate
with torch.no_grad():
    for x, _ in test_dataloader:
        prepared_model(x)

# Convert the calibrated model to a quantized model
quantized_model = convert_fx(prepared_model).eval()
print("FX-QUANTIZED MODEL:")
print(quantized_model)

In [None]:
dataiter = iter(test_dataloader)
first_batch_of_imgs, first_batch_of_groundtruth_labels = dataiter.next()

print("SAMPLE IMAGES from CIFAR-10:")
display_cifar10_imgs(torchvision.utils.make_grid(first_batch_of_imgs))
print("CORRESPONDING GROUND TRUTH LABELS:")
print("    " + "     ".join([CIFAR10_CLASS_NAMES[idx] for idx in first_batch_of_groundtruth_labels]))

outputs = quantized_model(first_batch_of_imgs)
_, predicted = torch.max(outputs, 1)

print("PREDICTED CLASSES from FX-QUANTIZED NETWORK should match be similar to original neural network outputs (i.e., cat, car, plane, plane, bird):")
print("    " + "     ".join([CIFAR10_CLASS_NAMES[idx] for idx in predicted]))

In [None]:
def evaluate_ClassyClassifier(model: Union[ClassyClassifier, QuantizerForClassyClassifier, QuantizedGraphModule], test_dataloader: DataLoader, report_model_size:bool=False):
    num_images_correct = 0
    total_images_seen = 0
    timings = []
    with torch.no_grad():
        for images, labels in test_dataloader:
            batch_size = images.shape[0]
            
            if type(model) is QuantizerForClassyClassifier:
                outputs = model.run_calibrated_quantized(images)
            else:
                outputs = model(images)
            
            _, predicted = torch.max(outputs.data, 1)
            
            total_images_seen += batch_size
            
            num_images_correct += (predicted == labels).sum().item()

    print("Accuracy: {0:.3f}, {1} correct/{2} total images".format(num_images_correct/total_images_seen, num_images_correct, total_images_seen))
    if report_model_size and type(model) is not QuantizerForClassyClassifier:
        torch.jit.save(torch.jit.script(model), "temp.p")
        print("Model Size (kB): {0:.1f}".format(os.path.getsize("temp.p")/1024))
        os.remove("temp.p")

print("ACCURACY:")

print("ORIGINAL ClassyClassifier:")
evaluate_ClassyClassifier(fp32_classifier, test_dataloader, True)

print("FX-QUANTIZED ClassyClassifier:")
evaluate_ClassyClassifier(torch.jit.script(quantized_model).eval(), test_dataloader, True)

print("PSEUDO-QUANTIZED ClassyClassifier:")
evaluate_ClassyClassifier(quantizer, test_dataloader)