-
Notifications
You must be signed in to change notification settings - Fork 210
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add LRU cache, add faster tokenization #37
Conversation
Adding LRU cache and speeding up tokenization.
Removing _old method. Note that the chinese token processing is optional and not used currently in training.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's remove chinese code! Otherwise looks good! Some small modification to improve the quality of the code, and I think we're good to go.
Also I think it's safe to say that we won't be running python2. Typically it's enforces in setup.py
.
def normalize_token_py3(self, token): | ||
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) | ||
ret = [bpe_token for bpe_token in self.bpe(token).split(' ')] | ||
return ret |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def normalize_token_py3(self, token): | |
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) | |
ret = [bpe_token for bpe_token in self.bpe(token).split(' ')] | |
return ret | |
def normalize_token(self, token): | |
# I think it's safe to say that sys.version_info[0] is constant across the whole run. | |
if sys.version_info[0] != 2: | |
token = ''.join(self.byte_encoder[ord(b)] for b in token) | |
else: | |
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) | |
ret = [bpe_token for bpe_token in self.bpe(token).split(' ')] | |
return ret |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the if statement being outside the for loop is better.
token = ''.join(self.byte_encoder[ord(b)] for b in token) | ||
bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(' ')) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
token = ''.join(self.byte_encoder[ord(b)] for b in token) | |
bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(' ')) | |
bpe_tokens.extend(self.normalize_token(token)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the if statement being outside the for loop is better.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think a better version would just to ignore that case since python 2 is not supported.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, but then it would be inconsistent since we STILL have code that tests for python2.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove that code and anywhere else where similar branches are found?
Should I create another PR and close this one? |
No you can keep working on that branch, and changes will be available in this PR directly. |
Just a small comment, recently I processed mC4 dataset. I found that if you have fair-amount of cpu (~65-128) machine the given pre-processing is fairly faster. The processing speed is around 25MiB/s-30Mib/s. |
Note, on JZ we have 40 cpu cores / node. Also note, we don't need to support python2. Almost no project does it any longer. |
@sbmaruf hopefully with #18 we should obtain higher speed, it usually obtains 40-45mb/s on JZ using 40 cpu physical cores. When @ontocord and I tried to combine this PR with #18, we obtained 60Mb/s (though there's still the task of merging datasets after). I'll check again once #18 is merged. We've benchmarked some preprocessing, though that branch is gone now. Basically this tokenizer was much faster that HF tokenizer (even the fast version) by using the cache mechanism. |
FYI, I think I saw this in another branch, but I don't know when this was changed in preprocess_data, but this line from megatron.data.indexed_dataset import best_fitting_dtype Needs to come after this line: sys.path.append(os.path.abspath(os.path.join(os.path.dirname(file), This doesn't have anything to do with the tokenizer. I can push in this PR too if you guys want. It's a different issue. |
The path needs to be set before we can find the "megatron" package.
Adding comments about max_token_len_cache
Also, as a note so we have a record of it. In the new gpt2 tokenize method, we use this class variable max_token_len_cache, which is fixed at 9 for now to determine whether to cache a normalized token, but we can change this by doing tokenizer.max_token_len_cache = X to test performance vs. cache usage if we want. Thanks to Tal Perry for the suggestion to memoize the token. |
@thomasw21 @stas00 lmk if you need anything else or otherwise we can merge. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some minor comments. Otherwise LGTM! Thanks!
token = ''.join(self.byte_encoder[ord(b)] for b in token) | ||
bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(' ')) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think a better version would just to ignore that case since python 2 is not supported.
Co-authored-by: Thomas Wang <24695242+thomasw21@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you, @ontocord
Left a few minor style suggestions.
Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>
looks like github hides suggestions when they are done for the same code that has been resolved, please also merge this: |
Can we merge this now? |
Note that this supprts LRU caching in Python3 only. In python2, the caching is removed from BPE altogether.