Skip to content

Commit

Permalink
feat: add support for subscriptions
Browse files Browse the repository at this point in the history
  • Loading branch information
dosco committed Jul 2, 2020
1 parent cd7f26b commit 5f4f2e0
Show file tree
Hide file tree
Showing 39 changed files with 1,954 additions and 1,349 deletions.
6 changes: 5 additions & 1 deletion config/dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,11 @@ roles:
- name: products
query:
limit: 50
filters: ["{ user_id: { eq: $user_id } }"]
filters:
["{ user_id: { eq: $user_id } }"]
# This is a role table level config that blocks aggregation functions
# like `count_id` or custom postgres functions that you can use in your query
# (https://supergraph.dev/docs/graphql/#custom-functions)
disable_functions: false

insert:
Expand Down
2 changes: 2 additions & 0 deletions core/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ import (
"hash/maphash"
_log "log"
"os"
"sync"

"github.com/chirino/graphql"
"github.com/dosco/super-graph/core/internal/allow"
Expand Down Expand Up @@ -93,6 +94,7 @@ type SuperGraph struct {
qc *qcode.Compiler
pc *psql.Compiler
ge *graphql.Engine
subs sync.Map
}

// NewSuperGraph creates the SuperGraph struct, this involves querying the database to learn its
Expand Down
29 changes: 16 additions & 13 deletions core/args.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package core

import (
"context"
"encoding/json"
"fmt"

Expand All @@ -11,15 +12,17 @@ import (
// argList function is used to create a list of arguments to pass
// to a prepared statement.

func (c *scontext) argList(md psql.Metadata) ([]interface{}, error) {
func (sg *SuperGraph) argList(c context.Context, md psql.Metadata, vars []byte) (
[]interface{}, error) {

params := md.Params()
vars := make([]interface{}, len(params))
vl := make([]interface{}, len(params))

var fields map[string]json.RawMessage
var err error

if len(c.vars) != 0 {
fields, _, err = jsn.Tree(c.vars)
if len(vars) != 0 {
fields, _, err = jsn.Tree(vars)

if err != nil {
return nil, err
Expand All @@ -30,34 +33,34 @@ func (c *scontext) argList(md psql.Metadata) ([]interface{}, error) {
switch p.Name {
case "user_id":
if v := c.Value(UserIDKey); v != nil {
vars[i] = v.(string)
vl[i] = v.(string)
} else {
return nil, argErr(p)
}

case "user_id_provider":
if v := c.Value(UserIDProviderKey); v != nil {
vars[i] = v.(string)
vl[i] = v.(string)
} else {
return nil, argErr(p)
}

case "user_role":
if v := c.Value(UserRoleKey); v != nil {
vars[i] = v.(string)
vl[i] = v.(string)
} else {
return nil, argErr(p)
}

case "cursor":
if v, ok := fields["cursor"]; ok && v[0] == '"' {
v1, err := c.sg.decrypt(string(v[1 : len(v)-1]))
v1, err := sg.decrypt(string(v[1 : len(v)-1]))
if err != nil {
return nil, err
}
vars[i] = v1
vl[i] = v1
} else {
vars[i] = nil
vl[i] = nil
}

default:
Expand All @@ -72,14 +75,14 @@ func (c *scontext) argList(md psql.Metadata) ([]interface{}, error) {

switch v[0] {
case '[', '{':
vars[i] = v
vl[i] = v

default:
var val interface{}
if err := json.Unmarshal(v, &val); err != nil {
return nil, err
}
vars[i] = val
vl[i] = val
}

} else {
Expand All @@ -88,7 +91,7 @@ func (c *scontext) argList(md psql.Metadata) ([]interface{}, error) {
}
}

return vars, nil
return vl, nil
}

func argErr(p psql.Param) error {
Expand Down
67 changes: 27 additions & 40 deletions core/build.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"encoding/json"
"errors"
"fmt"
"io"

"github.com/dosco/super-graph/core/internal/psql"
"github.com/dosco/super-graph/core/internal/qcode"
Expand All @@ -18,28 +17,15 @@ type stmt struct {
sql string
}

func (sg *SuperGraph) buildStmt(qt qcode.QType, query, vars []byte, role string) ([]stmt, error) {
switch qt {
case qcode.QTMutation:
return sg.buildRoleStmt(query, vars, role)

case qcode.QTQuery:
if role == "anon" {
return sg.buildRoleStmt(query, vars, "anon")
}

if sg.abacEnabled {
return sg.buildMultiStmt(query, vars)
}

return sg.buildRoleStmt(query, vars, "user")

default:
return nil, fmt.Errorf("unknown query type '%d'", qt)
func (sg *SuperGraph) buildStmt(qt qcode.QType, query, vars []byte, role string, poll bool) ([]stmt, error) {
if qt == qcode.QTQuery && sg.abacEnabled {
return sg.buildMultiStmt(query, vars, poll)
}

return sg.buildRoleStmt(query, vars, role, poll)
}

func (sg *SuperGraph) buildRoleStmt(query, vars []byte, role string) ([]stmt, error) {
func (sg *SuperGraph) buildRoleStmt(query, vars []byte, role string, poll bool) ([]stmt, error) {
ro, ok := sg.roles[role]
if !ok {
return nil, fmt.Errorf(`roles '%s' not defined in c.sg.config`, role)
Expand All @@ -61,8 +47,9 @@ func (sg *SuperGraph) buildRoleStmt(query, vars []byte, role string) ([]stmt, er

stmts := []stmt{{role: ro, qc: qc}}
w := &bytes.Buffer{}
md := psql.Metadata{Poll: poll}

stmts[0].md, err = sg.pc.Compile(w, qc, psql.Variables(vm))
stmts[0].md, err = sg.pc.CompileWithMetadata(w, qc, psql.Variables(vm), md)
if err != nil {
return nil, err
}
Expand All @@ -71,7 +58,7 @@ func (sg *SuperGraph) buildRoleStmt(query, vars []byte, role string) ([]stmt, er
return stmts, nil
}

func (sg *SuperGraph) buildMultiStmt(query, vars []byte) ([]stmt, error) {
func (sg *SuperGraph) buildMultiStmt(query, vars []byte, poll bool) ([]stmt, error) {
var vm map[string]json.RawMessage
var err error

Expand All @@ -87,7 +74,7 @@ func (sg *SuperGraph) buildMultiStmt(query, vars []byte) ([]stmt, error) {

stmts := make([]stmt, 0, len(sg.conf.Roles))
w := &bytes.Buffer{}
md := psql.Metadata{}
md := psql.Metadata{Poll: poll}

for i := 0; i < len(sg.conf.Roles); i++ {
role := &sg.conf.Roles[i]
Expand Down Expand Up @@ -129,40 +116,40 @@ func (sg *SuperGraph) buildMultiStmt(query, vars []byte) ([]stmt, error) {
func (sg *SuperGraph) renderUserQuery(md psql.Metadata, stmts []stmt) (string, error) {
w := &bytes.Buffer{}

io.WriteString(w, `SELECT "_sg_auth_info"."role", (CASE "_sg_auth_info"."role" `)
w.WriteString(`SELECT "_sg_auth_info"."role", (CASE "_sg_auth_info"."role" `)

for _, s := range stmts {
if s.role.Match == "" &&
s.role.Name != "user" && s.role.Name != "anon" {
continue
}
io.WriteString(w, `WHEN '`)
io.WriteString(w, s.role.Name)
io.WriteString(w, `' THEN (`)
io.WriteString(w, s.sql)
io.WriteString(w, `) `)
w.WriteString(`WHEN '`)
w.WriteString(s.role.Name)
w.WriteString(`' THEN (`)
w.WriteString(s.sql)
w.WriteString(`) `)
}

io.WriteString(w, `END) FROM (SELECT (CASE WHEN EXISTS (`)
w.WriteString(`END) FROM (SELECT (CASE WHEN EXISTS (`)
md.RenderVar(w, sg.conf.RolesQuery)
io.WriteString(w, `) THEN `)
w.WriteString(`) THEN `)

io.WriteString(w, `(SELECT (CASE`)
w.WriteString(`(SELECT (CASE`)
for _, s := range stmts {
if s.role.Match == "" {
continue
}
io.WriteString(w, ` WHEN `)
io.WriteString(w, s.role.Match)
io.WriteString(w, ` THEN '`)
io.WriteString(w, s.role.Name)
io.WriteString(w, `'`)
w.WriteString(` WHEN `)
w.WriteString(s.role.Match)
w.WriteString(` THEN '`)
w.WriteString(s.role.Name)
w.WriteString(`'`)
}

io.WriteString(w, ` ELSE 'user' END) FROM (`)
w.WriteString(` ELSE 'user' END) FROM (`)
md.RenderVar(w, sg.conf.RolesQuery)
io.WriteString(w, `) AS "_sg_auth_roles_query" LIMIT 1) `)
io.WriteString(w, `ELSE 'anon' END) FROM (VALUES (1)) AS "_sg_auth_filler") AS "_sg_auth_info"(role) LIMIT 1; `)
w.WriteString(`) AS "_sg_auth_roles_query" LIMIT 1) `)
w.WriteString(`ELSE 'anon' END) FROM (VALUES (1)) AS "_sg_auth_filler") AS "_sg_auth_info"(role) LIMIT 1; `)

return w.String(), nil
}
9 changes: 9 additions & 0 deletions core/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"path"
"path/filepath"
"strings"
"time"

"github.com/spf13/viper"
)
Expand Down Expand Up @@ -69,6 +70,14 @@ type Config struct {

// Log warnings and other debug information
Debug bool

// Useful for quickly debugging. Please set to false in production
CredsInVars bool `mapstructure:"creds_in_vars"`

// Subscriptions poll the database to query for updates
// this sets the duration (in seconds) between requests.
// Defaults to 5 seconds
PollDuration time.Duration `mapstructure:"poll_every_seconds"`
}

// Table struct defines a database table
Expand Down
9 changes: 5 additions & 4 deletions core/core.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ import (
type OpType int

const (
OpQuery OpType = iota
OpUnknown OpType = iota
OpQuery
OpMutation
)

Expand Down Expand Up @@ -196,7 +197,7 @@ func (c *scontext) resolvePreparedSQL() ([]byte, *stmt, error) {
var root []byte
var row *sql.Row

varsList, err := c.argList(q.st.md)
varsList, err := c.sg.argList(c, q.st.md, c.vars)
if err != nil {
return nil, nil, err
}
Expand Down Expand Up @@ -262,14 +263,14 @@ func (c *scontext) resolveSQL() ([]byte, *stmt, error) {
c.role = v.(string)
}

stmts, err := c.sg.buildStmt(c.res.op, []byte(c.query), c.vars, c.role)
stmts, err := c.sg.buildStmt(c.res.op, []byte(c.query), c.vars, c.role, false)
if err != nil {
return nil, nil, err
}
st := &stmts[0]
c.res.sql = st.sql

varList, err := c.argList(st.md)
varList, err := c.sg.argList(c, st.md, c.vars)
if err != nil {
return nil, nil, err
}
Expand Down
23 changes: 16 additions & 7 deletions core/internal/psql/metadata.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,19 +36,28 @@ func (md *Metadata) RenderVar(w io.Writer, vv string) {
}

func (md *Metadata) renderParam(w io.Writer, p Param) {
_, _ = io.WriteString(w, `$`)
if v, ok := md.pindex[p.Name]; ok {
int32String(w, int32(v))
var id int
var ok bool

} else {
if !md.Poll {
_, _ = io.WriteString(w, `$`)
}

if id, ok = md.pindex[p.Name]; !ok {
md.params = append(md.params, p)
n := len(md.params)
id = len(md.params)

if md.pindex == nil {
md.pindex = make(map[string]int)
}
md.pindex[p.Name] = n
int32String(w, int32(n))
md.pindex[p.Name] = id
}

if md.Poll {
_, _ = io.WriteString(w, `"_sg_sub".`)
quoted(w, p.Name)
} else {
int32String(w, int32(id))
}
}

Expand Down
4 changes: 2 additions & 2 deletions core/internal/psql/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ type Param struct {
}

type Metadata struct {
Poll bool
remoteCount int
params []Param
pindex map[string]int
Expand Down Expand Up @@ -81,7 +82,6 @@ func (co *Compiler) CompileEx(qc *qcode.QCode, vars Variables) (Metadata, []byte

func (co *Compiler) Compile(w io.Writer, qc *qcode.QCode, vars Variables) (Metadata, error) {
return co.CompileWithMetadata(w, qc, vars, Metadata{})

}

func (co *Compiler) CompileWithMetadata(w io.Writer, qc *qcode.QCode, vars Variables, md Metadata) (Metadata, error) {
Expand Down Expand Up @@ -524,7 +524,7 @@ func (c *compilerContext) renderSelect(sel *qcode.Select, ti *DBTableInfo, vars
}
}

io.WriteString(c.w, `AS "json"`)
io.WriteString(c.w, `AS "json" `)

if sel.Paging.Type != qcode.PtOffset {
for i := range sel.OrderBy {
Expand Down
Loading

0 comments on commit 5f4f2e0

Please sign in to comment.