Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
lz02k committed Nov 28, 2022
1 parent 8397203 commit 7189c8a
Show file tree
Hide file tree
Showing 7 changed files with 22 additions and 47 deletions.
29 changes: 12 additions & 17 deletions examples/quantization_aware_training/cifar10/basecase/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from model import resnet20
from sparsebit.quantization import QuantModel, parse_qconfig
from sparsebit.quantization.regularizers import build_regularizer


parser = argparse.ArgumentParser(description="PyTorch Cifar Training")
Expand Down Expand Up @@ -148,8 +149,6 @@ def main():

qconfig = parse_qconfig(args.config)

is_pact = qconfig.A.QUANTIZER.TYPE == "pact"

qmodel = QuantModel(model, qconfig).cuda() # 将model转化为量化模型,以支持后续QAT的各种量化操作

# set head and tail of model is 8bit
Expand Down Expand Up @@ -182,6 +181,11 @@ def main():
optimizer, milestones=[100, 150], last_epoch=args.start_epoch - 1
)

if qconfig.REGULARIZER.ENABLE:
regularizer = build_regularizer(qconfig)
else:
regularizer = None

best_acc1 = 0
for epoch in range(args.start_epoch, args.epochs):
# train for one epoch
Expand All @@ -191,7 +195,7 @@ def main():
criterion,
optimizer,
epoch,
is_pact,
regularizer,
args.regularizer_lambda,
args.print_freq,
)
Expand Down Expand Up @@ -226,18 +230,9 @@ def main():
)


# PACT算法中对 alpha 增加 L2-regularization
def get_pact_regularizer_loss(model):
loss = 0
for n, p in model.named_parameters():
if "alpha" in n:
loss += (p ** 2).sum()
return loss


def get_regularizer_loss(model, is_pact, scale=0):
if is_pact:
return get_pact_regularizer_loss(model) * scale
def get_regularizer_loss(model, regularizer, _lambda=0.0):
if regularizer:
return regularizer(model) * _lambda
else:
return torch.tensor(0.0).cuda()

Expand All @@ -248,7 +243,7 @@ def train(
criterion,
optimizer,
epoch,
is_pact,
regularizer,
regularizer_lambda,
print_freq,
):
Expand Down Expand Up @@ -279,7 +274,7 @@ def train(
# compute output
output = model(images)
ce_loss = criterion(output, target)
regular_loss = model.get_regularizer_loss() * schedule_value
regular_loss = get_regularizer_loss(model, regularizer, regularizer_lambda)
loss = ce_loss + regular_loss

# measure accuracy and record loss
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,3 @@ A:
REGULARIZER:
ENABLE: True
TYPE: dampen
LAMBDA: 0.01
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,3 @@ A:
REGULARIZER:
ENABLE: True
TYPE: pact
LAMBDA: 0.0001
3 changes: 1 addition & 2 deletions sparsebit/quantization/quant_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

_C = CN()
_C.BACKEND = "virtual"
_C.SKIP_TRACE_MODULES = [] # a list of modules_name
_C.SKIP_TRACE_MODULES = [] # a list of modules_name

_C.SCHEDULE = CN()
_C.SCHEDULE.FUSE_BN = False # use ``with torch.no_grad()`` if it's enabled
Expand Down Expand Up @@ -46,7 +46,6 @@
_C.REGULARIZER = CN()
_C.REGULARIZER.ENABLE = False
_C.REGULARIZER.TYPE = ""
_C.REGULARIZER.LAMBDA = 0.0


def parse_qconfig(cfg_file):
Expand Down
29 changes: 7 additions & 22 deletions sparsebit/quantization/quant_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,7 @@
from sparsebit.quantization.quantizers import Quantizer
from sparsebit.quantization.tools import QuantizationErrorProfiler
from sparsebit.quantization.converters import simplify, fuse_operations
<<<<<<< HEAD
from sparsebit.quantization.quant_tracer import QTracer
=======
from sparsebit.quantization.regularizers import build_regularizer
>>>>>>> 8bd7bbc... add regularizer


__all__ = ["QuantModel"]
Expand All @@ -39,7 +35,6 @@ def __init__(self, model: nn.Module, config):
self._run_simplifiers()
self._convert2quantmodule()
self._build_quantizer()
self._build_regularizer()
self._run_fuse_operations()

def _convert2quantmodule(self):
Expand Down Expand Up @@ -138,22 +133,18 @@ def _sub_build(src, module_name):
update_config(_config, "A", _sub_build(self.cfg.A, node.target))
identity_module.build_quantizer(_config)

<<<<<<< HEAD
def _trace(self, model):
skipped_modules = self.cfg.SKIP_TRACE_MODULES
tracer = QTracer(skipped_modules)
graph = tracer.trace(model)
name = model.__class__.__name__ if isinstance(model, torch.nn.Module) else model.__name__
name = (
model.__class__.__name__
if isinstance(model, torch.nn.Module)
else model.__name__
)
traced = fx.GraphModule(tracer.root, graph, name)
traced.graph.print_tabular()
return traced
=======
def _build_regularizer(self):
if self.cfg.REGULARIZER.ENABLE:
self.regularizer = build_regularizer(self.cfg)
else:
self.regularizer = None
>>>>>>> 8bd7bbc... add regularizer

def _run_simplifiers(self):
self.model = simplify(self.model)
Expand Down Expand Up @@ -204,8 +195,8 @@ def calc_qparams(self):
def init_QAT(self):
named_modules = dict(self.model.named_modules())
# TODO: disable quant of input, note: not full-test
#input_nodes = [n for n in self.model.graph.nodes if n.op == "placeholder"]
#for input_node in input_nodes:
# input_nodes = [n for n in self.model.graph.nodes if n.op == "placeholder"]
# for input_node in input_nodes:
# input_users = list(input_node.users)
# while len(input_users) > 0:
# _user = input_users.pop() # 弹出最后一个元素
Expand Down Expand Up @@ -250,12 +241,6 @@ def set_quant(self, w_quant=False, a_quant=False):
if isinstance(m, QuantOpr):
m.set_quant(w_quant, a_quant)

def get_regularizer_loss(self):
if self.regularizer is None:
return torch.tensor(0.).to(self.device)
else:
return self.regularizer(self.model)

def export_onnx(
self,
dummy_data,
Expand Down
3 changes: 1 addition & 2 deletions sparsebit/quantization/regularizers/dampen.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ class Regularizer(BaseRegularizer):
def __init__(self, config):
super(Regularizer, self).__init__(config)
self.config = config
self._lambda = config.REGULARIZER.LAMBDA

def _get_loss(self, x, quantizer):

Expand Down Expand Up @@ -46,4 +45,4 @@ def __call__(self, model):
and m.weight_quantizer.is_enable
):
loss += self._get_loss(m.weight, m.weight_quantizer)
return loss * self._lambda
return loss
3 changes: 1 addition & 2 deletions sparsebit/quantization/regularizers/pact.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,10 @@ class Regularizer(BaseRegularizer):
def __init__(self, config):
super(Regularizer, self).__init__(config)
self.config = config
self._lambda = config.REGULARIZER.LAMBDA

def __call__(self, model):
loss = 0.0
for n, p in model.named_parameters():
if "alpha" in n:
loss += (p ** 2).sum()
return loss * self._lambda
return loss

0 comments on commit 7189c8a

Please sign in to comment.