In [126]:
import numpy as np
import torch
from torch import nn
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader

from model import StudentModel, BaseModel

In [127]:
from utils import accuracy_fn, print_time
from timeit import default_timer as timer

In [128]:
my_device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
my_device

device(type='cuda')

In [129]:
BATCH_SIZE = 32
PRUNE_PERCENT = 0.75

In [130]:
train_data = datasets.CIFAR10(
    root="data", train=True, download=True, transform=ToTensor()
)


test_data = datasets.CIFAR10(
    root="data", train=False, download=True, transform=ToTensor()
)

Files already downloaded and verified
Files already downloaded and verified


In [131]:
class_names = train_data.classes

In [132]:
train_dataloader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
test_dataloader = DataLoader(dataset=test_data, batch_size=BATCH_SIZE, shuffle=False)

In [133]:
criterion = nn.NLLLoss()

In [134]:
loaded_model = StudentModel().to(device=my_device)
f = "models/best_student_model_0.7_2.0_78.90.pt"
loaded_model.load_state_dict(torch.load(f, weights_only=True))
loaded_model

StudentModel(
  (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv2): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv3): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (fc1): Linear(in_features=1024, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=10, bias=True)
  (dropout): Dropout(p=0.25, inplace=False)
  (relu): ReLU()
  (log_softmax): LogSoftmax(dim=1)
  (flatten): Flatten(start_dim=1, end_dim=-1)
)

In [135]:
# base_model = BaseModel().to(device=my_device)

In [136]:
test_loss = 0.0
class_correct = np.zeros(10)
class_total = np.zeros(10)

loaded_model.eval()
for X, y in test_dataloader:
    X, y = X.to(my_device), y.to(my_device)
    with torch.inference_mode():
        output = loaded_model(X)
    loss = criterion(output, y)
    test_loss += loss.item() * X.size(0)
    _, pred = torch.max(output, 1)
    correct_tensor = pred.eq(y.data.view_as(pred))
    correct = np.squeeze(correct_tensor.cpu().numpy())
    for i, _ in enumerate(correct):
        label = int(y.data[i])
        class_correct[label] += correct[i].item()
        class_total[label] += 1

test_loss = test_loss / len(test_dataloader.dataset)
print("Test Loss: {:.6f}\n".format(test_loss))

for i in range(10):
    if class_total[i] > 0:
        print(
            "Test Accuracy of %5s: %2d%% (%2d/%2d)"
            % (
                class_names[i],
                100 * class_correct[i] / class_total[i],
                class_correct[i],
                class_total[i],
            )
        )
    else:
        print("Test Accuracy of %5s: N/A (no training examples)" % (class_names[i]))

correct = np.sum(class_correct, dtype=int)
total = np.sum(class_total, dtype=int)
print(f"\nTest Accuracy (Overall): {100.0 * correct / total}% ({correct}/{total})")

Test Loss: 0.647316

Test Accuracy of airplane: 84% (840/1000)
Test Accuracy of automobile: 89% (891/1000)
Test Accuracy of  bird: 69% (697/1000)
Test Accuracy of   cat: 59% (590/1000)
Test Accuracy of  deer: 76% (765/1000)
Test Accuracy of   dog: 70% (704/1000)
Test Accuracy of  frog: 84% (842/1000)
Test Accuracy of horse: 83% (833/1000)
Test Accuracy of  ship: 87% (879/1000)
Test Accuracy of truck: 84% (848/1000)

Test Accuracy (Overall): 78.89% (7889/10000)


In [137]:
def get_sparsity(tensor: torch.Tensor) -> float:
    """
    calculate the sparsity of the given tensor
        sparsity = #zeros / #elements = 1 - #nonzeros / #elements
    """
    return 1 - float(tensor.count_nonzero()) / tensor.numel()


def get_model_sparsity(model: nn.Module) -> float:
    """
    calculate the sparsity of the given model
        sparsity = #zeros / #elements = 1 - #nonzeros / #elements
    """
    num_nonzeros, num_elements = 0, 0
    for param in model.parameters():
        num_nonzeros += param.count_nonzero()
        num_elements += param.numel()
    return 1 - float(num_nonzeros) / num_elements


def get_num_parameters(model: nn.Module, count_nonzero_only=False) -> int:
    """
    calculate the total number of parameters of model
    :param count_nonzero_only: only count nonzero weights
    """
    num_counted_elements = 0
    for param in model.parameters():
        if count_nonzero_only:
            num_counted_elements += param.count_nonzero()
        else:
            num_counted_elements += param.numel()
    return num_counted_elements


def get_model_size(model: nn.Module, data_width=32, count_nonzero_only=False) -> int:
    """
    calculate the model size in bits
    :param data_width: #bits per element
    :param count_nonzero_only: only count nonzero weights
    """
    return get_num_parameters(model, count_nonzero_only) * data_width


Byte = 8
KiB = 1024 * Byte
MiB = 1024 * KiB
GiB = 1024 * MiB

In [138]:
def fine_grained_prune(tensor: torch.Tensor, sparsity: float) -> torch.Tensor:
    """
    magnitude-based pruning for single tensor
    :param tensor: torch.(cuda.)Tensor, weight of conv/fc layer
    :param sparsity: float, pruning sparsity
        sparsity = #zeros / #elements = 1 - #nonzeros / #elements
    :return:
        torch.(cuda.)Tensor, mask for zeros
    """
    sparsity = min(max(0.0, sparsity), 1.0)
    if sparsity == 1.0:
        tensor.zero_()
        return torch.zeros_like(tensor)
    elif sparsity == 0.0:
        return torch.ones_like(tensor)

    num_elements = tensor.numel()

    num_zeros = round(num_elements * sparsity)
    importance = tensor.abs()
    threshold = importance.view(-1).kthvalue(num_zeros).values
    mask = torch.gt(importance, threshold)
    tensor.mul_(mask)

    return mask

In [139]:
class FineGrainedPruner:
    def __init__(self, model, sparsity_dict):
        self.masks = FineGrainedPruner.prune(model, sparsity_dict)

    @torch.no_grad()
    def apply(self, model):
        for name, param in model.named_parameters():
            if name in self.masks:
                param *= self.masks[name]

    @staticmethod
    @torch.no_grad()
    def prune(model, sparsity_dict):
        masks = dict()
        for name, param in model.named_parameters():
            if param.dim() > 1:  # we only prune conv and fc weights
                if isinstance(sparsity_dict, dict):
                    masks[name] = fine_grained_prune(param, sparsity_dict[name])
                else:
                    assert sparsity_dict < 1 and sparsity_dict >= 0
                    if sparsity_dict > 0:
                        masks[name] = fine_grained_prune(param, sparsity_dict)
        return masks

In [140]:
epochs = 25
optimizer = torch.optim.Adam(loaded_model.parameters(), lr=1e-3, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)
# scheduler = None

In [141]:
best_test_acc = 0
train_loss_list = []
train_acc_list = []
test_loss_list = []
test_acc_list = []
lr_list = []

In [142]:
dense_model_size = get_model_size(loaded_model)
pruner = FineGrainedPruner(loaded_model, PRUNE_PERCENT)
pruner.apply(loaded_model)
sparse_model_size = get_model_size(loaded_model, count_nonzero_only=True)
print(
    f"{PRUNE_PERCENT*100}% sparse model has size={sparse_model_size/MiB:.2f} MiB, "
    f"which is {dense_model_size/sparse_model_size:.2f}X smaller than "
    f"the {dense_model_size/MiB:.2f} MiB dense model"
)

75.0% sparse model has size=0.15 MiB, which is 3.98X smaller than the 0.60 MiB dense model


In [143]:
loaded_model.eval()
test_loss, test_acc = 0, 0
time_start = timer()

with torch.inference_mode():
    for X, y in test_dataloader:
        X, y = X.to(my_device), y.to(my_device)
        y_pred = loaded_model.forward(X)
        mean_batch_loss: torch.Tensor = criterion(y_pred, y)
        test_loss += mean_batch_loss.item()
        test_acc += accuracy_fn(y, torch.argmax(y_pred, dim=1))

test_loss /= len(test_dataloader)
test_acc /= len(test_dataloader)

print(f"test loss: {test_loss:.4f} | test_acc: {test_acc:.2f}%")

time_end = timer()
total_train_time = print_time("Inference", time_start, time_end, my_device)

test loss: 2.3879 | test_acc: 17.64%

Inference time on cuda: 0.738 seconds


In [144]:
for epoch in range(epochs):
    print("epoch:", epoch)
    loaded_model.train()
    train_loss, train_acc = 0, 0

    for X, y in train_dataloader:
        X, y = X.to(my_device), y.to(my_device)
        y_pred = loaded_model.forward(X)
        mean_batch_loss = criterion(y_pred, y)
        train_loss += mean_batch_loss.item()
        train_acc += accuracy_fn(y, torch.argmax(y_pred, dim=1))
        optimizer.zero_grad()
        mean_batch_loss.backward()
        optimizer.step()

    if scheduler is not None:
        scheduler.step()

    pruner.apply(loaded_model)

    train_loss /= len(train_dataloader)  # loss per batch
    train_acc /= len(train_dataloader)  # accuracy per batch
    train_loss_list.append(train_loss)
    train_acc_list.append(train_acc)
    lr_list.append(optimizer.param_groups[0]["lr"])

    print(f"train loss: {train_loss:.4f} | train_acc: {train_acc:.2f}%")

    loaded_model.eval()
    test_loss, test_acc = 0, 0

    with torch.inference_mode():
        for X, y in test_dataloader:
            X, y = X.to(my_device), y.to(my_device)
            y_pred = loaded_model.forward(X)
            mean_batch_loss: torch.Tensor = criterion(y_pred, y)
            test_loss += mean_batch_loss.item()
            test_acc += accuracy_fn(y, torch.argmax(y_pred, dim=1))

    test_loss /= len(test_dataloader)
    test_acc /= len(test_dataloader)

    if test_acc > best_test_acc:
        best_test_acc = test_acc
        torch.save(
            loaded_model.state_dict(), f"models/best_pruned_model_{PRUNE_PERCENT}.pt"
        )

    test_loss_list.append(test_loss)
    test_acc_list.append(test_acc)
    print(f"test loss: {test_loss:.4f} | test_acc: {test_acc:.2f}%")
    print()

train_time_end = timer()
print("Best test accuracy: ", best_test_acc)
total_train_time = print_time("Train", time_start, train_time_end, my_device)

epoch: 0
train loss: 0.7344 | train_acc: 74.35%
test loss: 1.7939 | test_acc: 36.11%

epoch: 1
train loss: 0.7053 | train_acc: 75.23%
test loss: 1.3662 | test_acc: 51.54%

epoch: 2
train loss: 0.6726 | train_acc: 76.41%
test loss: 1.2125 | test_acc: 57.55%

epoch: 3
train loss: 0.6546 | train_acc: 77.03%
test loss: 1.1361 | test_acc: 60.28%

epoch: 4
train loss: 0.6346 | train_acc: 77.65%
test loss: 1.0176 | test_acc: 64.62%

epoch: 5
train loss: 0.6230 | train_acc: 77.98%
test loss: 0.9499 | test_acc: 66.89%

epoch: 6
train loss: 0.6106 | train_acc: 78.64%
test loss: 0.9368 | test_acc: 67.12%

epoch: 7
train loss: 0.5966 | train_acc: 79.10%
test loss: 0.9440 | test_acc: 66.59%

epoch: 8
train loss: 0.5830 | train_acc: 79.43%
test loss: 0.9285 | test_acc: 67.25%

epoch: 9
train loss: 0.5655 | train_acc: 80.14%
test loss: 0.8685 | test_acc: 69.37%

epoch: 10
train loss: 0.5536 | train_acc: 80.67%
test loss: 0.8340 | test_acc: 71.03%

epoch: 11
train loss: 0.5372 | train_acc: 81.12%
test