Skip to content

Commit

Permalink
Remove some more overhead from GPT3Tokenizer (#675)
Browse files Browse the repository at this point in the history
### Motivation and Context

Some more low-hanging fruit reduction in GPT3Tokenizer.

### Description

- We can both simplify and make faster the parsing of the vocab.bpe file
- We can remove the SortedDictionary from BytePairEncoding, including a
full O(N) iteration of the dictionary on every iteration of the outer
loop (as part of the Min() call).


Co-authored-by: Devis Lucato <dluc@users.noreply.github.com>
Co-authored-by: Devis Lucato <devis@microsoft.com>
  • Loading branch information
3 people committed Apr 27, 2023
1 parent eea509e commit 8fc9d7a
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 48 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,8 @@ private static List<string> BytePairEncoding(string token)
word.Add(c.ToString());
}

var minPairs = new SortedDictionary<long, (string, string)>();
long smallestRank = long.MaxValue;
(string, string) smallestPair = ("", "");
List<string>? newWord = null;

while (word.Count >= 2)
Expand All @@ -207,23 +208,22 @@ private static List<string> BytePairEncoding(string token)
{
(string, string) pair = (word[pairIndex], word[pairIndex + 1]);

long minPairsRank = 100000000000;
if (GPT3Settings.BpeRanks.TryGetValue(pair, out int rank))
long pairRank = GPT3Settings.BpeRanks.TryGetValue(pair, out int rank) ? rank : 100_000_000_000;

if (pairRank <= smallestRank)
{
minPairsRank = rank;
smallestRank = pairRank;
smallestPair = pair;
}

minPairs[minPairsRank] = pair;
}

(string, string) biGram = minPairs[minPairs.Keys.Min()];
if (!GPT3Settings.BpeRanks.ContainsKey(biGram))
if (!GPT3Settings.BpeRanks.ContainsKey(smallestPair))
{
break;
}

string first = biGram.Item1;
string second = biGram.Item2;
string first = smallestPair.Item1;
string second = smallestPair.Item2;

newWord ??= new List<string>(word.Count);
for (int i = 0; i < word.Count; i++)
Expand Down Expand Up @@ -261,7 +261,7 @@ private static List<string> BytePairEncoding(string token)

// And reset state for the next go-around
newWord.Clear();
minPairs.Clear();
smallestRank = long.MaxValue;
}

s_bpeCache.TryAdd(token, word);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,62 +2,52 @@

using System;
using System.Collections.Generic;
using System.Linq;
using System.Diagnostics;
using System.Text.Json;
using Microsoft.SemanticKernel.AI;

namespace Microsoft.SemanticKernel.Connectors.AI.OpenAI.Tokenizers.Settings;

internal static class GPT3Settings
{
// Lazy load and cache encoding table (encoder.json)
/// <summary>Gets the cached encoding table (encoder.json).</summary>
internal static Dictionary<string, int> Encoder => s_encoder.Value;

// Lazy load and cache byte pair encoding table (vocab.bpe)
/// <summary>Gets the cached byte pair encoding table (vocab.bpe).</summary>
internal static Dictionary<(string, string), int> BpeRanks => s_bpeRanks.Value;

// Lazy load and cache encoding table (encoder.json)
private static readonly Lazy<Dictionary<string, int>> s_encoder = new(BuildEncoder);

// Lazy load and cache byte pair encoding table (vocab.bpe)
private static readonly Lazy<Dictionary<(string, string), int>> s_bpeRanks = new(BuildBpeRanks);

private static Dictionary<(string, string), int> BuildBpeRanks()
{
string[] lines = EmbeddedResource.ReadBytePairEncodingTable().Split('\n');
List<(string, string)> bpeMerges = new ArraySegment<string>(lines, 1, lines.Length - 1)
.Where(x => x.Trim().Length > 0)
.Select(x =>
{
string[] y = x.Split(' ');
return (y[0], y[1]);
}).ToList();
return DictZip(bpeMerges, Range(0, bpeMerges.Count));
}

private static Dictionary<string, int> BuildEncoder()
/// <summary>Lazy load the cached encoding table (encoder.json).</summary>
private static readonly Lazy<Dictionary<string, int>> s_encoder = new(() =>
{
string json = EmbeddedResource.ReadEncodingTable();
var encoder = JsonSerializer.Deserialize<Dictionary<string, int>>(json);

return encoder
return JsonSerializer.Deserialize<Dictionary<string, int>>(EmbeddedResource.ReadEncodingTable())
?? throw new AIException(AIException.ErrorCodes.InvalidConfiguration,
"Encoding table deserialization returned NULL");
}
});

private static Dictionary<(string, string), int> DictZip(List<(string, string)> x, List<int> y)
/// <summary>Lazy load the cached byte pair encoding table (vocab.bpe).</summary>
private static readonly Lazy<Dictionary<(string, string), int>> s_bpeRanks = new(() =>
{
string table = EmbeddedResource.ReadBytePairEncodingTable();
Debug.Assert(table.StartsWith("#version: 0.2", StringComparison.Ordinal));
// Skip past the header line
int pos = table.IndexOf('\n') + 1;
Debug.Assert(pos > 0);
// For each line, split on the first space and add the pair to the dictionary as a key with the value being the entry number.
var result = new Dictionary<(string, string), int>();
for (int i = 0; i < x.Count; i++)
int nextPos;
while ((nextPos = table.IndexOf('\n', pos)) >= 0)
{
result.Add(x[i], y[i]);
ReadOnlySpan<char> span = table.AsSpan(pos, nextPos - pos).Trim();
pos = span.IndexOf(' ');
if (pos >= 0)
{
result.Add((span.Slice(0, pos).ToString(), span.Slice(pos + 1).ToString()), result.Count);
}
pos = nextPos + 1;
}
return result;
}

private static List<int> Range(int x, int y)
{
return Enumerable.Range(x, y - x).ToList();
}
});
}

0 comments on commit 8fc9d7a

Please sign in to comment.