diff --git a/core/http/endpoints/jina/rerank.go b/core/http/endpoints/jina/rerank.go index 75d3de31410e..b7e2bb40c44a 100644 --- a/core/http/endpoints/jina/rerank.go +++ b/core/http/endpoints/jina/rerank.go @@ -32,10 +32,22 @@ func JINARerankEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, app } log.Debug().Str("model", input.Model).Msg("JINA Rerank Request received") - + var requestTopN int32 + docs := int32(len(input.Documents)) + if input.TopN == nil { // omit top_n to get all + requestTopN = docs + } else { + requestTopN = int32(*input.TopN) + if requestTopN < 1 { + return c.JSON(http.StatusUnprocessableEntity, "top_n - should be greater than or equal to 1") + } + if requestTopN > docs { // make it more obvious for backends + requestTopN = docs + } + } request := &proto.RerankRequest{ Query: input.Query, - TopN: int32(input.TopN), + TopN: requestTopN, Documents: input.Documents, } diff --git a/core/schema/jina.go b/core/schema/jina.go index 63d24556fe97..e4daba559d4b 100644 --- a/core/schema/jina.go +++ b/core/schema/jina.go @@ -5,7 +5,7 @@ type JINARerankRequest struct { BasicModelRequest Query string `json:"query"` Documents []string `json:"documents"` - TopN int `json:"top_n"` + TopN *int `json:"top_n,omitempty"` Backend string `json:"backend"` } diff --git a/tests/e2e-aio/e2e_test.go b/tests/e2e-aio/e2e_test.go index 371f2bedba7c..8421772a960f 100644 --- a/tests/e2e-aio/e2e_test.go +++ b/tests/e2e-aio/e2e_test.go @@ -286,45 +286,64 @@ var _ = Describe("E2E test", func() { Context("reranker", func() { It("correctly", func() { modelName := "jina-reranker-v1-base-en" - - req := schema.JINARerankRequest{ - BasicModelRequest: schema.BasicModelRequest{ - Model: modelName, - }, - Query: "Organic skincare products for sensitive skin", - Documents: []string{ - "Eco-friendly kitchenware for modern homes", - "Biodegradable cleaning supplies for eco-conscious consumers", - "Organic cotton baby clothes for sensitive skin", - "Natural organic skincare range for sensitive skin", - "Tech gadgets for smart homes: 2024 edition", - "Sustainable gardening tools and compost solutions", - "Sensitive skin-friendly facial cleansers and toners", - "Organic food wraps and storage solutions", - "All-natural pet food for dogs with allergies", - "Yoga mats made from recycled materials", - }, - TopN: 3, + const query = "Organic skincare products for sensitive skin" + var documents = []string{ + "Eco-friendly kitchenware for modern homes", + "Biodegradable cleaning supplies for eco-conscious consumers", + "Organic cotton baby clothes for sensitive skin", + "Natural organic skincare range for sensitive skin", + "Tech gadgets for smart homes: 2024 edition", + "Sustainable gardening tools and compost solutions", + "Sensitive skin-friendly facial cleansers and toners", + "Organic food wraps and storage solutions", + "All-natural pet food for dogs with allergies", + "Yoga mats made from recycled materials", + } + // Exceed len or requested results + randomValue := int(GinkgoRandomSeed()) % (len(documents) + 1) + requestResults := randomValue + 1 // at least 1 results + // Cap expectResults by the length of documents + expectResults := min(requestResults, len(documents)) + var maybeSkipTopN = &requestResults + if requestResults >= len(documents) && int(GinkgoRandomSeed())%2 == 0 { + maybeSkipTopN = nil } - serialized, err := json.Marshal(req) - Expect(err).To(BeNil()) - Expect(serialized).ToNot(BeNil()) - - rerankerEndpoint := apiEndpoint + "/rerank" - resp, err := http.Post(rerankerEndpoint, "application/json", bytes.NewReader(serialized)) - Expect(err).To(BeNil()) - Expect(resp).ToNot(BeNil()) - body, err := io.ReadAll(resp.Body) - Expect(err).ToNot(HaveOccurred()) + resp, body := requestRerank(modelName, query, documents, maybeSkipTopN, apiEndpoint) Expect(resp.StatusCode).To(Equal(200), fmt.Sprintf("body: %s, response: %+v", body, resp)) deserializedResponse := schema.JINARerankResponse{} - err = json.Unmarshal(body, &deserializedResponse) + err := json.Unmarshal(body, &deserializedResponse) Expect(err).To(BeNil()) Expect(deserializedResponse).ToNot(BeZero()) Expect(deserializedResponse.Model).To(Equal(modelName)) - Expect(len(deserializedResponse.Results)).To(BeNumerically(">", 0)) + //Expect(len(deserializedResponse.Results)).To(BeNumerically(">", 0)) + Expect(len(deserializedResponse.Results)).To(Equal(expectResults)) + // Assert that relevance scores are in decreasing order + for i := 1; i < len(deserializedResponse.Results); i++ { + Expect(deserializedResponse.Results[i].RelevanceScore).To( + BeNumerically("<=", deserializedResponse.Results[i-1].RelevanceScore), + fmt.Sprintf("Result at index %d should have lower relevance score than previous result.", i), + ) + } + // Assert that each result's index points to the correct document + for i, result := range deserializedResponse.Results { + Expect(result.Index).To( + And( + BeNumerically(">=", 0), + BeNumerically("<", len(documents)), + ), + fmt.Sprintf("Result at position %d has index %d which should be within bounds [0, %d)", i, result.Index, len(documents)), + ) + Expect(result.Document.Text).To( + Equal(documents[result.Index]), + fmt.Sprintf("Result at position %d (index %d) should have document text '%s', but got '%s'", + i, result.Index, documents[result.Index], result.Document.Text), + ) + } + zeroOrNeg := int(GinkgoRandomSeed())%2 - 1 // Results in either -1 or 0 + resp, body = requestRerank(modelName, query, documents, &zeroOrNeg, apiEndpoint) + Expect(resp.StatusCode).To(Equal(422), fmt.Sprintf("body: %s, response: %+v", body, resp)) }) }) }) @@ -350,3 +369,26 @@ func downloadHttpFile(url string) (string, error) { return tmpfile.Name(), nil } + +func requestRerank(modelName, query string, documents []string, topN *int, apiEndpoint string) (*http.Response, []byte) { + req := schema.JINARerankRequest{ + BasicModelRequest: schema.BasicModelRequest{ + Model: modelName, + }, + Query: query, + Documents: documents, + TopN: topN, + } + + serialized, err := json.Marshal(req) + Expect(err).To(BeNil()) + Expect(serialized).ToNot(BeNil()) + rerankerEndpoint := apiEndpoint + "/rerank" + resp, err := http.Post(rerankerEndpoint, "application/json", bytes.NewReader(serialized)) + Expect(err).To(BeNil()) + Expect(resp).ToNot(BeNil()) + body, err := io.ReadAll(resp.Body) + Expect(err).ToNot(HaveOccurred()) + + return resp, body +}