-
Notifications
You must be signed in to change notification settings - Fork 42
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
2) Optimize byte pair merge for small and big character sequences - 8.2s to 3.9s #76
Conversation
private final Pattern pattern; | ||
private final TokenEncoder encoder; | ||
private final SpecialEncoder specialEncoder; | ||
private final Map<Integer, byte[]> encodedToDecoded; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TokenEncoder
is only used for encoding now - so we can eliminate the types
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would prefer to keep (this explicitly typed) encodedToDecoded
map inside the modified TokenEncoder
to be consistent with the SpecialEncoder
structure
As far as I can tell, the only performance implication is that we lose the benefit of caching decoding in a single map and therefore having to make 2 map lookups for special tokens which is negligible (especially since encoding special characters is unsupported anyways)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In my original over-optimized version there were multiple TokenEncoders - when the byte count was < Long.BYTES, we stored them in a primitive long and used a primitive map with a long key instead - that enabled squeezing the last few drops, since short tokens are the most common ones and they're a lot faster since no byte arrays are present anymore. But I haven't committed that yet, so probably I should merge it back with TokenEncoder for now and maybe split it out again if there's a PR #79
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
lib/src/main/java/com/knuddels/jtokkit/GptBytePairEncoding.java
Outdated
Show resolved
Hide resolved
lib/src/main/java/com/knuddels/jtokkit/GptBytePairEncoding.java
Outdated
Show resolved
Hide resolved
benchmark/src/jmh/java/com/knuddels/jtokkit/SingleThreadedBenchmark.java
Show resolved
Hide resolved
lib/src/main/java/com/knuddels/jtokkit/GptBytePairEncoding.java
Outdated
Show resolved
Hide resolved
} | ||
} | ||
|
||
public static int getMinRankIndex(List<Integer> ranks) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When searching for the next minimum value, we've unrolled it to favor SIMD optimizations
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Out of curiosity: Have you benchmarked this unrolling seperately? I was under the impression that this kind of optimization (loop unrolling) is best left to the compiler
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it is still left for the JIT compiler, since I'm not using the vector api, just making it simpler for it to group similar instructions. If you could run the benchmarks yourself, you can tell me yourself if it reproduces or not.
And yes, I have benchmarked everything separately (haven't committed every single benchmark though), this was the fastest.
lib/src/test/java/com/knuddels/jtokkit/reference/Cl100kBaseTest.java
Outdated
Show resolved
Hide resolved
adeff4f
to
9c332c4
Compare
* When searching for the next minimum value, we've unrolled it to favor SIMD optimizations. * PieceIndexToRank is removed, we're only storing the rank, since we're not deleting next values anymore (which avoids copying every subsequent value) * Since we've replaced minimums with sentinels, previous and next indexes are replaced by iteration * The encoders map is split by input byte array size so that we're only querying small maps * Iteration stops before the last minimum is MAX_RANK by keeping track of merge results - resulting in one less minimum search at the end Before: Benchmark (dataFolderPath) Mode Cnt Score Error Units SingleThreadedBenchmark.benchmarkCl100kBase data ss 10 8.947 ± 0.109 s/op SingleThreadedBenchmark.benchmarkP50kBase data ss 10 9.419 ± 0.082 s/op SingleThreadedBenchmark.benchmarkP50kEdit data ss 10 9.365 ± 0.073 s/op SingleThreadedBenchmark.benchmarkR50kBase data ss 10 8.403 ± 0.080 s/op After: Benchmark (dataFolderPath) Mode Cnt Score Error Units SingleThreadedBenchmark.benchmarkCl100kBase data ss 10 7.313 ± 0.031 s/op SingleThreadedBenchmark.benchmarkCl100kBaseTokenCount data ss 10 7.242 ± 0.027 s/op SingleThreadedBenchmark.benchmarkP50kBase data ss 10 7.742 ± 0.054 s/op SingleThreadedBenchmark.benchmarkP50kEdit data ss 10 7.748 ± 0.121 s/op SingleThreadedBenchmark.benchmarkR50kBase data ss 10 7.017 ± 0.110 s/op
Before: Benchmark (dataFolderPath) Mode Cnt Score Error Units SingleThreadedBenchmark.benchmarkCl100kBaseTokenCount data ss 10 7.242 ± 0.027 s/op Benchmark (dataFolderPath) Mode Cnt Score Error Units SingleThreadedBenchmark.benchmarkCl100kBaseTokenCount data ss 10 6.885 ± 0.049 s/op
We're storing the ranks in a red-black tree of trees. Getting the minimum rank is basically constant time (grouping by the rank itself since we can have multiple, popping the first (representing the first occurrence)). Here we're removing the node after merge (also basically constant time operation). We're also counting the remaining valid ranks for stopping condition. To know the previous and next values here, we're storing all of it in a RankNode that we're updating after finding the minimum via the tree. Before: Benchmark (dataFolderPath) Mode Cnt Score Error Units SingleThreadedBenchmark.benchmarkCl100kBase data ss 10 7.372 ± 0.063 s/op SingleThreadedBenchmark.benchmarkCl100kBaseTokenCount data ss 10 6.885 ± 0.049 s/op SingleThreadedBenchmark.benchmarkP50kBase data ss 10 7.846 ± 0.051 s/op SingleThreadedBenchmark.benchmarkP50kEdit data ss 10 7.850 ± 0.051 s/op SingleThreadedBenchmark.benchmarkR50kBase data ss 10 7.006 ± 0.066 s/op After: Benchmark (dataFolderPath) Mode Cnt Score Error Units SingleThreadedBenchmark.benchmarkCl100kBase data ss 10 4.592 ± 0.055 s/op SingleThreadedBenchmark.benchmarkCl100kBaseTokenCount data ss 10 4.215 ± 0.036 s/op SingleThreadedBenchmark.benchmarkP50kBase data ss 10 5.598 ± 0.063 s/op SingleThreadedBenchmark.benchmarkP50kEdit data ss 10 5.569 ± 0.044 s/op SingleThreadedBenchmark.benchmarkR50kBase data ss 10 5.178 ± 0.128 s/op
It's simpler and in the current implementation it's basically just as fast.
…the large tokenizer as well
059f95f
to
ad7292f
Compare
validRanks -= (newRank == MAX_RANK) ? 1 : -1; | ||
TreeMap<Integer, RankNode> minNodes = rankMap.firstEntry().getValue(); | ||
for (int i = 0; i < minNodes.size(); i++) { | ||
RankNode minNode = minNodes.firstEntry().getValue(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the minimum we've found will be the next minimum anyway, if there are multiple tokens with the same rank, so we can just use up the gathered ones
public static final int DUMMY_RANK = Integer.MAX_VALUE; | ||
public static final int MAX_RANK = Integer.MAX_VALUE - 1; | ||
public final class TokenEncoder { | ||
public static final String VERY_LARGE_TOKENIZER_BYTE_THRESHOLD_KEY = "VERY_LARGE_TOKENIZER_BYTE_THRESHOLD"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
public should be fine, the users should also be able to override this from the outside if they really need it - though I wouldn't expose it through the API
@@ -14,22 +14,26 @@ class Cl100kBaseTest { | |||
|
|||
private static final Encoding ENCODING = Encodings.newDefaultEncodingRegistry().getEncoding(EncodingType.CL100K_BASE); | |||
|
|||
Encoding getEncoding() { | |||
return ENCODING; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we have to be able to override it from a child test to control the order of initializations
|
||
import static com.knuddels.jtokkit.TokenEncoder.VERY_LARGE_TOKENIZER_BYTE_THRESHOLD_KEY; | ||
|
||
class Cl100kLargeTokenizerTest extends Cl100kBaseTest { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
checked it via test coverage, the first only runs the fast array one, this one only runs the map based one
validRanks--; | ||
} | ||
removeNode(rankMap, nextNode); | ||
} | ||
removeNode(minNodes, rankMap, minNode); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
minor optimization - we could probably do a few more tiny ones here, but this is a worst-case handler anyway
validRanks--; | ||
} | ||
removeNode(rankMap, nextNode); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not sure what happened with the formatting here before :/
@@ -104,7 +104,7 @@ Kateri je tvoj najljubši okus sladoleda?,"[42, 977, 72, 4864, 259, 3415, 73, 30 | |||
Quel est ton livre préféré?,"[2232, 301, 1826, 8941, 56984, 27389, 69, 68862, 30]","[2232, 301, 1826, 8941, 56984, 27389, 69, 68862, 30]" | |||
Qual é a tua cor favorita?,"[32129, 4046, 264, 64984, 1867, 4799, 6388, 30]","[32129, 4046, 264, 64984, 1867, 4799, 6388, 30]" | |||
Koja ti je omiljena boja?,"[42, 78, 5697, 9165, 4864, 8019, 321, 73, 7304, 712, 5697, 30]","[42, 78, 5697, 9165, 4864, 8019, 321, 73, 7304, 712]" | |||
Melyik a kedvenc étel?,"[44, 989, 1609, 264, 80142, 85, 967, 14240, 301, 30]","[44, 989, 1609, 264, 80142, 85, 967, 14240, 301, 30]" | |||
Melyik a kedvenc ételed?,"[44, 989, 1609, 264, 80142, 85, 967, 4046, 668, 839, 30]","[44, 989, 1609, 264, 80142, 85, 967, 4046, 668, 839]" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
…iteration After: Benchmark (dataFolderPath) Mode Cnt Score Error Units SingleThreadedBenchmark.benchmarkCl100kBase data ss 10 4.547 ± 0.056 s/op SingleThreadedBenchmark.benchmarkCl100kBaseTokenCount data ss 10 3.944 ± 0.031 s/op SingleThreadedBenchmark.benchmarkP50kBase data ss 10 5.427 ± 0.065 s/op SingleThreadedBenchmark.benchmarkP50kEdit data ss 10 5.375 ± 0.062 s/op SingleThreadedBenchmark.benchmarkR50kBase data ss 10 5.073 ± 0.063 s/op
@@ -150,7 +150,7 @@ public static void main(String[] args) throws Exception { | |||
"'s", "'t", "'re", "'ve", "'m", "'ll", "'d", "'x", | |||
"x", | |||
"123", | |||
"ő", | |||
"a", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
testing a slightly different setup here since aaaaaa
is still a single token, so this makes it a more difficult case
@@ -177,7 +177,7 @@ public static void main(String[] args) throws Exception { | |||
} | |||
|
|||
var totalSize = calculateTotalFileSize(rootFolder); | |||
if (totalSize != 99_945_290) { | |||
if (totalSize != 99_925_295) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
reran the before/after with these data
assert minNode.rank != MAX_RANK; | ||
TreeMap<Integer, RankNode> minNodes = rankMap.pollFirstEntry().getValue(); | ||
int firstIndex; | ||
for (Entry<Integer, RankNode> entry = minNodes.firstEntry(); entry != null; entry = minNodes.ceilingEntry(firstIndex)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
because of the tree structure we're actually storing every instance of the same token, so once we find any of them, we can be sure that the next few minimums will also be the same token - so let's just iterate those instead, without removing them one-by-one (polling once and iterating until consumed).
Continuing #75 - note that the first few commits are repeated here, will be eliminated by rebase once the other one's merged.
The original byte pair merge algorithm diverges quickly for longer character sequences in a superlinear way - e.g. a 20000 character word (e.g. 2500 tokens) can take several seconds to tokenize.
For bigger character sequences we're switching to a linear(ithmic) algorithm at around 500 characters (below which the current one is faster):
The change also includes an optimization for just token counting - when the tokens themselves aren't important.
Before (i.e. assuming #75 was merged):
After:
Please review commit-by-commit for the changes to make sense: