-
Notifications
You must be signed in to change notification settings - Fork 25.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] TensorFlow variant of DataCollatorForLanguageModeling. #12199
[WIP] TensorFlow variant of DataCollatorForLanguageModeling. #12199
Conversation
Co-authored-by: Dalton Walker <dalton_walker@icloud.com>
Thanks a lot for your PR! Before I review more in detail, could you provide an example of use of this API? Data-collators are very PyTorch-ic so I want to make sure this is something that can actually be used in TensorFlow without too many contorsions. |
Absolutely! We are currently in the process of pretraining Bert with a custom dataset in a domain specific language. We are going to make use of the TFBertForPreTraining Model to achieve this as well as a custom trained Tokenizer. (https://huggingface.co/transformers/model_doc/bert.html#tfbertforpretraining) |
Do you have an example of data preprocessing a bit similar to the run_mlm script we have in PyTorch? That would be helpful to see this TF data collator in action. |
We are going to move this PR into a WIP so we can address your question. |
In answer to your question @sgugger, our objective is to integrate the collator with TFTrainer. Currently PyTorch users enjoy this functionality but TensorFlow users do not have the built-in functionality that deserves to be there (unless we are mistaken, and if so apologize). Our idea is to implement the following change in TFTrainer/get_train_tfdataset:
or we could implement the dataset conversion in the collator:
This would provide an avenue for TensorFlow users to train any models requiring collator functionality in TFTrainer. Any advice or alternative solutions are welcome! |
We plan to drop the TFTrainer pretty soon to the profit of using Keras, but this could still be useful as we will still rely on the datasets. |
Our intention is to drop TFTrainer to do training through Keras instead, and as a result in TF we want the input to come from tf.data.Dataset objects rather than custom collators. A lot of things like multi-GPU or TPU training in Keras expect tf.data.Dataset input, and will coerce the input into a Dataset if you don't supply it as one. |
@Rocketknight1 Understood. So providing a collator that could be passed to Dataset.map is the way to go if we want the option. Or are you saying that such an operation should be performed before TFTrainer? I just want to clarify before we continue with a PR. |
We want to avoid TFTrainer entirely in future, so yeah - any kind of custom collator should return a Dataset, or should work through Dataset.map(). This is something we're in the process of updating through our library - there's still a lot of usages of TFTrainer that I'm cleaning up over time! |
Thank you for your quick response! We will continue with the PR going down the .map route. Even though TFTrainer is depreciating, some may still find it beneficial in the meantime. Cheers! |
…g_with_Tensorflow
Co-authored-by: Dalton Walker <dalton_walker@icloud.com>
…g_with_Tensorflow
Co-authored-by: Dalton Walker <dalton_walker@icloud.com>
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
Hey! This isn't something I want to go stale, but I lost track of it when I saw you were still adding commits! Are you happy with it as-is, and ready for a review? |
That is no problem! And we are ready for a review at your convenience. |
Hi! I'm reviewing now. This is actually quite timely - we're planning a general revamp of all the data collators to support both Tensorflow and JAX, as well as support for our Dataset objects to automatically convert to The downside is that we haven't decided how exactly to structure the code yet, so we might ask you to move or rename this class, but hopefully we can use almost all of the code here as part of the revamp! |
self.mlm_probability = 0.15 | ||
|
||
def __post_init__(self): | ||
if self.tokenizer.mask_token is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can see this collator only supports MLM rather than CLM, so we might need to either include CLM support, or rename it to something like TFDataCollatorForMaskedLanguageModeling
and have a separate CLM collator. My preference would be for the former, but I realize that might cause some code complexity.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is something that we've discussed ourselves as well. We wanted to submit what we had to assure that you and the team agreed with our approach, and had planned on finishing out the TFDataCollator suite so it is on par with the Pytorch Data Collator. If CLM collator is highest on priority, that can be the next Data Collator @sdwalker62 and I tackle.
|
||
@tf.function | ||
def pseudo_bernoulli(self, prob_matrix, labels): | ||
return tf.cast(prob_matrix - tf.random.uniform(tf.shape(labels), 0, 1) >= 0, tf.bool) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a clever hack to get around the lack of tf.random.bernoulli
, lol.
@tf.function | ||
def mask_special_tokens(self, labels, special_tokens): | ||
# Finds all special tokens within labels | ||
x = tf.map_fn(lambda b: tf.cast(tf.math.equal(labels, b), tf.int32), special_tokens) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe it should be possible to do this without needing map_fn
and a lambda, which can be slow - just use tf.math.equal
. You can pass multiple elements to that, which will be broadcasted. So if you have e.g. 3 special tokens, then you'll get an extra dimension of length 3 in the output, and you can just squash that with tf.math.reduce_any()
to see if any of the special tokens matched, without needing a cast to tf.int32
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We've attempted something similar before taking this approach, this was definitely our last resort 😆 . We found tf.math.equal
doesn't accept incompatible shapes since the special tokens mask usually doesn't match the dimensionality of the labels. Unless there is an additional step we are missing?
Here is a google collab with the aforementioned warning: https://colab.research.google.com/drive/1n8C72CGlbgr5R9bN_3BaRDFkMgADFlr7?usp=sharing
That is perfect, we are glad we could help out! We will happily move/rename/or restructure the code in any way that best suits your revamp and the rest of your codebase 😄 |
return tf.math.greater(tf.reduce_sum(x, axis=0), 0) | ||
|
||
@tf.function | ||
def tf_pad_tokens( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks like it's just reproducing tokenizer.pad
- can I ask why you want to do it here instead of calling that function? So you can compile the whole thing into tf.function or a tf.data graph?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not long after posting this PR we ended up doing exactly that. We will be sure to upload this change after you finish your review.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Review is done! I'm ignoring the tests for now, since I realize those might change as we change the rest. Feel free to make any changes you like now, and I'll take another look tomorrow.
encoded_batch["input_ids"], encoded_batch["labels"] = self.tf_mask_tokens(padded_output) | ||
return encoded_batch | ||
|
||
@tf.function |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We've had issues where ragged tensors have extremely poor performance, particularly when creating them from input lists or similar - we're considering building our pipelines so that we do all the padding before converting to tf.Tensor as a result. Don't remove this code yet, but just be aware that we may want to avoid tf.ragged
as much as possible in our final implementation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We are both in agreement. Not long after submitting the PR, we ran into additional issues ourselves when trying to integrate the collator with ragged tensors that we didn't initially catch. We decided to do the work to square off the tensors before passing them into the collator.
else: | ||
return examples.to_tensor(0) | ||
|
||
def __call__(self, examples: tf.data.Dataset, is_ragged=False) -> tf.data.Dataset: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Whoops, I overlooked this bit! My intuition is that we should probably not be taking tf.data.Dataset
as input, but instead writing this as a function we can call with either dataset.map()
or with data as tensors or a dict of tensors. We don't want to lock people into tf.data
if they're trying to write eager code, or doing their own data input.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch! Originally we wrote this as a function that took in a tf.data.Dataset
due to the inclusion of ragged tensors. Now that we are planning on squaring off the tensors before inputting them, we can revisit a mapping approach.
So I've been thinking this over a bit more - my guess is that Given that we need to call a block of arbitrary Python code, that means we can't guarantee that the collation function will be compilable with I think we should go for the following:
That's a lot of changes, though I'm hopeful we could keep a lot of your code here as-is. Do you think it makes sense, or do you have any objections to any of it? |
In the meantime, I'm going to be working on this too - I'll take a different |
Hey, I've rewritten a few of the classes in our preferred style, but left the language modelling ones alone for now, you can see them here: #13105 We'd like to push ahead with this fairly soon, so if you'd like, you can try adjusting this PR to a similar style. If not, we can close this PR and I'll add the rest to my PR tomorrow. Either way, thank you for the contribution - whether or not we use the code directly, this PR was helpful in drawing our attention to the problem and to possible approaches for writing data collators that support frameworks besides Torch! |
This afternoon we started finalizing and adding some of those changes you've suggested in another branch. Once done, we will also adjust the code to match your preferred style shown in your new PR. We can merge those changes into the this PR here and you can feel free to just use this code in your PR or as a starting point for your revisions. Either way, no hard feelings, and we are glad we could help out in any way! |
I'm happy for you to submit your code, and I'll avoid any classes you're touching when I make my own PR! Which ones would you like to handle? |
Hey! We'd like to push to get this in soon, so we can proceed with a general overhaul of our TF data pipelines. At the same time, I know you're contributing code for free, and the rush is mostly caused by my own disorganization, so I don't want to force deadlines on you or anything! We'd like to move on and merge everything by Monday, so if you want to add any code today or this weekend, I'll grab it at that point and pull it into my PR. If not, then don't worry - what you've added up to now will already be quite helpful for the final PR, and we'll make sure that both of you get correct author/contributor credits for it regardless! |
Hey there! 😃 We just made some code changes to integrate more closely with your style and had all of our tests pass. We are finishing up lunch and then will go through a final review before updating the PR. |
…g_with_Tensorflow
…with_Tensorflow Merging review changes made by sdwalker62 and aromans
@sdwalker62 and I just pushed up our revisions based on your review and recent PR. We changed the name of the file to TFDataCollatorForMaskedLanguageModeling. Hopefully, this helps with your upcoming merge this Monday! Let us know if you need anything else, and we look forward to contributing to more things in the future! 😄 |
Thank you! We're just finishing off an upstream PR to |
Hey, just to update you: The code has been incorporated into my local copy, and I'm working on adding some other methods we need before I push it all to the other PR. I'll tag you as soon as that commit is in! |
Code is all in at #13105. I'm very likely to steal some of the test code from this PR too once we incorporate tests for all the classes, so I'll make sure you're acknowledged as contributors for that too! |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
Co-authored-by: Dalton Walker dalton_walker@icloud.com
What does this PR do?
We didn't see any support for TensorFlow within the DataCollatorForLanguageModeling data class. Integrating directly with TensorFlow seems useful for TensorFlow users and avoids the necessity for tensor conversion.
This PR adds a TFDataCollatorForLangaugeModeling data class that integrates directly with TensorFlow tensors and paves the way for further TFDataCollator conversions.
(Reopened PR #12179)
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@LysandreJik @Rocketknight1 @sgugger
Anyone in the community is free to review the PR.