Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(vector): Update query_rewriter to fix dotproduct and cosine query conversion #9083

Merged
merged 6 commits into from
May 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
13 changes: 7 additions & 6 deletions graphql/resolve/query_rewriter.go
Original file line number Diff line number Diff line change
Expand Up @@ -659,12 +659,12 @@ func rewriteAsSimilarByIdQuery(
topK := query.ArgValue(schema.SimilarTopKArgName)
similarByField := typ.Field(similarBy)
metric := similarByField.EmbeddingSearchMetric()
distanceFormula := "math((v2 - v1) dot (v2 - v1))" // default - euclidian
distanceFormula := "math(sqrt((v2 - v1) dot (v2 - v1)))" // default - euclidian

if metric == schema.SimilarSearchMetricDotProduct {
distanceFormula = "math(v1 dot v2)"
distanceFormula = "math((1.0 - (v1 dot v2)) /2.0)"
} else if metric == schema.SimilarSearchMetricCosine {
distanceFormula = "math((v1 dot v2) / ((v1 dot v1) * (v2 dot v2)))"
distanceFormula = "math((1.0 - ((v1 dot v2) / sqrt( (v1 dot v1) * (v2 dot v2) ) )) / 2.0)"
}

// First generate the query to fetch the uid
Expand Down Expand Up @@ -819,12 +819,13 @@ func rewriteAsSimilarByEmbeddingQuery(

similarByField := typ.Field(similarBy)
metric := similarByField.EmbeddingSearchMetric()
distanceFormula := "math((v2 - $search_vector) dot (v2 - $search_vector))" // default = euclidian
distanceFormula := "math(sqrt((v2 - $search_vector) dot (v2 - $search_vector)))" // default = euclidian

if metric == schema.SimilarSearchMetricDotProduct {
distanceFormula = "math($search_vector dot v2)"
distanceFormula = "math(( 1.0 - (($search_vector) dot v2)) /2.0)"
} else if metric == schema.SimilarSearchMetricCosine {
distanceFormula = "math(($search_vector dot v2) / (($search_vector dot $search_vector) * (v2 dot v2)))"
distanceFormula = "math((1.0 - ( (($search_vector) dot v2) / sqrt( (($search_vector) dot ($search_vector))" +
" * (v2 dot v2) ) )) / 2.0)"
}

// Save vectorString as a query variable, $search_vector
Expand Down
12 changes: 6 additions & 6 deletions graphql/resolve/query_test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3367,7 +3367,7 @@
query querySimilarProductByEmbedding($search_vector: float32vector = "[0.1,0.2,0.3,0.4,0.5]") {
var(func: similar_to(Product.productVector, 1, $search_vector)) @filter(type(Product)) {
v2 as Product.productVector
distance as math((v2 - $search_vector) dot (v2 - $search_vector))
distance as math(sqrt((v2 - $search_vector) dot (v2 - $search_vector)))
}
querySimilarProductByEmbedding(func: uid(distance), orderasc: val(distance)) {
Product.id : Product.id
Expand Down Expand Up @@ -3397,7 +3397,7 @@
}
var(func: similar_to(Product.productVector, 3, val(v1))) {
v2 as Product.productVector
distance as math((v2 - v1) dot (v2 - v1))
distance as math(sqrt((v2 - v1) dot (v2 - v1)))
}
querySimilarProductById(func: uid(distance), orderasc: val(distance)) {
Product.id : Product.id
Expand Down Expand Up @@ -3428,7 +3428,7 @@
}
var(func: similar_to(ProjectCosine.description_v, 3, val(v1))) {
v2 as ProjectCosine.description_v
distance as math((v1 dot v2) / ((v1 dot v1) * (v2 dot v2)))
distance as math((1.0 - ((v1 dot v2) / sqrt( (v1 dot v1) * (v2 dot v2) ) )) / 2.0)
}
querySimilarProjectCosineById(func: uid(distance), orderasc: val(distance)) {
ProjectCosine.id : ProjectCosine.id
Expand All @@ -3453,7 +3453,7 @@
query querySimilarProjectCosineByEmbedding($search_vector: float32vector = "[0.1,0.2,0.3,0.4,0.5]") {
var(func: similar_to(ProjectCosine.description_v, 1, $search_vector)) @filter(type(ProjectCosine)) {
v2 as ProjectCosine.description_v
distance as math(($search_vector dot v2) / (($search_vector dot $search_vector) * (v2 dot v2)))
distance as math((1.0 - ( (($search_vector) dot v2) / sqrt( (($search_vector) dot ($search_vector)) * (v2 dot v2) ) )) / 2.0)
}
querySimilarProjectCosineByEmbedding(func: uid(distance), orderasc: val(distance)) {
ProjectCosine.id : ProjectCosine.id
Expand Down Expand Up @@ -3483,7 +3483,7 @@
}
var(func: similar_to(ProjectDotProduct.description_v, 3, val(v1))) {
v2 as ProjectDotProduct.description_v
distance as math(v1 dot v2)
distance as math((1.0 - (v1 dot v2)) /2.0)
}
querySimilarProjectDotProductById(func: uid(distance), orderasc: val(distance)) {
ProjectDotProduct.id : ProjectDotProduct.id
Expand All @@ -3508,7 +3508,7 @@
query querySimilarProjectDotProductByEmbedding($search_vector: float32vector = "[0.1,0.2,0.3,0.4,0.5]") {
var(func: similar_to(ProjectDotProduct.description_v, 1, $search_vector)) @filter(type(ProjectDotProduct)) {
v2 as ProjectDotProduct.description_v
distance as math($search_vector dot v2)
distance as math(( 1.0 - (($search_vector) dot v2)) /2.0)
}
querySimilarProjectDotProductByEmbedding(func: uid(distance), orderasc: val(distance)) {
ProjectDotProduct.id : ProjectDotProduct.id
Expand Down
158 changes: 105 additions & 53 deletions query/vector/vector_graphql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ package query

import (
"encoding/json"
"fmt"
"math/rand"
"testing"

"github.com/dgraph-io/dgraph/dgraphtest"
Expand All @@ -36,29 +38,56 @@ const (
type Project {
id: ID!
title: String! @search(by: [exact])
title_v: [Float!] @embedding @search(by: ["hnsw(metric: euclidian, exponent: 4)"])
}
`
title_v: [Float!] @embedding @search(by: ["hnsw(metric: %v, exponent: 4)"])
} `
)

var (
projects = []ProjectInput{ProjectInput{
Title: "iCreate with a Mini iPad",
TitleV: []float32{0.7, 0.8, 0.9, 0.1, 0.2},
}, ProjectInput{
Title: "Resistive Touchscreen",
TitleV: []float32{0.7, 0.8, 0.9, 0.1, 0.2},
}, ProjectInput{
Title: "Fitness Band",
TitleV: []float32{0.7, 0.8, 0.9, 0.1, 0.2},
}, ProjectInput{
Title: "Smart Watch",
TitleV: []float32{0.7, 0.8, 0.9, 0.1, 0.2},
}, ProjectInput{
Title: "Smart Ring",
TitleV: []float32{0.7, 0.8, 0.9, 0.1, 0.2},
}}
)
func generateProjects(count int) []ProjectInput {
var projects []ProjectInput
for i := 0; i < count; i++ {
title := generateUniqueRandomTitle(projects)
titleV := generateRandomTitleV(5) // Assuming size is fixed at 5
project := ProjectInput{
Title: title,
TitleV: titleV,
}
projects = append(projects, project)
}
return projects
}

func isTitleExists(title string, existingTitles []ProjectInput) bool {
for _, project := range existingTitles {
if project.Title == title {
return true
}
}
return false
}

func generateUniqueRandomTitle(existingTitles []ProjectInput) string {
const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
const titleLength = 10
title := make([]byte, titleLength)
for {
for i := range title {
title[i] = charset[rand.Intn(len(charset))]
}
titleStr := string(title)
if !isTitleExists(titleStr, existingTitles) {
return titleStr
}
}
}

func generateRandomTitleV(size int) []float32 {
var titleV []float32
for i := 0; i < size; i++ {
value := rand.Float32()
titleV = append(titleV, value)
}
return titleV
}

func addProject(t *testing.T, hc *dgraphtest.HTTPClient, project ProjectInput) {
query := `
Expand All @@ -79,6 +108,7 @@ func addProject(t *testing.T, hc *dgraphtest.HTTPClient, project ProjectInput) {
_, err := hc.RunGraphqlQuery(params, false)
require.NoError(t, err)
}

func queryProjectUsingTitle(t *testing.T, hc *dgraphtest.HTTPClient, title string) ProjectInput {
query := ` query QueryProject($title: String!) {
queryProject(filter: { title: { eq: $title } }) {
Expand All @@ -96,19 +126,17 @@ func queryProjectUsingTitle(t *testing.T, hc *dgraphtest.HTTPClient, title strin
type QueryResult struct {
QueryProject []ProjectInput `json:"queryProject"`
}

var resp QueryResult
err = json.Unmarshal([]byte(string(response)), &resp)
require.NoError(t, err)

return resp.QueryProject[0]
}

func queryProjectsSimilarByEmbedding(t *testing.T, hc *dgraphtest.HTTPClient, vector []float32) []ProjectInput {
func queryProjectsSimilarByEmbedding(t *testing.T, hc *dgraphtest.HTTPClient, vector []float32, topk int) []ProjectInput {
// query similar project by embedding
queryProduct := `query QuerySimilarProjectByEmbedding($by: ProjectEmbedding!, $topK: Int!, $vector: [Float!]!) {
querySimilarProjectByEmbedding(by: $by, topK: $topK, vector: $vector) {
id
title
title_v
}
Expand All @@ -120,20 +148,19 @@ func queryProjectsSimilarByEmbedding(t *testing.T, hc *dgraphtest.HTTPClient, ve
Query: queryProduct,
Variables: map[string]interface{}{
"by": "title_v",
"topK": 3,
"topK": topk,
"vector": vector,
}}
response, err := hc.RunGraphqlQuery(params, false)
require.NoError(t, err)
type QueryResult struct {
QueryProject []ProjectInput `json:"queryProject"`
QueryProject []ProjectInput `json:"querySimilarProjectByEmbedding"`
}
var resp QueryResult
err = json.Unmarshal([]byte(string(response)), &resp)
require.NoError(t, err)

return resp.QueryProject

}

func TestVectorGraphQLAddVectorPredicate(t *testing.T) {
Expand All @@ -143,21 +170,67 @@ func TestVectorGraphQLAddVectorPredicate(t *testing.T) {
require.NoError(t, err)
hc.LoginIntoNamespace("groot", "password", 0)
// add schema
require.NoError(t, hc.UpdateGQLSchema(graphQLVectorSchema))
require.NoError(t, hc.UpdateGQLSchema(fmt.Sprintf(graphQLVectorSchema, "euclidean")))
}

func TestVectorSchema(t *testing.T) {
require.NoError(t, client.DropAll())

hc, err := dc.HTTPClient()
require.NoError(t, err)
hc.LoginIntoNamespace("groot", "password", 0)

schema := `type Project {
id: ID!
title: String! @search(by: [exact])
title_v: [Float!]
}`

// add schema
require.NoError(t, hc.UpdateGQLSchema(schema))
require.Error(t, hc.UpdateGQLSchema(fmt.Sprintf(graphQLVectorSchema, "euclidean")))
}

func TestVectorGraphQlEuclidianIndexMutationAndQuery(t *testing.T) {
require.NoError(t, client.DropAll())
hc, err := dc.HTTPClient()
require.NoError(t, err)
hc.LoginIntoNamespace("groot", "password", 0)

schema := fmt.Sprintf(graphQLVectorSchema, "euclidean")
// add schema
require.NoError(t, hc.UpdateGQLSchema(schema))
testVectorGraphQlMutationAndQuery(t, hc)
}

func TestVectorGraphQlMutationAndQuery(t *testing.T) {
func TestVectorGraphQlCosineIndexMutationAndQuery(t *testing.T) {
require.NoError(t, client.DropAll())
hc, err := dc.HTTPClient()
require.NoError(t, err)
hc.LoginIntoNamespace("groot", "password", 0)

schema := fmt.Sprintf(graphQLVectorSchema, "cosine")
// add schema
require.NoError(t, hc.UpdateGQLSchema(schema))
testVectorGraphQlMutationAndQuery(t, hc)
}

func TestVectorGraphQlDotProductIndexMutationAndQuery(t *testing.T) {
require.NoError(t, client.DropAll())
hc, err := dc.HTTPClient()
require.NoError(t, err)
hc.LoginIntoNamespace("groot", "password", 0)

schema := fmt.Sprintf(graphQLVectorSchema, "dotproduct")
// add schema
require.NoError(t, hc.UpdateGQLSchema(graphQLVectorSchema))
require.NoError(t, hc.UpdateGQLSchema(schema))
testVectorGraphQlMutationAndQuery(t, hc)
}

// add project
func testVectorGraphQlMutationAndQuery(t *testing.T, hc *dgraphtest.HTTPClient) {
var vectors [][]float32
numProjects := 100
projects := generateProjects(numProjects)
for _, project := range projects {
vectors = append(vectors, project.TitleV)
addProject(t, hc, project)
Expand All @@ -177,30 +250,9 @@ func TestVectorGraphQlMutationAndQuery(t *testing.T) {

// query similar project by embedding
for _, project := range projects {
similarProjects := queryProjectsSimilarByEmbedding(t, hc, project.TitleV)

similarProjects := queryProjectsSimilarByEmbedding(t, hc, project.TitleV, numProjects)
for _, similarVec := range similarProjects {
require.Contains(t, vectors, similarVec.TitleV)
}
}
}

func TestVectorSchema(t *testing.T) {
require.NoError(t, client.DropAll())

hc, err := dc.HTTPClient()
require.NoError(t, err)
hc.LoginIntoNamespace("groot", "password", 0)

schema := `type Project {
id: ID!
title: String! @search(by: [exact])
title_v: [Float!]
}`

// add schema
require.NoError(t, hc.UpdateGQLSchema(schema))
require.Error(t, hc.UpdateGQLSchema(graphQLVectorSchema))
require.NoError(t, client.DropAll())
require.NoError(t, hc.UpdateGQLSchema(graphQLVectorSchema))
}