Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Custom metrics from compute_loss in TGA #4905

Closed
wants to merge 12 commits into from
Closed

Conversation

mojtaba-komeili
Copy link
Contributor

Patch description
The compute_loss function in TGA generates the loss_per_token values but they are not accessible from outside this method. This PR adds a handle for calculating custom metrics that require loss_per_token. The added method (custom_loss_metrics) can be overridden by its children to work with the loss_per_token and batch for generating custom metrics.

@mojtaba-komeili mojtaba-komeili changed the title Custom metrics for loss values in TGA Custom metrics from compute_loss in TGA Dec 2, 2022
@@ -34,7 +34,7 @@
from parlai.utils.misc import warn_once
from parlai.utils.io import PathManager
import parlai.utils.logging as logging
from parlai.core.metrics import Metric, SumMetric, AverageMetric, FairseqBleuMetric
from parlai.core.metrics import SumMetric, AverageMetric, FairseqBleuMetric
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just for the lint error.

mojtaba-komeili and others added 2 commits December 5, 2022 13:55
* cat or concat

* back to cat

* Only add the metric if it is not None

* lint
* zero3 init commit

* minor cleanup:

* handle mpeval

* remove fairscale dependence

* fsdp avail

* update reqs

* better reqs

* autoformat

* autofromat
@@ -715,6 +715,12 @@ def _encoder_input(self, batch):
"""
return self._model_input(batch)

def custom_loss_metrics(self, batch, loss_per_token):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

super nit: can we call this something more related to what it's doing?

e.g., compute_per_token_metrics?

dependabot bot and others added 5 commits December 5, 2022 14:47
Bumps [decode-uri-component](https://github.com/SamVerschueren/decode-uri-component) from 0.2.0 to 0.2.2.
- [Release notes](https://github.com/SamVerschueren/decode-uri-component/releases)
- [Commits](SamVerschueren/decode-uri-component@v0.2.0...v0.2.2)

---
updated-dependencies:
- dependency-name: decode-uri-component
  dependency-type: indirect
...

Signed-off-by: dependabot[bot] <support@github.com>

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
…o compute-loss

* 'compute-loss' of github.com:facebookresearch/ParlAI:
  lint
  added the custom_loss_metrics
@klshuster klshuster self-requested a review December 13, 2022 16:24
klshuster and others added 3 commits December 13, 2022 16:16
* delay loading of ngram blocking codde

* tga
…rlAI into compute-loss

* 'compute-loss' of https://github.com/facebookresearch/ParlAI:
  added the custom_loss_metrics
  lint
  added the custom_loss_metrics
@mojtaba-komeili
Copy link
Contributor Author

Rebase added some unwanted changes to this PR. Closing it and opening a new clean one.

@mojtaba-komeili mojtaba-komeili deleted the compute-loss branch December 15, 2022 01:38
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants