Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Performance ideas #33

Open
christopher-hesse opened this issue Feb 20, 2023 · 6 comments
Open

Performance ideas #33

christopher-hesse opened this issue Feb 20, 2023 · 6 comments

Comments

@christopher-hesse
Copy link

I made a toy GPT2 tokenizer as a python rust extension. It seems to be slightly faster than tiktoken in my tests. It looks like #31 may get most or all the way there, but I thought I'd post the results from this script:

import os
import time
from typing import Any, cast

import numpy as np
import tiktoken

SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))

def benchmark_batch(documents: list[bytes]) -> None:
    num_threads = 1
    num_bytes = sum(map(len, documents))
    print(f"num_threads: {num_threads}, num_bytes: {num_bytes}")
    documents_decoded = [doc.decode("utf8") for doc in documents]

    enc = tiktoken.get_encoding("gpt2")
    enc.encode("warmup")

    start = time.perf_counter_ns()
    tiktoken_output = enc.encode_ordinary_batch(documents_decoded, num_threads=num_threads)
    end = time.perf_counter_ns()
    print(f"tiktoken \t{num_bytes / (end - start) * 1e9} bytes / s")

    import transformers

    hf_enc = cast(Any, transformers).GPT2TokenizerFast.from_pretrained("gpt2")
    hf_enc.model_max_length = 1e30  # silence!
    hf_enc.encode("warmup")

    start = time.perf_counter_ns()
    hf_enc_output = hf_enc(documents_decoded)
    end = time.perf_counter_ns()
    print(f"huggingface \t{num_bytes / (end - start) * 1e9} bytes / s")

    import csh_bpe.codec
    csh_bpe_enc = csh_bpe.codec.RustGPTCodec(word_encoder_kind="bigram", doc_splitter_kind="direct")
    csh_bpe_enc.encode(np.frombuffer(b"warmup", dtype=np.uint8))
    
    start = time.perf_counter_ns()
    csh_bpe_output = csh_bpe_enc.encode(np.frombuffer(documents[0], dtype=np.uint8))
    end = time.perf_counter_ns()
    print(f"csh_bpe \t{num_bytes / (end - start) * 1e9} bytes / s")

    assert hf_enc_output["input_ids"][0] == tiktoken_output[0]
    assert csh_bpe_output.tolist() == tiktoken_output[0]


def main():
    with open(os.path.join(SCRIPT_DIR, "..", "local-data", "64MB.txt"), "rb") as f:
        contents = f.read()
    benchmark_batch([contents])


if __name__ == "__main__":
    main()

The text is 64MiB of wikipedia wikitext, probably enwik8, but I just found it on my hard drive.

python -m csh_bpe.compare_tiktoken
num_threads: 1, num_bytes: 67108864
tiktoken        6004366.360373783 bytes / s
huggingface     1120214.7857500792 bytes / s
csh_bpe         17070974.6114367 bytes / s

There are no fancy optimizations here (like SIMD stuff), the library has a few things it might do differently from tiktoken:

  1. The word splitting regular expression is implemented using rust code instead of a regexp library. It uses Go's unicode tables: https://github.com/golang/go/blob/19309779ac5e2f5a2fd3cbb34421dafb2855ac21/src/unicode/tables.go and this seems to produce the same output at least for this 64MB file. The splitting is done with a function that takes a u8 numpy array and start offset and returns the end offset.
  2. The bigram encoder takes a u8 slice for the word, a HashMap<(i32, i32), i32> mergelist, an i32 slice mapping bytes to tokens (used to populate the initial output), and a mutable i32 slice of output tokens. It keeps a list of skip lengths for each index of the output tokens (initially all 1s), which it updates whenever it merges two tokens together, then compacts the output tokens when it is done.
  3. (I think tiktoken does this) after splitting, before encoding a word, it will check the vocab hashmap to see if the word is already a single token.
  4. The interface uses numpy arrays instead of bytes, and the output array is provided as one of the inputs so the caller can manage more memory allocations (not sure if this has any performance impact)

I didn't implement rust regexps so I don't know if the word splitting matters, though I could benchmark just the splitting part.

@christopher-hesse
Copy link
Author

I had a version using a trie to do single-pass-ish encoding of an input, but it wasn't correct. I'm not certain how fast a correct version of that trie would be.

@hauntsaninja
Copy link
Collaborator

Thanks, those are really nice results!

  1. Last time I checked, regex splitting was the majority of the time — I'd be interested in benchmarking the splitting part if easy. I'm potentially interested in specialised code, but we do vary the splitting regexes. Hopefully the PCRE approach demonstrated in the original version of Improve performance by 2x #31 is viable and closes most of the gap.
  2. Nice! The repeated bytes hashing that tiktoken does is clearly not efficient (I was surprised it was viable). Skip list seems like a good way to avoid the O(n) deletes in the loop.
  3. Yeah, tiktoken does this and it was a big piece in ensuring good perf.
  4. I thought about this. I wouldn't want caller to provide array because it's annoying to size. But even just returning an numpy array to get rid of the overhead of going from Rust vec to Python list could be good (PyO3 has numpy bindings) — but I haven't benchmarked.
  5. If you figure out how to do this, let me know! I don't see a way. I'm always uncertain about the perf characteristics of tries, since they're not CPU cache friendly.

@christopher-hesse
Copy link
Author

  1. Definitely easier if PCRE is fast enough, but if there's still a significant speed gain from hand-writing the most common regexps, could be worth it.

Previous script, full encode: csh_bpe 16592057.7610188 bytes / s => 60.3 ns/byte
Previous script, splitting only (commented out the bigram part): csh_bpe 104345021.38634916 bytes / s => 9.6 ns/byte

  1. The caller supplies an array that is the length of the input, out_tokens = np.empty(input.shape, dtype=np.int32). This does cost more memory during encoding, though the caller can copy the used part of the array afterward if they want. Also unclear to me if this has any measurable performance advantage.
  2. Yeah, the cache unfriendlyness is worrying, but definitely having a correct trie is the first step. It's not obviously impossible to me, but the correct trie could be gigantic.

@christopher-hesse
Copy link
Author

Feel free to close this if the ideas have been ideated.

@alkoumpa
Copy link

alkoumpa commented Apr 3, 2023

Hello,

It seems that the slow performance is due to an ineffective implementation of the negative lookahead clause ("\s+(?!\S)") in the fancy_regex library.

A possible solution to mimic the negative lookahead functionality is to remove it from the regex and manually re-add spaces to the matched parts, such as words or numbers. Although this approach achieves the same performance as pcre2, it may not be the most elegant solution.

@l0rinc
Copy link
Contributor

l0rinc commented Nov 23, 2023

I'm currently working on optimizing the tokenizer and the token counter (on the Java implementation at https://github.com/knuddelsgmbh/jtokkit, but most of the tricks should be applicable to other implementations as well).

Benchmark                                                      (dataFolderPath)  Mode  Cnt  Score   Error  Units
SingleThreadedBenchmark.benchmarkCl100kBaseTokenCountOriginal              data    ss   10  6.503 ± 0.053   s/op
SingleThreadedBenchmark.benchmarkCl100kBaseTokenCount                      data    ss   10  2.094 ± 0.042   s/op

So far it's 3x faster, but I still have a few ideas left.
I'll check after, if the recommendations here are applicable or not.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants