In [1]:
import torch
import torchvision
import pretrained
import nni

In [2]:
transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
dataset = torchvision.datasets.CIFAR10(".data", download=True, transform=transform)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=16)

Files already downloaded and verified


In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [4]:
model = pretrained.vgg11_bn(device=device)
model.eval()
model.to("cpu")
None

In [19]:
# %%time
# with torch.no_grad():
#     correct = 0
#     all_so_far = 0
#     for inputs, labels in dataloader:
#         inputs, labels = inputs.to(device), labels.to(device)
#         pred = torch.argmax(model(inputs), dim=1)
#
#         all_so_far += labels.size().numel()
#         correct += torch.sum(pred.eq(labels))
#         print(f"Accuracy so far: {correct/all_so_far:.2f}", end="\r")
# print(f"Accuracy: {correct/all_so_far:.2f}")

Accuracy: 0.91r: 0.91
CPU times: user 56.3 s, sys: 139 ms, total: 56.4 s
Wall time: 56.4 s


In [9]:
def training_func(model, optimizers, criterion, *_args, **_kwargs):
    model.train()
    model.to("cuda")
    for epoch in range(1):
        print(f"Epoch {epoch}")
        torch.cuda.empty_cache()
        for inputs, labels in dataloader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizers.zero_grad()
            loss = criterion(model(inputs), labels)
            loss.backward()
            optimizers.step()
    model.to("cpu")

In [10]:
torch.cuda.empty_cache()

In [11]:
from nni.algorithms.compression.v2.pytorch import TorchEvaluator
from nni.algorithms.compression.v2.pytorch.pruning import LinearPruner

optimizer_pruner = nni.trace(torch.optim.Adam)(model.parameters(), lr=1e-3)
dummy_input = torch.rand(8, 32, 32).to(device)

evaluator = TorchEvaluator(
    training_func=training_func,
    optimizers=optimizer_pruner,
    criterion=torch.nn.CrossEntropyLoss(),
    dummy_input=dummy_input)

config_list = [{
    "sparsity": 0.5,
    "op_types": ["Linear"]
}]
itpruner = LinearPruner(
    model,
    config_list,
    total_iteration=10,
    pruning_algorithm="level",
    evaluator=evaluator,
    log_dir=".nni_log/")

itpruner.compress()

[2022-11-29 17:21:49] [32msimulated prune classifier.0 remain/total: 4096/4096[0m
[2022-11-29 17:21:49] [32msimulated prune classifier.3 remain/total: 4096/4096[0m
[2022-11-29 17:21:49] [32msimulated prune classifier.6 remain/total: 10/10[0m
Epoch 0
[2022-11-29 17:24:52] [32msimulated prune classifier.0 remain/total: 4096/4096[0m
[2022-11-29 17:24:52] [32msimulated prune classifier.3 remain/total: 4096/4096[0m
[2022-11-29 17:24:52] [32msimulated prune classifier.6 remain/total: 10/10[0m
Epoch 0
[2022-11-29 17:27:51] [32msimulated prune classifier.0 remain/total: 3796/4096[0m
[2022-11-29 17:27:51] [32msimulated prune classifier.3 remain/total: 4096/4096[0m
[2022-11-29 17:27:51] [32msimulated prune classifier.6 remain/total: 10/10[0m
Epoch 0
[2022-11-29 17:31:02] [32msimulated prune classifier.0 remain/total: 3739/4096[0m
[2022-11-29 17:31:03] [32msimulated prune classifier.3 remain/total: 4034/4096[0m
[2022-11-29 17:31:03] [32msimulated prune classifier.6 remain/t

In [12]:
it, compressed_model, masks, *_ = itpruner.get_best_result()
it

10

In [15]:
dataset_test = torchvision.datasets.CIFAR10(".data", download=True, train=False, transform=transform)
dataloader_test = torch.utils.data.DataLoader(dataset, batch_size=16)

Files already downloaded and verified


In [16]:
%%time
torch.cuda.empty_cache()
with torch.no_grad():
    compressed_model.to(device)
    correct = 0
    all_so_far = 0
    for inputs, labels in dataloader_test:
        inputs, labels = inputs.to(device), labels.to(device)
        pred = torch.argmax(compressed_model(inputs), dim=1)

        all_so_far += labels.size().numel()
        correct += torch.sum(pred.eq(labels))
        print(f"Accuracy so far: {correct/all_so_far:.2f}", end="\r")
print(f"Accuracy: {correct/all_so_far:.2f}                  ")

Accuracy: 0.96                  
CPU times: user 1min 4s, sys: 32.1 ms, total: 1min 4s
Wall time: 1min 4s


In [24]:
torch.sum(compressed_model.classifier[0].weight.eq(0.0))/compressed_model.classifier[0].weight.size().numel()

tensor(0.5000, device='cuda:0')