Skip to content

Commit 58ce809

Browse files
parassshabhimanyusinghgaur
andauthored
feat(graphql+-): Allow parameterized @cascade directive. (#5607)
Co-authored-by: Abhimanyu Singh Gaur <12651351+abhimanyusinghgaur@users.noreply.github.com>
1 parent 53abd09 commit 58ce809

File tree

7 files changed

+644
-18
lines changed

7 files changed

+644
-18
lines changed

gql/parser.go

Lines changed: 70 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ type GraphQuery struct {
6767
Recurse bool
6868
RecurseArgs RecurseArgs
6969
ShortestPathArgs ShortestPathArgs
70-
Cascade bool
70+
Cascade []string
7171
IgnoreReflex bool
7272
Facets *pb.FacetParams
7373
FacetsFilter *FilterTree
@@ -961,7 +961,9 @@ L:
961961
case "normalize":
962962
gq.Normalize = true
963963
case "cascade":
964-
gq.Cascade = true
964+
if err := parseCascade(it, gq); err != nil {
965+
return nil, err
966+
}
965967
case "groupby":
966968
gq.IsGroupby = true
967969
if err := parseGroupby(it, gq); err != nil {
@@ -2094,6 +2096,69 @@ func tryParseFacetList(it *lex.ItemIterator) (res facetRes, parseOk bool, err er
20942096
}
20952097
}
20962098

2099+
// parseCascade parses the cascade directive.
2100+
// Two formats:
2101+
// 1. @cascade
2102+
// 2. @cascade(pred1, pred2, ...)
2103+
func parseCascade(it *lex.ItemIterator, gq *GraphQuery) error {
2104+
item := it.Item()
2105+
items, err := it.Peek(1)
2106+
if err != nil {
2107+
return item.Errorf("Unable to peek lexer after cascade")
2108+
}
2109+
2110+
// check if it is without any args:
2111+
// 1. @cascade {
2112+
// 2. @cascade }
2113+
// 3. @cascade @
2114+
// 4. @cascade\n someOtherPred
2115+
if items[0].Typ == itemLeftCurl || items[0].Typ == itemRightCurl || items[0].
2116+
Typ == itemAt || items[0].Typ == itemName {
2117+
// __all__ implies @cascade i.e. implies values for all the children are mandatory.
2118+
gq.Cascade = append(gq.Cascade, "__all__")
2119+
return nil
2120+
}
2121+
2122+
count := 0
2123+
expectArg := true
2124+
it.Next()
2125+
item = it.Item()
2126+
if item.Typ != itemLeftRound {
2127+
return item.Errorf("Expected a left round after cascade, got: %s", item.String())
2128+
}
2129+
2130+
loop:
2131+
for it.Next() {
2132+
item := it.Item()
2133+
switch item.Typ {
2134+
case itemRightRound:
2135+
break loop
2136+
case itemComma:
2137+
if expectArg {
2138+
return item.Errorf("Expected a predicate but got comma")
2139+
}
2140+
expectArg = true
2141+
case itemName:
2142+
if !expectArg {
2143+
return item.Errorf("Expected a comma or right round but got: %v", item.Val)
2144+
}
2145+
gq.Cascade = append(gq.Cascade, collectName(it, item.Val))
2146+
count++
2147+
expectArg = false
2148+
default:
2149+
return item.Errorf("Unexpected item while parsing: %v", item.Val)
2150+
}
2151+
}
2152+
if expectArg {
2153+
// use the initial item to report error line and column numbers
2154+
return item.Errorf("Unnecessary comma in cascade()")
2155+
}
2156+
if count == 0 {
2157+
return item.Errorf("At least one predicate required in parameterized cascade()")
2158+
}
2159+
return nil
2160+
}
2161+
20972162
// parseGroupby parses the groupby directive.
20982163
func parseGroupby(it *lex.ItemIterator, gq *GraphQuery) error {
20992164
count := 0
@@ -2430,7 +2495,9 @@ func parseDirective(it *lex.ItemIterator, curp *GraphQuery) error {
24302495
return item.Errorf("Facets parsing failed.")
24312496
}
24322497
case item.Val == "cascade":
2433-
curp.Cascade = true
2498+
if err := parseCascade(it, curp); err != nil {
2499+
return err
2500+
}
24342501
case item.Val == "normalize":
24352502
curp.Normalize = true
24362503
case peek[0].Typ == itemLeftRound:

gql/parser_test.go

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5275,3 +5275,86 @@ func TestFilterWithEmpty(t *testing.T) {
52755275
require.NoError(t, err)
52765276
require.Equal(t, gq.Query[0].Filter.Func.Args[0].Value, "")
52775277
}
5278+
5279+
func TestCascade(t *testing.T) {
5280+
query := `{
5281+
names(func: has(name)) @cascade {
5282+
name
5283+
}
5284+
}`
5285+
gq, err := Parse(Request{
5286+
Str: query,
5287+
})
5288+
require.NoError(t, err)
5289+
require.Equal(t, gq.Query[0].Cascade[0], "__all__")
5290+
}
5291+
5292+
func TestCascadeParameterized(t *testing.T) {
5293+
query := `{
5294+
names(func: has(name)) @cascade(name, age) {
5295+
name
5296+
age
5297+
dob
5298+
}
5299+
}`
5300+
gq, err := Parse(Request{
5301+
Str: query,
5302+
})
5303+
require.NoError(t, err)
5304+
require.Equal(t, gq.Query[0].Cascade[0], "name")
5305+
require.Equal(t, gq.Query[0].Cascade[1], "age")
5306+
}
5307+
5308+
func TestBadCascadeParameterized(t *testing.T) {
5309+
badQueries := []string{
5310+
`{
5311+
names(func: has(name)) @cascade( {
5312+
name
5313+
age
5314+
dob
5315+
}
5316+
}`,
5317+
`{
5318+
names(func: has(name)) @cascade) {
5319+
name
5320+
age
5321+
dob
5322+
}
5323+
}`,
5324+
`{
5325+
names(func: has(name)) @cascade() {
5326+
name
5327+
age
5328+
dob
5329+
}
5330+
}`,
5331+
`{
5332+
names(func: has(name)) @cascade(,) {
5333+
name
5334+
age
5335+
dob
5336+
}
5337+
}`,
5338+
`{
5339+
names(func: has(name)) @cascade(name,) {
5340+
name
5341+
age
5342+
dob
5343+
}
5344+
}`,
5345+
`{
5346+
names(func: has(name)) @cascade(,name) {
5347+
name
5348+
age
5349+
dob
5350+
}
5351+
}`,
5352+
}
5353+
5354+
for _, query := range badQueries {
5355+
_, err := Parse(Request{
5356+
Str: query,
5357+
})
5358+
require.Error(t, err)
5359+
}
5360+
}

graphql/dgraph/graphquery.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,10 @@ func writeQuery(b *strings.Builder, query *gql.GraphQuery, prefix string) {
6969
x.Check2(b.WriteRune(')'))
7070
}
7171

72-
if query.Cascade {
73-
x.Check2(b.WriteString(" @cascade"))
72+
if len(query.Cascade) != 0 {
73+
if query.Cascade[0] == "__all__" {
74+
x.Check2(b.WriteString(" @cascade"))
75+
}
7476
}
7577

7678
switch {

graphql/resolve/query_rewriter.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -544,7 +544,7 @@ func (authRw *authRewriter) rewriteRuleNode(
544544
r1 := rewriteAsQuery(qry, authRw)
545545
r1.Var = varName
546546
r1.Attr = "var"
547-
r1.Cascade = true
547+
r1.Cascade = append(r1.Cascade, "__all__")
548548

549549
return []*gql.GraphQuery{r1}, &gql.FilterTree{
550550
Func: &gql.Function{

graphql/schema/wrappers.go

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ type Field interface {
122122
SetArgTo(arg string, val interface{})
123123
Skip() bool
124124
Include() bool
125-
Cascade() bool
125+
Cascade() []string
126126
HasCustomDirective() (bool, map[string]bool)
127127
Type() Type
128128
SelectionSet() []Field
@@ -659,8 +659,12 @@ func (f *field) Include() bool {
659659
return dir.ArgumentMap(f.op.vars)["if"].(bool)
660660
}
661661

662-
func (f *field) Cascade() bool {
663-
return f.field.Directives.ForName(cascadeDirective) != nil
662+
func (f *field) Cascade() []string {
663+
664+
if f.field.Directives.ForName(cascadeDirective) == nil {
665+
return nil
666+
}
667+
return []string{"__all__"}
664668
}
665669

666670
func (f *field) HasCustomDirective() (bool, map[string]bool) {
@@ -1049,7 +1053,7 @@ func (q *query) Include() bool {
10491053
return true
10501054
}
10511055

1052-
func (q *query) Cascade() bool {
1056+
func (q *query) Cascade() []string {
10531057
return (*field)(q).Cascade()
10541058
}
10551059

@@ -1162,7 +1166,7 @@ func (m *mutation) Include() bool {
11621166
return true
11631167
}
11641168

1165-
func (m *mutation) Cascade() bool {
1169+
func (m *mutation) Cascade() []string {
11661170
return (*field)(m).Cascade()
11671171
}
11681172

@@ -1197,7 +1201,7 @@ func (m *mutation) QueryField() Field {
11971201
}
11981202
// if @cascade was given on mutation itself, then it should get applied for the query which
11991203
// gets executed to fetch the results of that mutation, so propagating it to the QueryField.
1200-
if m.Cascade() && !f.Cascade() {
1204+
if len(m.Cascade()) != 0 && len(f.Cascade()) == 0 {
12011205
field := f.(*field).field
12021206
field.Directives = append(field.Directives, &ast.Directive{Name: cascadeDirective})
12031207
}

query/query.go

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -156,8 +156,10 @@ type params struct {
156156
Recurse bool
157157
// RecurseArgs stores the arguments passed to the @recurse directive.
158158
RecurseArgs gql.RecurseArgs
159-
// Cascade is true if the @cascade directive is specified.
160-
Cascade bool
159+
// Cascade is the list of predicates to apply @cascade to.
160+
// __all__ is special to mean @cascade i.e. all the children of this subgraph are mandatory
161+
// and should have values otherwise the node will be excluded.
162+
Cascade []string
161163
// IgnoreReflex is true if the @ignorereflex directive is specified.
162164
IgnoreReflex bool
163165

@@ -532,7 +534,6 @@ func treeCopy(gq *gql.GraphQuery, sg *SubGraph) error {
532534

533535
args := params{
534536
Alias: gchild.Alias,
535-
Cascade: gchild.Cascade || sg.Params.Cascade,
536537
Expand: gchild.Expand,
537538
Facet: gchild.Facets,
538539
FacetsOrder: gchild.FacetsOrder,
@@ -549,6 +550,15 @@ func treeCopy(gq *gql.GraphQuery, sg *SubGraph) error {
549550
IsInternal: gchild.IsInternal,
550551
}
551552

553+
// If parent has @cascade (with or without params), inherit @cascade (with no params)
554+
if len(sg.Params.Cascade) > 0 {
555+
args.Cascade = append(args.Cascade, "__all__")
556+
}
557+
// Allow over-riding at this level.
558+
if len(gchild.Cascade) > 0 {
559+
args.Cascade = gchild.Cascade
560+
}
561+
552562
if gchild.IsCount {
553563
if len(gchild.Children) != 0 {
554564
return errors.New("Node with count cannot have child attributes")
@@ -1296,6 +1306,13 @@ func (sg *SubGraph) populateVarMap(doneVars map[string]varValue, sgPath []*SubGr
12961306
if sg.DestUIDs == nil || sg.IsGroupBy() {
12971307
return nil
12981308
}
1309+
1310+
cascadeArgMap := make(map[string]bool)
1311+
for _, pred := range sg.Params.Cascade {
1312+
cascadeArgMap[pred] = true
1313+
}
1314+
cascadeAllPreds := cascadeArgMap["__all__"]
1315+
12991316
out := make([]uint64, 0, len(sg.DestUIDs.Uids))
13001317
if sg.Params.Alias == "shortest" {
13011318
goto AssignStep
@@ -1311,7 +1328,7 @@ func (sg *SubGraph) populateVarMap(doneVars map[string]varValue, sgPath []*SubGr
13111328
return err
13121329
}
13131330
sgPath = sgPath[:len(sgPath)-1] // Backtrack
1314-
if !child.Params.Cascade {
1331+
if len(child.Params.Cascade) == 0 {
13151332
continue
13161333
}
13171334

@@ -1320,7 +1337,7 @@ func (sg *SubGraph) populateVarMap(doneVars map[string]varValue, sgPath []*SubGr
13201337
child.updateUidMatrix()
13211338
}
13221339

1323-
if !sg.Params.Cascade {
1340+
if len(sg.Params.Cascade) == 0 {
13241341
goto AssignStep
13251342
}
13261343

@@ -1336,7 +1353,8 @@ func (sg *SubGraph) populateVarMap(doneVars map[string]varValue, sgPath []*SubGr
13361353

13371354
// If the length of child UID list is zero and it has no valid value, then the
13381355
// current UID should be removed from this level.
1339-
if !child.IsInternal() &&
1356+
if (cascadeAllPreds || cascadeArgMap[child.Attr]) &&
1357+
!child.IsInternal() &&
13401358
// Check len before accessing index.
13411359
(len(child.valueMatrix) <= i || len(child.valueMatrix[i].Values) == 0) &&
13421360
(len(child.counts) <= i) &&

0 commit comments

Comments
 (0)