From 8992238be419c07045bf4ee3435f7c877bea480b Mon Sep 17 00:00:00 2001 From: Abhimanyu Singh Gaur <12651351+abhimanyusinghgaur@users.noreply.github.com> Date: Tue, 19 May 2020 17:30:34 +0530 Subject: [PATCH] Add authn for graphql and http admin endpoints (#5162) Fixes #4758. This PR adds authentication to following endpoints: /admin/backup (http & graphql) /admin/config/lru_mb (http [GET & PUT] & graphql [query & mutation]) /admin/draining (http & graphql) /admin/export (http & graphql) /admin/shutdown (http & graphql) /admin/restore (graphql only) /admin/listBackups (graphql only) Now, all the above http endpoints and their corresponding graphql versions have following kinds of auth: IP White-listing, if --whitelist flag is passed to alpha Poor-man's auth, if --auth_token flag is passed to alpha Guardian only access, if ACL is enabled This PR also adds query for config in graphql admin, as it was missing earlier. In addition to above points: All the /admin endpoints apply Poor-man's auth check at http level itself, while other auth checks are routed through graphql resolvers. GraphQL Resolvers for health/state and the ones related to ACL User/Group have IP whitelisting middleware applied, while dgraph handles Guardian auth for them. /alter has the existing behaviour of checking only Poor-man's and Guardian auth. GraphQL Resolvers related to schema don't apply IP whitelisting as to keep them in sync with /alter. They do apply Guardian auth. Any GraphQL admin introspection queries don't require IP whitelisting or Guardian auth. --- dgraph/cmd/alpha/admin.go | 186 ++++++++++------ dgraph/cmd/alpha/admin_backup.go | 56 ++--- dgraph/cmd/alpha/http.go | 25 ++- dgraph/cmd/alpha/http_test.go | 25 +-- dgraph/cmd/alpha/run.go | 62 ++++-- edgraph/access.go | 2 +- edgraph/access_ee.go | 4 +- edgraph/server.go | 4 +- ee/acl/acl_test.go | 320 ++++++++++++++++++++++------ graphql/admin/admin.go | 79 +++++-- graphql/admin/backup.go | 4 +- graphql/admin/config.go | 24 ++- graphql/admin/export.go | 6 +- graphql/admin/health.go | 2 +- graphql/admin/list_backups.go | 8 +- graphql/admin/login.go | 4 +- graphql/admin/restore.go | 4 +- graphql/admin/state.go | 8 +- graphql/resolve/middlewares.go | 178 ++++++++++++++++ graphql/resolve/middlewares_test.go | 97 +++++++++ graphql/resolve/resolver.go | 49 ++++- graphql/web/http.go | 26 +-- testutil/graphql.go | 32 +++ x/x.go | 37 ++++ 24 files changed, 968 insertions(+), 274 deletions(-) create mode 100644 graphql/resolve/middlewares.go create mode 100644 graphql/resolve/middlewares_test.go diff --git a/dgraph/cmd/alpha/admin.go b/dgraph/cmd/alpha/admin.go index e608ff473a3..2c2482189d8 100644 --- a/dgraph/cmd/alpha/admin.go +++ b/dgraph/cmd/alpha/admin.go @@ -17,76 +17,100 @@ package alpha import ( - "bytes" - "context" + "encoding/json" "fmt" "io/ioutil" - "net" "net/http" "strconv" - "github.com/dgraph-io/dgraph/posting" + "github.com/dgraph-io/dgraph/graphql/schema" + "github.com/dgraph-io/dgraph/graphql/web" + "github.com/dgraph-io/dgraph/worker" "github.com/dgraph-io/dgraph/x" - "github.com/golang/glog" ) -// handlerInit does some standard checks. Returns false if something is wrong. -func handlerInit(w http.ResponseWriter, r *http.Request, allowedMethods map[string]bool) bool { - if _, ok := allowedMethods[r.Method]; !ok { - x.SetStatus(w, x.ErrorInvalidMethod, "Invalid method") - return false - } +type allowedMethods map[string]bool - ip, _, err := net.SplitHostPort(r.RemoteAddr) - if err != nil || (!ipInIPWhitelistRanges(ip) && !net.ParseIP(ip).IsLoopback()) { - x.SetStatus(w, x.ErrorUnauthorized, fmt.Sprintf("Request from IP: %v", ip)) +// hasPoormansAuth checks if poorman's auth is required and if so whether the given http request has +// poorman's auth in it or not +func hasPoormansAuth(r *http.Request) bool { + if worker.Config.AuthToken != "" && worker.Config.AuthToken != r.Header.Get( + "X-Dgraph-AuthToken") { return false } return true } -func drainingHandler(w http.ResponseWriter, r *http.Request) { - switch r.Method { - case http.MethodPut, http.MethodPost: - enableStr := r.URL.Query().Get("enable") - - enable, err := strconv.ParseBool(enableStr) - if err != nil { - x.SetStatus(w, x.ErrorInvalidRequest, - "Found invalid value for the enable parameter") +func allowedMethodsHandler(allowedMethods allowedMethods, next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if _, ok := allowedMethods[r.Method]; !ok { + x.SetStatus(w, x.ErrorInvalidMethod, "Invalid method") + w.WriteHeader(http.StatusMethodNotAllowed) return } - x.UpdateDrainingMode(enable) - _, err = w.Write([]byte(fmt.Sprintf(`{"code": "Success",`+ - `"message": "draining mode has been set to %v"}`, enable))) - if err != nil { - glog.Errorf("Failed to write response: %v", err) + next.ServeHTTP(w, r) + }) +} + +// adminAuthHandler does some standard checks for admin endpoints. +// It returns if something is wrong. Otherwise, it lets the given handler serve the request. +func adminAuthHandler(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if !hasPoormansAuth(r) { + x.SetStatus(w, x.ErrorUnauthorized, "Invalid X-Dgraph-AuthToken") + return } - default: - w.WriteHeader(http.StatusMethodNotAllowed) - } + + next.ServeHTTP(w, r) + }) } -func shutDownHandler(w http.ResponseWriter, r *http.Request) { - if !handlerInit(w, r, map[string]bool{ - http.MethodGet: true, - }) { +func drainingHandler(w http.ResponseWriter, r *http.Request, adminServer web.IServeGraphQL) { + enableStr := r.URL.Query().Get("enable") + + enable, err := strconv.ParseBool(enableStr) + if err != nil { + x.SetStatus(w, x.ErrorInvalidRequest, + "Found invalid value for the enable parameter") return } - close(worker.ShutdownCh) + gqlReq := &schema.Request{ + Query: ` + mutation draining($enable: Boolean) { + draining(enable: $enable) { + response { + code + } + } + }`, + Variables: map[string]interface{}{"enable": enable}, + } + _ = resolveWithAdminServer(gqlReq, r, adminServer) w.Header().Set("Content-Type", "application/json") - x.Check2(w.Write([]byte(`{"code": "Success", "message": "Server is shutting down"}`))) + x.Check2(w.Write([]byte(fmt.Sprintf(`{"code": "Success",`+ + `"message": "draining mode has been set to %v"}`, enable)))) } -func exportHandler(w http.ResponseWriter, r *http.Request) { - if !handlerInit(w, r, map[string]bool{ - http.MethodGet: true, - }) { - return +func shutDownHandler(w http.ResponseWriter, r *http.Request, adminServer web.IServeGraphQL) { + gqlReq := &schema.Request{ + Query: ` + mutation { + shutdown { + response { + code + } + } + }`, } + _ = resolveWithAdminServer(gqlReq, r, adminServer) + w.Header().Set("Content-Type", "application/json") + x.Check2(w.Write([]byte(`{"code": "Success", "message": "Server is shutting down"}`))) +} + +func exportHandler(w http.ResponseWriter, r *http.Request, adminServer web.IServeGraphQL) { if err := r.ParseForm(); err != nil { x.SetHttpStatus(w, http.StatusBadRequest, "Parse of export request failed.") return @@ -105,26 +129,37 @@ func exportHandler(w http.ResponseWriter, r *http.Request) { return } } - if err := worker.ExportOverNetwork(context.Background(), format); err != nil { - x.SetStatus(w, err.Error(), "Export failed.") + + gqlReq := &schema.Request{ + Query: ` + mutation export($format: String) { + export(input: {format: $format}) { + response { + code + } + } + }`, + Variables: map[string]interface{}{}, + } + resp := resolveWithAdminServer(gqlReq, r, adminServer) + if len(resp.Errors) != 0 { + x.SetStatus(w, resp.Errors[0].Message, "Export failed.") return } w.Header().Set("Content-Type", "application/json") x.Check2(w.Write([]byte(`{"code": "Success", "message": "Export completed."}`))) } -func memoryLimitHandler(w http.ResponseWriter, r *http.Request) { +func memoryLimitHandler(w http.ResponseWriter, r *http.Request, adminServer web.IServeGraphQL) { switch r.Method { case http.MethodGet: - memoryLimitGetHandler(w, r) + memoryLimitGetHandler(w, r, adminServer) case http.MethodPut: - memoryLimitPutHandler(w, r) - default: - w.WriteHeader(http.StatusMethodNotAllowed) + memoryLimitPutHandler(w, r, adminServer) } } -func memoryLimitPutHandler(w http.ResponseWriter, r *http.Request) { +func memoryLimitPutHandler(w http.ResponseWriter, r *http.Request, adminServer web.IServeGraphQL) { body, err := ioutil.ReadAll(r.Body) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) @@ -135,36 +170,45 @@ func memoryLimitPutHandler(w http.ResponseWriter, r *http.Request) { http.Error(w, err.Error(), http.StatusBadRequest) return } + gqlReq := &schema.Request{ + Query: ` + mutation config($lruMb: Float) { + config(input: {lruMb: $lruMb}) { + response { + code + } + } + }`, + Variables: map[string]interface{}{"lruMb": memoryMB}, + } + resp := resolveWithAdminServer(gqlReq, r, adminServer) - if err := worker.UpdateLruMb(memoryMB); err != nil { + if len(resp.Errors) != 0 { w.WriteHeader(http.StatusBadRequest) - fmt.Fprint(w, err.Error()) + x.Check2(fmt.Fprint(w, resp.Errors[0].Message)) return } w.WriteHeader(http.StatusOK) } -func memoryLimitGetHandler(w http.ResponseWriter, r *http.Request) { - posting.Config.Lock() - memoryMB := posting.Config.AllottedMemory - posting.Config.Unlock() - - if _, err := fmt.Fprintln(w, memoryMB); err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) +func memoryLimitGetHandler(w http.ResponseWriter, r *http.Request, adminServer web.IServeGraphQL) { + gqlReq := &schema.Request{ + Query: ` + query { + config { + lruMb + } + }`, } -} - -func ipInIPWhitelistRanges(ipString string) bool { - ip := net.ParseIP(ipString) - - if ip == nil { - return false + resp := resolveWithAdminServer(gqlReq, r, adminServer) + var data struct { + Config struct { + LruMb float64 + } } + x.Check(json.Unmarshal(resp.Data.Bytes(), &data)) - for _, ipRange := range x.WorkerConfig.WhiteListedIPRanges { - if bytes.Compare(ip, ipRange.Lower) >= 0 && bytes.Compare(ip, ipRange.Upper) <= 0 { - return true - } + if _, err := fmt.Fprintln(w, data.Config.LruMb); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) } - return false } diff --git a/dgraph/cmd/alpha/admin_backup.go b/dgraph/cmd/alpha/admin_backup.go index 8919e0940da..9644b519d28 100644 --- a/dgraph/cmd/alpha/admin_backup.go +++ b/dgraph/cmd/alpha/admin_backup.go @@ -19,43 +19,45 @@ package alpha import ( - "context" "net/http" - "github.com/dgraph-io/dgraph/protos/pb" - "github.com/dgraph-io/dgraph/worker" + "github.com/dgraph-io/dgraph/graphql/schema" + + "github.com/dgraph-io/dgraph/graphql/web" + "github.com/dgraph-io/dgraph/x" ) func init() { - http.HandleFunc("/admin/backup", backupHandler) + http.Handle("/admin/backup", allowedMethodsHandler(allowedMethods{http.MethodPost: true}, + adminAuthHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + backupHandler(w, r, adminServer) + })))) } // backupHandler handles backup requests coming from the HTTP endpoint. -func backupHandler(w http.ResponseWriter, r *http.Request) { - if !handlerInit(w, r, map[string]bool{ - http.MethodPost: true, - }) { - return +func backupHandler(w http.ResponseWriter, r *http.Request, adminServer web.IServeGraphQL) { + gqlReq := &schema.Request{ + Query: ` + mutation backup($input: BackupInput!) { + backup(input: $input) { + response { + code + } + } + }`, + Variables: map[string]interface{}{"input": map[string]interface{}{ + "destination": r.FormValue("destination"), + "accessKey": r.FormValue("access_key"), + "secretKey": r.FormValue("secret_key"), + "sessionToken": r.FormValue("session_token"), + "anonymous": r.FormValue("anonymous") == "true", + "forceFull": r.FormValue("force_full") == "true", + }}, } - - destination := r.FormValue("destination") - accessKey := r.FormValue("access_key") - secretKey := r.FormValue("secret_key") - sessionToken := r.FormValue("session_token") - anonymous := r.FormValue("anonymous") == "true" - forceFull := r.FormValue("force_full") == "true" - - req := pb.BackupRequest{ - Destination: destination, - AccessKey: accessKey, - SecretKey: secretKey, - SessionToken: sessionToken, - Anonymous: anonymous, - } - - if err := worker.ProcessBackupRequest(context.Background(), &req, forceFull); err != nil { - x.SetStatus(w, err.Error(), "Backup failed.") + resp := resolveWithAdminServer(gqlReq, r, adminServer) + if resp.Errors != nil { + x.SetStatus(w, resp.Errors.Error(), "Backup failed.") return } diff --git a/dgraph/cmd/alpha/http.go b/dgraph/cmd/alpha/http.go index ceeaf88cd9f..1419be98eb1 100644 --- a/dgraph/cmd/alpha/http.go +++ b/dgraph/cmd/alpha/http.go @@ -586,12 +586,8 @@ func adminSchemaHandler(w http.ResponseWriter, r *http.Request, adminServer web. return } - md := metadata.New(nil) - ctx := metadata.NewIncomingContext(context.Background(), md) - ctx = x.AttachAccessJwt(ctx, r) - - gqlReq := &schema.Request{} - gqlReq.Query = ` + gqlReq := &schema.Request{ + Query: ` mutation updateGqlSchema($sch: String!) { updateGQLSchema(input: { set: { @@ -602,12 +598,11 @@ func adminSchemaHandler(w http.ResponseWriter, r *http.Request, adminServer web. id } } - }` - gqlReq.Variables = map[string]interface{}{ - "sch": string(b), + }`, + Variables: map[string]interface{}{"sch": string(b)}, } - response := adminServer.Resolve(ctx, gqlReq) + response := resolveWithAdminServer(gqlReq, r, adminServer) if len(response.Errors) > 0 { x.SetStatus(w, x.Error, response.Errors.Error()) return @@ -616,6 +611,16 @@ func adminSchemaHandler(w http.ResponseWriter, r *http.Request, adminServer web. writeSuccessResponse(w, r) } +func resolveWithAdminServer(gqlReq *schema.Request, r *http.Request, + adminServer web.IServeGraphQL) *schema.Response { + md := metadata.New(nil) + ctx := metadata.NewIncomingContext(context.Background(), md) + ctx = x.AttachAccessJwt(ctx, r) + ctx = x.AttachRemoteIP(ctx, r) + + return adminServer.Resolve(ctx, gqlReq) +} + func writeSuccessResponse(w http.ResponseWriter, r *http.Request) { res := map[string]interface{}{} data := map[string]interface{}{} diff --git a/dgraph/cmd/alpha/http_test.go b/dgraph/cmd/alpha/http_test.go index 041dde4838e..b39446a05d3 100644 --- a/dgraph/cmd/alpha/http_test.go +++ b/dgraph/cmd/alpha/http_test.go @@ -735,7 +735,7 @@ func TestHealth(t *testing.T) { require.True(t, info[0].Uptime > int64(time.Duration(1))) } -func setDrainingMode(t *testing.T, enable bool) { +func setDrainingMode(t *testing.T, enable bool, accessJwt string) { drainingRequest := `mutation drain($enable: Boolean) { draining(enable: $enable) { response { @@ -743,22 +743,13 @@ func setDrainingMode(t *testing.T, enable bool) { } } }` - adminUrl := fmt.Sprintf("%s/admin", addr) - params := testutil.GraphQLParams{ + params := &testutil.GraphQLParams{ Query: drainingRequest, Variables: map[string]interface{}{"enable": enable}, } - b, err := json.Marshal(params) - require.NoError(t, err) - - resp, err := http.Post(adminUrl, "application/json", bytes.NewBuffer(b)) - require.NoError(t, err) - - defer resp.Body.Close() - b, err = ioutil.ReadAll(resp.Body) - require.NoError(t, err) - require.JSONEq(t, `{"data":{"draining":{"response":{"code":"Success"}}}}`, - string(b)) + resp := testutil.MakeGQLRequestWithAccessJwt(t, params, accessJwt) + resp.RequireNoGraphQLErrors(t) + require.JSONEq(t, `{"draining":{"response":{"code":"Success"}}}`, string(resp.Data)) } func TestDrainingMode(t *testing.T) { @@ -800,10 +791,12 @@ func TestDrainingMode(t *testing.T) { } - setDrainingMode(t, true) + grootJwt, _ := testutil.GrootHttpLogin(addr + "/admin") + + setDrainingMode(t, true, grootJwt) runRequests(true) - setDrainingMode(t, false) + setDrainingMode(t, false, grootJwt) runRequests(false) } diff --git a/dgraph/cmd/alpha/run.go b/dgraph/cmd/alpha/run.go index 0411cf2ceba..1c87ff1ccd8 100644 --- a/dgraph/cmd/alpha/run.go +++ b/dgraph/cmd/alpha/run.go @@ -35,6 +35,8 @@ import ( "syscall" "time" + "github.com/dgraph-io/dgraph/graphql/web" + "github.com/dgraph-io/badger/v2/y" "github.com/dgraph-io/dgo/v200/protos/api" "github.com/dgraph-io/dgraph/edgraph" @@ -76,6 +78,9 @@ var ( // Alpha is the sub-command invoked when running "dgraph alpha". Alpha x.SubCommand + + // need this here to refer it in admin_backup.go + adminServer web.IServeGraphQL ) func init() { @@ -453,11 +458,6 @@ func setupServer(closer *y.Closer) { // TODO: Figure out what this is for? http.HandleFunc("/debug/store", storeStatsHandler) - http.HandleFunc("/admin/shutdown", shutDownHandler) - http.HandleFunc("/admin/draining", drainingHandler) - http.HandleFunc("/admin/export", exportHandler) - http.HandleFunc("/admin/config/lru_mb", memoryLimitHandler) - introspection := Alpha.Conf.GetBool("graphql_introspection") // Global Epoch is a lockless synchronization mechanism for graphql service. @@ -474,25 +474,43 @@ func setupServer(closer *y.Closer) { // The global epoch is set to maxUint64 while exiting the server. // By using this information polling goroutine terminates the subscription. globalEpoch := uint64(0) - mainServer, adminServer := admin.NewServers(introspection, &globalEpoch, closer) + var mainServer web.IServeGraphQL + mainServer, adminServer = admin.NewServers(introspection, &globalEpoch, closer) http.Handle("/graphql", mainServer.HTTPHandler()) - - whitelist := func(h http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if !handlerInit(w, r, map[string]bool{ - http.MethodPost: true, - http.MethodGet: true, - http.MethodOptions: true, - }) { - return - } - h.ServeHTTP(w, r) - }) - } - http.Handle("/admin", whitelist(adminServer.HTTPHandler())) - http.HandleFunc("/admin/schema", func(w http.ResponseWriter, r *http.Request) { + http.Handle("/admin", allowedMethodsHandler(allowedMethods{ + http.MethodGet: true, + http.MethodPost: true, + http.MethodOptions: true, + }, adminAuthHandler(adminServer.HTTPHandler()))) + + http.Handle("/admin/schema", adminAuthHandler(http.HandlerFunc(func(w http.ResponseWriter, + r *http.Request) { adminSchemaHandler(w, r, adminServer) - }) + }))) + + http.Handle("/admin/shutdown", allowedMethodsHandler(allowedMethods{http.MethodGet: true}, + adminAuthHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + shutDownHandler(w, r, adminServer) + })))) + + http.Handle("/admin/draining", allowedMethodsHandler(allowedMethods{ + http.MethodPut: true, + http.MethodPost: true, + }, adminAuthHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + drainingHandler(w, r, adminServer) + })))) + + http.Handle("/admin/export", allowedMethodsHandler(allowedMethods{http.MethodGet: true}, + adminAuthHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + exportHandler(w, r, adminServer) + })))) + + http.Handle("/admin/config/lru_mb", allowedMethodsHandler(allowedMethods{ + http.MethodGet: true, + http.MethodPut: true, + }, adminAuthHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + memoryLimitHandler(w, r, adminServer) + })))) addr := fmt.Sprintf("%s:%d", laddr, httpPort()) glog.Infof("Bringing up GraphQL HTTP API at %s/graphql", addr) diff --git a/edgraph/access.go b/edgraph/access.go index ff0a3857642..d8ffc434db3 100644 --- a/edgraph/access.go +++ b/edgraph/access.go @@ -65,7 +65,7 @@ func authorizeQuery(ctx context.Context, parsedReq *gql.Result, graphql bool) er return nil } -func authorizeGuardians(ctx context.Context) error { +func AuthorizeGuardians(ctx context.Context) error { // always allow access return nil } diff --git a/edgraph/access_ee.go b/edgraph/access_ee.go index 568a1500b17..826e03e8f87 100644 --- a/edgraph/access_ee.go +++ b/edgraph/access_ee.go @@ -794,8 +794,8 @@ func authorizeQuery(ctx context.Context, parsedReq *gql.Result, graphql bool) er return nil } -// authorizeGuardians authorizes the operation for users which belong to Guardians group. -func authorizeGuardians(ctx context.Context) error { +// AuthorizeGuardians authorizes the operation for users which belong to Guardians group. +func AuthorizeGuardians(ctx context.Context) error { if len(worker.Config.HmacSecret) == 0 { // the user has not turned on the acl feature return nil diff --git a/edgraph/server.go b/edgraph/server.go index e7a3aaf5f50..6fa4703d7d2 100644 --- a/edgraph/server.go +++ b/edgraph/server.go @@ -699,7 +699,7 @@ func (s *Server) Health(ctx context.Context, all bool) (*api.Response, error) { var healthAll []pb.HealthInfo if all { - if err := authorizeGuardians(ctx); err != nil { + if err := AuthorizeGuardians(ctx); err != nil { return nil, err } pool := conn.GetPools().GetAll() @@ -738,7 +738,7 @@ func (s *Server) State(ctx context.Context) (*api.Response, error) { return nil, ctx.Err() } - if err := authorizeGuardians(ctx); err != nil { + if err := AuthorizeGuardians(ctx); err != nil { return nil, err } diff --git a/ee/acl/acl_test.go b/ee/acl/acl_test.go index fcf2c07c5b7..9ff6da99738 100644 --- a/ee/acl/acl_test.go +++ b/ee/acl/acl_test.go @@ -13,13 +13,10 @@ package acl import ( - "bytes" "context" "encoding/json" "errors" "fmt" - "io/ioutil" - "net/http" "strconv" "testing" "time" @@ -631,28 +628,7 @@ type group struct { func makeRequest(t *testing.T, accessToken string, params testutil.GraphQLParams) *testutil. GraphQLResponse { - adminUrl := "http://" + testutil.SockAddrHttp + "/admin" - - b, err := json.Marshal(params) - require.NoError(t, err) - - req, err := http.NewRequest(http.MethodPost, adminUrl, bytes.NewBuffer(b)) - require.NoError(t, err) - req.Header.Set("X-Dgraph-AccessToken", accessToken) - req.Header.Set("Content-Type", "application/json") - client := &http.Client{} - resp, err := client.Do(req) - require.NoError(t, err) - - defer resp.Body.Close() - b, err = ioutil.ReadAll(resp.Body) - require.NoError(t, err) - - var result *testutil.GraphQLResponse - err = json.Unmarshal(b, &result) - require.NoError(t, err) - - return result + return testutil.MakeGQLRequestWithAccessJwt(t, ¶ms, accessToken) } func addRulesToGroup(t *testing.T, accessToken, group string, rules []rule) { @@ -1750,49 +1726,28 @@ func TestWrongPermission(t *testing.T) { } func TestHealthForAcl(t *testing.T) { - resetUser(t) - - gqlQuery := ` - query { - health { - instance - address - lastEcho - status - version - uptime - group - } - }` - params := testutil.GraphQLParams{ - Query: gqlQuery, + Query: ` + query { + health { + instance + address + lastEcho + status + version + uptime + group + } + }`, } // assert errors for non-guardians - accessJwt, _, err := testutil.HttpLogin(&testutil.LoginParams{ - Endpoint: adminEndpoint, - UserID: userid, - Passwd: userpassword, - }) - require.NoError(t, err, "login failed") - - resp := makeRequest(t, accessJwt, params) - expectedError := fmt.Sprintf("Error: rpc error: code"+ - " = PermissionDenied desc = Only guardians are allowed access. "+ - "User '%s' is not a member of guardians group.", userid) - require.Equal(t, x.GqlErrorList{{Message: expectedError}}, resp.Errors) - require.JSONEq(t, `{ "health": [] }`, string(resp.Data)) + assertNonGuardianFailure(t, "health", false, params) // assert data for guardians - accessJwt, _, err = testutil.HttpLogin(&testutil.LoginParams{ - Endpoint: adminEndpoint, - UserID: "groot", - Passwd: "password", - }) - require.NoError(t, err, "groot login failed") + accessJwt, _ := testutil.GrootHttpLogin(adminEndpoint) - resp = makeRequest(t, accessJwt, params) + resp := makeRequest(t, accessJwt, params) resp.RequireNoGraphQLErrors(t) var guardianResp struct { Health []struct { @@ -1805,7 +1760,7 @@ func TestHealthForAcl(t *testing.T) { Group string } } - err = json.Unmarshal(resp.Data, &guardianResp) + err := json.Unmarshal(resp.Data, &guardianResp) require.NoError(t, err, "health request failed") // we have 9 instances of alphas/zeros in teamcity environment @@ -1821,6 +1776,247 @@ func TestHealthForAcl(t *testing.T) { } } +func assertNonGuardianFailure(t *testing.T, queryName string, respIsNull bool, + params testutil.GraphQLParams) { + resetUser(t) + + accessJwt, _, err := testutil.HttpLogin(&testutil.LoginParams{ + Endpoint: adminEndpoint, + UserID: userid, + Passwd: userpassword, + }) + require.NoError(t, err, "login failed") + resp := makeRequest(t, accessJwt, params) + + require.Len(t, resp.Errors, 1) + require.Contains(t, resp.Errors[0].Message, + fmt.Sprintf("rpc error: code = PermissionDenied desc = Only guardians are allowed access."+ + " User '%s' is not a member of guardians group.", userid)) + if len(resp.Data) != 0 { + queryVal := "null" + if !respIsNull { + queryVal = "[]" + } + require.JSONEq(t, fmt.Sprintf(`{"%s": %s}`, queryName, queryVal), string(resp.Data)) + } +} + +type graphQLAdminEndpointTestCase struct { + name string + query string + queryName string + respIsArray bool + testGuardianAccess bool + guardianErrs x.GqlErrorList + // specifying this as empty string means it won't be compared with response data + guardianData string +} + +func TestGuardianOnlyAccessForAdminEndpoints(t *testing.T) { + tcases := []graphQLAdminEndpointTestCase{ + { + name: "backup has guardian auth", + query: ` + mutation { + backup(input: {destination: ""}) { + response { + code + message + } + } + }`, + queryName: "backup", + testGuardianAccess: true, + guardianErrs: x.GqlErrorList{{ + Message: "resolving backup failed because you must specify a 'destination' value", + Locations: []x.Location{{Line: 3, Column: 8}}, + }}, + guardianData: `{"backup": null}`, + }, + { + name: "listBackups has guardian auth", + query: ` + query { + listBackups(input: {location: ""}) { + backupId + } + }`, + queryName: "listBackups", + respIsArray: true, + testGuardianAccess: true, + guardianErrs: x.GqlErrorList{{ + Message: "resolving listBackups failed because Error: cannot read manfiests at " + + "location : The path \"\" does not exist or it is inaccessible.", + Locations: []x.Location{{Line: 3, Column: 8}}, + }}, + guardianData: `{"listBackups": []}`, + }, + { + name: "config update has guardian auth", + query: ` + mutation { + config(input: {lruMb: 1}) { + response { + code + message + } + } + }`, + queryName: "config", + testGuardianAccess: true, + guardianErrs: x.GqlErrorList{{ + Message: "resolving config failed because lru_mb must be at least 1024\n", + Locations: []x.Location{{Line: 3, Column: 8}}, + }}, + guardianData: `{"config": null}`, + }, + { + name: "config get has guardian auth", + query: ` + query { + config { + lruMb + } + }`, + queryName: "config", + testGuardianAccess: true, + guardianErrs: nil, + guardianData: "", + }, + { + name: "draining has guardian auth", + query: ` + mutation { + draining(enable: false) { + response { + code + message + } + } + }`, + queryName: "draining", + testGuardianAccess: true, + guardianErrs: nil, + guardianData: `{ + "draining": { + "response": { + "code": "Success", + "message": "draining mode has been set to false" + } + } + }`, + }, + { + name: "export has guardian auth", + query: ` + mutation { + export(input: {format: "invalid"}) { + response { + code + message + } + } + }`, + queryName: "export", + testGuardianAccess: true, + guardianErrs: x.GqlErrorList{{ + Message: "resolving export failed because invalid export format: invalid", + Locations: []x.Location{{Line: 3, Column: 8}}, + }}, + guardianData: `{"export": null}`, + }, + { + name: "restore has guardian auth", + query: ` + mutation { + restore(input: {location: "", backupId: "", keyFile: ""}) { + response { + code + message + } + } + }`, + queryName: "restore", + testGuardianAccess: true, + guardianErrs: x.GqlErrorList{{ + Message: "resolving restore failed because failed to verify backup: while retrieving" + + " manifests: The path \"\" does not exist or it is inaccessible.", + Locations: []x.Location{{Line: 3, Column: 8}}, + }}, + guardianData: `{"restore": null}`, + }, + { + name: "getGQLSchema has guardian auth", + query: ` + query { + getGQLSchema { + id + } + }`, + queryName: "getGQLSchema", + testGuardianAccess: true, + guardianErrs: nil, + guardianData: "", + }, + { + name: "updateGQLSchema has guardian auth", + query: ` + mutation { + updateGQLSchema(input: {set: {schema: ""}}) { + gqlSchema { + id + } + } + }`, + queryName: "updateGQLSchema", + testGuardianAccess: false, + guardianErrs: nil, + guardianData: "", + }, + { + name: "shutdown has guardian auth", + query: ` + mutation { + shutdown { + response { + code + message + } + } + }`, + queryName: "shutdown", + testGuardianAccess: false, + guardianErrs: nil, + guardianData: "", + }, + } + + for _, tcase := range tcases { + t.Run(tcase.name, func(t *testing.T) { + params := testutil.GraphQLParams{Query: tcase.query} + + // assert ACL error for non-guardians + assertNonGuardianFailure(t, tcase.queryName, !tcase.respIsArray, params) + + // for guardians, assert non-ACL error or success + if tcase.testGuardianAccess { + accessJwt, _ := testutil.GrootHttpLogin(adminEndpoint) + resp := makeRequest(t, accessJwt, params) + + if tcase.guardianErrs == nil { + resp.RequireNoGraphQLErrors(t) + } else { + require.Equal(t, tcase.guardianErrs, resp.Errors) + } + + if tcase.guardianData != "" { + require.JSONEq(t, tcase.guardianData, string(resp.Data)) + } + } + }) + } +} + func TestAddUpdateGroupWithDuplicateRules(t *testing.T) { groupName := "testGroup" addedRules := []rule{ diff --git a/graphql/admin/admin.go b/graphql/admin/admin.go index d1963bf65b6..dfe8afa4933 100644 --- a/graphql/admin/admin.go +++ b/graphql/admin/admin.go @@ -236,12 +236,17 @@ const ( response: Response } + type Config { + lruMb: Float + } + ` + adminTypes + ` type Query { getGQLSchema: GQLSchema health: [NodeState] state: MembershipState + config: Config ` + adminQueries + ` } @@ -281,6 +286,55 @@ const ( ` ) +var ( + // commonAdminQueryMWs are the middlewares which should be applied to queries served by admin + // server unless some exceptional behaviour is required + commonAdminQueryMWs = resolve.QueryMiddlewares{ + resolve.IpWhitelistingMW4Query, // good to apply ip whitelisting before Guardian auth + resolve.GuardianAuthMW4Query, + } + // commonAdminMutationMWs are the middlewares which should be applied to mutations served by + // admin server unless some exceptional behaviour is required + commonAdminMutationMWs = resolve.MutationMiddlewares{ + resolve.IpWhitelistingMW4Mutation, // good to apply ip whitelisting before Guardian auth + resolve.GuardianAuthMW4Mutation, + } + adminQueryMWConfig = map[string]resolve.QueryMiddlewares{ + "health": {resolve.IpWhitelistingMW4Query}, // dgraph handles Guardian auth for health + "state": {resolve.IpWhitelistingMW4Query}, // dgraph handles Guardian auth for state + "config": commonAdminQueryMWs, + "listBackups": commonAdminQueryMWs, + // not applying ip whitelisting to keep it in sync with /alter + "getGQLSchema": {resolve.GuardianAuthMW4Query}, + // for queries and mutations related to User/Group, dgraph handles Guardian auth, + // so no need to apply GuardianAuth Middleware + "queryGroup": {resolve.IpWhitelistingMW4Query}, + "queryUser": {resolve.IpWhitelistingMW4Query}, + "getGroup": {resolve.IpWhitelistingMW4Query}, + "getCurrentUser": {resolve.IpWhitelistingMW4Query}, + "getUser": {resolve.IpWhitelistingMW4Query}, + } + adminMutationMWConfig = map[string]resolve.MutationMiddlewares{ + "backup": commonAdminMutationMWs, + "config": commonAdminMutationMWs, + "draining": commonAdminMutationMWs, + "export": commonAdminMutationMWs, + "login": {resolve.IpWhitelistingMW4Mutation}, + "restore": commonAdminMutationMWs, + "shutdown": commonAdminMutationMWs, + // not applying ip whitelisting to keep it in sync with /alter + "updateGQLSchema": {resolve.GuardianAuthMW4Mutation}, + // for queries and mutations related to User/Group, dgraph handles Guardian auth, + // so no need to apply GuardianAuth Middleware + "addUser": {resolve.IpWhitelistingMW4Mutation}, + "addGroup": {resolve.IpWhitelistingMW4Mutation}, + "updateUser": {resolve.IpWhitelistingMW4Mutation}, + "updateGroup": {resolve.IpWhitelistingMW4Mutation}, + "deleteUser": {resolve.IpWhitelistingMW4Mutation}, + "deleteGroup": {resolve.IpWhitelistingMW4Mutation}, + } +) + type gqlSchema struct { ID string `json:"id,omitempty"` Schema string `json:"schema,omitempty"` @@ -416,7 +470,7 @@ func newAdminResolverFactory() resolve.ResolverFactory { adminMutationResolvers := map[string]resolve.MutationResolverFunc{ "backup": resolveBackup, - "config": resolveConfig, + "config": resolveUpdateConfig, "draining": resolveDraining, "export": resolveExport, "login": resolveLogin, @@ -425,12 +479,20 @@ func newAdminResolverFactory() resolve.ResolverFactory { } rf := resolverFactoryWithErrorMsg(errResolverNotFound). + WithQueryMiddlewareConfig(adminQueryMWConfig). + WithMutationMiddlewareConfig(adminMutationMWConfig). WithQueryResolver("health", func(q schema.Query) resolve.QueryResolver { return resolve.QueryResolverFunc(resolveHealth) }). WithQueryResolver("state", func(q schema.Query) resolve.QueryResolver { return resolve.QueryResolverFunc(resolveState) }). + WithQueryResolver("config", func(q schema.Query) resolve.QueryResolver { + return resolve.QueryResolverFunc(resolveGetConfig) + }). + WithQueryResolver("listBackups", func(q schema.Query) resolve.QueryResolver { + return resolve.QueryResolverFunc(resolveListBackups) + }). WithMutationResolver("updateGQLSchema", func(m schema.Mutation) resolve.MutationResolver { return resolve.MutationResolverFunc( func(ctx context.Context, m schema.Mutation) (*resolve.Resolved, bool) { @@ -448,18 +510,13 @@ func newAdminResolverFactory() resolve.ResolverFactory { for gqlMut, resolver := range adminMutationResolvers { // gotta force go to evaluate the right function at each loop iteration // otherwise you get variable capture issues - func(f resolve.MutationResolverFunc) { + func(f resolve.MutationResolver) { rf.WithMutationResolver(gqlMut, func(m schema.Mutation) resolve.MutationResolver { return f }) }(resolver) } - // Add admin query endpoints. - rf = rf.WithQueryResolver("listBackups", func(q schema.Query) resolve.QueryResolver { - return resolve.QueryResolverFunc(resolveListBackups) - }) - return rf.WithSchemaIntrospection() } @@ -766,11 +823,3 @@ func response(code, msg string) map[string]interface{} { return map[string]interface{}{ "response": map[string]interface{}{"code": code, "message": msg}} } - -func emptyResult(f schema.Field, err error) *resolve.Resolved { - return &resolve.Resolved{ - Data: map[string]interface{}{f.Name(): nil}, - Field: f, - Err: err, - } -} diff --git a/graphql/admin/backup.go b/graphql/admin/backup.go index 36d4e9b6f5c..c2238dfef5d 100644 --- a/graphql/admin/backup.go +++ b/graphql/admin/backup.go @@ -41,7 +41,7 @@ func resolveBackup(ctx context.Context, m schema.Mutation) (*resolve.Resolved, b input, err := getBackupInput(m) if err != nil { - return emptyResult(m, err), false + return resolve.EmptyResult(m, err), false } err = worker.ProcessBackupRequest(context.Background(), &pb.BackupRequest{ @@ -53,7 +53,7 @@ func resolveBackup(ctx context.Context, m schema.Mutation) (*resolve.Resolved, b }, input.ForceFull) if err != nil { - return emptyResult(m, err), false + return resolve.EmptyResult(m, err), false } return &resolve.Resolved{ diff --git a/graphql/admin/config.go b/graphql/admin/config.go index 141ecf2f128..d570f98cb4d 100644 --- a/graphql/admin/config.go +++ b/graphql/admin/config.go @@ -22,6 +22,7 @@ import ( "github.com/dgraph-io/dgraph/graphql/resolve" "github.com/dgraph-io/dgraph/graphql/schema" + "github.com/dgraph-io/dgraph/posting" "github.com/dgraph-io/dgraph/worker" "github.com/golang/glog" ) @@ -34,17 +35,17 @@ type configInput struct { LogRequest *bool } -func resolveConfig(ctx context.Context, m schema.Mutation) (*resolve.Resolved, bool) { - glog.Info("Got config request through GraphQL admin API") +func resolveUpdateConfig(ctx context.Context, m schema.Mutation) (*resolve.Resolved, bool) { + glog.Info("Got config update through GraphQL admin API") input, err := getConfigInput(m) if err != nil { - return emptyResult(m, err), false + return resolve.EmptyResult(m, err), false } if input.LruMB > 0 { if err = worker.UpdateLruMb(input.LruMB); err != nil { - return emptyResult(m, err), false + return resolve.EmptyResult(m, err), false } } @@ -59,6 +60,21 @@ func resolveConfig(ctx context.Context, m schema.Mutation) (*resolve.Resolved, b }, true } +func resolveGetConfig(ctx context.Context, q schema.Query) *resolve.Resolved { + glog.Info("Got config query through GraphQL admin API") + + conf := make(map[string]interface{}) + posting.Config.Lock() + conf["lruMb"] = posting.Config.AllottedMemory + posting.Config.Unlock() + + return &resolve.Resolved{ + Data: map[string]interface{}{q.Name(): conf}, + Field: q, + } + +} + func getConfigInput(m schema.Mutation) (*configInput, error) { inputArg := m.ArgValue(schema.InputArgName) inputByts, err := json.Marshal(inputArg) diff --git a/graphql/admin/export.go b/graphql/admin/export.go index cfb19686027..85bff5f274f 100644 --- a/graphql/admin/export.go +++ b/graphql/admin/export.go @@ -37,20 +37,20 @@ func resolveExport(ctx context.Context, m schema.Mutation) (*resolve.Resolved, b input, err := getExportInput(m) if err != nil { - return emptyResult(m, err), false + return resolve.EmptyResult(m, err), false } format := worker.DefaultExportFormat if input.Format != "" { format = worker.NormalizeExportFormat(input.Format) if format == "" { - return emptyResult(m, errors.Errorf("invalid export format: %v", input.Format)), false + return resolve.EmptyResult(m, errors.Errorf("invalid export format: %v", input.Format)), false } } err = worker.ExportOverNetwork(context.Background(), format) if err != nil { - return emptyResult(m, err), false + return resolve.EmptyResult(m, err), false } return &resolve.Resolved{ diff --git a/graphql/admin/health.go b/graphql/admin/health.go index 9fc8871d49c..30cf6d2b845 100644 --- a/graphql/admin/health.go +++ b/graphql/admin/health.go @@ -33,7 +33,7 @@ func resolveHealth(ctx context.Context, q schema.Query) *resolve.Resolved { resp, err := (&edgraph.Server{}).Health(ctx, true) if err != nil { - return emptyResult(q, errors.Errorf("%s: %s", x.Error, err.Error())) + return resolve.EmptyResult(q, errors.Errorf("%s: %s", x.Error, err.Error())) } var health []map[string]interface{} diff --git a/graphql/admin/list_backups.go b/graphql/admin/list_backups.go index 2a81ba03e1f..e550e27a491 100644 --- a/graphql/admin/list_backups.go +++ b/graphql/admin/list_backups.go @@ -54,7 +54,7 @@ type manifest struct { func resolveListBackups(ctx context.Context, q schema.Query) *resolve.Resolved { input, err := getLsBackupInput(q) if err != nil { - return emptyResult(q, err) + return resolve.EmptyResult(q, err) } creds := &worker.Credentials{ @@ -65,7 +65,7 @@ func resolveListBackups(ctx context.Context, q schema.Query) *resolve.Resolved { } manifests, err := worker.ProcessListBackups(ctx, input.Location, creds) if err != nil { - return emptyResult(q, errors.Errorf("%s: %s", x.Error, err.Error())) + return resolve.EmptyResult(q, errors.Errorf("%s: %s", x.Error, err.Error())) } convertedManifests := convertManifests(manifests) @@ -73,12 +73,12 @@ func resolveListBackups(ctx context.Context, q schema.Query) *resolve.Resolved { for _, m := range convertedManifests { b, err := json.Marshal(m) if err != nil { - return emptyResult(q, err) + return resolve.EmptyResult(q, err) } var result map[string]interface{} err = json.Unmarshal(b, &result) if err != nil { - return emptyResult(q, err) + return resolve.EmptyResult(q, err) } results = append(results, result) } diff --git a/graphql/admin/login.go b/graphql/admin/login.go index bdc038ca863..f9c8967d160 100644 --- a/graphql/admin/login.go +++ b/graphql/admin/login.go @@ -42,12 +42,12 @@ func resolveLogin(ctx context.Context, m schema.Mutation) (*resolve.Resolved, bo RefreshToken: input.RefreshToken, }) if err != nil { - return emptyResult(m, err), false + return resolve.EmptyResult(m, err), false } jwt := &dgoapi.Jwt{} if err := jwt.Unmarshal(resp.GetJson()); err != nil { - return emptyResult(m, err), false + return resolve.EmptyResult(m, err), false } return &resolve.Resolved{ diff --git a/graphql/admin/restore.go b/graphql/admin/restore.go index 6f478577e6e..881a28e6bf7 100644 --- a/graphql/admin/restore.go +++ b/graphql/admin/restore.go @@ -40,7 +40,7 @@ func resolveRestore(ctx context.Context, m schema.Mutation) (*resolve.Resolved, input, err := getRestoreInput(m) if err != nil { - return emptyResult(m, err), false + return resolve.EmptyResult(m, err), false } req := pb.RestoreRequest{ @@ -54,7 +54,7 @@ func resolveRestore(ctx context.Context, m schema.Mutation) (*resolve.Resolved, } err = worker.ProcessRestoreRequest(context.Background(), &req) if err != nil { - return emptyResult(m, err), false + return resolve.EmptyResult(m, err), false } return &resolve.Resolved{ diff --git a/graphql/admin/state.go b/graphql/admin/state.go index 06efb3cb4a5..5a2597990ec 100644 --- a/graphql/admin/state.go +++ b/graphql/admin/state.go @@ -37,7 +37,7 @@ type clusterGroup struct { func resolveState(ctx context.Context, q schema.Query) *resolve.Resolved { resp, err := (&edgraph.Server{}).State(ctx) if err != nil { - return emptyResult(q, errors.Errorf("%s: %s", x.Error, err.Error())) + return resolve.EmptyResult(q, errors.Errorf("%s: %s", x.Error, err.Error())) } // unmarshal it back to MembershipState proto in order to map to graphql response @@ -46,19 +46,19 @@ func resolveState(ctx context.Context, q schema.Query) *resolve.Resolved { err = u.Unmarshal(bytes.NewReader(resp.GetJson()), &ms) if err != nil { - return emptyResult(q, err) + return resolve.EmptyResult(q, err) } // map to graphql response structure state := convertToGraphQLResp(ms) b, err := json.Marshal(state) if err != nil { - return emptyResult(q, err) + return resolve.EmptyResult(q, err) } var resultState map[string]interface{} err = json.Unmarshal(b, &resultState) if err != nil { - return emptyResult(q, err) + return resolve.EmptyResult(q, err) } return &resolve.Resolved{ diff --git a/graphql/resolve/middlewares.go b/graphql/resolve/middlewares.go new file mode 100644 index 00000000000..82af2112a1d --- /dev/null +++ b/graphql/resolve/middlewares.go @@ -0,0 +1,178 @@ +/* + * Copyright 2020 Dgraph Labs, Inc. and Contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package resolve + +import ( + "context" + "net" + + "github.com/pkg/errors" + + "google.golang.org/grpc/peer" + + "github.com/dgraph-io/dgraph/edgraph" + "github.com/dgraph-io/dgraph/graphql/schema" + "github.com/dgraph-io/dgraph/x" +) + +// QueryMiddleware represents a middleware for queries +type QueryMiddleware func(resolver QueryResolver) QueryResolver + +// MutationMiddleware represents a middleware for mutations +type MutationMiddleware func(resolver MutationResolver) MutationResolver + +// QueryMiddlewares represents a list of middlewares for queries, that get applied in the order +// they are present in the list. +// Inspired from: https://github.com/justinas/alice +type QueryMiddlewares []QueryMiddleware + +// MutationMiddlewares represents a list of middlewares for mutations, that get applied in the order +// they are present in the list. +// Inspired from: https://github.com/justinas/alice +type MutationMiddlewares []MutationMiddleware + +// Then chains the middlewares and returns the final QueryResolver. +// QueryMiddlewares{m1, m2, m3}.Then(r) +// is equivalent to: +// m1(m2(m3(r))) +// When the request comes in, it will be passed to m1, then m2, then m3 +// and finally, the given resolverFunc +// (assuming every middleware calls the following one). +// +// A chain can be safely reused by calling Then() several times. +// commonMiddlewares := QueryMiddlewares{authMiddleware, loggingMiddleware} +// healthResolver = commonMiddlewares.Then(resolveHealth) +// stateResolver = commonMiddlewares.Then(resolveState) +// Note that middlewares are called on every call to Then() +// and thus several instances of the same middleware will be created +// when a chain is reused in this way. +// For proper middleware, this should cause no problems. +// +// Then() treats nil as a QueryResolverFunc that resolves to &Resolved{Field: query} +func (mws QueryMiddlewares) Then(resolver QueryResolver) QueryResolver { + if len(mws) == 0 { + return resolver + } + if resolver == nil { + resolver = QueryResolverFunc(func(ctx context.Context, query schema.Query) *Resolved { + return &Resolved{Field: query} + }) + } + for i := len(mws) - 1; i >= 0; i-- { + resolver = mws[i](resolver) + } + return resolver +} + +// Then chains the middlewares and returns the final MutationResolver. +// MutationMiddlewares{m1, m2, m3}.Then(r) +// is equivalent to: +// m1(m2(m3(r))) +// When the request comes in, it will be passed to m1, then m2, then m3 +// and finally, the given resolverFunc +// (assuming every middleware calls the following one). +// +// A chain can be safely reused by calling Then() several times. +// commonMiddlewares := MutationMiddlewares{authMiddleware, loggingMiddleware} +// backupResolver = commonMiddlewares.Then(resolveBackup) +// configResolver = commonMiddlewares.Then(resolveConfig) +// Note that middlewares are called on every call to Then() +// and thus several instances of the same middleware will be created +// when a chain is reused in this way. +// For proper middleware, this should cause no problems. +// +// Then() treats nil as a MutationResolverFunc that resolves to (&Resolved{Field: mutation}, true) +func (mws MutationMiddlewares) Then(resolver MutationResolver) MutationResolver { + if len(mws) == 0 { + return resolver + } + if resolver == nil { + resolver = MutationResolverFunc(func(ctx context.Context, + mutation schema.Mutation) (*Resolved, bool) { + return &Resolved{Field: mutation}, true + }) + } + for i := len(mws) - 1; i >= 0; i-- { + resolver = mws[i](resolver) + } + return resolver +} + +// resolveGuardianAuth returns a Resolved with error if the context doesn't contain any Guardian auth, +// otherwise it returns nil +func resolveGuardianAuth(ctx context.Context, f schema.Field) *Resolved { + if err := edgraph.AuthorizeGuardians(ctx); err != nil { + return EmptyResult(f, err) + } + return nil +} + +func resolveIpWhitelisting(ctx context.Context, f schema.Field) *Resolved { + peerInfo, ok := peer.FromContext(ctx) + if !ok { + return EmptyResult(f, errors.New("unable to find source ip")) + } + ip, _, err := net.SplitHostPort(peerInfo.Addr.String()) + if err != nil { + return EmptyResult(f, err) + } + if !x.IsIpWhitelisted(ip) { + return EmptyResult(f, errors.Errorf("unauthorized ip address: %s", ip)) + } + return nil +} + +// GuardianAuthMW4Query blocks the resolution of resolverFunc if there is no Guardian auth +// present in context, otherwise it lets the resolverFunc resolve the query. +func GuardianAuthMW4Query(resolver QueryResolver) QueryResolver { + return QueryResolverFunc(func(ctx context.Context, query schema.Query) *Resolved { + if resolved := resolveGuardianAuth(ctx, query); resolved != nil { + return resolved + } + return resolver.Resolve(ctx, query) + }) +} + +func IpWhitelistingMW4Query(resolver QueryResolver) QueryResolver { + return QueryResolverFunc(func(ctx context.Context, query schema.Query) *Resolved { + if resolved := resolveIpWhitelisting(ctx, query); resolved != nil { + return resolved + } + return resolver.Resolve(ctx, query) + }) +} + +// GuardianAuthMW4Mutation blocks the resolution of resolverFunc if there is no Guardian auth +// present in context, otherwise it lets the resolverFunc resolve the mutation. +func GuardianAuthMW4Mutation(resolver MutationResolver) MutationResolver { + return MutationResolverFunc(func(ctx context.Context, mutation schema.Mutation) (*Resolved, bool) { + if resolved := resolveGuardianAuth(ctx, mutation); resolved != nil { + return resolved, false + } + return resolver.Resolve(ctx, mutation) + }) +} + +func IpWhitelistingMW4Mutation(resolver MutationResolver) MutationResolver { + return MutationResolverFunc(func(ctx context.Context, mutation schema.Mutation) (*Resolved, + bool) { + if resolved := resolveIpWhitelisting(ctx, mutation); resolved != nil { + return resolved, false + } + return resolver.Resolve(ctx, mutation) + }) +} diff --git a/graphql/resolve/middlewares_test.go b/graphql/resolve/middlewares_test.go new file mode 100644 index 00000000000..2eaa5c18368 --- /dev/null +++ b/graphql/resolve/middlewares_test.go @@ -0,0 +1,97 @@ +/* + * Copyright 2020 Dgraph Labs, Inc. and Contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package resolve + +import ( + "context" + "testing" + + "github.com/dgraph-io/dgraph/graphql/schema" + "github.com/stretchr/testify/require" +) + +func TestQueryMiddlewares_Then_ExecutesMiddlewaresInOrder(t *testing.T) { + array := make([]int, 0) + addToArray := func(num int) { + array = append(array, num) + } + m1 := QueryMiddleware(func(resolver QueryResolver) QueryResolver { + return QueryResolverFunc(func(ctx context.Context, query schema.Query) *Resolved { + addToArray(1) + defer addToArray(5) + return resolver.Resolve(ctx, query) + }) + }) + m2 := QueryMiddleware(func(resolver QueryResolver) QueryResolver { + return QueryResolverFunc(func(ctx context.Context, query schema.Query) *Resolved { + addToArray(2) + resolved := resolver.Resolve(ctx, query) + addToArray(4) + return resolved + }) + }) + mws := QueryMiddlewares{m1, m2} + + resolver := mws.Then(QueryResolverFunc(func(ctx context.Context, query schema.Query) *Resolved { + addToArray(3) + return &Resolved{ + Field: query, + Extensions: &schema.Extensions{TouchedUids: 1}, + } + })) + resolved := resolver.Resolve(context.Background(), nil) + + require.Equal(t, &Resolved{Extensions: &schema.Extensions{TouchedUids: 1}}, resolved) + require.Equal(t, []int{1, 2, 3, 4, 5}, array) +} + +func TestMutationMiddlewares_Then_ExecutesMiddlewaresInOrder(t *testing.T) { + array := make([]int, 0) + addToArray := func(num int) { + array = append(array, num) + } + m1 := MutationMiddleware(func(resolver MutationResolver) MutationResolver { + return MutationResolverFunc(func(ctx context.Context, mutation schema.Mutation) (*Resolved, bool) { + addToArray(1) + defer addToArray(5) + return resolver.Resolve(ctx, mutation) + }) + }) + m2 := MutationMiddleware(func(resolver MutationResolver) MutationResolver { + return MutationResolverFunc(func(ctx context.Context, + mutation schema.Mutation) (*Resolved, bool) { + addToArray(2) + resolved, success := resolver.Resolve(ctx, mutation) + addToArray(4) + return resolved, success + }) + }) + mws := MutationMiddlewares{m1, m2} + + resolver := mws.Then(MutationResolverFunc(func(ctx context.Context, mutation schema.Mutation) (*Resolved, bool) { + addToArray(3) + return &Resolved{ + Field: mutation, + Extensions: &schema.Extensions{TouchedUids: 1}, + }, true + })) + resolved, succeeded := resolver.Resolve(context.Background(), nil) + + require.True(t, succeeded) + require.Equal(t, &Resolved{Extensions: &schema.Extensions{TouchedUids: 1}}, resolved) + require.Equal(t, []int{1, 2, 3, 4, 5}, array) +} diff --git a/graphql/resolve/resolver.go b/graphql/resolve/resolver.go index ce8016d9fc3..67b35874c1e 100644 --- a/graphql/resolve/resolver.go +++ b/graphql/resolve/resolver.go @@ -54,12 +54,12 @@ type ResolverFactory interface { mutationResolverFor(mutation schema.Mutation) MutationResolver // WithQueryResolver adds a new query resolver. Each time query name is resolved - // resolver is called to create a new instane of a QueryResolver to resolve the + // resolver is called to create a new instance of a QueryResolver to resolve the // query. WithQueryResolver(name string, resolver func(schema.Query) QueryResolver) ResolverFactory // WithMutationResolver adds a new query resolver. Each time mutation name is resolved - // resolver is called to create a new instane of a MutationResolver to resolve the + // resolver is called to create a new instance of a MutationResolver to resolve the // mutation. WithMutationResolver( name string, resolver func(schema.Mutation) MutationResolver) ResolverFactory @@ -68,6 +68,15 @@ type ResolverFactory interface { // factory. The registration happens only once. WithConventionResolvers(s schema.Schema, fns *ResolverFns) ResolverFactory + // WithQueryMiddlewareConfig adds the configuration to use to apply middlewares before resolving + // queries. The config should be a mapping of the name of query to its middlewares. + WithQueryMiddlewareConfig(config map[string]QueryMiddlewares) ResolverFactory + + // WithMutationMiddlewareConfig adds the configuration to use to apply middlewares before + // resolving mutations. The config should be a mapping of the name of mutation to its + // middlewares. + WithMutationMiddlewareConfig(config map[string]MutationMiddlewares) ResolverFactory + // WithSchemaIntrospection adds schema introspection capabilities to the factory. // So __schema and __type queries can be resolved. WithSchemaIntrospection() ResolverFactory @@ -97,6 +106,9 @@ type resolverFactory struct { queryResolvers map[string]func(schema.Query) QueryResolver mutationResolvers map[string]func(schema.Mutation) MutationResolver + queryMiddlewareConfig map[string]QueryMiddlewares + mutationMiddlewareConfig map[string]MutationMiddlewares + // returned if the factory gets asked for resolver for a field that it doesn't // know about. queryError QueryResolverFunc @@ -250,6 +262,22 @@ func (rf *resolverFactory) WithConventionResolvers( return rf } +func (rf *resolverFactory) WithQueryMiddlewareConfig( + config map[string]QueryMiddlewares) ResolverFactory { + if len(config) != 0 { + rf.queryMiddlewareConfig = config + } + return rf +} + +func (rf *resolverFactory) WithMutationMiddlewareConfig( + config map[string]MutationMiddlewares) ResolverFactory { + if len(config) != 0 { + rf.mutationMiddlewareConfig = config + } + return rf +} + // NewResolverFactory returns a ResolverFactory that resolves requests via // query/mutation rewriting and execution through Dgraph. If the factory gets asked // to resolve a query/mutation it doesn't know how to rewrite, it uses @@ -261,6 +289,9 @@ func NewResolverFactory( queryResolvers: make(map[string]func(schema.Query) QueryResolver), mutationResolvers: make(map[string]func(schema.Mutation) MutationResolver), + queryMiddlewareConfig: make(map[string]QueryMiddlewares), + mutationMiddlewareConfig: make(map[string]MutationMiddlewares), + queryError: queryError, mutationError: mutationError, } @@ -282,16 +313,18 @@ func StdDeleteCompletion(name string) CompletionFunc { } func (rf *resolverFactory) queryResolverFor(query schema.Query) QueryResolver { + mws := rf.queryMiddlewareConfig[query.Name()] if resolver, ok := rf.queryResolvers[query.Name()]; ok { - return resolver(query) + return mws.Then(resolver(query)) } return rf.queryError } func (rf *resolverFactory) mutationResolverFor(mutation schema.Mutation) MutationResolver { + mws := rf.mutationMiddlewareConfig[mutation.Name()] if resolver, ok := rf.mutationResolvers[mutation.Name()]; ok { - return resolver(mutation) + return mws.Then(resolver(mutation)) } return rf.mutationError @@ -1550,3 +1583,11 @@ func (h *httpMutationResolver) Resolve(ctx context.Context, mutation schema.Muta resolved := (*httpResolver)(h).Resolve(ctx, mutation) return resolved, resolved.Err == nil || resolved.Err.Error() == "" } + +func EmptyResult(f schema.Field, err error) *Resolved { + return &Resolved{ + Data: map[string]interface{}{f.Name(): nil}, + Field: f, + Err: schema.GQLWrapLocationf(err, f.Location(), "resolving %s failed", f.Name()), + } +} diff --git a/graphql/web/http.go b/graphql/web/http.go index 2a0ac84a48e..fd24c112c84 100644 --- a/graphql/web/http.go +++ b/graphql/web/http.go @@ -24,23 +24,19 @@ import ( "io" "io/ioutil" "mime" - "net" "net/http" - "strconv" "strings" - "github.com/golang/glog" - "github.com/graph-gophers/graphql-transport-ws/graphqlws" - "go.opencensus.io/trace" - "google.golang.org/grpc/peer" - "github.com/dgraph-io/dgraph/graphql/api" "github.com/dgraph-io/dgraph/graphql/authorization" "github.com/dgraph-io/dgraph/graphql/resolve" "github.com/dgraph-io/dgraph/graphql/schema" "github.com/dgraph-io/dgraph/graphql/subscription" "github.com/dgraph-io/dgraph/x" + "github.com/golang/glog" + "github.com/graph-gophers/graphql-transport-ws/graphqlws" "github.com/pkg/errors" + "go.opencensus.io/trace" ) // An IServeGraphQL can serve a GraphQL endpoint (currently only ons http) @@ -156,19 +152,9 @@ func (gh *graphqlHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { ctx = authorization.AttachAuthorizationJwt(ctx, r) ctx = x.AttachAccessJwt(ctx, r) - - if ip, port, err := net.SplitHostPort(r.RemoteAddr); err == nil { - // Add remote addr as peer info so that the remote address can be logged - // inside Server.Login - if intPort, convErr := strconv.Atoi(port); convErr == nil { - ctx = peer.NewContext(ctx, &peer.Peer{ - Addr: &net.TCPAddr{ - IP: net.ParseIP(ip), - Port: intPort, - }, - }) - } - } + // Add remote addr as peer info so that the remote address can be logged + // inside Server.Login + ctx = x.AttachRemoteIP(ctx, r) var res *schema.Response gqlReq, err := getRequest(ctx, r) diff --git a/testutil/graphql.go b/testutil/graphql.go index 2a45180ebe8..8b720a4e252 100644 --- a/testutil/graphql.go +++ b/testutil/graphql.go @@ -78,6 +78,38 @@ func RequireNoGraphQLErrors(t *testing.T, resp *http.Response) { require.Nil(t, result.Errors) } +func MakeGQLRequest(t *testing.T, params *GraphQLParams) *GraphQLResponse { + return MakeGQLRequestWithAccessJwt(t, params, "") +} + +func MakeGQLRequestWithAccessJwt(t *testing.T, params *GraphQLParams, + accessToken string) *GraphQLResponse { + adminUrl := "http://" + SockAddrHttp + "/admin" + + b, err := json.Marshal(params) + require.NoError(t, err) + + req, err := http.NewRequest(http.MethodPost, adminUrl, bytes.NewBuffer(b)) + require.NoError(t, err) + req.Header.Set("Content-Type", "application/json") + if accessToken != "" { + req.Header.Set("X-Dgraph-AccessToken", accessToken) + } + client := &http.Client{} + resp, err := client.Do(req) + require.NoError(t, err) + + defer resp.Body.Close() + b, err = ioutil.ReadAll(resp.Body) + require.NoError(t, err) + + var gqlResp GraphQLResponse + err = json.Unmarshal(b, &gqlResp) + require.NoError(t, err) + + return &gqlResp +} + type clientCustomClaims struct { Namespace string AuthVariables map[string]interface{} diff --git a/x/x.go b/x/x.go index 95c1c04474e..14d5cb14b74 100644 --- a/x/x.go +++ b/x/x.go @@ -37,6 +37,8 @@ import ( "syscall" "time" + "google.golang.org/grpc/peer" + "github.com/dgraph-io/badger/v2" "github.com/dgraph-io/badger/v2/y" "github.com/dgraph-io/dgo/v200" @@ -369,6 +371,41 @@ func AttachAccessJwt(ctx context.Context, r *http.Request) context.Context { return ctx } +// AttachRemoteIP adds any incoming IP data into the grpc context metadata +func AttachRemoteIP(ctx context.Context, r *http.Request) context.Context { + if ip, port, err := net.SplitHostPort(r.RemoteAddr); err == nil { + if intPort, convErr := strconv.Atoi(port); convErr == nil { + ctx = peer.NewContext(ctx, &peer.Peer{ + Addr: &net.TCPAddr{ + IP: net.ParseIP(ip), + Port: intPort, + }, + }) + } + } + return ctx +} + +// IsIpWhitelisted checks if the given ipString is within the whitelisted ip range +func IsIpWhitelisted(ipString string) bool { + ip := net.ParseIP(ipString) + + if ip == nil { + return false + } + + if ip.IsLoopback() { + return true + } + + for _, ipRange := range WorkerConfig.WhiteListedIPRanges { + if bytes.Compare(ip, ipRange.Lower) >= 0 && bytes.Compare(ip, ipRange.Upper) <= 0 { + return true + } + } + return false +} + // Write response body, transparently compressing if necessary. func WriteResponse(w http.ResponseWriter, r *http.Request, b []byte) (int, error) { var out io.Writer = w