Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize byte pair merge for really big tokens (40x faster for a 2500 token word) #239

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

paplorinc
Copy link
Contributor

@paplorinc paplorinc commented Jan 15, 2024

Continuing the optimizations started in #237 and #234, migrated from knuddelsgmbh/jtokkit#75, knuddelsgmbh/jtokkit#76, knuddelsgmbh/jtokkit#77.

This commit is mainly meant to address the issue of really big tokens spiraling out of control, see: #195

The original byte pair merge algorithm diverges quickly for longer character sequences in a superlinear way - e.g. a 20_000 character word (having 2500 tokens) can take several seconds to be tokenized.

Or on https://platform.openai.com/tokenizer:
image

The new algorithm scales so well that it could theoretically process the whole text in a single byte-pair-merge loop without any regex splitting (though it would need a different token set to be optimal since it produces slightly different results, mostly whitespaces, though - and it also consumes a lot more memory and is slower that the current one):
image

The new algorithm does the minimum search logarithmically and duplicates in constant time, but has a higher setup cost, so we're only using it for extreme cases (if the piece given by the regex is > 500 characters):

The benchmarking was done step-by-step in the Java clone and here retested in the way described in #234

110 multilingual books + some source codes + some big token files:

Before:

num_threads: 1, num_bytes: 98379144
tiktoken 	7,891,071 bytes / s
tiktoken 	7,927,160 bytes / s
tiktoken 	7,936,709 bytes / s
tiktoken 	7,912,032 bytes / s
tiktoken 	7,928,872 bytes / s

After:

num_threads: 1, num_bytes: 98379144
tiktoken 	9,494,719 bytes / s
tiktoken 	9,547,619 bytes / s
tiktoken 	9,525,891 bytes / s
tiktoken 	9,506,282 bytes / s
tiktoken 	9,563,429 bytes / s

From which only the big token files (the purpose of this PR):

Before:

num_threads: 1, num_bytes: 392784
tiktoken 	195,554 bytes / s
tiktoken 	195,998 bytes / s
tiktoken 	196,051 bytes / s
tiktoken 	194,724 bytes / s
tiktoken 	195,748 bytes / s

After:

num_threads: 1, num_bytes: 392784
tiktoken 	8,604,360 bytes / s
tiktoken 	8,628,191 bytes / s
tiktoken 	8,561,823 bytes / s
tiktoken 	8,675,756 bytes / s
tiktoken 	8,618,370 bytes / s

i.e. 40x faster for 20k character words.

And if we combine this with the previous regex optimizations, we're getting the following for the 110 books + sources + big tokens case:

num_threads: 1, num_bytes: 98379144
tiktoken 	12,043,907 bytes / s
tiktoken 	12,153,199 bytes / s
tiktoken 	12,173,271 bytes / s
tiktoken 	12,085,368 bytes / s
tiktoken 	12,147,123 bytes / s

i.e. 50% faster on average after all optimizations.

I recommend reviewing commit-by-commit:
image

@@ -61,13 +60,16 @@ def test_simple_regex():
def test_basic_encode():
enc = tiktoken.get_encoding("r50k_base")
assert enc.encode("hello world") == [31373, 995]
assert enc.encode("a" * 1000) == [24794] * 250
Copy link
Contributor Author

@paplorinc paplorinc Jan 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to cover the big encoder as well

src/lib.rs Outdated Show resolved Hide resolved
src/lib.rs Outdated Show resolved Hide resolved
src/lib.rs Outdated Show resolved Hide resolved
src/lib.rs Outdated Show resolved Hide resolved
src/lib.rs Outdated Show resolved Hide resolved
src/lib.rs Outdated Show resolved Hide resolved
src/lib.rs Outdated Show resolved Hide resolved
@@ -15,9 +16,22 @@ use rustc_hash::FxHashMap as HashMap;

type Rank = u32;

const LARGE_ENCODER_CHARACTER_LIMIT: usize = 500;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in the Java version this was controlled by an environmental variable, which enabled us to run all tests against both implementations - should I do it here as well?

.unwrap_or(&Rank::MAX)
};

let mut rank_indexes = BTreeMap::<Rank, BTreeSet<usize>>::new();
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

grouped by rank, the values ordered by index (basically a LinkedHashSet inside)

};

let mut rank_indexes = BTreeMap::<Rank, BTreeSet<usize>>::new();
let mut index_rank = vec![Rank::MAX; piece.len() + 1];
Copy link
Contributor Author

@paplorinc paplorinc Jan 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mutations seemed easier this way, compared to creating a struct with index/rank/prev/next - especially in Rust

let mut token_count = piece.len();
while token_count > 2 && rank_indexes.len() > 1 {
let mut skip_next = false;
if let Some((_, nodes)) = rank_indexes.pop_first() {
Copy link
Contributor Author

@paplorinc paplorinc Jan 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

next min is popped off in logarithmic time instead of linearly

while token_count > 2 && rank_indexes.len() > 1 {
let mut skip_next = false;
if let Some((_, nodes)) = rank_indexes.pop_first() {
for &min_node in &nodes {
Copy link
Contributor Author

@paplorinc paplorinc Jan 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

duplicates are processed in bulk (since the next min is strictly greater than equal), no need to remove them one-by-one

if let Some((_, nodes)) = rank_indexes.pop_first() {
for &min_node in &nodes {
if skip_next {
skip_next = false;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when merging neighboring elements with the same ranks

@paplorinc paplorinc changed the title 2) Optimize byte pair merge for really big tokens (40x faster for a 2500 token piece) 2) Optimize byte pair merge for really big tokens (40x faster for a 2500 token word) Jan 15, 2024
}
) -> Vec<(usize, Rank)> {
if piece.len() < LARGE_ENCODER_CHARACTER_LIMIT {
_byte_pair_merge_small(ranks, piece) // Quadratic, but lightweight
Copy link
Contributor Author

@paplorinc paplorinc Jan 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

quadtratic algo is usually faster for very small words - which is always the case for natural language, but e.g. DNA sequences or a DOS attack can be avoided by switching to the linearithmic algo

src/lib.rs Outdated Show resolved Hide resolved
Comment on lines +114 to +126
let prev_node = index_prev[min_node];
let next_node = index_next[min_node];
let next_next_node = index_next[next_node];
let next_next_next_node = index_next[next_next_node];
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're keeping track of the order of characters inside the rank-balanced tree, providing logarithmic access to the minimum rank and constant access to the previous/next

@hauntsaninja
Copy link
Collaborator

hauntsaninja commented Feb 9, 2024

This is great!

Thanks for keeping a simple to follow history. Most of the commits here are straightforward, I've separated them into different PRs (I've preserved authorship information, but let me know if you'd prefer to re-open them yourself)

I'm looking at 8f5dd7d and d24b67b now...

@paplorinc
Copy link
Contributor Author

paplorinc commented Feb 9, 2024

Thanks a lot for the thorough review, Shantanu.
Let me know if you need any help in speeding up the process! :)

After merging you may want to update the benchmark results in the readme.

hauntsaninja added a commit that referenced this pull request Feb 11, 2024
Based on suggestion in #239
(specifically 8f5dd7d)

Like that commit, this:
- Does the init in a single loop and saves a loop if there are no merges
- Simplifies get_rank and no longer uses it in init (so you don't need
multiple skip values)

Unlike that commit:
- We drop optimisations enabled by ignoring single tokens. These didn't
show any benefit on benchmarks for me (this makes sense given typical
piece sizes, but let me know if that's unexpected!). Given this, I opted
for the simpler version.
- I preserve some of the comments from the original that I think are
still useful

Co-authored-by: @paplorinc

---------

Co-authored-by: Lőrinc Pap <1841944+paplorinc@users.noreply.github.com>
@paplorinc paplorinc force-pushed the paplorinc/add-linearithmic-byte-pair-merge branch from 24d68bd to b7c6ac8 Compare February 11, 2024 13:42
}
}

successors(Some(0), |&n| index_next.get(n).filter(|&&x| x != usize::MAX).copied())
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

iterate until there's a valid rank

@paplorinc paplorinc changed the title 2) Optimize byte pair merge for really big tokens (40x faster for a 2500 token word) Optimize byte pair merge for really big tokens (40x faster for a 2500 token word) Feb 11, 2024
@paplorinc
Copy link
Contributor Author

@hauntsaninja, I've rebased this PR, removing the merged commits and adjusting the result a bit based on your previous preferences.
Hope it helps. Feel free to push on top of this PR or open a different one, whichever's easier.

let mut index_next = vec![usize::MAX; piece.len() + 1];

let get_rank = |start_idx: usize, end_idx: usize| -> Rank {
*piece.get(start_idx..end_idx)
Copy link
Contributor Author

@paplorinc paplorinc Feb 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when end_idx is out of bounds we're defaulting to Rank::MAX

Lőrinc added 2 commits February 11, 2024 17:03
We're storing the ranks in a sorted tree of sorted (or linked) trees.
Getting the minimum rank is logarithmic and each subsequent occurrence is constant time.
To know the previous and next indexes (and the corresponding ranks), we're storing them in arrays (the keys are the indexes). We're updating each after finding the minimum via the tree.
We're iterating duplicates without removing them one-by-one, but if they are neighbors, we're skipping them manually.
@paplorinc paplorinc force-pushed the paplorinc/add-linearithmic-byte-pair-merge branch from b7c6ac8 to 5af8058 Compare February 11, 2024 16:04
let min_rank = index_rank[min_node];

let prev_node = index_prev[min_node];
let next_node = index_next[min_node];
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

getting next and previous requires lookups now

@vvolhejn
Copy link

vvolhejn commented Apr 4, 2024

Hi! Any updates on this PR? It'd be great to have this 🙏 @hauntsaninja

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants