New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix ModelCheckpoint trained-on batch counting when using steps_per_execution>1 #17632
Conversation
When setting the steps_per_execution argument to a value N>1 when calling model.compile(), the on_train_batch_end() method of model.fit()'s callbacks only gets called every N batches with an argument batch equal to the 0-indexed index of the last batch which has been trained on. That is, after the first N trained-on batches, on_train_batch_end() gets called with its batch argument equal to N-1, then N trained-on batches later, to 2N-1, etc. until the end of the epoch. In order to handle this situation, ModelCheckpoint uses a _last_batch_seen member integer variable to record the value of the batch argument of its on_train_batch_end() method the last time this method was called. When on_train_batch_end() is called again, ModelCheckpoint then computes (in its _should_save_on_batch() method) add_batches = batch - self._last_batch_seen in order to know the number of batches which have been trained on between two consecutive calls to its on_train_batch_end() method. However, the _last_batch_seen member variable is initialized to 0 which means that, when using steps_per_execution=N, the first time on_train_batch_end() is called after N batches have been trained on (with a batch argument equal to N-1), only N-1 batches are counted since add_batches = batch - self._last_batch_seen = (N-1) - 0 = N-1 instead of N. This makes ModelCheckpoint miss one batch when counting them and effectively offset its save_freq contructor argument by 1. Therefore an initialization value of -1 is needed. In the special cases of steps_per_execution=1 or steps_per_execution=None (which are equivalent), the bug was hidden by the fact that the condition to check for a new epoch (batch <= self._last_batch_seen) was true since on the first call to on_train_batch_end() both the batch argument and _last_batch_seen variable were equal to 0. In this case, the number of batches trained on is counted by computing add_batches = batch + 1 = 1, which is indeed the correct result.
Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA). View this failed invocation of the CLA check for more information. For the most up to date status, view the checks section at the bottom of the pull request. |
Thanks for the PR! Please add a unit test for the change. |
I'm new to writing unit tests for Keras. Should I create a new method named like |
Yes, exactly. You can use a tmp folder for saving the checkpoints (see how it's done in other tests). |
I have just added unit tests which correctly fail without the fix and pass with it. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thank you!
…teps_per_execution>1 Imported from GitHub PR #17632 When setting the steps_per_execution argument to a value N>1 when calling model.compile(), the on_train_batch_end() method of model.fit()'s callbacks only gets called every N batches with an argument batch equal to the 0-indexed index of the last batch which has been trained on. That is, after the first N trained-on batches, on_train_batch_end() gets called with its batch argument equal to N-1, then N trained-on batches later, to 2N-1, etc. until the end of the epoch. In order to handle this situation, ModelCheckpoint uses a _last_batch_seen member integer variable to record the value of the batch argument of its on_train_batch_end() method the last time this method was called. When on_train_batch_end() is called again, ModelCheckpoint then computes (in its _should_save_on_batch() method) add_batches = batch - self._last_batch_seen in order to know the number of batches which have been trained on between two consecutive calls to its on_train_batch_end() method. However, the _last_batch_seen member variable is initialized to 0 which means that, when using steps_per_execution=N, the first time on_train_batch_end() is called after N batches have been trained on (with a batch argument equal to N-1), only N-1 batches are counted since add_batches = batch - self._last_batch_seen = (N-1) - 0 = N-1 instead of N. This makes ModelCheckpoint miss one batch when counting them and effectively offset its save_freq contructor argument by 1. Therefore an initialization value of -1 is needed. In the special cases of steps_per_execution=1 or steps_per_execution=None (which are equivalent), the bug was hidden by the fact that the condition to check for a new epoch (batch <= self._last_batch_seen) was true since on the first call to on_train_batch_end() both the batch argument and _last_batch_seen variable were equal to 0. In this case, the number of batches trained on is counted by computing add_batches = batch + 1 = 1, which is indeed the correct result. Copybara import of the project: -- 8b9f81d by Maël A <86840696+jasnyj@users.noreply.github.com>: Fix ModelCheckpoint trained-on batch counting When setting the steps_per_execution argument to a value N>1 when calling model.compile(), the on_train_batch_end() method of model.fit()'s callbacks only gets called every N batches with an argument batch equal to the 0-indexed index of the last batch which has been trained on. That is, after the first N trained-on batches, on_train_batch_end() gets called with its batch argument equal to N-1, then N trained-on batches later, to 2N-1, etc. until the end of the epoch. In order to handle this situation, ModelCheckpoint uses a _last_batch_seen member integer variable to record the value of the batch argument of its on_train_batch_end() method the last time this method was called. When on_train_batch_end() is called again, ModelCheckpoint then computes (in its _should_save_on_batch() method) add_batches = batch - self._last_batch_seen in order to know the number of batches which have been trained on between two consecutive calls to its on_train_batch_end() method. However, the _last_batch_seen member variable is initialized to 0 which means that, when using steps_per_execution=N, the first time on_train_batch_end() is called after N batches have been trained on (with a batch argument equal to N-1), only N-1 batches are counted since add_batches = batch - self._last_batch_seen = (N-1) - 0 = N-1 instead of N. This makes ModelCheckpoint miss one batch when counting them and effectively offset its save_freq contructor argument by 1. Therefore an initialization value of -1 is needed. In the special cases of steps_per_execution=1 or steps_per_execution=None (which are equivalent), the bug was hidden by the fact that the condition to check for a new epoch (batch <= self._last_batch_seen) was true since on the first call to on_train_batch_end() both the batch argument and _last_batch_seen variable were equal to 0. In this case, the number of batches trained on is counted by computing add_batches = batch + 1 = 1, which is indeed the correct result. -- b342b3a by Maël A <86840696+jasnyj@users.noreply.github.com>: Test ModelCheckpoint with steps_per_execution Merging this change closes #17632 FUTURE_COPYBARA_INTEGRATE_REVIEW=#17632 from jasnyj:master b342b3a PiperOrigin-RevId: 516873881
Here are the internal errors, @jasnyj can you please verify ? Thank you! Traceback (most recent call last): |
The steps_per_execution argument to model.compile(...) is only available on Keras>=2.4.0. Unit tests which are using this argument are therefore causing errors in v1 mode and should not be run in this mode.
@gbaned Indeed. The I was however unable to reproduce the error locally (and therefore test if it did go away), (If you are wondering why I force pushed, it's just that I made the last commit with a wrong email address configured and I had to amend the commit to change it.) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks
…teps_per_execution>1 Imported from GitHub PR #17632 When setting the steps_per_execution argument to a value N>1 when calling model.compile(), the on_train_batch_end() method of model.fit()'s callbacks only gets called every N batches with an argument batch equal to the 0-indexed index of the last batch which has been trained on. That is, after the first N trained-on batches, on_train_batch_end() gets called with its batch argument equal to N-1, then N trained-on batches later, to 2N-1, etc. until the end of the epoch. In order to handle this situation, ModelCheckpoint uses a _last_batch_seen member integer variable to record the value of the batch argument of its on_train_batch_end() method the last time this method was called. When on_train_batch_end() is called again, ModelCheckpoint then computes (in its _should_save_on_batch() method) add_batches = batch - self._last_batch_seen in order to know the number of batches which have been trained on between two consecutive calls to its on_train_batch_end() method. However, the _last_batch_seen member variable is initialized to 0 which means that, when using steps_per_execution=N, the first time on_train_batch_end() is called after N batches have been trained on (with a batch argument equal to N-1), only N-1 batches are counted since add_batches = batch - self._last_batch_seen = (N-1) - 0 = N-1 instead of N. This makes ModelCheckpoint miss one batch when counting them and effectively offset its save_freq contructor argument by 1. Therefore an initialization value of -1 is needed. In the special cases of steps_per_execution=1 or steps_per_execution=None (which are equivalent), the bug was hidden by the fact that the condition to check for a new epoch (batch <= self._last_batch_seen) was true since on the first call to on_train_batch_end() both the batch argument and _last_batch_seen variable were equal to 0. In this case, the number of batches trained on is counted by computing add_batches = batch + 1 = 1, which is indeed the correct result. Copybara import of the project: -- 8b9f81d by Maël A <86840696+jasnyj@users.noreply.github.com>: Fix ModelCheckpoint trained-on batch counting When setting the steps_per_execution argument to a value N>1 when calling model.compile(), the on_train_batch_end() method of model.fit()'s callbacks only gets called every N batches with an argument batch equal to the 0-indexed index of the last batch which has been trained on. That is, after the first N trained-on batches, on_train_batch_end() gets called with its batch argument equal to N-1, then N trained-on batches later, to 2N-1, etc. until the end of the epoch. In order to handle this situation, ModelCheckpoint uses a _last_batch_seen member integer variable to record the value of the batch argument of its on_train_batch_end() method the last time this method was called. When on_train_batch_end() is called again, ModelCheckpoint then computes (in its _should_save_on_batch() method) add_batches = batch - self._last_batch_seen in order to know the number of batches which have been trained on between two consecutive calls to its on_train_batch_end() method. However, the _last_batch_seen member variable is initialized to 0 which means that, when using steps_per_execution=N, the first time on_train_batch_end() is called after N batches have been trained on (with a batch argument equal to N-1), only N-1 batches are counted since add_batches = batch - self._last_batch_seen = (N-1) - 0 = N-1 instead of N. This makes ModelCheckpoint miss one batch when counting them and effectively offset its save_freq contructor argument by 1. Therefore an initialization value of -1 is needed. In the special cases of steps_per_execution=1 or steps_per_execution=None (which are equivalent), the bug was hidden by the fact that the condition to check for a new epoch (batch <= self._last_batch_seen) was true since on the first call to on_train_batch_end() both the batch argument and _last_batch_seen variable were equal to 0. In this case, the number of batches trained on is counted by computing add_batches = batch + 1 = 1, which is indeed the correct result. -- b342b3a by Maël A <86840696+jasnyj@users.noreply.github.com>: Test ModelCheckpoint with steps_per_execution -- d290db4 by Maël A <86840696+jasnyj@users.noreply.github.com>: Do not run steps_per_execution tests in v1 mode The steps_per_execution argument to model.compile(...) is only available on Keras>=2.4.0. Unit tests which are using this argument are therefore causing errors in v1 mode and should not be run in this mode. Merging this change closes #17632 FUTURE_COPYBARA_INTEGRATE_REVIEW=#17632 from jasnyj:master d290db4 PiperOrigin-RevId: 557831739
When setting the steps_per_execution argument to a value N>1 when calling model.compile(), the on_train_batch_end() method of model.fit()'s callbacks only gets called every N batches with an argument batch equal to the 0-indexed index of the last batch which has been trained on. That is, after the first N trained-on batches, on_train_batch_end() gets called with its batch argument equal to N-1, then N trained-on batches later, to 2N-1, etc. until the end of the epoch.
In order to handle this situation, ModelCheckpoint uses a _last_batch_seen member integer variable to record the value of the batch argument of its on_train_batch_end() method the last time this method was called. When on_train_batch_end() is called again, ModelCheckpoint then computes (in its _should_save_on_batch() method) add_batches = batch - self._last_batch_seen in order to know the number of batches which have been trained on between two consecutive calls to its on_train_batch_end() method.
However, the _last_batch_seen member variable is initialized to 0 which means that, when using steps_per_execution=N, the first time on_train_batch_end() is called after N batches have been trained on (with a batch argument equal to N-1), only N-1 batches are counted since add_batches = batch - self._last_batch_seen = (N-1) - 0 = N-1 instead of N. This makes ModelCheckpoint miss one batch when counting them and effectively offset its save_freq contructor argument by 1. Therefore an initialization value of -1 is needed.
In the special cases of steps_per_execution=1 or
steps_per_execution=None (which are equivalent), the bug was hidden by the fact that the condition to check for a new epoch (batch <= self._last_batch_seen) was true since on the first call to on_train_batch_end() both the batch argument and _last_batch_seen variable were equal to 0. In this case, the number of batches trained on is counted by computing add_batches = batch + 1 = 1, which is indeed the correct result.