Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] TensorFlow variant of DataCollatorForLanguageModeling. #12199

Closed
wants to merge 35 commits into from
Closed

[WIP] TensorFlow variant of DataCollatorForLanguageModeling. #12199

wants to merge 35 commits into from

Conversation

aromans
Copy link
Contributor

@aromans aromans commented Jun 16, 2021

Co-authored-by: Dalton Walker dalton_walker@icloud.com

What does this PR do?

We didn't see any support for TensorFlow within the DataCollatorForLanguageModeling data class. Integrating directly with TensorFlow seems useful for TensorFlow users and avoids the necessity for tensor conversion.

This PR adds a TFDataCollatorForLangaugeModeling data class that integrates directly with TensorFlow tensors and paves the way for further TFDataCollator conversions.

(Reopened PR #12179)

Before submitting

Who can review?

@LysandreJik @Rocketknight1 @sgugger

Anyone in the community is free to review the PR.

Co-authored-by: Dalton Walker <dalton_walker@icloud.com>
@sgugger
Copy link
Collaborator

sgugger commented Jun 16, 2021

Thanks a lot for your PR!

Before I review more in detail, could you provide an example of use of this API? Data-collators are very PyTorch-ic so I want to make sure this is something that can actually be used in TensorFlow without too many contorsions.

@aromans
Copy link
Contributor Author

aromans commented Jun 16, 2021

Thanks a lot for your PR!

Before I review more in detail, could you provide an example of use of this API? Data-collators are very PyTorch-ic so I want to make sure this is something that can actually be used in TensorFlow without too many contorsions.

Absolutely! We are currently in the process of pretraining Bert with a custom dataset in a domain specific language. We are going to make use of the TFBertForPreTraining Model to achieve this as well as a custom trained Tokenizer. (https://huggingface.co/transformers/model_doc/bert.html#tfbertforpretraining)
Specifically we started with the collator for language modeling to make our training data consistent with MLM and NSP tasks. The collator provided that functionality along with batching but only for PyTorch.
We wanted to provide the functionality that existed for PyTorch for TensorFlow users, and plan on completing the entire API for TensorFlow support if desired.
If you need specific implementation details we are willing to expand further.

@sgugger
Copy link
Collaborator

sgugger commented Jun 16, 2021

Do you have an example of data preprocessing a bit similar to the run_mlm script we have in PyTorch? That would be helpful to see this TF data collator in action.

@aromans
Copy link
Contributor Author

aromans commented Jun 16, 2021

Do you have an example of data preprocessing a bit similar to the run_mlm script we have in PyTorch? That would be helpful to see this TF data collator in action.

We are going to move this PR into a WIP so we can address your question.

@aromans aromans changed the title TensorFlow variant of DataCollatorForLanguageModeling. [WIP] TensorFlow variant of DataCollatorForLanguageModeling. Jun 16, 2021
@sdwalker62
Copy link
Contributor

In answer to your question @sgugger, our objective is to integrate the collator with TFTrainer. Currently PyTorch users enjoy this functionality but TensorFlow users do not have the built-in functionality that deserves to be there (unless we are mistaken, and if so apologize). Our idea is to implement the following change in TFTrainer/get_train_tfdataset:

if tf_collate_fn is None:
    ds = (
        self.train_dataset.repeat()
        .shuffle(self.num_train_examples, seed=self.args.seed)
        .batch(self.total_train_batch_size, drop_remainder=self.args.dataloader_drop_last)
        .prefetch(tf.data.experimental.AUTOTUNE)
    )
else
    ds = (
        self.train_dataset.repeat()
        .shuffle(self.num_train_examples, seed=self.args.seed)
        .batch(self.total_train_batch_size, drop_remainder=self.args.dataloader_drop_last)
        .map(tf_collate_fn)
        .prefetch(tf.data.experimental.AUTOTUNE)
    )

or we could implement the dataset conversion in the collator:

if not tf_collate_fn is None:
    ds = tf_collate_fn(ds)
else:
    ds = (
        self.train_dataset.repeat()
        .shuffle(self.num_train_examples, seed=self.args.seed)
        .batch(self.total_train_batch_size, drop_remainder=self.args.dataloader_drop_last)
        .prefetch(tf.data.experimental.AUTOTUNE)
    )

This would provide an avenue for TensorFlow users to train any models requiring collator functionality in TFTrainer.

Any advice or alternative solutions are welcome!

@sgugger
Copy link
Collaborator

sgugger commented Jun 17, 2021

We plan to drop the TFTrainer pretty soon to the profit of using Keras, but this could still be useful as we will still rely on the datasets.
I think the best API would be to apply it to a TensorFlow dataset but @Rocketknight1 might have other views.

@Rocketknight1
Copy link
Member

Our intention is to drop TFTrainer to do training through Keras instead, and as a result in TF we want the input to come from tf.data.Dataset objects rather than custom collators.

A lot of things like multi-GPU or TPU training in Keras expect tf.data.Dataset input, and will coerce the input into a Dataset if you don't supply it as one.

@sdwalker62
Copy link
Contributor

@Rocketknight1 Understood. So providing a collator that could be passed to Dataset.map is the way to go if we want the option. Or are you saying that such an operation should be performed before TFTrainer?

I just want to clarify before we continue with a PR.

@Rocketknight1
Copy link
Member

We want to avoid TFTrainer entirely in future, so yeah - any kind of custom collator should return a Dataset, or should work through Dataset.map(). This is something we're in the process of updating through our library - there's still a lot of usages of TFTrainer that I'm cleaning up over time!

@sdwalker62
Copy link
Contributor

Thank you for your quick response!

We will continue with the PR going down the .map route. Even though TFTrainer is depreciating, some may still find it beneficial in the meantime.

Cheers!

aromans and others added 2 commits July 14, 2021 12:07
Co-authored-by: Dalton Walker <dalton_walker@icloud.com>
@github-actions
Copy link

github-actions bot commented Aug 8, 2021

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@Rocketknight1
Copy link
Member

Hey! This isn't something I want to go stale, but I lost track of it when I saw you were still adding commits! Are you happy with it as-is, and ready for a review?

@sdwalker62
Copy link
Contributor

That is no problem! And we are ready for a review at your convenience.

@Rocketknight1
Copy link
Member

Hi! I'm reviewing now. This is actually quite timely - we're planning a general revamp of all the data collators to support both Tensorflow and JAX, as well as support for our Dataset objects to automatically convert to tf.data.Dataset, which will almost certainly include the new data collation functions as part of the tf.data pipeline.

The downside is that we haven't decided how exactly to structure the code yet, so we might ask you to move or rename this class, but hopefully we can use almost all of the code here as part of the revamp!

self.mlm_probability = 0.15

def __post_init__(self):
if self.tokenizer.mask_token is None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can see this collator only supports MLM rather than CLM, so we might need to either include CLM support, or rename it to something like TFDataCollatorForMaskedLanguageModeling and have a separate CLM collator. My preference would be for the former, but I realize that might cause some code complexity.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is something that we've discussed ourselves as well. We wanted to submit what we had to assure that you and the team agreed with our approach, and had planned on finishing out the TFDataCollator suite so it is on par with the Pytorch Data Collator. If CLM collator is highest on priority, that can be the next Data Collator @sdwalker62 and I tackle.


@tf.function
def pseudo_bernoulli(self, prob_matrix, labels):
return tf.cast(prob_matrix - tf.random.uniform(tf.shape(labels), 0, 1) >= 0, tf.bool)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a clever hack to get around the lack of tf.random.bernoulli, lol.

@tf.function
def mask_special_tokens(self, labels, special_tokens):
# Finds all special tokens within labels
x = tf.map_fn(lambda b: tf.cast(tf.math.equal(labels, b), tf.int32), special_tokens)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe it should be possible to do this without needing map_fn and a lambda, which can be slow - just use tf.math.equal. You can pass multiple elements to that, which will be broadcasted. So if you have e.g. 3 special tokens, then you'll get an extra dimension of length 3 in the output, and you can just squash that with tf.math.reduce_any() to see if any of the special tokens matched, without needing a cast to tf.int32.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We've attempted something similar before taking this approach, this was definitely our last resort 😆 . We found tf.math.equal doesn't accept incompatible shapes since the special tokens mask usually doesn't match the dimensionality of the labels. Unless there is an additional step we are missing?

Here is a google collab with the aforementioned warning: https://colab.research.google.com/drive/1n8C72CGlbgr5R9bN_3BaRDFkMgADFlr7?usp=sharing

@aromans
Copy link
Contributor Author

aromans commented Aug 10, 2021

Hi! I'm reviewing now. This is actually quite timely - we're planning a general revamp of all the data collators to support both Tensorflow and JAX, as well as support for our Dataset objects to automatically convert to tf.data.Dataset, which will almost certainly include the new data collation functions as part of the tf.data pipeline.

The downside is that we haven't decided how exactly to structure the code yet, so we might ask you to move or rename this class, but hopefully we can use almost all of the code here as part of the revamp!

That is perfect, we are glad we could help out! We will happily move/rename/or restructure the code in any way that best suits your revamp and the rest of your codebase 😄

return tf.math.greater(tf.reduce_sum(x, axis=0), 0)

@tf.function
def tf_pad_tokens(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks like it's just reproducing tokenizer.pad - can I ask why you want to do it here instead of calling that function? So you can compile the whole thing into tf.function or a tf.data graph?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not long after posting this PR we ended up doing exactly that. We will be sure to upload this change after you finish your review.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review is done! I'm ignoring the tests for now, since I realize those might change as we change the rest. Feel free to make any changes you like now, and I'll take another look tomorrow.

encoded_batch["input_ids"], encoded_batch["labels"] = self.tf_mask_tokens(padded_output)
return encoded_batch

@tf.function
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We've had issues where ragged tensors have extremely poor performance, particularly when creating them from input lists or similar - we're considering building our pipelines so that we do all the padding before converting to tf.Tensor as a result. Don't remove this code yet, but just be aware that we may want to avoid tf.ragged as much as possible in our final implementation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are both in agreement. Not long after submitting the PR, we ran into additional issues ourselves when trying to integrate the collator with ragged tensors that we didn't initially catch. We decided to do the work to square off the tensors before passing them into the collator.

@Rocketknight1 Rocketknight1 self-assigned this Aug 10, 2021
else:
return examples.to_tensor(0)

def __call__(self, examples: tf.data.Dataset, is_ragged=False) -> tf.data.Dataset:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Whoops, I overlooked this bit! My intuition is that we should probably not be taking tf.data.Dataset as input, but instead writing this as a function we can call with either dataset.map() or with data as tensors or a dict of tensors. We don't want to lock people into tf.data if they're trying to write eager code, or doing their own data input.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch! Originally we wrote this as a function that took in a tf.data.Dataset due to the inclusion of ragged tensors. Now that we are planning on squaring off the tensors before inputting them, we can revisit a mapping approach.

@Rocketknight1
Copy link
Member

So I've been thinking this over a bit more - my guess is that tokenizer.pad probably cannot/shouldn't be compiled with tf.function. It's effectively a totally arbitrary function, and every new model we add might have a different one, so we couldn't make any guarantee that AutoGraph will play nicely with it, even though in testing it seemed to work for me on a few common cases. For the same reasons, we shouldn't try to reimplement tokenizer.pad like you did with tf_pad_tokens, because at any moment a model could come along that would require a fresh rewrite of that.

Given that we need to call a block of arbitrary Python code, that means we can't guarantee that the collation function will be compilable with tf.function or Dataset.map(), although we could still use it in a tf.data pipeline by either using it when the data is loaded with from_generator, or wrapping it in py_function to allow it to be used in Dataset.map().

I think we should go for the following:

  1. The function should take input as either tf.Tensor or nested (possibly variable-length) lists. It could optionally accept np.ndarray or tf.ragged.RaggedTensor too.
  2. No tf.function anywhere - code is pure Python
  3. We can possibly have some kind of 'master' function that takes an argument like return_tensors and will call the framework-specific collators based on the argument value, but this is something we can implement later.

That's a lot of changes, though I'm hopeful we could keep a lot of your code here as-is. Do you think it makes sense, or do you have any objections to any of it?

@Rocketknight1
Copy link
Member

In the meantime, I'm going to be working on this too - I'll take a different DataCollator class and try to write a TF equivalent of it tomorrow. If I run into any issues there I'll let you know.

@Rocketknight1
Copy link
Member

Hey, I've rewritten a few of the classes in our preferred style, but left the language modelling ones alone for now, you can see them here: #13105

We'd like to push ahead with this fairly soon, so if you'd like, you can try adjusting this PR to a similar style. If not, we can close this PR and I'll add the rest to my PR tomorrow. Either way, thank you for the contribution - whether or not we use the code directly, this PR was helpful in drawing our attention to the problem and to possible approaches for writing data collators that support frameworks besides Torch!

@aromans
Copy link
Contributor Author

aromans commented Aug 12, 2021

Hey, I've rewritten a few of the classes in our preferred style, but left the language modelling ones alone for now, you can see them here: #13105

We'd like to push ahead with this fairly soon, so if you'd like, you can try adjusting this PR to a similar style. If not, we can close this PR and I'll add the rest to my PR tomorrow. Either way, thank you for the contribution - whether or not we use the code directly, this PR was helpful in drawing our attention to the problem and to possible approaches for writing data collators that support frameworks besides Torch!

This afternoon we started finalizing and adding some of those changes you've suggested in another branch. Once done, we will also adjust the code to match your preferred style shown in your new PR. We can merge those changes into the this PR here and you can feel free to just use this code in your PR or as a starting point for your revisions. Either way, no hard feelings, and we are glad we could help out in any way!

@Rocketknight1
Copy link
Member

I'm happy for you to submit your code, and I'll avoid any classes you're touching when I make my own PR! Which ones would you like to handle?

@Rocketknight1
Copy link
Member

Hey! We'd like to push to get this in soon, so we can proceed with a general overhaul of our TF data pipelines. At the same time, I know you're contributing code for free, and the rush is mostly caused by my own disorganization, so I don't want to force deadlines on you or anything!

We'd like to move on and merge everything by Monday, so if you want to add any code today or this weekend, I'll grab it at that point and pull it into my PR. If not, then don't worry - what you've added up to now will already be quite helpful for the final PR, and we'll make sure that both of you get correct author/contributor credits for it regardless!

@sdwalker62
Copy link
Contributor

sdwalker62 commented Aug 13, 2021

Hey there! 😃 We just made some code changes to integrate more closely with your style and had all of our tests pass. We are finishing up lunch and then will go through a final review before updating the PR.

@aromans
Copy link
Contributor Author

aromans commented Aug 13, 2021

@sdwalker62 and I just pushed up our revisions based on your review and recent PR. We changed the name of the file to TFDataCollatorForMaskedLanguageModeling. Hopefully, this helps with your upcoming merge this Monday! Let us know if you need anything else, and we look forward to contributing to more things in the future! 😄

@Rocketknight1
Copy link
Member

Thank you! We're just finishing off an upstream PR to Datasets, at which point I'll be merging your code into the other DataCollator PR and getting the rest of the team to review it.

@Rocketknight1
Copy link
Member

Hey, just to update you: The code has been incorporated into my local copy, and I'm working on adding some other methods we need before I push it all to the other PR. I'll tag you as soon as that commit is in!

@Rocketknight1
Copy link
Member

Code is all in at #13105. I'm very likely to steal some of the test code from this PR too once we incorporate tests for all the classes, so I'll make sure you're acknowledged as contributors for that too!

@github-actions
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot closed this Sep 20, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants