Skip to content

Commit

Permalink
.Net: Use state for the tokenizer/encoding for all examples with a to…
Browse files Browse the repository at this point in the history
…kencount (#5519)

Fixes #5515 

### Motivation and Context

The examples for the text splitter all instantiate the encoder/tokenizer
*every*-time the count function is called. This updates the samples to
have a count method against a class that stores the encoding.

### Description

<!-- Describe your changes, the overall approach, the underlying design.
These notes will help understanding how your code works. Thanks! -->

### Contribution Checklist

<!-- Before submitting this PR, please make sure: -->

- [ ] The code builds clean without any errors or warnings
- [ ] The PR follows the [SK Contribution
Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md)
and the [pre-submission formatting
script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts)
raises no violations
- [ ] All unit tests pass, and I have added new tests where possible
- [ ] I didn't break anyone 😄

Co-authored-by: Roger Barreto <19890735+RogerBarreto@users.noreply.github.com>
  • Loading branch information
tonybaloney and RogerBarreto committed Mar 19, 2024
1 parent b302550 commit 561a9be
Showing 1 changed file with 66 additions and 38 deletions.
104 changes: 66 additions & 38 deletions dotnet/samples/KernelSyntaxExamples/Example55_TextChunker.cs
Original file line number Diff line number Diff line change
Expand Up @@ -84,75 +84,103 @@ public enum TokenCounterType
/// Custom token counter implementation using SharpToken.
/// Note: SharpToken is used for demonstration purposes only, it's possible to use any available or custom tokenization logic.
/// </summary>
private static TokenCounter SharpTokenTokenCounter => (string input) =>
public class SharpTokenTokenCounter
{
// Initialize encoding by encoding name
var encoding = GptEncoding.GetEncoding("cl100k_base");
private readonly GptEncoding _encoding;

// Initialize encoding by model name
// var encoding = GptEncoding.GetEncodingForModel("gpt-4");
public SharpTokenTokenCounter()
{
this._encoding = GptEncoding.GetEncoding("cl100k_base");
// Initialize encoding by model name
// this._encoding = GptEncoding.GetEncodingForModel("gpt-4");
}

var tokens = encoding.Encode(input);
public int Count(string input)
{
var tokens = this._encoding.Encode(input);

return tokens.Count;
};
return tokens.Count;
}
}

/// <summary>
/// MicrosoftML token counter implementation.
/// </summary>
private static TokenCounter MicrosoftMLTokenCounter => (string input) =>
public class MicrosoftMLTokenCounter
{
Tokenizer tokenizer = new(new Bpe());
var tokens = tokenizer.Encode(input).Tokens;
private readonly Tokenizer _tokenizer;

return tokens.Count;
};
public MicrosoftMLTokenCounter()
{
this._tokenizer = new(new Bpe());
}

public int Count(string input)
{
var tokens = this._tokenizer.Encode(input).Tokens;

return tokens.Count;
}
}

/// <summary>
/// MicrosoftML token counter implementation using Roberta and local vocab
/// </summary>
private static TokenCounter MicrosoftMLRobertaTokenCounter => (string input) =>
public class MicrosoftMLRobertaTokenCounter
{
var encoder = EmbeddedResource.ReadStream("EnglishRoberta.encoder.json");
var vocab = EmbeddedResource.ReadStream("EnglishRoberta.vocab.bpe");
var dict = EmbeddedResource.ReadStream("EnglishRoberta.dict.txt");
private readonly Tokenizer _tokenizer;

if (encoder is null || vocab is null || dict is null)
public MicrosoftMLRobertaTokenCounter()
{
throw new FileNotFoundException("Missing required resources");
}
var encoder = EmbeddedResource.ReadStream("EnglishRoberta.encoder.json");
var vocab = EmbeddedResource.ReadStream("EnglishRoberta.vocab.bpe");
var dict = EmbeddedResource.ReadStream("EnglishRoberta.dict.txt");

EnglishRoberta model = new(encoder, vocab, dict);
if (encoder is null || vocab is null || dict is null)
{
throw new FileNotFoundException("Missing required resources");
}

model.AddMaskSymbol(); // Not sure what this does, but it's in the example
Tokenizer tokenizer = new(model, new RobertaPreTokenizer());
var tokens = tokenizer.Encode(input).Tokens;
EnglishRoberta model = new(encoder, vocab, dict);

return tokens.Count;
};
model.AddMaskSymbol(); // Not sure what this does, but it's in the example
this._tokenizer = new(model, new RobertaPreTokenizer());
}

public int Count(string input)
{
var tokens = this._tokenizer.Encode(input).Tokens;

return tokens.Count;
}
}

/// <summary>
/// DeepDev token counter implementation.
/// </summary>
private static TokenCounter DeepDevTokenCounter => (string input) =>
public class DeepDevTokenCounter
{
// Initialize encoding by encoding name
var tokenizer = TokenizerBuilder.CreateByEncoderNameAsync("cl100k_base").GetAwaiter().GetResult();
private readonly ITokenizer _tokenizer;

// Initialize encoding by model name
// var tokenizer = TokenizerBuilder.CreateByModelNameAsync("gpt-4").GetAwaiter().GetResult();
public DeepDevTokenCounter()
{
this._tokenizer = TokenizerBuilder.CreateByEncoderNameAsync("cl100k_base").GetAwaiter().GetResult();
}

var tokens = tokenizer.Encode(input, new HashSet<string>());
return tokens.Count;
};
public int Count(string input)
{
var tokens = this._tokenizer.Encode(input, new HashSet<string>());
return tokens.Count;
}
}

private static readonly Func<TokenCounterType, TokenCounter> s_tokenCounterFactory = (TokenCounterType counterType) =>
counterType switch
{
TokenCounterType.SharpToken => (string input) => SharpTokenTokenCounter(input),
TokenCounterType.MicrosoftML => (string input) => MicrosoftMLTokenCounter(input),
TokenCounterType.DeepDev => (string input) => DeepDevTokenCounter(input),
TokenCounterType.MicrosoftMLRoberta => (string input) => MicrosoftMLRobertaTokenCounter(input),
TokenCounterType.SharpToken => new SharpTokenTokenCounter().Count,
TokenCounterType.MicrosoftML => new MicrosoftMLTokenCounter().Count,
TokenCounterType.DeepDev => new DeepDevTokenCounter().Count,
TokenCounterType.MicrosoftMLRoberta => new MicrosoftMLRobertaTokenCounter().Count,
_ => throw new ArgumentOutOfRangeException(nameof(counterType), counterType, null),
};

Expand Down

0 comments on commit 561a9be

Please sign in to comment.