Skip to content

Commit

Permalink
add regularizer
Browse files Browse the repository at this point in the history
  • Loading branch information
lz02k committed Nov 17, 2022
1 parent 126c30d commit 8bd7bbc
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 21 deletions.
29 changes: 11 additions & 18 deletions examples/quantization_aware_training/cifar10/basecase/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import time
import warnings
from enum import Enum
import math

import torch
import torch.nn as nn
Expand All @@ -27,7 +28,7 @@
raise NotImplementedError("This example should run on a GPU device.") # 确定在GPU上运行


config = "qconfig_lsq.yaml" # QAT配置文件——包括量化方式(dorefa/lsq),权重和激活值的量化bit数等
config = "qconfig_lsq_dampen.yaml" # QAT配置文件——包括量化方式(dorefa/lsq),权重和激活值的量化bit数等
workers = 4
epochs = 200
start_epoch = 0
Expand All @@ -38,8 +39,7 @@
print_freq = 100
pretrained = ""
qconfig = parse_qconfig(config)
is_pact = qconfig.A.QUANTIZER.TYPE == "pact"
regularizer_lambda = 1e-4
regularizer_schedule = "cosine" if qconfig.REGULARIZER.TYPE == "dampen" else "keep"

model = resnet20(num_classes=10) # 以resnet20作为基础模型
if pretrained: # 可以采用pretrained中保存的模型参数
Expand Down Expand Up @@ -109,21 +109,8 @@
optimizer, milestones=[100, 150], last_epoch=start_epoch - 1
)

# PACT算法中对 alpha 增加 L2-regularization
def get_pact_regularizer_loss(model):
loss = 0
for n, p in model.named_parameters():
if "alpha" in n:
loss += (p ** 2).sum()
return loss

def get_regularizer_loss(model, scale=0):
if is_pact:
return get_pact_regularizer_loss(model) * scale
else:
return torch.tensor(0.).cuda()

def train(train_loader, model, criterion, optimizer, epoch):
def train(train_loader, model, criterion, optimizer, epoch, schedule_value=1.0):
batch_time = AverageMeter("Time", ":6.3f")
data_time = AverageMeter("Data", ":6.3f")
losses = AverageMeter("Loss", ":.4e")
Expand Down Expand Up @@ -151,7 +138,7 @@ def train(train_loader, model, criterion, optimizer, epoch):
# compute output
output = model(images)
ce_loss = criterion(output, target)
regular_loss = get_regularizer_loss(model, scale=regularizer_lambda)
regular_loss = model.get_regularizer_loss() * schedule_value
loss = ce_loss + regular_loss

# measure accuracy and record loss
Expand Down Expand Up @@ -311,12 +298,18 @@ def accuracy(output, target, topk=(1,)):
best_acc1 = 0
for epoch in range(start_epoch, epochs):
# train for one epoch
if regularizer_schedule == "cosine":
coeff = (epoch - start_epoch + 1) / (epochs - start_epoch)
schedule_value = 1 - 0.5 * (1.0 + math.cos(math.pi * coeff))
else:
schedule_value = 1.0
train(
trainloader,
model,
criterion,
optimizer,
epoch,
schedule_value=schedule_value,
)

# evaluate on validation set
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,7 @@ A:
QUANTIZER:
TYPE: pact
BIT: 4
REGULARIZER:
ENABLE: True
TYPE: pact
LAMBDA: 0.0001
5 changes: 5 additions & 0 deletions sparsebit/quantization/quant_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,11 @@
_C.A.OBSERVER.LAYOUT = "NCHW" # NCHW / NLC
_C.A.SPECIFIC = []

_C.REGULARIZER = CN()
_C.REGULARIZER.ENABLE = False
_C.REGULARIZER.TYPE = ""
_C.REGULARIZER.LAMBDA = 0.0


def parse_qconfig(cfg_file):
qconfig = _parse_config(cfg_file, default_cfg=_C)
Expand Down
20 changes: 18 additions & 2 deletions sparsebit/quantization/quant_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from sparsebit.quantization.quantizers import Quantizer
from sparsebit.quantization.tools import QuantizationErrorProfiler
from sparsebit.quantization.converters import simplify, fuse_operations
from sparsebit.quantization.regularizers import build_regularizer


__all__ = ["QuantModel"]
Expand All @@ -34,6 +35,7 @@ def __init__(self, model: nn.Module, config):
self._run_simplifiers()
self._convert2quantmodule()
self._build_quantizer()
self._build_regularizer()
self._run_fuse_operations()

def _convert2quantmodule(self):
Expand Down Expand Up @@ -119,11 +121,17 @@ def _sub_build(src, module_name):
update_config(_config, "A", _sub_build(self.cfg.A, node.target))
identity_module.build_quantizer(_config)

def _build_regularizer(self):
if self.cfg.REGULARIZER.ENABLE:
self.regularizer = build_regularizer(self.cfg)
else:
self.regularizer = None

def _run_simplifiers(self):
self.model = simplify(self.model)

def _run_fuse_operations(self):
if self.cfg.SCHEDULE.BN_TUNING: # first disable fuse bn
if self.cfg.SCHEDULE.BN_TUNING: # first disable fuse bn
update_config(self.cfg.SCHEDULE, "FUSE_BN", False)
self.model = fuse_operations(self.model, self.cfg.SCHEDULE)
self.model.graph.print_tabular()
Expand All @@ -144,7 +152,9 @@ def batchnorm_tuning(self):
yield
self.model.eval()
update_config(self.cfg.SCHEDULE, "FUSE_BN", True)
self.model = fuse_operations(self.model, self.cfg.SCHEDULE, custom_fuse_list=["fuse_bn"])
self.model = fuse_operations(
self.model, self.cfg.SCHEDULE, custom_fuse_list=["fuse_bn"]
)
self.set_quant(w_quant=False, a_quant=False)

def prepare_calibration(self):
Expand Down Expand Up @@ -210,6 +220,12 @@ def set_quant(self, w_quant=False, a_quant=False):
if isinstance(m, QuantOpr):
m.set_quant(w_quant, a_quant)

def get_regularizer_loss(self):
if self.regularizer is None:
return torch.tensor(0.).to(self.device)
else:
return self.regularizer(self.model)

def export_onnx(
self,
dummy_data,
Expand Down
2 changes: 1 addition & 1 deletion sparsebit/quantization/regularizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ def register_regularizer(regularizer):


from .base import Regularizer
from . import dampen
from . import dampen, pact


def build_regularizer(config):
Expand Down

0 comments on commit 8bd7bbc

Please sign in to comment.