Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relax][Op] Add new relax operators #43

Merged
merged 8 commits into from
Dec 13, 2022

Conversation

Ubospica
Copy link
Contributor

This PR adds these operators:

  • In relax.op.tensor
    • relax.negative
    • relax.log
    • relax.tanh
  • In relax.op.transform
    • relax.full_like
    • relax.ones
    • relax.zeros
    • relax.ones_like
    • relax.zeros_like
    • relax.collapse_sum_like
    • relax.collapse_sum_to
  • In topi.reduction
    • topi.collapse_sum
      • This operator is added to implement relax.collapse_sum_like and relax.collapse_sum_to.

Besides, this pull request fixes a bug in python/tvm/relax/block_builder.py so now BlockBuilder could support functions without parameters.

* Add LowerToTensorIRPass
* Register several high level ops
* Register gradients for ops
* Unit tests for Ops and gradients

[Pass] Add new SimpleAD pass (#3)

* Add SimpleAD pass

* Write unit tests

Co-authored-by: SiriusNEO <1713833595@qq.com>

[Relax][AD] Support tuple, modify APIs and move Ops to a new folder (#5)

This PR adds support for tuples in the automatic differentiation process of relax. Specifically, two operators, `Tuple()` and `TupleGetItem`, are supported. It also modifies `SimpleAD` api to remove string arguments and use GlobalVar, and adds `gradient` API, a new func to func interface.

Besides, we have moved the Ops created by us into a new folder `src/relax/op/training`.

[Testing] Relax LSTM demo and numeric gradient test (#6)

* lstm test changes

* relax training lstm demo

Rebase fix (#8)

* fix rebase error

* rebase fix: AD adapts to new parser

* rebase fix: return RuntimeDepShape now

Co-authored-by: ubospica <ubospica@gmail.com>

[Pass] Polish SimpleAD pass (#9)

* SimpleAD now will copy params, which fix the shared Var between functions problem

* add error unitests

* modify type of require_grads and error test

Optimizer API (#7)

Co-authored-by: SiriusNEO <1713833595@qq.com>

[Relax][Training] Relax Trainer API (#10)

* append call pass

* add test for append call

* successfully train LSTM using this trainer

* refractor training structure

* add tests for trainer

* remove a file

* Remove optimizer from this branch

* sync with optimizer api

* redesign API and add document

* fix some problems

Co-authored-by: ubospica <ubospica@gmail.com>

[FIX][Relax AD] replace ones_like with ones (#11)

* ones_like to one

* fix rebase

Co-authored-by: ubospica <ubospica@gmail.com>

Op polish

revision
@Ubospica Ubospica changed the title [Op] Add new relax operators [Relax][Op] Add new relax operators Dec 12, 2022
Copy link
Member

@MasterJH5574 MasterJH5574 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well done! Just a couple of nit-pickings.

include/tvm/topi/reduction.h Outdated Show resolved Hide resolved
tests/python/relax/test_relax_tensor_ops.py Outdated Show resolved Hide resolved
tests/python/relax/test_relax_transform_ops.py Outdated Show resolved Hide resolved
python/tvm/relax/block_builder.py Show resolved Hide resolved
python/tvm/relax/op/transform.py Outdated Show resolved Hide resolved
python/tvm/relax/op/transform.py Outdated Show resolved Hide resolved
python/tvm/relax/op/transform.py Outdated Show resolved Hide resolved
python/tvm/relax/op/transform.py Outdated Show resolved Hide resolved
python/tvm/relax/op/transform.py Outdated Show resolved Hide resolved
- Remove test_relax_ops running part
- Add op legalizer test
@MasterJH5574
Copy link
Member

You seemed to have changed the file apps/microtvm/zephyr/template_project/qemu-hack/qemu-system-riscv32 and some others. Please take a look to recover them.

Copy link
Member

@MasterJH5574 MasterJH5574 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let’s fix them and merge this PR!

tests/python/relax/test_relax_tensor_ops.py Outdated Show resolved Hide resolved
tests/python/relax/test_relax_transform_ops.py Outdated Show resolved Hide resolved
@Ubospica
Copy link
Contributor Author

@MasterJH5574 Thanks for your review!

@MasterJH5574 MasterJH5574 merged commit a3a22c1 into mlc-ai:relax Dec 13, 2022
@Ubospica Ubospica deleted the op_push_dyx branch December 13, 2022 16:21
MasterJH5574 pushed a commit that referenced this pull request Dec 14, 2022
* [Op] High level op added

* Add LowerToTensorIRPass
* Register several high level ops
* Register gradients for ops
* Unit tests for Ops and gradients

[Pass] Add new SimpleAD pass (#3)

* Add SimpleAD pass

* Write unit tests

Co-authored-by: SiriusNEO <1713833595@qq.com>

[Relax][AD] Support tuple, modify APIs and move Ops to a new folder (#5)

This PR adds support for tuples in the automatic differentiation process of relax. Specifically, two operators, `Tuple()` and `TupleGetItem`, are supported. It also modifies `SimpleAD` api to remove string arguments and use GlobalVar, and adds `gradient` API, a new func to func interface.

Besides, we have moved the Ops created by us into a new folder `src/relax/op/training`.

[Testing] Relax LSTM demo and numeric gradient test (#6)

* lstm test changes

* relax training lstm demo

Rebase fix (#8)

* fix rebase error

* rebase fix: AD adapts to new parser

* rebase fix: return RuntimeDepShape now

Co-authored-by: ubospica <ubospica@gmail.com>

[Pass] Polish SimpleAD pass (#9)

* SimpleAD now will copy params, which fix the shared Var between functions problem

* add error unitests

* modify type of require_grads and error test

Optimizer API (#7)

Co-authored-by: SiriusNEO <1713833595@qq.com>

[Relax][Training] Relax Trainer API (#10)

* append call pass

* add test for append call

* successfully train LSTM using this trainer

* refractor training structure

* add tests for trainer

* remove a file

* Remove optimizer from this branch

* sync with optimizer api

* redesign API and add document

* fix some problems

Co-authored-by: ubospica <ubospica@gmail.com>

[FIX][Relax AD] replace ones_like with ones (#11)

* ones_like to one

* fix rebase

Co-authored-by: ubospica <ubospica@gmail.com>

Op polish

revision

* fix ops

* debug for ops

* prepare for pr

* fix pr issues

* - Add collapse_sum document
- Remove test_relax_ops running part
- Add op legalizer test

* fix qemu-hack

* remove unnecessary imports from test
MasterJH5574 pushed a commit that referenced this pull request Jan 19, 2023
This PR migrates #46 to new struct
info infra, as part of our AD migration.

Because we need do numerical testing for gradients, this PR depends on
the operator legalizer #96. Also
because the original version of legalizer did not handle the negative
indexing case of `relax.mean`, this PR fixes it.

To lower `collapse_sum_to`, `collapse_sum_like` properly, this PR
migrates a previous patch #43 which
introduces `collapse_sum` in topi. Now we can remove the skip marker in
the legalizer test for `collapse_sum_to` and `collapse_sum_like`.

The gradients of `cross_entropy` and `softmax_cross_entropy` are
removed. And the former will be added back and adjust to new
`cross_entropy` introduced in #96.

Further plan in this PR:
- [x] Add gradients for `log_softmax` and `nll_loss` once
#94 is merged.
- [x] Gradients for some tuple related operators such as `split` and
`concat`. It can help us to test the correctness of AD when there are
Tuple-I/O operators.
- (Not in this PR) "Undefined Gradient" representation. As we know, the
gradients of some operators w.r.t. specified inputs are undefined or
meaningless, such as the partial gradient of `indices` in `take(x,
indices)`. Relay directly uses `zeros_like` in this case as it won't
affect gradient propagation. Another choice is to introduce a dummy Expr
named `UndefinedGradient` to represent it. How do we handle this case in
relax?
MasterJH5574 pushed a commit that referenced this pull request Jan 28, 2023
This PR migrates #46 to new struct
info infra, as part of our AD migration.

Because we need do numerical testing for gradients, this PR depends on
the operator legalizer #96. Also
because the original version of legalizer did not handle the negative
indexing case of `relax.mean`, this PR fixes it.

To lower `collapse_sum_to`, `collapse_sum_like` properly, this PR
migrates a previous patch #43 which
introduces `collapse_sum` in topi. Now we can remove the skip marker in
the legalizer test for `collapse_sum_to` and `collapse_sum_like`.

The gradients of `cross_entropy` and `softmax_cross_entropy` are
removed. And the former will be added back and adjust to new
`cross_entropy` introduced in #96.

Further plan in this PR:
- [x] Add gradients for `log_softmax` and `nll_loss` once
#94 is merged.
- [x] Gradients for some tuple related operators such as `split` and
`concat`. It can help us to test the correctness of AD when there are
Tuple-I/O operators.
- (Not in this PR) "Undefined Gradient" representation. As we know, the
gradients of some operators w.r.t. specified inputs are undefined or
meaningless, such as the partial gradient of `indices` in `take(x,
indices)`. Relay directly uses `zeros_like` in this case as it won't
affect gradient propagation. Another choice is to introduce a dummy Expr
named `UndefinedGradient` to represent it. How do we handle this case in
relax?
MasterJH5574 pushed a commit that referenced this pull request Jan 31, 2023
This PR migrates #46 to new struct
info infra, as part of our AD migration.

Because we need do numerical testing for gradients, this PR depends on
the operator legalizer #96. Also
because the original version of legalizer did not handle the negative
indexing case of `relax.mean`, this PR fixes it.

To lower `collapse_sum_to`, `collapse_sum_like` properly, this PR
migrates a previous patch #43 which
introduces `collapse_sum` in topi. Now we can remove the skip marker in
the legalizer test for `collapse_sum_to` and `collapse_sum_like`.

The gradients of `cross_entropy` and `softmax_cross_entropy` are
removed. And the former will be added back and adjust to new
`cross_entropy` introduced in #96.

Further plan in this PR:
- [x] Add gradients for `log_softmax` and `nll_loss` once
#94 is merged.
- [x] Gradients for some tuple related operators such as `split` and
`concat`. It can help us to test the correctness of AD when there are
Tuple-I/O operators.
- (Not in this PR) "Undefined Gradient" representation. As we know, the
gradients of some operators w.r.t. specified inputs are undefined or
meaningless, such as the partial gradient of `indices` in `take(x,
indices)`. Relay directly uses `zeros_like` in this case as it won't
affect gradient propagation. Another choice is to introduce a dummy Expr
named `UndefinedGradient` to represent it. How do we handle this case in
relax?
MasterJH5574 pushed a commit that referenced this pull request Feb 8, 2023
This PR migrates #46 to new struct
info infra, as part of our AD migration.

Because we need do numerical testing for gradients, this PR depends on
the operator legalizer #96. Also
because the original version of legalizer did not handle the negative
indexing case of `relax.mean`, this PR fixes it.

To lower `collapse_sum_to`, `collapse_sum_like` properly, this PR
migrates a previous patch #43 which
introduces `collapse_sum` in topi. Now we can remove the skip marker in
the legalizer test for `collapse_sum_to` and `collapse_sum_like`.

The gradients of `cross_entropy` and `softmax_cross_entropy` are
removed. And the former will be added back and adjust to new
`cross_entropy` introduced in #96.

Further plan in this PR:
- [x] Add gradients for `log_softmax` and `nll_loss` once
#94 is merged.
- [x] Gradients for some tuple related operators such as `split` and
`concat`. It can help us to test the correctness of AD when there are
Tuple-I/O operators.
- (Not in this PR) "Undefined Gradient" representation. As we know, the
gradients of some operators w.r.t. specified inputs are undefined or
meaningless, such as the partial gradient of `indices` in `take(x,
indices)`. Relay directly uses `zeros_like` in this case as it won't
affect gradient propagation. Another choice is to introduce a dummy Expr
named `UndefinedGradient` to represent it. How do we handle this case in
relax?
MasterJH5574 pushed a commit to MasterJH5574/tlc-relax that referenced this pull request Feb 12, 2023
This PR migrates mlc-ai/relax#46 to new struct
info infra, as part of our AD migration.

Because we need do numerical testing for gradients, this PR depends on
the operator legalizer mlc-ai/relax#96. Also
because the original version of legalizer did not handle the negative
indexing case of `relax.mean`, this PR fixes it.

To lower `collapse_sum_to`, `collapse_sum_like` properly, this PR
migrates a previous patch mlc-ai/relax#43 which
introduces `collapse_sum` in topi. Now we can remove the skip marker in
the legalizer test for `collapse_sum_to` and `collapse_sum_like`.

The gradients of `cross_entropy` and `softmax_cross_entropy` are
removed. And the former will be added back and adjust to new
`cross_entropy` introduced in mlc-ai/relax#96.

Further plan in this PR:
- [x] Add gradients for `log_softmax` and `nll_loss` once
mlc-ai/relax#94 is merged.
- [x] Gradients for some tuple related operators such as `split` and
`concat`. It can help us to test the correctness of AD when there are
Tuple-I/O operators.
- (Not in this PR) "Undefined Gradient" representation. As we know, the
gradients of some operators w.r.t. specified inputs are undefined or
meaningless, such as the partial gradient of `indices` in `take(x,
indices)`. Relay directly uses `zeros_like` in this case as it won't
affect gradient propagation. Another choice is to introduce a dummy Expr
named `UndefinedGradient` to represent it. How do we handle this case in
relax?
MasterJH5574 pushed a commit to MasterJH5574/tlc-relax that referenced this pull request Feb 12, 2023
This PR migrates mlc-ai/relax#46 to new struct
info infra, as part of our AD migration.

Because we need do numerical testing for gradients, this PR depends on
the operator legalizer mlc-ai/relax#96. Also
because the original version of legalizer did not handle the negative
indexing case of `relax.mean`, this PR fixes it.

To lower `collapse_sum_to`, `collapse_sum_like` properly, this PR
migrates a previous patch mlc-ai/relax#43 which
introduces `collapse_sum` in topi. Now we can remove the skip marker in
the legalizer test for `collapse_sum_to` and `collapse_sum_like`.

The gradients of `cross_entropy` and `softmax_cross_entropy` are
removed. And the former will be added back and adjust to new
`cross_entropy` introduced in mlc-ai/relax#96.

Further plan in this PR:
- [x] Add gradients for `log_softmax` and `nll_loss` once
mlc-ai/relax#94 is merged.
- [x] Gradients for some tuple related operators such as `split` and
`concat`. It can help us to test the correctness of AD when there are
Tuple-I/O operators.
- (Not in this PR) "Undefined Gradient" representation. As we know, the
gradients of some operators w.r.t. specified inputs are undefined or
meaningless, such as the partial gradient of `indices` in `take(x,
indices)`. Relay directly uses `zeros_like` in this case as it won't
affect gradient propagation. Another choice is to introduce a dummy Expr
named `UndefinedGradient` to represent it. How do we handle this case in
relax?
MasterJH5574 pushed a commit that referenced this pull request Feb 12, 2023
This PR migrates #46 to new struct
info infra, as part of our AD migration.

Because we need do numerical testing for gradients, this PR depends on
the operator legalizer #96. Also
because the original version of legalizer did not handle the negative
indexing case of `relax.mean`, this PR fixes it.

To lower `collapse_sum_to`, `collapse_sum_like` properly, this PR
migrates a previous patch #43 which
introduces `collapse_sum` in topi. Now we can remove the skip marker in
the legalizer test for `collapse_sum_to` and `collapse_sum_like`.

The gradients of `cross_entropy` and `softmax_cross_entropy` are
removed. And the former will be added back and adjust to new
`cross_entropy` introduced in #96.

Further plan in this PR:
- [x] Add gradients for `log_softmax` and `nll_loss` once
#94 is merged.
- [x] Gradients for some tuple related operators such as `split` and
`concat`. It can help us to test the correctness of AD when there are
Tuple-I/O operators.
- (Not in this PR) "Undefined Gradient" representation. As we know, the
gradients of some operators w.r.t. specified inputs are undefined or
meaningless, such as the partial gradient of `indices` in `take(x,
indices)`. Relay directly uses `zeros_like` in this case as it won't
affect gradient propagation. Another choice is to introduce a dummy Expr
named `UndefinedGradient` to represent it. How do we handle this case in
relax?
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants