In [1]:
import torch
from torch import nn
from torch.autograd import Variable

from sklearn.datasets import load_iris
from sklearn.preprocessing import StandardScaler

from nni.algorithms.compression.v2.pytorch.pruning import LinearPruner
from nni.compression.pytorch import TorchEvaluator
import nni

In [2]:
iris = load_iris()
X = iris["data"]
X = Variable(torch.from_numpy(StandardScaler().fit_transform(X))).float()
y = iris["target"]
y = Variable(torch.from_numpy(y)).long()

In [3]:
model = nn.Sequential(
    nn.Linear(4, 32),
    nn.Sigmoid(),
    nn.Linear(32, 3),
    nn.Softmax(dim=1)
)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [4]:
with torch.no_grad():
    y_pred = model(X)
    loss_start = loss_fn(y_pred, y)

for i in range(1000):
    y_pred = model(X)
    loss = loss_fn(y_pred, y)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

with torch.no_grad():
    y_pred = model(X)
    loss_end = loss_fn(y_pred, y)

assert (loss_start-loss_end) >= 0.01, f"{(loss_start-loss_end)}"

In [27]:
config_list = [{
    'sparsity_per_layer': 0.5,
    'op_types': ['Linear']
}]
itpruner = LinearPruner(model, config_list, "level", log_dir="./.nni_log", total_iteration=5)
itpruner.compress()

[2022-11-20 16:11:24] [32msimulated prune 0 remain/total: 32/32[0m
[2022-11-20 16:11:24] [32msimulated prune 2 remain/total: 3/3[0m
[2022-11-20 16:11:24] [32msimulated prune 0 remain/total: 32/32[0m
[2022-11-20 16:11:24] [32msimulated prune 2 remain/total: 3/3[0m
[2022-11-20 16:11:25] [32msimulated prune 0 remain/total: 31/32[0m
[2022-11-20 16:11:25] [32msimulated prune 2 remain/total: 3/3[0m
[2022-11-20 16:11:25] [32msimulated prune 0 remain/total: 31/32[0m
[2022-11-20 16:11:25] [32msimulated prune 2 remain/total: 3/3[0m
[2022-11-20 16:11:25] [32msimulated prune 0 remain/total: 31/32[0m
[2022-11-20 16:11:25] [32msimulated prune 2 remain/total: 3/3[0m
[2022-11-20 16:11:25] [32msimulated prune 0 remain/total: 31/32[0m
[2022-11-20 16:11:25] [32msimulated prune 2 remain/total: 3/3[0m


In [28]:
iteration, model, model_masks, *_ = itpruner.get_best_result()

In [29]:
from actuallysparse import converter

In [30]:
# todo: zintegrować jako część konwertera
sparse_model = nn.Sequential(
    converter.convert(model[0], "coo", model_masks["0"]["weight"]),
    model[1],
    converter.convert(model[2], "coo", model_masks["2"]["weight"]),
    model[3]
)

In [31]:
model(X)

tensor([[2.2866e-01, 7.7130e-01, 4.2970e-05],
        [1.2582e-01, 8.7412e-01, 6.1340e-05],
        [2.0059e-01, 7.9935e-01, 5.2136e-05],
        [1.7629e-01, 8.2365e-01, 6.1827e-05],
        [2.6519e-01, 7.3477e-01, 4.1297e-05],
        [2.8343e-01, 7.1651e-01, 5.4384e-05],
        [2.6471e-01, 7.3524e-01, 5.6522e-05],
        [2.1195e-01, 7.8800e-01, 4.7899e-05],
        [1.4546e-01, 8.5447e-01, 7.1931e-05],
        [1.4797e-01, 8.5198e-01, 4.9855e-05],
        [2.3590e-01, 7.6406e-01, 3.9550e-05],
        [2.3170e-01, 7.6825e-01, 5.1242e-05],
        [1.3683e-01, 8.6312e-01, 5.2247e-05],
        [1.9759e-01, 8.0236e-01, 5.0313e-05],
        [2.5698e-01, 7.4299e-01, 3.0701e-05],
        [3.4152e-01, 6.5844e-01, 4.1310e-05],
        [2.9567e-01, 7.0429e-01, 4.4965e-05],
        [2.3083e-01, 7.6912e-01, 5.0398e-05],
        [2.2018e-01, 7.7977e-01, 4.6961e-05],
        [2.9955e-01, 7.0040e-01, 4.5136e-05],
        [1.6159e-01, 8.3836e-01, 4.9783e-05],
        [2.8102e-01, 7.1893e-01, 5

In [32]:
sparse_model(X)

tensor([[2.2866e-01, 7.7130e-01, 4.2970e-05],
        [1.2582e-01, 8.7412e-01, 6.1340e-05],
        [2.0059e-01, 7.9935e-01, 5.2136e-05],
        [1.7629e-01, 8.2365e-01, 6.1827e-05],
        [2.6519e-01, 7.3477e-01, 4.1297e-05],
        [2.8343e-01, 7.1651e-01, 5.4384e-05],
        [2.6471e-01, 7.3524e-01, 5.6522e-05],
        [2.1195e-01, 7.8800e-01, 4.7899e-05],
        [1.4546e-01, 8.5447e-01, 7.1931e-05],
        [1.4797e-01, 8.5198e-01, 4.9855e-05],
        [2.3590e-01, 7.6406e-01, 3.9550e-05],
        [2.3170e-01, 7.6825e-01, 5.1242e-05],
        [1.3683e-01, 8.6312e-01, 5.2247e-05],
        [1.9759e-01, 8.0236e-01, 5.0313e-05],
        [2.5698e-01, 7.4299e-01, 3.0701e-05],
        [3.4152e-01, 6.5844e-01, 4.1310e-05],
        [2.9567e-01, 7.0429e-01, 4.4965e-05],
        [2.3083e-01, 7.6912e-01, 5.0398e-05],
        [2.2018e-01, 7.7977e-01, 4.6961e-05],
        [2.9955e-01, 7.0040e-01, 4.5136e-05],
        [1.6159e-01, 8.3836e-01, 4.9783e-05],
        [2.8102e-01, 7.1893e-01, 5