diff --git a/src/Microsoft.ML.Tokenizers/Model/BPETokenizer.cs b/src/Microsoft.ML.Tokenizers/Model/BPETokenizer.cs index 9874709eb7..f17845adf8 100644 --- a/src/Microsoft.ML.Tokenizers/Model/BPETokenizer.cs +++ b/src/Microsoft.ML.Tokenizers/Model/BPETokenizer.cs @@ -132,6 +132,11 @@ public static BpeTokenizer Create( return new BpeTokenizer(result.vocab, result.merges, preTokenizer, normalizer, specialTokens, unknownToken, continuingSubwordPrefix, endOfWordSuffix, fuseUnknownTokens); } + /// + /// Create a new Bpe tokenizer object to use for text encoding. + /// + /// The options used to create the Bpe tokenizer. + /// The Bpe tokenizer object. public static BpeTokenizer Create(BpeOptions options) { if (options is null) @@ -146,9 +151,9 @@ public static BpeTokenizer Create(BpeOptions options) Dictionary vocab = new Dictionary(1000); - foreach ((string token, int id) in options.Vocabulary) + foreach (KeyValuePair kvp in options.Vocabulary) { - vocab.Add(new StringSpanOrdinalKey(token), id); + vocab.Add(new StringSpanOrdinalKey(kvp.Key), kvp.Value); } if (vocab.Count == 0) @@ -395,7 +400,7 @@ private BpeTokenizer( /// /// Gets the optional beginning of sentence token. /// - internal string? BeginningOfSentenceToken { get; } + public string? BeginningOfSentenceToken { get; } /// /// The id of the beginning of sentence token. diff --git a/src/Microsoft.ML.Tokenizers/Model/BpeOptions.cs b/src/Microsoft.ML.Tokenizers/Model/BpeOptions.cs index 94c3e0913b..8eee50ac66 100644 --- a/src/Microsoft.ML.Tokenizers/Model/BpeOptions.cs +++ b/src/Microsoft.ML.Tokenizers/Model/BpeOptions.cs @@ -4,6 +4,8 @@ using System; using System.Collections.Generic; +using System.IO; +using System.Text.Json; namespace Microsoft.ML.Tokenizers { @@ -15,7 +17,9 @@ public sealed class BpeOptions /// /// Initializes a new instance of the class. /// - public BpeOptions(IEnumerable<(string Token, int Id)> vocabulary) + /// The vocabulary to use. + /// Thrown when is null. + public BpeOptions(IEnumerable> vocabulary) { if (vocabulary == null) { @@ -25,10 +29,74 @@ public BpeOptions(IEnumerable<(string Token, int Id)> vocabulary) Vocabulary = vocabulary; } + /// + /// Initializes a new instance of the class. + /// + /// The JSON file path containing the dictionary of string keys and their ids. + /// The file path containing the tokens's pairs list. + public BpeOptions(string vocabFile, string? mergesFile = null) + { + if (vocabFile is null) + { + throw new ArgumentNullException(nameof(vocabFile)); + } + + if (!File.Exists(vocabFile)) + { + throw new ArgumentException($"Could not find the vocabulary file '{vocabFile}'."); + } + + using Stream vocabStream = File.OpenRead(vocabFile); + Dictionary? dictionary = JsonSerializer.Deserialize>(vocabStream); + + if (dictionary is null) + { + throw new InvalidOperationException($"The content of the vocabulary file '{vocabFile}' is not valid."); + } + + Vocabulary = dictionary; + + if (mergesFile is not null) + { + if (!File.Exists(mergesFile)) + { + throw new ArgumentException($"Could not find the merges file '{mergesFile}'."); + } + + using Stream mergesStream = File.OpenRead(mergesFile); + using StreamReader reader = new(mergesStream); + + List merges = new(); + + int lineNumber = 0; + string? line; + + while ((line = reader.ReadLine()) is not null) + { + lineNumber++; + if (line.StartsWith("#version", StringComparison.Ordinal) || line.Length == 0) + { + continue; + } + + // validate the merges format + int index = line.IndexOf(' '); + if (index < 0 || index == line.Length - 1 || line.IndexOf(' ', index + 1) >= 0) + { + throw new InvalidOperationException($"Invalid merge file format at line: {lineNumber}"); + } + + merges.Add(line); + } + + Merges = merges; + } + } + /// /// Gets or sets the vocabulary to use. /// - public IEnumerable<(string Token, int Id)> Vocabulary { get; } + public IEnumerable> Vocabulary { get; } /// /// Gets or sets the list of the merge strings used to merge tokens during encoding. @@ -38,7 +106,7 @@ public BpeOptions(IEnumerable<(string Token, int Id)> vocabulary) /// /// Gets or sets the optional special tokens to use. /// - public Dictionary? SpecialTokens { get; set; } + public IReadOnlyDictionary? SpecialTokens { get; set; } /// /// Gets or sets the optional normalizer to normalize the input text before encoding it. diff --git a/src/Microsoft.ML.Tokenizers/Model/SentencePieceTokenizer.cs b/src/Microsoft.ML.Tokenizers/Model/SentencePieceTokenizer.cs index f41516e270..cb945d24fa 100644 --- a/src/Microsoft.ML.Tokenizers/Model/SentencePieceTokenizer.cs +++ b/src/Microsoft.ML.Tokenizers/Model/SentencePieceTokenizer.cs @@ -436,7 +436,7 @@ public OperationStatus Decode(IEnumerable ids, Span destination, bool /// https://github.com/google/sentencepiece/blob/master/src/sentencepiece_model.proto. /// /// The stream containing the SentencePiece Bpe or Unigram model. - /// Indicate emitting the beginning of sentence token during the encoding. + /// Indicate emitting the beginning of sentence token during the encoding. /// Indicate emitting the end of sentence token during the encoding. /// The additional tokens to add to the vocabulary. /// @@ -444,7 +444,7 @@ public OperationStatus Decode(IEnumerable ids, Span destination, bool /// public static SentencePieceTokenizer Create( Stream modelStream, - bool addBeginOfSentence = true, + bool addBeginningOfSentence = true, bool addEndOfSentence = false, IReadOnlyDictionary? specialTokens = null) { @@ -455,7 +455,7 @@ public static SentencePieceTokenizer Create( throw new ArgumentNullException(nameof(modelProto)); } - return new SentencePieceTokenizer(modelProto, addBeginOfSentence, addEndOfSentence, specialTokens); + return new SentencePieceTokenizer(modelProto, addBeginningOfSentence, addEndOfSentence, specialTokens); } } } diff --git a/src/Microsoft.ML.Tokenizers/PreTokenizer/CompositePreTokenizer.cs b/src/Microsoft.ML.Tokenizers/PreTokenizer/CompositePreTokenizer.cs index 5e1422bfab..5081296098 100644 --- a/src/Microsoft.ML.Tokenizers/PreTokenizer/CompositePreTokenizer.cs +++ b/src/Microsoft.ML.Tokenizers/PreTokenizer/CompositePreTokenizer.cs @@ -10,6 +10,8 @@ using System.Linq; using System.Text.RegularExpressions; +namespace Microsoft.ML.Tokenizers; + /// /// CompositePreTokenizer is a pre-tokenizer that applies multiple pre-tokenizers in sequence. /// diff --git a/test/Microsoft.ML.Tokenizers.Tests/BpeTests.cs b/test/Microsoft.ML.Tokenizers.Tests/BpeTests.cs index 6a0bdcb52b..7394464b90 100644 --- a/test/Microsoft.ML.Tokenizers.Tests/BpeTests.cs +++ b/test/Microsoft.ML.Tokenizers.Tests/BpeTests.cs @@ -257,6 +257,19 @@ public void SimpleTestWithUnknownToken( continuingSubwordPrefix: continuingSubwordPrefix, endOfWordSuffix: endOfWordSuffix, fuseUnknownTokens: fuseUnknownToken); SimpleWithUnknownTokenTest(bpe, sentence, offsets, ids, expectedTokens, decodedTokens, decodedTokensWithoutUnknownToken); + + BpeOptions bpeOptions = new BpeOptions(vocabFile, mergesFile) + { + PreTokenizer = PreTokenizer.CreateWordOrNonWord(), + Normalizer = null, + UnknownToken = unknownToken, + ContinuingSubwordPrefix = continuingSubwordPrefix, + EndOfWordSuffix = endOfWordSuffix, + FuseUnknownTokens = fuseUnknownToken + }; + + bpe = BpeTokenizer.Create(bpeOptions); + SimpleWithUnknownTokenTest(bpe, sentence, offsets, ids, expectedTokens, decodedTokens, decodedTokensWithoutUnknownToken); } finally { @@ -267,7 +280,7 @@ public void SimpleTestWithUnknownToken( } } - BpeOptions bpeOptions = new BpeOptions(vocab.Select(kvp => (kvp.Key, kvp.Value))) + BpeOptions bpeOptions1 = new BpeOptions(vocab) { Merges = merges?.Select(kvp => $"{kvp.Item1} {kvp.Item2}"), PreTokenizer = PreTokenizer.CreateWordOrNonWord(), @@ -278,7 +291,7 @@ public void SimpleTestWithUnknownToken( FuseUnknownTokens = fuseUnknownToken }; - BpeTokenizer bpe1 = BpeTokenizer.Create(bpeOptions); + BpeTokenizer bpe1 = BpeTokenizer.Create(bpeOptions1); SimpleWithUnknownTokenTest(bpe1, sentence, offsets, ids, expectedTokens, decodedTokens, decodedTokensWithoutUnknownToken); } @@ -387,7 +400,7 @@ public async Task TestBpeCreation() Dictionary? dictionary = JsonSerializer.Deserialize>(jsonString); bpe = BpeTokenizer.Create( - new BpeOptions(dictionary!.Select(kvp => (kvp.Key, kvp.Value))) + new BpeOptions(dictionary!) { Merges = File.ReadAllLines(mergesFile).Skip(1).ToArray() // Skip the first line which is the header "#version". }); @@ -928,11 +941,11 @@ private static BpeTokenizer CreateBpeTokenizerFromJson() return BpeTokenizer.Create(bpeOptions); } - private static IEnumerable<(string Token, int Id)> GetVocabulary(JsonElement vocabElement) + private static IEnumerable> GetVocabulary(JsonElement vocabElement) { foreach (JsonProperty token in vocabElement.EnumerateObject()) { - yield return (token.Name, token.Value.GetInt32()); + yield return new KeyValuePair(token.Name, token.Value.GetInt32()); } }