Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Custom metrics from compute_loss in TGA (2nd try) #4913

Merged
merged 2 commits into from Jan 17, 2023
Merged

Conversation

mojtaba-komeili
Copy link
Contributor

Patch description
The compute_loss function in TGA generates the loss_per_token values but they are not accessible from outside this method. This PR adds a handle for calculating custom metrics that require loss_per_token. The added method (custom_loss_metrics) can be overridden by its children to work with the loss_per_token and batch for generating custom metrics.

NOTE: this is cleaned up version of 4905 which became corrupted with other commits after some bad rebase and merge.

@mojtaba-komeili
Copy link
Contributor Author

@klshuster reminding you of this old PR.

Copy link
Contributor

@klshuster klshuster left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the changes!

@mojtaba-komeili mojtaba-komeili merged commit f249627 into main Jan 17, 2023
@mojtaba-komeili mojtaba-komeili deleted the compute-loss-2 branch January 17, 2023 18:32
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants