Skip to content

Commit 3572772

Browse files
feat(GraphQL): Support auth with custom DQL (#7775)
Fixes GRAPHQL-1171. This PR adds support for @auth on custom DQL queries. This PR also adds the parsing step to the DQL query. The Algorithm tries to figure out the `queried type` at the root of the query either from the `root func` or `filters`. Suppose if the root func is `func: type(Person)` or `func: has(Person.name)` or `func: eq(Person.age, 10)` then it infers that the queried type is `Person`. However, there are certain DQL queries on which @auth rules are not added currently since the queried type is not clear. for eg: ``` me(func: uid(0x1, 0x2)) { ... } ``` or the query: ``` me(func: has(name@en) { ... } ```
1 parent 556d7fa commit 3572772

File tree

10 files changed

+1253
-56
lines changed

10 files changed

+1253
-56
lines changed

graphql/dgraph/graphquery.go

Lines changed: 104 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import (
2121
"strings"
2222

2323
"github.com/dgraph-io/dgraph/gql"
24+
"github.com/dgraph-io/dgraph/graphql/schema"
2425
"github.com/dgraph-io/dgraph/x"
2526
)
2627

@@ -65,10 +66,24 @@ func writeQuery(b *strings.Builder, query *gql.GraphQuery, prefix string) {
6566
x.Check2(b.WriteString(query.Alias))
6667
x.Check2(b.WriteString(" : "))
6768
}
68-
x.Check2(b.WriteString(query.Attr))
69+
70+
if query.IsCount {
71+
x.Check2(b.WriteString(fmt.Sprintf("count(%s)", query.Attr)))
72+
} else if query.Attr != "val" {
73+
x.Check2(b.WriteString(query.Attr))
74+
} else if isAggregateFn(query.Func) {
75+
x.Check2(b.WriteString("sum(val("))
76+
writeNeedVar(b, query)
77+
x.Check2(b.WriteRune(')'))
78+
} else {
79+
x.Check2(b.WriteString("val("))
80+
writeNeedVar(b, query)
81+
x.Check2(b.WriteRune(')'))
82+
}
6983

7084
if query.Func != nil {
7185
writeRoot(b, query)
86+
x.Check2(b.WriteRune(')'))
7287
}
7388

7489
if query.Filter != nil {
@@ -93,6 +108,12 @@ func writeQuery(b *strings.Builder, query *gql.GraphQuery, prefix string) {
93108
}
94109
}
95110

111+
if query.IsGroupby {
112+
x.Check2(b.WriteString(" @groupby("))
113+
writeGroupByAttributes(b, query.GroupbyAttrs)
114+
x.Check2(b.WriteRune(')'))
115+
}
116+
96117
switch {
97118
case len(query.Children) > 0:
98119
prefixAdd := ""
@@ -112,7 +133,43 @@ func writeQuery(b *strings.Builder, query *gql.GraphQuery, prefix string) {
112133
}
113134
}
114135

115-
func writeUIDFunc(b *strings.Builder, uids []uint64, args []gql.Arg) {
136+
// writeNeedVar writes the NeedsVar of the query. For eg :-
137+
// `userFollowerCount as sum(val(followers))` has `followers`
138+
// as NeedsVar.
139+
func writeNeedVar(b *strings.Builder, query *gql.GraphQuery) {
140+
for i, v := range query.NeedsVar {
141+
if i != 0 {
142+
x.Check2(b.WriteString(", "))
143+
}
144+
x.Check2(b.WriteString(v.Name))
145+
}
146+
}
147+
148+
func isAggregateFn(f *gql.Function) bool {
149+
if f == nil {
150+
return false
151+
}
152+
switch f.Name {
153+
case "min", "max", "avg", "sum":
154+
return true
155+
}
156+
return false
157+
}
158+
159+
func writeGroupByAttributes(b *strings.Builder, attrList []gql.GroupByAttr) {
160+
for i, attr := range attrList {
161+
if i != 0 {
162+
x.Check2(b.WriteString(", "))
163+
}
164+
if attr.Alias != "" {
165+
x.Check2(b.WriteString(attr.Alias))
166+
x.Check2(b.WriteString(" : "))
167+
}
168+
x.Check2(b.WriteString(attr.Attr))
169+
}
170+
}
171+
172+
func writeUIDFunc(b *strings.Builder, uids []uint64, args []gql.Arg, needVar []gql.VarContext) {
116173
x.Check2(b.WriteString("uid("))
117174
if len(uids) > 0 {
118175
// uid function with uint64 - uid(0x123, 0x456, ...)
@@ -122,14 +179,21 @@ func writeUIDFunc(b *strings.Builder, uids []uint64, args []gql.Arg) {
122179
}
123180
x.Check2(b.WriteString(fmt.Sprintf("%#x", uid)))
124181
}
125-
} else {
182+
} else if len(args) > 0 {
126183
// uid function with a Dgraph query variable - uid(Post1)
127184
for i, arg := range args {
128185
if i != 0 {
129186
x.Check2(b.WriteString(", "))
130187
}
131188
x.Check2(b.WriteString(arg.Value))
132189
}
190+
} else {
191+
for i, v := range needVar {
192+
if i != 0 {
193+
x.Check2(b.WriteString(", "))
194+
}
195+
x.Check2(b.WriteString(v.Name))
196+
}
133197
}
134198
x.Check2(b.WriteString(")"))
135199
}
@@ -145,25 +209,38 @@ func writeRoot(b *strings.Builder, q *gql.GraphQuery) {
145209
}
146210

147211
switch {
212+
case q.Func.Name == "has":
213+
x.Check2(b.WriteString(fmt.Sprintf("(func: has(%s)", q.Func.Attr)))
148214
case q.Func.Name == "uid":
149215
x.Check2(b.WriteString("(func: "))
150-
writeUIDFunc(b, q.Func.UID, q.Func.Args)
216+
writeUIDFunc(b, q.Func.UID, q.Func.Args, q.Func.NeedsVar)
151217
case q.Func.Name == "type" && len(q.Func.Args) == 1:
152218
x.Check2(b.WriteString(fmt.Sprintf("(func: type(%s)", q.Func.Args[0].Value)))
153219
case q.Func.Name == "eq":
154220
x.Check2(b.WriteString("(func: eq("))
155-
writeFilterArguments(b, q.Func.Args)
221+
writeFilterArguments(b, q.Func)
156222
x.Check2(b.WriteRune(')'))
157223
}
158224
writeOrderAndPage(b, q, true)
159-
x.Check2(b.WriteRune(')'))
160225
}
161226

162-
func writeFilterArguments(b *strings.Builder, args []gql.Arg) {
163-
for i, arg := range args {
164-
if i != 0 {
227+
// writeFilterArguments writes the filter arguments. If the filter
228+
// is constructed in graphql query rewriting then `Attr` is an empty
229+
// string since we add Attr in the argument itself.
230+
func writeFilterArguments(b *strings.Builder, q *gql.Function) {
231+
if q.Attr != "" {
232+
x.Check2(b.WriteString(q.Attr))
233+
}
234+
235+
for i, arg := range q.Args {
236+
if i != 0 || q.Attr != "" {
165237
x.Check2(b.WriteString(", "))
166238
}
239+
if q.Attr != "" {
240+
// quote the arguments since this is the case of
241+
// @custom DQL string.
242+
arg.Value = schema.MaybeQuoteArg(q.Name, arg.Value)
243+
}
167244
x.Check2(b.WriteString(arg.Value))
168245
}
169246
}
@@ -175,10 +252,10 @@ func writeFilterFunction(b *strings.Builder, f *gql.Function) {
175252

176253
switch {
177254
case f.Name == "uid":
178-
writeUIDFunc(b, f.UID, f.Args)
255+
writeUIDFunc(b, f.UID, f.Args, f.NeedsVar)
179256
default:
180257
x.Check2(b.WriteString(fmt.Sprintf("%s(", f.Name)))
181-
writeFilterArguments(b, f.Args)
258+
writeFilterArguments(b, f)
182259
x.Check2(b.WriteRune(')'))
183260
}
184261
}
@@ -215,6 +292,15 @@ func hasOrderOrPage(q *gql.GraphQuery) bool {
215292
return len(q.Order) > 0 || hasFirst || hasOffset
216293
}
217294

295+
func IsValueVar(attr string, q *gql.GraphQuery) bool {
296+
for _, vars := range q.NeedsVar {
297+
if attr == vars.Name && vars.Typ == 2 {
298+
return true
299+
}
300+
}
301+
return false
302+
}
303+
218304
func writeOrderAndPage(b *strings.Builder, query *gql.GraphQuery, root bool) {
219305
var wroteOrder, wroteFirst bool
220306

@@ -227,7 +313,13 @@ func writeOrderAndPage(b *strings.Builder, query *gql.GraphQuery, root bool) {
227313
} else {
228314
x.Check2(b.WriteString("orderasc: "))
229315
}
230-
x.Check2(b.WriteString(ord.Attr))
316+
if IsValueVar(ord.Attr, query) {
317+
x.Check2(b.WriteString("val("))
318+
x.Check2(b.WriteString(ord.Attr))
319+
x.Check2(b.WriteRune(')'))
320+
} else {
321+
x.Check2(b.WriteString(ord.Attr))
322+
}
231323
wroteOrder = true
232324
}
233325

graphql/e2e/auth/auth_test.go

Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,159 @@ func (s Student) add(t *testing.T) {
322322
require.JSONEq(t, result, string(gqlResponse.Data))
323323
}
324324

325+
func TestAuthWithCustomDQL(t *testing.T) {
326+
TestCases := []TestCase{
327+
{
328+
name: "RBAC OR filter query; RBAC Pass",
329+
query: `
330+
query{
331+
queryProjectsOrderByName{
332+
name
333+
}
334+
}
335+
`,
336+
role: "ADMIN",
337+
result: `{"queryProjectsOrderByName":[{"name": "Project1"},{"name": "Project2"}]}`,
338+
},
339+
{
340+
name: "RBAC OR filter query; RBAC false OR `user1` projects",
341+
query: `
342+
query{
343+
queryProjectsOrderByName{
344+
name
345+
}
346+
}
347+
`,
348+
role: "USER",
349+
user: "user1",
350+
result: `{"queryProjectsOrderByName":[{"name": "Project1"}]}`,
351+
},
352+
{
353+
name: "RBAC OR filter query; missing jwt",
354+
query: `
355+
query{
356+
queryProjectsOrderByName{
357+
name
358+
}
359+
}
360+
`,
361+
result: `{"queryProjectsOrderByName":[]}`,
362+
},
363+
{
364+
name: "var query; RBAC AND filter query; RBAC pass",
365+
query: `
366+
query{
367+
queryIssueSortedByOwnerAge{
368+
msg
369+
}
370+
}
371+
`,
372+
role: "ADMIN",
373+
user: "user2",
374+
result: `{"queryIssueSortedByOwnerAge": [{"msg": "Issue2"}]}`,
375+
},
376+
{
377+
name: "var query; RBAC AND filter query; RBAC fail",
378+
query: `
379+
query{
380+
queryIssueSortedByOwnerAge{
381+
msg
382+
}
383+
}
384+
`,
385+
role: "USER",
386+
user: "user2",
387+
result: `{"queryIssueSortedByOwnerAge": []}`,
388+
},
389+
{
390+
name: "DQL query with @cascade and pagination",
391+
query: `
392+
query{
393+
queryFirstTwoMovieWithNonNullRegion{
394+
content
395+
code
396+
regionsAvailable{
397+
name
398+
}
399+
}
400+
}
401+
`,
402+
role: "ADMIN",
403+
user: "user1",
404+
result: `{"queryFirstTwoMovieWithNonNullRegion": [
405+
{
406+
"content": "Movie3",
407+
"code": "m3",
408+
"regionsAvailable": [
409+
{
410+
"name": "Region1"
411+
}
412+
]
413+
},
414+
{
415+
"content": "Movie4",
416+
"code": "m4",
417+
"regionsAvailable": [
418+
{
419+
"name": "Region5"
420+
}
421+
]
422+
}
423+
]
424+
}`,
425+
},
426+
{
427+
name: "query interface; auth rules pass for all the implementing types",
428+
query: `
429+
query{
430+
queryQuestionAndAnswer{
431+
text
432+
}
433+
}
434+
`,
435+
ans: true,
436+
user: "user1@dgraph.io",
437+
result: `{"queryQuestionAndAnswer": [{"text": "A Answer"},{"text": "A Question"}]}`,
438+
},
439+
{
440+
name: "query interface; auth rules fail for some implementing types",
441+
query: `
442+
query{
443+
queryQuestionAndAnswer{
444+
text
445+
}
446+
}
447+
`,
448+
user: "user2@dgraph.io",
449+
result: `{"queryQuestionAndAnswer": [{"text": "B Answer"}]}`,
450+
},
451+
{
452+
name: "query interface; auth rules fail for the interface",
453+
query: `
454+
query{
455+
queryQuestionAndAnswer{
456+
text
457+
}
458+
}
459+
`,
460+
ans: true,
461+
result: `{"queryQuestionAndAnswer": []}`,
462+
},
463+
}
464+
465+
for _, tcase := range TestCases {
466+
t.Run(tcase.name, func(t *testing.T) {
467+
getUserParams := &common.GraphQLParams{
468+
Headers: common.GetJWTForInterfaceAuth(t, tcase.user, tcase.role, tcase.ans, metaInfo),
469+
Query: tcase.query,
470+
}
471+
gqlResponse := getUserParams.ExecuteAsPost(t, common.GraphqlURL)
472+
common.RequireNoGQLErrors(t, gqlResponse)
473+
require.JSONEq(t, tcase.result, string(gqlResponse.Data))
474+
})
475+
}
476+
}
477+
325478
func TestAddMutationWithXid(t *testing.T) {
326479
mutation := `
327480
mutation addTweets($tweet: AddTweetsInput!){

0 commit comments

Comments
 (0)