In [1]:
import torch
import torchvision
import pretrained
import nni
from nni.algorithms.compression.v2.pytorch import TorchEvaluator
from nni.algorithms.compression.v2.pytorch.pruning import LinearPruner

In [2]:
transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
dataset_train = torchvision.datasets.CIFAR10(".data", download=True, transform=transform)
dataloader_train = torch.utils.data.DataLoader(dataset_train, batch_size=16)
dataset_test = torchvision.datasets.CIFAR10(".data", download=True, train=False, transform=transform)
dataloader_test = torch.utils.data.DataLoader(dataset_train, batch_size=16)

Files already downloaded and verified
Files already downloaded and verified


In [3]:
training_device = "cuda" if torch.cuda.is_available() else "cpu"
pruning_device = "cpu" # do pruning on the CPU to conserve VRAM. Working on 2GB is pain.
pretrained_model = pretrained.vgg11_bn(device=training_device)

In [4]:
def eval_accuracy(model, dataset="train"):
    dataloader = dataloader_train if dataset == "train" else dataloader_test
    with torch.no_grad():
        model.to(training_device)
        correct = 0
        all_so_far = 0
        for inputs, labels in dataloader:
            inputs, labels = inputs.to(training_device), labels.to(training_device)
            pred = torch.argmax(model(inputs), dim=1)

            all_so_far += labels.size().numel()
            correct += torch.sum(pred.eq(labels))
    return correct/all_so_far
eval_accuracy(pretrained_model)

tensor(0.9684, device='cuda:0')

In [5]:
def training_func(model, optimizers, criterion, *_args, **_kwargs):
    model.train()
    model.to(training_device)
    torch.cuda.empty_cache()
    for epoch in range(3):
        for inputs, labels in dataloader_train:
            inputs, labels = inputs.to(training_device), labels.to(training_device)
            optimizers.zero_grad()
            loss = criterion(model(inputs), labels)
            loss.backward()
            optimizers.step()
    model.to(pruning_device)
    torch.cuda.empty_cache()

In [6]:
def prune_model(model, sparsity, iterations):
    optimizer_pruner = nni.trace(torch.optim.Adam)(pretrained_model.parameters(), lr=1e-3)
    dummy_input = torch.rand(8, 32, 32).to(training_device)

    evaluator = TorchEvaluator(
        training_func=training_func,
        optimizers=optimizer_pruner,
        criterion=torch.nn.CrossEntropyLoss(),
        dummy_input=dummy_input)

    config_list = [{
        "sparsity": sparsity,
        "op_types": ["Linear"]
    }]

    model.to(pruning_device)

    itpruner = LinearPruner(
        model,
        config_list,
        total_iteration=iterations,
        pruning_algorithm="level",
        evaluator=evaluator,
        log_dir=".nni_log/")

    torch.cuda.empty_cache()
    itpruner.compress()
    return itpruner.get_best_result()

In [7]:
_, pruned_model, masks, *_ = prune_model(pretrained_model, .90, 10)
_, very_pruned_model, very_masks, *_ = prune_model(pretrained_model, .99, 10)

[2022-12-06 12:59:57] [32msimulated prune classifier.0 remain/total: 4096/4096[0m
[2022-12-06 12:59:57] [32msimulated prune classifier.3 remain/total: 4096/4096[0m
[2022-12-06 12:59:57] [32msimulated prune classifier.6 remain/total: 10/10[0m
[2022-12-06 13:09:00] [32msimulated prune classifier.0 remain/total: 3811/4096[0m
[2022-12-06 13:09:01] [32msimulated prune classifier.3 remain/total: 4096/4096[0m
[2022-12-06 13:09:01] [32msimulated prune classifier.6 remain/total: 10/10[0m
[2022-12-06 13:18:16] [32msimulated prune classifier.0 remain/total: 3715/4096[0m
[2022-12-06 13:18:16] [32msimulated prune classifier.3 remain/total: 4024/4096[0m
[2022-12-06 13:18:16] [32msimulated prune classifier.6 remain/total: 10/10[0m
[2022-12-06 13:27:29] [32msimulated prune classifier.0 remain/total: 3645/4096[0m
[2022-12-06 13:27:29] [32msimulated prune classifier.3 remain/total: 4001/4096[0m
[2022-12-06 13:27:29] [32msimulated prune classifier.6 remain/total: 10/10[0m
[2022-12

In [9]:
import copy
extra_trained_model = copy.deepcopy(pretrained_model)
normal_optimizer = torch.optim.Adam(extra_trained_model.parameters(), lr=1e-3)
for _ in range(10):
    training_func(extra_trained_model, normal_optimizer, torch.nn.CrossEntropyLoss())

In [10]:
print("pretrained: ", eval_accuracy(pretrained_model, "test"))
print("extra train: ", eval_accuracy(extra_trained_model, "test"))
print("0.9 prune: ", eval_accuracy(pruned_model, "test"))
print("0.99 prune: ", eval_accuracy(very_pruned_model, "test"))

pretrained:  tensor(0.9686, device='cuda:0')
extra train:  tensor(0.9809, device='cuda:0')
0.9 prune:  tensor(0.9959, device='cuda:0')
0.99 prune:  tensor(0.8350, device='cuda:0')


In [12]:
torch.save(pretrained_model, ".weights/full/pretrained")
torch.save(extra_trained_model, ".weights/full/extra_trained")
torch.save(pruned_model, ".weights/full/pruned")
torch.save(very_pruned_model, ".weights/full/very_pruned")