Conversation
|
The documentation is not available anymore as the PR was closed or merged. |
younesbelkada
left a comment
There was a problem hiding this comment.
Thanks a lot for this great addition! I left few comments and questions as a first pass!
| mini_batch_data, | ||
| batch_size=self.config.mini_batch_size, | ||
| shuffle=True, | ||
| collate_fn=collator, |
There was a problem hiding this comment.
| collate_fn=collator, | |
| collate_fn=collator, | |
| drop_last=True, |
Maybe we can add this to avoid some corner-cases such as the one described on a previous issue
There was a problem hiding this comment.
Sounds good, let's also set a warning if that's the case so the user knows that a batch will be dropped.
| bs = self.config.batch_size | ||
| fbs = self.config.forward_batch_size | ||
| bs = len(queries) | ||
| fbs = min(bs, self.config.forward_batch_size) |
There was a problem hiding this comment.
So this is the case where the last element has less instances than the mini_batch_size or the case a users put a batch_size that is smaller than mini_batch_size on the config? If it's the second case we can maybe add a warning on the config, if the first case since we have drop_last=True set here I don't think we'll face this case but I am not sure
There was a problem hiding this comment.
It's for the case where mini_batch_size is smaller than forward_batch_size during the forward passes inside the minibatch loop. I am also not quite happy with how we do it actually.
younesbelkada
left a comment
There was a problem hiding this comment.
Also, what about completely removing forward_batch_size from the config? I don't think this is a breaking change as the configs cannot be pushed on the Hub, just need to update the examples accodingly. I believe this can be done on a follow up PR too
|
The breaking change actually also happens for users who currently use the library with |
|
This solution makes a lot of sense yes! |
Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
|
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
|
Deprecated |
younesbelkada
left a comment
There was a problem hiding this comment.
Thanks a lot for your great work on this! 💯
* add minibatching * all the fixes i missed * ore fixes * add dedicated variable for mini batch size * style * minor fixes * fix rewards * unbiased variance estimation * mask values/returns * moar fixes * style * change structure and add moar tests * Apply suggestions from code review Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * deprecate `forward_batch_size` * remove out of date warning about batching s2s and left padding models * make style * fixed failed merge --------- Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
Until now the PPO mini batch size has been hardcoded to 1. This PR aims to change it by refactoring the forward/backward passing logic.
In summary this PR does the following things:
batched_forward_passreturns new amaskwhich can be used to mask parts of the sequence to be ignoreddataloaderwith themini_batch_sizeto sample from the current PPO batchlossmethod we replace all operations affected by masked parts of the sequence with masked ones (masked_mean,masked_whiten)compute_logits_vpredand usebatched_forward_passfor everythingW&B logs: