Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Op][Trans1] Reshape, Transpose, ExpandDims, Squeeze, Flatten #78

Merged
merged 4 commits into from
Jan 1, 2023

Conversation

MasterJH5574
Copy link
Member

This PR introduces transformation operators to the latest StructInfo codebase. It is part of our operator migration efforts as listed in #62.

Again, we split the migration into multiple parts and PRs for transformation ops, due to the large amount of operators. This PR includes the following operators:

  • Reshape
  • Transpose
  • ExpandDims, Squeeze
  • Flatten

Most LOC comes from the test file

  • tests/python/relax/test_op_transform.py

@@ -325,6 +325,50 @@ struct ReduceAttrs : public tvm::AttrsNode<ReduceAttrs> {
}
}; // struct ReduceAttrs

/*! \brief Attributes used in reshape operator */
struct ReshapeAttrs : public tvm::AttrsNode<ReshapeAttrs> {
Array<PrimExpr> new_shape;
Copy link
Contributor

@tqchen tqchen Dec 31, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because we now support ShapeExpr as first class, let us make reshape take Expr as the second argument and expects a ShapeStructInfo

Note that this would require us to enable negative values in shape, which is OK.

Copy link
Contributor

@tqchen tqchen Dec 31, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would require some special checkings for -1 in ShapeExpr.

Copy link
Contributor

@tqchen tqchen Dec 31, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps an implicit op, relax.builtin.prepare_reshape_shape(src_tensor, shape) for the most general case

Copy link
Member Author

@MasterJH5574 MasterJH5574 Dec 31, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My consideration of making this attribute is that having negative values in ShapeExpr doesn’t sound perfectly right. -1 is only used in Reshape. And any other case where non-positive values appear in ShapeExpr or ShapeStructInfo should be regarded as invalid (I’m even thinking about adding a check to ShapeExpr constructor to enforce this).

Since the -1 inference is only for reshape, I think it’s reasonable to make Reshape treat the input specially, compared with other ops (like full/ones/zeros in #79), because those ops require normal input shape, where any non-positive value is not allowed.

Allowing -1 to appear in ShapeExpr brings the implication that we will need to check “if the input ShapeExpr contains -1” in all other use site of ShapeExpr, which does not worth IMHO.

Copy link
Contributor

@tqchen tqchen Jan 1, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One of the main problem in this case is to have PrimExpr appear as an Attrs, which breaks our abstraction level (Attrs should always remain constant).

In this case, having constant reshape is fine. But we won't be able to update the prim expression if the IR get mutated(say we want to remap n into some other shape variables) because the assumption is that Attrs only contans constant(and won't get updated).

So having symbolic new_shape as Attr is not ideal from that pov. We can have new_shape as constant int.

As another alternative, perhaps we should just require this as ShapeExpr, and provide helper conversion to construct such shape Expr in reshape implementation in construction time.

Expr MakeReshape(x, Array<PrimExpr> shape) {
   ShapeExpr new_shape
   if (shape.as<Array<PrimExpr>>()) {
      // check -1, check x.shape and construct such a ShapeExpr when constructing the call 
   }
   return Call(reshape, {x, new_shape});
}

This of course is also not perfect but restricts the use-cases to where we can handle.

src/relax/op/tensor/transform.cc Outdated Show resolved Hide resolved
src/relax/op/tensor/transform.cc Outdated Show resolved Hide resolved
src/relax/op/tensor/transform.cc Outdated Show resolved Hide resolved
src/relax/op/tensor/transform.cc Outdated Show resolved Hide resolved
src/relax/op/tensor/transform.cc Outdated Show resolved Hide resolved
src/relax/op/tensor/transform.cc Outdated Show resolved Hide resolved
@tqchen
Copy link
Contributor

tqchen commented Dec 31, 2022

Some general remarks

  • https://data-apis.org/array-api/latest/API_specification/ can be helpful for our naming, categorization and API spec, note that sometimes numpy version can be slightly more general(squeeze), make remark on those cases
  • There are a few cases (reshape, zeros) that takes shape as input parameters. Previously they were treated as Attrs, we should turn them directly into func parameter. In the case of reshape this would mean allowing constant -1 in ShapeExpr and I think it is fine as long as we normalize it.

@MasterJH5574
Copy link
Member Author

@tqchen Addressed most of your comments, with the last commit dedicated to rename API names and categorize per DataAPI.

@MasterJH5574 MasterJH5574 force-pushed the mlc-dev/2022-12-30-op-trans1 branch 2 times, most recently from 61bbfa8 to 1cac3f4 Compare December 31, 2022 22:53
@tqchen
Copy link
Contributor

tqchen commented Jan 1, 2023

Thanks @MasterJH5574 , see my followup comment on reshape. In this case, I think we can make new_shape a ShapeExpr and do some best effort -1 deduction at construction time for now.

@MasterJH5574
Copy link
Member Author

MasterJH5574 commented Jan 1, 2023

Thanks @tqchen . Doing -1 inference upon construction does fit better in our case, and I just update in this way.

Still have a couple of notes here:

  • N1. By doing inference before constructing the CallNode, we no longer put the “-1” into our AST. Therefore, this “request for inference” won’t be reflected in the IR. This hurts the TVMScript roundtrip a bit, since the script being parsed may have “-1”, while the later printed script doesn’t have “-1”.

  • N2. Please take a look at the impl of the inference (the function ConvertNewShapeToExpr). I’m not pretty sure my way of manipulating Array and ArrayNode is the most proper. I found one thing interesting today:

    Say I have a List[relax.Var] on Python side, and pass the list to C++ side via FFI, in the type of ObjectRef. When I use Downcast to cast it from ObjectRef to Array<PrimExpr>, the downcast doesn’t fail.

    Therefore, seems that we can ensure each Array element is a PrimExpr only by going through the entire Array, which is what I do now.

@tqchen tqchen merged commit 77ec1d0 into mlc-ai:structinfo Jan 1, 2023
@tqchen
Copy link
Contributor

tqchen commented Jan 1, 2023

LGTM!

MasterJH5574 pushed a commit that referenced this pull request Jan 16, 2023
* Enable tests.

* Updated.

* Updated.

* Updated.
MasterJH5574 pushed a commit that referenced this pull request Feb 8, 2023
* Enable tests.

* Updated.

* Updated.

* Updated.
spectrometerHBH pushed a commit to spectrometerHBH/relax that referenced this pull request Feb 9, 2023
* Enable tests.

* Updated.

* Updated.

* Updated.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants