-
Notifications
You must be signed in to change notification settings - Fork 708
/
compile.go
230 lines (200 loc) · 6.72 KB
/
compile.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
package regosql
import (
"encoding/json"
"strings"
"github.com/open-policy-agent/opa/ast"
"github.com/open-policy-agent/opa/rego"
"golang.org/x/xerrors"
"github.com/coder/coder/coderd/rbac/regosql/sqltypes"
)
// ConvertConfig is required to generate SQL from the rego queries.
type ConvertConfig struct {
// VariableConverter is called each time a var is encountered. This creates
// the SQL ast for the variable. Without this, the SQL generator does not
// know how to convert rego variables into SQL columns.
VariableConverter sqltypes.VariableMatcher
}
// ConvertRegoAst converts partial rego queries into a single SQL where
// clause. If the query equates to "true" then the user should have access.
func ConvertRegoAst(cfg ConvertConfig, partial *rego.PartialQueries) (sqltypes.BooleanNode, error) {
if len(partial.Queries) == 0 {
// Always deny if there are no queries. This means there is no possible
// way this user can access these resources.
return sqltypes.Bool(false), nil
}
for _, q := range partial.Queries {
// An empty query in rego means "true". If any query in the set is
// empty, then the user should have access.
if len(q) == 0 {
// Always allow
return sqltypes.Bool(true), nil
}
}
var queries []sqltypes.BooleanNode
var builder strings.Builder
for i, q := range partial.Queries {
converted, err := convertQuery(cfg, q)
if err != nil {
return nil, xerrors.Errorf("query %s: %w", q.String(), err)
}
if i != 0 {
_, _ = builder.WriteString("\n")
}
_, _ = builder.WriteString(q.String())
queries = append(queries, converted)
}
// All queries are OR'd together. This means that if any query is true,
// then the user should have access.
sqlClause := sqltypes.Or(sqltypes.RegoSource(builder.String()), queries...)
// Always wrap in parens to ensure the correct precedence when combining with other
// SQL clauses.
return sqltypes.BoolParenthesis(sqlClause), nil
}
func convertQuery(cfg ConvertConfig, q ast.Body) (sqltypes.BooleanNode, error) {
var expressions []sqltypes.BooleanNode
for _, e := range q {
exp, err := convertExpression(cfg, e)
if err != nil {
return nil, xerrors.Errorf("expression %s: %w", e.String(), err)
}
expressions = append(expressions, exp)
}
// All expressions in a single query are AND'd together. This means that
// all expressions must be true for the user to have access.
return sqltypes.And(sqltypes.RegoSource(q.String()), expressions...), nil
}
func convertExpression(cfg ConvertConfig, e *ast.Expr) (sqltypes.BooleanNode, error) {
if e.IsCall() {
n, err := convertCall(cfg, e.Terms.([]*ast.Term))
if err != nil {
return nil, xerrors.Errorf("call: %w", err)
}
boolN, ok := n.(sqltypes.BooleanNode)
if !ok {
return nil, xerrors.Errorf("call %q: not a boolean expression", e.String())
}
return boolN, nil
}
// If it's not a call, it is a single term
if term, ok := e.Terms.(*ast.Term); ok {
ty, err := convertTerm(cfg, term)
if err != nil {
return nil, xerrors.Errorf("convert term %s: %w", term.String(), err)
}
tyBool, ok := ty.(sqltypes.BooleanNode)
if !ok {
return nil, xerrors.Errorf("convert term %s is not a boolean: %w", term.String(), err)
}
return tyBool, nil
}
return nil, xerrors.Errorf("expression %s not supported", e.String())
}
// convertCall converts a function call to a SQL expression.
func convertCall(cfg ConvertConfig, call ast.Call) (sqltypes.Node, error) {
if len(call) == 0 {
return nil, xerrors.Errorf("empty call")
}
// Operator is the first term
op := call[0]
var args []*ast.Term
if len(call) > 1 {
args = call[1:]
}
opString := op.String()
// Supported operators.
switch op.String() {
case "neq", "eq", "equals", "equal":
args, err := convertTerms(cfg, args, 2)
if err != nil {
return nil, xerrors.Errorf("arguments: %w", err)
}
not := false
if opString == "neq" || opString == "notequals" || opString == "notequal" {
not = true
}
equality := sqltypes.Equality(not, args[0], args[1])
return sqltypes.BoolParenthesis(equality), nil
case "internal.member_2":
args, err := convertTerms(cfg, args, 2)
if err != nil {
return nil, xerrors.Errorf("arguments: %w", err)
}
member := sqltypes.MemberOf(args[0], args[1])
return sqltypes.BoolParenthesis(member), nil
default:
return nil, xerrors.Errorf("operator %s not supported", op)
}
}
func convertTerms(cfg ConvertConfig, terms []*ast.Term, expected int) ([]sqltypes.Node, error) {
if len(terms) != expected {
return nil, xerrors.Errorf("expected %d terms, got %d", expected, len(terms))
}
result := make([]sqltypes.Node, 0, len(terms))
for _, t := range terms {
term, err := convertTerm(cfg, t)
if err != nil {
return nil, xerrors.Errorf("term: %w", err)
}
result = append(result, term)
}
return result, nil
}
func convertTerm(cfg ConvertConfig, term *ast.Term) (sqltypes.Node, error) {
source := sqltypes.RegoSource(term.String())
switch t := term.Value.(type) {
case ast.Var:
// All vars should be contained in ast.Ref's.
return nil, xerrors.New("var not yet supported")
case ast.Ref:
if len(t) == 0 {
// A reference with no text is a variable with no name?
// This makes no sense.
return nil, xerrors.New("empty ref not supported")
}
if cfg.VariableConverter == nil {
return nil, xerrors.New("no variable converter provided to handle variables")
}
// The structure of references is as follows:
// 1. All variables start with a regoAst.Var as the first term.
// 2. The next term is either a regoAst.String or a regoAst.Var.
// - regoAst.String if a static field name or index.
// - regoAst.Var if the field reference is a variable itself. Such as
// the wildcard "[_]"
// 3. Repeat 1-2 until the end of the reference.
node, ok := cfg.VariableConverter.ConvertVariable(t)
if !ok {
return nil, xerrors.Errorf("variable %q cannot be converted", t.String())
}
return node, nil
case ast.String:
return sqltypes.String(string(t)), nil
case ast.Number:
return sqltypes.Number(source, json.Number(t)), nil
case ast.Boolean:
return sqltypes.Bool(bool(t)), nil
case *ast.Array:
elems := make([]sqltypes.Node, 0, t.Len())
for i := 0; i < t.Len(); i++ {
value, err := convertTerm(cfg, t.Elem(i))
if err != nil {
return nil, xerrors.Errorf("array element %d in %q: %w", i, t.String(), err)
}
elems = append(elems, value)
}
return sqltypes.Array(source, elems...)
case ast.Object:
return nil, xerrors.New("object not yet supported")
case ast.Set:
// Just treat a set like an array for now.
arr := t.Sorted()
return convertTerm(cfg, &ast.Term{
Value: arr,
Location: term.Location,
})
case ast.Call:
// This is a function call
return convertCall(cfg, t)
default:
return nil, xerrors.Errorf("%T not yet supported", t)
}
}