Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relax][Training] Add automatic differentiation pass #103

Merged

Conversation

Ubospica
Copy link
Contributor

@Ubospica Ubospica commented Jan 18, 2023

This is the PR following #55 after source branch moved to personal repo.

This PR is based on #98.

This PR adds the new automatic differentiation API:

  • Gradient(func: GlobalVar, require_grads: Optional[Union[Var, List[Var]]] = None) -> tvm.ir.transform.Pass
    • transforms the given funcion in the IRModule, and adds a new function that calculates the gradient with regard to the function's output

Now Gradient only supports differentiating a function in the IRModule with one dataflow block with respect to the only return value of the function, which needs to be scalar.

This PR writes two files for unit test:

  • tests/python/relax/test_transform_gradient.py only contains assert_structural_equal assertions.
  • tests/python/relax/test_transform_gradient_numeric.py contains numeric checks, including manually derived gradients and the numerical differentiation method check_numerical_grads.

Checkpoints:

  • Refactor to use CopyWithNewParams and ExprFunctor
  • Check int64/int32 tensors should not be differentiated (now only check in params)
  • Rebase & migrate to StructInfo
  • Refactor about Tuple
  • Refactor about NestedMsg
  • Support ops taking in tuple or returning tuple
  • Eliminating collapse_sum_to (done in [Op] Migration: Gradients for some operators #98)

Future:

  • (Not in this PR) Handle undefined gradient in add and return value
    • Now we handle them as zeros

debug finished

Change details on and modify document

polish test

unit test finished

reformat changed files

Fix problems in pr comments

move op and change them static in AD pass

fix some problems

fix for comments so far

update documents

formatted

update document1

draft 1

revise version 1

version 1 completed and formatted

revise on doc and detail

polish after refactor

remove zeros_tracker_

update doc

restructure and add float check for input

format

refactor done

refactor again

add test_tuple_ops

update comments

rebase onto struct info

fix log_softmax case

fix test

eliminator draft

fix after rebase

refactor draft

Revise doc & Epilogue & details

remove some irrelevant codes

remove collapse_sum_eliminator
normalize has problems to fix
@SiriusNEO SiriusNEO force-pushed the mlc-dev/2023-01-18-ad_after_tuple_refactor branch from f2364e4 to da9f2d7 Compare January 19, 2023 09:25
@SiriusNEO SiriusNEO force-pushed the mlc-dev/2023-01-18-ad_after_tuple_refactor branch from 2eafa15 to ac98e81 Compare January 19, 2023 14:00
@SiriusNEO SiriusNEO force-pushed the mlc-dev/2023-01-18-ad_after_tuple_refactor branch from bb035e4 to cd34f4d Compare January 19, 2023 15:48
@Ubospica Ubospica force-pushed the mlc-dev/2023-01-18-ad_after_tuple_refactor branch from f6d5622 to 79a9e94 Compare January 19, 2023 19:12
@Ubospica Ubospica changed the title [WIP][Relax][AD] Add automatic differentiation pass [Relax][AD] Add automatic differentiation pass Jan 19, 2023
@Ubospica Ubospica force-pushed the mlc-dev/2023-01-18-ad_after_tuple_refactor branch from 79a9e94 to ad2495c Compare January 19, 2023 19:18
@Ubospica Ubospica force-pushed the mlc-dev/2023-01-18-ad_after_tuple_refactor branch from ad2495c to 7eb39f8 Compare January 19, 2023 19:29
include/tvm/relax/nested_msg.h Outdated Show resolved Hide resolved
include/tvm/relax/nested_msg.h Outdated Show resolved Hide resolved
src/relax/transform/gradient.cc Outdated Show resolved Hide resolved
src/relax/transform/gradient.cc Outdated Show resolved Hide resolved
src/relax/transform/gradient.cc Outdated Show resolved Hide resolved
src/relax/transform/gradient.cc Show resolved Hide resolved
src/relax/transform/gradient.cc Outdated Show resolved Hide resolved
src/relax/transform/gradient.cc Outdated Show resolved Hide resolved
src/relax/transform/gradient.cc Outdated Show resolved Hide resolved
src/relax/transform/gradient.cc Show resolved Hide resolved
modify nested_msg.h, util.h
add testcases
@Ubospica Ubospica force-pushed the mlc-dev/2023-01-18-ad_after_tuple_refactor branch from 9421fff to 9011549 Compare January 20, 2023 09:09
@Ubospica Ubospica force-pushed the mlc-dev/2023-01-18-ad_after_tuple_refactor branch from 9011549 to d2854c6 Compare January 20, 2023 09:11
src/relax/transform/gradient.cc Outdated Show resolved Hide resolved
src/relax/transform/gradient.cc Outdated Show resolved Hide resolved
src/relax/transform/gradient.cc Outdated Show resolved Hide resolved
src/relax/transform/utils.cc Outdated Show resolved Hide resolved
@Ubospica Ubospica force-pushed the mlc-dev/2023-01-18-ad_after_tuple_refactor branch from 3b56962 to 0cdd5dc Compare January 22, 2023 09:25
@Ubospica Ubospica force-pushed the mlc-dev/2023-01-18-ad_after_tuple_refactor branch 2 times, most recently from c6c10dc to e5f4504 Compare January 23, 2023 07:37
@Ubospica Ubospica force-pushed the mlc-dev/2023-01-18-ad_after_tuple_refactor branch from e5f4504 to 0389aaf Compare January 23, 2023 07:45
@SiriusNEO SiriusNEO force-pushed the mlc-dev/2023-01-18-ad_after_tuple_refactor branch from 04572e8 to 3c0d8fc Compare January 23, 2023 09:58
Copy link
Contributor

@SiriusNEO SiriusNEO left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me now

@MasterJH5574 MasterJH5574 merged commit 5f31e8e into mlc-ai:relax Jan 23, 2023
@MasterJH5574
Copy link
Member

Thanks for the great efforts in pushing this PR! Would you folks mind sending the changes of nested msg together with the test in this PR to tlc-pack when getting time? That would ease our next sync. But no hurry at all - enjoy your holidays :-)

MasterJH5574 pushed a commit that referenced this pull request Jan 24, 2023
This PR is a small fix patch for
#103, containing two small
modifications:
- In AD PR we introduce `NestedMsgToExpr`, where the signature of
`fmapleaf` is `Expr fmapleaf(T)`. And for null nested msg it will
directly throws error. Now we change it to `Expr fmapleaf(Optional<T>)`
and pass `NullOpt` to `fmapleaf`, which enables user to decide whether
to throw an error or return some default value.
- In the test of nested msg we forget to change the signature of
`fmapleaf` (originally it is `Expr fmapleaf(NestedMsg<T>)`). It still
passed due to the type casting between `NestedMsg<T>` and `T` but it
needs to be fixed.
MasterJH5574 pushed a commit that referenced this pull request Jan 28, 2023
This is the PR following #55 after source branch moved to personal repo.

This PR is based on #98.

This PR adds the new automatic differentiation API:
- `Gradient(func: GlobalVar, require_grads: Optional[Union[Var,
List[Var]]] = None) -> tvm.ir.transform.Pass`
- transforms the given funcion in the IRModule, and adds a new function
that calculates the gradient with regard to the function's output

Now Gradient only supports differentiating a function in the IRModule
with one dataflow block with respect to the only return value of the
function, which needs to be scalar.

This PR writes two files for unit test:
- `tests/python/relax/test_transform_gradient.py` only contains
`assert_structural_equal` assertions.
- `tests/python/relax/test_transform_gradient_numeric.py` contains
numeric checks, including manually derived gradients and the numerical
differentiation method `check_numerical_grads`.

Checkpoints:
- [x] Refactor to use CopyWithNewParams and ExprFunctor
- [x] Check int64/int32 tensors should not be differentiated (now only
check in params)
- [x] Rebase & migrate to StructInfo
- [x] Refactor about Tuple
- [x] Refactor about NestedMsg
- [x] Support ops taking in tuple or returning tuple
- [x] Eliminating collapse_sum_to (done in #98)

Future:
- (Not in this PR) Handle undefined gradient in add and return value
	- Now we handle them as zeros

Co-authored-by: SiriusNEO <1713833595@qq.com>
MasterJH5574 pushed a commit that referenced this pull request Jan 28, 2023
This PR is a small fix patch for
#103, containing two small
modifications:
- In AD PR we introduce `NestedMsgToExpr`, where the signature of
`fmapleaf` is `Expr fmapleaf(T)`. And for null nested msg it will
directly throws error. Now we change it to `Expr fmapleaf(Optional<T>)`
and pass `NullOpt` to `fmapleaf`, which enables user to decide whether
to throw an error or return some default value.
- In the test of nested msg we forget to change the signature of
`fmapleaf` (originally it is `Expr fmapleaf(NestedMsg<T>)`). It still
passed due to the type casting between `NestedMsg<T>` and `T` but it
needs to be fixed.
MasterJH5574 pushed a commit that referenced this pull request Jan 31, 2023
This is the PR following #55 after source branch moved to personal repo.

This PR is based on #98.

This PR adds the new automatic differentiation API:
- `Gradient(func: GlobalVar, require_grads: Optional[Union[Var,
List[Var]]] = None) -> tvm.ir.transform.Pass`
- transforms the given funcion in the IRModule, and adds a new function
that calculates the gradient with regard to the function's output

Now Gradient only supports differentiating a function in the IRModule
with one dataflow block with respect to the only return value of the
function, which needs to be scalar.

This PR writes two files for unit test:
- `tests/python/relax/test_transform_gradient.py` only contains
`assert_structural_equal` assertions.
- `tests/python/relax/test_transform_gradient_numeric.py` contains
numeric checks, including manually derived gradients and the numerical
differentiation method `check_numerical_grads`.

Checkpoints:
- [x] Refactor to use CopyWithNewParams and ExprFunctor
- [x] Check int64/int32 tensors should not be differentiated (now only
check in params)
- [x] Rebase & migrate to StructInfo
- [x] Refactor about Tuple
- [x] Refactor about NestedMsg
- [x] Support ops taking in tuple or returning tuple
- [x] Eliminating collapse_sum_to (done in #98)

Future:
- (Not in this PR) Handle undefined gradient in add and return value
	- Now we handle them as zeros

Co-authored-by: SiriusNEO <1713833595@qq.com>
MasterJH5574 pushed a commit that referenced this pull request Feb 8, 2023
This is the PR following #55 after source branch moved to personal repo.

This PR is based on #98.

This PR adds the new automatic differentiation API:
- `Gradient(func: GlobalVar, require_grads: Optional[Union[Var,
List[Var]]] = None) -> tvm.ir.transform.Pass`
- transforms the given funcion in the IRModule, and adds a new function
that calculates the gradient with regard to the function's output

Now Gradient only supports differentiating a function in the IRModule
with one dataflow block with respect to the only return value of the
function, which needs to be scalar.

This PR writes two files for unit test:
- `tests/python/relax/test_transform_gradient.py` only contains
`assert_structural_equal` assertions.
- `tests/python/relax/test_transform_gradient_numeric.py` contains
numeric checks, including manually derived gradients and the numerical
differentiation method `check_numerical_grads`.

Checkpoints:
- [x] Refactor to use CopyWithNewParams and ExprFunctor
- [x] Check int64/int32 tensors should not be differentiated (now only
check in params)
- [x] Rebase & migrate to StructInfo
- [x] Refactor about Tuple
- [x] Refactor about NestedMsg
- [x] Support ops taking in tuple or returning tuple
- [x] Eliminating collapse_sum_to (done in #98)

Future:
- (Not in this PR) Handle undefined gradient in add and return value
	- Now we handle them as zeros

Co-authored-by: SiriusNEO <1713833595@qq.com>
MasterJH5574 pushed a commit that referenced this pull request Feb 12, 2023
This is the PR following #55 after source branch moved to personal repo.

This PR is based on #98.

This PR adds the new automatic differentiation API:
- `Gradient(func: GlobalVar, require_grads: Optional[Union[Var,
List[Var]]] = None) -> tvm.ir.transform.Pass`
- transforms the given funcion in the IRModule, and adds a new function
that calculates the gradient with regard to the function's output

Now Gradient only supports differentiating a function in the IRModule
with one dataflow block with respect to the only return value of the
function, which needs to be scalar.

This PR writes two files for unit test:
- `tests/python/relax/test_transform_gradient.py` only contains
`assert_structural_equal` assertions.
- `tests/python/relax/test_transform_gradient_numeric.py` contains
numeric checks, including manually derived gradients and the numerical
differentiation method `check_numerical_grads`.

Checkpoints:
- [x] Refactor to use CopyWithNewParams and ExprFunctor
- [x] Check int64/int32 tensors should not be differentiated (now only
check in params)
- [x] Rebase & migrate to StructInfo
- [x] Refactor about Tuple
- [x] Refactor about NestedMsg
- [x] Support ops taking in tuple or returning tuple
- [x] Eliminating collapse_sum_to (done in #98)

Future:
- (Not in this PR) Handle undefined gradient in add and return value
	- Now we handle them as zeros

Co-authored-by: SiriusNEO <1713833595@qq.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants