TF Model train and eval step metrics for seq2seq models.#14009
TF Model train and eval step metrics for seq2seq models.#14009Rocketknight1 merged 2 commits intohuggingface:masterfrom
Conversation
|
This seems like a good fix, thank you! The changes to |
|
I did some quick testing - the impact of the change seems limited as |
|
@Rocketknight1 I did some more testing. When one compiles the model multiple times the metric doesn't work correctly. If a model is compiled only once it does work. I don't know if this is a concern. |
91930d1 to
cd678e8
Compare
When using a model with a seq2seq output compute metrics against logits.
cd678e8 to
7c8022f
Compare
|
@Rocketknight1 Please take another look. The test makes sure that the metric now works as expected. I also removed a call to compute the loss on |
|
@pedro-r-marques You're correct that some vestigial code made it into test_step - that's not great (at our end, you did a good job in spotting it!). Let me double-check that bit and tidy it up in your branch before we merge. |
|
Done! The old code comes from a period when we were experimenting with a different way of handling the model's internal loss computations, and should have been removed. I fixed it now, and the rest of the code looks good. If you're happy with my changes, we can merge once the tests are good. |
|
(That torch failure has nothing to do with this PR, don't worry) |
05af297 to
610beb9
Compare
|
@Rocketknight1 Checks are green now. Please merge the PR, if you are happy with it. Thanks a lot ! |
|
The rebase reverted the fix to the vestigial code, so I took it out again. Will merge once it's all green! |
93d0a01 to
f3b7299
Compare
|
@Rocketknight1 apologies for squashing your changes unintentionally :-(. Cleaned up the git log; hopefully preserving your changes this time and took another roll at the CI dice. |
|
@pedro-r-marques It's okay, I wrecked your changes too. Git is just hard, lol. Anyway, it looks good now and tests are green, so merging! |
When using a model with a seq2seq output compute metrics against logits.
What does this PR do?
This PR changes the TF train and test steps so that metrics can be correctly computed when using keras model.fit.
The keras Model train/test step functions are supposed to compare the labels (y_true) and the predictions (y_pred).
The previous code was passing as y_pred the ModelOutput dataclass (basically a dict) of values which results in
TF/keras attempting to compute metrics between the variable 'y' (y_true) and each of the elements in the ModelOutput dict.
Who can review?
@sgugger
@Rocketknight1