forked from wuzhe71/CPD
-
Notifications
You must be signed in to change notification settings - Fork 0
/
prune.py
51 lines (43 loc) · 1.63 KB
/
prune.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import torch
import torchvision.transforms as transforms
import torch.utils.tensorboard as tensorboard
import torch.nn.utils.prune as prune
import os, argparse
from datetime import datetime
from model.models import CPD, CPD_A, CPD_darknet19, CPD_darknet19_A, CPD_darknet_A
device = torch.device('cpu')
state_dict = torch.load('CPD-O.pth', map_location=torch.device(device))
model = CPD().to(device)
save_path = 'pruned/'
if not os.path.exists(save_path):
os.makedirs(save_path)
for i in [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]:
model.load_state_dict(state_dict)
parameters_to_prune = []
for name, module in model.named_modules():
# prune 20% of connections in all 2D-conv layers
if isinstance(module, torch.nn.Conv2d):
parameters_to_prune.append([module, 'weight'])
prune.global_unstructured(
parameters_to_prune,
pruning_method=prune.L1Unstructured,
amount=i,
)
nelements = 0
weight_sum = 0
for name, module in model.named_modules():
if isinstance(module, torch.nn.Conv2d):
nelements += module.weight.nelement()
weight_sum += torch.sum(module.weight == 0)
# print(
# "Sparsity in {}: {:.2f}%".format(name,
# 100. * float(torch.sum(module.weight == 0))
# / float(module.weight.nelement())
# )
# )
gs = 100 * weight_sum // nelements
print('Global sparsity: {:.2f}%'.format(gs))
for para in parameters_to_prune:
print(para)
prune.remove(para[0], para[1])
torch.save(model.state_dict(), 'pruned/{}_{:.0f}.pth.'.format(model.name, gs))