diff --git a/config/settings.yml b/config/settings.yml index 78df0ebd8..a2bff9f69 100644 --- a/config/settings.yml +++ b/config/settings.yml @@ -171,6 +171,9 @@ discourse_ai: default: "" hidden: true ai_llava_api_key: "" + ai_strict_token_counting: + default: false + hidden: true composer_ai_helper_enabled: default: false diff --git a/lib/tokenizer/basic_tokenizer.rb b/lib/tokenizer/basic_tokenizer.rb index f4afd6753..8adc4d270 100644 --- a/lib/tokenizer/basic_tokenizer.rb +++ b/lib/tokenizer/basic_tokenizer.rb @@ -17,14 +17,19 @@ def size(text) end def truncate(text, max_length) - # Fast track the common case where the text is already short enough. - return text if text.size < max_length + # fast track common case, /2 to handle unicode chars + # than can take more than 1 token per char + return text if !SiteSetting.ai_strict_token_counting && text.size < max_length / 2 tokenizer.decode(tokenizer.encode(text).ids.take(max_length)) end def can_expand_tokens?(text, addition, max_length) - return true if text.size + addition.size < max_length + # fast track common case, /2 to handle unicode chars + # than can take more than 1 token per char + if !SiteSetting.ai_strict_token_counting && text.size + addition.size < max_length / 2 + return true + end tokenizer.encode(text).ids.length + tokenizer.encode(addition).ids.length < max_length end diff --git a/lib/tokenizer/open_ai_tokenizer.rb b/lib/tokenizer/open_ai_tokenizer.rb index 0a31ffce9..daf7d4530 100644 --- a/lib/tokenizer/open_ai_tokenizer.rb +++ b/lib/tokenizer/open_ai_tokenizer.rb @@ -13,8 +13,9 @@ def tokenize(text) end def truncate(text, max_length) - # Fast track the common case where the text is already short enough. - return text if text.size < max_length + # fast track common case, /2 to handle unicode chars + # than can take more than 1 token per char + return text if !SiteSetting.ai_strict_token_counting && text.size < max_length / 2 tokenizer.decode(tokenize(text).take(max_length)) rescue Tiktoken::UnicodeError @@ -23,7 +24,11 @@ def truncate(text, max_length) end def can_expand_tokens?(text, addition, max_length) - return true if text.size + addition.size < max_length + # fast track common case, /2 to handle unicode chars + # than can take more than 1 token per char + if !SiteSetting.ai_strict_token_counting && text.size + addition.size < max_length / 2 + return true + end tokenizer.encode(text).length + tokenizer.encode(addition).length < max_length end diff --git a/spec/shared/tokenizer_spec.rb b/spec/shared/tokenizer_spec.rb index 74366b0d3..2c0900ed5 100644 --- a/spec/shared/tokenizer_spec.rb +++ b/spec/shared/tokenizer_spec.rb @@ -81,6 +81,31 @@ sentence = "foo bar 👨🏿‍👩🏿‍👧🏿‍👧🏿 baz qux quux corge grault garply waldo fred plugh xyzzy thud" expect(described_class.truncate(sentence, 7)).to eq("foo bar 👨🏿") end + + it "truncates unicode characters properly when they use more than one token per char" do + sentence = "我喜欢吃比萨" + original_size = described_class.size(sentence) + expect(described_class.size(described_class.truncate(sentence, original_size - 1))).to be < + original_size + end + end + + describe "#can_expand_tokens?" do + it "returns true when the tokens can be expanded" do + expect(described_class.can_expand_tokens?("foo bar", "baz qux", 6)).to eq(true) + end + + it "returns false when the tokens cannot be expanded" do + expect(described_class.can_expand_tokens?("foo bar", "baz qux", 3)).to eq(false) + end + + it "returns false when the tokens cannot be expanded due to multibyte unicode characters" do + expect(described_class.can_expand_tokens?("foo bar 👨🏿", "baz qux", 6)).to eq(false) + end + + it "handles unicode characters properly when they use more than one token per char" do + expect(described_class.can_expand_tokens?("我喜欢吃比萨", "萨", 10)).to eq(false) + end end end