Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove some more overhead from GPT3Tokenizer #675

Merged
merged 4 commits into from
Apr 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,8 @@ private static List<string> BytePairEncoding(string token)
word.Add(c.ToString());
}

var minPairs = new SortedDictionary<long, (string, string)>();
long smallestRank = long.MaxValue;
(string, string) smallestPair = ("", "");
List<string>? newWord = null;

while (word.Count >= 2)
Expand All @@ -207,23 +208,22 @@ private static List<string> BytePairEncoding(string token)
{
(string, string) pair = (word[pairIndex], word[pairIndex + 1]);

long minPairsRank = 100000000000;
if (GPT3Settings.BpeRanks.TryGetValue(pair, out int rank))
long pairRank = GPT3Settings.BpeRanks.TryGetValue(pair, out int rank) ? rank : 100_000_000_000;

if (pairRank <= smallestRank)
{
minPairsRank = rank;
smallestRank = pairRank;
smallestPair = pair;
}

minPairs[minPairsRank] = pair;
}

(string, string) biGram = minPairs[minPairs.Keys.Min()];
if (!GPT3Settings.BpeRanks.ContainsKey(biGram))
if (!GPT3Settings.BpeRanks.ContainsKey(smallestPair))
{
break;
}

string first = biGram.Item1;
string second = biGram.Item2;
string first = smallestPair.Item1;
string second = smallestPair.Item2;

newWord ??= new List<string>(word.Count);
for (int i = 0; i < word.Count; i++)
Expand Down Expand Up @@ -261,7 +261,7 @@ private static List<string> BytePairEncoding(string token)

// And reset state for the next go-around
newWord.Clear();
minPairs.Clear();
smallestRank = long.MaxValue;
}

s_bpeCache.TryAdd(token, word);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,62 +2,52 @@

using System;
using System.Collections.Generic;
using System.Linq;
using System.Diagnostics;
using System.Text.Json;
using Microsoft.SemanticKernel.AI;

namespace Microsoft.SemanticKernel.Connectors.AI.OpenAI.Tokenizers.Settings;

internal static class GPT3Settings
{
// Lazy load and cache encoding table (encoder.json)
/// <summary>Gets the cached encoding table (encoder.json).</summary>
internal static Dictionary<string, int> Encoder => s_encoder.Value;

// Lazy load and cache byte pair encoding table (vocab.bpe)
/// <summary>Gets the cached byte pair encoding table (vocab.bpe).</summary>
internal static Dictionary<(string, string), int> BpeRanks => s_bpeRanks.Value;

// Lazy load and cache encoding table (encoder.json)
private static readonly Lazy<Dictionary<string, int>> s_encoder = new(BuildEncoder);

// Lazy load and cache byte pair encoding table (vocab.bpe)
private static readonly Lazy<Dictionary<(string, string), int>> s_bpeRanks = new(BuildBpeRanks);

private static Dictionary<(string, string), int> BuildBpeRanks()
{
string[] lines = EmbeddedResource.ReadBytePairEncodingTable().Split('\n');
List<(string, string)> bpeMerges = new ArraySegment<string>(lines, 1, lines.Length - 1)
.Where(x => x.Trim().Length > 0)
.Select(x =>
{
string[] y = x.Split(' ');
return (y[0], y[1]);
}).ToList();
return DictZip(bpeMerges, Range(0, bpeMerges.Count));
}

private static Dictionary<string, int> BuildEncoder()
/// <summary>Lazy load the cached encoding table (encoder.json).</summary>
private static readonly Lazy<Dictionary<string, int>> s_encoder = new(() =>
{
string json = EmbeddedResource.ReadEncodingTable();
var encoder = JsonSerializer.Deserialize<Dictionary<string, int>>(json);

return encoder
return JsonSerializer.Deserialize<Dictionary<string, int>>(EmbeddedResource.ReadEncodingTable())
?? throw new AIException(AIException.ErrorCodes.InvalidConfiguration,
"Encoding table deserialization returned NULL");
}
});

private static Dictionary<(string, string), int> DictZip(List<(string, string)> x, List<int> y)
/// <summary>Lazy load the cached byte pair encoding table (vocab.bpe).</summary>
private static readonly Lazy<Dictionary<(string, string), int>> s_bpeRanks = new(() =>
{
string table = EmbeddedResource.ReadBytePairEncodingTable();
Debug.Assert(table.StartsWith("#version: 0.2", StringComparison.Ordinal));

// Skip past the header line
int pos = table.IndexOf('\n') + 1;
Debug.Assert(pos > 0);

// For each line, split on the first space and add the pair to the dictionary as a key with the value being the entry number.
var result = new Dictionary<(string, string), int>();
for (int i = 0; i < x.Count; i++)
int nextPos;
while ((nextPos = table.IndexOf('\n', pos)) >= 0)
{
result.Add(x[i], y[i]);
ReadOnlySpan<char> span = table.AsSpan(pos, nextPos - pos).Trim();
pos = span.IndexOf(' ');
if (pos >= 0)
{
result.Add((span.Slice(0, pos).ToString(), span.Slice(pos + 1).ToString()), result.Count);
}
pos = nextPos + 1;
}

return result;
}

private static List<int> Range(int x, int y)
{
return Enumerable.Range(x, y - x).ToList();
}
});
}