Pruning  
https://pytorch.org/tutorials/intermediate/pruning_tutorial.html  
枝刈りの種類  
Untructured Pruinig  
- prune.l1_unstructured: tensor単位の枝刈り  

Structured Pruinig   
- prune.ln_structured: channel単位の枝刈り  

In [None]:
import sys

import torch
import torch.nn.utils.prune as prune

sys.path.append("../1-classification_mnist/py")
from model import LeNet

In [None]:
model = LeNet()
print(list(model.features[0].named_parameters()))

In [None]:
# pruningの基本
## featuresの1層目のweightと名前がついてるnn.moduleのweightを30%ランダムに枝刈りする設定
prune.random_unstructured(model.features[0], name="weight", amount=0.4)

## 枝刈りするパラメータがweightからweight_origになる
print(list(model.features[0].named_parameters()))

In [None]:
## buffersにmaskが格納される(mask=0)
print(list(model.features[0].named_buffers()))
## modelのweightには, maskが適用された値が格納される
print(model.features[0].weight)

In [None]:
# forwardは枝刈りしたほうのweightが使われる
print(model.features[0]._forward_pre_hooks)

In [None]:
## featuresの1層目のbiasと名前がついてるnn.moduleのbiasのL1normが最小の2個で枝刈りする設定
prune.l1_unstructured(model.features[0], name="bias", amount=2)

In [None]:
print(list(model.features[0].named_buffers()))

In [None]:
print("before remove")
print(list(model.features[0].named_parameters()))
# name+'orig'とname+'_mask'を削除してweightを枝刈り前と同じstate_dictの状態にする
prune.remove(model.features[0], 'weight')
print("after remove")
print(list(model.features[0].named_parameters()))

In [None]:
# Structured Pruinig (モジュール単位)
weight_name = "./../1-classification_mnist/weight/MNIST_lenet_10.pth"
print("use pretrained model : %s" % weight_name)
param = torch.load(weight_name, map_location=lambda storage, loc: storage)
model = LeNet()
model.load_state_dict(param)

is_first_conv = True
prune_amount = 0.4
for name, module in model.named_modules():
    # prune 40% of connections in all 2D-conv layers
    if isinstance(module, torch.nn.Conv2d):
        if is_first_conv:
            is_first_conv = False
        else:
            prune.ln_structured(module, name='weight', amount=prune_amount, n=2, dim=1)
    # prune 40% of connections in all linear layers
    elif isinstance(module, torch.nn.Linear):
        prune.ln_structured(module, name='weight', amount=prune_amount, n=2, dim=1)
print(dict(model.named_buffers()).keys())

is_first_conv = True  # 50%超えると入力のweightを全部枝刈りするから
for name, module in model.named_modules():
    if isinstance(module, torch.nn.Conv2d):
        if is_first_conv:
            is_first_conv = False
        else:
            prune.remove(module, 'weight')
    elif isinstance(module, torch.nn.Linear):
        prune.remove(module, 'weight')
torch.save(model.state_dict(), 'weight/MNIST_lenet_10_structured_pruning.pth')

import itertools
for name, module in model.named_modules():
    if isinstance(module, torch.nn.Conv2d):
        zero_cnt = 0
        out_channel = module.weight.shape[0]
        in_channel = module.weight.shape[1]
        kernel_size = module.weight.shape[2] * module.weight.shape[3]
        for in_c in range(in_channel):
            weight_sum = 0.0
            for out_c in range(out_channel):
                weight_sum += torch.sum(module.weight[out_c, in_c])
            if weight_sum == 0:
                # print(name, "all zero weights channel", in_c)
                zero_cnt += 1
        print(name, "all zero weights ", zero_cnt, "/", in_channel)
    elif isinstance(module, torch.nn.Linear):
        zero_cnt = 0
        out_channel = module.weight.shape[0]
        in_channel = module.weight.shape[1]
        for in_c in range(in_channel):
            weight_sum = 0.0
            for out_c in range(out_channel):
                weight_sum += torch.sum(module.weight[out_c, in_c])
            if weight_sum == 0:
                # print(name, "all zero weights channel", in_c)
                zero_cnt += 1
        print(name, "all zero weights ", zero_cnt, "/", in_channel) 

In [None]:
# Unstructured Pruning (ネットワーク全体)
weight_name = "./../1-classification_mnist/weight/MNIST_lenet_10.pth"
print("use pretrained model : %s" % weight_name)
param = torch.load(weight_name, map_location=lambda storage, loc: storage)
model = LeNet()
model.load_state_dict(param)

parameters_to_prune = []
for name, module in model.named_modules():
    if isinstance(module, torch.nn.Conv2d):
        parameters_to_prune.append((module, 'weight'))
    elif isinstance(module, torch.nn.Linear):
        parameters_to_prune.append((module, 'weight'))
parameters_to_prune = tuple(parameters_to_prune)

print(parameters_to_prune)
prune_amount = 0.4
# ネットワーク全体でL1normが小さい順に20%枝刈り 
prune.global_unstructured(
    parameters_to_prune,
    pruning_method=prune.L1Unstructured,
    amount=prune_amount,
)

In [None]:
print("Sparsity in features[0].weight: {:.2f}%".format(
        100. * float(torch.sum(model.features[0].weight == 0))
        / float(model.features[0].weight.nelement())))
print("Sparsity in features[3].weight: {:.2f}%".format(
        100. * float(torch.sum(model.features[3].weight == 0))
        / float(model.features[3].weight.nelement())))
print("Sparsity in classifier[0].weight: {:.2f}%".format(
        100. * float(torch.sum(model.classifier[0].weight == 0))
        / float(model.classifier[0].weight.nelement())))
print("Sparsity in classifier[2].weight: {:.2f}%".format(
        100. * float(torch.sum(model.classifier[2].weight == 0))
        / float(model.classifier[2].weight.nelement())))
print("Sparsity in classifier[4].weight: {:.2f}%".format(
        100. * float(torch.sum(model.classifier[4].weight == 0))
        / float(model.classifier[4].weight.nelement())))

print("Global sparsity: {:.2f}%".format(
        100. * float(torch.sum(model.features[0].weight == 0)
            + torch.sum(model.features[3].weight == 0)
            + torch.sum(model.classifier[0].weight == 0)
            + torch.sum(model.classifier[2].weight == 0)
            + torch.sum(model.classifier[4].weight == 0))
        / float(model.features[0].weight.nelement()
            + model.features[3].weight.nelement()
            + model.classifier[0].weight.nelement()
            + model.classifier[2].weight.nelement()
            + model.classifier[4].weight.nelement())))

for name, module in model.named_modules():
    if isinstance(module, torch.nn.Conv2d):
        prune.remove(module, 'weight')
    elif isinstance(module, torch.nn.Linear):
        prune.remove(module, 'weight')
torch.save(model.state_dict(), 'weight/MNIST_lenet_10_unstructured_pruning.pth')

精度のみの比較  
cd ./../1-classification_mnist  
python py/main.py --evaluate --resume='weight/MNIST_lenet_10.pth'  
python py/main.py --evaluate --resume='./../notebook/weight/MNIST_lenet_10_unstructured_pruning.pth'  
python py/main.py --evaluate --resume='./../notebook/weight/MNIST_lenet_10_structured_pruning.pth'  

use pretrained model : weight/MNIST_lenet_10.pth  
Validate: [10/10] Loss 0.03638 (0.03311)  Acc@1  98.80 ( 98.92)   Acc@5 100.00 ( 99.99)   
use pretrained model : ./../notebook/weight/MNIST_lenet_10_unstructured_pruning.pth  
Validate: [10/10] Loss 0.03372 (0.03378)  Acc@1  98.90 ( 98.94)   Acc@5 100.00 ( 99.99)  
use pretrained model : ./../notebook/weight/MNIST_lenet_10_structured_pruning.pth  
Validate: [10/10] Loss 0.07547 (0.08172)  Acc@1  98.00 ( 97.52)   Acc@5 100.00 ( 99.99)  