In [45]:
import torch
import torchvision
import pretrained
import nni

In [41]:
transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
dataset = torchvision.datasets.CIFAR10(".data", download=True, transform=transform)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=8)

Files already downloaded and verified


In [42]:
model = pretrained.vgg11_bn()
model.eval()
None

In [43]:
%%time
with torch.no_grad():
    correct = 0
    for input, label in dataloader:
        pred = torch.argmax(model(input), dim=1)
        correct += torch.sum(pred == label)
        print(correct, end="\r")
correct

tensor(30553)

KeyboardInterrupt: 

In [31]:
49717/50000

0.99434

In [48]:
def training_func(model, optimizers, criterion, lr_schedulers, max_steps, max_epochs, *_, **__):
    model.train()
    total_epochs = max_epochs if max_epochs else 10
    total_steps = max_steps if max_steps else 100
    current_step = 0

    for epoch in range(total_epochs):
        print(f"Epoch {epoch}")
        for inputs, labels in dataloader:
            optimizers.zero_grad()
            loss = criterion(model(inputs), labels)
            loss.backward()
            optimizers.step()

            current_step += 1
            if current_step >= total_steps:
                return
        lr_schedulers.step()

In [49]:
from nni.algorithms.compression.v2.pytorch import TorchEvaluator
from nni.algorithms.compression.v2.pytorch.pruning import LinearPruner

optimizer_pruner = nni.trace(torch.optim.Adam)(model.parameters(), lr=1e-3)
dummy_input = torch.rand(8, 32, 32)

evaluator = TorchEvaluator(
    training_func=training_func,
    optimizers=optimizer_pruner,
    criterion=torch.nn.CrossEntropyLoss(),
    dummy_input=dummy_input)

config_list = [{
    "sparsity": 0.5,
    "op_types": ["Linear"]
}]
itpruner = LinearPruner(
    model,
    config_list,
    total_iteration=10,
    pruning_algorithm="level",
    evaluator=evaluator,
    log_dir=".nni_log/")

itpruner.compress()

[2022-11-29 15:16:43] [32msimulated prune classifier.0 remain/total: 4096/4096[0m
[2022-11-29 15:16:43] [32msimulated prune classifier.3 remain/total: 4096/4096[0m
[2022-11-29 15:16:43] [32msimulated prune classifier.6 remain/total: 10/10[0m
Epoch 0
[2022-11-29 15:17:20] [32msimulated prune classifier.0 remain/total: 4096/4096[0m
[2022-11-29 15:17:20] [32msimulated prune classifier.3 remain/total: 4096/4096[0m
[2022-11-29 15:17:20] [32msimulated prune classifier.6 remain/total: 10/10[0m
Epoch 0
[2022-11-29 15:17:58] [32msimulated prune classifier.0 remain/total: 3796/4096[0m
[2022-11-29 15:17:58] [32msimulated prune classifier.3 remain/total: 4096/4096[0m
[2022-11-29 15:17:58] [32msimulated prune classifier.6 remain/total: 10/10[0m
Epoch 0
[2022-11-29 15:18:36] [32msimulated prune classifier.0 remain/total: 3739/4096[0m
[2022-11-29 15:18:36] [32msimulated prune classifier.3 remain/total: 4034/4096[0m
[2022-11-29 15:18:36] [32msimulated prune classifier.6 remain/t

In [50]:
it, compressed_model, masks, *_ = itpruner.get_best_result()
it

10

In [51]:
%%time
correct = 0
with torch.no_grad():
    for input, label in dataloader:
        pred = torch.argmax(compressed_model(input), dim=1)
        correct += torch.sum(pred == label)
        print(correct, end="\r")
correct/50_000

CPU times: user 20min 34s, sys: 946 ms, total: 20min 35s
Wall time: 5min 8s


tensor(0.3371)