New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[OPT] Low-bit Quantization #2116

Open
wants to merge 10 commits into
base: master
from

Conversation

Projects
None yet
@ZihengJiang
Copy link
Member

ZihengJiang commented Nov 15, 2018

Thanks for contributing to TVM! Please refer to guideline https://docs.tvm.ai/contribute/ for useful information and tips. After the pull request is submitted, please request code reviews from Reviewers.

@ajtulloch

This comment has been minimized.

Copy link
Contributor

ajtulloch commented Nov 15, 2018

This is all assuming a symmetric quantization scheme, correct? Have you considered generalizing this slightly to an asymmetric quantization scheme like the one used in GEMMLOWP, QNNPACK, FBGEMM, NNAPI, etc?

Show resolved Hide resolved python/tvm/relay/_quantization.py Outdated
@tqchen

This comment has been minimized.

Copy link
Member

tqchen commented Nov 15, 2018

Since quantization is a major feature, it is better to send a RFC first

@ZihengJiang

This comment has been minimized.

Copy link
Member

ZihengJiang commented Nov 17, 2018

I will propose a RFC next week. Thanks @ajtulloch @tqchen .

@ZihengJiang ZihengJiang force-pushed the ZihengJiang:dev branch from bff439a to 43fb017 Nov 27, 2018

@tqchen tqchen force-pushed the ZihengJiang:dev branch from 81669fb to 963c758 Dec 2, 2018

@ZihengJiang ZihengJiang force-pushed the ZihengJiang:dev branch 2 times, most recently from 81669fb to d92f41e Dec 3, 2018

@ZihengJiang ZihengJiang requested review from Huyuwei and Laurawly as code owners Dec 4, 2018

Show resolved Hide resolved include/tvm/relay/op.h Outdated
Show resolved Hide resolved python/tvm/relay/build_module.py Outdated
Show resolved Hide resolved python/tvm/relay/quantize/quantize.py Outdated
Show resolved Hide resolved python/tvm/relay/quantize/quantize.py Outdated
Show resolved Hide resolved python/tvm/relay/quantize/quantize.py Outdated
Show resolved Hide resolved topi/python/topi/util.py Outdated
Show resolved Hide resolved python/tvm/relay/quantize/quantize_ops.py Outdated
Show resolved Hide resolved python/tvm/relay/quantize/quantize_ops.py Outdated
Show resolved Hide resolved src/relay/pass/forward_rewrite.cc Outdated
Show resolved Hide resolved src/relay/op/nn/convolution.cc Outdated

@ZihengJiang ZihengJiang force-pushed the ZihengJiang:dev branch from 37ee257 to a4c7b65 Dec 5, 2018

Show resolved Hide resolved python/tvm/relay/quantize/quantize.py Outdated
Show resolved Hide resolved src/relay/pass/pattern_util.h Outdated
Show resolved Hide resolved src/relay/pass/quantize.cc Outdated
@ajtulloch

This comment has been minimized.

Copy link
Contributor

ajtulloch commented Dec 5, 2018

Has there been an RFC posted btw? This comment probably belongs there.

FWIW I'm a little concerned about some directions this PR is taking, or at least some use-cases that would be good to see handled that I don't see how they fit in currently.

For background on my perspective, a standard training flow for quantized models in TF/C2 (at least the fwk's I'm familiar with that implement this), is to:

  1. Implement a model in a standard ML framework, generally using fp16/bfloat16/fp32 compute precision as this has highest throughput on most commonly-used training hardware.
  2. (optionally) insert fake quantization (here, called simulated quantization) nodes at quantization boundaries (i.e. if your backend implements a fused Int8Conv + Int8Relu, you'd insert them after a Conv + Relu block), to simulate the quantization numerics at training time.
  3. Train the model as usual
  4. Implement a graph rewriting pass (i.e. TF's toco, C2's int8_converter, MXNet's quantization, etc) that rewrites the graph to target the int8 operators directly — i.e. remapping subgraphs of e.g. FP32Conv + FP32Relu to be a fused Int8ConvRelu operator. This requires computing output quantization parameters at requantization boundaries, which can be done either by
    • calibration to an example set of activations, via e.g. l-p norm or kl minimization (c2/tf/mxnet/tensorrt)
    • using activation ranges learned during training (c2/tf).
  5. Using this quantized graph, evaluate various metrics to verify the quantization-induced error/loss is acceptable.
  6. Deploy the quantized graph.

Does this workflow make sense to folks? If not, could folks please elaborate on where we differ?

Given this flow, we'd like to insert TVM into this process. One key use case that I'd like TVM to consider supporting is to allow frameworks to continue to use their existing approaches for Steps 1-5, and involve TVM in Step 6. There are several reasons for this, such as calibration-based quantization isn't always sufficient, and we'd like to supporting importing from existing int8 graph IRs like TFLite or C2.

I think requiring TVM to take on Steps 4 and 5 in order to implement quantized models is unnecessarily opinionated, and moves it towards being a fully-fledged framework in it's own right (which I thought was not the goal).

I would have thought one natural (and minimalistic) direction for TVM to support quantized models (which isn't precluded by this diff, but I want to see what folks think about this) would be something like:

  1. Implement (in topi) support for int8 ops (i.e. ((u)int8 inputs, int32 accumulation, int32 output). This is partially done already by the great work from folks in the community. If we generalize to asymmetric quantization (which IMO is quite important), then it's arguably more natural to represent the inputs/outputs as tuples of (uint8 tensor, float min, float max) or equivalently (uint8 tensor, int32 bias, float scale), and implement operators using this representation.
  2. Add some kind of requantize op in NNVM, that performs a int32 -> (u)int8 requantization with the appropriate output float min/float max obtained via calibration or training.
  3. Implement in nnvm frontend an importer for e.g. tflite models (which would mostly involve mapping ops like TFLiteConv into a nnvm::Conv + nnvm::Requantize sequence, and ensuring that TVM/NNVM fuse away sequences of requantize/pointwise/requantize), and demonstrate a) bitwise numerical equivalence, and b) speedups vs tflite's runtime for models like MobileNetV2 or similar.

Concretely, my concerns with this approach (assuming the goal is to be the 'the one true way' to execute quantized models in TVM) are that it a) integrates too early in the pipeline, which unnecessarily requires some assumptions, and b) these assumptions aren't the most general ones (i.e. requires symmetric quantization as used by e.g. MKLDNN), which precludes asymmetric quantization as in TF, TFLite, C2, GEMMLOWP, QNNPACK, and channel-wise quantization as in TF/C2 which is very useful for pushing bitwidths lower (see e.g. https://arxiv.org/pdf/1806.08342.pdf), and c) is less modular than other approaches, which makes it harder to target from existing frameworks that already support quantization.

I don't think our goals are in conflict, I just thought that I should put this on the radar. Happy to send out an RFC (and dedicate engineering effort) to the more alternative approach as well if folks are on board?

@tqchen

This comment has been minimized.

Copy link
Member

tqchen commented Dec 6, 2018

@ajtulloch an RFC need to be sent out and we won't merge the PR before the RFC get discussed, so we can move the discuss there after it get posted

@ZihengJiang

This comment has been minimized.

Copy link
Member

ZihengJiang commented Dec 6, 2018

Hi @ajtulloch, I have a paper deadline so I pushed forward this PR in a hurry to get a workable quantization workflow. Let me send out a RFC tomorrow. This PR won't be merged before we have discussion in the community.

@tqchen

This comment has been minimized.

Copy link
Member

tqchen commented Dec 6, 2018

x

@lixiaoquan

This comment has been minimized.

Copy link
Contributor

lixiaoquan commented Dec 7, 2018

Currently, it seems NNVM requires inputs of a op have same data type. But a quantization scheme may cause different types of inputs. Any suggestion about that?

@ajtulloch

This comment has been minimized.

Copy link
Contributor

ajtulloch commented Dec 8, 2018

@lixiaoquan there's no such requirement today AFAIK, it's user-controlled in the implementation of attr<FInferType>(..) for the relevant NNVM op.

@@ -213,3 +214,16 @@ def select_array(i, j):
return now

return tvm.compute(matrix.shape, select_array, name=name)


@tvm.register_func("print_tensor")

This comment has been minimized.

@tqchen

tqchen Dec 10, 2018

Member

sure, maybe we can add as an util later as separate PR, but we need documents on these

Show resolved Hide resolved python/tvm/relay/quantize/annotate_ops.py Outdated
Show resolved Hide resolved python/tvm/relay/quantize/annotate_ops.py Outdated

@tqchen tqchen force-pushed the ZihengJiang:dev branch from a4c7b65 to ac6b3d4 Dec 10, 2018

@ZihengJiang ZihengJiang force-pushed the ZihengJiang:dev branch from ac6b3d4 to a4c7b65 Dec 10, 2018

Show resolved Hide resolved src/relay/pass/quantize.cc Outdated
Show resolved Hide resolved python/tvm/relay/quantize/__init__.py Outdated
Show resolved Hide resolved tests/python/quantize/test_pass_quantize.py Outdated
Show resolved Hide resolved python/tvm/relay/quantize/annotate_ops.py Outdated
Show resolved Hide resolved python/tvm/relay/quantize/annotate_ops.py Outdated
Show resolved Hide resolved python/tvm/relay/quantize/quantize.py Outdated
Show resolved Hide resolved python/tvm/relay/quantize/quantize.py Outdated
Show resolved Hide resolved python/tvm/relay/quantize/quantize.py Outdated
Show resolved Hide resolved python/tvm/relay/quantize/quantize.py Outdated
Show resolved Hide resolved src/relay/pass/pattern_util.h Outdated
@vinx13

This comment was marked as resolved.

Copy link
Member

vinx13 commented Dec 17, 2018

Running on inception-v3 produces segmentation fault. It seems some memory issue (some free'd memory accessed by PyObject_Free), I'm not sure about the exact cause. Could you take a look? @ZihengJiang @tqchen

sym, params = tvm.relay.testing.inception_v3.get_workload(1)
sym = relay.quantize.quantize(sym, params)
graph, lib, params = relay.build(sym, target, params=params)

@ZihengJiang ZihengJiang changed the title [WIP] Low-bit Quantization [OPT] Low-bit Quantization Dec 18, 2018

rhs_expr, rhs_kind = _get_expr_kind(new_args[1])

if lhs_kind is None or lhs_kind != QAnnotateKind.INPUT:
lhs_expr = attach_simulated_quantize(lhs_expr, QAnnotateKind.INPUT)

This comment has been minimized.

@vinx13

vinx13 Jan 4, 2019

Member

can / should we avoid duplicated quantization in parallel branches? e.g

   data (fp32) 
    /       \
conv1    conv2

This comment has been minimized.

@tqchen

tqchen Jan 4, 2019

Member

we should and we can, possibly via memoization. In theory forward rewrite already memoize, if there is any problem, please provide a minimum test case and let us double check

This comment has been minimized.

@vinx13

vinx13 Jan 5, 2019

Member

This is a test case, to reproduce, you need to set opt level "CombineParallelConv2D": 4 to disable this pass.

import tvm
import tvm.relay as relay
import tvm.relay.testing

def get_workload():
    data = relay.var("data", shape=(1, 3, 224, 224), dtype='float32')
    conv1 = relay.testing.layers.conv2d(data=data, channels=16, kernel_size=(1, 1), name='conv1')
    shortcut = relay.testing.layers.conv2d(data=data, channels=16, kernel_size=(1, 1), name='sc')
    net = relay.add(conv1, shortcut)
    f = relay.Function(relay.ir_pass.free_vars(net), net)
    return relay.testing.init.create_workload(f)
    
sym, params = get_workload()
with tvm.relay.quantize.qconfig(skip_k_conv=0):
    sym = relay.quantize.quantize(sym, params)
print(sym.astext(show_meta_data=False))
tvm.relay.build(sym, 'llvm', params=params)

Result:

fn (%data: Tensor[(1, 3, 224, 224), float32])
    -> Tensor[(1, 16, 224, 224), float32] {
  %0 = multiply(%data, 16f) # ty=Tensor[(1, 3, 224, 224), float32]
  %1 = round(%0) # ty=Tensor[(1, 3, 224, 224), float32]
  %2 = clip(%1, a_min=-127, a_max=127) # ty=Tensor[(1, 3, 224, 224), float32]
  %3 = cast(%2, dtype="int8") # ty=Tensor[(1, 3, 224, 224), int8]
  %4 = meta.relay.Constant(id=0) # ty=Tensor[(16, 3, 1, 1), int8]
  %5 = nn.conv2d(%3, %4, channels=16, kernel_size=[1, 1], out_dtype="int32")
  %6 = multiply(%data, 16f) # ty=Tensor[(1, 3, 224, 224), float32]
  %7 = round(%6) # ty=Tensor[(1, 3, 224, 224), float32]
  %8 = clip(%7, a_min=-127, a_max=127) # ty=Tensor[(1, 3, 224, 224), float32]
  %9 = cast(%8, dtype="int8") # ty=Tensor[(1, 3, 224, 224), int8]
  %10 = meta.relay.Constant(id=1) # ty=Tensor[(16, 3, 1, 1), int8]
  %11 = nn.conv2d(%9, %10, channels=16, kernel_size=[1, 1], out_dtype="int32")
  %12 = add(%5, %11)
  %13 = add(%12, 64)
  %14 = right_shift(%13, 7)
  %15 = clip(%14, a_min=-127, a_max=127)
  %16 = cast(%15, dtype="int8")
  %17 = cast(%16, dtype="float32")
  %18 = multiply(%17, 0.0625f)
  %18
}

This comment has been minimized.

@tqchen

tqchen Jan 5, 2019

Member

@ZihengJiang @merrymercy @vinx13 can you look into this? let us open this testcase as an issue to be fixed

This comment has been minimized.

@vinx13

vinx13 Jan 7, 2019

Member

This is because quantization of data happens during rewrite of conv2d, so this won't be memorized. We need some message passing to quantize data during forward rewrite of data.

This comment has been minimized.

@jroesch

jroesch Jan 9, 2019

Member

Couldn't we also avoid this by using ANF ?

This comment has been minimized.

@tqchen

tqchen Jan 9, 2019

Member

It does not have things to do with ANF. The problem is that if two conv refers to the same input and they want to run the same transformation f on that input, there will be two such f.

One solution is to build a generic common subexpression combination(elimination) path to create a concise dag

This comment has been minimized.

@liangfu

liangfu Jan 18, 2019

Member

As we have seen in @vinx13 's test case, there're three multiply operations. The multipliers are typically 16 or 1/16 (which is 0.0625). In order to eliminate floating-point multiplication, can we convert them into shift operation upon integers?

This comment has been minimized.

@tqchen

tqchen Jan 18, 2019

Member

As far as i understand, this PR already do that

This comment has been minimized.

@liangfu

liangfu Jan 19, 2019

Member

I thought so, and i think it should be configured by disabling round_for_shift in qconfig. However, the argument doesn't actually work, or at least not for replacing the combination of multiply and round with shift operation.

@nhynes

nhynes approved these changes Jan 4, 2019

@tqchen tqchen self-assigned this Jan 8, 2019

@vinx13

This comment has been minimized.

Copy link
Member

vinx13 commented Jan 15, 2019

Currently if conv2d has bias enabled, we have two add which can be are possibly fused (the other add is from round_bias). This requires some reorder of quantizations. We need to cast to int8 before relu. Below is part of vgg for example.

    %9 = nn.conv2d(%p01, %p1, padding=[1, 1], channels=64, kernel_size=[3, 3], out_dtype="int32") # ty=Tensor[(1, 64, 224, 224), int32]
    %10 = add(%9, %p2) # ty=Tensor[(1, 64, 224, 224), int32]
    %11 = nn.relu(%10) # ty=Tensor[(1, 64, 224, 224), int32]
    %12 = add(%11, 512) # ty=Tensor[(1, 64, 224, 224), int32]
    %13 = right_shift(%12, 10) # ty=Tensor[(1, 64, 224, 224), int32]
    %14 = clip(%13, a_min=-127, a_max=127) # ty=Tensor[(1, 64, 224, 224), int32]
    %15 = cast(%14, dtype="int8") # ty=Tensor[(1, 64, 224, 224), int8]
@tqchen

This comment has been minimized.

Copy link
Member

tqchen commented Jan 15, 2019

@ajtulloch @vinx13 @merrymercy @jroesch Please take another round of review of this PR. I will assume lazy consensus and aim to get the comments addressed and merge it this week if there are no further comments.

Show resolved Hide resolved python/tvm/relay/quantize/_annotate.py
Show resolved Hide resolved python/tvm/relay/quantize/quantize.py
Show resolved Hide resolved python/tvm/relay/quantize/quantize.py
Show resolved Hide resolved python/tvm/relay/quantize/quantize.py
Show resolved Hide resolved python/tvm/relay/quantize/quantize.py
Show resolved Hide resolved src/relay/pass/quantize.cc
Show resolved Hide resolved src/relay/pass/quantize.cc
Show resolved Hide resolved src/relay/pass/quantize.cc Outdated
Show resolved Hide resolved src/relay/pass/quantize.cc
Show resolved Hide resolved src/relay/pass/quantize.cc

ZihengJiang added some commits Jan 16, 2019

Show resolved Hide resolved python/tvm/relay/quantize/quantize.py Outdated
Show resolved Hide resolved python/tvm/relay/quantize/quantize.py Outdated

ZihengJiang added some commits Jan 16, 2019

@tqchen

This comment has been minimized.

Copy link
Member

tqchen commented Jan 16, 2019

@antinucleon @hlu1 can you also take a pass over the PR?

@liangfu
Copy link
Member

liangfu left a comment

When I wonder why Jenkins is still silent, I found that the test script sits into tests/python/quantize instead of tests/python/unittest
(see http://ci.tvm.ai:8080/blue/organizations/jenkins/tvm/detail/PR-2116/42/pipeline/144)
The test script is not tested by Jenkins at all.
Please move the test script into the unittest folder. @ZihengJiang

Show resolved Hide resolved tests/python/quantize/test_pass_quantize.py Outdated
@ZihengJiang

This comment has been minimized.

Copy link
Member

ZihengJiang commented Jan 17, 2019

@liangfu Thanks for catching this outdated test

@@ -124,7 +124,7 @@ def _bind_params_by_name(func, params):
return expr.bind(func, bind_dict)


def optimize(func, target, params=None):
def optimize(func, target=None, params=None):

This comment has been minimized.

@ZihengJiang

ZihengJiang Jan 17, 2019

Member

Seems this API changes recently? It breaks some codes @tqchen

@ZihengJiang

This comment has been minimized.

Copy link
Member

ZihengJiang commented Jan 18, 2019

@tqchen

This comment has been minimized.

Copy link
Member

tqchen commented Jan 18, 2019

@antinucleon @hlu1 @anijain2305 please also help take a look when you have time

@eqy

This comment has been minimized.

Copy link
Contributor

eqy commented Jan 18, 2019

@ZihengJiang sorry this is basic question, but is there support for mixed quantization levels? It looks like currently we specify a global weight and activation precision only. Since we can already skip the first k conv layers, it seems that this would be a useful generalization.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment