diff --git a/graphql/resolve/query_rewriter.go b/graphql/resolve/query_rewriter.go index ced445bdcb6..509e593a772 100644 --- a/graphql/resolve/query_rewriter.go +++ b/graphql/resolve/query_rewriter.go @@ -659,12 +659,12 @@ func rewriteAsSimilarByIdQuery( topK := query.ArgValue(schema.SimilarTopKArgName) similarByField := typ.Field(similarBy) metric := similarByField.EmbeddingSearchMetric() - distanceFormula := "math((v2 - v1) dot (v2 - v1))" // default - euclidian + distanceFormula := "math(sqrt((v2 - v1) dot (v2 - v1)))" // default - euclidian if metric == schema.SimilarSearchMetricDotProduct { - distanceFormula = "math(v1 dot v2)" + distanceFormula = "math((1.0 - (v1 dot v2)) /2.0)" } else if metric == schema.SimilarSearchMetricCosine { - distanceFormula = "math((v1 dot v2) / ((v1 dot v1) * (v2 dot v2)))" + distanceFormula = "math((1.0 - ((v1 dot v2) / sqrt( (v1 dot v1) * (v2 dot v2) ) )) / 2.0)" } // First generate the query to fetch the uid @@ -819,12 +819,13 @@ func rewriteAsSimilarByEmbeddingQuery( similarByField := typ.Field(similarBy) metric := similarByField.EmbeddingSearchMetric() - distanceFormula := "math((v2 - $search_vector) dot (v2 - $search_vector))" // default = euclidian + distanceFormula := "math(sqrt((v2 - $search_vector) dot (v2 - $search_vector)))" // default = euclidian if metric == schema.SimilarSearchMetricDotProduct { - distanceFormula = "math($search_vector dot v2)" + distanceFormula = "math(( 1.0 - (($search_vector) dot v2)) /2.0)" } else if metric == schema.SimilarSearchMetricCosine { - distanceFormula = "math(($search_vector dot v2) / (($search_vector dot $search_vector) * (v2 dot v2)))" + distanceFormula = "math((1.0 - ( (($search_vector) dot v2) / sqrt( (($search_vector) dot ($search_vector))" + + " * (v2 dot v2) ) )) / 2.0)" } // Save vectorString as a query variable, $search_vector diff --git a/graphql/resolve/query_test.yaml b/graphql/resolve/query_test.yaml index c74b9d77471..ab15599d020 100644 --- a/graphql/resolve/query_test.yaml +++ b/graphql/resolve/query_test.yaml @@ -3367,7 +3367,7 @@ query querySimilarProductByEmbedding($search_vector: float32vector = "[0.1,0.2,0.3,0.4,0.5]") { var(func: similar_to(Product.productVector, 1, $search_vector)) @filter(type(Product)) { v2 as Product.productVector - distance as math((v2 - $search_vector) dot (v2 - $search_vector)) + distance as math(sqrt((v2 - $search_vector) dot (v2 - $search_vector))) } querySimilarProductByEmbedding(func: uid(distance), orderasc: val(distance)) { Product.id : Product.id @@ -3397,7 +3397,7 @@ } var(func: similar_to(Product.productVector, 3, val(v1))) { v2 as Product.productVector - distance as math((v2 - v1) dot (v2 - v1)) + distance as math(sqrt((v2 - v1) dot (v2 - v1))) } querySimilarProductById(func: uid(distance), orderasc: val(distance)) { Product.id : Product.id @@ -3428,7 +3428,7 @@ } var(func: similar_to(ProjectCosine.description_v, 3, val(v1))) { v2 as ProjectCosine.description_v - distance as math((v1 dot v2) / ((v1 dot v1) * (v2 dot v2))) + distance as math((1.0 - ((v1 dot v2) / sqrt( (v1 dot v1) * (v2 dot v2) ) )) / 2.0) } querySimilarProjectCosineById(func: uid(distance), orderasc: val(distance)) { ProjectCosine.id : ProjectCosine.id @@ -3453,7 +3453,7 @@ query querySimilarProjectCosineByEmbedding($search_vector: float32vector = "[0.1,0.2,0.3,0.4,0.5]") { var(func: similar_to(ProjectCosine.description_v, 1, $search_vector)) @filter(type(ProjectCosine)) { v2 as ProjectCosine.description_v - distance as math(($search_vector dot v2) / (($search_vector dot $search_vector) * (v2 dot v2))) + distance as math((1.0 - ( (($search_vector) dot v2) / sqrt( (($search_vector) dot ($search_vector)) * (v2 dot v2) ) )) / 2.0) } querySimilarProjectCosineByEmbedding(func: uid(distance), orderasc: val(distance)) { ProjectCosine.id : ProjectCosine.id @@ -3483,7 +3483,7 @@ } var(func: similar_to(ProjectDotProduct.description_v, 3, val(v1))) { v2 as ProjectDotProduct.description_v - distance as math(v1 dot v2) + distance as math((1.0 - (v1 dot v2)) /2.0) } querySimilarProjectDotProductById(func: uid(distance), orderasc: val(distance)) { ProjectDotProduct.id : ProjectDotProduct.id @@ -3508,7 +3508,7 @@ query querySimilarProjectDotProductByEmbedding($search_vector: float32vector = "[0.1,0.2,0.3,0.4,0.5]") { var(func: similar_to(ProjectDotProduct.description_v, 1, $search_vector)) @filter(type(ProjectDotProduct)) { v2 as ProjectDotProduct.description_v - distance as math($search_vector dot v2) + distance as math(( 1.0 - (($search_vector) dot v2)) /2.0) } querySimilarProjectDotProductByEmbedding(func: uid(distance), orderasc: val(distance)) { ProjectDotProduct.id : ProjectDotProduct.id diff --git a/query/vector/vector_graphql_test.go b/query/vector/vector_graphql_test.go index 058e89dca04..defa04278cd 100644 --- a/query/vector/vector_graphql_test.go +++ b/query/vector/vector_graphql_test.go @@ -20,6 +20,8 @@ package query import ( "encoding/json" + "fmt" + "math/rand" "testing" "github.com/dgraph-io/dgraph/dgraphtest" @@ -36,29 +38,56 @@ const ( type Project { id: ID! title: String! @search(by: [exact]) - title_v: [Float!] @embedding @search(by: ["hnsw(metric: euclidian, exponent: 4)"]) - } - ` + title_v: [Float!] @embedding @search(by: ["hnsw(metric: %v, exponent: 4)"]) + } ` ) -var ( - projects = []ProjectInput{ProjectInput{ - Title: "iCreate with a Mini iPad", - TitleV: []float32{0.7, 0.8, 0.9, 0.1, 0.2}, - }, ProjectInput{ - Title: "Resistive Touchscreen", - TitleV: []float32{0.7, 0.8, 0.9, 0.1, 0.2}, - }, ProjectInput{ - Title: "Fitness Band", - TitleV: []float32{0.7, 0.8, 0.9, 0.1, 0.2}, - }, ProjectInput{ - Title: "Smart Watch", - TitleV: []float32{0.7, 0.8, 0.9, 0.1, 0.2}, - }, ProjectInput{ - Title: "Smart Ring", - TitleV: []float32{0.7, 0.8, 0.9, 0.1, 0.2}, - }} -) +func generateProjects(count int) []ProjectInput { + var projects []ProjectInput + for i := 0; i < count; i++ { + title := generateUniqueRandomTitle(projects) + titleV := generateRandomTitleV(5) // Assuming size is fixed at 5 + project := ProjectInput{ + Title: title, + TitleV: titleV, + } + projects = append(projects, project) + } + return projects +} + +func isTitleExists(title string, existingTitles []ProjectInput) bool { + for _, project := range existingTitles { + if project.Title == title { + return true + } + } + return false +} + +func generateUniqueRandomTitle(existingTitles []ProjectInput) string { + const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" + const titleLength = 10 + title := make([]byte, titleLength) + for { + for i := range title { + title[i] = charset[rand.Intn(len(charset))] + } + titleStr := string(title) + if !isTitleExists(titleStr, existingTitles) { + return titleStr + } + } +} + +func generateRandomTitleV(size int) []float32 { + var titleV []float32 + for i := 0; i < size; i++ { + value := rand.Float32() + titleV = append(titleV, value) + } + return titleV +} func addProject(t *testing.T, hc *dgraphtest.HTTPClient, project ProjectInput) { query := ` @@ -79,6 +108,7 @@ func addProject(t *testing.T, hc *dgraphtest.HTTPClient, project ProjectInput) { _, err := hc.RunGraphqlQuery(params, false) require.NoError(t, err) } + func queryProjectUsingTitle(t *testing.T, hc *dgraphtest.HTTPClient, title string) ProjectInput { query := ` query QueryProject($title: String!) { queryProject(filter: { title: { eq: $title } }) { @@ -96,7 +126,6 @@ func queryProjectUsingTitle(t *testing.T, hc *dgraphtest.HTTPClient, title strin type QueryResult struct { QueryProject []ProjectInput `json:"queryProject"` } - var resp QueryResult err = json.Unmarshal([]byte(string(response)), &resp) require.NoError(t, err) @@ -104,11 +133,10 @@ func queryProjectUsingTitle(t *testing.T, hc *dgraphtest.HTTPClient, title strin return resp.QueryProject[0] } -func queryProjectsSimilarByEmbedding(t *testing.T, hc *dgraphtest.HTTPClient, vector []float32) []ProjectInput { +func queryProjectsSimilarByEmbedding(t *testing.T, hc *dgraphtest.HTTPClient, vector []float32, topk int) []ProjectInput { // query similar project by embedding queryProduct := `query QuerySimilarProjectByEmbedding($by: ProjectEmbedding!, $topK: Int!, $vector: [Float!]!) { querySimilarProjectByEmbedding(by: $by, topK: $topK, vector: $vector) { - id title title_v } @@ -120,20 +148,19 @@ func queryProjectsSimilarByEmbedding(t *testing.T, hc *dgraphtest.HTTPClient, ve Query: queryProduct, Variables: map[string]interface{}{ "by": "title_v", - "topK": 3, + "topK": topk, "vector": vector, }} response, err := hc.RunGraphqlQuery(params, false) require.NoError(t, err) type QueryResult struct { - QueryProject []ProjectInput `json:"queryProject"` + QueryProject []ProjectInput `json:"querySimilarProjectByEmbedding"` } var resp QueryResult err = json.Unmarshal([]byte(string(response)), &resp) require.NoError(t, err) return resp.QueryProject - } func TestVectorGraphQLAddVectorPredicate(t *testing.T) { @@ -143,21 +170,67 @@ func TestVectorGraphQLAddVectorPredicate(t *testing.T) { require.NoError(t, err) hc.LoginIntoNamespace("groot", "password", 0) // add schema - require.NoError(t, hc.UpdateGQLSchema(graphQLVectorSchema)) + require.NoError(t, hc.UpdateGQLSchema(fmt.Sprintf(graphQLVectorSchema, "euclidean"))) +} + +func TestVectorSchema(t *testing.T) { + require.NoError(t, client.DropAll()) + + hc, err := dc.HTTPClient() + require.NoError(t, err) + hc.LoginIntoNamespace("groot", "password", 0) + + schema := `type Project { + id: ID! + title: String! @search(by: [exact]) + title_v: [Float!] + }` + + // add schema + require.NoError(t, hc.UpdateGQLSchema(schema)) + require.Error(t, hc.UpdateGQLSchema(fmt.Sprintf(graphQLVectorSchema, "euclidean"))) +} + +func TestVectorGraphQlEuclidianIndexMutationAndQuery(t *testing.T) { + require.NoError(t, client.DropAll()) + hc, err := dc.HTTPClient() + require.NoError(t, err) + hc.LoginIntoNamespace("groot", "password", 0) + + schema := fmt.Sprintf(graphQLVectorSchema, "euclidean") + // add schema + require.NoError(t, hc.UpdateGQLSchema(schema)) + testVectorGraphQlMutationAndQuery(t, hc) } -func TestVectorGraphQlMutationAndQuery(t *testing.T) { +func TestVectorGraphQlCosineIndexMutationAndQuery(t *testing.T) { require.NoError(t, client.DropAll()) + hc, err := dc.HTTPClient() + require.NoError(t, err) + hc.LoginIntoNamespace("groot", "password", 0) + schema := fmt.Sprintf(graphQLVectorSchema, "cosine") + // add schema + require.NoError(t, hc.UpdateGQLSchema(schema)) + testVectorGraphQlMutationAndQuery(t, hc) +} + +func TestVectorGraphQlDotProductIndexMutationAndQuery(t *testing.T) { + require.NoError(t, client.DropAll()) hc, err := dc.HTTPClient() require.NoError(t, err) hc.LoginIntoNamespace("groot", "password", 0) + schema := fmt.Sprintf(graphQLVectorSchema, "dotproduct") // add schema - require.NoError(t, hc.UpdateGQLSchema(graphQLVectorSchema)) + require.NoError(t, hc.UpdateGQLSchema(schema)) + testVectorGraphQlMutationAndQuery(t, hc) +} - // add project +func testVectorGraphQlMutationAndQuery(t *testing.T, hc *dgraphtest.HTTPClient) { var vectors [][]float32 + numProjects := 100 + projects := generateProjects(numProjects) for _, project := range projects { vectors = append(vectors, project.TitleV) addProject(t, hc, project) @@ -177,30 +250,9 @@ func TestVectorGraphQlMutationAndQuery(t *testing.T) { // query similar project by embedding for _, project := range projects { - similarProjects := queryProjectsSimilarByEmbedding(t, hc, project.TitleV) - + similarProjects := queryProjectsSimilarByEmbedding(t, hc, project.TitleV, numProjects) for _, similarVec := range similarProjects { require.Contains(t, vectors, similarVec.TitleV) } } } - -func TestVectorSchema(t *testing.T) { - require.NoError(t, client.DropAll()) - - hc, err := dc.HTTPClient() - require.NoError(t, err) - hc.LoginIntoNamespace("groot", "password", 0) - - schema := `type Project { - id: ID! - title: String! @search(by: [exact]) - title_v: [Float!] - }` - - // add schema - require.NoError(t, hc.UpdateGQLSchema(schema)) - require.Error(t, hc.UpdateGQLSchema(graphQLVectorSchema)) - require.NoError(t, client.DropAll()) - require.NoError(t, hc.UpdateGQLSchema(graphQLVectorSchema)) -}