Pruning
https://pytorch.org/tutorials/intermediate/pruning_tutorial.html

In [None]:
import sys

import torch
import torch.nn.utils.prune as prune

sys.path.append("../1-classification_mnist/py")
from model import LeNet

In [None]:
model = LeNet()
print(list(model.features[0].named_parameters()))

In [None]:
# Local Pruinig (モジュール単位)
## featuresの1層目のweightと名前がついてるnn.moduleのweightを30%ランダムに枝刈りする設定
prune.random_unstructured(model.features[0], name="weight", amount=0.4)

## 枝刈りするパラメータがweightからweight_origになる
print(list(model.features[0].named_parameters()))

In [None]:
## buffersにmaskが格納される(mask=0)
print(list(model.features[0].named_buffers()))
## modelのweightには, maskが適用された値が格納される
print(model.features[0].weight)

In [None]:
# forwardは枝刈りしたほうのweightが使われる
print(model.features[0]._forward_pre_hooks)

In [None]:
## featuresの1層目のbiasと名前がついてるnn.moduleのbiasのL1normが最小の2個で枝刈りする設定
prune.l1_unstructured(model.features[0], name="bias", amount=2)

In [None]:
print(list(model.features[0].named_buffers()))

In [None]:
print("before remove")
print(list(model.features[0].named_parameters()))
# name+'orig'とname+'_mask'を削除してweightを枝刈り前と同じstate_dictの状態にする
prune.remove(model.features[0], 'weight')
print("after remove")
print(list(model.features[0].named_parameters()))

In [None]:
weight_name = "./../1-classification_mnist/weight/MNIST_lenet_10.pth"
print("use pretrained model : %s" % weight_name)
param = torch.load(weight_name, map_location=lambda storage, loc: storage)
model = LeNet()
model.load_state_dict(param)

In [None]:
for name, module in model.named_modules():
    # prune 20% of connections in all 2D-conv layers
    if isinstance(module, torch.nn.Conv2d):
        prune.l1_unstructured(module, name='weight', amount=0.4)
    # prune 40% of connections in all linear layers
    elif isinstance(module, torch.nn.Linear):
        prune.l1_unstructured(module, name='weight', amount=0.4)
print(dict(model.named_buffers()).keys())

for name, module in model.named_modules():
    if isinstance(module, torch.nn.Conv2d):
        prune.remove(module, 'weight')
    elif isinstance(module, torch.nn.Linear):
        prune.remove(module, 'weight')
torch.save(model.state_dict(), 'weight/MNIST_lenet_10_local_pruning.pth')

In [None]:
# Global Pruning (ネットワーク全体)
weight_name = "./../1-classification_mnist/weight/MNIST_lenet_10.pth"
print("use pretrained model : %s" % weight_name)
param = torch.load(weight_name, map_location=lambda storage, loc: storage)
model = LeNet()
model.load_state_dict(param)

parameters_to_prune = []
for name, module in model.named_modules():
    if isinstance(module, torch.nn.Conv2d):
        parameters_to_prune.append((module, 'weight'))
    elif isinstance(module, torch.nn.Linear):
        parameters_to_prune.append((module, 'weight'))
parameters_to_prune = tuple(parameters_to_prune)

print(parameters_to_prune)

# ネットワーク全体でL1normが小さい順に20%枝刈り 
prune.global_unstructured(
    parameters_to_prune,
    pruning_method=prune.L1Unstructured,
    amount=0.4,
)

In [None]:
print("Sparsity in features[0].weight: {:.2f}%".format(
        100. * float(torch.sum(model.features[0].weight == 0))
        / float(model.features[0].weight.nelement())))
print("Sparsity in features[3].weight: {:.2f}%".format(
        100. * float(torch.sum(model.features[3].weight == 0))
        / float(model.features[3].weight.nelement())))
print("Sparsity in classifier[0].weight: {:.2f}%".format(
        100. * float(torch.sum(model.classifier[0].weight == 0))
        / float(model.classifier[0].weight.nelement())))
print("Sparsity in classifier[2].weight: {:.2f}%".format(
        100. * float(torch.sum(model.classifier[2].weight == 0))
        / float(model.classifier[2].weight.nelement())))
print("Sparsity in classifier[4].weight: {:.2f}%".format(
        100. * float(torch.sum(model.classifier[4].weight == 0))
        / float(model.classifier[4].weight.nelement())))

print("Global sparsity: {:.2f}%".format(
        100. * float(torch.sum(model.features[0].weight == 0)
            + torch.sum(model.features[3].weight == 0)
            + torch.sum(model.classifier[0].weight == 0)
            + torch.sum(model.classifier[2].weight == 0)
            + torch.sum(model.classifier[4].weight == 0))
        / float(model.features[0].weight.nelement()
            + model.features[3].weight.nelement()
            + model.classifier[0].weight.nelement()
            + model.classifier[2].weight.nelement()
            + model.classifier[4].weight.nelement())))

In [None]:
for name, module in model.named_modules():
    if isinstance(module, torch.nn.Conv2d):
        prune.remove(module, 'weight')
    elif isinstance(module, torch.nn.Linear):
        prune.remove(module, 'weight')
torch.save(model.state_dict(), 'weight/MNIST_lenet_10_global_pruning.pth')

Evaluate Acuracy  
cd ./../1-classification_mnist  
python py/main.py --evaluate --resume='weight/MNIST_lenet_10.pth'  
python py/main.py --evaluate --resume='./../notebook/weight/MNIST_lenet_10_local_pruning.pth'  
python py/main.py --evaluate --resume='./../notebook/weight/MNIST_lenet_10_global_pruning.pth'  

python py/main.py --evaluate --resume='weight/MNIST_lenet_10.pth'  
use pretrained model : weight/MNIST_lenet_10.pth  
Validate: [10/10] Time  0.081 ( 0.205)    Data  0.001 ( 0.108)    Loss 0.03638 (0.03311)  Acc@1  98.80 ( 98.92)   Acc@5 100.00 ( 99.99)  
python py/main.py --evaluate --resume='./../notebook/weight/MNIST_lenet_10_local_pruning.pth'  
use pretrained model : ./../notebook/weight/MNIST_lenet_10_local_pruning.pth  
Validate: [10/10] Time  0.080 ( 0.186)    Data  0.001 ( 0.086)    Loss 0.03094 (0.03337)  Acc@1  98.90 ( 98.95)   Acc@5 100.00 ( 99.99)  
python py/main.py --evaluate --resume='./../notebook/weight/MNIST_lenet_10_global_pruning.pth'  
use pretrained model : ./../notebook/weight/MNIST_lenet_10_global_pruning.pth  
Validate: [10/10] Time  0.087 ( 0.180)    Data  0.001 ( 0.082)    Loss 0.03372 (0.03378)  Acc@1  98.90 ( 98.94)   Acc@5 100.00 ( 99.99)  