-
Notifications
You must be signed in to change notification settings - Fork 3
/
quantize.py
461 lines (351 loc) · 16.8 KB
/
quantize.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
################################################################################
# SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: MIT
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.
################################################################################
import os
import re
from typing import List, Callable, Union, Dict
from tqdm import tqdm
from copy import deepcopy
# PyTorch
import torch
import torch.optim as optim
from torch.cuda import amp
# Pytorch Quantization
from pytorch_quantization import nn as quant_nn
from pytorch_quantization.nn.modules import _utils as quant_nn_utils
from pytorch_quantization import calib
from pytorch_quantization.tensor_quant import QuantDescriptor
from pytorch_quantization import quant_modules
from absl import logging as quant_logging
import torch.nn.functional as F
from utils.general import (LOGGER,colorstr)
# Custom Rules
from models.quantize_rules import find_quantizer_pairs
class QuantAdd(torch.nn.Module, quant_nn_utils.QuantMixin):
def __init__(self, quantization):
super().__init__()
if quantization:
self._input0_quantizer = quant_nn.TensorQuantizer(QuantDescriptor())
self._input1_quantizer = quant_nn.TensorQuantizer(QuantDescriptor())
self.quantization = quantization
def forward(self, x, y):
if self.quantization:
#rint(f"QAdd {self._input0_quantizer} {self._input1_quantizer}")
return self._input0_quantizer(x) + self._input1_quantizer(y)
return x + y
class QuantUpsample(torch.nn.Module):
def __init__(self, size, scale_factor, mode):
super().__init__()
self.size = size
self.scale_factor = scale_factor
self.mode = mode
self._input_quantizer = quant_nn.TensorQuantizer(QuantDescriptor())
def forward(self, x):
return F.interpolate(self._input_quantizer(x), self.size, self.scale_factor, self.mode)
class QuantConcat(torch.nn.Module):
def __init__(self, dim):
super().__init__()
self._input0_quantizer = quant_nn.TensorQuantizer(QuantDescriptor())
self._input1_quantizer = quant_nn.TensorQuantizer(QuantDescriptor())
self.dim = dim
def forward(self, x, dim):
x_0 = self._input0_quantizer(x[0])
x_1 = self._input1_quantizer(x[1])
return torch.cat((x_0, x_1), self.dim)
class disable_quantization:
def __init__(self, model):
self.model = model
def apply(self, disabled=True):
for name, module in self.model.named_modules():
if isinstance(module, quant_nn.TensorQuantizer):
module._disabled = disabled
def __enter__(self):
self.apply(True)
def __exit__(self, *args, **kwargs):
self.apply(False)
class enable_quantization:
def __init__(self, model):
self.model = model
def apply(self, enabled=True):
for name, module in self.model.named_modules():
if isinstance(module, quant_nn.TensorQuantizer):
module._disabled = not enabled
def __enter__(self):
self.apply(True)
return self
def __exit__(self, *args, **kwargs):
self.apply(False)
def have_quantizer(module):
for name, module in module.named_modules():
if isinstance(module, quant_nn.TensorQuantizer):
return True
# Initialize PyTorch Quantization
def initialize():
quant_modules.initialize( )
quant_desc_input = QuantDescriptor(calib_method="histogram")
quant_nn.QuantConv2d.set_default_quant_desc_input(quant_desc_input)
quant_nn.QuantMaxPool2d.set_default_quant_desc_input(quant_desc_input)
quant_nn.QuantAvgPool2d.set_default_quant_desc_input(quant_desc_input)
quant_nn.QuantLinear.set_default_quant_desc_input(quant_desc_input)
quant_logging.set_verbosity(quant_logging.ERROR)
def transfer_torch_to_quantization(nninstance : torch.nn.Module, quantmodule):
quant_instance = quantmodule.__new__(quantmodule)
for k, val in vars(nninstance).items():
setattr(quant_instance, k, val)
def __init__(self):
quant_desc_input, quant_desc_weight = quant_nn_utils.pop_quant_desc_in_kwargs(self.__class__)
if isinstance(self, quant_nn_utils.QuantInputMixin):
self.init_quantizer(quant_desc_input)
# Turn on torch_hist to enable higher calibration speeds
if isinstance(self._input_quantizer._calibrator, calib.HistogramCalibrator):
self._input_quantizer._calibrator._torch_hist = True
else:
self.init_quantizer(quant_desc_input, quant_desc_weight)
# Turn on torch_hist to enable higher calibration speeds
if isinstance(self._input_quantizer._calibrator, calib.HistogramCalibrator):
self._input_quantizer._calibrator._torch_hist = True
self._weight_quantizer._calibrator._torch_hist = True
__init__(quant_instance)
return quant_instance
def quantization_ignore_match(ignore_policy : Union[str, List[str], Callable], path : str) -> bool:
if ignore_policy is None: return False
if isinstance(ignore_policy, Callable):
return ignore_policy(path)
if isinstance(ignore_policy, str) or isinstance(ignore_policy, List):
if isinstance(ignore_policy, str):
ignore_policy = [ignore_policy]
if path in ignore_policy: return True
for item in ignore_policy:
if re.match(item, path):
return True
return False
def set_module(model, submodule_key, module):
tokens = submodule_key.split('.')
sub_tokens = tokens[:-1]
cur_mod = model
for s in sub_tokens:
cur_mod = getattr(cur_mod, s)
setattr(cur_mod, tokens[-1], module)
def replace_to_quantization_module(model : torch.nn.Module, ignore_policy : Union[str, List[str], Callable] = None, prefixx=colorstr('QAT:')):
module_dict = {}
for entry in quant_modules._DEFAULT_QUANT_MAP:
module = getattr(entry.orig_mod, entry.mod_name)
module_dict[id(module)] = entry.replace_mod
def recursive_and_replace_module(module, prefix=""):
for name in module._modules:
submodule = module._modules[name]
path = name if prefix == "" else prefix + "." + name
recursive_and_replace_module(submodule, path)
submodule_id = id(type(submodule))
if submodule_id in module_dict:
ignored = quantization_ignore_match(ignore_policy, path)
if ignored:
LOGGER.info(f'{prefixx} Quantization: {path} has ignored.')
continue
module._modules[name] = transfer_torch_to_quantization(submodule, module_dict[submodule_id])
recursive_and_replace_module(model)
def get_attr_with_path(m, path):
def sub_attr(m, names):
name = names[0]
value = getattr(m, name)
if len(names) == 1:
return value
return sub_attr(value, names[1:])
return sub_attr(m, path.split("."))
def repnbottleneck_quant_forward(self, x):
if hasattr(self, "repaddop"):
return self.repaddop(x, self.cv2(self.cv1(x))) if self.add else self.cv2(self.cv1(x))
return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
def upsample_quant_forward(self, x):
if hasattr(self, "upsampleop"):
return self.upsampleop(x)
return F.interpolate(x)
def concat_quant_forward(self, x):
if hasattr(self, "concatop"):
return self.concatop(x, self.d)
return torch.cat(x, self.d)
def apply_custom_rules_to_quantizer(model : torch.nn.Module, export_onnx : Callable):
export_onnx(model, "quantization-custom-rules-temp.onnx")
pairs = find_quantizer_pairs("quantization-custom-rules-temp.onnx")
for major, sub in pairs:
print(f"Rules: {sub} match to {major}")
get_attr_with_path(model, sub)._input_quantizer = get_attr_with_path(model, major)._input_quantizer
os.remove("quantization-custom-rules-temp.onnx")
for name, module in model.named_modules():
if module.__class__.__name__ == "RepNBottleneck":
if module.add:
print(f"Rules: {name}.add match to {name}.cv1")
major = module.cv1.conv._input_quantizer
module.repaddop._input0_quantizer = major
module.repaddop._input1_quantizer = major
if isinstance(module, torch.nn.MaxPool2d):
quant_conv_desc_input = QuantDescriptor(num_bits=8, calib_method='histogram')
quant_maxpool2d = quant_nn.QuantMaxPool2d(module.kernel_size,
module.stride,
module.padding,
module.dilation,
module.ceil_mode,
quant_desc_input = quant_conv_desc_input)
set_module(model, name, quant_maxpool2d)
def replace_custom_module_forward(model):
for name, module in model.named_modules():
if module.__class__.__name__ == "RepNBottleneck":
if module.add:
if not hasattr(module, "repaddop"):
print(f"Add QuantAdd to {name}")
module.repaddop = QuantAdd(module.add)
module.__class__.forward = repnbottleneck_quant_forward
if module.__class__.__name__ == "Concat":
if not hasattr(module, "concatop"):
print(f"Add QuantConcat to {name}")
module.concatop = QuantConcat(module.d)
module.__class__.forward = concat_quant_forward
if module.__class__.__name__ == "Upsample":
if not hasattr(module, "upsampleop"):
print(f"Add QuantUpsample to {name}")
module.upsampleop = QuantUpsample(module.size, module.scale_factor, module.mode)
module.__class__.forward = upsample_quant_forward
def calibrate_model(model : torch.nn.Module, dataloader, device, num_batch=25):
def compute_amax(model, **kwargs):
for name, module in model.named_modules():
if isinstance(module, quant_nn.TensorQuantizer):
if module._calibrator is not None:
if isinstance(module._calibrator, calib.MaxCalibrator):
module.load_calib_amax()
else:
module.load_calib_amax(**kwargs)
module._amax = module._amax.to(device)
def collect_stats(model, data_loader, device, num_batch=200):
"""Feed data to the network and collect statistics"""
# Enable calibrators
model.eval()
for name, module in model.named_modules():
if isinstance(module, quant_nn.TensorQuantizer):
if module._calibrator is not None:
module.disable_quant()
module.enable_calib()
else:
module.disable()
# Feed data to the network for collecting stats
with torch.no_grad():
for i, datas in tqdm(enumerate(data_loader), total=num_batch, desc="Collect stats for calibrating"):
imgs = datas[0].to(device, non_blocking=True).float() / 255.0
model(imgs)
if i >= num_batch:
break
# Disable calibrators
for name, module in model.named_modules():
if isinstance(module, quant_nn.TensorQuantizer):
if module._calibrator is not None:
module.enable_quant()
module.disable_calib()
else:
module.enable()
with torch.no_grad():
collect_stats(model, dataloader, device, num_batch=num_batch)
compute_amax(model, method="percentile", percentile=99.999, strict=False) # strict=False avoid Exception when some quantizer are never used
def finetune(
model : torch.nn.Module, train_dataloader, no_last_layer, per_epoch_callback : Callable = None, preprocess : Callable = None,
nepochs=10, early_exit_batchs_per_epoch=1000, lrschedule : Dict = None, fp16=True, learningrate=1e-5,
supervision_policy : Callable = None, prefix=colorstr('QAT:')
):
origin_model = deepcopy(model).eval()
disable_quantization(origin_model).apply()
model.train()
model.requires_grad_(True)
scaler = amp.GradScaler(enabled=fp16)
optimizer = optim.Adam(model.parameters(), learningrate)
quant_lossfn = torch.nn.MSELoss()
device = next(model.parameters()).device
if no_last_layer:
last_layer_index = len(model.model) - 1
last_layer = model.model[last_layer_index]
if have_quantizer(last_layer):
LOGGER.info(f'{prefix} Quantization disabled for Last Layer model.{last_layer_index}')
disable_quantization(last_layer).apply()
if lrschedule is None:
lrschedule = {
0: 1e-6,
3: 1e-5,
6: 1e-6
}
def make_layer_forward_hook(l):
def forward_hook(m, input, output):
l.append(output)
return forward_hook
supervision_module_pairs = []
for ((mname, ml), (oriname, ori)) in zip(model.named_modules(), origin_model.named_modules()):
if isinstance(ml, quant_nn.TensorQuantizer): continue
if supervision_policy:
if not supervision_policy(mname, ml):
continue
supervision_module_pairs.append([ml, ori])
for iepoch in range(nepochs):
if iepoch in lrschedule:
learningrate = lrschedule[iepoch]
for g in optimizer.param_groups:
g["lr"] = learningrate
model_outputs = []
origin_outputs = []
remove_handle = []
for ml, ori in supervision_module_pairs:
remove_handle.append(ml.register_forward_hook(make_layer_forward_hook(model_outputs)))
remove_handle.append(ori.register_forward_hook(make_layer_forward_hook(origin_outputs)))
model.train()
pbar = tqdm(train_dataloader, desc="QAT", total=early_exit_batchs_per_epoch)
for ibatch, imgs in enumerate(pbar):
if ibatch >= early_exit_batchs_per_epoch:
break
if preprocess:
imgs = preprocess(imgs)
imgs = imgs.to(device)
with amp.autocast(enabled=fp16):
model(imgs)
with torch.no_grad():
origin_model(imgs)
quant_loss = 0
for mo, fo in zip(model_outputs, origin_outputs):
for m, f in zip(mo, fo):
quant_loss += quant_lossfn(m, f)
model_outputs.clear()
origin_outputs.clear()
if fp16:
scaler.scale(quant_loss).backward()
scaler.step(optimizer)
scaler.update()
else:
quant_loss.backward()
optimizer.step()
optimizer.zero_grad()
pbar.set_description(f"QAT Finetuning {iepoch + 1} / {nepochs}, Loss: {quant_loss.detach().item():.5f}, LR: {learningrate:g}")
# You must remove hooks during onnx export or torch.save
for rm in remove_handle:
rm.remove()
if per_epoch_callback:
if per_epoch_callback(model, iepoch, learningrate):
break
def export_onnx(model, input, file, *args, **kwargs):
quant_nn.TensorQuantizer.use_fb_fake_quant = True
model.eval()
with torch.no_grad():
torch.onnx.export(model, input, file, *args, **kwargs)
quant_nn.TensorQuantizer.use_fb_fake_quant = False