Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Switch from using sum for flattening lists of lists in group_texts #14472

Merged
merged 4 commits into from
Nov 22, 2021

Conversation

nbroad1881
Copy link
Contributor

Speed up list flattening in group_texts by changing sum(list_of_lists, []) to functools.reduce(operator.iconcat, list_of_lists, [])

I changed all list flattening from sum(list_of_lists, []) to functools.reduce(operator.iconcat, list_of_lists, []).

Here is a stack overflow thread about which method is fastest: https://stackoverflow.com/a/45323085

Here is a colab notebook that shows a quick example between the old way and the new way and a couple of timed examples. The new way is about 5-6x faster. https://colab.research.google.com/drive/1Kxj_JbM9HMLFpjUduy6i3tfqDob_pYIp?usp=sharing

I discovered this while trying to use group_texts on many GB of data, and the speedup was greatly appreciated.

Nearly all of these changes are in run_mlm or run_clm examples, but there are a couple in run_swag and another
in file_utils.py which might be unnecessary.

I don't know why make style is moving import functools to its own line above the other imports in examples/flax/language-modeling/run_t5_mlm_flax.py and examples/tensorflow/language-modeling/run_clm.py

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

I think @sgugger wrote the original group_texts

@sgugger
Copy link
Collaborator

sgugger commented Nov 20, 2021

Thanks for investigating the performance of that line! I initially used the sum() because it's easy to read. The functools.reduce(operator.iconcat, list_of_lists, []) is way less readable, but it gives so much in speed that switching may be a good thing.

Wdyt @LysandreJik ?

@nbroad1881
Copy link
Contributor Author

nbroad1881 commented Nov 21, 2021

I was actually confused at first because I didn't know that sum() could be used to flatten a list of lists like that. Perhaps just adding a comment explaining that it is flattening the list of lists into one list will suffice? Something along the lines of, "This concatenates all sequences together by flattening the list of lists"

Alternatively, would chain.from_iterable(list_of_lists) or chain(*list_of_lists) seem more readable?

@sgugger
Copy link
Collaborator

sgugger commented Nov 21, 2021

I think chain(*list_of_lists) is the most readable, and is even clearer than the sum thing.
It makes almost no difference in time from your notebook, so let's go with this one?

@nbroad1881
Copy link
Contributor Author

Ok I'll make the changes. Do you know why make style moved import functools into its own line in examples/flax/language-modeling/run_t5_mlm_flax.py and examples/tensorflow/language-modeling/run_clm.py?

Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good to me ! The chain(*examples) is indeed way clearer than functools.reduce(operator.iconcat, list_of_lists, []) and sum(examples, []).

Thanks for looking into it, @nbroad1881!

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM as well, thanks for amending your PR! Let's just remove all those blank new lines before merging.

examples/flax/language-modeling/run_clm_flax.py Outdated Show resolved Hide resolved
examples/flax/language-modeling/run_mlm_flax.py Outdated Show resolved Hide resolved
examples/flax/language-modeling/run_t5_mlm_flax.py Outdated Show resolved Hide resolved
examples/pytorch/language-modeling/run_clm.py Outdated Show resolved Hide resolved
examples/pytorch/multiple-choice/run_swag.py Outdated Show resolved Hide resolved
examples/tensorflow/language-modeling/run_clm.py Outdated Show resolved Hide resolved
@nbroad1881
Copy link
Contributor Author

nbroad1881 commented Nov 22, 2021

I did a couple more tests in this notebook: https://colab.research.google.com/drive/1Kxj_JbM9HMLFpjUduy6i3tfqDob_pYIp

Edit: This actually didn't work. Let me try to fix it.
One way to improve readability would be to make a utility function like this: ravel = functools.partial(functools.reduce, operator.iconcat, []) so then you could just use ravel(x) inside group_texts.

This works: ravel = functools.partial(functools.reduce, operator.iconcat) so then you could just use ravel(x, []) inside group_texts. ravel would mirror what torch.ravel and np.ravel do. I'm not sure if that is more or less confusing.

Edit: it is same with/without partial
Edit: double-checking this right now
This method actually performed the fastest out of everything I tried.

Here is a summary of the methods when using group_texts on SQuAD contexts where x is a list of lists (each time is a 'best of 5' except sum which is a 'best of 3'):

  1. sum(x, []) - 2min 47s
  2. list(chain.from_iterable(x)) - 27.2 s
  3. list(chain(*x)) - 27.1 s
  4. functools.reduce(operator.iconcat, x, []) - 26.8 s
  5. functools.partial(functools.reduce, operator.iconcat)(x, []) - 26.8 s
  6. np.ravel(x) - 28.9 s
  7. [b for a in x for b in a] - 28.4 s per loop

@sgugger
Copy link
Collaborator

sgugger commented Nov 22, 2021

I think that option 3 (list(chain(*x))) is the best compromise in terms of readability vs speed (only 3ms longer than the best run, which might also be dataset-dependent).

Thanks a lot for benchmarking all options!

per sgugger's suggestions

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
@sgugger sgugger merged commit 69e16ab into huggingface:master Nov 22, 2021
@nbroad1881 nbroad1881 deleted the fastest_list_flatten branch November 23, 2021 01:34
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants