Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: support quadapter #118

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sparsebit/quantization/quantizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ def register_quantizer(quantizer):
from . import lsq_plus
from . import pact
from . import adaround
from . import quadapter


def build_quantizer(cfg):
Expand Down
1 change: 1 addition & 0 deletions sparsebit/quantization/quantizers/adaround.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def __init__(self, config):
config.TARGET[0] == QuantTarget.WEIGHT
), "AdaRound only supports to quant weights"
self.zeta, self.gamma = 1.1, -0.1 # stretch-parameters
self.reconstruct_qlayer = reconstruct_qlayer

def init_variables(self, x):
x_floor = torch.floor(x / self.scale)
Expand Down
65 changes: 65 additions & 0 deletions sparsebit/quantization/quantizers/quadapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import math
import torch
import torch.nn as nn
from torch.nn import functional as F

from sparsebit.quantization.quantizers import Quantizer as BaseQuantizer
from sparsebit.quantization.quantizers import register_quantizer
from .quant_tensor import STE


@register_quantizer
class Quantizer(BaseQuantizer):
TYPE = "Quadapter"

def __init__(self, config):
super(Quantizer, self).__init__(config)
self.reconstruct_qlayer = reconstruct_qlayer

def init_variables(self, x: torch.Tensor):
alpha_shape = [1 for _ in range(self.dims)]
alpha_shape[self.qdesc._ch_axis] = x.shape[self.qdesc._ch_axis]
self.alpha = nn.Parameter(torch.ones(alpha_shape).to(self.device))

def update_observer(self, x):
self.dims = len(x.shape)
self.observer.data_cache.update(x.detach())

def _forward(self, x_f, scale, zero_point):
x_f = x_f * self.alpha
x_dq = STE.apply(x_f, scale, zero_point, self.qdesc, self.backend)
x_dq = x_dq / self.alpha
return x_dq


def reconstruct_qlayer(
layer,
inputs: torch.Tensor,
outputs: torch.Tensor,
batch_size=32,
max_steps=20000,
p=2.0,
):
# init
layer.eval()
layer.set_quant(w_quant=True, a_quant=True)
layer.input_quantizer.init_variables(inputs)
layer.input_quantizer.train()
opt_params = [layer.input_quantizer.alpha]
optimizer = torch.optim.Adam(opt_params)
print_freq = 500
# training
device = layer.input_quantizer.device
inputs, outputs = inputs.to(device), outputs.to(device)
for step in range(max_steps):
idx = torch.randperm(inputs.size(0))[:batch_size]
cur_input, cur_output = inputs[idx], outputs[idx]
optimizer.zero_grad()
quant_output = layer(cur_input)
loss = (quant_output - cur_output).abs().pow(p).sum(1).mean()
loss.backward(retain_graph=True)
optimizer.step()
if step % print_freq == 0:
print("Loss: {:.3f} step={}".format(loss, step))
torch.cuda.empty_cache()
layer.input_quantizer.eval()
35 changes: 31 additions & 4 deletions sparsebit/quantization/tools/calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from functools import partial

from sparsebit.quantization.modules import QuantOpr
from sparsebit.quantization.quantizers.adaround import reconstruct_qlayer
from .graph_wrapper import GraphVisitor, fx_symbolic_trace
from .tensor_wrapper import to_cpu, to_device, to_detach

Expand Down Expand Up @@ -89,6 +88,8 @@ def layerwise_calibration(self, device, asym=False, w_quant=False, a_quant=False
float_outputs = self.module_forward(batch_num, node, device)
self.builder.storage.set_output(node.target, float_outputs)
self.run_weight_calibration(node, asym, a_quant=a_quant)
# layerwise reconstruction
self.run_layerwise_reconstruction(node, asym, a_quant=a_quant)
# foward quant output
if asym:
quant_outputs = self.module_forward(
Expand All @@ -115,15 +116,41 @@ def run_weight_calibration(self, node, asym=False, a_quant=False):
if isinstance(module, QuantOpr) and getattr(module, "weight_quantizer", None):
module.weight_quantizer.update_observer(module.weight)
module.weight_quantizer.calc_qparams()
if module.weight_quantizer.TYPE.lower() == "adaround":

def run_layerwise_reconstruction(self, node, asym=False, a_quant=False):
module = self.model
for n in node.target.split("."):
module = getattr(module, n)
if isinstance(module, QuantOpr):
if (
getattr(module, "input_quantizer", None)
and not module.input_quantizer.fake_fused
and module.input_quantizer.TYPE.lower() == "quadapter"
):
assert (
len(node.all_input_nodes) == 1
), "Quadapter not supports the oprs which has more than one inputs"
_storage = self.builder.qstorage if asym else self.builder.storage
inp_tensors = _storage.get_output(node.all_input_nodes[0].target)
out_tensors = self.builder.storage.get_output(node.target)
print("Reconstruct input_quantizer of {}".format(node.target))
module.input_quantizer.reconstruct_qlayer(
module,
torch.cat(inp_tensors, dim=0),
torch.cat(out_tensors, dim=0),
)
if (
getattr(module, "weight_quantizer", None)
and module.weight_quantizer.TYPE.lower() == "adaround"
):
assert (
len(node.all_input_nodes) == 1
), "AdaRound not supports the oprs which has more than one inputs"
_storage = self.builder.qstorage if asym else self.builder.storage
inp_tensors = _storage.get_output(node.all_input_nodes[0].target)
out_tensors = self.builder.storage.get_output(node.target)
print("Reconstruct {}".format(node.target))
reconstruct_qlayer(
print("Reconstruct weight_quantizer of {}".format(node.target))
module.weight_quantizer.reconstruct_qlayer(
module,
torch.cat(inp_tensors, dim=0),
torch.cat(out_tensors, dim=0),
Expand Down