In [241]:
import torch
from torch import nn
from torch.autograd import Variable

from sklearn.datasets import load_iris
from sklearn.preprocessing import StandardScaler

from nni.algorithms.compression.v2.pytorch.pruning import LinearPruner
import nni

In [242]:
iris = load_iris()
X = iris["data"]
X = Variable(torch.from_numpy(StandardScaler().fit_transform(X))).float()
y = iris["target"]
y = Variable(torch.from_numpy(y)).long()

In [243]:
model = nn.Sequential(
    nn.Linear(4, 128),
    nn.Sigmoid(),
    nn.Linear(128, 3),
    nn.Softmax(dim=1)
)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [244]:
def accuracy(y_p, y_a):
    return torch.sum(torch.argmax(y_p, dim=1) == y_a)/len(y_a)

In [245]:
with torch.no_grad():
    y_pred = model(X)
    loss_start_train = loss_fn(y_pred, y)
    acc_start_train = accuracy(y_pred, y)

for epoch in range(2000):
    y_pred = model(X)
    loss = loss_fn(y_pred, y)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

with torch.no_grad():
    y_pred = model(X)
    loss_end_train = loss_fn(y_pred, y)
    acc_end_train = accuracy(y_pred, y)


print(f"loss reduction {(loss_start_train-loss_end_train)}")
print(f"accuracy before: {acc_start_train}, accuracy after: {acc_end_train}")

loss reduction 0.5107707977294922
accuracy before: 0.3333333432674408, accuracy after: 0.9866666793823242


In [246]:
def training_func(model, optimizers, criterion, _lr_schedulers, _max_steps, max_epochs, *_, **__):

    model.train()
    total_epochs = max_epochs if max_epochs else 2000

    for _ in range(total_epochs):
        optimizers.zero_grad()
        loss = criterion(model(X), y)
        loss.backward()
        optimizers.step()

In [247]:
from nni.algorithms.compression.v2.pytorch import TorchEvaluator

optimizer_pruner = nni.trace(torch.optim.Adam)(model.parameters(), lr=1e-3)
dummy_input = torch.rand(10, 4)

evaluator = TorchEvaluator(
    training_func=training_func,
    optimizers=optimizer_pruner,
    criterion=torch.nn.CrossEntropyLoss(),
    dummy_input=dummy_input)

config_list = [{
    "sparsity": 0.5,
    "op_types": ["Linear"]
}]
itpruner = LinearPruner(
    model,
    config_list,
    total_iteration=10,
    pruning_algorithm="level",
    evaluator=evaluator,
    log_dir=".nni_log/")

itpruner.compress()

[2022-11-29 12:51:47] [32msimulated prune 0 remain/total: 128/128[0m
[2022-11-29 12:51:47] [32msimulated prune 2 remain/total: 3/3[0m
[2022-11-29 12:51:49] [32msimulated prune 0 remain/total: 128/128[0m
[2022-11-29 12:51:49] [32msimulated prune 2 remain/total: 3/3[0m
[2022-11-29 12:51:50] [32msimulated prune 0 remain/total: 128/128[0m
[2022-11-29 12:51:50] [32msimulated prune 2 remain/total: 3/3[0m
[2022-11-29 12:51:52] [32msimulated prune 0 remain/total: 128/128[0m
[2022-11-29 12:51:52] [32msimulated prune 2 remain/total: 3/3[0m
[2022-11-29 12:51:54] [32msimulated prune 0 remain/total: 128/128[0m
[2022-11-29 12:51:54] [32msimulated prune 2 remain/total: 3/3[0m
[2022-11-29 12:51:55] [32msimulated prune 0 remain/total: 128/128[0m
[2022-11-29 12:51:55] [32msimulated prune 2 remain/total: 3/3[0m
[2022-11-29 12:51:57] [32msimulated prune 0 remain/total: 128/128[0m
[2022-11-29 12:51:57] [32msimulated prune 2 remain/total: 3/3[0m
[2022-11-29 12:51:59] [32msimulat

In [248]:
it, reduced_model, *_ = itpruner.get_best_result()
with torch.no_grad():
    loss_after_prune = loss_fn(reduced_model(X), y)
    acc_end_prune = accuracy(reduced_model(X), y)

print(f"best iteration {it}")
print(f"loss change after prune {loss_end_train- loss_after_prune}")
print(f"accuracy after prune: {acc_end_prune}")

best iteration 10
loss change after prune -0.18073707818984985
accuracy after prune: 0.7266666889190674


In [249]:
torch.sum(torch.isclose(reduced_model[0].weight, torch.tensor(0.0)))/reduced_model[0].weight.size().numel()

tensor(0.5000)

In [57]:
from actuallysparse import converter

In [59]:
sparse_model = nn.Sequential(
    converter.convert(model[0], "coo", masks["0"]["weight"]),
    model[1],
    converter.convert(model[2], "coo", masks["2"]["weight"]),
    model[3]
)

In [64]:
sparse_optimizer = torch.optim.Adam(sparse_model.parameters(), lr=0.001)

for epoch in range(1000):
    y_pred = sparse_model(X)
    loss = loss_fn(y_pred, y)

    sparse_optimizer.zero_grad()
    loss.backward()
    sparse_optimizer.step()

with torch.no_grad():
    y_pred = sparse_model(X)
    loss_fine_tuned = loss_fn(y_pred, y)

print(f"loss change {(loss_pruned - loss_fine_tuned)}")

RuntimeError: Adam does not support sparse gradients, please consider SparseAdam instead

In [31]:
model(X)

tensor([[2.2866e-01, 7.7130e-01, 4.2970e-05],
        [1.2582e-01, 8.7412e-01, 6.1340e-05],
        [2.0059e-01, 7.9935e-01, 5.2136e-05],
        [1.7629e-01, 8.2365e-01, 6.1827e-05],
        [2.6519e-01, 7.3477e-01, 4.1297e-05],
        [2.8343e-01, 7.1651e-01, 5.4384e-05],
        [2.6471e-01, 7.3524e-01, 5.6522e-05],
        [2.1195e-01, 7.8800e-01, 4.7899e-05],
        [1.4546e-01, 8.5447e-01, 7.1931e-05],
        [1.4797e-01, 8.5198e-01, 4.9855e-05],
        [2.3590e-01, 7.6406e-01, 3.9550e-05],
        [2.3170e-01, 7.6825e-01, 5.1242e-05],
        [1.3683e-01, 8.6312e-01, 5.2247e-05],
        [1.9759e-01, 8.0236e-01, 5.0313e-05],
        [2.5698e-01, 7.4299e-01, 3.0701e-05],
        [3.4152e-01, 6.5844e-01, 4.1310e-05],
        [2.9567e-01, 7.0429e-01, 4.4965e-05],
        [2.3083e-01, 7.6912e-01, 5.0398e-05],
        [2.2018e-01, 7.7977e-01, 4.6961e-05],
        [2.9955e-01, 7.0040e-01, 4.5136e-05],
        [1.6159e-01, 8.3836e-01, 4.9783e-05],
        [2.8102e-01, 7.1893e-01, 5

In [32]:
sparse_model(X)

tensor([[2.2866e-01, 7.7130e-01, 4.2970e-05],
        [1.2582e-01, 8.7412e-01, 6.1340e-05],
        [2.0059e-01, 7.9935e-01, 5.2136e-05],
        [1.7629e-01, 8.2365e-01, 6.1827e-05],
        [2.6519e-01, 7.3477e-01, 4.1297e-05],
        [2.8343e-01, 7.1651e-01, 5.4384e-05],
        [2.6471e-01, 7.3524e-01, 5.6522e-05],
        [2.1195e-01, 7.8800e-01, 4.7899e-05],
        [1.4546e-01, 8.5447e-01, 7.1931e-05],
        [1.4797e-01, 8.5198e-01, 4.9855e-05],
        [2.3590e-01, 7.6406e-01, 3.9550e-05],
        [2.3170e-01, 7.6825e-01, 5.1242e-05],
        [1.3683e-01, 8.6312e-01, 5.2247e-05],
        [1.9759e-01, 8.0236e-01, 5.0313e-05],
        [2.5698e-01, 7.4299e-01, 3.0701e-05],
        [3.4152e-01, 6.5844e-01, 4.1310e-05],
        [2.9567e-01, 7.0429e-01, 4.4965e-05],
        [2.3083e-01, 7.6912e-01, 5.0398e-05],
        [2.2018e-01, 7.7977e-01, 4.6961e-05],
        [2.9955e-01, 7.0040e-01, 4.5136e-05],
        [1.6159e-01, 8.3836e-01, 4.9783e-05],
        [2.8102e-01, 7.1893e-01, 5