Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding needed Tokenizer's APIs #7047

Merged
merged 7 commits into from
Mar 7, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
52 changes: 46 additions & 6 deletions src/Microsoft.ML.Tokenizers/Model/Tiktoken.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
using System.IO;
using System.Linq;
using System.Text;
using System.Text.RegularExpressions;
using System.Threading;
using System.Threading.Tasks;

Expand Down Expand Up @@ -100,6 +101,43 @@ private Tiktoken(Stream vocabStream, IReadOnlyDictionary<string, int>? specialTo
}
}

/// <summary>
/// Create a Tiktoken tokenizer based on model name and vocab file.
/// </summary>
/// <param name="modelName">Model name</param>
/// <param name="vocabStream">The stream to the BPE vocab file.</param>
/// <param name="extraSpecialTokens">Extra special tokens other than the built-in ones for the model</param>
/// <param name="cacheSize">The size of the cache to use.</param>
/// <param name="normalizer">To normalize the text before tokenization</param>
/// <returns>The tokenizer</returns>
public static Tokenizer CreateByModelName(
tarekgh marked this conversation as resolved.
Show resolved Hide resolved
string modelName,
Stream vocabStream,
IReadOnlyDictionary<string, int>? extraSpecialTokens = null,
int cacheSize = LruCache<int[]>.DefaultCacheSize,
Normalizer? normalizer = null)
{
if (string.IsNullOrEmpty(modelName))
{
throw new ArgumentNullException(nameof(modelName));
}

(Dictionary<string, int> SpecialTokens, Regex Regex) tiktokenConfiguration = Tokenizer.GetTiktokenConfigurations(modelName);

if (extraSpecialTokens is not null)
{
foreach (var extraSpecialToken in extraSpecialTokens)
{
tiktokenConfiguration.SpecialTokens.Add(extraSpecialToken.Key, extraSpecialToken.Value);
}
}

return new Tokenizer(
new Tiktoken(vocabStream, tiktokenConfiguration.SpecialTokens, cacheSize),
new TikTokenPreTokenizer(tiktokenConfiguration.Regex, tiktokenConfiguration.SpecialTokens),
normalizer);
}

private static (Dictionary<StringSpanOrdinalKey, int>?, Dictionary<int, string>?) CreateEncoderDecoder(IReadOnlyDictionary<string, int>? specialTokens)
{
if (specialTokens is not null)
Expand Down Expand Up @@ -233,7 +271,7 @@ private static (Dictionary<StringSpanOrdinalKey, int>?, Dictionary<int, string>?
/// <param name="text">The text to encode.</param>
/// <param name="isSpecialToken">Indicate if the token is a special token.</param>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Existing: what token does this refer to? The only other thing specified is text

Copy link
Member Author

@tarekgh tarekgh Mar 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

token is the string word. can be any word like dog or it can be special token like <|endoftext|>. I can change Indicate if the token is a special token. to Indicate if the text represent a special token. or similar.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is that the parameter text, and does that text need to represent a single token? Or does it refer to all tokens within text?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the text can represent multiple tokens or represent one special token.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, so how does isSpecialToken apply to 'text` in the case it is multiple tokens?

Copy link
Member Author

@tarekgh tarekgh Mar 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is a flag telling if the input text is representing a special token or not so the encoder can treat it differently. Here is some example how this is used

if (isSpecialToken && _specialTokensEncoder is not null)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, so we should probably refer to text in the docs.

Here's my attempt:

Indicates if the <paramRef name="text"/> in it's entirety is a special token.  This method will throw if <paramRef name="isSpecialToken"/> is `true` and the specified <paramRef name="text"/> is not a special token.

Similarly it looks like the Count and EncodeToIds just return default values ignoring the text if it's not a special token so they could get a slightly different version of this.


return _specialTokensEncoder.TryGetValue(text, out _) ? 1 : 0;

I raised this issue because this parameter confused me when adopting the tokenizer.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm curious as to the use case for setting isSpecialToken to true...?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like the only scenario is when someone wants to pass in a single special token string and get the value for that. If we're making them specify a parameter to this API to do that they might as well just call a different API to do it and avoid the confusing parameter on this API.

I wonder what happens if someone specifies a special token string but forgets to set isSpecialToken?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tarekgh and I talked about this offline - we'll update the docs for these methods in a separate PR and discuss this during API review.

/// <returns>The list of tokens generated from the text tokenization.</returns>
public override IReadOnlyList<Token> Encode(string text, bool isSpecialToken)
public override IReadOnlyList<Token> Encode(string text, bool isSpecialToken = false)
{
Token[] tokens;

Expand Down Expand Up @@ -462,12 +500,14 @@ public override int CountTokens(ReadOnlySpan<char> text, bool isSpecialToken)
/// <returns>The decoded string.</returns>
public override string? Decode(IEnumerable<int> ids, TokenizerDecoder? decoder = null, bool considerSpecialTokens = true)
{
// Tiktoken does not ensure a one-to-one mapping between IDs and tokens. Consequently, decoding individual IDs into tokens is not supported;
// instead, decoding all IDs must be done collectively.
// Here is example of case that map one character to multiple Ids:
// '⭐' U-2B50 is mapped to Ids [2928, 99834] in the Tiktoken model.
// In other words, the character '⭐' has UTF-8 code point 0xE2, 0xAD, 0x90, Tiktoken will map 0xE2 to [2928] and 0xAD, 0x90 to [99834].

// Tiktoken doesn't guarantee a one-to-one correspondence between IDs and UTF-16 words.
// Consequently, decoding individual IDs into UTF-16 string is not supported; instead, decoding all IDs must be performed collectively.
// Here's an example case that maps one character to multiple IDs:
// '⭐' U-2B50 is mapped to IDs [2928, 99834] in the Tiktoken model.
// In other words, the character '⭐' with UTF-8 code point 0xE2, 0xAD, 0x90 will be mapped by Tiktoken as follows: 0xE2 to [2928]
// and 0xAD, 0x90 to [99834]. Decoding 2928 and 99834 individually won't reconstruct the original UTF-16 string '⭐' U-2B50;
// decoding all IDs together is required to get the expected result.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could imagine someone wanting an API like IEnumerable<byte> Decode(IEnumerable<ids> ids, ...). Presumably if that was desired we could always add it in the future.

if (ids is null)
{
return null;
Expand Down
142 changes: 122 additions & 20 deletions src/Microsoft.ML.Tokenizers/Tokenizer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
using System.Linq;
using System.Net.Http;
using System.Text.RegularExpressions;
using System.Threading;
Expand Down Expand Up @@ -136,6 +137,76 @@ public int CountTokens(string text, bool considerSpecialTokens = true)
return idsCount;
}

/// <summary>
/// Find the maximum encoding capacity within the input text without surpassing the token limit.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="maxTokenCount">The maximum token count to limit the encoding capacity.</param>
/// <param name="fromStart">Indicate if want to trim from the start of the text.</param>
tarekgh marked this conversation as resolved.
Show resolved Hide resolved
/// <param name="considerSpecialTokens">Indicate if want to consider the special tokens during the encoding.</param>
/// <returns>
/// The entire normalized text, the starting offset within the returned text for token counting, the length of text constrained by the maximum token count,
/// and the token count can be generated using the provided subtext offset and length.
/// </returns>
/// <exception cref="ArgumentNullException">The input text is null.</exception>
/// <exception cref="ArgumentOutOfRangeException">The maximum token count must be greater than 0.</exception>
/// <remarks>
/// If the tokenizer has a normalizer, the returned text will be the normalized text. Otherwise the returned text will be the input text.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it ever the case that someone might have pre-normalized text and want a method that doesn't do this normalization?

Copy link
Member Author

@tarekgh tarekgh Mar 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, when creating the tokenizer you have the option to provide the normalizer object. If you don't then the tokenizer will not do any normalization before processing the text.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just imagining that someone might want to call the normalizer once, then tell this method that they've already done the normalization and avoid double-normalization / allocation. My understanding of the use case for this API is to do a minimal amount of work so I was just asking myself "is there anything else that I can imagine someone might not want this method to do?" and normalization was the only thing I could imagine.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well written normalizer will return the original text without new allocation if there is no change from the original text. But I think processing time will be counted.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, users can create a copy of the tokenizer without the normalization and can be used in such scenario too.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, but I was more concerned with this API doing two things where someone might want it to do just one. It's the same feeling @stephentoub shared offline

The normalization stuff sneaking in here still rubs me the wrong way

I don't feel like it needs to be solved now, but it may be a topic during API review.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The normalization is part of tokenization. It is optional for some scenario but important to other scenarios. So the API is not really doing two things more than communicating the encoding results including the change in the text. Anyway, I am open to any better suggestion that can make us avoid any confusion or to get cleaner API shape.

/// If <paramref name="fromStart"/> is true, the returned offset will be 0. Otherwise the returned offset will be the starting index of the subtext.
/// If the provided <paramref name="maxTokenCount"/> is greater than the token count of the input text, the returned length will be the length of the input text.
/// If the provided <paramref name="maxTokenCount"/> is smaller enough to hold smallest number of grouped Ids, the returned length will be 0 and returned TokenCount will be 0.
/// </remarks>
public (string Text, int Offset, int Length, int TokenCount) TrimWithinTokenLimit(string text, int maxTokenCount, bool fromStart = true, bool considerSpecialTokens = true)
tarekgh marked this conversation as resolved.
Show resolved Hide resolved
{
if (text is null)
{
throw new ArgumentNullException(nameof(text));
}

if (maxTokenCount <= 0)
{
throw new ArgumentOutOfRangeException(nameof(maxTokenCount), "The max token count must be greater than 0.");
}

string normalized = Normalizer is not null ? Normalizer.Normalize(text) : text;
int idsCount = 0;

if (fromStart)
{
foreach (Split split in PreTokenizer.PreTokenize(normalized, considerSpecialTokens))
{
int tokenCount = Model.CountTokens(split.TokenSpan, split.IsSpecialToken);

if (tokenCount + idsCount > maxTokenCount)
{
return (normalized, 0, split.Offset.Index, idsCount);
}

idsCount += tokenCount;
}

return (normalized, 0, normalized.Length, idsCount);
}

// from end
Split[] splits = PreTokenizer.PreTokenize(normalized, considerSpecialTokens).ToArray();

for (int i = splits.Length - 1; i >= 0; i--)
{
Split split = splits[i];
int tokenCount = Model.CountTokens(split.TokenSpan, split.IsSpecialToken);

if (tokenCount + idsCount > maxTokenCount)
{
return (normalized, split.Offset.Index + split.Offset.Length, normalized.Length - split.Offset.Index - split.Offset.Length, idsCount);
}

idsCount += tokenCount;
}

return (normalized, 0, normalized.Length, idsCount);
tarekgh marked this conversation as resolved.
Show resolved Hide resolved
}

/// <summary>
/// Decodes the Id to the mapped token.
/// </summary>
Expand Down Expand Up @@ -230,6 +301,56 @@ private enum ModelEncoding
{ "gpt2", ModelEncoding.GPT2 }
};
tarekgh marked this conversation as resolved.
Show resolved Hide resolved

private static ModelEncoding GetModelEncoding(string modelName)
{
if (!_modelToEncoding.TryGetValue(modelName, out ModelEncoding encoder))
{
foreach ((string Prefix, ModelEncoding Encoding) in _modelPrefixToEncoding)
{
if (modelName.StartsWith(Prefix, StringComparison.OrdinalIgnoreCase))
{
encoder = Encoding;
break;
}
}
}

if (encoder == ModelEncoding.None)
{
throw new NotImplementedException($"Doesn't support this model [{modelName}]");
tarekgh marked this conversation as resolved.
Show resolved Hide resolved
}

return encoder;
}

internal static (Dictionary<string, int> SpecialTokens, Regex Regex) GetTiktokenConfigurations(string modelName)
{
ModelEncoding modelEncoding = GetModelEncoding(modelName);

switch (modelEncoding)
{
case ModelEncoding.Cl100kBase:
return (new Dictionary<string, int>
{ { EndOfText, 100257}, { FimPrefix, 100258}, { FimMiddle, 100259}, { FimSuffix, 100260}, { EndOfPrompt, 100276} }, Cl100kBaseRegex());

case ModelEncoding.P50kBase:
return (new Dictionary<string, int> { { EndOfText, 50256 } }, P50kBaseRegex());

case ModelEncoding.P50kEdit:
return (new Dictionary<string, int>
{ { EndOfText, 50256 }, { FimPrefix, 50281 }, { FimMiddle, 50282 }, { FimSuffix, 50283 } }, P50kBaseRegex());

case ModelEncoding.R50kBase:
return (new Dictionary<string, int> { { EndOfText, 50256 } }, P50kBaseRegex());

case ModelEncoding.GPT2:
return (new Dictionary<string, int> { { EndOfText, 50256 }, }, P50kBaseRegex());

default:
Debug.Assert(false, $"Unexpected encoder [{modelEncoding}]");
throw new NotImplementedException($"Doesn't support model '{modelName}'");
}
}

/// <summary>
/// Create tokenizer based on model name
Expand All @@ -247,26 +368,7 @@ private enum ModelEncoding
{
try
{
ModelEncoding encoder;

if (!_modelToEncoding.TryGetValue(modelName, out encoder))
{
foreach ((string Prefix, ModelEncoding Encoding) in _modelPrefixToEncoding)
{
if (modelName.StartsWith(Prefix, StringComparison.OrdinalIgnoreCase))
{
encoder = Encoding;
break;
}
}
}

if (encoder == ModelEncoding.None)
{
throw new NotImplementedException($"Doesn't support this model [{modelName}]");
}

return CreateByEncoderNameAsync(encoder, extraSpecialTokens, normalizer, cancellationToken);
return CreateByEncoderNameAsync(GetModelEncoding(modelName), extraSpecialTokens, normalizer, cancellationToken);
}
catch (Exception ex)
{
Expand Down
2 changes: 2 additions & 0 deletions test/Microsoft.ML.Tokenizers.Tests/BpeTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,8 @@ public async void TestGpt2Vocab()
Assert.Equal(12, encoding.Ids.Count);
Assert.Equal(encoding.Ids, ids);
Assert.Equal(12, tokenizer.CountTokens(text));

TokenizerTests.TestTokenLimits(tokenizer);
}

private static string WriteToMergeFile((string, string)[] mergeEntries)
Expand Down
1 change: 1 addition & 0 deletions test/Microsoft.ML.Tokenizers.Tests/EnglishRobertaTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ public async void TokenizationTest()

Tokenizer tokenizer = new Tokenizer(new EnglishRoberta(vocabFile, mergeFile, translationFile), RobertaPreTokenizer.Instance);
TestTokenizer(tokenizer);
TokenizerTests.TestTokenLimits(tokenizer);

tokenizer = new Tokenizer(new EnglishRoberta(vocabFile, mergeFile, translationFile, filterUnsupportedChars: false), RobertaPreTokenizer.Instance);
TestTokenizer(tokenizer);
Expand Down
19 changes: 14 additions & 5 deletions test/Microsoft.ML.Tokenizers.Tests/TitokenTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,12 @@ public async void TestTokenizerCreation()
tokenizer = new Tokenizer(await Tiktoken.CreateAsync(stream, specialTokensEncoder), GPT4.PreTokenizer);
}
TestGPT4TokenizationEncoding(tokenizer);

using (Stream stream = File.OpenRead(tokenizerDataFileName))
{
tokenizer = Tiktoken.CreateByModelName("gpt-4", stream);
}
TestGPT4TokenizationEncoding(tokenizer);
}
finally
{
Expand All @@ -82,6 +88,8 @@ private void TestGPT4TokenizationEncoding(Tokenizer tokenizer)
Assert.Equal(new List<(int, int)> { (0, 5), (5, 6) }, result.Offsets);
Assert.Equal(encoded.Count, idsCount);
Assert.Equal(encoded, result.Ids);

TestGPT4Tokenizer(tokenizer);
}

[Fact]
Expand All @@ -101,13 +109,12 @@ public void TestEncode1()
Assert.Equal(encoded, result.Ids);
}

[Fact]
public void TestEncode2()
private void TestGPT4Tokenizer(Tokenizer gpt4Tokenizer)
{
string text = ReadAndSanitizeFile("./Data/lib.rs.txt");
IReadOnlyList<int> encoded = GPT4.EncodeToIds(text, considerSpecialTokens: false);
IReadOnlyList<int> encoded = gpt4Tokenizer.EncodeToIds(text, considerSpecialTokens: false);
Assert.Equal(5584, encoded.Count);
int idsCount = GPT4.CountTokens(text, considerSpecialTokens: false);
int idsCount = gpt4Tokenizer.CountTokens(text, considerSpecialTokens: false);
Assert.Equal(encoded.Count, idsCount);

using (Stream stream = File.OpenRead("./Data/tokens.json"))
Expand All @@ -116,8 +123,10 @@ public void TestEncode2()
Assert.Equal(expected!, encoded.ToArray());
}

string? decoded = GPT4.Decode(encoded.ToArray());
string? decoded = gpt4Tokenizer.Decode(encoded.ToArray());
Assert.Equal(text, decoded!);

TokenizerTests.TestTokenLimits(gpt4Tokenizer);
}

[Fact]
Expand Down
67 changes: 67 additions & 0 deletions test/Microsoft.ML.Tokenizers.Tests/TokenizerTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using Microsoft.ML.Tokenizers;
using System;
using System.Collections.Generic;
using System.Linq;
using Xunit;

namespace Microsoft.ML.Tokenizers.Tests
{
public class TokenizerTests
{
internal static void TestTokenLimits(Tokenizer tokenizer)
{
string input = @"
OpenAI's large language models (sometimes referred to as GPT's) process text using tokens, which are common sequences of characters found in a set of text.
The models learn to understand the statistical relationships between these tokens, and excel at producing the next token in a sequence of tokens.
You can use the tool below to understand how a piece of text might be tokenized by a language model, and the total count of tokens in that piece of text.
It's important to note that the exact tokenization process varies between models. Newer models like GPT-3.5 and GPT-4 use a different tokenizer than previous models,
and will produce different tokens for the same input text.
";

IReadOnlyList<int> fullIdsList = tokenizer.EncodeToIds(input);

for (int i = 1; i <= fullIdsList.Count; i++)
{
(string Text, int Offset, int Length, int TokenCount) result1 = tokenizer.TrimWithinTokenLimit(input, maxTokenCount: i, fromStart: true);
(string Text, int Offset, int Length, int TokenCount) result2 = tokenizer.TrimWithinTokenLimit(input, maxTokenCount: i, fromStart: false);

IReadOnlyList<int>? prefixIds = null;
IReadOnlyList<int>? suffixIds = null;

if (result1.TokenCount > 0)
{
Assert.Equal(0, result1.Offset);
string prefixString = result1.Text.Substring(result1.Offset, result1.Length);
prefixIds = tokenizer.EncodeToIds(prefixString);
Assert.Equal(result1.TokenCount, prefixIds.Count);
Assert.Equal(prefixIds, fullIdsList.Take(prefixIds.Count));
}

if (result2.TokenCount > 0)
{
Assert.Equal(result2.Text.Length, result2.Offset + result2.Length);
string suffixString = result2.Text.Substring(result2.Offset, result2.Length);
suffixIds = tokenizer.EncodeToIds(suffixString);
Assert.Equal(result2.TokenCount, suffixIds.Count);
Assert.Equal(suffixIds, fullIdsList.Skip(fullIdsList.Count - suffixIds.Count));
}

if (i == fullIdsList.Count)
{
Assert.Equal(result1.Text.Length, result1.Length);
Assert.Equal(result2.Text.Length, result2.Length);
Assert.Equal(fullIdsList, prefixIds);
Assert.Equal(fullIdsList, suffixIds);
}
}

Assert.Throws<ArgumentOutOfRangeException>(() => tokenizer.TrimWithinTokenLimit(input, maxTokenCount: 0, fromStart: true));
Assert.Throws<ArgumentOutOfRangeException>(() => tokenizer.TrimWithinTokenLimit(input, maxTokenCount: -1, fromStart: true));
Assert.Throws<ArgumentNullException>(() => tokenizer.TrimWithinTokenLimit(null!, maxTokenCount: 0, fromStart: false));
}
}
}