Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding to_tf_dataset method #2731

Merged
merged 46 commits into from
Sep 16, 2021
Merged

Adding to_tf_dataset method #2731

merged 46 commits into from
Sep 16, 2021

Conversation

Rocketknight1
Copy link
Member

Oh my god do not merge this yet, it's just a draft.

I've added a method (via a mixin) to the arrow_dataset.Dataset class that automatically converts our Dataset classes to TF Dataset classes ready for training. It hopefully has most of the features we want, including streaming from disk (no need to load the whole dataset in memory!), correct shuffling, variable-length batches to reduce compute, and correct support for unusual padding. It achieves that by calling the tokenizer pad method in the middle of a TF compute graph via a very hacky call to tf.py_function, which is heretical but seems to work.

A number of issues need to be resolved before it's ready to merge, though:

  1. Is a MixIn the right way to do this? Do other classes besides arrow_dataset.Dataset need this method too?
  2. Needs an argument to support constant-length batches for TPU training - this is easy to add and I'll do it soon.
  3. Needs the user to supply the list of columns to drop from the arrow Dataset. Is there some automatic way to get the columns we want, or see which columns were added by the tokenizer?
  4. Assumes the label column is always present and always called "label" - this is probably not great, but I'm not sure what the 'correct' thing to do here is.

Copy link
Contributor

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is very... PyTorchic XD
I like the design, I just think it can be made more general by using a data_collator instead of a tokenizer. My only concern is how it will go in term of performance (since TF might not like the PyTorch-iness of it all) but since we're just grabbing tokenized texts and maybe padding, this shouldn't be too much of a problem.

For computer vision though, we should see if there is a way to make sure to use several processes to prepare the batches, or does the final map do that automatically? Knowing TF I doubt it but one can hope.

src/datasets/arrow_dataset.py Outdated Show resolved Hide resolved
src/datasets/arrow_dataset.py Outdated Show resolved Hide resolved
src/datasets/arrow_dataset.py Outdated Show resolved Hide resolved
@Rocketknight1
Copy link
Member Author

This seems to be working reasonably well in testing, and performance is way better. tf.py_function has been dropped for an input generator, but I moved as much of the code as possible outside the generator to allow TF to compile it correctly. I also avoid tf.RaggedTensor at all costs, and do the shuffle in the dataset followed by accessing sequential chunks, instead of shuffling an index tensor. The combination of all of these gives us a more flexible data loader as well as a ~20X boost in performance compared to the first solution.

Copy link
Contributor

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking good! Just a few more comments on the API.

src/datasets/arrow_dataset.py Outdated Show resolved Hide resolved
src/datasets/arrow_dataset.py Outdated Show resolved Hide resolved
@Rocketknight1
Copy link
Member Author

I made a change to the TFFormatter in this PR that will need some changes to the tests, so I wanted to ping @lhoestq and anyone else before I made those changes.

The key problem is that up until now the TFFormatter always returns RaggedTensor, created using the very slow tf.ragged.constant function. This is a big performance penalty, but it's also (imo) surprising for users - RaggedTensor handles tensors where one dimension has variable length. This is a good choice for tokenized datasets with variable sequence length, but it's an odd choice when the non-batch dimensions are constant, such as in image datasets, or in datasets where all samples are padded to the same length (e.g. for TPU training).

The change I made was to try to return standard Tensor objects instead of RaggedTensor when all the samples in the batch had the same shape, and if that was not the case to fall back to fast RaggedTensor creation with tf.ragged.stack, and only falling back to the very slow tf.ragged.constant function as a last resort. I think this will match user expectations in most cases and greatly improve performance, but it's a (very slightly) breaking change, so any feedback is welcome!

@Rocketknight1
Copy link
Member Author

Also I really can't emphasize enough how slow tf.ragged.constant is, it's bad enough to create a data pipeline bottleneck in more or less any training setup:
image

Copy link
Member

@lhoestq lhoestq left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm fine with this change to use tf tensors instead of ragged tensors, and it's nice to see that there is a fallback on ragged tensors anyway. I'm very impressed by the speed gains

This is indeed a breaking change, but I agree with you that in the end it's the only way to have a proper speed. It's also always better to get actual tensors rather than ragged tensors when possible.

The API looks fine to me :)

Maybe in the future people will be happy to have more control over the shuffling (setting the parameters to pass to Dataset.shuffle), but for now I think it's fine

Comment on lines 219 to 220
columns,
batch_size,
shuffle,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could they be optional parameters ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking about that! It's unclear to me what the defaults should be for columns or batch_size though, and I really wanted shuffle to be a required parameter to ensure people were aware of it, and that they didn't accidentally shuffle or skip shuffling their data when they didn't mean to.

I could maybe set batch_size to something like 32 by default and leave the other two as required parameters?

Copy link
Member

@lhoestq lhoestq Aug 30, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I see your point about shuffle.
And actually thinking more about it, it looks like we should require batch_size as well no ?

Maybe if columns is not specified then all of them are used ?

(this is just some random ideas, in the end we should just pick the one that fits the TF paradigm the best)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking about that, but usually our datasets have one or more string columns that we don't want, so the default of using all columns will probably not work most of the time. It'd be nice if we had some way to auto-detect relevant columns, but I can't think of how we'd do that, so I think the safest thing to do is to just ask them to specify.

@Rocketknight1
Copy link
Member Author

Hi @lhoestq, the tests have been modified and everything is passing. The Windows tests look to be failing for an unrelated reason, but other than that I'm ready to merge if you are!

@Rocketknight1 Rocketknight1 changed the title First draft of a method to auto-convert our datasets to TF datasets! Adding to_tf_dataset method Sep 2, 2021
@lhoestq
Copy link
Member

lhoestq commented Sep 6, 2021

Hi @Rocketknight1 ! Feel free to merge master into this branch to fix and run the full CI :)

@Rocketknight1
Copy link
Member Author

@lhoestq rebased onto master and it looks good! I'm doing some testing with new notebook examples, but are you happy to merge if that looks good?

Copy link
Collaborator

@mariosasko mariosasko left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This feature seems super cool.

A few nits in terms of style:

src/datasets/arrow_dataset.py Outdated Show resolved Hide resolved
src/datasets/arrow_dataset.py Outdated Show resolved Hide resolved
src/datasets/arrow_dataset.py Outdated Show resolved Hide resolved
Copy link
Member

@lhoestq lhoestq left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pushing this :)
Feel free to add docstrings + type hints + tests.
Let me know if I can help you with this

Also what do you think of adding it to the documentation as well ?

# We assume that if you're shuffling it's the train set, so we drop the remainder unless told not to
drop_remainder = shuffle
dataset = self.remove_columns([col for col in self.features if col not in cols_to_retain])
dataset.set_format("python")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that it is faster to use the numpy format rather than python, especially for tensors. There's a zero-copy conversion from the Arrow data to numpy).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Noted! I'll try to make it work in numpy format.

Comment on lines +303 to +347
cast_dtype = np.int64 if np.issubdtype(array.dtype, np.integer) else np.float32
array = array.astype(cast_dtype)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would this work for string types or nested types ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've had some success with nested dtypes (in multiple choice datasets). This does fail on string types though - the tf.data.Dataset is intended to be passed straight to a model, so the assumption was that everything coming out of it would be convertable to a tf.Tensor. We could possibly make strings work in this context, though - but I'd need to think about a more generic approach to building the dataset and doing shape inference.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok ! Maybe we can mention this in the docstring ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just mentioned that numeric data only are expected in the docstring :)

Copy link
Member

@lhoestq lhoestq left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot ! I think the PR is ready to be merged now :)

After that we may to update parts of the documentation:

  • add the method to the list of documented Dataset method in main_classes.rst
  • update the demo google colab
  • update the tensorflow parts of the documentation

Are there other changes that you wanted to do before merging ?

@Rocketknight1
Copy link
Member Author

@lhoestq No, I'm happy to merge it as-is and add documentation afterwards!

Copy link
Member

@lhoestq lhoestq left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perfect then :)

@Rocketknight1 Rocketknight1 merged commit fa09d37 into master Sep 16, 2021
@Rocketknight1 Rocketknight1 deleted the tf_dataset_conversion branch September 16, 2021 13:50
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Mutable columns argument breaks set_format
4 participants