Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIX: Fix embeddings truncation strategy #139

Merged
merged 1 commit into from Aug 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
14 changes: 7 additions & 7 deletions lib/modules/embeddings/strategies/truncation.rb
Expand Up @@ -22,17 +22,17 @@ def initialize(target, model)
@model = model
@target = target
@tokenizer = @model.tokenizer
@max_length = @model.max_sequence_length
@processed_target = +""
@max_length = @model.max_sequence_length - 2
@processed_target = nil
end

# Need a better name for this method
def process!
case @target
when Topic
topic_truncation(@target)
@processed_target = topic_truncation(@target)
when Post
post_truncation(@target)
@processed_target = post_truncation(@target)
else
raise ArgumentError, "Invalid target type"
end
Expand All @@ -41,7 +41,7 @@ def process!
end

def topic_truncation(topic)
t = @processed_target
t = +""

t << topic.title
t << "\n\n"
Expand All @@ -54,15 +54,15 @@ def topic_truncation(topic)

topic.posts.find_each do |post|
t << post.raw
break if @tokenizer.size(t) >= @max_length
break if @tokenizer.size(t) >= @max_length #maybe keep a partial counter to speed this up?
t << "\n\n"
end

@tokenizer.truncate(t, @max_length)
end

def post_truncation(post)
t = processed_target
t = +""

t << post.topic.title
t << "\n\n"
Expand Down
31 changes: 31 additions & 0 deletions spec/lib/modules/embeddings/strategies/truncation_spec.rb
@@ -0,0 +1,31 @@
# frozen_string_literal: true

RSpec.describe DiscourseAi::Embeddings::Strategies::Truncation do
describe "#process!" do
context "when the model uses OpenAI to create embeddings" do
before { SiteSetting.max_post_length = 100_000 }

fab!(:topic) { Fabricate(:topic) }
fab!(:post) do
Fabricate(:post, topic: topic, raw: "Baby, bird, bird, bird\nBird is the word\n" * 500)
end
fab!(:post) do
Fabricate(
:post,
topic: topic,
raw: "Don't you know about the bird?\nEverybody knows that the bird is a word\n" * 400,
)
end
fab!(:post) { Fabricate(:post, topic: topic, raw: "Surfin' bird\n" * 800) }

let(:model) { DiscourseAi::Embeddings::Models::Base.descendants.sample(1).first }
let(:truncation) { described_class.new(topic, model) }

it "truncates a topic" do
truncation.process!

expect(model.tokenizer.size(truncation.processed_target)).to be <= model.max_sequence_length
end
end
end
end