In [None]:
'checkpoints_student/T=10, alpha=0.5, dropout_hidden=0.0, dropout_input=0.0, lr=0.01, lr_decay=0.95, momentum=0.9, weight_decay=0.001_pruning_0.1_final.tar'

In [2]:
import numpy as np
import math

%matplotlib inline
import matplotlib.pyplot as plt

import pickle
import argparse
import time
import itertools
from copy import deepcopy
import os

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import csv
import sys
#sys.path.append('/content/KD')
# Import the module
import networks
import utils

%load_ext autoreload
%autoreload 2

In [3]:
use_gpu = True    # set use_gpu to True if system has gpu
gpu_id = 0        # id of gpu to be used
cpu_device = torch.device('cpu')
# fast_device is where computation (training, inference) happens
fast_device = torch.device('cpu')
if use_gpu:
    os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2'    # set visible devices depending on system configuration
    fast_device = torch.device('cuda:' + str(gpu_id))

In [6]:
def reproducibilitySeed():
    """
    Ensure reproducibility of results; Seeds to 0
    """
    torch_init_seed = 0
    torch.manual_seed(torch_init_seed)
    numpy_init_seed = 0
    np.random.seed(numpy_init_seed)
    if use_gpu:
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

reproducibilitySeed()

In [4]:
import torchvision
import torchvision.transforms as transforms

# Set up transformations for CIFAR-10
transform_train = transforms.Compose(
    [
        transforms.RandomCrop(32, padding=4),  # Augment training data by padding 4 and random cropping
        transforms.RandomHorizontalFlip(),     # Randomly flip images horizontally
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))  # Normalization for CIFAR-10
    ]
)

transform_test = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))  # Normalization for CIFAR-10
    ]
)

# Load CIFAR-10 dataset
train_val_dataset = torchvision.datasets.CIFAR10(root='./CIFAR10_dataset/', train=True,
                                            download=True, transform=transform_train)

test_dataset = torchvision.datasets.CIFAR10(root='./CIFAR10_dataset/', train=False,
                                            download=True, transform=transform_test)

# Split the training dataset into training and validation
num_train = int(0.95 * len(train_val_dataset))  # 95% of the dataset for training
num_val = len(train_val_dataset) - num_train  # Remaining 5% for validation
train_dataset, val_dataset = torch.utils.data.random_split(train_val_dataset, [num_train, num_val])

# DataLoader setup
batch_size = 128
train_val_loader = torch.utils.data.DataLoader(train_val_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

Files already downloaded and verified
Files already downloaded and verified


In [16]:
# Path to the saved model
teacher_path = "checkpoints_teacher/dropout_hidden=0.0, dropout_input=0.0, lr=0.01, lr_decay=0.95, momentum=0.9, weight_decay=0.001_final.tar"

# Initialize the network
teacher_net = networks.TeacherNetwork()
teacher_net = teacher_net.to(fast_device)

# Load the checkpoint
checkpoint = torch.load(teacher_path)

# Load the state dictionary into the model
teacher_net.load_state_dict(checkpoint['model_state_dict'])

# pre-trained teacher accuracy
reproducibilitySeed()
_, test_accuracy = utils.getLossAccuracyOnDataset(teacher_net, test_loader, fast_device)
print('test accuracy: ', test_accuracy)

  checkpoint = torch.load(teacher_path)


test accuracy:  0.8616


In [None]:
from quantize_neural_net import QuantizeNeuralNet
pruning_factors = [i/20 for i in range(1, 11)]


# Path to the saved model
student_path = "checkpoints_student/T=10, alpha=0.5, dropout_hidden=0.0, dropout_input=0.0, lr=0.01, lr_decay=0.95, momentum=0.9, weight_decay=0.001_pruning_0.4_final.tar"

# Initialize the network
student_model = networks.StudentNetwork(.4, teacher_net)
student_model = student_model.to(fast_device)

# Load the checkpoint
checkpoint = torch.load(student_path)

student_model.load_state_dict(checkpoint['model_state_dict'])

# Ensure reproducibility and evaluate pre-trained teacher accuracy
reproducibilitySeed()


_, test_accuracy = utils.getLossAccuracyOnDataset(student_model, test_loader, fast_device)
print('Test accuracy:', test_accuracy)

quantizer = QuantizeNeuralNet(student_model.model.model,
    'resnet18',  # Default from `-model`
    batch_size=128,  # Default from `--batch_size`
    data_loader=train_loader,
    mlp_bits=4,  # Default from `--bits`
    cnn_bits=4,  # Default from `--bits`
    ignore_layers=[],  # Default from `--ignore_layer`
    mlp_alphabet_scalar=1.16,  # Default from `--scalar`
    cnn_alphabet_scalar=1.16,  # Default from `--scalar`
    mlp_percentile=1,  # Default from `--percentile`
    cnn_percentile=1,  # Default from `--percentile`
    reg=None,  # Default from `--regularizer`
    lamb=0.1,  # Default from `--lamb`
    retain_rate=0.25,  # Default from `--retain_rate`
    stochastic_quantization=False,  # Default from `--stochastic_quantization`
    device=fast_device
)

quantized_model = quantizer.quantize_network()

_, test_accuracy = utils.getLossAccuracyOnDataset(quantized_model, test_loader, fast_device)
print('Test accuracy:', test_accuracy)

  checkpoint = torch.load(student_path)


Test accuracy: 0.8604


In [22]:
def count_parameters(model):
    """
    Counts the total number of trainable parameters in a PyTorch model.

    Args:
        model (torch.nn.Module): The model whose parameters need to be counted.

    Returns:
        int: Total number of trainable parameters.
    """
    return sum((p.data != 0).sum().item() for p in model.parameters() if p.requires_grad)


def count_zero_parameters(model):
    """
    Counts the number of trainable parameters that are exactly zero in a PyTorch model.

    Args:
        model (torch.nn.Module): The model whose zero parameters need to be counted.

    Returns:
        int: Total number of trainable parameters that are exactly zero.
    """
    return sum((p.data == 0).sum().item() for p in model.parameters() if p.requires_grad)

In [23]:
count_parameters(student_model)

6712822

In [42]:
from quantize_neural_net import QuantizeNeuralNet


In [48]:
quantized_model = quantizer.quantize_network()

Layer indices to quantize [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20]
Total number of layers to quantize 21

Quantizing layer with index: 0
Quantization progress: 0 out of 20

shape of W: torch.Size([64, 3, 7, 7])
shape of analog_layer_input: torch.Size([896, 147])
shape of quantized_layer_input: torch.Size([896, 147])
The number of groups: 1



100%|██████████| 147/147 [00:00<00:00, 549.03it/s]


The quantization error of layer 0 is 25.63368797302246.
The relative quantization error of layer 0 is 0.07626139372587204.


Quantizing layer with index: 1
Quantization progress: 1 out of 20

shape of W: torch.Size([64, 64, 3, 3])
shape of analog_layer_input: torch.Size([384, 576])
shape of quantized_layer_input: torch.Size([384, 576])
The number of groups: 1



100%|██████████| 576/576 [00:00<00:00, 2098.73it/s]


The quantization error of layer 1 is 3.9425508975982666.
The relative quantization error of layer 1 is 0.05762189254164696.


Quantizing layer with index: 2
Quantization progress: 2 out of 20

shape of W: torch.Size([64, 64, 3, 3])
shape of analog_layer_input: torch.Size([384, 576])
shape of quantized_layer_input: torch.Size([384, 576])
The number of groups: 1



100%|██████████| 576/576 [00:00<00:00, 2564.26it/s]


The quantization error of layer 2 is 1.7968082427978516.
The relative quantization error of layer 2 is 0.08124750852584839.


Quantizing layer with index: 3
Quantization progress: 3 out of 20

shape of W: torch.Size([64, 64, 3, 3])
shape of analog_layer_input: torch.Size([384, 576])
shape of quantized_layer_input: torch.Size([384, 576])
The number of groups: 1



100%|██████████| 576/576 [00:00<00:00, 2740.14it/s]


The quantization error of layer 3 is 4.898229598999023.
The relative quantization error of layer 3 is 0.08185645192861557.


Quantizing layer with index: 4
Quantization progress: 4 out of 20

shape of W: torch.Size([64, 64, 3, 3])
shape of analog_layer_input: torch.Size([384, 576])
shape of quantized_layer_input: torch.Size([384, 576])
The number of groups: 1



100%|██████████| 576/576 [00:00<00:00, 2577.88it/s]


The quantization error of layer 4 is 1.6156749725341797.
The relative quantization error of layer 4 is 0.09929851442575455.


Quantizing layer with index: 5
Quantization progress: 5 out of 20

shape of W: torch.Size([128, 64, 3, 3])
shape of analog_layer_input: torch.Size([384, 576])
shape of quantized_layer_input: torch.Size([384, 576])
The number of groups: 1



100%|██████████| 576/576 [00:00<00:00, 2834.06it/s]


The quantization error of layer 5 is 8.018511772155762.
The relative quantization error of layer 5 is 0.11040127277374268.


Quantizing layer with index: 6
Quantization progress: 6 out of 20

shape of W: torch.Size([128, 128, 3, 3])
shape of analog_layer_input: torch.Size([256, 1152])
shape of quantized_layer_input: torch.Size([256, 1152])
The number of groups: 1



100%|██████████| 1152/1152 [00:00<00:00, 3003.33it/s]


The quantization error of layer 6 is 1.252508521080017.
The relative quantization error of layer 6 is 0.08935167640447617.


Quantizing layer with index: 7
Quantization progress: 7 out of 20

shape of W: torch.Size([128, 64, 1, 1])
shape of analog_layer_input: torch.Size([2176, 64])
shape of quantized_layer_input: torch.Size([2176, 64])
The number of groups: 1



100%|██████████| 64/64 [00:00<00:00, 3126.14it/s]

The quantization error of layer 7 is 13.434507369995117.
The relative quantization error of layer 7 is 0.1808401495218277.







Quantizing layer with index: 8
Quantization progress: 8 out of 20

shape of W: torch.Size([128, 128, 3, 3])
shape of analog_layer_input: torch.Size([256, 1152])
shape of quantized_layer_input: torch.Size([256, 1152])
The number of groups: 1



100%|██████████| 1152/1152 [00:00<00:00, 3124.68it/s]


The quantization error of layer 8 is 2.0917251110076904.
The relative quantization error of layer 8 is 0.09697361290454865.


Quantizing layer with index: 9
Quantization progress: 9 out of 20

shape of W: torch.Size([128, 128, 3, 3])
shape of analog_layer_input: torch.Size([256, 1152])
shape of quantized_layer_input: torch.Size([256, 1152])
The number of groups: 1



100%|██████████| 1152/1152 [00:00<00:00, 3199.35it/s]


The quantization error of layer 9 is 0.667126476764679.
The relative quantization error of layer 9 is 0.09861595928668976.


Quantizing layer with index: 10
Quantization progress: 10 out of 20

shape of W: torch.Size([256, 128, 3, 3])
shape of analog_layer_input: torch.Size([256, 1152])
shape of quantized_layer_input: torch.Size([256, 1152])
The number of groups: 1



100%|██████████| 1152/1152 [00:00<00:00, 3223.77it/s]


The quantization error of layer 10 is 3.0722508430480957.
The relative quantization error of layer 10 is 0.12117790430784225.


Quantizing layer with index: 11
Quantization progress: 11 out of 20

shape of W: torch.Size([256, 256, 3, 3])
shape of analog_layer_input: torch.Size([128, 2304])
shape of quantized_layer_input: torch.Size([128, 2304])
The number of groups: 1



100%|██████████| 2304/2304 [00:00<00:00, 3480.09it/s]


The quantization error of layer 11 is 1.1908488273620605.
The relative quantization error of layer 11 is 0.08112328499555588.


Quantizing layer with index: 12
Quantization progress: 12 out of 20

shape of W: torch.Size([256, 128, 1, 1])
shape of analog_layer_input: torch.Size([640, 128])
shape of quantized_layer_input: torch.Size([640, 128])
The number of groups: 1



100%|██████████| 128/128 [00:00<00:00, 2684.54it/s]

The quantization error of layer 12 is 4.296163082122803.
The relative quantization error of layer 12 is 0.2490590214729309.







Quantizing layer with index: 13
Quantization progress: 13 out of 20

shape of W: torch.Size([256, 256, 3, 3])
shape of analog_layer_input: torch.Size([128, 2304])
shape of quantized_layer_input: torch.Size([128, 2304])
The number of groups: 1



100%|██████████| 2304/2304 [00:00<00:00, 3344.41it/s]


The quantization error of layer 13 is 0.8493529558181763.
The relative quantization error of layer 13 is 0.10086272656917572.


Quantizing layer with index: 14
Quantization progress: 14 out of 20

shape of W: torch.Size([256, 256, 3, 3])
shape of analog_layer_input: torch.Size([128, 2304])
shape of quantized_layer_input: torch.Size([128, 2304])
The number of groups: 1



100%|██████████| 2304/2304 [00:00<00:00, 3353.66it/s]


The quantization error of layer 14 is 0.34594428539276123.
The relative quantization error of layer 14 is 0.13153810799121857.


Quantizing layer with index: 15
Quantization progress: 15 out of 20

shape of W: torch.Size([512, 256, 3, 3])
shape of analog_layer_input: torch.Size([128, 2304])
shape of quantized_layer_input: torch.Size([128, 2304])
The number of groups: 1



100%|██████████| 2304/2304 [00:00<00:00, 3389.00it/s]


The quantization error of layer 15 is 1.9843748807907104.
The relative quantization error of layer 15 is 0.1347113400697708.


Quantizing layer with index: 16
Quantization progress: 16 out of 20

shape of W: torch.Size([512, 512, 3, 3])
shape of analog_layer_input: torch.Size([128, 4608])
shape of quantized_layer_input: torch.Size([128, 4608])
The number of groups: 1



100%|██████████| 4608/4608 [00:01<00:00, 3529.13it/s]


The quantization error of layer 16 is 0.7867875099182129.
The relative quantization error of layer 16 is 0.2127627283334732.


Quantizing layer with index: 17
Quantization progress: 17 out of 20

shape of W: torch.Size([512, 256, 1, 1])
shape of analog_layer_input: torch.Size([256, 256])
shape of quantized_layer_input: torch.Size([256, 256])
The number of groups: 1



100%|██████████| 256/256 [00:00<00:00, 3344.99it/s]

The quantization error of layer 17 is 2.458876609802246.
The relative quantization error of layer 17 is 0.3303937017917633.







Quantizing layer with index: 18
Quantization progress: 18 out of 20

shape of W: torch.Size([512, 512, 3, 3])
shape of analog_layer_input: torch.Size([128, 4608])
shape of quantized_layer_input: torch.Size([128, 4608])
The number of groups: 1



100%|██████████| 4608/4608 [00:01<00:00, 3574.08it/s]


The quantization error of layer 18 is 1.88277268409729.
The relative quantization error of layer 18 is 0.1320144683122635.


Quantizing layer with index: 19
Quantization progress: 19 out of 20

shape of W: torch.Size([512, 512, 3, 3])
shape of analog_layer_input: torch.Size([128, 4608])
shape of quantized_layer_input: torch.Size([128, 4608])
The number of groups: 1



100%|██████████| 4608/4608 [00:01<00:00, 3428.54it/s]


The quantization error of layer 19 is 0.6852036714553833.
The relative quantization error of layer 19 is 0.14334264397621155.


Quantizing layer with index: 20
Quantization progress: 20 out of 20

The number of groups: 1



100%|██████████| 512/512 [00:00<00:00, 3313.58it/s]


The quantization error of layer 20 is 20.633893966674805.
The relative quantization error of layer 20 is 0.19272233545780182.

