In [1]:
import numpy as np
import math

%matplotlib inline
import matplotlib.pyplot as plt

import pickle
import argparse
import time
import itertools
from copy import deepcopy
import os

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import csv
import sys
#sys.path.append('/content/KD')
# Import the module
import networks
import utils

%load_ext autoreload
%autoreload 2

In [2]:
use_gpu = True    # set use_gpu to True if system has gpu
gpu_id = 0        # id of gpu to be used
cpu_device = torch.device('cpu')
# fast_device is where computation (training, inference) happens
fast_device = torch.device('cpu')
if use_gpu:
    os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2'    # set visible devices depending on system configuration
    fast_device = torch.device('cuda:' + str(gpu_id))

In [3]:
def reproducibilitySeed():
    """
    Ensure reproducibility of results; Seeds to 0
    """
    torch_init_seed = 0
    torch.manual_seed(torch_init_seed)
    numpy_init_seed = 0
    np.random.seed(numpy_init_seed)
    if use_gpu:
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

reproducibilitySeed()

In [4]:
checkpoints_path_teacher = 'checkpoints_teacher/'
checkpoints_path_student = 'checkpoints_student_DKD_new/'
if not os.path.exists(checkpoints_path_teacher):
    os.makedirs(checkpoints_path_teacher)
if not os.path.exists(checkpoints_path_student):
    os.makedirs(checkpoints_path_student)

## Load Dataset

In [5]:
import torchvision
import torchvision.transforms as transforms
import PIL

# Set up transformations for CIFAR-10
transform_train = transforms.Compose(
    [
        #transforms.RandomCrop(32, padding=4),  # Augment training data by padding 4 and random cropping
        transforms.RandomHorizontalFlip(),     # Randomly flip images horizontally
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))  # Normalization for CIFAR-10
    ]
)

transform_test = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))  # Normalization for CIFAR-10
    ]
)

import torchvision as tv
preprocess_train = tv.transforms.Compose([
    tv.transforms.Resize((160, 160), interpolation=PIL.Image.BILINEAR),  # It's the default, just being explicit for the reader.
    tv.transforms.RandomCrop((128, 128)),
    tv.transforms.RandomHorizontalFlip(),
    tv.transforms.ToTensor(),
    tv.transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))  # Normalization for CIFAR-10
])

preprocess_eval = tv.transforms.Compose([
    tv.transforms.Resize((128, 128), interpolation=PIL.Image.BILINEAR),
    tv.transforms.ToTensor(),
    tv.transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))  # Normalization for CIFAR-10
])

# Load CIFAR-10 dataset
train_val_dataset = torchvision.datasets.CIFAR10(root='./CIFAR10_dataset/', train=True,
                                            download=True, transform=transform_train)

test_dataset = torchvision.datasets.CIFAR10(root='./CIFAR10_dataset/', train=False,
                                            download=True, transform=transform_test)

# Split the training dataset into training and validation
num_train = int(0.95 * len(train_val_dataset))  # 95% of the dataset for training
num_val = len(train_val_dataset) - num_train  # Remaining 5% for validation
train_dataset, val_dataset = torch.utils.data.random_split(train_val_dataset, [num_train, num_val])

# DataLoader setup
batch_size = 128
train_val_loader = torch.utils.data.DataLoader(train_val_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

Files already downloaded and verified
Files already downloaded and verified


## Load Teacher

In [6]:
# Instantiate the teacher networks
teacher_net_1 = networks.TeacherNetwork50()
checkpoint = torch.load('resnet50_cifar10_pretrained.bin')


teacher_net_1.model.load_state_dict(checkpoint)
teacher_net_2 = networks.TeacherNetworkBiT()

# Create the ensemble model
#teacher_net = networks.EnsembleModel(teacher_net_1, teacher_net_2)

# Move the ensemble model to the appropriate device (e.g., GPU if available)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
teacher_net = teacher_net_1.to(device)

reproducibilitySeed()
_, test_accuracy = utils.getLossAccuracyOnDataset(teacher_net, test_loader, device)
print('test accuracy: ', test_accuracy)

Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /tmp/xdg-cache/torch/hub/checkpoints/resnet50-0676ba61.pth
100%|██████████| 97.8M/97.8M [00:00<00:00, 344MB/s]


cuda
test accuracy:  0.9225


## Train Student

In [7]:
num_epochs = 200
print_every = 100

In [8]:
def count_parameters(model):
    """
    Counts the total number of trainable parameters in a PyTorch model.

    Args:
        model (torch.nn.Module): The model whose parameters need to be counted.

    Returns:
        int: Total number of trainable parameters.
    """
    return sum((p.data != 0).sum().item() for p in model.parameters() if p.requires_grad)


def count_zero_parameters(model):
    """
    Counts the number of trainable parameters that are exactly zero in a PyTorch model.

    Args:
        model (torch.nn.Module): The model whose zero parameters need to be counted.

    Returns:
        int: Total number of trainable parameters that are exactly zero.
    """
    return sum((p.data == 0).sum().item() for p in model.parameters() if p.requires_grad)


In [None]:
import itertools
import os
import csv
import time
import torch

num_epochs = 200


temperatures = [4]
alphas = [2]
betas = [4, 8, 10]
learning_rates = [5e-4]
learning_rate_decays = [0.95]
weight_decays = [1e-4]
momentums = [0.9]
dropout_probabilities = [(0.0, 0.0)]
hparams_list = []

checkpoints_path_student = 'checkpoints_student_DKD_new/'

for hparam_tuple in itertools.product(alphas, betas, temperatures, dropout_probabilities, weight_decays, learning_rate_decays, momentums, learning_rates):
    hparam = {
        'alpha': hparam_tuple[0],
        'beta': hparam_tuple[1],
        'T': hparam_tuple[2],
        'dropout_input': hparam_tuple[3][0],
        'dropout_hidden': hparam_tuple[3][1],
        'weight_decay': hparam_tuple[4],
        'lr_decay': hparam_tuple[5],
        'momentum': hparam_tuple[6],
        'lr': hparam_tuple[7]
    }
    hparams_list.append(hparam)

results_distill = {}
pruning_factors = [0]

# CSV file setup
csv_file = checkpoints_path_student + "results_student.csv"
if not os.path.exists(csv_file):
    with open(csv_file, mode='w', newline='') as file:
        writer = csv.writer(file)
        writer.writerow([
            "Alpha", "Beta", "Temperature", "Dropout Input", "Dropout Hidden",
            "Weight Decay", "LR Decay", "Momentum", "Learning Rate",
            "Pruning Factor", "Zero Parameters", "Test Accuracy", "Training Time (s)"
        ])

# Training and logging
for pruning_factor in pruning_factors:
    for hparam in hparams_list:
        alpha = hparam['alpha']
        beta = hparam['beta']  # Now using beta from hparams_list

        print('Training with hparams' + utils.hparamToString(hparam) + f'_{alpha}_{beta}' + f' and pruning factor {pruning_factor}')

        # Measure training time
        start_time = time.time()

        reproducibilitySeed()
        student_net = networks.StudentNetwork()
        student_net.to(device)
        hparam_tuple = utils.hparamDictToTuple(hparam)

        # Count parameters
        student_params_num = count_parameters(student_net)

        print(pruning_factor, student_params_num, count_parameters(teacher_net))
        results_distill[(hparam_tuple, pruning_factor)] = utils.trainStudentWithDKD(
            teacher_net, student_net, hparam, num_epochs,
            train_loader, val_loader,
            print_every=print_every,
            fast_device=device, quant=False, checkpoint_save_path=checkpoints_path_student, a=alpha, b=beta
        )

        training_time = time.time() - start_time

        # Final model save
        final_save_path = checkpoints_path_student + utils.hparamToString(hparam) + f'_{alpha}_{beta}' + '.tar'
        torch.save({
            'results': results_distill[(hparam_tuple, pruning_factor)],
            'model_state_dict': student_net.state_dict(),
            'epoch': num_epochs
        }, final_save_path)

        # Calculate test accuracy
        _, test_accuracy = utils.getLossAccuracyOnDataset(student_net, test_loader, fast_device)
        print('Test accuracy: ', test_accuracy)

        # Write results to CSV
        with open(csv_file, mode='a', newline='') as file:
            writer = csv.writer(file)
            writer.writerow([
                alpha, beta, hparam['T'], hparam['dropout_input'], hparam['dropout_hidden'],
                hparam['weight_decay'], hparam['lr_decay'], hparam['momentum'], hparam['lr'],
                pruning_factor, student_params_num, test_accuracy, training_time
            ])

print(f"Results saved to {csv_file}")


Training with hparamsT=4, alpha=2, beta=4, dropout_hidden=0.0, dropout_input=0.0, lr=0.0005, lr_decay=0.95, momentum=0.9, weight_decay=0.0001_2_4 and pruning factor 0


Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /tmp/xdg-cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 330MB/s]


0 11173962 23520842
[Epoch 1, Batch 100/372] Loss: 12.036, Accuracy: 0.570


In [None]:
print(device)

In [None]:
_, test_accuracy = utils.getLossAccuracyOnDataset(student_net, test_loader, fast_device)
print('Test accuracy: ', test_accuracy)

In [None]:
final_save_path = checkpoints_path_student + utils.hparamToString(hparam) + f'_{alpha}_{beta}' + '.tar'
torch.save({
    'results': results_distill[(hparam_tuple, pruning_factor)],
    'model_state_dict': student_net.state_dict(),
    'epoch': num_epochs
    }, final_save_path)

In [None]:
print(hparams_list)