Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions src/Microsoft.ML.Tokenizers/Model/BPETokenizer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,11 @@ public static BpeTokenizer Create(
return new BpeTokenizer(result.vocab, result.merges, preTokenizer, normalizer, specialTokens, unknownToken, continuingSubwordPrefix, endOfWordSuffix, fuseUnknownTokens);
}

/// <summary>
/// Create a new Bpe tokenizer object to use for text encoding.
/// </summary>
/// <param name="options">The options used to create the Bpe tokenizer.</param>
/// <returns>The Bpe tokenizer object.</returns>
public static BpeTokenizer Create(BpeOptions options)
{
if (options is null)
Expand All @@ -146,9 +151,9 @@ public static BpeTokenizer Create(BpeOptions options)

Dictionary<StringSpanOrdinalKey, int> vocab = new Dictionary<StringSpanOrdinalKey, int>(1000);

foreach ((string token, int id) in options.Vocabulary)
foreach (KeyValuePair<string, int> kvp in options.Vocabulary)
{
vocab.Add(new StringSpanOrdinalKey(token), id);
vocab.Add(new StringSpanOrdinalKey(kvp.Key), kvp.Value);
}

if (vocab.Count == 0)
Expand Down Expand Up @@ -395,7 +400,7 @@ private BpeTokenizer(
/// <summary>
/// Gets the optional beginning of sentence token.
/// </summary>
internal string? BeginningOfSentenceToken { get; }
public string? BeginningOfSentenceToken { get; }

/// <summary>
/// The id of the beginning of sentence token.
Expand Down
74 changes: 71 additions & 3 deletions src/Microsoft.ML.Tokenizers/Model/BpeOptions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

using System;
using System.Collections.Generic;
using System.IO;
using System.Text.Json;

namespace Microsoft.ML.Tokenizers
{
Expand All @@ -15,7 +17,9 @@ public sealed class BpeOptions
/// <summary>
/// Initializes a new instance of the <see cref="BpeOptions"/> class.
/// </summary>
public BpeOptions(IEnumerable<(string Token, int Id)> vocabulary)
/// <param name="vocabulary">The vocabulary to use.</param>
/// <exception cref="ArgumentNullException">Thrown when <paramref name="vocabulary"/> is null.</exception>
public BpeOptions(IEnumerable<KeyValuePair<string, int>> vocabulary)
{
if (vocabulary == null)
{
Expand All @@ -25,10 +29,74 @@ public BpeOptions(IEnumerable<(string Token, int Id)> vocabulary)
Vocabulary = vocabulary;
}

/// <summary>
/// Initializes a new instance of the <see cref="BpeOptions"/> class.
/// </summary>
/// <param name="vocabFile">The JSON file path containing the dictionary of string keys and their ids.</param>
/// <param name="mergesFile">The file path containing the tokens's pairs list.</param>
public BpeOptions(string vocabFile, string? mergesFile = null)
{
if (vocabFile is null)
{
throw new ArgumentNullException(nameof(vocabFile));
}

if (!File.Exists(vocabFile))
{
throw new ArgumentException($"Could not find the vocabulary file '{vocabFile}'.");
}

using Stream vocabStream = File.OpenRead(vocabFile);
Dictionary<string, int>? dictionary = JsonSerializer.Deserialize<Dictionary<string, int>>(vocabStream);

if (dictionary is null)
{
throw new InvalidOperationException($"The content of the vocabulary file '{vocabFile}' is not valid.");
}

Vocabulary = dictionary;

if (mergesFile is not null)
{
if (!File.Exists(mergesFile))
{
throw new ArgumentException($"Could not find the merges file '{mergesFile}'.");
}

using Stream mergesStream = File.OpenRead(mergesFile);
using StreamReader reader = new(mergesStream);

List<string> merges = new();

int lineNumber = 0;
string? line;

while ((line = reader.ReadLine()) is not null)
{
lineNumber++;
if (line.StartsWith("#version", StringComparison.Ordinal) || line.Length == 0)
{
continue;
}

// validate the merges format
int index = line.IndexOf(' ');
if (index < 0 || index == line.Length - 1 || line.IndexOf(' ', index + 1) >= 0)
{
throw new InvalidOperationException($"Invalid merge file format at line: {lineNumber}");
}

merges.Add(line);
}

Merges = merges;
}
}

/// <summary>
/// Gets or sets the vocabulary to use.
/// </summary>
public IEnumerable<(string Token, int Id)> Vocabulary { get; }
public IEnumerable<KeyValuePair<string, int>> Vocabulary { get; }

/// <summary>
/// Gets or sets the list of the merge strings used to merge tokens during encoding.
Expand All @@ -38,7 +106,7 @@ public BpeOptions(IEnumerable<(string Token, int Id)> vocabulary)
/// <summary>
/// Gets or sets the optional special tokens to use.
/// </summary>
public Dictionary<string, int>? SpecialTokens { get; set; }
public IReadOnlyDictionary<string, int>? SpecialTokens { get; set; }

/// <summary>
/// Gets or sets the optional normalizer to normalize the input text before encoding it.
Expand Down
6 changes: 3 additions & 3 deletions src/Microsoft.ML.Tokenizers/Model/SentencePieceTokenizer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -436,15 +436,15 @@ public OperationStatus Decode(IEnumerable<int> ids, Span<char> destination, bool
/// https://github.com/google/sentencepiece/blob/master/src/sentencepiece_model.proto.
/// </summary>
/// <param name="modelStream">The stream containing the SentencePiece Bpe or Unigram model.</param>
/// <param name="addBeginOfSentence">Indicate emitting the beginning of sentence token during the encoding.</param>
/// <param name="addBeginningOfSentence">Indicate emitting the beginning of sentence token during the encoding.</param>
/// <param name="addEndOfSentence">Indicate emitting the end of sentence token during the encoding.</param>
/// <param name="specialTokens">The additional tokens to add to the vocabulary.</param>
/// <remarks>
/// When creating the tokenizer, ensure that the vocabulary stream is sourced from a trusted provider.
/// </remarks>
public static SentencePieceTokenizer Create(
Stream modelStream,
bool addBeginOfSentence = true,
bool addBeginningOfSentence = true,
bool addEndOfSentence = false,
IReadOnlyDictionary<string, int>? specialTokens = null)
{
Expand All @@ -455,7 +455,7 @@ public static SentencePieceTokenizer Create(
throw new ArgumentNullException(nameof(modelProto));
}

return new SentencePieceTokenizer(modelProto, addBeginOfSentence, addEndOfSentence, specialTokens);
return new SentencePieceTokenizer(modelProto, addBeginningOfSentence, addEndOfSentence, specialTokens);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
using System.Linq;
using System.Text.RegularExpressions;

namespace Microsoft.ML.Tokenizers;

/// <summary>
/// CompositePreTokenizer is a pre-tokenizer that applies multiple pre-tokenizers in sequence.
/// </summary>
Expand Down
23 changes: 18 additions & 5 deletions test/Microsoft.ML.Tokenizers.Tests/BpeTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,19 @@ public void SimpleTestWithUnknownToken(
continuingSubwordPrefix: continuingSubwordPrefix, endOfWordSuffix: endOfWordSuffix, fuseUnknownTokens: fuseUnknownToken);

SimpleWithUnknownTokenTest(bpe, sentence, offsets, ids, expectedTokens, decodedTokens, decodedTokensWithoutUnknownToken);

BpeOptions bpeOptions = new BpeOptions(vocabFile, mergesFile)
{
PreTokenizer = PreTokenizer.CreateWordOrNonWord(),
Normalizer = null,
UnknownToken = unknownToken,
ContinuingSubwordPrefix = continuingSubwordPrefix,
EndOfWordSuffix = endOfWordSuffix,
FuseUnknownTokens = fuseUnknownToken
};

bpe = BpeTokenizer.Create(bpeOptions);
SimpleWithUnknownTokenTest(bpe, sentence, offsets, ids, expectedTokens, decodedTokens, decodedTokensWithoutUnknownToken);
}
finally
{
Expand All @@ -267,7 +280,7 @@ public void SimpleTestWithUnknownToken(
}
}

BpeOptions bpeOptions = new BpeOptions(vocab.Select(kvp => (kvp.Key, kvp.Value)))
BpeOptions bpeOptions1 = new BpeOptions(vocab)
{
Merges = merges?.Select(kvp => $"{kvp.Item1} {kvp.Item2}"),
PreTokenizer = PreTokenizer.CreateWordOrNonWord(),
Expand All @@ -278,7 +291,7 @@ public void SimpleTestWithUnknownToken(
FuseUnknownTokens = fuseUnknownToken
};

BpeTokenizer bpe1 = BpeTokenizer.Create(bpeOptions);
BpeTokenizer bpe1 = BpeTokenizer.Create(bpeOptions1);
SimpleWithUnknownTokenTest(bpe1, sentence, offsets, ids, expectedTokens, decodedTokens, decodedTokensWithoutUnknownToken);
}

Expand Down Expand Up @@ -387,7 +400,7 @@ public async Task TestBpeCreation()
Dictionary<string, int>? dictionary = JsonSerializer.Deserialize<Dictionary<string, int>>(jsonString);

bpe = BpeTokenizer.Create(
new BpeOptions(dictionary!.Select(kvp => (kvp.Key, kvp.Value)))
new BpeOptions(dictionary!)
{
Merges = File.ReadAllLines(mergesFile).Skip(1).ToArray() // Skip the first line which is the header "#version".
});
Expand Down Expand Up @@ -928,11 +941,11 @@ private static BpeTokenizer CreateBpeTokenizerFromJson()
return BpeTokenizer.Create(bpeOptions);
}

private static IEnumerable<(string Token, int Id)> GetVocabulary(JsonElement vocabElement)
private static IEnumerable<KeyValuePair<string, int>> GetVocabulary(JsonElement vocabElement)
{
foreach (JsonProperty token in vocabElement.EnumerateObject())
{
yield return (token.Name, token.Value.GetInt32());
yield return new KeyValuePair<string, int>(token.Name, token.Value.GetInt32());
}
}

Expand Down