In [1]:
import torch
import torchvision
import pretrained
import nni

In [2]:
transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
dataset = torchvision.datasets.CIFAR10(".data", download=True, transform=transform)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=16)

Files already downloaded and verified


In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"
pruning_device = "cpu" # do pruning on the CPU to conserve VRAM. Working on 2GB is pain.

In [4]:
model = pretrained.vgg11_bn(device=device)
model.eval()


VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (4): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (5): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): ReLU(inplace=True)
    (7): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (8): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): ReLU(inplace=True)
    (11): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (12): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (13): ReLU(inplace=True)
    (14): MaxPool2d(ke

In [5]:
%%time
with torch.no_grad():
    model.to(device)
    correct = 0
    all_so_far = 0
    for inputs, labels in dataloader:
        inputs, labels = inputs.to(device), labels.to(device)
        pred = torch.argmax(model(inputs), dim=1)

        all_so_far += labels.size().numel()
        correct += torch.sum(pred.eq(labels))
        print(f"Accuracy so far: {correct/all_so_far:.2f}", end="\r")
print(f"Accuracy: {correct/all_so_far:.2f}                   ")

Accuracy: 0.91                   
CPU times: user 59.8 s, sys: 195 ms, total: 60 s
Wall time: 59.9 s


In [6]:
def training_func(model, optimizers, criterion, *_args, **_kwargs):
    model.train()
    model.to(device)
    torch.cuda.empty_cache()
    for epoch in range(3):
        for inputs, labels in dataloader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizers.zero_grad()
            loss = criterion(model(inputs), labels)
            loss.backward()
            optimizers.step()
    model.to(pruning_device)
    torch.cuda.empty_cache()

In [7]:
torch.cuda.empty_cache()

In [8]:
from nni.algorithms.compression.v2.pytorch import TorchEvaluator
from nni.algorithms.compression.v2.pytorch.pruning import LinearPruner

optimizer_pruner = nni.trace(torch.optim.Adam)(model.parameters(), lr=1e-3)
dummy_input = torch.rand(8, 32, 32).to(device)

evaluator = TorchEvaluator(
    training_func=training_func,
    optimizers=optimizer_pruner,
    criterion=torch.nn.CrossEntropyLoss(),
    dummy_input=dummy_input)

config_list = [{
    "sparsity": 0.99,
    "op_types": ["Linear"]
}]

model.to(pruning_device)

itpruner = LinearPruner(
    model,
    config_list,
    total_iteration=10,
    pruning_algorithm="level",
    evaluator=evaluator,
    log_dir=".nni_log/")

itpruner.compress()

[2022-11-29 20:38:55] [32msimulated prune classifier.0 remain/total: 4096/4096[0m
[2022-11-29 20:38:55] [32msimulated prune classifier.3 remain/total: 4096/4096[0m
[2022-11-29 20:38:55] [32msimulated prune classifier.6 remain/total: 10/10[0m
[2022-11-29 20:48:07] [32msimulated prune classifier.0 remain/total: 3797/4096[0m
[2022-11-29 20:48:07] [32msimulated prune classifier.3 remain/total: 4096/4096[0m
[2022-11-29 20:48:07] [32msimulated prune classifier.6 remain/total: 10/10[0m
[2022-11-29 20:57:03] [32msimulated prune classifier.0 remain/total: 3704/4096[0m
[2022-11-29 20:57:03] [32msimulated prune classifier.3 remain/total: 4019/4096[0m
[2022-11-29 20:57:03] [32msimulated prune classifier.6 remain/total: 10/10[0m
[2022-11-29 21:05:14] [32msimulated prune classifier.0 remain/total: 3633/4096[0m
[2022-11-29 21:05:14] [32msimulated prune classifier.3 remain/total: 3988/4096[0m
[2022-11-29 21:05:15] [32msimulated prune classifier.6 remain/total: 10/10[0m
[2022-11

In [9]:
it, compressed_model, masks, *_ = itpruner.get_best_result()
it

10

In [10]:
dataset_test = torchvision.datasets.CIFAR10(".data", download=True, train=False, transform=transform)
dataloader_test = torch.utils.data.DataLoader(dataset, batch_size=16)

Files already downloaded and verified


In [11]:
%%time
torch.cuda.empty_cache()
with torch.no_grad():
    compressed_model.to(device)
    correct = 0
    all_so_far = 0
    for inputs, labels in dataloader_test:
        inputs, labels = inputs.to(device), labels.to(device)
        pred = torch.argmax(compressed_model(inputs), dim=1)

        all_so_far += labels.size().numel()
        correct += torch.sum(pred.eq(labels))
        print(f"Accuracy so far: {correct/all_so_far:.5f}", end="\r")
print(f"Accuracy: {correct/all_so_far:.5f}                  ")

Accuracy: 0.83724                  
CPU times: user 1min 1s, sys: 56.3 ms, total: 1min 1s
Wall time: 1min 1s


In [12]:
torch.sum(compressed_model.classifier[0].weight.eq(0.0))/compressed_model.classifier[0].weight.size().numel()

tensor(0.9900, device='cuda:0')

In [13]:
compressed_model.classifier[0].weight

Parameter containing:
tensor([[-0.0000, -0.0000, 0.0090,  ..., -0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., -0.0000, -0.0000, -0.0000],
        [-0.0000, -0.0000, 0.0078,  ..., 0.0000, 0.0000, 0.0000],
        ...,
        [0.0089, -0.0000, -0.0000,  ..., -0.0000, -0.0000, -0.0000],
        [0.0000, -0.0000, -0.0000,  ..., -0.0000, -0.0000, -0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, -0.0000, 0.0000]],
       device='cuda:0', requires_grad=True)

In [14]:
compressed_model

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (4): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (5): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): ReLU(inplace=True)
    (7): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (8): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): ReLU(inplace=True)
    (11): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (12): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (13): ReLU(inplace=True)
    (14): MaxPool2d(ke