Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add regularizer #73

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 13 additions & 17 deletions examples/quantization_aware_training/cifar10/basecase/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import time
import warnings
from enum import Enum
import math

import torch
import torch.nn as nn
Expand All @@ -21,6 +22,7 @@

from model import resnet20
from sparsebit.quantization import QuantModel, parse_qconfig
from sparsebit.quantization.regularizers import build_regularizer


parser = argparse.ArgumentParser(description="PyTorch Cifar Training")
Expand Down Expand Up @@ -147,8 +149,6 @@ def main():

qconfig = parse_qconfig(args.config)

is_pact = qconfig.A.QUANTIZER.TYPE == "pact"

qmodel = QuantModel(model, qconfig).cuda() # 将model转化为量化模型,以支持后续QAT的各种量化操作

# set head and tail of model is 8bit
Expand Down Expand Up @@ -181,6 +181,11 @@ def main():
optimizer, milestones=[100, 150], last_epoch=args.start_epoch - 1
)

if qconfig.REGULARIZER.ENABLE:
regularizer = build_regularizer(qconfig)
else:
regularizer = None

best_acc1 = 0
for epoch in range(args.start_epoch, args.epochs):
# train for one epoch
Expand All @@ -190,7 +195,7 @@ def main():
criterion,
optimizer,
epoch,
is_pact,
regularizer,
args.regularizer_lambda,
args.print_freq,
)
Expand Down Expand Up @@ -225,18 +230,9 @@ def main():
)


# PACT算法中对 alpha 增加 L2-regularization
def get_pact_regularizer_loss(model):
loss = 0
for n, p in model.named_parameters():
if "alpha" in n:
loss += (p**2).sum()
return loss


def get_regularizer_loss(model, is_pact, scale=0):
if is_pact:
return get_pact_regularizer_loss(model) * scale
def get_regularizer_loss(model, regularizer, _lambda):
if regularizer is not None:
return regularizer(model) * _lambda
else:
return torch.tensor(0.0).cuda()

Expand All @@ -247,7 +243,7 @@ def train(
criterion,
optimizer,
epoch,
is_pact,
regularizer,
regularizer_lambda,
print_freq,
):
Expand Down Expand Up @@ -278,7 +274,7 @@ def train(
# compute output
output = model(images)
ce_loss = criterion(output, target)
regular_loss = get_regularizer_loss(model, is_pact, scale=regularizer_lambda)
regular_loss = get_regularizer_loss(model, regularizer, regularizer_lambda)
loss = ce_loss + regular_loss

# measure accuracy and record loss
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
BACKEND: virtual
W:
QSCHEME: per-channel-symmetric
QUANTIZER:
TYPE: lsq
BIT: 4
A:
QSCHEME: per-tensor-affine
QUANTIZER:
TYPE: lsq
BIT: 4
REGULARIZER:
ENABLE: True
TYPE: dampen
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,6 @@ A:
QUANTIZER:
TYPE: pact
BIT: 4
REGULARIZER:
ENABLE: True
TYPE: pact
4 changes: 4 additions & 0 deletions sparsebit/quantization/quant_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@
_C.A.QADD.ENABLE_QUANT = False
_C.A.SPECIFIC = []

_C.REGULARIZER = CN()
_C.REGULARIZER.ENABLE = False
_C.REGULARIZER.TYPE = ""


def parse_qconfig(cfg_file):
qconfig = _parse_config(cfg_file, default_cfg=_C)
Expand Down
15 changes: 15 additions & 0 deletions sparsebit/quantization/regularizers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
REGULARIZERS_MAP = {}


def register_regularizer(regularizer):
REGULARIZERS_MAP[regularizer.TYPE.lower()] = regularizer
return regularizer


from .base import Regularizer
from . import dampen, pact


def build_regularizer(config):
regularizer = REGULARIZERS_MAP[config.REGULARIZER.TYPE.lower()](config)
return regularizer
6 changes: 6 additions & 0 deletions sparsebit/quantization/regularizers/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
class Regularizer(object):
def __init__(self, config):
self.config = config

def __call__(self):
pass
47 changes: 47 additions & 0 deletions sparsebit/quantization/regularizers/dampen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import torch

from sparsebit.quantization.regularizers import Regularizer as BaseRegularizer
from sparsebit.quantization.regularizers import register_regularizer
from sparsebit.quantization.quantizers.quant_tensor import fake_qrange_factory


@register_regularizer
class Regularizer(BaseRegularizer):
TYPE = "Dampen"

def __init__(self, config):
super(Regularizer, self).__init__(config)
self.config = config

def _get_loss(self, x, quantizer):

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please use black to reformat

x_q = quantizer(x)

scale, zero_point = quantizer._qparams_preprocess(x)

min_val, max_val = fake_qrange_factory[quantizer.backend](
scale, zero_point, quantizer.qdesc
)

min_val = min_val.detach()
max_val = max_val.detach()

x_c = torch.min(torch.max(x, min_val), max_val)

loss = (x_q - x_c) ** 2

loss = loss.sum()

return loss

def __call__(self, model):
loss = 0.0
for n, m in model.named_modules():
if (
hasattr(m, "weight")
and hasattr(m, "weight_quantizer")
and m.weight_quantizer
and m.weight_quantizer.is_enable
):
loss += self._get_loss(m.weight, m.weight_quantizer)
return loss
20 changes: 20 additions & 0 deletions sparsebit/quantization/regularizers/pact.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import torch

from sparsebit.quantization.regularizers import Regularizer as BaseRegularizer
from sparsebit.quantization.regularizers import register_regularizer


@register_regularizer
class Regularizer(BaseRegularizer):
TYPE = "Pact"

def __init__(self, config):
super(Regularizer, self).__init__(config)
self.config = config

def __call__(self, model):
loss = 0.0
for n, p in model.named_parameters():
if "alpha" in n:
loss += (p**2).sum()
return loss