### Import required packages and limit GPU usage

In [None]:
import numpy as np
import math

%matplotlib inline
import matplotlib.pyplot as plt

import pickle
import argparse
import time
import itertools
from copy import deepcopy
import os

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import sys
sys.path.append('/content/KD')
# Import the module
import networks
import utils

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


device(type='cuda', index=0)

In [None]:
use_gpu = False    # set use_gpu to True if system has gpu
gpu_id = 0        # id of gpu to be used
cpu_device = torch.device('cpu')
# fast_device is where computation (training, inference) happens
fast_device = torch.device('cpu')
if use_gpu:
    os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2'    # set visible devices depending on system configuration
    fast_device = torch.device('cuda:' + str(gpu_id))

In [None]:
def reproducibilitySeed():
    """
    Ensure reproducibility of results; Seeds to 0
    """
    torch_init_seed = 0
    torch.manual_seed(torch_init_seed)
    numpy_init_seed = 0
    np.random.seed(numpy_init_seed)
    if use_gpu:
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

reproducibilitySeed()

In [None]:
checkpoints_path_teacher = 'checkpoints_teacher/'
checkpoints_path_student = 'checkpoints_student/'
if not os.path.exists(checkpoints_path_student):
    os.makedirs(checkpoints_path_student)

### Load dataset

In [None]:
import torchvision
import torchvision.transforms as transforms

# Set up transformations for CIFAR-10
transform_train = transforms.Compose(
    [
        transforms.RandomCrop(32, padding=4),  # Augment training data by padding 4 and random cropping
        transforms.RandomHorizontalFlip(),     # Randomly flip images horizontally
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))  # Normalization for CIFAR-10
    ]
)

transform_test = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))  # Normalization for CIFAR-10
    ]
)

# Load CIFAR-10 dataset
train_val_dataset = torchvision.datasets.CIFAR10(root='./CIFAR10_dataset/', train=True,
                                            download=True, transform=transform_train)

test_dataset = torchvision.datasets.CIFAR10(root='./CIFAR10_dataset/', train=False,
                                            download=True, transform=transform_test)

# Split the training dataset into training and validation
num_train = int(0.95 * len(train_val_dataset))  # 95% of the dataset for training
num_val = len(train_val_dataset) - num_train  # Remaining 5% for validation
train_dataset, val_dataset = torch.utils.data.random_split(train_val_dataset, [num_train, num_val])

# DataLoader setup
batch_size = 128
train_val_loader = torch.utils.data.DataLoader(train_val_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

Files already downloaded and verified
Files already downloaded and verified


### Train teacher network

In [None]:
num_epochs = 10
print_every = 100    # Interval size for which to print statistics of training

In [None]:
# Hyperparamters can be tuned by setting required range below
# learning_rates = list(np.logspace(-4, -2, 3))
learning_rates = [1e-2]
learning_rate_decays = [0.95]    # learning rate decays at every epoch
# weight_decays = [0.0] + list(np.logspace(-5, -1, 5))
weight_decays = [1e-5]           # regularization weight
momentums = [0.9]
# dropout_probabilities = [(0.2, 0.5), (0.0, 0.0)]
dropout_probabilities = [(0.0, 0.0)]
hparams_list = []
for hparam_tuple in itertools.product(dropout_probabilities, weight_decays, learning_rate_decays,
                                        momentums, learning_rates):
    hparam = {}
    hparam['dropout_input'] = hparam_tuple[0][0]
    hparam['dropout_hidden'] = hparam_tuple[0][1]
    hparam['weight_decay'] = hparam_tuple[1]
    hparam['lr_decay'] = hparam_tuple[2]
    hparam['momentum'] = hparam_tuple[3]
    hparam['lr'] = hparam_tuple[4]
    hparams_list.append(hparam)

results = {}
for hparam in hparams_list:
    print('Training with hparams' + utils.hparamToString(hparam))
    reproducibilitySeed()
    teacher_net = networks.TeacherNetworkVGG()
    teacher_net = teacher_net.to(fast_device)
    hparam_tuple = utils.hparamDictToTuple(hparam)
    results[hparam_tuple] = utils.trainTeacherOnHparam(teacher_net, hparam, num_epochs,
                                                        train_val_loader, None,
                                                        print_every=print_every,
                                                        fast_device=fast_device)
    save_path = checkpoints_path + utils.hparamToString(hparam) + '_final.tar'
    torch.save({'results' : results[hparam_tuple],
                'model_state_dict' : teacher_net.state_dict(),
                'epoch' : num_epochs}, save_path)

Training with hparamsdropout_hidden=0.0, dropout_input=0.0, lr=0.01, lr_decay=0.95, momentum=0.9, weight_decay=1e-05




[1,   100/  391] train loss: 0.882 train accuracy: 0.711
[1,   200/  391] train loss: 0.764 train accuracy: 0.711
[1,   300/  391] train loss: 0.610 train accuracy: 0.805


NameError: name 'checkpoints_path' is not defined

In [None]:
# Calculate test accuracy
_, test_accuracy = utils.getLossAccuracyOnDataset(teacher_net, test_loader, fast_device)
print('test accuracy: ', test_accuracy)

test accuracy:  0.8654


Student Network

In [None]:
num_epochs = 30
print_every = 100    #

In [None]:
1/32

0.03125

In [None]:
# Hypothetical setup, please adjust according to actual import paths and methods
temperatures = [10]
alphas = [0.5]
learning_rates = [1e-2]
learning_rate_decays = [0.95]
weight_decays = [1e-5]
momentums = [0.9]
dropout_probabilities = [(0.0, 0.0)]
hparams_list = []

for hparam_tuple in itertools.product(alphas, temperatures, dropout_probabilities, weight_decays, learning_rate_decays, momentums, learning_rates):
    hparam = {}
    hparam['alpha'] = hparam_tuple[0]
    hparam['T'] = hparam_tuple[1]
    hparam['dropout_input'] = hparam_tuple[2][0]
    hparam['dropout_hidden'] = hparam_tuple[2][1]
    hparam['weight_decay'] = hparam_tuple[3]
    hparam['lr_decay'] = hparam_tuple[4]
    hparam['momentum'] = hparam_tuple[5]
    hparam['lr'] = hparam_tuple[6]
    hparams_list.append(hparam)

results_distill = {}
pruning_factors = [i/32 for i in range(1, 33)]
#pruning_factors = [0.2]  # Define desired pruning factors

for pruning_factor in pruning_factors:
    for hparam in hparams_list:
        print('Training with hparams' + utils.hparamToString(hparam) + f' and pruning factor {pruning_factor}')
        reproducibilitySeed()
        student_net = networks.StudentNetwork(pruning_factor)
        student_net = student_net.to(fast_device)
        teacher_net = teacher_net.to(fast_device)
        hparam_tuple = utils.hparamDictToTuple(hparam)

        student_zero_params = count_zero_parameters(student_net)
        student_total_params = count_parameters(student_net)
        print(100 * student_zero_params / student_total_params)

        results_distill[(hparam_tuple, pruning_factor)] = utils.trainStudentOnHparam(teacher_net, student_net, hparam, num_epochs,
                                                                                    train_val_loader, None,
                                                                                    print_every=print_every,
                                                                                    fast_device=fast_device)
        save_path = checkpoints_path_student + utils.hparamToString(hparam) + f'_pruning_{pruning_factor}_final.tar'
        torch.save({'results': results_distill[(hparam_tuple, pruning_factor)],
                    'model_state_dict': student_net.state_dict(),
                    'epoch': num_epochs}, save_path)

        _, test_accuracy = utils.getLossAccuracyOnDataset(student_net, test_loader, fast_device)

Training with hparamsT=10, alpha=0.5, dropout_hidden=0.0, dropout_input=0.0, lr=0.01, lr_decay=0.95, momentum=0.9, weight_decay=1e-05 and pruning factor 0.03125
3.1168690456815966
[1,   100/  391] train loss: 7.448 train accuracy: 0.578
[1,   200/  391] train loss: 4.315 train accuracy: 0.672
[1,   300/  391] train loss: 5.264 train accuracy: 0.602
[2,   100/  391] train loss: 4.616 train accuracy: 0.727
[2,   200/  391] train loss: 6.824 train accuracy: 0.531
[2,   300/  391] train loss: 5.004 train accuracy: 0.648
[3,   100/  391] train loss: 3.450 train accuracy: 0.805
[3,   200/  391] train loss: 3.186 train accuracy: 0.758
[3,   300/  391] train loss: 3.757 train accuracy: 0.734
[4,   100/  391] train loss: 2.790 train accuracy: 0.805
[4,   200/  391] train loss: 3.134 train accuracy: 0.805
[4,   300/  391] train loss: 2.854 train accuracy: 0.797
[5,   100/  391] train loss: 2.686 train accuracy: 0.773
[5,   200/  391] train loss: 2.679 train accuracy: 0.805
[5,   300/  391] train

In [None]:
def count_parameters(model):
    """
    Counts the total number of trainable parameters in a PyTorch model.

    Args:
        model (torch.nn.Module): The model whose parameters need to be counted.

    Returns:
        int: Total number of trainable parameters.
    """
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def count_zero_parameters(model):
    """
    Counts the number of trainable parameters that are exactly zero in a PyTorch model.

    Args:
        model (torch.nn.Module): The model whose zero parameters need to be counted.

    Returns:
        int: Total number of trainable parameters that are exactly zero.
    """
    return sum((p.data == 0).sum().item() for p in model.parameters() if p.requires_grad)


In [None]:
# Assuming teacher_net and student_net are instances of TeacherNetwork and StudentNetwork, respectively
teacher_net = networks.TeacherNetwork()
student_net = networks.StudentNetwork(0.1)

# Calculate and print the total number of parameters for both models
teacher_total_params = count_parameters(teacher_net)
student_total_params = count_parameters(student_net)

# Calculate and print the number of zero parameters for both models
teacher_zero_params = count_zero_parameters(teacher_net)
student_zero_params = count_zero_parameters(student_net)

print(f"Teacher Network: {teacher_total_params} total parameters, {teacher_zero_params} are zero.")
print(f"Student Network: {student_total_params} total parameters, {student_zero_params} are zero.")

# Optionally, calculate the percentage of zero parameters in each model
teacher_zero_percent = 100 * teacher_zero_params / teacher_total_params
student_zero_percent = 100 * student_zero_params / student_total_params

print(f"Percentage of zero parameters in Teacher Network: {teacher_zero_percent:.2f}%")
print(f"Percentage of zero parameters in Student Network: {student_zero_percent:.2f}%")



Teacher Network: 58164298 total parameters, 0 are zero.
Student Network: 58164298 total parameters, 5801373 are zero.
Percentage of zero parameters in Teacher Network: 0.00%
Percentage of zero parameters in Student Network: 9.97%


In [None]:
teacher_total_params = count_parameters(teacher_net)
teacher_zero_params = count_zero_parameters(teacher_net)
teacher_zero_percent = 100 * teacher_zero_params / teacher_total_params

In [None]:
import pandas as pd
result = pd.DataFrame(columns=['Pruning Factor', 'Accuracy'])

In [None]:
import torch
import itertools
import networks  # Ensure the correct import of your networks module
import utils  # Utilities for hyperparameter string conversion and more


# Define your hyperparameters
t = [10]
alpha = [0.5]
dropout_probabilities = [(0.0, 0.0)]
weight_decays = [1e-5]
learning_rate_decays = [0.95]
momentums = [0.9]
learning_rates = [1e-2]
pruning_factors = [i/32 for i in range(1, 33)]
#pruning_factors = [0.1, 0.2]  # Example pruning factors

hparams_list = []
for hparam_tuple in itertools.product(t, alpha, dropout_probabilities, weight_decays, learning_rate_decays, momentums, learning_rates):
    hparam = {
        'T': hparam_tuple[0],
        'alpha': hparam_tuple[1],
        'dropout_input': hparam_tuple[2][0],
        'dropout_hidden': hparam_tuple[2][1],
        'weight_decay': hparam_tuple[3],
        'lr_decay': hparam_tuple[4],
        'momentum': hparam_tuple[5],
        'lr': hparam_tuple[6]
    }
    hparams_list.append(hparam)

# Define the path to your checkpoints
checkpoints_path_student = "../content/checkpoints_student/"

# Load and set up each student model based on hyperparameters and pruning factor
for hparam in hparams_list:
    for prune_factor in pruning_factors:
        filename = utils.hparamToString(hparam) + f'_pruning_{prune_factor}_final.tar'
        load_path = checkpoints_path_student + filename

        # Load the student network
        student_net = networks.StudentNetwork(prune_amount=prune_factor)
        student_net.load_state_dict(torch.load(load_path, map_location=fast_device, weights_only=True)['model_state_dict'])
        student_net = student_net.to(fast_device)  # Move to the appropriate device, again adjust as needed

        _, test_accuracy = utils.getLossAccuracyOnDataset(student_net, test_loader, fast_device)

        # Create a new DataFrame from the data to be added
        new_data = pd.DataFrame({'Pruning Factor': [prune_factor], 'Accuracy': [test_accuracy]})
        # Use concat to add the new data to the existing DataFrame
        result = pd.concat([result, new_data], ignore_index=True)
        print('student test accuracy for ' + f'pruning factor = {prune_factor}:', test_accuracy)

In [None]:
import matplotlib.pyplot as plt
import pandas as pd

# Plotting
plt.figure(figsize=(10, 6))
plt.plot(result['Pruning Factor'], result['Accuracy'], marker='o')
plt.title('Accuracy vs Pruning Factor')
plt.xlabel('Pruning Factor')
plt.ylabel('Accuracy')
plt.grid(True)
plt.show()


NameError: name 'result' is not defined

<Figure size 1000x600 with 0 Axes>

In [None]:
# Assuming teacher_net and student_net are instances of TeacherNetwork and StudentNetwork, respectively
teacher_net = networks.TeacherNetwork()
student_net = networks.StudentNetwork(0.2)

# Calculate and print the total number of parameters for both models
teacher_total_params = count_parameters(teacher_net)
student_total_params = count_parameters(student_net)

# Calculate and print the number of zero parameters for both models
teacher_zero_params = count_zero_parameters(teacher_net)
student_zero_params = count_zero_parameters(student_net)

print(f"Teacher Network: {teacher_total_params} total parameters, {teacher_zero_params} are zero.")
print(f"Student Network: {student_total_params} total parameters, {student_zero_params} are zero.")

# Optionally, calculate the percentage of zero parameters in each model
teacher_zero_percent = 100 * teacher_zero_params / teacher_total_params
student_zero_percent = 100 * student_zero_params / student_total_params

print(f"Percentage of zero parameters in Teacher Network: {teacher_zero_percent:.2f}%")
print(f"Percentage of zero parameters in Student Network: {student_zero_percent:.2f}%")

Teacher Network: 58164298 total parameters, 0 are zero.
Student Network: 58164298 total parameters, 11602625 are zero.
Percentage of zero parameters in Teacher Network: 0.00%
Percentage of zero parameters in Student Network: 19.95%
