diff --git a/src/Microsoft.ML.Tokenizers/Model/BPETokenizer.cs b/src/Microsoft.ML.Tokenizers/Model/BPETokenizer.cs
index 9874709eb7..f17845adf8 100644
--- a/src/Microsoft.ML.Tokenizers/Model/BPETokenizer.cs
+++ b/src/Microsoft.ML.Tokenizers/Model/BPETokenizer.cs
@@ -132,6 +132,11 @@ public static BpeTokenizer Create(
return new BpeTokenizer(result.vocab, result.merges, preTokenizer, normalizer, specialTokens, unknownToken, continuingSubwordPrefix, endOfWordSuffix, fuseUnknownTokens);
}
+ ///
+ /// Create a new Bpe tokenizer object to use for text encoding.
+ ///
+ /// The options used to create the Bpe tokenizer.
+ /// The Bpe tokenizer object.
public static BpeTokenizer Create(BpeOptions options)
{
if (options is null)
@@ -146,9 +151,9 @@ public static BpeTokenizer Create(BpeOptions options)
Dictionary vocab = new Dictionary(1000);
- foreach ((string token, int id) in options.Vocabulary)
+ foreach (KeyValuePair kvp in options.Vocabulary)
{
- vocab.Add(new StringSpanOrdinalKey(token), id);
+ vocab.Add(new StringSpanOrdinalKey(kvp.Key), kvp.Value);
}
if (vocab.Count == 0)
@@ -395,7 +400,7 @@ private BpeTokenizer(
///
/// Gets the optional beginning of sentence token.
///
- internal string? BeginningOfSentenceToken { get; }
+ public string? BeginningOfSentenceToken { get; }
///
/// The id of the beginning of sentence token.
diff --git a/src/Microsoft.ML.Tokenizers/Model/BpeOptions.cs b/src/Microsoft.ML.Tokenizers/Model/BpeOptions.cs
index 94c3e0913b..8eee50ac66 100644
--- a/src/Microsoft.ML.Tokenizers/Model/BpeOptions.cs
+++ b/src/Microsoft.ML.Tokenizers/Model/BpeOptions.cs
@@ -4,6 +4,8 @@
using System;
using System.Collections.Generic;
+using System.IO;
+using System.Text.Json;
namespace Microsoft.ML.Tokenizers
{
@@ -15,7 +17,9 @@ public sealed class BpeOptions
///
/// Initializes a new instance of the class.
///
- public BpeOptions(IEnumerable<(string Token, int Id)> vocabulary)
+ /// The vocabulary to use.
+ /// Thrown when is null.
+ public BpeOptions(IEnumerable> vocabulary)
{
if (vocabulary == null)
{
@@ -25,10 +29,74 @@ public BpeOptions(IEnumerable<(string Token, int Id)> vocabulary)
Vocabulary = vocabulary;
}
+ ///
+ /// Initializes a new instance of the class.
+ ///
+ /// The JSON file path containing the dictionary of string keys and their ids.
+ /// The file path containing the tokens's pairs list.
+ public BpeOptions(string vocabFile, string? mergesFile = null)
+ {
+ if (vocabFile is null)
+ {
+ throw new ArgumentNullException(nameof(vocabFile));
+ }
+
+ if (!File.Exists(vocabFile))
+ {
+ throw new ArgumentException($"Could not find the vocabulary file '{vocabFile}'.");
+ }
+
+ using Stream vocabStream = File.OpenRead(vocabFile);
+ Dictionary? dictionary = JsonSerializer.Deserialize>(vocabStream);
+
+ if (dictionary is null)
+ {
+ throw new InvalidOperationException($"The content of the vocabulary file '{vocabFile}' is not valid.");
+ }
+
+ Vocabulary = dictionary;
+
+ if (mergesFile is not null)
+ {
+ if (!File.Exists(mergesFile))
+ {
+ throw new ArgumentException($"Could not find the merges file '{mergesFile}'.");
+ }
+
+ using Stream mergesStream = File.OpenRead(mergesFile);
+ using StreamReader reader = new(mergesStream);
+
+ List merges = new();
+
+ int lineNumber = 0;
+ string? line;
+
+ while ((line = reader.ReadLine()) is not null)
+ {
+ lineNumber++;
+ if (line.StartsWith("#version", StringComparison.Ordinal) || line.Length == 0)
+ {
+ continue;
+ }
+
+ // validate the merges format
+ int index = line.IndexOf(' ');
+ if (index < 0 || index == line.Length - 1 || line.IndexOf(' ', index + 1) >= 0)
+ {
+ throw new InvalidOperationException($"Invalid merge file format at line: {lineNumber}");
+ }
+
+ merges.Add(line);
+ }
+
+ Merges = merges;
+ }
+ }
+
///
/// Gets or sets the vocabulary to use.
///
- public IEnumerable<(string Token, int Id)> Vocabulary { get; }
+ public IEnumerable> Vocabulary { get; }
///
/// Gets or sets the list of the merge strings used to merge tokens during encoding.
@@ -38,7 +106,7 @@ public BpeOptions(IEnumerable<(string Token, int Id)> vocabulary)
///
/// Gets or sets the optional special tokens to use.
///
- public Dictionary? SpecialTokens { get; set; }
+ public IReadOnlyDictionary? SpecialTokens { get; set; }
///
/// Gets or sets the optional normalizer to normalize the input text before encoding it.
diff --git a/src/Microsoft.ML.Tokenizers/Model/SentencePieceTokenizer.cs b/src/Microsoft.ML.Tokenizers/Model/SentencePieceTokenizer.cs
index f41516e270..cb945d24fa 100644
--- a/src/Microsoft.ML.Tokenizers/Model/SentencePieceTokenizer.cs
+++ b/src/Microsoft.ML.Tokenizers/Model/SentencePieceTokenizer.cs
@@ -436,7 +436,7 @@ public OperationStatus Decode(IEnumerable ids, Span destination, bool
/// https://github.com/google/sentencepiece/blob/master/src/sentencepiece_model.proto.
///
/// The stream containing the SentencePiece Bpe or Unigram model.
- /// Indicate emitting the beginning of sentence token during the encoding.
+ /// Indicate emitting the beginning of sentence token during the encoding.
/// Indicate emitting the end of sentence token during the encoding.
/// The additional tokens to add to the vocabulary.
///
@@ -444,7 +444,7 @@ public OperationStatus Decode(IEnumerable ids, Span destination, bool
///
public static SentencePieceTokenizer Create(
Stream modelStream,
- bool addBeginOfSentence = true,
+ bool addBeginningOfSentence = true,
bool addEndOfSentence = false,
IReadOnlyDictionary? specialTokens = null)
{
@@ -455,7 +455,7 @@ public static SentencePieceTokenizer Create(
throw new ArgumentNullException(nameof(modelProto));
}
- return new SentencePieceTokenizer(modelProto, addBeginOfSentence, addEndOfSentence, specialTokens);
+ return new SentencePieceTokenizer(modelProto, addBeginningOfSentence, addEndOfSentence, specialTokens);
}
}
}
diff --git a/src/Microsoft.ML.Tokenizers/PreTokenizer/CompositePreTokenizer.cs b/src/Microsoft.ML.Tokenizers/PreTokenizer/CompositePreTokenizer.cs
index 5e1422bfab..5081296098 100644
--- a/src/Microsoft.ML.Tokenizers/PreTokenizer/CompositePreTokenizer.cs
+++ b/src/Microsoft.ML.Tokenizers/PreTokenizer/CompositePreTokenizer.cs
@@ -10,6 +10,8 @@
using System.Linq;
using System.Text.RegularExpressions;
+namespace Microsoft.ML.Tokenizers;
+
///
/// CompositePreTokenizer is a pre-tokenizer that applies multiple pre-tokenizers in sequence.
///
diff --git a/test/Microsoft.ML.Tokenizers.Tests/BpeTests.cs b/test/Microsoft.ML.Tokenizers.Tests/BpeTests.cs
index 6a0bdcb52b..7394464b90 100644
--- a/test/Microsoft.ML.Tokenizers.Tests/BpeTests.cs
+++ b/test/Microsoft.ML.Tokenizers.Tests/BpeTests.cs
@@ -257,6 +257,19 @@ public void SimpleTestWithUnknownToken(
continuingSubwordPrefix: continuingSubwordPrefix, endOfWordSuffix: endOfWordSuffix, fuseUnknownTokens: fuseUnknownToken);
SimpleWithUnknownTokenTest(bpe, sentence, offsets, ids, expectedTokens, decodedTokens, decodedTokensWithoutUnknownToken);
+
+ BpeOptions bpeOptions = new BpeOptions(vocabFile, mergesFile)
+ {
+ PreTokenizer = PreTokenizer.CreateWordOrNonWord(),
+ Normalizer = null,
+ UnknownToken = unknownToken,
+ ContinuingSubwordPrefix = continuingSubwordPrefix,
+ EndOfWordSuffix = endOfWordSuffix,
+ FuseUnknownTokens = fuseUnknownToken
+ };
+
+ bpe = BpeTokenizer.Create(bpeOptions);
+ SimpleWithUnknownTokenTest(bpe, sentence, offsets, ids, expectedTokens, decodedTokens, decodedTokensWithoutUnknownToken);
}
finally
{
@@ -267,7 +280,7 @@ public void SimpleTestWithUnknownToken(
}
}
- BpeOptions bpeOptions = new BpeOptions(vocab.Select(kvp => (kvp.Key, kvp.Value)))
+ BpeOptions bpeOptions1 = new BpeOptions(vocab)
{
Merges = merges?.Select(kvp => $"{kvp.Item1} {kvp.Item2}"),
PreTokenizer = PreTokenizer.CreateWordOrNonWord(),
@@ -278,7 +291,7 @@ public void SimpleTestWithUnknownToken(
FuseUnknownTokens = fuseUnknownToken
};
- BpeTokenizer bpe1 = BpeTokenizer.Create(bpeOptions);
+ BpeTokenizer bpe1 = BpeTokenizer.Create(bpeOptions1);
SimpleWithUnknownTokenTest(bpe1, sentence, offsets, ids, expectedTokens, decodedTokens, decodedTokensWithoutUnknownToken);
}
@@ -387,7 +400,7 @@ public async Task TestBpeCreation()
Dictionary? dictionary = JsonSerializer.Deserialize>(jsonString);
bpe = BpeTokenizer.Create(
- new BpeOptions(dictionary!.Select(kvp => (kvp.Key, kvp.Value)))
+ new BpeOptions(dictionary!)
{
Merges = File.ReadAllLines(mergesFile).Skip(1).ToArray() // Skip the first line which is the header "#version".
});
@@ -928,11 +941,11 @@ private static BpeTokenizer CreateBpeTokenizerFromJson()
return BpeTokenizer.Create(bpeOptions);
}
- private static IEnumerable<(string Token, int Id)> GetVocabulary(JsonElement vocabElement)
+ private static IEnumerable> GetVocabulary(JsonElement vocabElement)
{
foreach (JsonProperty token in vocabElement.EnumerateObject())
{
- yield return (token.Name, token.Value.GetInt32());
+ yield return new KeyValuePair(token.Name, token.Value.GetInt32());
}
}