-
-
Notifications
You must be signed in to change notification settings - Fork 179
/
prepare.go
116 lines (92 loc) · 2.14 KB
/
prepare.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
package core
import (
"bytes"
"fmt"
"io"
"strings"
"sync"
"github.com/dosco/graphjin/core/internal/allow"
"github.com/dosco/graphjin/core/internal/qcode"
)
type cquery struct {
sync.Once
q rquery
stmts []stmt
st stmt
roleArg bool
}
type rquery struct {
op qcode.QType
name string
query []byte
vars []byte
}
// nolint: errcheck
func (gj *GraphJin) prepareRoleStmt() error {
if !gj.abacEnabled {
return nil
}
if !strings.Contains(gj.conf.RolesQuery, "$user_id") {
return fmt.Errorf("roles_query: $user_id variable missing")
}
w := &bytes.Buffer{}
io.WriteString(w, `SELECT (CASE WHEN EXISTS (`)
gj.pc.RenderVar(w, &gj.roleStmtMD, gj.conf.RolesQuery)
io.WriteString(w, `) THEN `)
io.WriteString(w, `(SELECT (CASE`)
for _, role := range gj.conf.Roles {
if role.Match == "" {
continue
}
io.WriteString(w, ` WHEN `)
io.WriteString(w, role.Match)
io.WriteString(w, ` THEN '`)
io.WriteString(w, role.Name)
io.WriteString(w, `'`)
}
io.WriteString(w, ` ELSE 'user' END) FROM (`)
gj.pc.RenderVar(w, &gj.roleStmtMD, gj.conf.RolesQuery)
io.WriteString(w, `) AS "_sg_auth_roles_query" LIMIT 1) `)
io.WriteString(w, `ELSE 'anon' END) FROM (VALUES (1)) AS "_sg_auth_filler" LIMIT 1; `)
gj.roleStmt = w.String()
return nil
}
func (gj *GraphJin) initAllowList() error {
var err error
if gj.conf.DisableAllowList {
return nil
}
gj.allowList, err = allow.New(gj.conf.AllowListFile, allow.Config{
Log: gj.log,
})
if err != nil {
return fmt.Errorf("failed to initialize allow list: %w", err)
}
gj.queries = make(map[string]*cquery)
list, err := gj.allowList.Load()
if err != nil {
return err
}
for _, v := range list {
if v.Query == "" {
continue
}
qt, _ := qcode.GetQType(v.Query)
q := rquery{
op: qt,
name: v.Name,
query: []byte(v.Query),
vars: []byte(v.Vars),
}
switch q.op {
case qcode.QTQuery, qcode.QTSubscription:
gj.queries[(v.Name + "user")] = &cquery{q: q}
gj.queries[(v.Name + "anon")] = &cquery{q: q}
case qcode.QTMutation:
for _, role := range gj.conf.Roles {
gj.queries[(v.Name + role.Name)] = &cquery{q: q}
}
}
}
return nil
}