Skip to content

Commit

Permalink
fix: enforce allow list with subscriptions and other fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
dosco committed Jul 3, 2020
1 parent e10c814 commit 9c19db2
Show file tree
Hide file tree
Showing 18 changed files with 407 additions and 523 deletions.
1 change: 1 addition & 0 deletions config/dev.yml
Expand Up @@ -52,6 +52,7 @@ cors_debug: true
# on POST requests (does not work with not mutations)
# cache_control: "public, max-age=300, s-maxage=600"

poll_every_seconds: 2s
# Postgres related environment Variables
# SG_DATABASE_HOST
# SG_DATABASE_PORT
Expand Down
56 changes: 34 additions & 22 deletions core/api.go
Expand Up @@ -49,6 +49,7 @@ import (
"crypto/sha256"
"database/sql"
"encoding/json"
"errors"
"hash/maphash"
_log "log"
"os"
Expand Down Expand Up @@ -86,9 +87,9 @@ type SuperGraph struct {
allowList *allow.List
encKey [32]byte
hashSeed maphash.Seed
queries map[uint64]*query
queries map[string]*cquery
roles map[string]*Role
getRole *sql.Stmt
roleStmt string
rmap map[uint64]resolvFn
abacEnabled bool
qc *qcode.Compiler
Expand Down Expand Up @@ -129,10 +130,6 @@ func newSuperGraph(conf *Config, db *sql.DB, dbinfo *psql.DBInfo) (*SuperGraph,
return nil, err
}

if err := sg.initPrepared(); err != nil {
return nil, err
}

if err := sg.initResolvers(); err != nil {
return nil, err
}
Expand All @@ -141,6 +138,8 @@ func newSuperGraph(conf *Config, db *sql.DB, dbinfo *psql.DBInfo) (*SuperGraph,
return nil, err
}

sg.prepareRoleStmt()

if conf.SecretKey != "" {
sk := sha256.Sum256([]byte(conf.SecretKey))
conf.SecretKey = ""
Expand Down Expand Up @@ -172,43 +171,56 @@ type Result struct {
// In developer mode all names queries are saved into a file `allow.list` and in production mode only
// queries from this file can be run.
func (sg *SuperGraph) GraphQL(c context.Context, query string, vars json.RawMessage) (*Result, error) {
var res Result
ct := scontext{
Context: c,
sg: sg,
op: qcode.GetQType(query),
name: Name(query),
}

res.op = qcode.GetQType(query)
res.name = allow.QueryName(query)
res := &Result{
op: ct.op,
name: ct.name,
}

if ct.op == qcode.QTSubscription {
return nil, errors.New("use 'core.Subscribe' for subscriptions")
}

// use the chirino/graphql library for introspection queries
// disabled when allow list is enforced
if !sg.conf.UseAllowList && res.name == "IntrospectionQuery" {
if !sg.conf.UseAllowList && ct.name == "IntrospectionQuery" {
r := sg.ge.ServeGraphQL(&graphql.Request{Query: query})
res.Data = r.Data

if r.Error() != nil {
res.Error = r.Error().Error()
}
return &res, r.Error()
return res, r.Error()
}

ct := scontext{Context: c, sg: sg, query: query, vars: vars, res: res}

if len(vars) <= 2 {
ct.vars = nil
}
var role string

if keyExists(c, UserIDKey) {
ct.role = "user"
role = "user"
} else {
ct.role = "anon"
role = "anon"
}

data, err := ct.execQuery()
qr, err := ct.execQuery(query, vars, role)

if err != nil {
return &ct.res, err
res.Error = err.Error()
}

if qr.q != nil {
res.sql = qr.q.st.sql
}

ct.res.Data = json.RawMessage(data)
res.Data = json.RawMessage(qr.data)
res.role = qr.role

return &ct.res, nil
return res, nil
}

// GraphQLSchema function return the GraphQL schema for the underlying database connected
Expand Down
10 changes: 9 additions & 1 deletion core/args.go
Expand Up @@ -3,6 +3,7 @@ package core
import (
"context"
"encoding/json"
"errors"
"fmt"

"github.com/dosco/super-graph/core/internal/psql"
Expand Down Expand Up @@ -33,7 +34,14 @@ func (sg *SuperGraph) argList(c context.Context, md psql.Metadata, vars []byte)
switch p.Name {
case "user_id":
if v := c.Value(UserIDKey); v != nil {
vl[i] = v.(string)
switch v1 := v.(type) {
case string:
vl[i] = v1
case int:
vl[i] = v1
default:
return nil, errors.New("user_id must be an integer or a string")
}
} else {
return nil, argErr(p)
}
Expand Down
99 changes: 77 additions & 22 deletions core/build.go
Expand Up @@ -17,59 +17,110 @@ type stmt struct {
sql string
}

func (sg *SuperGraph) buildStmt(qt qcode.QType, query, vars []byte, role string, poll bool) ([]stmt, error) {
if qt == qcode.QTQuery && sg.abacEnabled {
return sg.buildMultiStmt(query, vars, poll)
func (sg *SuperGraph) compileQuery(cq *cquery, role string) error {
var err error

// In production mode enforce the allow list and
// compile and cache the result else compile each time
if sg.conf.UseAllowList {
if cq1, ok := sg.queries[(cq.q.name + role)]; ok {
cq.q = cq1.q
} else {
return errNotFound
}

if cq.st.sql == "" {
cq.Do(func() {
err = sg.compileQueryFn(cq, role)
})
}

} else {
err = sg.compileQueryFn(cq, role)
}

return err
}

func (sg *SuperGraph) compileQueryFn(cq *cquery, role string) error {
var err error

switch cq.q.op {
case qcode.QTQuery:
if sg.abacEnabled {
cq.stmts, cq.st, err = sg.buildMultiStmt(cq.q.query, cq.q.vars, false)
} else {
cq.st, err = sg.buildRoleStmt(cq.q.query, cq.q.vars, role, false)
}

case qcode.QTSubscription:
if sg.abacEnabled {
cq.stmts, cq.st, err = sg.buildMultiStmt(cq.q.query, cq.q.vars, true)
} else {
cq.st, err = sg.buildRoleStmt(cq.q.query, cq.q.vars, role, true)
}

case qcode.QTMutation:
cq.st, err = sg.buildRoleStmt(cq.q.query, cq.q.vars, role, true)

default:
err = errors.New("unknown query")
}

return sg.buildRoleStmt(query, vars, role, poll)
cq.roleArg = (len(cq.stmts) > 0)
return err
}

func (sg *SuperGraph) buildRoleStmt(query, vars []byte, role string, poll bool) ([]stmt, error) {
func (sg *SuperGraph) buildRoleStmt(query, vars []byte, role string, poll bool) (stmt, error) {
var st stmt

ro, ok := sg.roles[role]
if !ok {
return nil, fmt.Errorf(`roles '%s' not defined in c.sg.config`, role)
return st, fmt.Errorf(`roles '%s' not defined in c.sg.config`, role)
}

var vm map[string]json.RawMessage
var err error

if len(vars) != 0 {
if err := json.Unmarshal(vars, &vm); err != nil {
return nil, err
return st, err
}
}

qc, err := sg.qc.Compile(query, ro.Name)
if err != nil {
return nil, err
return st, err
}

stmts := []stmt{{role: ro, qc: qc}}
w := &bytes.Buffer{}
md := psql.Metadata{Poll: poll}

stmts[0].md, err = sg.pc.CompileWithMetadata(w, qc, psql.Variables(vm), md)
st.md, err = sg.pc.CompileWithMetadata(w, qc, psql.Variables(vm), md)
if err != nil {
return nil, err
return st, err
}
stmts[0].sql = w.String()

return stmts, nil
st.role = ro
st.qc = qc
st.sql = w.String()

return st, nil
}

func (sg *SuperGraph) buildMultiStmt(query, vars []byte, poll bool) ([]stmt, error) {
func (sg *SuperGraph) buildMultiStmt(query, vars []byte, poll bool) ([]stmt, stmt, error) {
var vm map[string]json.RawMessage
var err error
var st stmt

if len(vars) != 0 {
if err := json.Unmarshal(vars, &vm); err != nil {
return nil, err
return nil, st, err
}
}

if sg.conf.RolesQuery == "" {
return nil, errors.New("roles_query not defined")
return nil, st, errors.New("roles_query not defined")
}

stmts := make([]stmt, 0, len(sg.conf.Roles))
Expand All @@ -86,34 +137,38 @@ func (sg *SuperGraph) buildMultiStmt(query, vars []byte, poll bool) ([]stmt, err

qc, err := sg.qc.Compile(query, role.Name)
if err != nil {
return nil, err
return nil, st, err
}

stmts = append(stmts, stmt{role: role, qc: qc})
s := &stmts[len(stmts)-1]

md, err = sg.pc.CompileWithMetadata(w, qc, psql.Variables(vm), md)
if err != nil {
return nil, err
return nil, st, err
}

s.sql = w.String()
s.md = md

w.Reset()
}
st = stmts[0]

sql, err := sg.renderUserQuery(md, stmts)
st.sql, err = sg.renderUserQuery(md, stmts)
if err != nil {
return nil, err
return nil, st, err
}

stmts[0].sql = sql
return stmts, nil
return stmts, st, nil
}

//nolint: errcheck
func (sg *SuperGraph) renderUserQuery(md psql.Metadata, stmts []stmt) (string, error) {
if sg.conf.RolesQuery == "" {
return "", errors.New("roles_query not defined")
}

w := &bytes.Buffer{}

w.WriteString(`SELECT "_sg_auth_info"."role", (CASE "_sg_auth_info"."role" `)
Expand Down

0 comments on commit 9c19db2

Please sign in to comment.