diff --git a/src/Microsoft.ML.Tokenizers/Model/BPETokenizer.cs b/src/Microsoft.ML.Tokenizers/Model/BPETokenizer.cs index f17845adf8..b9592d2e2b 100644 --- a/src/Microsoft.ML.Tokenizers/Model/BPETokenizer.cs +++ b/src/Microsoft.ML.Tokenizers/Model/BPETokenizer.cs @@ -320,7 +320,7 @@ private BpeTokenizer( if (beginningOfSentenceToken is not null) { - if (!_vocab.TryGetValue(beginningOfSentenceToken, out int aId)) + if (_vocab.TryGetValue(beginningOfSentenceToken, out int aId) is false && specialTokens?.TryGetValue(beginningOfSentenceToken, out aId) is false) { throw new InvalidOperationException($"The beginning of sentence token '{beginningOfSentenceToken}' was not present in the vocabulary."); } @@ -331,7 +331,7 @@ private BpeTokenizer( if (endOfSentenceToken is not null) { - if (!_vocab.TryGetValue(endOfSentenceToken, out int aId)) + if (_vocab.TryGetValue(endOfSentenceToken, out int aId) is false && specialTokens?.TryGetValue(endOfSentenceToken, out aId) is false) { throw new InvalidOperationException($"The end of sentence token '{endOfSentenceToken}' was not present in the vocabulary."); } @@ -792,31 +792,30 @@ public string Decode(IEnumerable ids, bool considerSpecialTokens) ValueStringBuilder sb = new ValueStringBuilder(); - bool decodeUnknownToken = _unknownTokenId.HasValue && considerSpecialTokens; - - if (decodeUnknownToken) + foreach (int id in ids) { - foreach (int id in ids) + if (_specialTokensReverse?.TryGetValue(id, out string? token) is true) { - if (MapIdToToken(id) is string s) + if (considerSpecialTokens) { - sb.Append(s); + sb.Append(token); } + continue; } - } - else - { - foreach (int id in ids) + + if (id == _unknownTokenId) { - if (id == _unknownTokenId) + if (considerSpecialTokens) { - continue; + Debug.Assert(UnknownToken is not null); + sb.Append(UnknownToken); } + continue; + } - if (MapIdToToken(id) is string s) - { - sb.Append(s); - } + if (MapIdToToken(id) is string s) + { + sb.Append(s); } } diff --git a/test/Microsoft.ML.Tokenizers.Tests/BpeTests.cs b/test/Microsoft.ML.Tokenizers.Tests/BpeTests.cs index 7394464b90..5c2da4aece 100644 --- a/test/Microsoft.ML.Tokenizers.Tests/BpeTests.cs +++ b/test/Microsoft.ML.Tokenizers.Tests/BpeTests.cs @@ -885,6 +885,66 @@ public void TestDeepSeekR1Tokenizer(string text, int[] ids, string[] tokens, (in Assert.Equal(text, tokenizer.Decode(ids, considerSpecialTokens: false)); } + [Fact] + public void TestTokenizerWithSpecialTokens() + { + // "https://huggingface.co/openai-community/gpt2/raw/main/vocab.json"; + // "https://huggingface.co/openai-community/gpt2/raw/main/merges.txt"; + + BpeOptions options = new BpeOptions(Path.Combine(@"Gpt-2", "vocab.json"), Path.Combine(@"Gpt-2", "merges.txt")) + { + UnknownToken = "unk", + + SpecialTokens = new Dictionary // SpecialTokens not part of the original vocab.json + { + { "<|sos|>", 50257 }, + { "<|eos|>", 50258 } + }, + BeginningOfSentenceToken = "<|sos|>", + EndOfSentenceToken = "<|eos|>" + }; + + BpeTokenizer bpeTokenizer = BpeTokenizer.Create(options); + Assert.True(bpeTokenizer.Vocabulary.TryGetValue(options.UnknownToken, out int unkId)); + + string text = "Hello world!\uD800"; + + var ids = bpeTokenizer.EncodeToIds(text, considerPreTokenization: false); + Assert.Equal([50257, 15496, 2954, 6894, 0, 2954, 50258], ids); // space and u+D800 couldn't be encoded and produced unk tokens + Assert.Equal(unkId, ids[ids.Count - 2]); + Assert.Equal(options.SpecialTokens["<|sos|>"], ids[0]); + Assert.Equal(options.SpecialTokens["<|eos|>"], ids[^1]); + + var tokens = bpeTokenizer.EncodeToTokens(text, out _, considerPreTokenization: false).Select(t => t.Value).ToArray(); + Assert.Equal(["<|sos|>", "Hello", "unk", "world", "!", "unk", "<|eos|>"], tokens); + + Assert.Equal("<|sos|>Hellounkworld!unk<|eos|>", bpeTokenizer.Decode(ids)); + Assert.Equal("Helloworld!", bpeTokenizer.Decode(ids, considerSpecialTokens: false)); + + BpeOptions options1 = new BpeOptions(options.Vocabulary) + { + // Null UnknownToken means no unknown token support + Merges = options.Merges, + SpecialTokens = options.SpecialTokens, + BeginningOfSentenceToken = options.BeginningOfSentenceToken, + EndOfSentenceToken = options.EndOfSentenceToken + }; + + bpeTokenizer = BpeTokenizer.Create(options1); + ids = bpeTokenizer.EncodeToIds(text, considerPreTokenization: false); + + // Because Unknown is not supported in this encoding, the encoding will produce different encoding results + Assert.Equal([50257, 39, 5037, 1764, 0, 50258], ids); + Assert.Equal(options.SpecialTokens["<|sos|>"], ids[0]); + Assert.Equal(options.SpecialTokens["<|eos|>"], ids[^1]); + + tokens = bpeTokenizer.EncodeToTokens(text, out _, considerPreTokenization: false).Select(t => t.Value).ToArray(); + Assert.Equal(["<|sos|>", "H", "ellow", "orld", "!", "<|eos|>"], tokens); + + Assert.Equal("<|sos|>Helloworld!<|eos|>", bpeTokenizer.Decode(ids)); + Assert.Equal("Helloworld!", bpeTokenizer.Decode(ids, considerSpecialTokens: false)); + } + private static BpeTokenizer CreateBpeTokenizerFromJson() { // @"https://huggingface.co/deepseek-ai/DeepSeek-R1/resolve/main/tokenizer.json?download=true"