diff --git a/crates/bpe/Cargo.toml b/crates/bpe/Cargo.toml index 4050236..3e2e190 100644 --- a/crates/bpe/Cargo.toml +++ b/crates/bpe/Cargo.toml @@ -8,8 +8,8 @@ crate-type = ["lib", "staticlib"] bench = false [[bench]] -name = "counting" -path = "benches/counting.rs" +name = "performance" +path = "benches/performance.rs" harness = false [features] diff --git a/crates/bpe/README.md b/crates/bpe/README.md index 0042795..0cd4c58 100644 --- a/crates/bpe/README.md +++ b/crates/bpe/README.md @@ -4,6 +4,7 @@ The main purpose of this library is to provide fast and correct token counting f As a by-product, it can also be used to efficiently encode those chunks if desired. For chunking the following operations are of interest: + 1) Split text after exactly n tokens at a character boundary. 1) Count tokens for sub-ranges of a text. 1) Incrementally count tokens while appending text to a chunk. @@ -29,6 +30,7 @@ This library presents novel algorithms to compute BPE encodings which address th ## Prior Art There are mostly three strategies for BPE encoding. + 1) Trivial solution. Search brute force for the most frequent pair in the encoded text according the dictionary and replace those occurrences. This has a `O(n^2)` complexity and is therefore not very appealing in production. 2) Heap based. Set up a heap with the frequencies. This improves the linear search time to a logarithmic factor. If done properly, the overall complexity reduces now to `O(n log n)`. 3) Split the input into sections of a maximum size first and then process each section individually. This shrinks in theory the complexity to `O(n)` if the section size is small enough. But it will in general produce now different results. In order to produce the "correct" encoding, one would need to choose split points at token boundaries. But without having the text encoded already, this is in general impossible. @@ -89,38 +91,38 @@ If BPE wants to make a different merge decision when it sees the full input, the Given a valid encoding sequence `e_0..e_i` and a valid encoding tuple `e_i e_j`, then `e_0..e_i e_j` is also a valid encoding sequence. - ## Novel Algorithm At a first glance, it seems impossible to achieve `O(n)` complexity while preserving the encoding output of the original BPE algorithm, since the original BPE algorithm needs to first scan the full input before it can make any encoding decision. -For instance, the sequence `abab` would be encoded as `ab ab` when the dictionary contains the tokens `a b ab ba bc abc babc ababc` ordered by frequency. But appending a single character `ababc` would result in a pretty different tokenization: `ab a bc`. So without looking ahead it seems impossible to properly tokenize the text. +For instance, the sequence `abac` would be encoded as `ab ac` when the dictionary contains the tokens `a b c ab cb ac` ordered by frequency. But appending a single character `abacb` would result in a pretty different tokenization: `ab a cb`. So without looking ahead it seems impossible to properly tokenize the text. + +The solution is to track the encodings of ALL text prefixes. For our example `abacb` we would get: -The solution is to track the encodings of ALL text prefixes. For our example `ababc` we would get: - `a` ------> `a` - `ab` -----> `ab` - `aba` ----> `ab a` -- `abab` ---> `ab ab` -- `ababc` --> `ab a bc` +- `abab` ---> `ab ac` +- `ababc` --> `ab a cb` This can be done much more efficiently thanks to Corollary IIa, since now only the last token of every prefix has to be remembered: - `a` ------> `a` - `ab` -----> `ab` - `aba` ----> `a` -- `abab` ---> `ab` -- `ababc` --> `bc` +- `abac` ---> `ac` +- `abacb` --> `bc` In order to reconstruct the full encoding for a specific prefix, one simply starts with the last token of that prefix, shortens the prefix by the extracted token and looks up the token associated with the shortened prefix and so on until the beginning of the text is reached. -For our example prefix `ababc`, this procedure executes the following steps and determines the correct encoding in reverse order: +For our example prefix `abacb`, this procedure executes the following steps and determines the correct encoding in reverse order: -- `ababc` -> `bc` +- `abacb` -> `cb` - `aba` ---> `a` - `ab` ----> `ab` - `` The actual challenge is to determine for every prefix this last token efficiently. -The prefix `abab` could for instance end with either the token `b` or `ab`, but only `ab` leads to a valid encoding sequence. +The prefix `abac` could for instance end with either the token `c` or `ac`, but only `ac` leads to a valid encoding sequence. But, Corollary IIa tells us that **one and only one** last token can be the correct one and Corollary IIIa shows us how to find it: We only have to check whether a possible next token is "compatible" with its previous token, i.e. whether the two tokens form a valid encoding sequence. @@ -136,6 +138,7 @@ Once that happens the reencoding will be different and the algorithm can stop. The actual implementation needs essentially at most 14 lookups for the most complex cases to determine whether two tokens are compatible or not. Putting all these pieces together leads to the following algorithmic sketch: + ```rust let last_tokens = vec![]; for pos in 0..text.len() { @@ -166,6 +169,7 @@ The main observation is that often the greedy heuristic picks already the correc In the cases, where it doesn't the algorithm has to somehow backtrack to the next tokenization until it converged to the correct solution. Our backtracking implementation solves the enumeration problem as follows: + 1) If the current tokenization sequence is valid, then append the longest matching token to the right. 2) Otherwise, replace the right most token with the next longest prefix token. 3) If there is no such token, then remove that token and go back to step 2. @@ -179,18 +183,96 @@ On average it is about ~4 faster, since the short-cuts usually pay off. ## Benchmarks -We compared our implementations with the tiktoken implementation on a MacBook Pro on a random input sequence: - -| Algorithm | Runtime | correct BPE output | -| ------------ | -------- | ---------- | -| Greedy | 100 µs | ✘ | -| Minimal | 300 µs | ✘ | -| Backtracking | 400 µs | ✔ | -| Dynamic Programming | 1300 µs | ✔ | -| TikToken | 1500 µs | ✘ | -| Heap | 1900 µs | ✔ | - -As can be seen, our Backtracking implementation beats the TikToken Rust implementation by ~4x. -And even the fully dynamic programming solution is faster with a more consistent runtime. -The tuned heap implementation is still quite competitive to TikToken (especially for smaller inputs). -If the requirement of correct BPE output can be relaxed, then the Greedy approach or the minimal encoding approach are the clear winners. \ No newline at end of file +We ran several benchmarks to compare performance of different encoders and a tiktoken implementation. +For the tiktoken implementation we used [tiktoken-rs](https://crates.io/crates/tiktoken-rs) library, a wrapper around OpenAI's tiktoken implementation. +Note that tiktoken does not run BPE on the full input text. +Instead it splits it into large chunks using a regex and runs BPE on the individual chunks. +We have not tried to see if that approach is compatible with our BPE implementation. +We benchmarked the following scenarios: + +- The first measures encoding runtime for our different encoders and the tiktoken Rust implementation. + This shows a ~3.5x performance improvement for our fastest correct encoder compared to the tiktoken library. + +- The second measures incremental encoding runtime, where the text is built up byte-by-byte. + This mode is not available in tiktoken, which only supports counting/encoding a complete text. + +- The third measures interval counting runtime, where tokens of sub-slices of a fixed text are counted. + The data structure we built specifically for this purpose can answer those interval counting requests in typically constant times after the initial linear preprocessing of the text. + This mode is not available in tiktoken, which only supports counting/encoding a complete text. + +All benchmarks were run single-threaded on a MacBook Pro M1. + +### Encoding + +Encoding is computing the tokens for a given text. +This benchmark compares several encoders: + +- The backtracking encoder uses the backtracking algorithm with memorisation based on top of a string matching automaton. +- The heap encoder uses a priority heap and a bitmask to represent token positions to implement the traditional BPE algorithm. +- The table encoder implements the raw dynamic programming algorithm proposed above. + +Two additional encoders are included that are faster but deviate from the original BPE encoding strategy: + +- The greedy encoder picks the left-longest token. +- The minimal encoder computes an encoding with the minimal number of tokens. + +The benchmark measured the runtime of encoding of slices of lengths 10, 100, 1000, and 10000 from a random 20000 token original text using the o200k token set. +(All encodings were computed from scratch for each slice.) + +The graph below shows encoding runtime vs slice length. +All encoders (except the heap encoder) show the expected linear runtime complexity. +The backtracking encoder, the fastest encoder that still returns correct results, shows a performance gain of approximately 3.5x compared to tiktoken. +The fully dynamic programming solution and the heap implementation are still quite competitive to TikToken (especially for smaller inputs). +If the requirement of correct BPE output can be relaxed, then the Greedy approach or the minimal encoding approach are the clear winners. + +![encoding runtime comparison](./benches/result/encoding-o200k.svg) + +### Incremental encoding + +Incremental encoding tokenizes a text while appending bytes. +This type of algorithm is interesting for use cases where a certain token budget must not be exceeded. +This benchmark shows the runtime for the appending encoder when a text is encoded byte-by-byte. +For comparison we show the runtime of the backtracking encoder when it encodes the whole text at once. + +The benchmark measured the runtime of encoding of slices of lengths 10, 100, 1000, and 10000 from a random 20000 token original using the o200k token set. + +The graph below shows encoding runtime vs slice length. +The overall runtime of byte-by-byte incremental encoder for encoding the full text is comparable to the runtime of the backtracking encoder, with only a constant factor overhead. +Note that this is a huge win for incremental use cases, which would otherwise require retokenization after each append, resulting in a quadratic slowdown. + +![appending runtime comparison](./benches/result/appending-o200k.svg) + +### Interval counting + +Interval counting is counting the tokens for a slice of an original text. +This benchmark uses two encoders: + +- The backtracking encoder encodes the slice from scratch. + This is similar to what one has to do with other libraries, like `tiktoken`. +- The interval encoder encodes the original text once and reuses that encoding to count tokens for intervals of the original text. + The initial encoding time for the interval encoder is comparable to that of the backtracking encoder. + +The benchmark measured the runtime of counting o200k tokens on slices of lengths 10, 100, 1000, and 10000 from a random 20000 token original text. + +The graph below shows counting runtime vs slice length. +The runtime of the backtracking encoder grows with the length of the slice. +The interval encoder counts any interval in typically constant time. + +![counting runtime comparison](./benches/result/counting-o200k.svg) + +### Running the benchmarks + +Run the benchmark as follows (required [cargo-criterion](https://crates.io/crates/cargo-criterion) installed): + +```sh +cargo criterion +``` + +(Using `cargo bench` ignores the settings in `criterion.toml`!) +Open the full report which should be located in `target/criterion/reports/index.html`. + +Update the figures in this repo as follows (requires `rsvg-convert` from `librsvg` installed): + +```sh +script/copy-benchmark-results +``` diff --git a/crates/bpe/benches/counting.rs b/crates/bpe/benches/counting.rs deleted file mode 100644 index 9b746d3..0000000 --- a/crates/bpe/benches/counting.rs +++ /dev/null @@ -1,139 +0,0 @@ -use std::time::Duration; - -use bpe::byte_pair_encoding::{create_test_bytes, BytePairEncoding}; -use bpe::interval_encoding::IntervalEncoding; -use criterion::{criterion_group, criterion_main, Criterion}; -use rand::{thread_rng, Rng}; - -fn counting_benchmark(c: &mut Criterion) { - for (name, bpe) in [ - ("cl100k", BytePairEncoding::cl100k()), - ("o200k", BytePairEncoding::o200k()), - ] { - let text = create_test_bytes(&bpe, 20000); - let fast = IntervalEncoding::new(&bpe, &text); - - for bytes in [10, 100, 1000, 10000] { - let mut group = c.benchmark_group(format!("bpe-{name}-bytes-{bytes}")); - group.bench_function("hybrid counting", |b| { - b.iter_batched( - || thread_rng().gen_range(0..text.len() - bytes), - |start| fast.count(start..start + bytes), - criterion::BatchSize::SmallInput, - ) - }); - group.bench_function("backtrack counting", |b| { - b.iter_batched( - || thread_rng().gen_range(0..text.len() - bytes), - |start| bpe.count(&text[start..start + bytes]), - criterion::BatchSize::SmallInput, - ) - }); - } - } -} - -fn encoding_benchmark(c: &mut Criterion) { - for (name, bpe, tiktoken) in [ - ( - "cl100k", - BytePairEncoding::cl100k(), - tiktoken_rs::cl100k_base().unwrap(), - ), - ( - "o200k", - BytePairEncoding::o200k(), - tiktoken_rs::o200k_base().unwrap(), - ), - ] { - let text = create_test_string(&bpe, 20000); - let input = text.as_bytes(); - - for bytes in [10, 100, 1000, 10000] { - let mut group = c.benchmark_group(format!("bpe-{name}-bytes-{bytes}")); - group.bench_function("backtracking", |b| { - b.iter_batched( - || thread_rng().gen_range(0..input.len() - bytes), - |start| bpe.encode_via_backtracking(&input[start..start + bytes]), - criterion::BatchSize::SmallInput, - ) - }); - group.bench_function("heap", |b| { - b.iter_batched( - || thread_rng().gen_range(0..input.len() - bytes), - |start| bpe.encode_via_bitfield(&input[start..start + bytes]), - criterion::BatchSize::SmallInput, - ) - }); - group.bench_function("dynamic programming", |b| { - b.iter_batched( - || thread_rng().gen_range(0..input.len() - bytes), - |start| bpe.encode_via_table(&input[start..start + bytes]), - criterion::BatchSize::SmallInput, - ) - }); - group.bench_function("greedy", |b| { - b.iter_batched( - || thread_rng().gen_range(0..input.len() - bytes), - |start| bpe.encode_greedy(&input[start..start + bytes]), - criterion::BatchSize::SmallInput, - ) - }); - group.bench_function("minimal", |b| { - b.iter_batched( - || thread_rng().gen_range(0..input.len() - bytes), - |start| bpe.encode_minimal(&input[start..start + bytes]), - criterion::BatchSize::SmallInput, - ) - }); - group.bench_function("tiktoken", |b| { - b.iter_batched( - || loop { - let start = thread_rng().gen_range(0..input.len() - bytes - 1); - if is_char_boundary(input[start]) && is_char_boundary(input[start + bytes]) - { - return start; - } - }, - |start| tiktoken.encode_ordinary(&text[start..start + bytes]), - criterion::BatchSize::SmallInput, - ) - }); - } - } -} - -fn is_char_boundary(b: u8) -> bool { - // Single byte encodings satisfy the bit pattern 0xxxxxxx, i.e. b < 128 - // Continuation bytes satisfy the bit pattern 10xxxxxx, i.e. b < 192 - // The rest are bytes belonging to the first byte of multi byte encodings (11xxxxxx): b >= 192 - // When interpreting the byte representation as signed integers, then numbers in the range 128..192 - // correspond to the smallest representable numbers. I.e. the two ranges [0, 128) and [192, 256) can - // be tested with a single signed comparison. - b as i8 >= -0x40 // NB: b < 128 || b >= 192 -} - -fn create_test_string(bpe: &BytePairEncoding, tokens: usize) -> String { - use rand::{thread_rng, Rng}; - let mut text = String::new(); - for _ in 0..tokens { - loop { - let i = thread_rng().gen_range(0..bpe.num_tokens()); - let s = bpe.token_bytes(i as u32); - if s.iter().all(|b| is_char_boundary(*b)) { - if let Ok(s) = std::str::from_utf8(s) { - text.push_str(s); - break; - } - } - } - } - text -} - -criterion_group!( - name = benches; - config = Criterion::default().warm_up_time(Duration::from_millis(500)).measurement_time(Duration::from_millis(500)).nresamples(1000); - targets = counting_benchmark, encoding_benchmark -); -criterion_main!(benches); diff --git a/crates/bpe/benches/performance.rs b/crates/bpe/benches/performance.rs new file mode 100644 index 0000000..4cff09c --- /dev/null +++ b/crates/bpe/benches/performance.rs @@ -0,0 +1,199 @@ +use std::sync::LazyLock; +use std::time::Duration; + +use bpe::appendable_encoder::AppendableEncoder; +use bpe::byte_pair_encoding::{create_test_bytes, BytePairEncoding}; +use bpe::interval_encoding::IntervalEncoding; +use criterion::{ + criterion_group, criterion_main, AxisScale, BenchmarkId, Criterion, PlotConfiguration, +}; +use rand::{thread_rng, Rng}; +use tiktoken_rs::CoreBPE; + +static TOKENIZERS: LazyLock<[(&'static str, &'static BytePairEncoding, CoreBPE); 2]> = + LazyLock::new(|| { + [ + ( + "cl100k", + BytePairEncoding::cl100k(), + tiktoken_rs::cl100k_base().unwrap(), + ), + ( + "o200k", + BytePairEncoding::o200k(), + tiktoken_rs::o200k_base().unwrap(), + ), + ] + }); + +fn counting_benchmark(c: &mut Criterion) { + for (name, bpe, _) in TOKENIZERS.iter() { + let input = create_test_bytes(bpe, 20000); + let fast = IntervalEncoding::new(bpe, &input); + + let mut group = c.benchmark_group(format!("counting-{name}")); + group.plot_config(PlotConfiguration::default().summary_scale(AxisScale::Logarithmic)); + for bytes in [10, 100, 1000, 10000] { + group.throughput(criterion::Throughput::Bytes(bytes as u64)); + group.bench_with_input(BenchmarkId::new("interval", bytes), &bytes, |b, bytes| { + b.iter_batched( + || thread_rng().gen_range(0..input.len() - bytes), + |start| fast.count(start..start + bytes), + criterion::BatchSize::SmallInput, + ) + }); + group.bench_with_input( + BenchmarkId::new("backtracking", bytes), + &bytes, + |b, bytes| { + b.iter_batched( + || thread_rng().gen_range(0..input.len() - bytes), + |start| bpe.count(&input[start..start + bytes]), + criterion::BatchSize::SmallInput, + ) + }, + ); + } + group.finish(); + } +} + +fn encoding_benchmark(c: &mut Criterion) { + for (name, bpe, tiktoken) in TOKENIZERS.iter() { + let text = create_test_string(bpe, 20000); + let input = text.as_bytes(); + + let mut group = c.benchmark_group(format!("encoding-{name}")); + group.plot_config(PlotConfiguration::default().summary_scale(AxisScale::Logarithmic)); + for bytes in [10, 100, 1000, 10000] { + group.throughput(criterion::Throughput::Bytes(bytes as u64)); + group.bench_with_input( + BenchmarkId::new("backtracking", bytes), + &bytes, + |b, bytes| { + b.iter_batched( + || thread_rng().gen_range(0..input.len() - bytes), + |start| bpe.encode_via_backtracking(&input[start..start + bytes]), + criterion::BatchSize::SmallInput, + ) + }, + ); + group.bench_with_input(BenchmarkId::new("heap", bytes), &bytes, |b, bytes| { + b.iter_batched( + || thread_rng().gen_range(0..input.len() - bytes), + |start| bpe.encode_via_bitfield(&input[start..start + bytes]), + criterion::BatchSize::SmallInput, + ) + }); + group.bench_with_input(BenchmarkId::new("table", bytes), &bytes, |b, bytes| { + b.iter_batched( + || thread_rng().gen_range(0..input.len() - bytes), + |start| bpe.encode_via_table(&input[start..start + bytes]), + criterion::BatchSize::SmallInput, + ) + }); + group.bench_with_input(BenchmarkId::new("greedy", bytes), &bytes, |b, bytes| { + b.iter_batched( + || thread_rng().gen_range(0..input.len() - bytes), + |start| bpe.encode_greedy(&input[start..start + bytes]), + criterion::BatchSize::SmallInput, + ) + }); + group.bench_with_input(BenchmarkId::new("minimal", bytes), &bytes, |b, bytes| { + b.iter_batched( + || thread_rng().gen_range(0..input.len() - bytes), + |start| bpe.encode_minimal(&input[start..start + bytes]), + criterion::BatchSize::SmallInput, + ) + }); + group.bench_with_input(BenchmarkId::new("tiktoken", bytes), &bytes, |b, bytes| { + b.iter_batched( + || loop { + let start = thread_rng().gen_range(0..input.len() - bytes - 1); + if is_char_boundary(input[start]) && is_char_boundary(input[start + bytes]) + { + return start; + } + }, + |start| tiktoken.encode_ordinary(&text[start..start + bytes]), + criterion::BatchSize::SmallInput, + ) + }); + } + group.finish(); + } +} + +fn appending_benchmark(c: &mut Criterion) { + for (name, bpe, _) in TOKENIZERS.iter() { + let input = create_test_bytes(bpe, 20000); + + let mut group = c.benchmark_group(format!("appending-{name}")); + group.plot_config(PlotConfiguration::default().summary_scale(AxisScale::Logarithmic)); + for bytes in [10, 100, 1000, 10000] { + group.throughput(criterion::Throughput::Bytes(bytes as u64)); + group.bench_with_input(BenchmarkId::new("appending", bytes), &bytes, |b, bytes| { + b.iter_batched( + || { + ( + thread_rng().gen_range(0..input.len() - bytes), + AppendableEncoder::new(bpe), + ) + }, + |(start, mut enc)| enc.extend(input[start..start + bytes].iter().copied()), + criterion::BatchSize::SmallInput, + ) + }); + group.bench_with_input( + BenchmarkId::new("backtracking", bytes), + &bytes, + |b, bytes| { + b.iter_batched( + || thread_rng().gen_range(0..input.len() - bytes), + |start| bpe.count(&input[start..start + bytes]), + criterion::BatchSize::SmallInput, + ) + }, + ); + } + group.finish(); + } +} + +fn is_char_boundary(b: u8) -> bool { + // Single byte encodings satisfy the bit pattern 0xxxxxxx, i.e. b < 128 + // Continuation bytes satisfy the bit pattern 10xxxxxx, i.e. b < 192 + // The rest are bytes belonging to the first byte of multi byte encodings (11xxxxxx): b >= 192 + // When interpreting the byte representation as signed integers, then numbers in the range 128..192 + // correspond to the smallest representable numbers. I.e. the two ranges [0, 128) and [192, 256) can + // be tested with a single signed comparison. + b as i8 >= -0x40 // NB: b < 128 || b >= 192 +} + +fn create_test_string(bpe: &BytePairEncoding, tokens: usize) -> String { + use rand::{thread_rng, Rng}; + let mut text = String::new(); + for _ in 0..tokens { + loop { + let i = thread_rng().gen_range(0..bpe.num_tokens()); + let s = bpe.token_bytes(i as u32); + if s.iter().all(|b| is_char_boundary(*b)) { + if let Ok(s) = std::str::from_utf8(s) { + text.push_str(s); + break; + } + } + } + } + text +} + +criterion_group!( + name = benches; + config = Criterion::default() + .warm_up_time(Duration::from_millis(500)) + .measurement_time(Duration::from_millis(1000)) + .nresamples(1000); + targets = counting_benchmark, encoding_benchmark, appending_benchmark +); +criterion_main!(benches); diff --git a/crates/bpe/benches/result/appending-o200k.svg b/crates/bpe/benches/result/appending-o200k.svg new file mode 100644 index 0000000..f358527 --- /dev/null +++ b/crates/bpe/benches/result/appending-o200k.svg @@ -0,0 +1,52 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/crates/bpe/benches/result/counting-o200k.svg b/crates/bpe/benches/result/counting-o200k.svg new file mode 100644 index 0000000..deaf497 --- /dev/null +++ b/crates/bpe/benches/result/counting-o200k.svg @@ -0,0 +1,48 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/crates/bpe/benches/result/encoding-o200k.svg b/crates/bpe/benches/result/encoding-o200k.svg new file mode 100644 index 0000000..468755c --- /dev/null +++ b/crates/bpe/benches/result/encoding-o200k.svg @@ -0,0 +1,76 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/crates/bpe/criterion.toml b/crates/bpe/criterion.toml new file mode 100644 index 0000000..a954003 --- /dev/null +++ b/crates/bpe/criterion.toml @@ -0,0 +1,16 @@ +# save report in this directory, even if a custom target directory is set +criterion_home = "./target/criterion" + +# The colors table allows users to configure the colors used by the charts +# cargo-criterion generates. +[colors] +# Color-blind friendly color scheme from https://personal.sron.nl/~pault/. +comparison_colors = [ + {r = 102, g = 204, b = 238}, # cyan + {r = 204, g = 187, b = 68}, # yellow + {r = 238, g = 102, b = 119}, # red + {r = 68, g = 119, b = 170}, # blue + {r = 170, g = 51, b = 119}, # purple + {r = 34, g = 136, b = 51}, # green +# {r = 187, g = 187, b = 187}, # grey +] diff --git a/crates/bpe/script/copy-benchmark-results b/crates/bpe/script/copy-benchmark-results new file mode 100755 index 0000000..df9e97f --- /dev/null +++ b/crates/bpe/script/copy-benchmark-results @@ -0,0 +1,11 @@ +#!/usr/bin/env bash + +set -eu + +result_dir="benches/result" + +mkdir -p "$result_dir" + +for i in {counting,encoding,appending}-o200k; do + rsvg-convert --format svg --output "$result_dir/$i.svg" --background-color white "target/criterion/reports/$i/lines.svg" +done diff --git a/crates/bpe/src/byte_pair_encoding.rs b/crates/bpe/src/byte_pair_encoding.rs index d66b8bd..72fa946 100644 --- a/crates/bpe/src/byte_pair_encoding.rs +++ b/crates/bpe/src/byte_pair_encoding.rs @@ -176,12 +176,12 @@ pub fn find_hash_factor_for_tiktoken(bpe: &tiktoken_rs::CoreBPE, len: usize) -> /// Find a suitable hash factor for a set of given tokens that prevents collisions when /// constructing a [`BytePairEncoding`] from those tokens. #[cfg(feature = "rand")] -pub fn find_hash_factor_for_dictionary(iter: impl Iterator>) -> u64 { +pub fn find_hash_factor_for_dictionary(tokens: impl IntoIterator>) -> u64 { use std::collections::HashSet; use rand::Rng; - let all_tokens = iter.collect_vec(); + let all_tokens = tokens.into_iter().collect_vec(); let mut rnd = rand::thread_rng(); loop { let factor: u64 = rnd.gen(); @@ -244,7 +244,10 @@ impl BytePairEncoding { /// /// The recommended approach is to store the serialized value and reuse that, /// to prevent repeating the cost of computing the hash factor and encoding. - pub fn from_dictionary(iter: impl Iterator>, hash_factor: Option) -> Self { + pub fn from_dictionary( + tokens: impl IntoIterator>, + hash_factor: Option, + ) -> Self { let hash_factor = hash_factor .inspect(|f| assert_ne!(*f, 0, "hash factor must be larger than zero")) .unwrap_or(1); @@ -252,7 +255,7 @@ impl BytePairEncoding { let mut all_tokens_rev = Vec::new(); let mut token_starts = vec![0]; let mut bytes_hash_to_token = FnvHashMap::default(); - for (i, token) in iter.enumerate() { + for (i, token) in tokens.into_iter().enumerate() { bytes_hash_to_token.insert(hash_bytes(&token, hash_factor), i as u32); all_tokens_rev.extend(token.iter().copied().rev()); all_tokens.extend(token); diff --git a/crates/bpe/src/lib.rs b/crates/bpe/src/lib.rs index 452024e..2c7ab43 100644 --- a/crates/bpe/src/lib.rs +++ b/crates/bpe/src/lib.rs @@ -4,3 +4,64 @@ mod bitfield; pub mod byte_pair_encoding; pub mod interval_encoding; pub mod prependable_encoder; + +#[cfg(test)] +mod tests { + use itertools::Itertools; + + use crate::byte_pair_encoding::BytePairEncoding; + + /// This test produces the output for the encoding example in the README. + #[test] + fn readme_example() { + let tokens = ["a", "b", "c", "ab", "cb", "ac"].map(|t| t.as_bytes().to_vec()); + let bpe = BytePairEncoding::from_dictionary(tokens, None); + let text = "abacb"; + let prefixes = (1..=text.len()).map(|end| &text[..end]).collect_vec(); + let all_prefix_tokens = prefixes + .iter() + .map(|prefix| { + bpe.encode_via_backtracking(prefix.as_bytes()) + .into_iter() + .map(|t| unsafe { String::from_utf8_unchecked(bpe.decode_tokens(&[t])) }) + .collect_vec() + }) + .collect_vec(); + let last_prefix_tokens = all_prefix_tokens + .iter() + .map(|tokens| tokens.last().unwrap()) + .collect_vec(); + + println!("All tokens for each prefix of `{text}`:\n"); + for (prefix, tokens) in prefixes.iter().zip(&all_prefix_tokens) { + println!( + "- `{prefix}` {}> `{}`", + "-".repeat(text.len() + 2 - prefix.len()), + tokens.join(" ") + ); + } + println!(); + + println!("Last token for each prefix of `{text}`:\n"); + for (prefix, token) in prefixes.iter().zip(&last_prefix_tokens) { + println!( + "- `{prefix}` {}> `{token}`", + "-".repeat(text.len() + 2 - prefix.len()), + ); + } + println!(); + + println!("Tokenization of `{text}`:\n"); + let mut remaining = text.len(); + while remaining > 0 { + let prefix = &text[..remaining]; + let token = last_prefix_tokens[remaining - 1]; + println!( + "- `{prefix}` {}> `{token}`", + "-".repeat(text.len() + 2 - prefix.len()), + ); + remaining -= token.len(); + } + println!("- ``"); + } +} diff --git a/criterion.toml b/crates/geo_filters/criterion.toml similarity index 100% rename from criterion.toml rename to crates/geo_filters/criterion.toml