New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] NNVMv2 IR - Relay #1673

Closed
jroesch opened this Issue Aug 29, 2018 · 73 comments

Comments

Projects
None yet
@jroesch
Contributor

jroesch commented Aug 29, 2018

[RFC]: Relay a new high level IR for TVM

Relay is a new high level intermediate representation (IR) intended to act as v2.0 of NNVM.

Motivation

Computation graphs are a powerful program representation as demonstrated by the first generation of DL frameworks. Most popular frameworks have employed computation graphs as their input, intermediate representation, and execution data structure.

However, as workloads continue to evolve, the design of our high level IRs needs to evolve to better support the needs of developers and users

Graph-level challenges such as control flow and sub-graphs have become necessary features to natively support and optimize.

The tight coupling between runtime representation and compile-time representation has limited flexibility and frustrated developers; Relay will decouple the representations.

Finally we believe the high level must be designed in tandem with the low level IR, allowing for the two layers to communicate during compilation to achieve optimal performance.

Design

The first version of NNVM set out to solve some of these challenges, and we view Relay as second generation IR designed specifically for integration into the TVM stack as the input layer. Our goal is to focus on TVM as our primary backend, easing development and maintenance for both TVM developers and current NNVM users, as well as enabling new features.

In order to address the challenges presented above we designed Relay to build on the things computation graphs are good at (pure, dataflow, compositional), and improve on the things they struggle with (control flow, subgraph, runtime/compilation distinction).

Core IR

Relay is a typed pure functional IR, with a few basic features such as functions, if-then-else control flow, recursion, operator and function calls, and variable binding.

We have iterated on Relay's design over the past 8 months. This versions represents the culmination of our experiments. This PR does not contain all the pieces of the previous version, instead we focus on introducing the core IR, its associated data structures, and a few integral passes.

The core IR is defined in just a few files:

  • include/tvm/relay/base.h (the base classes and common data)
  • include/tvm/relay/type.h (the type system and all relevant nodes)
  • include/tvm/relay/expr.h (the expression language)

Typing

All Relay programs are typed, similar to more conventional languages such as C++.
A type system allows us to statically (i.e at compile time) distinguish between different sorts of values. This means we know whether an expression will evaluate to a tensor, a function (i.e (float32, float32) -> float32) or a tuple (float32, int32). Furthermore, our type system has the ability to be shape generic (i.e polymorphism, templating).

Type inference and checking take the place of shape inference in traditional computation graphs style IRs.

This PR implements type inference and checking for Relay, the code can be found in src/tvm/relay/pass/type_infer.cc, and relevant helper utilities in src/tvm/relay/pass.

Control Flow

Relay adds a notion of control flow to the IR, in the form of simple if (cond) { true_branch } else { false_branch}. Relay requires that the condition variable computes a single boolean value controlling which branch is taken. if is an expression in Relay, meaning the result of the entire
expression is the result of the branch taken.

We introduce this to add a formal way to distinguish between data flow and control flow without having to conflate the two in the representation. Because we separate the control signal, we can easily batch a program without affecting control flow.

The definition of control flow can be found in include/tvm/relay/expr.h.

Abstraction

Relay supports the definition of functions which can be used to represent "sub-graphs" (i.e chunks of reusable computation).

Relay functions are like traditional functions: they represent some set of parameters (i.e placeholders) and a body which is a chunk of computation involving the the parameters (i.e sub-graph). We can build a full network/model by composing together functions.

Compilation

The Relay IR is designed as a compile time representation of models. The new features are exposed only in Relay's abstract syntax tree, and used for compile time program manipulation. We do not intend to use Relay's IR as a data structure for serious interpretation or execution.

Runtime

These new features increase the expressivity of the current computation model, and one may ask how to execute programs using these features with the existing runtime. Our goal is to introduce Relay as the compiler representation in this PR, and reuse the existing runtime maintaining compatibility on both the frontend and backend. We anticipate a new version of the runtime having native support for Relay's new constructs in the future.

TVM Co-design

We made an effort to model Relay's implementation after TVM and reuse much of the existing infrastructure in order to provide better compatibility between TOPI operators and Relay programs. One big design decision is reusing the TVM node system to expose the Relay language to Python in the style of TVM. Users who are familiar with TVM's expression language should feel comfortable working with the Relay AST's definition in C++, and Python. We also share representations for many data structures. For example tensor containers (i.e tvm::runtime::NDArray), and generic attributes which can be shared between Relay and TVM are two such shared structures.

Transitioning from NNVM

We plan on adding a guide for transitioning programs from NNVM to Relay. This is one of the remaining work items before releasing the Relay Alpha. The goal is users can use the Relay operators and builder API to construct Relay programs, and we will follow-up with a compatibility layer to make transitioning from NNVM smooth.

For an implementation see #1672 which implements this bit.

@jroesch jroesch changed the title from [RFC] Relay to [RFC] NNVMv2 IR - Relay Aug 29, 2018

@jroesch jroesch referenced this issue Aug 29, 2018

Merged

[High level OPT][RFC] NNVMv2 IR - Relay #1672

14 of 20 tasks complete
@masahi

This comment has been minimized.

Show comment
Hide comment
@masahi

masahi Aug 29, 2018

Contributor

hi @jroesch, looks cool. I am hoping that this is a TVM's answer to Tensorflow's somewhat awkward tf.while_loop or other control flow constructs.

Regarding this sentence,

Finally we believe the high level must be designed in tandem with the low level IR, allowing for the two layers to communicate during compilation to achieve optimal performance.,

does HalideIR correspond to the "low level IR"? This sounds very interesting.

Contributor

masahi commented Aug 29, 2018

hi @jroesch, looks cool. I am hoping that this is a TVM's answer to Tensorflow's somewhat awkward tf.while_loop or other control flow constructs.

Regarding this sentence,

Finally we believe the high level must be designed in tandem with the low level IR, allowing for the two layers to communicate during compilation to achieve optimal performance.,

does HalideIR correspond to the "low level IR"? This sounds very interesting.

@junrushao1994

This comment has been minimized.

Show comment
Hide comment
@junrushao1994

junrushao1994 Aug 29, 2018

Contributor

We know that the goal of designing a new IR is to benefit potential optimization, so could you be more specific that what kinds of optimization Relay is planning to support?

Contributor

junrushao1994 commented Aug 29, 2018

We know that the goal of designing a new IR is to benefit potential optimization, so could you be more specific that what kinds of optimization Relay is planning to support?

@tqchen

This comment has been minimized.

Show comment
Hide comment
@tqchen

tqchen Aug 29, 2018

Member

@junrushao1994 can you also elaborate on your use-cases and how easy/hard it is to bring this to the current proposal?

Member

tqchen commented Aug 29, 2018

@junrushao1994 can you also elaborate on your use-cases and how easy/hard it is to bring this to the current proposal?

@tqchen tqchen added the status: RFC label Aug 29, 2018

@tqchen

This comment has been minimized.

Show comment
Hide comment
@tqchen

tqchen Aug 29, 2018

Member

This is one of the major design chance and we would love to have participation from the community to review and improve the proposal @dmlc/tvm-team

Member

tqchen commented Aug 29, 2018

This is one of the major design chance and we would love to have participation from the community to review and improve the proposal @dmlc/tvm-team

@tqchen

This comment has been minimized.

Show comment
Hide comment
@tqchen

tqchen Aug 29, 2018

Member

Thanks @jroesch for the proposal, I am going to elaborate some of my take on this proposal. Note that no design is perfect and that is why we need help of everyone to work together the evolve the IR.

Specific Technical Points

  • Tight integration with TVM node system. NNVM was designed pre-tvm so we did not put tvm runtime into consideration. This makes registering python callback, traversing the IR and interaction hard for the current nnvm. Relay refactor directly bring this tight integration, now every IR can be visited, inspected and we can do prototyping in both python and c++ easily. This is a feature that has nothing to do with the IR spec but never the less very significant for developers.
  • Shape/dtype integrated as TensorType, reuse symbolic integer(tvm::Expr). One advantage of TOPI description is that it is possible to support symbolic integer in many cases, so the code can be generic on one dimension(e.g. batch), the TensorType also reuses this. This makes things consistent with TOPI and allows declaration of specific programs like input size (n, 128, 128)
  • Control flows, if, for(via tail recursion) and function recursion.
  • Separation of compiler IR and runtime. This is again good for two reasons:
    • We do not have to put too many consideration of compiler into runtime and keep runtime minimum
    • We can keep the current graph runtime.

Some Possible Point to Discuss

These are things that pops up from my head, feel free to add more.

  • What need to be clarified, please say so since we need to make it accessible
  • How to support a specific pass, and how easy/hard it is in the current proposal
  • Specific use case scenario(e.g. transformer) and what things we need
  • What helps to constitute a minimum runtime for inference
  • Any considerations we need to build a JIT runtime for training.
Member

tqchen commented Aug 29, 2018

Thanks @jroesch for the proposal, I am going to elaborate some of my take on this proposal. Note that no design is perfect and that is why we need help of everyone to work together the evolve the IR.

Specific Technical Points

  • Tight integration with TVM node system. NNVM was designed pre-tvm so we did not put tvm runtime into consideration. This makes registering python callback, traversing the IR and interaction hard for the current nnvm. Relay refactor directly bring this tight integration, now every IR can be visited, inspected and we can do prototyping in both python and c++ easily. This is a feature that has nothing to do with the IR spec but never the less very significant for developers.
  • Shape/dtype integrated as TensorType, reuse symbolic integer(tvm::Expr). One advantage of TOPI description is that it is possible to support symbolic integer in many cases, so the code can be generic on one dimension(e.g. batch), the TensorType also reuses this. This makes things consistent with TOPI and allows declaration of specific programs like input size (n, 128, 128)
  • Control flows, if, for(via tail recursion) and function recursion.
  • Separation of compiler IR and runtime. This is again good for two reasons:
    • We do not have to put too many consideration of compiler into runtime and keep runtime minimum
    • We can keep the current graph runtime.

Some Possible Point to Discuss

These are things that pops up from my head, feel free to add more.

  • What need to be clarified, please say so since we need to make it accessible
  • How to support a specific pass, and how easy/hard it is in the current proposal
  • Specific use case scenario(e.g. transformer) and what things we need
  • What helps to constitute a minimum runtime for inference
  • Any considerations we need to build a JIT runtime for training.
@junrushao1994

This comment has been minimized.

Show comment
Hide comment
@junrushao1994

junrushao1994 Aug 29, 2018

Contributor

I would say the design itself is perfect, which addresses almost all problems that this design targets. Yes, this is the first systematic approach to address the lack of Turing-completeness in deep learning frameworks, rather than quick hacks like TensorFlow's while_loop. Also, The implementation is elegant and I love it.

So please allow me to take the liberty to talk about some concerns that might be out of the target of the current design. Briefly, I will comment in the following aspects.

  • frontend requirement from deep learning guys
  • optimization techniques
  • runtime
  • possible opportunities for graph-level auto-tuning
Contributor

junrushao1994 commented Aug 29, 2018

I would say the design itself is perfect, which addresses almost all problems that this design targets. Yes, this is the first systematic approach to address the lack of Turing-completeness in deep learning frameworks, rather than quick hacks like TensorFlow's while_loop. Also, The implementation is elegant and I love it.

So please allow me to take the liberty to talk about some concerns that might be out of the target of the current design. Briefly, I will comment in the following aspects.

  • frontend requirement from deep learning guys
  • optimization techniques
  • runtime
  • possible opportunities for graph-level auto-tuning
@junrushao1994

This comment has been minimized.

Show comment
Hide comment
@junrushao1994

junrushao1994 Aug 29, 2018

Contributor

Part 1: Frontend

Sometimes I prefer to think that worse is better, and guess that it might be kind of restrictive to ask deep learning practitioners to do "the right thing". After all, we could not assume each user has a PhD degree in PL. Here are several things we might need to consider in the future.

Container types

I would love to discuss the necessity of having containers like List<T>, Dict<K, V>. This concern raises from some deep learning model in NLP. For example, self-attention, which is used in sequence generation tasks.

def self-attentive-generator():
  initialize states
  initialize outputs = []
  while True:
    prev_outputs = concat(outputs)
    context_vector = self-attention(prev_outputs, inputs, state)
    step_output, states = DecoderStep(states, prev_outputs)
    outputs.append(step_output)
    if some_condition:
      break
  outputs = concat(outputs)
  return outputs, states

We already have Tuple<T_1, ...> in Relay, which is great. We could definitely ask users to convert it to a FP style so that everything has no side effect, and while performance is not quite affected, but are we able to reduce the memory footprint in this case?

Supporting Dict<K, V> is somewhat weird requirement, but it seems to me that states in RNN are often represented using a dict (need confirmation from @szha). I guess we could probably replace this with something like namedtuple.

Other weird use cases may include Fast WaveNet, Pixel CNN++, beam search in decoding, many of while requires users to write a normal program to manipulate containers, which is easier for common users to write in an imperative style, rather than FP. I am a big fan of FPs like Haskell, but yet a little bit worried

  1. whether the market likes this style of programming.
  2. whether this IR could support such optimizations.

Side effects, e.g. I/O and random monad

For example, in dropout, we should definitely introduce something with randomness.

This is just a remainder that these are stuff we should take into consideration.

Whether to incorporate context into the type system

I am also wondering if ctx could be put into the type system. This is kind of co-design with TVM.

Primitives related to distributed system

This is kind of off the topic, but I am personally interested in seeing brilliant ideas about how we could handle these situations.

A. SyncBN: This introduces allreduce in the forward pass. Of course, such ugly thing like forward/backward will never exist in Relay, which is great, but it seems that we should introduce synchronization primitives in the design.

B. Timeout: It is common practice on an edge device to store a smaller DL model offline, which is used to produce a coarse result; in the meantime, an online model on remote servers computes some part of and send it back. In case the online model failed to send the fine-grind result in time, the edge device would use the offline result. This is another kind of side-effect, should we handle this in a deep learning framework, or leave it to others?

Contributor

junrushao1994 commented Aug 29, 2018

Part 1: Frontend

Sometimes I prefer to think that worse is better, and guess that it might be kind of restrictive to ask deep learning practitioners to do "the right thing". After all, we could not assume each user has a PhD degree in PL. Here are several things we might need to consider in the future.

Container types

I would love to discuss the necessity of having containers like List<T>, Dict<K, V>. This concern raises from some deep learning model in NLP. For example, self-attention, which is used in sequence generation tasks.

def self-attentive-generator():
  initialize states
  initialize outputs = []
  while True:
    prev_outputs = concat(outputs)
    context_vector = self-attention(prev_outputs, inputs, state)
    step_output, states = DecoderStep(states, prev_outputs)
    outputs.append(step_output)
    if some_condition:
      break
  outputs = concat(outputs)
  return outputs, states

We already have Tuple<T_1, ...> in Relay, which is great. We could definitely ask users to convert it to a FP style so that everything has no side effect, and while performance is not quite affected, but are we able to reduce the memory footprint in this case?

Supporting Dict<K, V> is somewhat weird requirement, but it seems to me that states in RNN are often represented using a dict (need confirmation from @szha). I guess we could probably replace this with something like namedtuple.

Other weird use cases may include Fast WaveNet, Pixel CNN++, beam search in decoding, many of while requires users to write a normal program to manipulate containers, which is easier for common users to write in an imperative style, rather than FP. I am a big fan of FPs like Haskell, but yet a little bit worried

  1. whether the market likes this style of programming.
  2. whether this IR could support such optimizations.

Side effects, e.g. I/O and random monad

For example, in dropout, we should definitely introduce something with randomness.

This is just a remainder that these are stuff we should take into consideration.

Whether to incorporate context into the type system

I am also wondering if ctx could be put into the type system. This is kind of co-design with TVM.

Primitives related to distributed system

This is kind of off the topic, but I am personally interested in seeing brilliant ideas about how we could handle these situations.

A. SyncBN: This introduces allreduce in the forward pass. Of course, such ugly thing like forward/backward will never exist in Relay, which is great, but it seems that we should introduce synchronization primitives in the design.

B. Timeout: It is common practice on an edge device to store a smaller DL model offline, which is used to produce a coarse result; in the meantime, an online model on remote servers computes some part of and send it back. In case the online model failed to send the fine-grind result in time, the edge device would use the offline result. This is another kind of side-effect, should we handle this in a deep learning framework, or leave it to others?

@tqchen

This comment has been minimized.

Show comment
Hide comment
@tqchen

tqchen Aug 29, 2018

Member

+1 on making things accessible(worse is better), this is exactly what we should push for.

Making most part functional makes differentiation easy and allows build the things around. It is important to be able to support mutation in some of the outer loops, where diff is not needed. Have such clear distinction is important.

The List Dict programming style is more like multi staging program, where the graph is staged in the data structure before we call backward. While it is definitely possible to do so via imperative autodiff, it is an interesting question to ask if we can desugar this into some form of functions. Note that this is different from mutation(because differentiation is necessary)

Distributed sync prImitives and timeout can likely be implmement via special core operator(like add and sub) and there should be no problem handling this: making things more sideeffect free will make distributed parts easier

Member

tqchen commented Aug 29, 2018

+1 on making things accessible(worse is better), this is exactly what we should push for.

Making most part functional makes differentiation easy and allows build the things around. It is important to be able to support mutation in some of the outer loops, where diff is not needed. Have such clear distinction is important.

The List Dict programming style is more like multi staging program, where the graph is staged in the data structure before we call backward. While it is definitely possible to do so via imperative autodiff, it is an interesting question to ask if we can desugar this into some form of functions. Note that this is different from mutation(because differentiation is necessary)

Distributed sync prImitives and timeout can likely be implmement via special core operator(like add and sub) and there should be no problem handling this: making things more sideeffect free will make distributed parts easier

@jroesch

This comment has been minimized.

Show comment
Hide comment
@jroesch

jroesch Aug 29, 2018

Contributor

Frontend

@tqchen and @junrushao1994 in some of other experimentation we have worked on rewriting imperative Python code into the IR directly. This should allow users to write imperative looking code which is translated into a form that is easy to optimize and perform automatic differentiation on. In general we could expose an API which looks destructive to the user but is actually pure under the covers.

Side effects, e.g. I/O and random monad

I would like to track effects as an operator attribute so we can use this information during optimization. Roughly we will respect effects when transforming programs, for example no reordering over I/O or stateful computations.

Contributor

jroesch commented Aug 29, 2018

Frontend

@tqchen and @junrushao1994 in some of other experimentation we have worked on rewriting imperative Python code into the IR directly. This should allow users to write imperative looking code which is translated into a form that is easy to optimize and perform automatic differentiation on. In general we could expose an API which looks destructive to the user but is actually pure under the covers.

Side effects, e.g. I/O and random monad

I would like to track effects as an operator attribute so we can use this information during optimization. Roughly we will respect effects when transforming programs, for example no reordering over I/O or stateful computations.

@zheng-da

This comment has been minimized.

Show comment
Hide comment
@zheng-da

zheng-da Aug 29, 2018

Contributor

@tqchen, as you mentioned transformer, I'm curious how Relay is going to handle this kind of workloads where shapes changes at runtime and can't be statically inferred. Is there a plan to support it in TVM and Relay?

Contributor

zheng-da commented Aug 29, 2018

@tqchen, as you mentioned transformer, I'm curious how Relay is going to handle this kind of workloads where shapes changes at runtime and can't be statically inferred. Is there a plan to support it in TVM and Relay?

@junrushao1994

This comment has been minimized.

Show comment
Hide comment
@junrushao1994

junrushao1994 Aug 29, 2018

Contributor

@zheng-da in the standard way.

@tqchen, as you mentioned transformer, I'm curious how Relay is going to handle this kind of workloads where shapes changes at runtime and can't be statically inferred. Is there a plan to support it in TVM and Relay?

Contributor

junrushao1994 commented Aug 29, 2018

@zheng-da in the standard way.

@tqchen, as you mentioned transformer, I'm curious how Relay is going to handle this kind of workloads where shapes changes at runtime and can't be statically inferred. Is there a plan to support it in TVM and Relay?

@tqchen

This comment has been minimized.

Show comment
Hide comment
@tqchen

tqchen Aug 29, 2018

Member

TVM support variable length input via symbolic variable, so in theory we could build op that takes in input shape (n, 128) where n is a symbolic variable. Relay also adopt this in type system that allows handle cases of fixed dimension but symbolic shape. How to do generic code gen is another question that we can followup, but the IR itself can handle shape inference of this kind

Member

tqchen commented Aug 29, 2018

TVM support variable length input via symbolic variable, so in theory we could build op that takes in input shape (n, 128) where n is a symbolic variable. Relay also adopt this in type system that allows handle cases of fixed dimension but symbolic shape. How to do generic code gen is another question that we can followup, but the IR itself can handle shape inference of this kind

@junrushao1994

This comment has been minimized.

Show comment
Hide comment
@junrushao1994

junrushao1994 Aug 29, 2018

Contributor

@jroesch I see. That looks cool.

Frontend

@tqchen and @junrushao1994 in some of other experimentation we have worked on rewriting imperative Python code into the IR directly. This should allow users to write imperative looking code which is translated into a form that is easy to optimize and perform automatic differentiation on. In general we could expose an API which looks destructive to the user but is actually pure under the covers.

Side effects, e.g. I/O and random monad

I would like to track effects as an operator attribute so we can use this information during optimization. Roughly we will respect effects when transforming programs, for example no reordering over I/O or stateful computations.

Contributor

junrushao1994 commented Aug 29, 2018

@jroesch I see. That looks cool.

Frontend

@tqchen and @junrushao1994 in some of other experimentation we have worked on rewriting imperative Python code into the IR directly. This should allow users to write imperative looking code which is translated into a form that is easy to optimize and perform automatic differentiation on. In general we could expose an API which looks destructive to the user but is actually pure under the covers.

Side effects, e.g. I/O and random monad

I would like to track effects as an operator attribute so we can use this information during optimization. Roughly we will respect effects when transforming programs, for example no reordering over I/O or stateful computations.

@zheng-da

This comment has been minimized.

Show comment
Hide comment
@zheng-da

zheng-da Aug 30, 2018

Contributor

Do I understand it correct?

Suppose I have an operator that increases the size of all dimensions of the output by 1 compared with the input, i.e., if the input shape is (x, y), the output shape will be (x+1, y+1).
Now if one of the input shape is unknown, e.g., (n, 128), as you mentioned, the output shape will be inferred as (n+1, 129). In other words, the system can infer the size of all known dimensions and capture the relation of the unknown dimensions.

Now, I have this operator op described as above (it increases the size of all dimensions by 1) and use it in a while loop as follows:

i = 0
data = mx.nd.random.uniform((10, 10))
out = mx.nd.zeros((10, 10))
while (sum(data[i]) > 1):
    out = op(out)
    i = i + 1

Can the shape of out still be inferred, maybe expressed in a symbolic way?

Contributor

zheng-da commented Aug 30, 2018

Do I understand it correct?

Suppose I have an operator that increases the size of all dimensions of the output by 1 compared with the input, i.e., if the input shape is (x, y), the output shape will be (x+1, y+1).
Now if one of the input shape is unknown, e.g., (n, 128), as you mentioned, the output shape will be inferred as (n+1, 129). In other words, the system can infer the size of all known dimensions and capture the relation of the unknown dimensions.

Now, I have this operator op described as above (it increases the size of all dimensions by 1) and use it in a while loop as follows:

i = 0
data = mx.nd.random.uniform((10, 10))
out = mx.nd.zeros((10, 10))
while (sum(data[i]) > 1):
    out = op(out)
    i = i + 1

Can the shape of out still be inferred, maybe expressed in a symbolic way?

@yzhliu

This comment has been minimized.

Show comment
Hide comment
@yzhliu

yzhliu Aug 30, 2018

Member

+1 for incorporating context or target into the type system, so that it can directly support heterogeneous runtime.

Shall we provide an approach to convert RelayIR to graph representation (if it can)? I'm thinking about passing subset to accelerators like TensorRT.

Member

yzhliu commented Aug 30, 2018

+1 for incorporating context or target into the type system, so that it can directly support heterogeneous runtime.

Shall we provide an approach to convert RelayIR to graph representation (if it can)? I'm thinking about passing subset to accelerators like TensorRT.

@junrushao1994

This comment has been minimized.

Show comment
Hide comment
@junrushao1994

junrushao1994 Aug 30, 2018

Contributor

@zheng-da If we convert your code to a functional one, yes, this is called shape generic?

Do I understand it correct?

Suppose I have an operator that increases the size of all dimensions of the output by 1 compared with the input, i.e., if the input shape is (x, y), the output shape will be (x+1, y+1).
Now if one of the input shape is unknown, e.g., (n, 128), as you mentioned, the output shape will be inferred as (n+1, 129). In other words, the system can infer the size of all known dimensions and capture the relation of the unknown dimensions.

Now, I have this operator op described as above (it increases the size of all dimensions by 1) and use it in a while loop as follows:

i = 0
data = mx.nd.random.uniform((10, 10))
out = mx.nd.zeros((10, 10))
while (sum(data[i]) > 1):
    out = op(out)
    i = i + 1

Can the shape of out still be inferred, maybe expressed in a symbolic way?

Contributor

junrushao1994 commented Aug 30, 2018

@zheng-da If we convert your code to a functional one, yes, this is called shape generic?

Do I understand it correct?

Suppose I have an operator that increases the size of all dimensions of the output by 1 compared with the input, i.e., if the input shape is (x, y), the output shape will be (x+1, y+1).
Now if one of the input shape is unknown, e.g., (n, 128), as you mentioned, the output shape will be inferred as (n+1, 129). In other words, the system can infer the size of all known dimensions and capture the relation of the unknown dimensions.

Now, I have this operator op described as above (it increases the size of all dimensions by 1) and use it in a while loop as follows:

i = 0
data = mx.nd.random.uniform((10, 10))
out = mx.nd.zeros((10, 10))
while (sum(data[i]) > 1):
    out = op(out)
    i = i + 1

Can the shape of out still be inferred, maybe expressed in a symbolic way?

@zheng-da

This comment has been minimized.

Show comment
Hide comment
@zheng-da

zheng-da Aug 30, 2018

Contributor

@junrushao1994 I don't know what it means. Could you show me how the shape of out is expressed? Also, how is the shape used after it's inferred?

Contributor

zheng-da commented Aug 30, 2018

@junrushao1994 I don't know what it means. Could you show me how the shape of out is expressed? Also, how is the shape used after it's inferred?

@junrushao1994

This comment has been minimized.

Show comment
Hide comment
@junrushao1994

junrushao1994 Aug 30, 2018

Contributor

@zheng-da Sorry for making you confused. There are two steps, the first step is to convert the code to a purely functional one, which means you use pattern matching + recursion to substitute the loop. The second step is to look at the function, then you will see the function that represents the loop body is generic.

@junrushao1994 I don't know what it means. Could you show me how the shape of out is expressed? Also, how is the shape used after it's inferred?

Contributor

junrushao1994 commented Aug 30, 2018

@zheng-da Sorry for making you confused. There are two steps, the first step is to convert the code to a purely functional one, which means you use pattern matching + recursion to substitute the loop. The second step is to look at the function, then you will see the function that represents the loop body is generic.

@junrushao1994 I don't know what it means. Could you show me how the shape of out is expressed? Also, how is the shape used after it's inferred?

@junrushao1994

This comment has been minimized.

Show comment
Hide comment
@junrushao1994

junrushao1994 Aug 30, 2018

Contributor

@zheng-da It is called MGU (@jroesch correct me if it is not). The shape of out is called IncompleteType somewhere in the code. (I briefly glanced through the code, but didn't remember the exact name) A simple union-find set perfectly solve this problem.

@junrushao1994 I don't know what it means. Could you show me how the shape of out is expressed? Also, how is the shape used after it's inferred?

Contributor

junrushao1994 commented Aug 30, 2018

@zheng-da It is called MGU (@jroesch correct me if it is not). The shape of out is called IncompleteType somewhere in the code. (I briefly glanced through the code, but didn't remember the exact name) A simple union-find set perfectly solve this problem.

@junrushao1994 I don't know what it means. Could you show me how the shape of out is expressed? Also, how is the shape used after it's inferred?

@tqchen

This comment has been minimized.

Show comment
Hide comment
@tqchen

tqchen Aug 30, 2018

Member

In my unstanding, type inference was designed to understand things in compile time, so in the case of random and dimension expansion,it is impossible to decide the final dimension and inference will likely return things like incomplete type, or use shape of node that defers things to runtime. The fixed dimension symbolic shape case was the most common one that we can still take benefit from such static info

Member

tqchen commented Aug 30, 2018

In my unstanding, type inference was designed to understand things in compile time, so in the case of random and dimension expansion,it is impossible to decide the final dimension and inference will likely return things like incomplete type, or use shape of node that defers things to runtime. The fixed dimension symbolic shape case was the most common one that we can still take benefit from such static info

@zheng-da

This comment has been minimized.

Show comment
Hide comment
@zheng-da

zheng-da Aug 30, 2018

Contributor

Agreed. The fixed-dimension symbolic shape is very useful. I think mxnet can greatly benefit from it. Could you point me to the code in TVM that does it?

I think my original question was whether there is a plan of supporting the case that the shape really can't be inferred in TVM and relay? For example, in mxnet, I'm thinking of doing something like this: apache/incubator-mxnet#12400. Of course, my solution is hacky.

Contributor

zheng-da commented Aug 30, 2018

Agreed. The fixed-dimension symbolic shape is very useful. I think mxnet can greatly benefit from it. Could you point me to the code in TVM that does it?

I think my original question was whether there is a plan of supporting the case that the shape really can't be inferred in TVM and relay? For example, in mxnet, I'm thinking of doing something like this: apache/incubator-mxnet#12400. Of course, my solution is hacky.

@tqchen

This comment has been minimized.

Show comment
Hide comment
@tqchen

tqchen Aug 30, 2018

Member

Supporting runtime dynamic shape is really a problem of the runtime, since in terms of high level ir, we can always introduce a type(incomplete or dynamic) that indicate it can be anything. And we will need to add JIT in runtime to build and run the stages when we have more information

Member

tqchen commented Aug 30, 2018

Supporting runtime dynamic shape is really a problem of the runtime, since in terms of high level ir, we can always introduce a type(incomplete or dynamic) that indicate it can be anything. And we will need to add JIT in runtime to build and run the stages when we have more information

@junrushao1994

This comment has been minimized.

Show comment
Hide comment
@junrushao1994

junrushao1994 Aug 30, 2018

Contributor

@tqchen Don't think it is a big deal for the runtime if we support only PackedFunc wrapping libs like cuDNN. Many passes could be represented using a sparsely conditional constant prop, or other very mature compiler techniques.

However, if we want auto-tuning in such scenario, it could be cutting edge research.

Supporting runtime dynamic shape is really a problem of the runtime, since in terms of high level ir, we can always introduce a type(incomplete or dynamic) that indicate it can be anything. And we will need to add JIT in runtime to build and run the stages when we have more information

Contributor

junrushao1994 commented Aug 30, 2018

@tqchen Don't think it is a big deal for the runtime if we support only PackedFunc wrapping libs like cuDNN. Many passes could be represented using a sparsely conditional constant prop, or other very mature compiler techniques.

However, if we want auto-tuning in such scenario, it could be cutting edge research.

Supporting runtime dynamic shape is really a problem of the runtime, since in terms of high level ir, we can always introduce a type(incomplete or dynamic) that indicate it can be anything. And we will need to add JIT in runtime to build and run the stages when we have more information

@tqchen

This comment has been minimized.

Show comment
Hide comment
@tqchen

tqchen Aug 30, 2018

Member

To follow up on @yzhliu 's comment on whether ctx should be part of type system.

We don't have to enforce everything as part of type in order to do such optimization. Context assignments(or machine assignments) in distributed setting can also be presented in column meta data(like in NNVM). We have quite a lot of cases like this: alterative data layout, distributed machine assignements etc.

The possible pros/cons of the type system vs the additional metadata are

  • Putting things in type helps cases when we need frequent pattern matching on them.
  • Bringing things as mandatory type, however, will force the developer to always keep them in mind.

So in the case of shape vs context

  • Shapes are mostly needed in many passes
  • Context, on the other hand, is normally not needed in most passes, we build a partition, then most optimizations do not need to consider context.

Because of this reason, we can argue that context is preferred not as part of the type, but more like a metadata of say function or a call.

Member

tqchen commented Aug 30, 2018

To follow up on @yzhliu 's comment on whether ctx should be part of type system.

We don't have to enforce everything as part of type in order to do such optimization. Context assignments(or machine assignments) in distributed setting can also be presented in column meta data(like in NNVM). We have quite a lot of cases like this: alterative data layout, distributed machine assignements etc.

The possible pros/cons of the type system vs the additional metadata are

  • Putting things in type helps cases when we need frequent pattern matching on them.
  • Bringing things as mandatory type, however, will force the developer to always keep them in mind.

So in the case of shape vs context

  • Shapes are mostly needed in many passes
  • Context, on the other hand, is normally not needed in most passes, we build a partition, then most optimizations do not need to consider context.

Because of this reason, we can argue that context is preferred not as part of the type, but more like a metadata of say function or a call.

@tqchen

This comment has been minimized.

Show comment
Hide comment
@tqchen

tqchen Aug 30, 2018

Member

@masahi I think in here low-level IR refers to the tensor expression part of TVM, including autoTVM, topi, compute primitives.

Member

tqchen commented Aug 30, 2018

@masahi I think in here low-level IR refers to the tensor expression part of TVM, including autoTVM, topi, compute primitives.

@masahi

This comment has been minimized.

Show comment
Hide comment
@masahi

masahi Aug 30, 2018

Contributor

@tqchen thanks, makes sense. Those are in turn based on HalideIR, so in some sense HalideIR is the foundation for everything.

Contributor

masahi commented Aug 30, 2018

@tqchen thanks, makes sense. Those are in turn based on HalideIR, so in some sense HalideIR is the foundation for everything.

@zheng-da

This comment has been minimized.

Show comment
Hide comment
@zheng-da

zheng-da Aug 30, 2018

Contributor

@tqchen Another question is how to integrate with some backward libraries in Relay. Maybe this isn't really a Relay question, but it's something we need to consider after TVM moves to Relay. I suppose Relay is good at pattern matching. Is it easy to take out the matched pattern and put it somewhere (maybe in an operator) to invoke TensorRT? How do you think about supporting stateful operators, both from the perspective of Relay and TVM? Having a stateful operator may be easier for us to integrate with TensorRT.

Contributor

zheng-da commented Aug 30, 2018

@tqchen Another question is how to integrate with some backward libraries in Relay. Maybe this isn't really a Relay question, but it's something we need to consider after TVM moves to Relay. I suppose Relay is good at pattern matching. Is it easy to take out the matched pattern and put it somewhere (maybe in an operator) to invoke TensorRT? How do you think about supporting stateful operators, both from the perspective of Relay and TVM? Having a stateful operator may be easier for us to integrate with TensorRT.

@junrushao1994

This comment has been minimized.

Show comment
Hide comment
@junrushao1994

junrushao1994 Aug 31, 2018

Contributor

@zheng-da Relay has some notion to track effects, so why not you guys put these arbitrary stuff inside something like a PackedFunc?

Update: as @MarisaKirisame mentioned, I am wrong. Please just ignore this reply.

@junrushao1994 when i look at the type system (in the Relay paper), it supports Base type, shape, Tensor, function, type, reference, tuple. Do you suggest representing the data structure for any arbitrary external library with the Relay type system? For example, MKLDNN requires some data structure like mkldnn::memory::primitive_desc. It's a class that contains std shared_ptr. It's probably doable to store this data structure in Relay, but it might be more convenient to support something like OpaqueType for arbitrary operator states.

The other problem is that these external libraries may change the state after each invocation. However, we don't know if they really change or how they change the states. Therefore, the operator can't be pure functional. Does Relay need to deal with it?

Contributor

junrushao1994 commented Aug 31, 2018

@zheng-da Relay has some notion to track effects, so why not you guys put these arbitrary stuff inside something like a PackedFunc?

Update: as @MarisaKirisame mentioned, I am wrong. Please just ignore this reply.

@junrushao1994 when i look at the type system (in the Relay paper), it supports Base type, shape, Tensor, function, type, reference, tuple. Do you suggest representing the data structure for any arbitrary external library with the Relay type system? For example, MKLDNN requires some data structure like mkldnn::memory::primitive_desc. It's a class that contains std shared_ptr. It's probably doable to store this data structure in Relay, but it might be more convenient to support something like OpaqueType for arbitrary operator states.

The other problem is that these external libraries may change the state after each invocation. However, we don't know if they really change or how they change the states. Therefore, the operator can't be pure functional. Does Relay need to deal with it?

@junrushao1994

This comment has been minimized.

Show comment
Hide comment
@junrushao1994

junrushao1994 Aug 31, 2018

Contributor

@kevinthesun Hey Yao, could you kindly share more thoughts about what information you think must be put into the type system? It will be very helpful!

This is a good chance to look at data layout system. I think @yzhliu is currently working on refactoring layout in TVM: https://discuss.tvm.ai/t/datalayout-structure/80

To enable graph level optimization, every operator will require layout information. Maybe we can considering adding it Relay type system.

Contributor

junrushao1994 commented Aug 31, 2018

@kevinthesun Hey Yao, could you kindly share more thoughts about what information you think must be put into the type system? It will be very helpful!

This is a good chance to look at data layout system. I think @yzhliu is currently working on refactoring layout in TVM: https://discuss.tvm.ai/t/datalayout-structure/80

To enable graph level optimization, every operator will require layout information. Maybe we can considering adding it Relay type system.

@MarisaKirisame

This comment has been minimized.

Show comment
Hide comment
@MarisaKirisame

MarisaKirisame Aug 31, 2018

Contributor

@junrushao1994 we hasnt settle on any effect system yet. We would love to know what actual use case is there!

Contributor

MarisaKirisame commented Aug 31, 2018

@junrushao1994 we hasnt settle on any effect system yet. We would love to know what actual use case is there!

@junrushao1994

This comment has been minimized.

Show comment
Hide comment
@junrushao1994

junrushao1994 Aug 31, 2018

Contributor

@MarisaKirisame This is the concern raised from colleagues working on external library integration, as mentioned by @zheng-da.

@junrushao1994 we hasnt settle on any effect system yet. We would love to know what actual use case is there!

Contributor

junrushao1994 commented Aug 31, 2018

@MarisaKirisame This is the concern raised from colleagues working on external library integration, as mentioned by @zheng-da.

@junrushao1994 we hasnt settle on any effect system yet. We would love to know what actual use case is there!

@kevinthesun

This comment has been minimized.

Show comment
Hide comment
@kevinthesun

kevinthesun Aug 31, 2018

Contributor

@junrushao1994 For NNVM/Relay, the layout information is mainly used to insert layout transformation op when necessary. Currently this is achieved by FCorrectLayout attribute.It's like an "InferLayout" attr. We might want to preserve the latest valid layout of each op, so that we can easily fall back to last valid layout when the new layout pass is illegal for some ops. The logic should be similar to current NNVM implementation, but we might be able to better manage it in Relay.

Contributor

kevinthesun commented Aug 31, 2018

@junrushao1994 For NNVM/Relay, the layout information is mainly used to insert layout transformation op when necessary. Currently this is achieved by FCorrectLayout attribute.It's like an "InferLayout" attr. We might want to preserve the latest valid layout of each op, so that we can easily fall back to last valid layout when the new layout pass is illegal for some ops. The logic should be similar to current NNVM implementation, but we might be able to better manage it in Relay.

@junrushao1994

This comment has been minimized.

Show comment
Hide comment
@junrushao1994

junrushao1994 Aug 31, 2018

Contributor

@kevinthesun IMO the concern is valid, @tqchen what do you think?

@junrushao1994 For NNVM/Relay, the layout information is mainly used to insert layout transformation op when necessary. Currently this is achieved by FCorrectLayout attribute.It's like an "InferLayout" attr. We might want to preserve the latest valid layout of each op, so that we can easily fall back to last valid layout when the new layout pass is illegal for some ops. The logic should be similar to current NNVM implementation, but we might be able to better manage it in Relay.

Contributor

junrushao1994 commented Aug 31, 2018

@kevinthesun IMO the concern is valid, @tqchen what do you think?

@junrushao1994 For NNVM/Relay, the layout information is mainly used to insert layout transformation op when necessary. Currently this is achieved by FCorrectLayout attribute.It's like an "InferLayout" attr. We might want to preserve the latest valid layout of each op, so that we can easily fall back to last valid layout when the new layout pass is illegal for some ops. The logic should be similar to current NNVM implementation, but we might be able to better manage it in Relay.

@tqchen

This comment has been minimized.

Show comment
Hide comment
@tqchen

tqchen Aug 31, 2018

Member

Let us hear the opinion and possible proposals from everyone in the community :) at this moment for things that are in flux we don’t necessarily need decisions. Instead we would love to see clear actionable options

Member

tqchen commented Aug 31, 2018

Let us hear the opinion and possible proposals from everyone in the community :) at this moment for things that are in flux we don’t necessarily need decisions. Instead we would love to see clear actionable options

@grwlf

This comment has been minimized.

Show comment
Hide comment
@grwlf

grwlf Sep 1, 2018

Contributor

Hi. Very interesting discussion we have. I have some questions related maybe more to the Relay paper than to this exact RFC. AFAIK they are closely connected, but please let me know if we have a better place to discuss the article.

Python typing

I'm glad that you are going to employ Python typing facilities! In the article, type specifications look like def lenet(x: Tensor[Float, (1, 28, 28)]) -> Tensor[Float, 10]. Since python doesn't specify the exact typechecking algorithm, I assume we are going to extract type information from Python and pass it to C++ core to do shape math during graph compilation passes. Am I correct with that?

Next, If we use other typing tools (probably mypy) in the same project, then two things are going to happen: 1) Mypy will see Relay types while performing before-runtime typechecking. 2) Relay will see mypy types at runtime.

AFAIK, mypy will not accept Relay types as they are, it needs some stubs/definitions or maybe it needs to know which places in code to avoid looking at. Do you plan to maintain some degree of toolset compatibility?

Contributor

grwlf commented Sep 1, 2018

Hi. Very interesting discussion we have. I have some questions related maybe more to the Relay paper than to this exact RFC. AFAIK they are closely connected, but please let me know if we have a better place to discuss the article.

Python typing

I'm glad that you are going to employ Python typing facilities! In the article, type specifications look like def lenet(x: Tensor[Float, (1, 28, 28)]) -> Tensor[Float, 10]. Since python doesn't specify the exact typechecking algorithm, I assume we are going to extract type information from Python and pass it to C++ core to do shape math during graph compilation passes. Am I correct with that?

Next, If we use other typing tools (probably mypy) in the same project, then two things are going to happen: 1) Mypy will see Relay types while performing before-runtime typechecking. 2) Relay will see mypy types at runtime.

AFAIK, mypy will not accept Relay types as they are, it needs some stubs/definitions or maybe it needs to know which places in code to avoid looking at. Do you plan to maintain some degree of toolset compatibility?

@junrushao1994

This comment has been minimized.

Show comment
Hide comment
@junrushao1994

junrushao1994 Sep 1, 2018

Contributor

Another thing I am concerning is autobatch (https://arxiv.org/abs/1705.07860), which is an important feature that could potentially speed up irregular workloads. This involves some sort of scalar arithmetics, indexing, concatenation stuff.

Relay could definitely represent these stuff easily. Just curious, if you guys want to support auto-batching, would you prefer to write it in frontend language, parse it out, and represent it with Relay, or write a low-level kernel or a PackedFunc?

Contributor

junrushao1994 commented Sep 1, 2018

Another thing I am concerning is autobatch (https://arxiv.org/abs/1705.07860), which is an important feature that could potentially speed up irregular workloads. This involves some sort of scalar arithmetics, indexing, concatenation stuff.

Relay could definitely represent these stuff easily. Just curious, if you guys want to support auto-batching, would you prefer to write it in frontend language, parse it out, and represent it with Relay, or write a low-level kernel or a PackedFunc?

@jekbradbury

This comment has been minimized.

Show comment
Hide comment
@jekbradbury

jekbradbury Sep 1, 2018

Autobatching is probably best implemented as a transformation on Relay IR? That’s pretty much what’s working for PyTorch.

jekbradbury commented Sep 1, 2018

Autobatching is probably best implemented as a transformation on Relay IR? That’s pretty much what’s working for PyTorch.

@szha

This comment has been minimized.

Show comment
Hide comment
@szha

szha Sep 1, 2018

Member

On the other hand, dynet seems happy with putting it in the scheduler and batch whatever is available to run. I think whether IR transformation is better depends on the granularity of the graph, as well as the ability to discover reusable sub-graphs.

Member

szha commented Sep 1, 2018

On the other hand, dynet seems happy with putting it in the scheduler and batch whatever is available to run. I think whether IR transformation is better depends on the granularity of the graph, as well as the ability to discover reusable sub-graphs.

@szha

This comment has been minimized.

Show comment
Hide comment
@szha

szha Sep 1, 2018

Member

Oh I forgot, also depends on whether you need to keep the dependency graph for AD or not.

Member

szha commented Sep 1, 2018

Oh I forgot, also depends on whether you need to keep the dependency graph for AD or not.

@jekbradbury

This comment has been minimized.

Show comment
Hide comment
@jekbradbury

jekbradbury Sep 1, 2018

DyNet’s original approach (essentially a runtime scheduling pass) doesn’t scale well with accelerators. But the next version of DyNet will likely use Cavs, which solves that by dramatically reducing the complexity of the scheduling problem at the cost of a new user-visible level of abstraction (an annotated static subgraph they call a vertex). In the long run a combination of IR transformation for SPMD-style batching and Cavs-like scheduling for more highly divergent cases seems like it might be the right approach.

jekbradbury commented Sep 1, 2018

DyNet’s original approach (essentially a runtime scheduling pass) doesn’t scale well with accelerators. But the next version of DyNet will likely use Cavs, which solves that by dramatically reducing the complexity of the scheduling problem at the cost of a new user-visible level of abstraction (an annotated static subgraph they call a vertex). In the long run a combination of IR transformation for SPMD-style batching and Cavs-like scheduling for more highly divergent cases seems like it might be the right approach.

@szha

This comment has been minimized.

Show comment
Hide comment
@szha

szha Sep 1, 2018

Member

Indeed, a combination of these two approaches will likely cover more cases.

Member

szha commented Sep 1, 2018

Indeed, a combination of these two approaches will likely cover more cases.

@junrushao1994

This comment has been minimized.

Show comment
Hide comment
@junrushao1994

junrushao1994 Sep 3, 2018

Contributor

A quick question, do you guys want to strictly enforce homo type of a list, or allow some generics. For example, will all tensors inside a list be of the same shape?

Update: seems like a silly question...

Contributor

junrushao1994 commented Sep 3, 2018

A quick question, do you guys want to strictly enforce homo type of a list, or allow some generics. For example, will all tensors inside a list be of the same shape?

Update: seems like a silly question...

@MarisaKirisame

This comment has been minimized.

Show comment
Hide comment
@MarisaKirisame

MarisaKirisame Sep 3, 2018

Contributor

Very Probably. The generic across shape part will probably be done using existential type.

Contributor

MarisaKirisame commented Sep 3, 2018

Very Probably. The generic across shape part will probably be done using existential type.

@joshpoll

This comment has been minimized.

Show comment
Hide comment
@joshpoll

joshpoll Sep 4, 2018

Contributor

@grwlf I'm one of the people working on the Python frontend. You’re correct. Type information is passed directly from the frontend to the C++ compiler. We can handle mypy type annotations in Relay code, since those annotations are encoded in the Python AST that we compile to Relay. I’m not sure that Relay will ever encounter mypy types at runtime, but I may be missing something.

The Relay IR nodes (including types) are mirrored in Python via the TVM FFI, allowing us to construct Relay programs in Python. As a side effect, mypy can correctly handle some Relay types; however, it won’t be able to handle more complicated types like tensors. PEP 484 included @no_type_check_decorator, which allows us to disable mypy type checking for code annotated with the relay decorators. In fact we check the frontend with mypy, so you should expect good compatibility between Relay and mypy.

Contributor

joshpoll commented Sep 4, 2018

@grwlf I'm one of the people working on the Python frontend. You’re correct. Type information is passed directly from the frontend to the C++ compiler. We can handle mypy type annotations in Relay code, since those annotations are encoded in the Python AST that we compile to Relay. I’m not sure that Relay will ever encounter mypy types at runtime, but I may be missing something.

The Relay IR nodes (including types) are mirrored in Python via the TVM FFI, allowing us to construct Relay programs in Python. As a side effect, mypy can correctly handle some Relay types; however, it won’t be able to handle more complicated types like tensors. PEP 484 included @no_type_check_decorator, which allows us to disable mypy type checking for code annotated with the relay decorators. In fact we check the frontend with mypy, so you should expect good compatibility between Relay and mypy.

@jroesch

This comment has been minimized.

Show comment
Hide comment
@jroesch

jroesch Sep 4, 2018

Contributor

Great feedback everyone. Sorry for being absent the last couple of days, right after I shipped the initial version of the PR I had to move apartments and have been hastily moving things from one apartment to the next and trying to get my desk set back up so I could work 😄.

I want to address a bunch of the feedback in order.

Typing dynamic dimensions

This is referring to @zheng-da, @tqchen and @junrushao1994's line of questions and discussions.

The type system design is general enough to support writing the functions you. For example we can write down a functions like: (n, s, k) -> (s + 1, k * n) (where n, s and k are symbolic) the problem is if we want to compile that function ahead of time without knowing n, s, and k. We need to generate code which works over arbitrary dimensions. In the current system we can still write such code we just rely on knowing n, s, and k at compile time.

Fortunately in deep learning many input dimensions are known at compile time allowing us to efficiently compile that code today.

The bigger issue is typing a loop which applies that function an unknown number of times. If you can convert the code to functional form like @junrushao1994 suggested then its possible to assign a symbolic relationship between the loop condition and the type. Realistically many of these need to be dynamic and require us to introduce a concept of an unknown dimension (which we do not currently support) and makes code generation significantly more complex if we want efficient kernels.

Graph Extraction from Relay

It is straight forward enough to extract a graph from Relay, I'm working on porting the Relay to TVM compiler to the newest version of Relay on the currently open PR. For example in the pure data flow case you would only accept Let and Call nodes, a program containing anything else would be rejected (this is roughly a basic block). I think extraction of graphs would be relatively easy to setup, and ideally the next generation runtime would allow us to generate and schedule code like this.

ONNX as IR and interchange format

There was some discussion about ONNX and a unified interchange format but the problem is that ONNX is insufficiently expressive. For example PyTorch can't export dynamic graphs of any form due to ONNX limitations. There are also issues like full support for broadcasting operations in the IR.

Data Structures

I think data structures are interesting if we could assign them semantics ourselves I think optimization of them will be much easier, for example if we use data structures in a linear
fashion it is possible to heavily fuse/elide allocations. I think we should follow up on this
after the initial IR lands.

On Optimizations

In terms of doing lower level optimizations my vision has always been that most users will interact with the IR before the "suffix" optimizations are applied. That is the user's perception will be of a high level functional graph, mostly free of effects, and can be reasoned about algebraically like most computation graphs. The "suffix" phases will do things like concretize the operators, do memory planning, convert from functional operations to destructive ones, etc. and then we will generate low level VM instructions/LLVM and a set of customized TVM operators. This is the place where we could break the uniform constant representation and schedule scalar code differently, etc.

Data Layout

@kevinthesun We have talked about extending typing with data layout but it wasn't clear how to insert it into the type system. Last time @tqchen and I talked we were planning on storing it as an attribute. My thought was to have a layout pass which inserts explicit layout change operators everywhere and then tries to minimize the number of layout transforms, could imagine employing some search here, for example maybe its good to keep some data around in both layouts? would be interested in hearing more.

Pass Manger

@zhiics The older version of Relay had a basic version of the pass manager. I think one really interesting use case for it is learning pass ordering. Pass ordering traditionally in compilers has been something that is relatively fixed, can have substantial impact on performance, and the optimal ordering is often application specific. I think having the ability to apply auto-tvm style techniques here would be great, especially as the number of passes grow.

Auto-batching

We have discussed implementing auto-batching as a compiler transform, but in general it sounds like a mixture of scheduling and program transform makes sense. It seems like there are many examples of this static/dynamic split in CS and especially compilers. One simple example is stack-frame/stack, we can simply compute the amount of storage for a single frame but in order to support recursion we allow some of the work to happen dynamically. We have discussed a similar strategy for the runtime, and I could imagine trying to heavily optimize static sub-graphs where possible and then bucket then dynamically during runtime. I'm not an expert in auto-batching so I would love to hear more ideas from the community.

MyPy and Relay

@joshpoll is right about the type annotations we have setup them up to not interfere with MyPy, if MyPy gained support to type numpy computations we could probably design annotations which work both in MyPy and Relay. The previous version of Relay (not the one in the PR) had the required stubs to make everything work.

Contributor

jroesch commented Sep 4, 2018

Great feedback everyone. Sorry for being absent the last couple of days, right after I shipped the initial version of the PR I had to move apartments and have been hastily moving things from one apartment to the next and trying to get my desk set back up so I could work 😄.

I want to address a bunch of the feedback in order.

Typing dynamic dimensions

This is referring to @zheng-da, @tqchen and @junrushao1994's line of questions and discussions.

The type system design is general enough to support writing the functions you. For example we can write down a functions like: (n, s, k) -> (s + 1, k * n) (where n, s and k are symbolic) the problem is if we want to compile that function ahead of time without knowing n, s, and k. We need to generate code which works over arbitrary dimensions. In the current system we can still write such code we just rely on knowing n, s, and k at compile time.

Fortunately in deep learning many input dimensions are known at compile time allowing us to efficiently compile that code today.

The bigger issue is typing a loop which applies that function an unknown number of times. If you can convert the code to functional form like @junrushao1994 suggested then its possible to assign a symbolic relationship between the loop condition and the type. Realistically many of these need to be dynamic and require us to introduce a concept of an unknown dimension (which we do not currently support) and makes code generation significantly more complex if we want efficient kernels.

Graph Extraction from Relay

It is straight forward enough to extract a graph from Relay, I'm working on porting the Relay to TVM compiler to the newest version of Relay on the currently open PR. For example in the pure data flow case you would only accept Let and Call nodes, a program containing anything else would be rejected (this is roughly a basic block). I think extraction of graphs would be relatively easy to setup, and ideally the next generation runtime would allow us to generate and schedule code like this.

ONNX as IR and interchange format

There was some discussion about ONNX and a unified interchange format but the problem is that ONNX is insufficiently expressive. For example PyTorch can't export dynamic graphs of any form due to ONNX limitations. There are also issues like full support for broadcasting operations in the IR.

Data Structures

I think data structures are interesting if we could assign them semantics ourselves I think optimization of them will be much easier, for example if we use data structures in a linear
fashion it is possible to heavily fuse/elide allocations. I think we should follow up on this
after the initial IR lands.

On Optimizations

In terms of doing lower level optimizations my vision has always been that most users will interact with the IR before the "suffix" optimizations are applied. That is the user's perception will be of a high level functional graph, mostly free of effects, and can be reasoned about algebraically like most computation graphs. The "suffix" phases will do things like concretize the operators, do memory planning, convert from functional operations to destructive ones, etc. and then we will generate low level VM instructions/LLVM and a set of customized TVM operators. This is the place where we could break the uniform constant representation and schedule scalar code differently, etc.

Data Layout

@kevinthesun We have talked about extending typing with data layout but it wasn't clear how to insert it into the type system. Last time @tqchen and I talked we were planning on storing it as an attribute. My thought was to have a layout pass which inserts explicit layout change operators everywhere and then tries to minimize the number of layout transforms, could imagine employing some search here, for example maybe its good to keep some data around in both layouts? would be interested in hearing more.

Pass Manger

@zhiics The older version of Relay had a basic version of the pass manager. I think one really interesting use case for it is learning pass ordering. Pass ordering traditionally in compilers has been something that is relatively fixed, can have substantial impact on performance, and the optimal ordering is often application specific. I think having the ability to apply auto-tvm style techniques here would be great, especially as the number of passes grow.

Auto-batching

We have discussed implementing auto-batching as a compiler transform, but in general it sounds like a mixture of scheduling and program transform makes sense. It seems like there are many examples of this static/dynamic split in CS and especially compilers. One simple example is stack-frame/stack, we can simply compute the amount of storage for a single frame but in order to support recursion we allow some of the work to happen dynamically. We have discussed a similar strategy for the runtime, and I could imagine trying to heavily optimize static sub-graphs where possible and then bucket then dynamically during runtime. I'm not an expert in auto-batching so I would love to hear more ideas from the community.

MyPy and Relay

@joshpoll is right about the type annotations we have setup them up to not interfere with MyPy, if MyPy gained support to type numpy computations we could probably design annotations which work both in MyPy and Relay. The previous version of Relay (not the one in the PR) had the required stubs to make everything work.

@zhiics

This comment has been minimized.

Show comment
Hide comment
@zhiics

zhiics Sep 4, 2018

Contributor

@jroesch Thanks for the summary. The phase ordering problem is also one of my major concerns, but I haven't looked into the auto-tvm style technique. I will read it.

Contributor

zhiics commented Sep 4, 2018

@jroesch Thanks for the summary. The phase ordering problem is also one of my major concerns, but I haven't looked into the auto-tvm style technique. I will read it.

@grwlf

This comment has been minimized.

Show comment
Hide comment
@grwlf

grwlf Sep 4, 2018

Contributor

Can't stop myself from adding 2 cents about pass managers. TL;DR IMHO maybe passes of one kind do not need pass managers, but passes of other kind should have one.

I think (an obvious thing) that absence of pass manager may lead to hardly-detectable bugs during compiler development. I also think it is correct to separate passes which do actually change IR but do preserve the algorithm (typechecking of all sorts, may be loop unrolling) and passes which keep IR the same but change the program written in it (the ones which removes or inserts function calls, like fusing).

The passes which change IR but preserve the algorithm may really benefit from some kind of pass manager, better static than dynamic. I imagine, that In C++ this may be done by template metaprogramming on AST classes. In Haskell there are solutions employing fixed-point datatypes able attach or remove pass-specific details to AST datatype while keeping it statically-typed. Dynamic pass managers are simpler but should already make life easier, so it may be a compromise solution.

The passes which change algorithm but preserve IR should be the playground of autotuning. I don't have experience here, need to better learn this area. Maybe one should think of something like first-class passes, if it is ever possible. Generally, I believe that one should avoid playing with optimizations which can't be represented with a language because of troubles with formal verification in future. Thank you.

Contributor

grwlf commented Sep 4, 2018

Can't stop myself from adding 2 cents about pass managers. TL;DR IMHO maybe passes of one kind do not need pass managers, but passes of other kind should have one.

I think (an obvious thing) that absence of pass manager may lead to hardly-detectable bugs during compiler development. I also think it is correct to separate passes which do actually change IR but do preserve the algorithm (typechecking of all sorts, may be loop unrolling) and passes which keep IR the same but change the program written in it (the ones which removes or inserts function calls, like fusing).

The passes which change IR but preserve the algorithm may really benefit from some kind of pass manager, better static than dynamic. I imagine, that In C++ this may be done by template metaprogramming on AST classes. In Haskell there are solutions employing fixed-point datatypes able attach or remove pass-specific details to AST datatype while keeping it statically-typed. Dynamic pass managers are simpler but should already make life easier, so it may be a compromise solution.

The passes which change algorithm but preserve IR should be the playground of autotuning. I don't have experience here, need to better learn this area. Maybe one should think of something like first-class passes, if it is ever possible. Generally, I believe that one should avoid playing with optimizations which can't be represented with a language because of troubles with formal verification in future. Thank you.

@kevinthesun

This comment has been minimized.

Show comment
Hide comment
@kevinthesun

kevinthesun Sep 5, 2018

Contributor

@jroesch Thank you for summary. Adding layout attribute and a layout pass should be able to deal with current use cases.

Contributor

kevinthesun commented Sep 5, 2018

@jroesch Thank you for summary. Adding layout attribute and a layout pass should be able to deal with current use cases.

@JammyZhou

This comment has been minimized.

Show comment
Hide comment
@JammyZhou

JammyZhou Sep 10, 2018

Contributor

@jroesch As to the ONNX limitations, can you share more background or details why it cannot satisfy the requirements of NNVM/TVM? I'm really interested on this. Thanks in advance!

Contributor

JammyZhou commented Sep 10, 2018

@jroesch As to the ONNX limitations, can you share more background or details why it cannot satisfy the requirements of NNVM/TVM? I'm really interested on this. Thanks in advance!

@sgrechanik-h

This comment has been minimized.

Show comment
Hide comment
@sgrechanik-h

sgrechanik-h Sep 10, 2018

Contributor

@jroesch @MarisaKirisame What will the interface for defining gradients for operations look like? Currently in nnvm there is the FGradient attribute which expresses the gradient in terms of other nnvm operations. Will it be something like this for Relay too, or may be the gradients will be expressed directly in terms of topi and tvm operations?

Currently we are working on AD on the TVM level, so we are thinking about the possibility of integrating it into Relay. In our solution we calculate gradients for tensors produced by tvm.compute primitive. Are you going to support elementwise construction of tensors like this in Relay? And if so, are you going to implement automatic differentiation for tensors defined this way?

Contributor

sgrechanik-h commented Sep 10, 2018

@jroesch @MarisaKirisame What will the interface for defining gradients for operations look like? Currently in nnvm there is the FGradient attribute which expresses the gradient in terms of other nnvm operations. Will it be something like this for Relay too, or may be the gradients will be expressed directly in terms of topi and tvm operations?

Currently we are working on AD on the TVM level, so we are thinking about the possibility of integrating it into Relay. In our solution we calculate gradients for tensors produced by tvm.compute primitive. Are you going to support elementwise construction of tensors like this in Relay? And if so, are you going to implement automatic differentiation for tensors defined this way?

@MarisaKirisame

This comment has been minimized.

Show comment
Hide comment
@MarisaKirisame

MarisaKirisame Sep 14, 2018

Contributor

@sgrechanik-h
Somewhat like FGradient, Relay gradient will be expressed as other Relay gradient for closure property (we can differentiate on differentiated expression for arbitrary times).
Suppose an operator is of type x -> y, we are planning to expose an api that let you define it's gradient as an expression (probably) of type (x -> (y, y -> x)). It might be changed to some other similar type for efficiency/composability reason, but there is no plan to release an api base on compute, as we have no idea of it's consequence on performance.
If you can show good performance number we are very happy to change our mind, as it is certainly more principled to define gradient once and for all!

Contributor

MarisaKirisame commented Sep 14, 2018

@sgrechanik-h
Somewhat like FGradient, Relay gradient will be expressed as other Relay gradient for closure property (we can differentiate on differentiated expression for arbitrary times).
Suppose an operator is of type x -> y, we are planning to expose an api that let you define it's gradient as an expression (probably) of type (x -> (y, y -> x)). It might be changed to some other similar type for efficiency/composability reason, but there is no plan to release an api base on compute, as we have no idea of it's consequence on performance.
If you can show good performance number we are very happy to change our mind, as it is certainly more principled to define gradient once and for all!

@sgrechanik-h

This comment has been minimized.

Show comment
Hide comment
@sgrechanik-h

sgrechanik-h Sep 17, 2018

Contributor

@MarisaKirisame Thanks for clarification. It seems like there will always be operations for which automatic differentiation does poor job, and gradients should be defined manually for them, so I think this is the right approach anyway.

Contributor

sgrechanik-h commented Sep 17, 2018

@MarisaKirisame Thanks for clarification. It seems like there will always be operations for which automatic differentiation does poor job, and gradients should be defined manually for them, so I think this is the right approach anyway.

@grwlf

This comment has been minimized.

Show comment
Hide comment
@grwlf

grwlf Oct 5, 2018

Contributor

@jroesch could you please clarify the place of scheduling operations in Relay? Do we plan to allow users to specify schedulers manually or we mean automatic scheduling during some phases of Relay->TVM transformation?

I'm asking this because I am working on performance measurements for a large model, exported from TensorFlow to NNVM. I am facing the case where automatic schedulers were able to accelerate some parts of the model (~2x faster than TF on CPU), but for other parts they showed poor results. In order to figure out the cause I had to "reverse engineer" the NNVM graph back to sources. (Here is a draft code for GraphDef staging). I am not finished yet but it seems I will need to schedule some parts manually and AFAIK it is not that simple in NNVM. What about Relay? What would be the workflow for cases where manual scheduling is required?

Contributor

grwlf commented Oct 5, 2018

@jroesch could you please clarify the place of scheduling operations in Relay? Do we plan to allow users to specify schedulers manually or we mean automatic scheduling during some phases of Relay->TVM transformation?

I'm asking this because I am working on performance measurements for a large model, exported from TensorFlow to NNVM. I am facing the case where automatic schedulers were able to accelerate some parts of the model (~2x faster than TF on CPU), but for other parts they showed poor results. In order to figure out the cause I had to "reverse engineer" the NNVM graph back to sources. (Here is a draft code for GraphDef staging). I am not finished yet but it seems I will need to schedule some parts manually and AFAIK it is not that simple in NNVM. What about Relay? What would be the workflow for cases where manual scheduling is required?

@tqchen

This comment has been minimized.

Show comment
Hide comment
@tqchen

tqchen Oct 18, 2018

Member

Thanks for very helpful discussions, relay IR has been mainlined. Let us open new issues for discussing specific aspect of the development

Member

tqchen commented Oct 18, 2018

Thanks for very helpful discussions, relay IR has been mainlined. Let us open new issues for discussing specific aspect of the development

@tqchen tqchen closed this Oct 18, 2018

@tqchen

This comment has been minimized.

Show comment
Hide comment
@tqchen
Member

tqchen commented Oct 18, 2018

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment