This repository has been archived by the owner on Sep 18, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 1.8k
/
iterative_pruning_torch.py
138 lines (115 loc) · 5.02 KB
/
iterative_pruning_torch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
'''
NNI example for supported iterative pruning algorithms.
In this example, we show the end-to-end iterative pruning process: pre-training -> pruning -> fine-tuning.
'''
import sys
import argparse
from tqdm import tqdm
import torch
from torchvision import datasets, transforms
from nni.compression.pytorch.pruning import (
LinearPruner,
AGPPruner,
LotteryTicketPruner
)
from pathlib import Path
sys.path.append(str(Path(__file__).absolute().parents[1] / 'models'))
from cifar10.vgg import VGG
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
train_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('./data', train=True, transform=transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(32, 4),
transforms.ToTensor(),
normalize,
]), download=True),
batch_size=128, shuffle=True)
test_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('./data', train=False, transform=transforms.Compose([
transforms.ToTensor(),
normalize,
])),
batch_size=128, shuffle=False)
criterion = torch.nn.CrossEntropyLoss()
def trainer(model, optimizer, criterion, epoch):
model.train()
for data, target in tqdm(iterable=train_loader, desc='Epoch {}'.format(epoch)):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
def finetuner(model):
model.train()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
criterion = torch.nn.CrossEntropyLoss()
for data, target in tqdm(iterable=train_loader, desc='Epoch PFs'):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
def evaluator(model):
model.eval()
correct = 0
with torch.no_grad():
for data, target in tqdm(iterable=test_loader, desc='Test'):
data, target = data.to(device), target.to(device)
output = model(data)
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
acc = 100 * correct / len(test_loader.dataset)
print('Accuracy: {}%\n'.format(acc))
return acc
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='PyTorch Iterative Example for model comporession')
parser.add_argument('--pruner', type=str, default='linear',
choices=['linear', 'agp', 'lottery'],
help='pruner to use')
parser.add_argument('--pretrain-epochs', type=int, default=10,
help='number of epochs to pretrain the model')
parser.add_argument('--total-iteration', type=int, default=10,
help='number of iteration to iteratively prune the model')
parser.add_argument('--pruning-algo', type=str, default='l1',
choices=['level', 'l1', 'l2', 'fpgm', 'slim', 'apoz',
'mean_activation', 'taylorfo', 'admm'],
help='algorithm to evaluate weights to prune')
parser.add_argument('--speedup', type=bool, default=False,
help='Whether to speedup the pruned model')
parser.add_argument('--reset-weight', type=bool, default=True,
help='Whether to reset weight during each iteration')
args = parser.parse_args()
model = VGG().to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
criterion = torch.nn.CrossEntropyLoss()
# pre-train the model
for i in range(args.pretrain_epochs):
trainer(model, optimizer, criterion, i)
evaluator(model)
config_list = [{'op_types': ['Conv2d'], 'sparsity': 0.8}]
dummy_input = torch.rand(10, 3, 32, 32).to(device)
# if you just want to keep the final result as the best result, you can pass evaluator as None.
# or the result with the highest score (given by evaluator) will be the best result.
kw_args = {'pruning_algorithm': args.pruning_algo,
'total_iteration': args.total_iteration,
'evaluator': None,
'finetuner': finetuner}
if args.speedup:
kw_args['speedup'] = args.speedup
kw_args['dummy_input'] = torch.rand(10, 3, 32, 32).to(device)
if args.pruner == 'linear':
iterative_pruner = LinearPruner
elif args.pruner == 'agp':
iterative_pruner = AGPPruner
elif args.pruner == 'lottery':
kw_args['reset_weight'] = args.reset_weight
iterative_pruner = LotteryTicketPruner
pruner = iterative_pruner(model, config_list, **kw_args)
pruner.compress()
_, model, masks, _, _ = pruner.get_best_result()
evaluator(model)