Skip to content

Commit

Permalink
add regularizer
Browse files Browse the repository at this point in the history
  • Loading branch information
lz02k committed Nov 28, 2022
1 parent c512272 commit 8397203
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import time
import warnings
from enum import Enum
import math

import torch
import torch.nn as nn
Expand Down Expand Up @@ -278,7 +279,7 @@ def train(
# compute output
output = model(images)
ce_loss = criterion(output, target)
regular_loss = get_regularizer_loss(model, is_pact, scale=regularizer_lambda)
regular_loss = model.get_regularizer_loss() * schedule_value
loss = ce_loss + regular_loss

# measure accuracy and record loss
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,7 @@ A:
QUANTIZER:
TYPE: pact
BIT: 4
REGULARIZER:
ENABLE: True
TYPE: pact
LAMBDA: 0.0001
5 changes: 5 additions & 0 deletions sparsebit/quantization/quant_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@
_C.A.OBSERVER.LAYOUT = "NCHW" # NCHW / NLC
_C.A.SPECIFIC = []

_C.REGULARIZER = CN()
_C.REGULARIZER.ENABLE = False
_C.REGULARIZER.TYPE = ""
_C.REGULARIZER.LAMBDA = 0.0


def parse_qconfig(cfg_file):
qconfig = _parse_config(cfg_file, default_cfg=_C)
Expand Down
25 changes: 23 additions & 2 deletions sparsebit/quantization/quant_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,11 @@
from sparsebit.quantization.quantizers import Quantizer
from sparsebit.quantization.tools import QuantizationErrorProfiler
from sparsebit.quantization.converters import simplify, fuse_operations
<<<<<<< HEAD
from sparsebit.quantization.quant_tracer import QTracer
=======
from sparsebit.quantization.regularizers import build_regularizer
>>>>>>> 8bd7bbc... add regularizer


__all__ = ["QuantModel"]
Expand All @@ -35,6 +39,7 @@ def __init__(self, model: nn.Module, config):
self._run_simplifiers()
self._convert2quantmodule()
self._build_quantizer()
self._build_regularizer()
self._run_fuse_operations()

def _convert2quantmodule(self):
Expand Down Expand Up @@ -133,6 +138,7 @@ def _sub_build(src, module_name):
update_config(_config, "A", _sub_build(self.cfg.A, node.target))
identity_module.build_quantizer(_config)

<<<<<<< HEAD
def _trace(self, model):
skipped_modules = self.cfg.SKIP_TRACE_MODULES
tracer = QTracer(skipped_modules)
Expand All @@ -141,12 +147,19 @@ def _trace(self, model):
traced = fx.GraphModule(tracer.root, graph, name)
traced.graph.print_tabular()
return traced
=======
def _build_regularizer(self):
if self.cfg.REGULARIZER.ENABLE:
self.regularizer = build_regularizer(self.cfg)
else:
self.regularizer = None
>>>>>>> 8bd7bbc... add regularizer

def _run_simplifiers(self):
self.model = simplify(self.model)

def _run_fuse_operations(self):
if self.cfg.SCHEDULE.BN_TUNING: # first disable fuse bn
if self.cfg.SCHEDULE.BN_TUNING: # first disable fuse bn
update_config(self.cfg.SCHEDULE, "FUSE_BN", False)
self.model = fuse_operations(self.model, self.cfg.SCHEDULE)
self.model.graph.print_tabular()
Expand All @@ -167,7 +180,9 @@ def batchnorm_tuning(self):
yield
self.model.eval()
update_config(self.cfg.SCHEDULE, "FUSE_BN", True)
self.model = fuse_operations(self.model, self.cfg.SCHEDULE, custom_fuse_list=["fuse_bn"])
self.model = fuse_operations(
self.model, self.cfg.SCHEDULE, custom_fuse_list=["fuse_bn"]
)
self.set_quant(w_quant=False, a_quant=False)

def prepare_calibration(self):
Expand Down Expand Up @@ -235,6 +250,12 @@ def set_quant(self, w_quant=False, a_quant=False):
if isinstance(m, QuantOpr):
m.set_quant(w_quant, a_quant)

def get_regularizer_loss(self):
if self.regularizer is None:
return torch.tensor(0.).to(self.device)
else:
return self.regularizer(self.model)

def export_onnx(
self,
dummy_data,
Expand Down
2 changes: 1 addition & 1 deletion sparsebit/quantization/regularizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ def register_regularizer(regularizer):


from .base import Regularizer
from . import dampen
from . import dampen, pact


def build_regularizer(config):
Expand Down

0 comments on commit 8397203

Please sign in to comment.