Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions tests/bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
CHUNK_SIZE = 512
# END CONFIG #

def bench() -> dict[str, float]:
def bench() -> dict[str, float]:
# Initialise the chunkers.
semchunk_chunker = semchunk.chunkerify(tiktoken.encoding_for_model('gpt-4'), CHUNK_SIZE)
sts_chunker = TextSplitter.from_tiktoken_model('gpt-4', CHUNK_SIZE)
Expand All @@ -21,7 +21,7 @@ def bench_semchunk(texts: list[str]) -> None:
semchunk_chunker(texts)

def bench_sts(texts: list[str]) -> None:
[sts_chunker.chunks(text) for text in texts]
sts_chunker.chunk_all(texts)

libraries = {
'semchunk': bench_semchunk,
Expand All @@ -31,22 +31,23 @@ def bench_sts(texts: list[str]) -> None:
# Download the Gutenberg corpus.
try:
gutenberg = nltk.corpus.gutenberg

except Exception:
nltk.download('gutenberg')
gutenberg = nltk.corpus.gutenberg

# Benchmark the libraries.
benchmarks = dict.fromkeys(libraries.keys(), 0)
benchmarks = dict.fromkeys(libraries.keys(), 0.0)
texts = [gutenberg.raw(fileid) for fileid in gutenberg.fileids()]

for library, function in libraries.items():
start = time.time()
function(texts)
benchmarks[library] = time.time() - start

return benchmarks

if __name__ == '__main__':
nltk.download('gutenberg')
for library, time_taken in bench().items():
print(f'{library}: {time_taken:.2f}s')
print(f'{library}: {time_taken:.2f}s')