Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Op][NN] cross_entropy, log_softmax, nll_loss #94

Merged
merged 10 commits into from
Jan 15, 2023

Conversation

SiriusNEO
Copy link
Contributor

@SiriusNEO SiriusNEO commented Jan 10, 2023

After discussing about the loss, a good way is log_softmax + nll_loss. This PR introduces these two operators and tests them.
As for nll_loss, here are some basic shape descriptions which may help review. And an important reference: https://pytorch.org/docs/stable/generated/torch.nn.NLLLoss.html#torch.nn.NLLLoss

def nll_loss(
    predictions: Expr,
    targets: Expr,
    weights: Optional[Expr] = None,
    reduction: str = "mean",
    ignore_index: int = -100,
) -> Expr:

Notations:
N: minibatch size
C: number of classes
K: number of input dimensions

Shape:
    weights: (C,) (always)

  without minibatch:
    predictions: (C,)
    targets: ()
    output: ()
    
  with minibatch N:
    predictions: (N, C)
    targets: (N,)
    output: (N,) (reduction=none)
    output: () (reduction=mean/sum)
  
  with minibatch N and high dimension input d1, d2, ..., dk:
    predictions: (N, C, d1, d2, ..., dk)
    targets: (N, d1, d2, ..., dk)
    output: (N, d1, d2, ..., dk) (reduction=none)
    output: () (reduction=mean/sum)

Our inference rule is trusting predictions, do equal assertion if other arguments have enough information and do best effort inference. Please check the code for details.

This PR also introduces cross entropy operator since it is dropped when rebasing onto tlc. Given that torch has different definitions with our cross entropy, here we use the names cross_entropy_without_logits and cross_entropy_with_logits to make it less confused and align with relay.

python/tvm/relax/op/nn/nn.py Outdated Show resolved Hide resolved
python/tvm/relax/op/nn/nn.py Outdated Show resolved Hide resolved
tests/python/relax/test_op_nn.py Show resolved Hide resolved
src/relax/op/nn/nn.cc Outdated Show resolved Hide resolved
src/relax/op/nn/nn.cc Show resolved Hide resolved
@Ubospica Ubospica requested a review from tqchen January 11, 2023 06:29
@SiriusNEO SiriusNEO changed the base branch from structinfo to relax January 11, 2023 15:32
src/relax/op/nn/nn.cc Outdated Show resolved Hide resolved
src/relax/op/nn/nn.cc Outdated Show resolved Hide resolved
src/relax/op/nn/nn.cc Outdated Show resolved Hide resolved
src/relax/op/nn/nn.cc Outdated Show resolved Hide resolved
tests/python/relax/test_op_nn.py Show resolved Hide resolved
src/relax/op/nn/nn.cc Outdated Show resolved Hide resolved
@MasterJH5574
Copy link
Member

One thing occurs to me. Let’s remove softmax_cross_entropy and cross_entropy from our codebase in this PR.

@SiriusNEO SiriusNEO force-pushed the mlc-dev/2023-01-09-two-nn-ops branch from 048ac8d to 2c9c8b9 Compare January 12, 2023 10:33
@SiriusNEO SiriusNEO changed the title [Op][NN] log_softmax, nll_loss [Op][NN] cross_entropy, log_softmax, nll_loss Jan 12, 2023
@SiriusNEO SiriusNEO force-pushed the mlc-dev/2023-01-09-two-nn-ops branch from f8cff75 to bc5a3f7 Compare January 12, 2023 10:42
@MasterJH5574 MasterJH5574 merged commit 97bcded into mlc-ai:relax Jan 15, 2023
MasterJH5574 pushed a commit that referenced this pull request Jan 16, 2023
* Change call_tir convention and fix shape/type deduction.

* test

* output shape as 3rd arg.

* address comments.

* lint
MasterJH5574 pushed a commit that referenced this pull request Jan 16, 2023
After discussing about the loss, a good way is `log_softmax` + `nll_loss`. This PR introduces these two operators and tests them.
As for `nll_loss`, here are some basic shape descriptions which may help review. And an important reference: https://pytorch.org/docs/stable/generated/torch.nn.NLLLoss.html#torch.nn.NLLLoss
```
def nll_loss(
    predictions: Expr,
    targets: Expr,
    weights: Optional[Expr] = None,
    reduction: str = "mean",
    ignore_index: int = -100,
) -> Expr:

Notations:
N: minibatch size
C: number of classes
K: number of input dimensions

Shape:
    weights: (C,) (always)

  without minibatch:
    predictions: (C,)
    targets: ()
    output: ()
    
  with minibatch N:
    predictions: (N, C)
    targets: (N,)
    output: (N,) (reduction=none)
    output: () (reduction=mean/sum)
  
  with minibatch N and high dimension input d1, d2, ..., dk:
    predictions: (N, C, d1, d2, ..., dk)
    targets: (N, d1, d2, ..., dk)
    output: (N, d1, d2, ..., dk) (reduction=none)
    output: () (reduction=mean/sum)
```
Our inference rule is trusting `predictions`, do equal assertion if other arguments have enough information and do best effort inference. Please check the code for details.

This PR also introduces cross entropy operator since it is dropped when rebasing onto tlc. Given that torch has different definitions with our cross entropy, here we use the names `cross_entropy_without_logits` and `cross_entropy_with_logits` to make it less confused and align with relay.
MasterJH5574 pushed a commit that referenced this pull request Jan 19, 2023
This PR migrates #46 to new struct
info infra, as part of our AD migration.

Because we need do numerical testing for gradients, this PR depends on
the operator legalizer #96. Also
because the original version of legalizer did not handle the negative
indexing case of `relax.mean`, this PR fixes it.

To lower `collapse_sum_to`, `collapse_sum_like` properly, this PR
migrates a previous patch #43 which
introduces `collapse_sum` in topi. Now we can remove the skip marker in
the legalizer test for `collapse_sum_to` and `collapse_sum_like`.

The gradients of `cross_entropy` and `softmax_cross_entropy` are
removed. And the former will be added back and adjust to new
`cross_entropy` introduced in #96.

Further plan in this PR:
- [x] Add gradients for `log_softmax` and `nll_loss` once
#94 is merged.
- [x] Gradients for some tuple related operators such as `split` and
`concat`. It can help us to test the correctness of AD when there are
Tuple-I/O operators.
- (Not in this PR) "Undefined Gradient" representation. As we know, the
gradients of some operators w.r.t. specified inputs are undefined or
meaningless, such as the partial gradient of `indices` in `take(x,
indices)`. Relay directly uses `zeros_like` in this case as it won't
affect gradient propagation. Another choice is to introduce a dummy Expr
named `UndefinedGradient` to represent it. How do we handle this case in
relax?
MasterJH5574 pushed a commit that referenced this pull request Jan 28, 2023
After discussing about the loss, a good way is `log_softmax` + `nll_loss`. This PR introduces these two operators and tests them.
As for `nll_loss`, here are some basic shape descriptions which may help review. And an important reference: https://pytorch.org/docs/stable/generated/torch.nn.NLLLoss.html#torch.nn.NLLLoss
```
def nll_loss(
    predictions: Expr,
    targets: Expr,
    weights: Optional[Expr] = None,
    reduction: str = "mean",
    ignore_index: int = -100,
) -> Expr:

Notations:
N: minibatch size
C: number of classes
K: number of input dimensions

Shape:
    weights: (C,) (always)

  without minibatch:
    predictions: (C,)
    targets: ()
    output: ()
    
  with minibatch N:
    predictions: (N, C)
    targets: (N,)
    output: (N,) (reduction=none)
    output: () (reduction=mean/sum)
  
  with minibatch N and high dimension input d1, d2, ..., dk:
    predictions: (N, C, d1, d2, ..., dk)
    targets: (N, d1, d2, ..., dk)
    output: (N, d1, d2, ..., dk) (reduction=none)
    output: () (reduction=mean/sum)
```
Our inference rule is trusting `predictions`, do equal assertion if other arguments have enough information and do best effort inference. Please check the code for details.

This PR also introduces cross entropy operator since it is dropped when rebasing onto tlc. Given that torch has different definitions with our cross entropy, here we use the names `cross_entropy_without_logits` and `cross_entropy_with_logits` to make it less confused and align with relay.
MasterJH5574 pushed a commit that referenced this pull request Jan 28, 2023
This PR migrates #46 to new struct
info infra, as part of our AD migration.

Because we need do numerical testing for gradients, this PR depends on
the operator legalizer #96. Also
because the original version of legalizer did not handle the negative
indexing case of `relax.mean`, this PR fixes it.

To lower `collapse_sum_to`, `collapse_sum_like` properly, this PR
migrates a previous patch #43 which
introduces `collapse_sum` in topi. Now we can remove the skip marker in
the legalizer test for `collapse_sum_to` and `collapse_sum_like`.

The gradients of `cross_entropy` and `softmax_cross_entropy` are
removed. And the former will be added back and adjust to new
`cross_entropy` introduced in #96.

Further plan in this PR:
- [x] Add gradients for `log_softmax` and `nll_loss` once
#94 is merged.
- [x] Gradients for some tuple related operators such as `split` and
`concat`. It can help us to test the correctness of AD when there are
Tuple-I/O operators.
- (Not in this PR) "Undefined Gradient" representation. As we know, the
gradients of some operators w.r.t. specified inputs are undefined or
meaningless, such as the partial gradient of `indices` in `take(x,
indices)`. Relay directly uses `zeros_like` in this case as it won't
affect gradient propagation. Another choice is to introduce a dummy Expr
named `UndefinedGradient` to represent it. How do we handle this case in
relax?
MasterJH5574 pushed a commit that referenced this pull request Jan 31, 2023
After discussing about the loss, a good way is `log_softmax` + `nll_loss`. This PR introduces these two operators and tests them.
As for `nll_loss`, here are some basic shape descriptions which may help review. And an important reference: https://pytorch.org/docs/stable/generated/torch.nn.NLLLoss.html#torch.nn.NLLLoss
```
def nll_loss(
    predictions: Expr,
    targets: Expr,
    weights: Optional[Expr] = None,
    reduction: str = "mean",
    ignore_index: int = -100,
) -> Expr:

Notations:
N: minibatch size
C: number of classes
K: number of input dimensions

Shape:
    weights: (C,) (always)

  without minibatch:
    predictions: (C,)
    targets: ()
    output: ()
    
  with minibatch N:
    predictions: (N, C)
    targets: (N,)
    output: (N,) (reduction=none)
    output: () (reduction=mean/sum)
  
  with minibatch N and high dimension input d1, d2, ..., dk:
    predictions: (N, C, d1, d2, ..., dk)
    targets: (N, d1, d2, ..., dk)
    output: (N, d1, d2, ..., dk) (reduction=none)
    output: () (reduction=mean/sum)
```
Our inference rule is trusting `predictions`, do equal assertion if other arguments have enough information and do best effort inference. Please check the code for details.

This PR also introduces cross entropy operator since it is dropped when rebasing onto tlc. Given that torch has different definitions with our cross entropy, here we use the names `cross_entropy_without_logits` and `cross_entropy_with_logits` to make it less confused and align with relay.
MasterJH5574 pushed a commit that referenced this pull request Jan 31, 2023
This PR migrates #46 to new struct
info infra, as part of our AD migration.

Because we need do numerical testing for gradients, this PR depends on
the operator legalizer #96. Also
because the original version of legalizer did not handle the negative
indexing case of `relax.mean`, this PR fixes it.

To lower `collapse_sum_to`, `collapse_sum_like` properly, this PR
migrates a previous patch #43 which
introduces `collapse_sum` in topi. Now we can remove the skip marker in
the legalizer test for `collapse_sum_to` and `collapse_sum_like`.

The gradients of `cross_entropy` and `softmax_cross_entropy` are
removed. And the former will be added back and adjust to new
`cross_entropy` introduced in #96.

Further plan in this PR:
- [x] Add gradients for `log_softmax` and `nll_loss` once
#94 is merged.
- [x] Gradients for some tuple related operators such as `split` and
`concat`. It can help us to test the correctness of AD when there are
Tuple-I/O operators.
- (Not in this PR) "Undefined Gradient" representation. As we know, the
gradients of some operators w.r.t. specified inputs are undefined or
meaningless, such as the partial gradient of `indices` in `take(x,
indices)`. Relay directly uses `zeros_like` in this case as it won't
affect gradient propagation. Another choice is to introduce a dummy Expr
named `UndefinedGradient` to represent it. How do we handle this case in
relax?
MasterJH5574 pushed a commit that referenced this pull request Feb 8, 2023
* Change call_tir convention and fix shape/type deduction.

* test

* output shape as 3rd arg.

* address comments.

* lint
MasterJH5574 pushed a commit that referenced this pull request Feb 8, 2023
After discussing about the loss, a good way is `log_softmax` + `nll_loss`. This PR introduces these two operators and tests them.
As for `nll_loss`, here are some basic shape descriptions which may help review. And an important reference: https://pytorch.org/docs/stable/generated/torch.nn.NLLLoss.html#torch.nn.NLLLoss
```
def nll_loss(
    predictions: Expr,
    targets: Expr,
    weights: Optional[Expr] = None,
    reduction: str = "mean",
    ignore_index: int = -100,
) -> Expr:

Notations:
N: minibatch size
C: number of classes
K: number of input dimensions

Shape:
    weights: (C,) (always)

  without minibatch:
    predictions: (C,)
    targets: ()
    output: ()

  with minibatch N:
    predictions: (N, C)
    targets: (N,)
    output: (N,) (reduction=none)
    output: () (reduction=mean/sum)

  with minibatch N and high dimension input d1, d2, ..., dk:
    predictions: (N, C, d1, d2, ..., dk)
    targets: (N, d1, d2, ..., dk)
    output: (N, d1, d2, ..., dk) (reduction=none)
    output: () (reduction=mean/sum)
```
Our inference rule is trusting `predictions`, do equal assertion if other arguments have enough information and do best effort inference. Please check the code for details.

This PR also introduces cross entropy operator since it is dropped when rebasing onto tlc. Given that torch has different definitions with our cross entropy, here we use the names `cross_entropy_without_logits` and `cross_entropy_with_logits` to make it less confused and align with relay.
MasterJH5574 pushed a commit that referenced this pull request Feb 8, 2023
This PR migrates #46 to new struct
info infra, as part of our AD migration.

Because we need do numerical testing for gradients, this PR depends on
the operator legalizer #96. Also
because the original version of legalizer did not handle the negative
indexing case of `relax.mean`, this PR fixes it.

To lower `collapse_sum_to`, `collapse_sum_like` properly, this PR
migrates a previous patch #43 which
introduces `collapse_sum` in topi. Now we can remove the skip marker in
the legalizer test for `collapse_sum_to` and `collapse_sum_like`.

The gradients of `cross_entropy` and `softmax_cross_entropy` are
removed. And the former will be added back and adjust to new
`cross_entropy` introduced in #96.

Further plan in this PR:
- [x] Add gradients for `log_softmax` and `nll_loss` once
#94 is merged.
- [x] Gradients for some tuple related operators such as `split` and
`concat`. It can help us to test the correctness of AD when there are
Tuple-I/O operators.
- (Not in this PR) "Undefined Gradient" representation. As we know, the
gradients of some operators w.r.t. specified inputs are undefined or
meaningless, such as the partial gradient of `indices` in `take(x,
indices)`. Relay directly uses `zeros_like` in this case as it won't
affect gradient propagation. Another choice is to introduce a dummy Expr
named `UndefinedGradient` to represent it. How do we handle this case in
relax?
spectrometerHBH pushed a commit to spectrometerHBH/relax that referenced this pull request Feb 9, 2023
* Change call_tir convention and fix shape/type deduction.

* test

* output shape as 3rd arg.

* address comments.

* lint
MasterJH5574 pushed a commit to MasterJH5574/tlc-relax that referenced this pull request Feb 12, 2023
This PR migrates mlc-ai/relax#46 to new struct
info infra, as part of our AD migration.

Because we need do numerical testing for gradients, this PR depends on
the operator legalizer mlc-ai/relax#96. Also
because the original version of legalizer did not handle the negative
indexing case of `relax.mean`, this PR fixes it.

To lower `collapse_sum_to`, `collapse_sum_like` properly, this PR
migrates a previous patch mlc-ai/relax#43 which
introduces `collapse_sum` in topi. Now we can remove the skip marker in
the legalizer test for `collapse_sum_to` and `collapse_sum_like`.

The gradients of `cross_entropy` and `softmax_cross_entropy` are
removed. And the former will be added back and adjust to new
`cross_entropy` introduced in mlc-ai/relax#96.

Further plan in this PR:
- [x] Add gradients for `log_softmax` and `nll_loss` once
mlc-ai/relax#94 is merged.
- [x] Gradients for some tuple related operators such as `split` and
`concat`. It can help us to test the correctness of AD when there are
Tuple-I/O operators.
- (Not in this PR) "Undefined Gradient" representation. As we know, the
gradients of some operators w.r.t. specified inputs are undefined or
meaningless, such as the partial gradient of `indices` in `take(x,
indices)`. Relay directly uses `zeros_like` in this case as it won't
affect gradient propagation. Another choice is to introduce a dummy Expr
named `UndefinedGradient` to represent it. How do we handle this case in
relax?
MasterJH5574 pushed a commit to MasterJH5574/tlc-relax that referenced this pull request Feb 12, 2023
This PR migrates mlc-ai/relax#46 to new struct
info infra, as part of our AD migration.

Because we need do numerical testing for gradients, this PR depends on
the operator legalizer mlc-ai/relax#96. Also
because the original version of legalizer did not handle the negative
indexing case of `relax.mean`, this PR fixes it.

To lower `collapse_sum_to`, `collapse_sum_like` properly, this PR
migrates a previous patch mlc-ai/relax#43 which
introduces `collapse_sum` in topi. Now we can remove the skip marker in
the legalizer test for `collapse_sum_to` and `collapse_sum_like`.

The gradients of `cross_entropy` and `softmax_cross_entropy` are
removed. And the former will be added back and adjust to new
`cross_entropy` introduced in mlc-ai/relax#96.

Further plan in this PR:
- [x] Add gradients for `log_softmax` and `nll_loss` once
mlc-ai/relax#94 is merged.
- [x] Gradients for some tuple related operators such as `split` and
`concat`. It can help us to test the correctness of AD when there are
Tuple-I/O operators.
- (Not in this PR) "Undefined Gradient" representation. As we know, the
gradients of some operators w.r.t. specified inputs are undefined or
meaningless, such as the partial gradient of `indices` in `take(x,
indices)`. Relay directly uses `zeros_like` in this case as it won't
affect gradient propagation. Another choice is to introduce a dummy Expr
named `UndefinedGradient` to represent it. How do we handle this case in
relax?
MasterJH5574 pushed a commit that referenced this pull request Feb 12, 2023
After discussing about the loss, a good way is `log_softmax` + `nll_loss`. This PR introduces these two operators and tests them.
As for `nll_loss`, here are some basic shape descriptions which may help review. And an important reference: https://pytorch.org/docs/stable/generated/torch.nn.NLLLoss.html#torch.nn.NLLLoss
```
def nll_loss(
    predictions: Expr,
    targets: Expr,
    weights: Optional[Expr] = None,
    reduction: str = "mean",
    ignore_index: int = -100,
) -> Expr:

Notations:
N: minibatch size
C: number of classes
K: number of input dimensions

Shape:
    weights: (C,) (always)

  without minibatch:
    predictions: (C,)
    targets: ()
    output: ()

  with minibatch N:
    predictions: (N, C)
    targets: (N,)
    output: (N,) (reduction=none)
    output: () (reduction=mean/sum)

  with minibatch N and high dimension input d1, d2, ..., dk:
    predictions: (N, C, d1, d2, ..., dk)
    targets: (N, d1, d2, ..., dk)
    output: (N, d1, d2, ..., dk) (reduction=none)
    output: () (reduction=mean/sum)
```
Our inference rule is trusting `predictions`, do equal assertion if other arguments have enough information and do best effort inference. Please check the code for details.

This PR also introduces cross entropy operator since it is dropped when rebasing onto tlc. Given that torch has different definitions with our cross entropy, here we use the names `cross_entropy_without_logits` and `cross_entropy_with_logits` to make it less confused and align with relay.
MasterJH5574 pushed a commit that referenced this pull request Feb 12, 2023
This PR migrates #46 to new struct
info infra, as part of our AD migration.

Because we need do numerical testing for gradients, this PR depends on
the operator legalizer #96. Also
because the original version of legalizer did not handle the negative
indexing case of `relax.mean`, this PR fixes it.

To lower `collapse_sum_to`, `collapse_sum_like` properly, this PR
migrates a previous patch #43 which
introduces `collapse_sum` in topi. Now we can remove the skip marker in
the legalizer test for `collapse_sum_to` and `collapse_sum_like`.

The gradients of `cross_entropy` and `softmax_cross_entropy` are
removed. And the former will be added back and adjust to new
`cross_entropy` introduced in #96.

Further plan in this PR:
- [x] Add gradients for `log_softmax` and `nll_loss` once
#94 is merged.
- [x] Gradients for some tuple related operators such as `split` and
`concat`. It can help us to test the correctness of AD when there are
Tuple-I/O operators.
- (Not in this PR) "Undefined Gradient" representation. As we know, the
gradients of some operators w.r.t. specified inputs are undefined or
meaningless, such as the partial gradient of `indices` in `take(x,
indices)`. Relay directly uses `zeros_like` in this case as it won't
affect gradient propagation. Another choice is to introduce a dummy Expr
named `UndefinedGradient` to represent it. How do we handle this case in
relax?
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants