Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] Dynamic Shape Support - Graph Dispatching #4118

Closed
3 tasks
kevinthesun opened this issue Oct 14, 2019 · 17 comments
Closed
3 tasks

[RFC] Dynamic Shape Support - Graph Dispatching #4118

kevinthesun opened this issue Oct 14, 2019 · 17 comments

Comments

@kevinthesun
Copy link
Contributor

Overview

There are more and more deployment requirements regarding dynamic input graphs, such as dynamic batching CV models and dynamic BERT. While dynamic input graph is supported in eager mode(Pytorch, Tensorflow Eager, MXNet Gluon) for model developing, TVM still just support static shape input models. In this thread I'll discuss about possible solution for dynamic shape AOT compilation.

Let's start by considering supporting a single operator with dynamic shape. TVM has already supported tensor expression with symbolic variable well, which means we have no difficulty in expressing a dynamic shape kernel with existing compute and schedule system. However, a single schedule cannot achieve desired performance for all possible values for a symbolic axis. For example, a dynamic batch conv2d on cuda can require quite different values of block_z and thread_z for different batch sizes. A possible method to solve this problem is to split symbolic axes into several buckets:
dg1 (1)
For each bucket, we select a representative kernel which performs well in the corresponding range for symbolic axis.
In this thread, I won't focus on this topic and @icemelon9 @comaniac @sxjscience will dive deep into this issue in other threads.

In this thread, we will discuss graph dispatching for dynamic shape. Bucketing method for kernel works well in runtime for operators which doesn’t require layout transformation, such as dense and batch_matmul(as for today's tvm implementation). However, in computer vision models, conv2d usually requires layout transformation to achieve better performance. Two issues raise to use kernel dispatch function in runtime:

  1. A generic layout transform function and a runtime layout tracking system are needed, which introduces a lot of complexity.
  2. Graph tuning is not well defined if kernel dispatch function is used, which brings performance degradation.

To resolve these issues, instead of kernel dispatch function, we use graph dispatch function which splits input shape of the whole graph into buckets and clone a graph for each bucket:
dg2 (1)
Graph dispatch function is a nested IfThenElse statement block which selects copy of graph depending on actual input shape. Thanks to the functional nature of relay, we can easily create global function in relay module to represent different clone of graph, and share parameters through function call. These are advantages for graph dispatch function:

  1. Modifications are done on relay level by inserting dispatch function and call original graph function. No extra change to VM runtime is required.(Though kernel shape function is required anyway)
  2. Parameter sharing is naturally achieved by function call. No runtime change is required.
  3. Graph tuning can be done for each copy of graph and no extra layout tracking system is required.
  4. Autotvm dispatch context ApplyGraphBest can be easily extended to support this feature.

API

We will add a new member function Dispatch to Relay Module:

void Dispatch(const std::string& func_name, const InputShapeDict& input_shape, const PackedFunc& dispatch_func);

This function update a global function inside module to be a dispatching block followed by copied functions. dispatch_func decides how to generate buckets.

Dispatch Function

Dispatch function is a function from an input shape dictionary to a map from input name to a map from symbolic axis index to list of intervals. For example, for input shape dictionary which represents a CV model allowing arbitrary image sizes:

{"data": (1, 3, tvm.relay.Any(), tvm.relay.Any())}

A logarithmical dispatch function returns a dictionary:

{
  "data":
      {
          2: [(1, 2), (2, 4), (4, 8), (8, 16), (16, 32), (32, 64), (64, 128), (128, 256), (256, None)],
          3: [(1, 2), (2, 4), (4, 8), (8, 16), (16, 32), (32, 64), (64, 128), (128, 256), (256, None)],
      }
}

As a result, in the final main function there will be 9 * 9 = 81 copies of original graph. Here introduces a tradeoff between overall performance and number of function kernels.

We will provide two pre-defined dispatching functions splitting uniformly and logarithmically. User can define their own customized dispatching function.

Prune buckets though boundary for symbolic axis.

In most practical cases, we don't really need a complete range [1, +inf) for symbolic axis. Boundary for tvm.var can greatly reduce the number of buckets and thus the number of kernel functions. In this design we don't consider any boundary pruning yet. We might want to leverage the idea in this topic: https://discuss.tvm.ai/t/discuss-embed-more-bound-information-into-var-or-expr/4079.

A working example:

input_name = "data"
input_shape = [tvm.relay.Any(), 3, 224, 224]
dtype = "float32"
block = get_model('resnet50_v1', pretrained=True)
mod, params = relay.frontend.from_mxnet(block, shape={input_name: input_shape}, dtype=dtype)
mod.dispatch("main", {input_name: input_shape}, tvm.relay.vm.log_dispatcher)
vmc = relay.backend.vm.VMCompiler()
with tvm.autotvm.apply_graph_best("resnet50_v1_graph_opt.log"):
    vm = vmc.compile(mod, "llvm")
                     
vm.init(ctx)
vm.load_params(params)

data = np.random.uniform(size=(1, 3, 224, 224)).astype("float32")
out = vm.run(data)

data = np.random.uniform(size=(4, 3, 224, 224)).astype("float32")
out = vm.run(data)

TODO

  • Relay module dispatch function.
  • Shape functions for most common operators in CV models.
  • Graph tuner changes to tune a dispatched graph.

@tqchen @jroesch @icemelon9 @comaniac @sxjscience @yzhliu @wweic @zhiics @yongwww @antinucleon @junrushao1994

This was referenced Oct 16, 2019
@soiferj
Copy link
Contributor

soiferj commented Oct 23, 2019

Thanks a lot for working on this, this is going to be really impactful, especially toward supporting NLP models. I have a couple of questions:

  1. Can you please explain the shape function in a little more detail? What exactly is its purpose? Will it have to be registered for every op?
  2. Some ops, like full, take their shape argument as a constant list. With this change, we could potentially support either a constant list or a relay expression that is unknown at compile time. How would that work? Would the operator definition have to change?

@tqchen
Copy link
Member

tqchen commented Oct 24, 2019

Thanks for the proposal. One high level comment: ideally we want to keep the module API minimum, and move transformation-like operations to the transform namespace :)

@icemelon
Copy link
Member

@soiferj

  1. Shape function is used to compute the output shape(s) of an op at runtime, which cannot be determined at compilation time. And yes, fow now, we have to register the shape function for all ops to support dynamic shape.
  2. We could do this. But we need to change the attribute of full op to let it take non-constant shapes.

@kevinthesun
Copy link
Contributor Author

@soiferj For full op, we can change the input shape argument to be relay.Expr. We use hybrid script to register shape functions, since most of them are not easy to be written as tensor expression. We only add CPU version shape functions, and relay on Heterogeneous execution for gpu.

@kevinthesun
Copy link
Contributor Author

@tqchen Sure. Dispatch function doesn't need to couple with relay::Module.

@zhanghaohit
Copy link
Contributor

any progress update about this feature? Thanks

@zhanghaohit
Copy link
Contributor

dg1 (1)

I'm still curious what will happen if we have conv2d(5, 3, 224, 224)? We'll use conv2d(8, 3, 224, 224)? Do we need to do some padding to use the kernel conv2d(8, 3, 224, 224)?

Thanks @kevinthesun for the clarification.

@kevinthesun
Copy link
Contributor Author

@zhanghaohit This feature is one part of dynamic codegen. We are working some significant backend features and will update this later.

@cloudhan
Copy link

cloudhan commented Jun 3, 2020

@kevinthesun Any timeframe?

Off topic, I want to mention TensorRT supports dynamic shape from 7.0. To provide better performance, it supports multiple optimization profiles for different shape range. Say your input is 1d ranged from 1 to 1024. You can create profiles whatever shape you specified, overlap or non-overlap, e.g.

(1, 8, 16)    -> this profile is valid for 1 to 16, specifically optimized of 8 
(16, 24, 32)
(33, 48, 63)
...

Users can manually select profile to use. In terms of https://sampl.cs.washington.edu/tvmconf/slides/2019/Jared-Roesch-Haichen-Shen-RelayVM.pdf I have some questions, Representing Dynamism is solved by shape function, but how do I manually represent a constrained dynamism? That is our input is only valid from 1 to 1024 or 16 to 32.
This dynamic shape may be n-d varying, how do you want to efficiently handle large amount of buckets?

@kevinthesun
Copy link
Contributor Author

@cloudhan Thanks for your info. @icemelon9 Do we have any work related to dynamic axis range?

In terms of codegen, indeed efficiency(and also how to limit the number of buckets but loss less performance) is one of the difficult part. We are working on improving some fundamental infra to see how we can achieve a practical workflow in tvm. Currently it is still in researching stage.

@zhanghaohit
Copy link
Contributor

dg1 (1)

I'm still curious what will happen if we have conv2d(5, 3, 224, 224)? We'll use conv2d(8, 3, 224, 224)? Do we need to do some padding to use the kernel conv2d(8, 3, 224, 224)?

Thanks @kevinthesun for the clarification.

Hi @kevinthesun . For this question, do we plan to do padding or resize? Thanks

@kevinthesun
Copy link
Contributor Author

kevinthesun commented Jun 12, 2020

@zhanghaohit It's still under investigation for different options, but it's more likely a static shape will fall into a bucket and call corresponding kernel.

@xutianming
Copy link
Contributor

How is this RFC going ? Are there any following pull requests?

@tiandiao123
Copy link
Contributor

May I ask whether there are some working code examples for graph dispatch available? Thank you!

@monklof
Copy link
Contributor

monklof commented Nov 8, 2020

@tiandiao123 this feature is still WIP.

@wxyhv
Copy link

wxyhv commented Mar 24, 2021

Dose any progress in dynamic input shape inference? I am expecting it

@masahi
Copy link
Member

masahi commented Jan 9, 2022

I assume this is no longer active.

@masahi masahi closed this as completed Jan 9, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests