Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
167 changes: 138 additions & 29 deletions src/Microsoft.ML.Tokenizers/Model/BertTokenizer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -290,9 +290,25 @@ public IReadOnlyList<int> BuildInputsWithSpecialTokens(IEnumerable<int> tokenIds
throw new ArgumentNullException(nameof(tokenIds0));
}

// Add 2 for [CLS] and [SEP] tokens. Add 1 for [SEP] token if tokenIds1 is not null.
int capacity = tokenIds0.Count() + 2 + (tokenIds1 is null ? 0 : tokenIds1.Count() + 1);
List<int> ids = new List<int>(capacity: capacity) { ClsTokenId };
List<int> ids;

if (tokenIds0 is ICollection<int> c1)
{
int capacity = c1.Count + 2; // Add 2 for [CLS] and two [SEP] tokens.

if (tokenIds1 is not null)
{
capacity += tokenIds1 is ICollection<int> c2 ? c2.Count + 1 : c1.Count + 1;
}

ids = new(capacity) { ClsTokenId };
}
else
{
// slow path
ids = new List<int>(10) { ClsTokenId };
}

ids.AddRange(tokenIds0);
ids.Add(SepTokenId);

Expand Down Expand Up @@ -323,29 +339,48 @@ public OperationStatus BuildInputsWithSpecialTokens(IEnumerable<int> tokenIds0,
throw new ArgumentNullException(nameof(tokenIds0));
}

// Add 2 for [CLS] and [SEP] tokens. Add 1 for [SEP] token if tokenIds1 is not null.
int capacity = tokenIds0.Count() + 2 + (tokenIds1 is null ? 0 : tokenIds1.Count() + 1);
if (buffer.Length < capacity)
written = 0;
if (buffer.Length < 1)
{
written = 0;
return OperationStatus.DestinationTooSmall;
}

written = 0;
buffer[written++] = ClsTokenId;
foreach (int id in tokenIds0)
{
if (buffer.Length <= written)
{
written = 0;
return OperationStatus.DestinationTooSmall;
}

buffer[written++] = id;
}

if (buffer.Length <= written)
{
written = 0;
return OperationStatus.DestinationTooSmall;
}
buffer[written++] = SepTokenId;

if (tokenIds1 is not null)
{
foreach (int id in tokenIds1)
{
if (buffer.Length <= written)
{
written = 0;
return OperationStatus.DestinationTooSmall;
}
buffer[written++] = id;
}

if (buffer.Length <= written)
{
written = 0;
return OperationStatus.DestinationTooSmall;
}
buffer[written++] = SepTokenId;
}

Expand All @@ -367,11 +402,22 @@ public IReadOnlyList<int> GetSpecialTokensMask(IEnumerable<int> tokenIds0, IEnum
throw new ArgumentNullException(nameof(tokenIds0));
}

int capacity = alreadyHasSpecialTokens ?
tokenIds0.Count() + (tokenIds1?.Count() ?? 0) :
tokenIds0.Count() + 2 + (tokenIds1 is null ? 0 : 1); // Add 2 for [CLS] and [SEP] tokens. Add 1 for [SEP] token if tokenIds1 is not null.
List<int> mask;
if (tokenIds0 is ICollection<int> c1)
{
int capcity = c1.Count + 2;

if (tokenIds1 is not null)
{
capcity += tokenIds1 is ICollection<int> c2 ? c2.Count + 1 : c1.Count + 1;
}

List<int> mask = new List<int>(capacity: capacity);
mask = new List<int>(capcity);
}
else
{
mask = new List<int>(10);
}

if (!alreadyHasSpecialTokens)
{
Expand Down Expand Up @@ -420,31 +466,49 @@ public OperationStatus GetSpecialTokensMask(IEnumerable<int> tokenIds0, Span<int
throw new ArgumentNullException(nameof(tokenIds0));
}

int capacity = alreadyHasSpecialTokens ?
tokenIds0.Count() + (tokenIds1?.Count() ?? 0) :
tokenIds0.Count() + 2 + (tokenIds1 is null ? 0 : tokenIds1.Count() + 1); // Add 2 for [CLS] and [SEP] tokens. Add 1 for [SEP] token if tokenIds1 is not null.

written = 0;
if (buffer.Length < capacity)
{
return OperationStatus.DestinationTooSmall;
}

if (!alreadyHasSpecialTokens)
{
if (buffer.Length < 1)
{
return OperationStatus.DestinationTooSmall;
}
buffer[written++] = 1; // CLS

foreach (int id in tokenIds0)
{
if (buffer.Length <= written)
{
written = 0;
return OperationStatus.DestinationTooSmall;
}
buffer[written++] = 0;
}

if (buffer.Length <= written)
{
written = 0;
return OperationStatus.DestinationTooSmall;
}
buffer[written++] = 1; // SEP

if (tokenIds1 is not null)
{
foreach (int id in tokenIds1)
{
if (buffer.Length <= written)
{
written = 0;
return OperationStatus.DestinationTooSmall;
}
buffer[written++] = 0;
}

if (buffer.Length <= written)
{
written = 0;
return OperationStatus.DestinationTooSmall;
}
buffer[written++] = 1; // SEP
}

Expand All @@ -453,13 +517,23 @@ public OperationStatus GetSpecialTokensMask(IEnumerable<int> tokenIds0, Span<int

foreach (int id in tokenIds0)
{
if (buffer.Length <= written)
{
written = 0;
return OperationStatus.DestinationTooSmall;
}
buffer[written++] = id == ClsTokenId || id == SepTokenId || id == PadTokenId || id == MaskTokenId || id == UnknownTokenId ? 1 : 0;
}

if (tokenIds1 is not null)
{
foreach (int id in tokenIds1)
{
if (buffer.Length <= written)
{
written = 0;
return OperationStatus.DestinationTooSmall;
}
buffer[written++] = id == ClsTokenId || id == SepTokenId || id == PadTokenId || id == MaskTokenId || id == UnknownTokenId ? 1 : 0;
}
}
Expand All @@ -484,21 +558,38 @@ public IReadOnlyList<int> CreateTokenTypeIdsFromSequences(IEnumerable<int> token
throw new ArgumentNullException(nameof(tokenIds0));
}

// Add 2 for [CLS] and [SEP] tokens. Add 1 for [SEP] token if tokenIds1 is not null.
int capacity = tokenIds0.Count() + 2 + (tokenIds1 is null ? 0 : tokenIds1.Count() + 1);
List<int> typeIds;
if (tokenIds0 is ICollection<int> c1)
{
int capacity = c1.Count + 2; // Add 2 for [CLS] and [SEP] tokens.

if (tokenIds1 is not null)
{
capacity += tokenIds1 is ICollection<int> c2 ? c2.Count + 1 : c1.Count + 1;
}

List<int> typeIds = new List<int>(capacity);
for (int i = 0; i < tokenIds0.Count() + 2; i++) // Add 2 for [CLS] and [SEP] tokens.
typeIds = new List<int>(capacity);
}
else
{
typeIds = new List<int>(10);
}

foreach (var id in tokenIds0)
{
typeIds.Add(0);
}
typeIds.Add(0); // [CLS]
typeIds.Add(0); // [SEP]

if (tokenIds1 is not null)
{
for (int i = 0; i < tokenIds1.Count() + 1; i++) // Add 1 for [SEP] token.
foreach (int id in tokenIds1)
{
typeIds.Add(1);
}

typeIds.Add(1); // [SEP]
}

return typeIds;
Expand All @@ -515,22 +606,40 @@ public OperationStatus CreateTokenTypeIdsFromSequences(IEnumerable<int> tokenIds

// Add 2 for [CLS] and [SEP] tokens. Add 1 for [SEP] token if tokenIds1 is not null.
int capacity = tokenIds0.Count() + 2 + (tokenIds1 is null ? 0 : tokenIds1.Count() + 1);
if (buffer.Length < capacity)
if (buffer.Length < 2)
{
return OperationStatus.DestinationTooSmall;
}
buffer[written++] = 0; // [CLS]
buffer[written++] = 0; // [SEP]

for (int i = 0; i < tokenIds0.Count() + 2; i++) // Add 2 for [CLS] and [SEP] tokens.
foreach (int id in tokenIds0)
{
if (buffer.Length <= written)
{
written = 0;
return OperationStatus.DestinationTooSmall;
}
buffer[written++] = 0;
}

if (tokenIds1 is not null)
{
for (int i = 0; i < tokenIds1.Count() + 1; i++) // Add 1 for [SEP] token.
foreach (int id in tokenIds1)
{
if (buffer.Length <= written)
{
written = 0;
return OperationStatus.DestinationTooSmall;
}
buffer[written++] = 1;
}

if (buffer.Length < written)
{
return OperationStatus.DestinationTooSmall;
}
buffer[written++] = 1; // [SEP]
}

return OperationStatus.Done;
Expand Down
4 changes: 2 additions & 2 deletions src/Microsoft.ML.Tokenizers/Model/WordPieceTokenizer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ await CreateAsync(
continuingSubwordPrefix,
maxInputCharsPerWord,
cancellationToken,
disposeStream: true);
disposeStream: true).ConfigureAwait(false);

/// <summary>
/// Create a new instance of the <see cref="WordPieceTokenizer"/> class asynchronously.
Expand All @@ -259,7 +259,7 @@ public static async Task<WordPieceTokenizer> CreateAsync(
string continuingSubwordPrefix = DefaultContinuingSubwordPrefix,
int maxInputCharsPerWord = DefaultMaxInputCharsPerWord,
CancellationToken cancellationToken = default) =>
await CreateAsync(vocabStream, preTokenizer, normalizer, specialTokens, unknownToken, continuingSubwordPrefix, maxInputCharsPerWord, cancellationToken, disposeStream: false);
await CreateAsync(vocabStream, preTokenizer, normalizer, specialTokens, unknownToken, continuingSubwordPrefix, maxInputCharsPerWord, cancellationToken, disposeStream: false).ConfigureAwait(false);

private static async Task<WordPieceTokenizer> CreateAsync(
Stream vocabStream,
Expand Down
22 changes: 11 additions & 11 deletions src/Microsoft.ML.Tokenizers/Normalizer/BertNormalizer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ public override string Normalize(string original)

if (category == UnicodeCategory.SpaceSeparator)
{
InsertChar(ref buffer, ref index, ' ');
AddChar(ref buffer, ref index, ' ');
i += inc;
continue;
}
Expand All @@ -85,30 +85,30 @@ public override string Normalize(string original)
int length = original.AsSpan().Slice(i, inc + 1).ToLowerInvariant(casingBuffer);
Debug.Assert(length > 0);

InsertSpan(ref buffer, ref index, casingBuffer.Slice(0, length));
AddSpan(ref buffer, ref index, casingBuffer.Slice(0, length));

i += inc;
continue;
}

if (_tokenizeChineseChars && IsChineseChar(codePoint))
{
InsertChar(ref buffer, ref index, ' ');
InsertChar(ref buffer, ref index, c);
AddChar(ref buffer, ref index, ' ');
AddChar(ref buffer, ref index, c);
if (inc > 0)
{
InsertChar(ref buffer, ref index, original[i + 1]);
AddChar(ref buffer, ref index, original[i + 1]);
}
InsertChar(ref buffer, ref index, ' ');
AddChar(ref buffer, ref index, ' ');

i += inc;
continue;
}

InsertChar(ref buffer, ref index, c);
AddChar(ref buffer, ref index, c);
if (inc > 0)
{
InsertChar(ref buffer, ref index, original[i + 1]);
AddChar(ref buffer, ref index, original[i + 1]);
}
i += inc;
}
Expand Down Expand Up @@ -147,7 +147,7 @@ public BertNormalizer(bool doLowerCase, bool tokenizeChineseChars, bool stripAcc
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static void InsertChar(ref char[] buffer, ref int index, char c)
private static void AddChar(ref char[] buffer, ref int index, char c)
{
if (index >= buffer.Length)
{
Expand All @@ -158,9 +158,9 @@ private static void InsertChar(ref char[] buffer, ref int index, char c)
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static void InsertSpan(ref char[] buffer, ref int index, Span<char> chars)
private static void AddSpan(ref char[] buffer, ref int index, Span<char> chars)
{
if (index + buffer.Length >= buffer.Length)
if (index + chars.Length >= buffer.Length)
{
Helpers.ArrayPoolGrow(ref buffer, index + buffer.Length + 10);
}
Expand Down
Loading