From 8b8b0adc549c4c382c16942889f0cdfe414ba894 Mon Sep 17 00:00:00 2001 From: vmrajas Date: Fri, 11 Dec 2020 16:56:47 +0530 Subject: [PATCH] Fix(GraphQL): Fix Aggregate queries on empty data (#7119) * Fix aggregate queries on empty data * Empty commit to make the CLA pass --- graphql/e2e/auth/auth_test.go | 14 +----- graphql/e2e/common/common.go | 2 + graphql/e2e/common/query.go | 39 +++++++++++++---- graphql/resolve/auth_query_test.yaml | 5 ++- graphql/resolve/query_rewriter.go | 48 ++++++++++++-------- graphql/resolve/query_test.yaml | 65 ++++++++++++++++++++++++++++ graphql/resolve/resolver.go | 29 +++++++++++++ 7 files changed, 162 insertions(+), 40 deletions(-) diff --git a/graphql/e2e/auth/auth_test.go b/graphql/e2e/auth/auth_test.go index a8044761ede..8eeac15e636 100644 --- a/graphql/e2e/auth/auth_test.go +++ b/graphql/e2e/auth/auth_test.go @@ -1620,12 +1620,7 @@ func TestChildAggregateQueryWithDeepRBAC(t *testing.T) { [ { "username": "user1", - "issuesAggregate": - { - "count": null, - "msgMax": null, - "msgMin": null - } + "issuesAggregate": null } ] }`}, @@ -1687,12 +1682,7 @@ func TestChildAggregateQueryWithOtherFields(t *testing.T) { { "username": "user1", "issues":[], - "issuesAggregate": - { - "count": null, - "msgMin": null, - "msgMax": null - } + "issuesAggregate": null } ] }`}, diff --git a/graphql/e2e/common/common.go b/graphql/e2e/common/common.go index f2ea9daa4fd..63f221779a4 100644 --- a/graphql/e2e/common/common.go +++ b/graphql/e2e/common/common.go @@ -644,10 +644,12 @@ func RunAll(t *testing.T) { t.Run("persisted query", persistedQuery) t.Run("query aggregate without filter", queryAggregateWithoutFilter) t.Run("query aggregate with filter", queryAggregateWithFilter) + t.Run("query aggregate on empty data", queryAggregateOnEmptyData) t.Run("query aggregate with alias", queryAggregateWithAlias) t.Run("query aggregate with repeated fields", queryAggregateWithRepeatedFields) t.Run("query aggregate at child level", queryAggregateAtChildLevel) t.Run("query aggregate at child level with filter", queryAggregateAtChildLevelWithFilter) + t.Run("query aggregate at child level with empty data", queryAggregateAtChildLevelWithEmptyData) t.Run("query aggregate at child level with multiple alias", queryAggregateAtChildLevelWithMultipleAlias) t.Run("query aggregate at child level with repeated fields", queryAggregateAtChildLevelWithRepeatedFields) t.Run("query aggregate and other fields at child level", queryAggregateAndOtherFieldsAtChildLevel) diff --git a/graphql/e2e/common/query.go b/graphql/e2e/common/query.go index c5a7322d07f..642c622c282 100644 --- a/graphql/e2e/common/query.go +++ b/graphql/e2e/common/query.go @@ -2833,8 +2833,10 @@ func queryAggregateWithFilter(t *testing.T) { } }`, string(gqlResponse.Data)) +} - queryPostParams = &GraphQLParams{ +func queryAggregateOnEmptyData(t *testing.T) { + queryPostParams := &GraphQLParams{ Query: `query { aggregatePost (filter: {title : { anyofterms : "Nothing" }} ) { count @@ -2844,16 +2846,11 @@ func queryAggregateWithFilter(t *testing.T) { }`, } - gqlResponse = queryPostParams.ExecuteAsPost(t, GraphqlURL) + gqlResponse := queryPostParams.ExecuteAsPost(t, GraphqlURL) RequireNoGQLErrors(t, gqlResponse) testutil.CompareJSON(t, `{ - "aggregatePost": - { - "count":0, - "numLikesMax": 0, - "titleMin": "0.000000" - } + "aggregatePost": null }`, string(gqlResponse.Data)) } @@ -3011,6 +3008,32 @@ func queryAggregateAtChildLevelWithFilter(t *testing.T) { string(gqlResponse.Data)) } +func queryAggregateAtChildLevelWithEmptyData(t *testing.T) { + queryNumberOfIndianStates := &GraphQLParams{ + Query: `query + { + queryCountry(filter: { name: { eq: "India" } }) { + name + ag : statesAggregate(filter: {xcode: {in: ["nothing"]}}) { + count + nameMin + } + } + }`, + } + gqlResponse := queryNumberOfIndianStates.ExecuteAsPost(t, GraphqlURL) + RequireNoGQLErrors(t, gqlResponse) + testutil.CompareJSON(t, + ` + { + "queryCountry": [{ + "name": "India", + "ag": null + }] + }`, + string(gqlResponse.Data)) +} + func queryAggregateAtChildLevelWithMultipleAlias(t *testing.T) { queryNumberOfIndianStates := &GraphQLParams{ Query: `query diff --git a/graphql/resolve/auth_query_test.yaml b/graphql/resolve/auth_query_test.yaml index 07173d8df3e..38e3f59c24f 100644 --- a/graphql/resolve/auth_query_test.yaml +++ b/graphql/resolve/auth_query_test.yaml @@ -617,13 +617,13 @@ dgquery: |- query { aggregateProject() { - nameMin : min(val(nameVar)) count : max(val(countVar)) + nameMin : min(val(nameVar)) randomMin : min(val(randomVar)) } var(func: uid(ProjectRoot)) { - nameVar as Project.name countVar as count(uid) + nameVar as Project.name randomVar as Project.random } ProjectRoot as var(func: uid(Project1)) @@ -1194,6 +1194,7 @@ ticketsAggregate_titleVar as Ticket.title dgraph.uid : uid } + count_ticketsAggregate : count(User.tickets) @filter(uid(TicketAggregateResult1)) titleMin_ticketsAggregate : min(val(ticketsAggregate_titleVar)) issuesAggregate : User.issues @filter(uid(IssueAggregateResult4)) { issuesAggregate_msgVar as Issue.msg diff --git a/graphql/resolve/query_rewriter.go b/graphql/resolve/query_rewriter.go index d0edc5b6954..457884ca665 100644 --- a/graphql/resolve/query_rewriter.go +++ b/graphql/resolve/query_rewriter.go @@ -191,23 +191,30 @@ func aggregateQuery(query schema.Query, authRw *authRewriter) []*gql.GraphQuery // isAggregateFunctionVisited stores if the aggregate function for a field has been added or not. // So the map entries would contain keys as nameMin, ageMin, nameName, etc. isAggregateFunctionVisited := make(map[string]bool) + + // Add count field to aggregateQuery by default. This is done to ensure that null is + // returned in case the count of nodes is 0. + child := &gql.GraphQuery{ + Var: "countVar", + Attr: "count(uid)", + } + finalQueryChild := &gql.GraphQuery{ + Alias: "count", + Attr: "max(val(countVar))", + } + mainQuery.Children = append(mainQuery.Children, child) + finalMainQuery.Children = append(finalMainQuery.Children, finalQueryChild) + for _, f := range query.SelectionSet() { + // fldName stores Name of the field f. fldName := f.Name() if _, visited := isAggregateFunctionVisited[fldName]; visited { continue } isAggregateFunctionVisited[fldName] = true if fldName == "count" { - child := &gql.GraphQuery{ - Var: "countVar", - Attr: "count(uid)", - } - finalQueryChild := &gql.GraphQuery{ - Alias: fldName, - Attr: "max(val(countVar))", - } - mainQuery.Children = append(mainQuery.Children, child) - finalMainQuery.Children = append(finalMainQuery.Children, finalQueryChild) + // We continue in case of a count field in Aggregate Query as count has already + // been added by default just before the for loop. continue } @@ -1038,6 +1045,16 @@ func buildAggregateFields( // contain "scoreVar as Tweets.score" only once. isAggregateFieldVisited := make(map[string]bool) + // Add the default count field. Count field is part of an AggregateField by default + // as this makes it possible to return null field in case the count of nodes is 0 + aggregateChild := &gql.GraphQuery{ + Alias: "count_" + fieldAlias, + Attr: "count(" + constructedForDgraphPredicate + ")", + } + // Add filter to count aggregation field. + _ = addFilter(aggregateChild, constructedForType, fieldFilter) + aggregateChildren = append(aggregateChildren, aggregateChild) + // Iterate over fields queried inside aggregate. for _, aggregateField := range f.SelectionSet() { // Don't add the same field twice @@ -1046,15 +1063,10 @@ func buildAggregateFields( } addedAggregateField[aggregateField.DgraphAlias()] = true - // Handle count fields inside aggregate fields. + // As count fields are always part of an AggregateField by + // default (added just before this for loop). We continue + // in case of a count field. if aggregateField.DgraphAlias() == "count" { - aggregateChild := &gql.GraphQuery{ - Alias: "count_" + fieldAlias, - Attr: "count(" + constructedForDgraphPredicate + ")", - } - // Add filter to count aggregation field. - _ = addFilter(aggregateChild, constructedForType, fieldFilter) - aggregateChildren = append(aggregateChildren, aggregateChild) continue } // Handle other aggregate functions than count diff --git a/graphql/resolve/query_test.yaml b/graphql/resolve/query_test.yaml index 6969175909b..2201d9289cf 100644 --- a/graphql/resolve/query_test.yaml +++ b/graphql/resolve/query_test.yaml @@ -1435,6 +1435,29 @@ } } +- + name: "Aggregate Query with no count field" + gqlquery: | + query { + aggregateCountry(filter: { name: { regexp: "/.*ust.*/" }}) { + nameMin + nm : nameMin + nameMax + } + } + dgquery: |- + query { + aggregateCountry() { + count : max(val(countVar)) + nameMin : min(val(nameVar)) + nameMax : max(val(nameVar)) + } + var(func: type(Country)) @filter(regexp(Country.name, /.*ust.*/)) { + countVar as count(uid) + nameVar as Country.name + } + } + - name: "Skip directive" variables: @@ -2956,6 +2979,48 @@ statesAggregate_nameVar as State.name dgraph.uid : uid } + count_statesAggregate : count(Country.states) + nameMin_statesAggregate : min(val(statesAggregate_nameVar)) + nameMax_statesAggregate : max(val(statesAggregate_nameVar)) + statesAggregate1 : Country.states @filter(eq(State.code, "state code")) { + statesAggregate1_nameVar as State.name + statesAggregate1_capitalVar as State.capital + dgraph.uid : uid + } + count_statesAggregate1 : count(Country.states) @filter(eq(State.code, "state code")) + nameMin_statesAggregate1 : min(val(statesAggregate1_nameVar)) + nameMax_statesAggregate1 : max(val(statesAggregate1_nameVar)) + capitalMin_statesAggregate1 : min(val(statesAggregate1_capitalVar)) + dgraph.uid : uid + } + } + +- + name: "Aggregate query at child level with no count field" + gqlquery: | + query { + queryCountry { + nm : name + ag : statesAggregate { + nMin : nameMin + nMax : nameMax + } + statesAggregate(filter: { code: { eq: "state code" } }) { + nMin : nameMin + nMax : nameMax + cMin : capitalMin + } + } + } + dgquery: |- + query { + queryCountry(func: type(Country)) { + name : Country.name + statesAggregate : Country.states { + statesAggregate_nameVar as State.name + dgraph.uid : uid + } + count_statesAggregate : count(Country.states) nameMin_statesAggregate : min(val(statesAggregate_nameVar)) nameMax_statesAggregate : max(val(statesAggregate_nameVar)) statesAggregate1 : Country.states @filter(eq(State.code, "state code")) { diff --git a/graphql/resolve/resolver.go b/graphql/resolve/resolver.go index 73fbfbe7f22..54a79ea4eae 100644 --- a/graphql/resolve/resolver.go +++ b/graphql/resolve/resolver.go @@ -1429,6 +1429,35 @@ func completeObject( }} } } + + // Handle the case of empty data in Aggregate Queries. If count of data is equal + // to 0, set the val map to nil. This makes the aggregateField return null instead + // of returning "0.0000" for Min, Max function on strings and 0 for Min, Max functions + // on integers/float. + if strings.HasSuffix(f.Type().Name(), "AggregateResult") && val != nil { + var count json.Number + countVal := val.(map[string]interface{})["count"] + if countVal == nil { + // This case may happen in case of auth queries when the user does not have + // sufficient permission to query aggregate fields. We set val to nil in this + // case + val = nil + } else { + if count, ok = countVal.(json.Number); !ok { + // This is to handle case in which countVal is of any other type than + // json.Number. This should never happen. We return an error. + return nil, x.GqlErrorList{&x.GqlError{ + Message: "Expected count field of type json.Number inside Aggregate Field", + Locations: []x.Location{f.Location()}, + Path: copyPath(path), + }} + } + if count == "0" { + val = nil + } + } + } + completed, err := completeValue(append(path, f.ResponseName()), f, val) errs = append(errs, err...) if completed == nil {