Skip to content

Commit

Permalink
add regularizer
Browse files Browse the repository at this point in the history
  • Loading branch information
lz02k committed Dec 15, 2022
1 parent 12fa56e commit 4c5ae67
Show file tree
Hide file tree
Showing 8 changed files with 120 additions and 5 deletions.
15 changes: 10 additions & 5 deletions examples/quantization_aware_training/cifar10/basecase/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import time
import warnings
from enum import Enum
import math

import torch
import torch.nn as nn
Expand All @@ -21,6 +22,7 @@

from model import resnet20
from sparsebit.quantization import QuantModel, parse_qconfig
from sparsebit.quantization.regularizers import build_regularizer


parser = argparse.ArgumentParser(description="PyTorch Cifar Training")
Expand Down Expand Up @@ -147,8 +149,6 @@ def main():

qconfig = parse_qconfig(args.config)

is_pact = qconfig.A.QUANTIZER.TYPE == "pact"

qmodel = QuantModel(model, qconfig).cuda() # 将model转化为量化模型,以支持后续QAT的各种量化操作

# set head and tail of model is 8bit
Expand Down Expand Up @@ -181,6 +181,11 @@ def main():
optimizer, milestones=[100, 150], last_epoch=args.start_epoch - 1
)

if qconfig.REGULARIZER.ENABLE:
regularizer = build_regularizer(qconfig)
else:
regularizer = None

best_acc1 = 0
for epoch in range(args.start_epoch, args.epochs):
# train for one epoch
Expand All @@ -190,7 +195,7 @@ def main():
criterion,
optimizer,
epoch,
is_pact,
regularizer,
args.regularizer_lambda,
args.print_freq,
)
Expand Down Expand Up @@ -247,7 +252,7 @@ def train(
criterion,
optimizer,
epoch,
is_pact,
regularizer,
regularizer_lambda,
print_freq,
):
Expand Down Expand Up @@ -278,7 +283,7 @@ def train(
# compute output
output = model(images)
ce_loss = criterion(output, target)
regular_loss = get_regularizer_loss(model, is_pact, scale=regularizer_lambda)
regular_loss = get_regularizer_loss(model, regularizer, regularizer_lambda)
loss = ce_loss + regular_loss

# measure accuracy and record loss
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
BACKEND: virtual
W:
QSCHEME: per-channel-symmetric
QUANTIZER:
TYPE: lsq
BIT: 4
A:
QSCHEME: per-tensor-affine
QUANTIZER:
TYPE: lsq
BIT: 4
REGULARIZER:
ENABLE: True
TYPE: dampen
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,6 @@ A:
QUANTIZER:
TYPE: pact
BIT: 4
REGULARIZER:
ENABLE: True
TYPE: pact
4 changes: 4 additions & 0 deletions sparsebit/quantization/quant_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@
_C.A.QADD.ENABLE_QUANT = False
_C.A.SPECIFIC = []

_C.REGULARIZER = CN()
_C.REGULARIZER.ENABLE = False
_C.REGULARIZER.TYPE = ""


def parse_qconfig(cfg_file):
qconfig = _parse_config(cfg_file, default_cfg=_C)
Expand Down
15 changes: 15 additions & 0 deletions sparsebit/quantization/regularizers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
REGULARIZERS_MAP = {}


def register_regularizer(regularizer):
REGULARIZERS_MAP[regularizer.TYPE.lower()] = regularizer
return regularizer


from .base import Regularizer
from . import dampen, pact


def build_regularizer(config):
regularizer = REGULARIZERS_MAP[config.REGULARIZER.TYPE.lower()](config)
return regularizer
6 changes: 6 additions & 0 deletions sparsebit/quantization/regularizers/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
class Regularizer(object):
def __init__(self, config):
self.config = config

def __call__(self):
pass
48 changes: 48 additions & 0 deletions sparsebit/quantization/regularizers/dampen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import torch

from sparsebit.quantization.regularizers import Regularizer as BaseRegularizer
from sparsebit.quantization.regularizers import register_regularizer


@register_regularizer
class Regularizer(BaseRegularizer):
TYPE = "Dampen"

def __init__(self, config):
super(Regularizer, self).__init__(config)
self.config = config

def _get_loss(self, x, quantizer):

x_q = quantizer(x)

qmin, qmax = quantizer.qdesc.qrange

scale, zero_point = quantizer._qparams_preprocess(x)

scale = scale.detach()
zero_point = zero_point.detach()

min_val = (qmin - zero_point) * scale

max_val = (qmax - zero_point) * scale

x_c = torch.min(torch.max(x, min_val), max_val)

loss = (x_q - x_c) ** 2

loss = loss.sum()

return loss

def __call__(self, model):
loss = 0.0
for n, m in model.named_modules():
if (
hasattr(m, "weight")
and hasattr(m, "weight_quantizer")
and m.weight_quantizer
and m.weight_quantizer.is_enable
):
loss += self._get_loss(m.weight, m.weight_quantizer)
return loss
20 changes: 20 additions & 0 deletions sparsebit/quantization/regularizers/pact.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import torch

from sparsebit.quantization.regularizers import Regularizer as BaseRegularizer
from sparsebit.quantization.regularizers import register_regularizer


@register_regularizer
class Regularizer(BaseRegularizer):
TYPE = "Pact"

def __init__(self, config):
super(Regularizer, self).__init__(config)
self.config = config

def __call__(self, model):
loss = 0.0
for n, p in model.named_parameters():
if "alpha" in n:
loss += (p ** 2).sum()
return loss

0 comments on commit 4c5ae67

Please sign in to comment.