diff --git a/dgraph/cmd/alpha/admin.go b/dgraph/cmd/alpha/admin.go index e608ff473a3..2c2482189d8 100644 --- a/dgraph/cmd/alpha/admin.go +++ b/dgraph/cmd/alpha/admin.go @@ -17,76 +17,100 @@ package alpha import ( - "bytes" - "context" + "encoding/json" "fmt" "io/ioutil" - "net" "net/http" "strconv" - "github.com/dgraph-io/dgraph/posting" + "github.com/dgraph-io/dgraph/graphql/schema" + "github.com/dgraph-io/dgraph/graphql/web" + "github.com/dgraph-io/dgraph/worker" "github.com/dgraph-io/dgraph/x" - "github.com/golang/glog" ) -// handlerInit does some standard checks. Returns false if something is wrong. -func handlerInit(w http.ResponseWriter, r *http.Request, allowedMethods map[string]bool) bool { - if _, ok := allowedMethods[r.Method]; !ok { - x.SetStatus(w, x.ErrorInvalidMethod, "Invalid method") - return false - } +type allowedMethods map[string]bool - ip, _, err := net.SplitHostPort(r.RemoteAddr) - if err != nil || (!ipInIPWhitelistRanges(ip) && !net.ParseIP(ip).IsLoopback()) { - x.SetStatus(w, x.ErrorUnauthorized, fmt.Sprintf("Request from IP: %v", ip)) +// hasPoormansAuth checks if poorman's auth is required and if so whether the given http request has +// poorman's auth in it or not +func hasPoormansAuth(r *http.Request) bool { + if worker.Config.AuthToken != "" && worker.Config.AuthToken != r.Header.Get( + "X-Dgraph-AuthToken") { return false } return true } -func drainingHandler(w http.ResponseWriter, r *http.Request) { - switch r.Method { - case http.MethodPut, http.MethodPost: - enableStr := r.URL.Query().Get("enable") - - enable, err := strconv.ParseBool(enableStr) - if err != nil { - x.SetStatus(w, x.ErrorInvalidRequest, - "Found invalid value for the enable parameter") +func allowedMethodsHandler(allowedMethods allowedMethods, next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if _, ok := allowedMethods[r.Method]; !ok { + x.SetStatus(w, x.ErrorInvalidMethod, "Invalid method") + w.WriteHeader(http.StatusMethodNotAllowed) return } - x.UpdateDrainingMode(enable) - _, err = w.Write([]byte(fmt.Sprintf(`{"code": "Success",`+ - `"message": "draining mode has been set to %v"}`, enable))) - if err != nil { - glog.Errorf("Failed to write response: %v", err) + next.ServeHTTP(w, r) + }) +} + +// adminAuthHandler does some standard checks for admin endpoints. +// It returns if something is wrong. Otherwise, it lets the given handler serve the request. +func adminAuthHandler(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if !hasPoormansAuth(r) { + x.SetStatus(w, x.ErrorUnauthorized, "Invalid X-Dgraph-AuthToken") + return } - default: - w.WriteHeader(http.StatusMethodNotAllowed) - } + + next.ServeHTTP(w, r) + }) } -func shutDownHandler(w http.ResponseWriter, r *http.Request) { - if !handlerInit(w, r, map[string]bool{ - http.MethodGet: true, - }) { +func drainingHandler(w http.ResponseWriter, r *http.Request, adminServer web.IServeGraphQL) { + enableStr := r.URL.Query().Get("enable") + + enable, err := strconv.ParseBool(enableStr) + if err != nil { + x.SetStatus(w, x.ErrorInvalidRequest, + "Found invalid value for the enable parameter") return } - close(worker.ShutdownCh) + gqlReq := &schema.Request{ + Query: ` + mutation draining($enable: Boolean) { + draining(enable: $enable) { + response { + code + } + } + }`, + Variables: map[string]interface{}{"enable": enable}, + } + _ = resolveWithAdminServer(gqlReq, r, adminServer) w.Header().Set("Content-Type", "application/json") - x.Check2(w.Write([]byte(`{"code": "Success", "message": "Server is shutting down"}`))) + x.Check2(w.Write([]byte(fmt.Sprintf(`{"code": "Success",`+ + `"message": "draining mode has been set to %v"}`, enable)))) } -func exportHandler(w http.ResponseWriter, r *http.Request) { - if !handlerInit(w, r, map[string]bool{ - http.MethodGet: true, - }) { - return +func shutDownHandler(w http.ResponseWriter, r *http.Request, adminServer web.IServeGraphQL) { + gqlReq := &schema.Request{ + Query: ` + mutation { + shutdown { + response { + code + } + } + }`, } + _ = resolveWithAdminServer(gqlReq, r, adminServer) + w.Header().Set("Content-Type", "application/json") + x.Check2(w.Write([]byte(`{"code": "Success", "message": "Server is shutting down"}`))) +} + +func exportHandler(w http.ResponseWriter, r *http.Request, adminServer web.IServeGraphQL) { if err := r.ParseForm(); err != nil { x.SetHttpStatus(w, http.StatusBadRequest, "Parse of export request failed.") return @@ -105,26 +129,37 @@ func exportHandler(w http.ResponseWriter, r *http.Request) { return } } - if err := worker.ExportOverNetwork(context.Background(), format); err != nil { - x.SetStatus(w, err.Error(), "Export failed.") + + gqlReq := &schema.Request{ + Query: ` + mutation export($format: String) { + export(input: {format: $format}) { + response { + code + } + } + }`, + Variables: map[string]interface{}{}, + } + resp := resolveWithAdminServer(gqlReq, r, adminServer) + if len(resp.Errors) != 0 { + x.SetStatus(w, resp.Errors[0].Message, "Export failed.") return } w.Header().Set("Content-Type", "application/json") x.Check2(w.Write([]byte(`{"code": "Success", "message": "Export completed."}`))) } -func memoryLimitHandler(w http.ResponseWriter, r *http.Request) { +func memoryLimitHandler(w http.ResponseWriter, r *http.Request, adminServer web.IServeGraphQL) { switch r.Method { case http.MethodGet: - memoryLimitGetHandler(w, r) + memoryLimitGetHandler(w, r, adminServer) case http.MethodPut: - memoryLimitPutHandler(w, r) - default: - w.WriteHeader(http.StatusMethodNotAllowed) + memoryLimitPutHandler(w, r, adminServer) } } -func memoryLimitPutHandler(w http.ResponseWriter, r *http.Request) { +func memoryLimitPutHandler(w http.ResponseWriter, r *http.Request, adminServer web.IServeGraphQL) { body, err := ioutil.ReadAll(r.Body) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) @@ -135,36 +170,45 @@ func memoryLimitPutHandler(w http.ResponseWriter, r *http.Request) { http.Error(w, err.Error(), http.StatusBadRequest) return } + gqlReq := &schema.Request{ + Query: ` + mutation config($lruMb: Float) { + config(input: {lruMb: $lruMb}) { + response { + code + } + } + }`, + Variables: map[string]interface{}{"lruMb": memoryMB}, + } + resp := resolveWithAdminServer(gqlReq, r, adminServer) - if err := worker.UpdateLruMb(memoryMB); err != nil { + if len(resp.Errors) != 0 { w.WriteHeader(http.StatusBadRequest) - fmt.Fprint(w, err.Error()) + x.Check2(fmt.Fprint(w, resp.Errors[0].Message)) return } w.WriteHeader(http.StatusOK) } -func memoryLimitGetHandler(w http.ResponseWriter, r *http.Request) { - posting.Config.Lock() - memoryMB := posting.Config.AllottedMemory - posting.Config.Unlock() - - if _, err := fmt.Fprintln(w, memoryMB); err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) +func memoryLimitGetHandler(w http.ResponseWriter, r *http.Request, adminServer web.IServeGraphQL) { + gqlReq := &schema.Request{ + Query: ` + query { + config { + lruMb + } + }`, } -} - -func ipInIPWhitelistRanges(ipString string) bool { - ip := net.ParseIP(ipString) - - if ip == nil { - return false + resp := resolveWithAdminServer(gqlReq, r, adminServer) + var data struct { + Config struct { + LruMb float64 + } } + x.Check(json.Unmarshal(resp.Data.Bytes(), &data)) - for _, ipRange := range x.WorkerConfig.WhiteListedIPRanges { - if bytes.Compare(ip, ipRange.Lower) >= 0 && bytes.Compare(ip, ipRange.Upper) <= 0 { - return true - } + if _, err := fmt.Fprintln(w, data.Config.LruMb); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) } - return false } diff --git a/dgraph/cmd/alpha/admin_backup.go b/dgraph/cmd/alpha/admin_backup.go index 8919e0940da..9644b519d28 100644 --- a/dgraph/cmd/alpha/admin_backup.go +++ b/dgraph/cmd/alpha/admin_backup.go @@ -19,43 +19,45 @@ package alpha import ( - "context" "net/http" - "github.com/dgraph-io/dgraph/protos/pb" - "github.com/dgraph-io/dgraph/worker" + "github.com/dgraph-io/dgraph/graphql/schema" + + "github.com/dgraph-io/dgraph/graphql/web" + "github.com/dgraph-io/dgraph/x" ) func init() { - http.HandleFunc("/admin/backup", backupHandler) + http.Handle("/admin/backup", allowedMethodsHandler(allowedMethods{http.MethodPost: true}, + adminAuthHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + backupHandler(w, r, adminServer) + })))) } // backupHandler handles backup requests coming from the HTTP endpoint. -func backupHandler(w http.ResponseWriter, r *http.Request) { - if !handlerInit(w, r, map[string]bool{ - http.MethodPost: true, - }) { - return +func backupHandler(w http.ResponseWriter, r *http.Request, adminServer web.IServeGraphQL) { + gqlReq := &schema.Request{ + Query: ` + mutation backup($input: BackupInput!) { + backup(input: $input) { + response { + code + } + } + }`, + Variables: map[string]interface{}{"input": map[string]interface{}{ + "destination": r.FormValue("destination"), + "accessKey": r.FormValue("access_key"), + "secretKey": r.FormValue("secret_key"), + "sessionToken": r.FormValue("session_token"), + "anonymous": r.FormValue("anonymous") == "true", + "forceFull": r.FormValue("force_full") == "true", + }}, } - - destination := r.FormValue("destination") - accessKey := r.FormValue("access_key") - secretKey := r.FormValue("secret_key") - sessionToken := r.FormValue("session_token") - anonymous := r.FormValue("anonymous") == "true" - forceFull := r.FormValue("force_full") == "true" - - req := pb.BackupRequest{ - Destination: destination, - AccessKey: accessKey, - SecretKey: secretKey, - SessionToken: sessionToken, - Anonymous: anonymous, - } - - if err := worker.ProcessBackupRequest(context.Background(), &req, forceFull); err != nil { - x.SetStatus(w, err.Error(), "Backup failed.") + resp := resolveWithAdminServer(gqlReq, r, adminServer) + if resp.Errors != nil { + x.SetStatus(w, resp.Errors.Error(), "Backup failed.") return } diff --git a/dgraph/cmd/alpha/http.go b/dgraph/cmd/alpha/http.go index ceeaf88cd9f..1419be98eb1 100644 --- a/dgraph/cmd/alpha/http.go +++ b/dgraph/cmd/alpha/http.go @@ -586,12 +586,8 @@ func adminSchemaHandler(w http.ResponseWriter, r *http.Request, adminServer web. return } - md := metadata.New(nil) - ctx := metadata.NewIncomingContext(context.Background(), md) - ctx = x.AttachAccessJwt(ctx, r) - - gqlReq := &schema.Request{} - gqlReq.Query = ` + gqlReq := &schema.Request{ + Query: ` mutation updateGqlSchema($sch: String!) { updateGQLSchema(input: { set: { @@ -602,12 +598,11 @@ func adminSchemaHandler(w http.ResponseWriter, r *http.Request, adminServer web. id } } - }` - gqlReq.Variables = map[string]interface{}{ - "sch": string(b), + }`, + Variables: map[string]interface{}{"sch": string(b)}, } - response := adminServer.Resolve(ctx, gqlReq) + response := resolveWithAdminServer(gqlReq, r, adminServer) if len(response.Errors) > 0 { x.SetStatus(w, x.Error, response.Errors.Error()) return @@ -616,6 +611,16 @@ func adminSchemaHandler(w http.ResponseWriter, r *http.Request, adminServer web. writeSuccessResponse(w, r) } +func resolveWithAdminServer(gqlReq *schema.Request, r *http.Request, + adminServer web.IServeGraphQL) *schema.Response { + md := metadata.New(nil) + ctx := metadata.NewIncomingContext(context.Background(), md) + ctx = x.AttachAccessJwt(ctx, r) + ctx = x.AttachRemoteIP(ctx, r) + + return adminServer.Resolve(ctx, gqlReq) +} + func writeSuccessResponse(w http.ResponseWriter, r *http.Request) { res := map[string]interface{}{} data := map[string]interface{}{} diff --git a/dgraph/cmd/alpha/http_test.go b/dgraph/cmd/alpha/http_test.go index 041dde4838e..b39446a05d3 100644 --- a/dgraph/cmd/alpha/http_test.go +++ b/dgraph/cmd/alpha/http_test.go @@ -735,7 +735,7 @@ func TestHealth(t *testing.T) { require.True(t, info[0].Uptime > int64(time.Duration(1))) } -func setDrainingMode(t *testing.T, enable bool) { +func setDrainingMode(t *testing.T, enable bool, accessJwt string) { drainingRequest := `mutation drain($enable: Boolean) { draining(enable: $enable) { response { @@ -743,22 +743,13 @@ func setDrainingMode(t *testing.T, enable bool) { } } }` - adminUrl := fmt.Sprintf("%s/admin", addr) - params := testutil.GraphQLParams{ + params := &testutil.GraphQLParams{ Query: drainingRequest, Variables: map[string]interface{}{"enable": enable}, } - b, err := json.Marshal(params) - require.NoError(t, err) - - resp, err := http.Post(adminUrl, "application/json", bytes.NewBuffer(b)) - require.NoError(t, err) - - defer resp.Body.Close() - b, err = ioutil.ReadAll(resp.Body) - require.NoError(t, err) - require.JSONEq(t, `{"data":{"draining":{"response":{"code":"Success"}}}}`, - string(b)) + resp := testutil.MakeGQLRequestWithAccessJwt(t, params, accessJwt) + resp.RequireNoGraphQLErrors(t) + require.JSONEq(t, `{"draining":{"response":{"code":"Success"}}}`, string(resp.Data)) } func TestDrainingMode(t *testing.T) { @@ -800,10 +791,12 @@ func TestDrainingMode(t *testing.T) { } - setDrainingMode(t, true) + grootJwt, _ := testutil.GrootHttpLogin(addr + "/admin") + + setDrainingMode(t, true, grootJwt) runRequests(true) - setDrainingMode(t, false) + setDrainingMode(t, false, grootJwt) runRequests(false) } diff --git a/dgraph/cmd/alpha/run.go b/dgraph/cmd/alpha/run.go index 0411cf2ceba..1c87ff1ccd8 100644 --- a/dgraph/cmd/alpha/run.go +++ b/dgraph/cmd/alpha/run.go @@ -35,6 +35,8 @@ import ( "syscall" "time" + "github.com/dgraph-io/dgraph/graphql/web" + "github.com/dgraph-io/badger/v2/y" "github.com/dgraph-io/dgo/v200/protos/api" "github.com/dgraph-io/dgraph/edgraph" @@ -76,6 +78,9 @@ var ( // Alpha is the sub-command invoked when running "dgraph alpha". Alpha x.SubCommand + + // need this here to refer it in admin_backup.go + adminServer web.IServeGraphQL ) func init() { @@ -453,11 +458,6 @@ func setupServer(closer *y.Closer) { // TODO: Figure out what this is for? http.HandleFunc("/debug/store", storeStatsHandler) - http.HandleFunc("/admin/shutdown", shutDownHandler) - http.HandleFunc("/admin/draining", drainingHandler) - http.HandleFunc("/admin/export", exportHandler) - http.HandleFunc("/admin/config/lru_mb", memoryLimitHandler) - introspection := Alpha.Conf.GetBool("graphql_introspection") // Global Epoch is a lockless synchronization mechanism for graphql service. @@ -474,25 +474,43 @@ func setupServer(closer *y.Closer) { // The global epoch is set to maxUint64 while exiting the server. // By using this information polling goroutine terminates the subscription. globalEpoch := uint64(0) - mainServer, adminServer := admin.NewServers(introspection, &globalEpoch, closer) + var mainServer web.IServeGraphQL + mainServer, adminServer = admin.NewServers(introspection, &globalEpoch, closer) http.Handle("/graphql", mainServer.HTTPHandler()) - - whitelist := func(h http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if !handlerInit(w, r, map[string]bool{ - http.MethodPost: true, - http.MethodGet: true, - http.MethodOptions: true, - }) { - return - } - h.ServeHTTP(w, r) - }) - } - http.Handle("/admin", whitelist(adminServer.HTTPHandler())) - http.HandleFunc("/admin/schema", func(w http.ResponseWriter, r *http.Request) { + http.Handle("/admin", allowedMethodsHandler(allowedMethods{ + http.MethodGet: true, + http.MethodPost: true, + http.MethodOptions: true, + }, adminAuthHandler(adminServer.HTTPHandler()))) + + http.Handle("/admin/schema", adminAuthHandler(http.HandlerFunc(func(w http.ResponseWriter, + r *http.Request) { adminSchemaHandler(w, r, adminServer) - }) + }))) + + http.Handle("/admin/shutdown", allowedMethodsHandler(allowedMethods{http.MethodGet: true}, + adminAuthHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + shutDownHandler(w, r, adminServer) + })))) + + http.Handle("/admin/draining", allowedMethodsHandler(allowedMethods{ + http.MethodPut: true, + http.MethodPost: true, + }, adminAuthHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + drainingHandler(w, r, adminServer) + })))) + + http.Handle("/admin/export", allowedMethodsHandler(allowedMethods{http.MethodGet: true}, + adminAuthHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + exportHandler(w, r, adminServer) + })))) + + http.Handle("/admin/config/lru_mb", allowedMethodsHandler(allowedMethods{ + http.MethodGet: true, + http.MethodPut: true, + }, adminAuthHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + memoryLimitHandler(w, r, adminServer) + })))) addr := fmt.Sprintf("%s:%d", laddr, httpPort()) glog.Infof("Bringing up GraphQL HTTP API at %s/graphql", addr) diff --git a/edgraph/access.go b/edgraph/access.go index ff0a3857642..d8ffc434db3 100644 --- a/edgraph/access.go +++ b/edgraph/access.go @@ -65,7 +65,7 @@ func authorizeQuery(ctx context.Context, parsedReq *gql.Result, graphql bool) er return nil } -func authorizeGuardians(ctx context.Context) error { +func AuthorizeGuardians(ctx context.Context) error { // always allow access return nil } diff --git a/edgraph/access_ee.go b/edgraph/access_ee.go index 568a1500b17..826e03e8f87 100644 --- a/edgraph/access_ee.go +++ b/edgraph/access_ee.go @@ -794,8 +794,8 @@ func authorizeQuery(ctx context.Context, parsedReq *gql.Result, graphql bool) er return nil } -// authorizeGuardians authorizes the operation for users which belong to Guardians group. -func authorizeGuardians(ctx context.Context) error { +// AuthorizeGuardians authorizes the operation for users which belong to Guardians group. +func AuthorizeGuardians(ctx context.Context) error { if len(worker.Config.HmacSecret) == 0 { // the user has not turned on the acl feature return nil diff --git a/edgraph/server.go b/edgraph/server.go index e7a3aaf5f50..6fa4703d7d2 100644 --- a/edgraph/server.go +++ b/edgraph/server.go @@ -699,7 +699,7 @@ func (s *Server) Health(ctx context.Context, all bool) (*api.Response, error) { var healthAll []pb.HealthInfo if all { - if err := authorizeGuardians(ctx); err != nil { + if err := AuthorizeGuardians(ctx); err != nil { return nil, err } pool := conn.GetPools().GetAll() @@ -738,7 +738,7 @@ func (s *Server) State(ctx context.Context) (*api.Response, error) { return nil, ctx.Err() } - if err := authorizeGuardians(ctx); err != nil { + if err := AuthorizeGuardians(ctx); err != nil { return nil, err } diff --git a/ee/acl/acl_test.go b/ee/acl/acl_test.go index fcf2c07c5b7..9ff6da99738 100644 --- a/ee/acl/acl_test.go +++ b/ee/acl/acl_test.go @@ -13,13 +13,10 @@ package acl import ( - "bytes" "context" "encoding/json" "errors" "fmt" - "io/ioutil" - "net/http" "strconv" "testing" "time" @@ -631,28 +628,7 @@ type group struct { func makeRequest(t *testing.T, accessToken string, params testutil.GraphQLParams) *testutil. GraphQLResponse { - adminUrl := "http://" + testutil.SockAddrHttp + "/admin" - - b, err := json.Marshal(params) - require.NoError(t, err) - - req, err := http.NewRequest(http.MethodPost, adminUrl, bytes.NewBuffer(b)) - require.NoError(t, err) - req.Header.Set("X-Dgraph-AccessToken", accessToken) - req.Header.Set("Content-Type", "application/json") - client := &http.Client{} - resp, err := client.Do(req) - require.NoError(t, err) - - defer resp.Body.Close() - b, err = ioutil.ReadAll(resp.Body) - require.NoError(t, err) - - var result *testutil.GraphQLResponse - err = json.Unmarshal(b, &result) - require.NoError(t, err) - - return result + return testutil.MakeGQLRequestWithAccessJwt(t, ¶ms, accessToken) } func addRulesToGroup(t *testing.T, accessToken, group string, rules []rule) { @@ -1750,49 +1726,28 @@ func TestWrongPermission(t *testing.T) { } func TestHealthForAcl(t *testing.T) { - resetUser(t) - - gqlQuery := ` - query { - health { - instance - address - lastEcho - status - version - uptime - group - } - }` - params := testutil.GraphQLParams{ - Query: gqlQuery, + Query: ` + query { + health { + instance + address + lastEcho + status + version + uptime + group + } + }`, } // assert errors for non-guardians - accessJwt, _, err := testutil.HttpLogin(&testutil.LoginParams{ - Endpoint: adminEndpoint, - UserID: userid, - Passwd: userpassword, - }) - require.NoError(t, err, "login failed") - - resp := makeRequest(t, accessJwt, params) - expectedError := fmt.Sprintf("Error: rpc error: code"+ - " = PermissionDenied desc = Only guardians are allowed access. "+ - "User '%s' is not a member of guardians group.", userid) - require.Equal(t, x.GqlErrorList{{Message: expectedError}}, resp.Errors) - require.JSONEq(t, `{ "health": [] }`, string(resp.Data)) + assertNonGuardianFailure(t, "health", false, params) // assert data for guardians - accessJwt, _, err = testutil.HttpLogin(&testutil.LoginParams{ - Endpoint: adminEndpoint, - UserID: "groot", - Passwd: "password", - }) - require.NoError(t, err, "groot login failed") + accessJwt, _ := testutil.GrootHttpLogin(adminEndpoint) - resp = makeRequest(t, accessJwt, params) + resp := makeRequest(t, accessJwt, params) resp.RequireNoGraphQLErrors(t) var guardianResp struct { Health []struct { @@ -1805,7 +1760,7 @@ func TestHealthForAcl(t *testing.T) { Group string } } - err = json.Unmarshal(resp.Data, &guardianResp) + err := json.Unmarshal(resp.Data, &guardianResp) require.NoError(t, err, "health request failed") // we have 9 instances of alphas/zeros in teamcity environment @@ -1821,6 +1776,247 @@ func TestHealthForAcl(t *testing.T) { } } +func assertNonGuardianFailure(t *testing.T, queryName string, respIsNull bool, + params testutil.GraphQLParams) { + resetUser(t) + + accessJwt, _, err := testutil.HttpLogin(&testutil.LoginParams{ + Endpoint: adminEndpoint, + UserID: userid, + Passwd: userpassword, + }) + require.NoError(t, err, "login failed") + resp := makeRequest(t, accessJwt, params) + + require.Len(t, resp.Errors, 1) + require.Contains(t, resp.Errors[0].Message, + fmt.Sprintf("rpc error: code = PermissionDenied desc = Only guardians are allowed access."+ + " User '%s' is not a member of guardians group.", userid)) + if len(resp.Data) != 0 { + queryVal := "null" + if !respIsNull { + queryVal = "[]" + } + require.JSONEq(t, fmt.Sprintf(`{"%s": %s}`, queryName, queryVal), string(resp.Data)) + } +} + +type graphQLAdminEndpointTestCase struct { + name string + query string + queryName string + respIsArray bool + testGuardianAccess bool + guardianErrs x.GqlErrorList + // specifying this as empty string means it won't be compared with response data + guardianData string +} + +func TestGuardianOnlyAccessForAdminEndpoints(t *testing.T) { + tcases := []graphQLAdminEndpointTestCase{ + { + name: "backup has guardian auth", + query: ` + mutation { + backup(input: {destination: ""}) { + response { + code + message + } + } + }`, + queryName: "backup", + testGuardianAccess: true, + guardianErrs: x.GqlErrorList{{ + Message: "resolving backup failed because you must specify a 'destination' value", + Locations: []x.Location{{Line: 3, Column: 8}}, + }}, + guardianData: `{"backup": null}`, + }, + { + name: "listBackups has guardian auth", + query: ` + query { + listBackups(input: {location: ""}) { + backupId + } + }`, + queryName: "listBackups", + respIsArray: true, + testGuardianAccess: true, + guardianErrs: x.GqlErrorList{{ + Message: "resolving listBackups failed because Error: cannot read manfiests at " + + "location : The path \"\" does not exist or it is inaccessible.", + Locations: []x.Location{{Line: 3, Column: 8}}, + }}, + guardianData: `{"listBackups": []}`, + }, + { + name: "config update has guardian auth", + query: ` + mutation { + config(input: {lruMb: 1}) { + response { + code + message + } + } + }`, + queryName: "config", + testGuardianAccess: true, + guardianErrs: x.GqlErrorList{{ + Message: "resolving config failed because lru_mb must be at least 1024\n", + Locations: []x.Location{{Line: 3, Column: 8}}, + }}, + guardianData: `{"config": null}`, + }, + { + name: "config get has guardian auth", + query: ` + query { + config { + lruMb + } + }`, + queryName: "config", + testGuardianAccess: true, + guardianErrs: nil, + guardianData: "", + }, + { + name: "draining has guardian auth", + query: ` + mutation { + draining(enable: false) { + response { + code + message + } + } + }`, + queryName: "draining", + testGuardianAccess: true, + guardianErrs: nil, + guardianData: `{ + "draining": { + "response": { + "code": "Success", + "message": "draining mode has been set to false" + } + } + }`, + }, + { + name: "export has guardian auth", + query: ` + mutation { + export(input: {format: "invalid"}) { + response { + code + message + } + } + }`, + queryName: "export", + testGuardianAccess: true, + guardianErrs: x.GqlErrorList{{ + Message: "resolving export failed because invalid export format: invalid", + Locations: []x.Location{{Line: 3, Column: 8}}, + }}, + guardianData: `{"export": null}`, + }, + { + name: "restore has guardian auth", + query: ` + mutation { + restore(input: {location: "", backupId: "", keyFile: ""}) { + response { + code + message + } + } + }`, + queryName: "restore", + testGuardianAccess: true, + guardianErrs: x.GqlErrorList{{ + Message: "resolving restore failed because failed to verify backup: while retrieving" + + " manifests: The path \"\" does not exist or it is inaccessible.", + Locations: []x.Location{{Line: 3, Column: 8}}, + }}, + guardianData: `{"restore": null}`, + }, + { + name: "getGQLSchema has guardian auth", + query: ` + query { + getGQLSchema { + id + } + }`, + queryName: "getGQLSchema", + testGuardianAccess: true, + guardianErrs: nil, + guardianData: "", + }, + { + name: "updateGQLSchema has guardian auth", + query: ` + mutation { + updateGQLSchema(input: {set: {schema: ""}}) { + gqlSchema { + id + } + } + }`, + queryName: "updateGQLSchema", + testGuardianAccess: false, + guardianErrs: nil, + guardianData: "", + }, + { + name: "shutdown has guardian auth", + query: ` + mutation { + shutdown { + response { + code + message + } + } + }`, + queryName: "shutdown", + testGuardianAccess: false, + guardianErrs: nil, + guardianData: "", + }, + } + + for _, tcase := range tcases { + t.Run(tcase.name, func(t *testing.T) { + params := testutil.GraphQLParams{Query: tcase.query} + + // assert ACL error for non-guardians + assertNonGuardianFailure(t, tcase.queryName, !tcase.respIsArray, params) + + // for guardians, assert non-ACL error or success + if tcase.testGuardianAccess { + accessJwt, _ := testutil.GrootHttpLogin(adminEndpoint) + resp := makeRequest(t, accessJwt, params) + + if tcase.guardianErrs == nil { + resp.RequireNoGraphQLErrors(t) + } else { + require.Equal(t, tcase.guardianErrs, resp.Errors) + } + + if tcase.guardianData != "" { + require.JSONEq(t, tcase.guardianData, string(resp.Data)) + } + } + }) + } +} + func TestAddUpdateGroupWithDuplicateRules(t *testing.T) { groupName := "testGroup" addedRules := []rule{ diff --git a/graphql/admin/admin.go b/graphql/admin/admin.go index d1963bf65b6..dfe8afa4933 100644 --- a/graphql/admin/admin.go +++ b/graphql/admin/admin.go @@ -236,12 +236,17 @@ const ( response: Response } + type Config { + lruMb: Float + } + ` + adminTypes + ` type Query { getGQLSchema: GQLSchema health: [NodeState] state: MembershipState + config: Config ` + adminQueries + ` } @@ -281,6 +286,55 @@ const ( ` ) +var ( + // commonAdminQueryMWs are the middlewares which should be applied to queries served by admin + // server unless some exceptional behaviour is required + commonAdminQueryMWs = resolve.QueryMiddlewares{ + resolve.IpWhitelistingMW4Query, // good to apply ip whitelisting before Guardian auth + resolve.GuardianAuthMW4Query, + } + // commonAdminMutationMWs are the middlewares which should be applied to mutations served by + // admin server unless some exceptional behaviour is required + commonAdminMutationMWs = resolve.MutationMiddlewares{ + resolve.IpWhitelistingMW4Mutation, // good to apply ip whitelisting before Guardian auth + resolve.GuardianAuthMW4Mutation, + } + adminQueryMWConfig = map[string]resolve.QueryMiddlewares{ + "health": {resolve.IpWhitelistingMW4Query}, // dgraph handles Guardian auth for health + "state": {resolve.IpWhitelistingMW4Query}, // dgraph handles Guardian auth for state + "config": commonAdminQueryMWs, + "listBackups": commonAdminQueryMWs, + // not applying ip whitelisting to keep it in sync with /alter + "getGQLSchema": {resolve.GuardianAuthMW4Query}, + // for queries and mutations related to User/Group, dgraph handles Guardian auth, + // so no need to apply GuardianAuth Middleware + "queryGroup": {resolve.IpWhitelistingMW4Query}, + "queryUser": {resolve.IpWhitelistingMW4Query}, + "getGroup": {resolve.IpWhitelistingMW4Query}, + "getCurrentUser": {resolve.IpWhitelistingMW4Query}, + "getUser": {resolve.IpWhitelistingMW4Query}, + } + adminMutationMWConfig = map[string]resolve.MutationMiddlewares{ + "backup": commonAdminMutationMWs, + "config": commonAdminMutationMWs, + "draining": commonAdminMutationMWs, + "export": commonAdminMutationMWs, + "login": {resolve.IpWhitelistingMW4Mutation}, + "restore": commonAdminMutationMWs, + "shutdown": commonAdminMutationMWs, + // not applying ip whitelisting to keep it in sync with /alter + "updateGQLSchema": {resolve.GuardianAuthMW4Mutation}, + // for queries and mutations related to User/Group, dgraph handles Guardian auth, + // so no need to apply GuardianAuth Middleware + "addUser": {resolve.IpWhitelistingMW4Mutation}, + "addGroup": {resolve.IpWhitelistingMW4Mutation}, + "updateUser": {resolve.IpWhitelistingMW4Mutation}, + "updateGroup": {resolve.IpWhitelistingMW4Mutation}, + "deleteUser": {resolve.IpWhitelistingMW4Mutation}, + "deleteGroup": {resolve.IpWhitelistingMW4Mutation}, + } +) + type gqlSchema struct { ID string `json:"id,omitempty"` Schema string `json:"schema,omitempty"` @@ -416,7 +470,7 @@ func newAdminResolverFactory() resolve.ResolverFactory { adminMutationResolvers := map[string]resolve.MutationResolverFunc{ "backup": resolveBackup, - "config": resolveConfig, + "config": resolveUpdateConfig, "draining": resolveDraining, "export": resolveExport, "login": resolveLogin, @@ -425,12 +479,20 @@ func newAdminResolverFactory() resolve.ResolverFactory { } rf := resolverFactoryWithErrorMsg(errResolverNotFound). + WithQueryMiddlewareConfig(adminQueryMWConfig). + WithMutationMiddlewareConfig(adminMutationMWConfig). WithQueryResolver("health", func(q schema.Query) resolve.QueryResolver { return resolve.QueryResolverFunc(resolveHealth) }). WithQueryResolver("state", func(q schema.Query) resolve.QueryResolver { return resolve.QueryResolverFunc(resolveState) }). + WithQueryResolver("config", func(q schema.Query) resolve.QueryResolver { + return resolve.QueryResolverFunc(resolveGetConfig) + }). + WithQueryResolver("listBackups", func(q schema.Query) resolve.QueryResolver { + return resolve.QueryResolverFunc(resolveListBackups) + }). WithMutationResolver("updateGQLSchema", func(m schema.Mutation) resolve.MutationResolver { return resolve.MutationResolverFunc( func(ctx context.Context, m schema.Mutation) (*resolve.Resolved, bool) { @@ -448,18 +510,13 @@ func newAdminResolverFactory() resolve.ResolverFactory { for gqlMut, resolver := range adminMutationResolvers { // gotta force go to evaluate the right function at each loop iteration // otherwise you get variable capture issues - func(f resolve.MutationResolverFunc) { + func(f resolve.MutationResolver) { rf.WithMutationResolver(gqlMut, func(m schema.Mutation) resolve.MutationResolver { return f }) }(resolver) } - // Add admin query endpoints. - rf = rf.WithQueryResolver("listBackups", func(q schema.Query) resolve.QueryResolver { - return resolve.QueryResolverFunc(resolveListBackups) - }) - return rf.WithSchemaIntrospection() } @@ -766,11 +823,3 @@ func response(code, msg string) map[string]interface{} { return map[string]interface{}{ "response": map[string]interface{}{"code": code, "message": msg}} } - -func emptyResult(f schema.Field, err error) *resolve.Resolved { - return &resolve.Resolved{ - Data: map[string]interface{}{f.Name(): nil}, - Field: f, - Err: err, - } -} diff --git a/graphql/admin/backup.go b/graphql/admin/backup.go index 36d4e9b6f5c..c2238dfef5d 100644 --- a/graphql/admin/backup.go +++ b/graphql/admin/backup.go @@ -41,7 +41,7 @@ func resolveBackup(ctx context.Context, m schema.Mutation) (*resolve.Resolved, b input, err := getBackupInput(m) if err != nil { - return emptyResult(m, err), false + return resolve.EmptyResult(m, err), false } err = worker.ProcessBackupRequest(context.Background(), &pb.BackupRequest{ @@ -53,7 +53,7 @@ func resolveBackup(ctx context.Context, m schema.Mutation) (*resolve.Resolved, b }, input.ForceFull) if err != nil { - return emptyResult(m, err), false + return resolve.EmptyResult(m, err), false } return &resolve.Resolved{ diff --git a/graphql/admin/config.go b/graphql/admin/config.go index 141ecf2f128..d570f98cb4d 100644 --- a/graphql/admin/config.go +++ b/graphql/admin/config.go @@ -22,6 +22,7 @@ import ( "github.com/dgraph-io/dgraph/graphql/resolve" "github.com/dgraph-io/dgraph/graphql/schema" + "github.com/dgraph-io/dgraph/posting" "github.com/dgraph-io/dgraph/worker" "github.com/golang/glog" ) @@ -34,17 +35,17 @@ type configInput struct { LogRequest *bool } -func resolveConfig(ctx context.Context, m schema.Mutation) (*resolve.Resolved, bool) { - glog.Info("Got config request through GraphQL admin API") +func resolveUpdateConfig(ctx context.Context, m schema.Mutation) (*resolve.Resolved, bool) { + glog.Info("Got config update through GraphQL admin API") input, err := getConfigInput(m) if err != nil { - return emptyResult(m, err), false + return resolve.EmptyResult(m, err), false } if input.LruMB > 0 { if err = worker.UpdateLruMb(input.LruMB); err != nil { - return emptyResult(m, err), false + return resolve.EmptyResult(m, err), false } } @@ -59,6 +60,21 @@ func resolveConfig(ctx context.Context, m schema.Mutation) (*resolve.Resolved, b }, true } +func resolveGetConfig(ctx context.Context, q schema.Query) *resolve.Resolved { + glog.Info("Got config query through GraphQL admin API") + + conf := make(map[string]interface{}) + posting.Config.Lock() + conf["lruMb"] = posting.Config.AllottedMemory + posting.Config.Unlock() + + return &resolve.Resolved{ + Data: map[string]interface{}{q.Name(): conf}, + Field: q, + } + +} + func getConfigInput(m schema.Mutation) (*configInput, error) { inputArg := m.ArgValue(schema.InputArgName) inputByts, err := json.Marshal(inputArg) diff --git a/graphql/admin/export.go b/graphql/admin/export.go index cfb19686027..85bff5f274f 100644 --- a/graphql/admin/export.go +++ b/graphql/admin/export.go @@ -37,20 +37,20 @@ func resolveExport(ctx context.Context, m schema.Mutation) (*resolve.Resolved, b input, err := getExportInput(m) if err != nil { - return emptyResult(m, err), false + return resolve.EmptyResult(m, err), false } format := worker.DefaultExportFormat if input.Format != "" { format = worker.NormalizeExportFormat(input.Format) if format == "" { - return emptyResult(m, errors.Errorf("invalid export format: %v", input.Format)), false + return resolve.EmptyResult(m, errors.Errorf("invalid export format: %v", input.Format)), false } } err = worker.ExportOverNetwork(context.Background(), format) if err != nil { - return emptyResult(m, err), false + return resolve.EmptyResult(m, err), false } return &resolve.Resolved{ diff --git a/graphql/admin/health.go b/graphql/admin/health.go index 9fc8871d49c..30cf6d2b845 100644 --- a/graphql/admin/health.go +++ b/graphql/admin/health.go @@ -33,7 +33,7 @@ func resolveHealth(ctx context.Context, q schema.Query) *resolve.Resolved { resp, err := (&edgraph.Server{}).Health(ctx, true) if err != nil { - return emptyResult(q, errors.Errorf("%s: %s", x.Error, err.Error())) + return resolve.EmptyResult(q, errors.Errorf("%s: %s", x.Error, err.Error())) } var health []map[string]interface{} diff --git a/graphql/admin/list_backups.go b/graphql/admin/list_backups.go index 2a81ba03e1f..e550e27a491 100644 --- a/graphql/admin/list_backups.go +++ b/graphql/admin/list_backups.go @@ -54,7 +54,7 @@ type manifest struct { func resolveListBackups(ctx context.Context, q schema.Query) *resolve.Resolved { input, err := getLsBackupInput(q) if err != nil { - return emptyResult(q, err) + return resolve.EmptyResult(q, err) } creds := &worker.Credentials{ @@ -65,7 +65,7 @@ func resolveListBackups(ctx context.Context, q schema.Query) *resolve.Resolved { } manifests, err := worker.ProcessListBackups(ctx, input.Location, creds) if err != nil { - return emptyResult(q, errors.Errorf("%s: %s", x.Error, err.Error())) + return resolve.EmptyResult(q, errors.Errorf("%s: %s", x.Error, err.Error())) } convertedManifests := convertManifests(manifests) @@ -73,12 +73,12 @@ func resolveListBackups(ctx context.Context, q schema.Query) *resolve.Resolved { for _, m := range convertedManifests { b, err := json.Marshal(m) if err != nil { - return emptyResult(q, err) + return resolve.EmptyResult(q, err) } var result map[string]interface{} err = json.Unmarshal(b, &result) if err != nil { - return emptyResult(q, err) + return resolve.EmptyResult(q, err) } results = append(results, result) } diff --git a/graphql/admin/login.go b/graphql/admin/login.go index bdc038ca863..f9c8967d160 100644 --- a/graphql/admin/login.go +++ b/graphql/admin/login.go @@ -42,12 +42,12 @@ func resolveLogin(ctx context.Context, m schema.Mutation) (*resolve.Resolved, bo RefreshToken: input.RefreshToken, }) if err != nil { - return emptyResult(m, err), false + return resolve.EmptyResult(m, err), false } jwt := &dgoapi.Jwt{} if err := jwt.Unmarshal(resp.GetJson()); err != nil { - return emptyResult(m, err), false + return resolve.EmptyResult(m, err), false } return &resolve.Resolved{ diff --git a/graphql/admin/restore.go b/graphql/admin/restore.go index 6f478577e6e..881a28e6bf7 100644 --- a/graphql/admin/restore.go +++ b/graphql/admin/restore.go @@ -40,7 +40,7 @@ func resolveRestore(ctx context.Context, m schema.Mutation) (*resolve.Resolved, input, err := getRestoreInput(m) if err != nil { - return emptyResult(m, err), false + return resolve.EmptyResult(m, err), false } req := pb.RestoreRequest{ @@ -54,7 +54,7 @@ func resolveRestore(ctx context.Context, m schema.Mutation) (*resolve.Resolved, } err = worker.ProcessRestoreRequest(context.Background(), &req) if err != nil { - return emptyResult(m, err), false + return resolve.EmptyResult(m, err), false } return &resolve.Resolved{ diff --git a/graphql/admin/state.go b/graphql/admin/state.go index 06efb3cb4a5..5a2597990ec 100644 --- a/graphql/admin/state.go +++ b/graphql/admin/state.go @@ -37,7 +37,7 @@ type clusterGroup struct { func resolveState(ctx context.Context, q schema.Query) *resolve.Resolved { resp, err := (&edgraph.Server{}).State(ctx) if err != nil { - return emptyResult(q, errors.Errorf("%s: %s", x.Error, err.Error())) + return resolve.EmptyResult(q, errors.Errorf("%s: %s", x.Error, err.Error())) } // unmarshal it back to MembershipState proto in order to map to graphql response @@ -46,19 +46,19 @@ func resolveState(ctx context.Context, q schema.Query) *resolve.Resolved { err = u.Unmarshal(bytes.NewReader(resp.GetJson()), &ms) if err != nil { - return emptyResult(q, err) + return resolve.EmptyResult(q, err) } // map to graphql response structure state := convertToGraphQLResp(ms) b, err := json.Marshal(state) if err != nil { - return emptyResult(q, err) + return resolve.EmptyResult(q, err) } var resultState map[string]interface{} err = json.Unmarshal(b, &resultState) if err != nil { - return emptyResult(q, err) + return resolve.EmptyResult(q, err) } return &resolve.Resolved{ diff --git a/graphql/resolve/middlewares.go b/graphql/resolve/middlewares.go new file mode 100644 index 00000000000..82af2112a1d --- /dev/null +++ b/graphql/resolve/middlewares.go @@ -0,0 +1,178 @@ +/* + * Copyright 2020 Dgraph Labs, Inc. and Contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package resolve + +import ( + "context" + "net" + + "github.com/pkg/errors" + + "google.golang.org/grpc/peer" + + "github.com/dgraph-io/dgraph/edgraph" + "github.com/dgraph-io/dgraph/graphql/schema" + "github.com/dgraph-io/dgraph/x" +) + +// QueryMiddleware represents a middleware for queries +type QueryMiddleware func(resolver QueryResolver) QueryResolver + +// MutationMiddleware represents a middleware for mutations +type MutationMiddleware func(resolver MutationResolver) MutationResolver + +// QueryMiddlewares represents a list of middlewares for queries, that get applied in the order +// they are present in the list. +// Inspired from: https://github.com/justinas/alice +type QueryMiddlewares []QueryMiddleware + +// MutationMiddlewares represents a list of middlewares for mutations, that get applied in the order +// they are present in the list. +// Inspired from: https://github.com/justinas/alice +type MutationMiddlewares []MutationMiddleware + +// Then chains the middlewares and returns the final QueryResolver. +// QueryMiddlewares{m1, m2, m3}.Then(r) +// is equivalent to: +// m1(m2(m3(r))) +// When the request comes in, it will be passed to m1, then m2, then m3 +// and finally, the given resolverFunc +// (assuming every middleware calls the following one). +// +// A chain can be safely reused by calling Then() several times. +// commonMiddlewares := QueryMiddlewares{authMiddleware, loggingMiddleware} +// healthResolver = commonMiddlewares.Then(resolveHealth) +// stateResolver = commonMiddlewares.Then(resolveState) +// Note that middlewares are called on every call to Then() +// and thus several instances of the same middleware will be created +// when a chain is reused in this way. +// For proper middleware, this should cause no problems. +// +// Then() treats nil as a QueryResolverFunc that resolves to &Resolved{Field: query} +func (mws QueryMiddlewares) Then(resolver QueryResolver) QueryResolver { + if len(mws) == 0 { + return resolver + } + if resolver == nil { + resolver = QueryResolverFunc(func(ctx context.Context, query schema.Query) *Resolved { + return &Resolved{Field: query} + }) + } + for i := len(mws) - 1; i >= 0; i-- { + resolver = mws[i](resolver) + } + return resolver +} + +// Then chains the middlewares and returns the final MutationResolver. +// MutationMiddlewares{m1, m2, m3}.Then(r) +// is equivalent to: +// m1(m2(m3(r))) +// When the request comes in, it will be passed to m1, then m2, then m3 +// and finally, the given resolverFunc +// (assuming every middleware calls the following one). +// +// A chain can be safely reused by calling Then() several times. +// commonMiddlewares := MutationMiddlewares{authMiddleware, loggingMiddleware} +// backupResolver = commonMiddlewares.Then(resolveBackup) +// configResolver = commonMiddlewares.Then(resolveConfig) +// Note that middlewares are called on every call to Then() +// and thus several instances of the same middleware will be created +// when a chain is reused in this way. +// For proper middleware, this should cause no problems. +// +// Then() treats nil as a MutationResolverFunc that resolves to (&Resolved{Field: mutation}, true) +func (mws MutationMiddlewares) Then(resolver MutationResolver) MutationResolver { + if len(mws) == 0 { + return resolver + } + if resolver == nil { + resolver = MutationResolverFunc(func(ctx context.Context, + mutation schema.Mutation) (*Resolved, bool) { + return &Resolved{Field: mutation}, true + }) + } + for i := len(mws) - 1; i >= 0; i-- { + resolver = mws[i](resolver) + } + return resolver +} + +// resolveGuardianAuth returns a Resolved with error if the context doesn't contain any Guardian auth, +// otherwise it returns nil +func resolveGuardianAuth(ctx context.Context, f schema.Field) *Resolved { + if err := edgraph.AuthorizeGuardians(ctx); err != nil { + return EmptyResult(f, err) + } + return nil +} + +func resolveIpWhitelisting(ctx context.Context, f schema.Field) *Resolved { + peerInfo, ok := peer.FromContext(ctx) + if !ok { + return EmptyResult(f, errors.New("unable to find source ip")) + } + ip, _, err := net.SplitHostPort(peerInfo.Addr.String()) + if err != nil { + return EmptyResult(f, err) + } + if !x.IsIpWhitelisted(ip) { + return EmptyResult(f, errors.Errorf("unauthorized ip address: %s", ip)) + } + return nil +} + +// GuardianAuthMW4Query blocks the resolution of resolverFunc if there is no Guardian auth +// present in context, otherwise it lets the resolverFunc resolve the query. +func GuardianAuthMW4Query(resolver QueryResolver) QueryResolver { + return QueryResolverFunc(func(ctx context.Context, query schema.Query) *Resolved { + if resolved := resolveGuardianAuth(ctx, query); resolved != nil { + return resolved + } + return resolver.Resolve(ctx, query) + }) +} + +func IpWhitelistingMW4Query(resolver QueryResolver) QueryResolver { + return QueryResolverFunc(func(ctx context.Context, query schema.Query) *Resolved { + if resolved := resolveIpWhitelisting(ctx, query); resolved != nil { + return resolved + } + return resolver.Resolve(ctx, query) + }) +} + +// GuardianAuthMW4Mutation blocks the resolution of resolverFunc if there is no Guardian auth +// present in context, otherwise it lets the resolverFunc resolve the mutation. +func GuardianAuthMW4Mutation(resolver MutationResolver) MutationResolver { + return MutationResolverFunc(func(ctx context.Context, mutation schema.Mutation) (*Resolved, bool) { + if resolved := resolveGuardianAuth(ctx, mutation); resolved != nil { + return resolved, false + } + return resolver.Resolve(ctx, mutation) + }) +} + +func IpWhitelistingMW4Mutation(resolver MutationResolver) MutationResolver { + return MutationResolverFunc(func(ctx context.Context, mutation schema.Mutation) (*Resolved, + bool) { + if resolved := resolveIpWhitelisting(ctx, mutation); resolved != nil { + return resolved, false + } + return resolver.Resolve(ctx, mutation) + }) +} diff --git a/graphql/resolve/middlewares_test.go b/graphql/resolve/middlewares_test.go new file mode 100644 index 00000000000..2eaa5c18368 --- /dev/null +++ b/graphql/resolve/middlewares_test.go @@ -0,0 +1,97 @@ +/* + * Copyright 2020 Dgraph Labs, Inc. and Contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package resolve + +import ( + "context" + "testing" + + "github.com/dgraph-io/dgraph/graphql/schema" + "github.com/stretchr/testify/require" +) + +func TestQueryMiddlewares_Then_ExecutesMiddlewaresInOrder(t *testing.T) { + array := make([]int, 0) + addToArray := func(num int) { + array = append(array, num) + } + m1 := QueryMiddleware(func(resolver QueryResolver) QueryResolver { + return QueryResolverFunc(func(ctx context.Context, query schema.Query) *Resolved { + addToArray(1) + defer addToArray(5) + return resolver.Resolve(ctx, query) + }) + }) + m2 := QueryMiddleware(func(resolver QueryResolver) QueryResolver { + return QueryResolverFunc(func(ctx context.Context, query schema.Query) *Resolved { + addToArray(2) + resolved := resolver.Resolve(ctx, query) + addToArray(4) + return resolved + }) + }) + mws := QueryMiddlewares{m1, m2} + + resolver := mws.Then(QueryResolverFunc(func(ctx context.Context, query schema.Query) *Resolved { + addToArray(3) + return &Resolved{ + Field: query, + Extensions: &schema.Extensions{TouchedUids: 1}, + } + })) + resolved := resolver.Resolve(context.Background(), nil) + + require.Equal(t, &Resolved{Extensions: &schema.Extensions{TouchedUids: 1}}, resolved) + require.Equal(t, []int{1, 2, 3, 4, 5}, array) +} + +func TestMutationMiddlewares_Then_ExecutesMiddlewaresInOrder(t *testing.T) { + array := make([]int, 0) + addToArray := func(num int) { + array = append(array, num) + } + m1 := MutationMiddleware(func(resolver MutationResolver) MutationResolver { + return MutationResolverFunc(func(ctx context.Context, mutation schema.Mutation) (*Resolved, bool) { + addToArray(1) + defer addToArray(5) + return resolver.Resolve(ctx, mutation) + }) + }) + m2 := MutationMiddleware(func(resolver MutationResolver) MutationResolver { + return MutationResolverFunc(func(ctx context.Context, + mutation schema.Mutation) (*Resolved, bool) { + addToArray(2) + resolved, success := resolver.Resolve(ctx, mutation) + addToArray(4) + return resolved, success + }) + }) + mws := MutationMiddlewares{m1, m2} + + resolver := mws.Then(MutationResolverFunc(func(ctx context.Context, mutation schema.Mutation) (*Resolved, bool) { + addToArray(3) + return &Resolved{ + Field: mutation, + Extensions: &schema.Extensions{TouchedUids: 1}, + }, true + })) + resolved, succeeded := resolver.Resolve(context.Background(), nil) + + require.True(t, succeeded) + require.Equal(t, &Resolved{Extensions: &schema.Extensions{TouchedUids: 1}}, resolved) + require.Equal(t, []int{1, 2, 3, 4, 5}, array) +} diff --git a/graphql/resolve/resolver.go b/graphql/resolve/resolver.go index ce8016d9fc3..67b35874c1e 100644 --- a/graphql/resolve/resolver.go +++ b/graphql/resolve/resolver.go @@ -54,12 +54,12 @@ type ResolverFactory interface { mutationResolverFor(mutation schema.Mutation) MutationResolver // WithQueryResolver adds a new query resolver. Each time query name is resolved - // resolver is called to create a new instane of a QueryResolver to resolve the + // resolver is called to create a new instance of a QueryResolver to resolve the // query. WithQueryResolver(name string, resolver func(schema.Query) QueryResolver) ResolverFactory // WithMutationResolver adds a new query resolver. Each time mutation name is resolved - // resolver is called to create a new instane of a MutationResolver to resolve the + // resolver is called to create a new instance of a MutationResolver to resolve the // mutation. WithMutationResolver( name string, resolver func(schema.Mutation) MutationResolver) ResolverFactory @@ -68,6 +68,15 @@ type ResolverFactory interface { // factory. The registration happens only once. WithConventionResolvers(s schema.Schema, fns *ResolverFns) ResolverFactory + // WithQueryMiddlewareConfig adds the configuration to use to apply middlewares before resolving + // queries. The config should be a mapping of the name of query to its middlewares. + WithQueryMiddlewareConfig(config map[string]QueryMiddlewares) ResolverFactory + + // WithMutationMiddlewareConfig adds the configuration to use to apply middlewares before + // resolving mutations. The config should be a mapping of the name of mutation to its + // middlewares. + WithMutationMiddlewareConfig(config map[string]MutationMiddlewares) ResolverFactory + // WithSchemaIntrospection adds schema introspection capabilities to the factory. // So __schema and __type queries can be resolved. WithSchemaIntrospection() ResolverFactory @@ -97,6 +106,9 @@ type resolverFactory struct { queryResolvers map[string]func(schema.Query) QueryResolver mutationResolvers map[string]func(schema.Mutation) MutationResolver + queryMiddlewareConfig map[string]QueryMiddlewares + mutationMiddlewareConfig map[string]MutationMiddlewares + // returned if the factory gets asked for resolver for a field that it doesn't // know about. queryError QueryResolverFunc @@ -250,6 +262,22 @@ func (rf *resolverFactory) WithConventionResolvers( return rf } +func (rf *resolverFactory) WithQueryMiddlewareConfig( + config map[string]QueryMiddlewares) ResolverFactory { + if len(config) != 0 { + rf.queryMiddlewareConfig = config + } + return rf +} + +func (rf *resolverFactory) WithMutationMiddlewareConfig( + config map[string]MutationMiddlewares) ResolverFactory { + if len(config) != 0 { + rf.mutationMiddlewareConfig = config + } + return rf +} + // NewResolverFactory returns a ResolverFactory that resolves requests via // query/mutation rewriting and execution through Dgraph. If the factory gets asked // to resolve a query/mutation it doesn't know how to rewrite, it uses @@ -261,6 +289,9 @@ func NewResolverFactory( queryResolvers: make(map[string]func(schema.Query) QueryResolver), mutationResolvers: make(map[string]func(schema.Mutation) MutationResolver), + queryMiddlewareConfig: make(map[string]QueryMiddlewares), + mutationMiddlewareConfig: make(map[string]MutationMiddlewares), + queryError: queryError, mutationError: mutationError, } @@ -282,16 +313,18 @@ func StdDeleteCompletion(name string) CompletionFunc { } func (rf *resolverFactory) queryResolverFor(query schema.Query) QueryResolver { + mws := rf.queryMiddlewareConfig[query.Name()] if resolver, ok := rf.queryResolvers[query.Name()]; ok { - return resolver(query) + return mws.Then(resolver(query)) } return rf.queryError } func (rf *resolverFactory) mutationResolverFor(mutation schema.Mutation) MutationResolver { + mws := rf.mutationMiddlewareConfig[mutation.Name()] if resolver, ok := rf.mutationResolvers[mutation.Name()]; ok { - return resolver(mutation) + return mws.Then(resolver(mutation)) } return rf.mutationError @@ -1550,3 +1583,11 @@ func (h *httpMutationResolver) Resolve(ctx context.Context, mutation schema.Muta resolved := (*httpResolver)(h).Resolve(ctx, mutation) return resolved, resolved.Err == nil || resolved.Err.Error() == "" } + +func EmptyResult(f schema.Field, err error) *Resolved { + return &Resolved{ + Data: map[string]interface{}{f.Name(): nil}, + Field: f, + Err: schema.GQLWrapLocationf(err, f.Location(), "resolving %s failed", f.Name()), + } +} diff --git a/graphql/web/http.go b/graphql/web/http.go index 2a0ac84a48e..fd24c112c84 100644 --- a/graphql/web/http.go +++ b/graphql/web/http.go @@ -24,23 +24,19 @@ import ( "io" "io/ioutil" "mime" - "net" "net/http" - "strconv" "strings" - "github.com/golang/glog" - "github.com/graph-gophers/graphql-transport-ws/graphqlws" - "go.opencensus.io/trace" - "google.golang.org/grpc/peer" - "github.com/dgraph-io/dgraph/graphql/api" "github.com/dgraph-io/dgraph/graphql/authorization" "github.com/dgraph-io/dgraph/graphql/resolve" "github.com/dgraph-io/dgraph/graphql/schema" "github.com/dgraph-io/dgraph/graphql/subscription" "github.com/dgraph-io/dgraph/x" + "github.com/golang/glog" + "github.com/graph-gophers/graphql-transport-ws/graphqlws" "github.com/pkg/errors" + "go.opencensus.io/trace" ) // An IServeGraphQL can serve a GraphQL endpoint (currently only ons http) @@ -156,19 +152,9 @@ func (gh *graphqlHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { ctx = authorization.AttachAuthorizationJwt(ctx, r) ctx = x.AttachAccessJwt(ctx, r) - - if ip, port, err := net.SplitHostPort(r.RemoteAddr); err == nil { - // Add remote addr as peer info so that the remote address can be logged - // inside Server.Login - if intPort, convErr := strconv.Atoi(port); convErr == nil { - ctx = peer.NewContext(ctx, &peer.Peer{ - Addr: &net.TCPAddr{ - IP: net.ParseIP(ip), - Port: intPort, - }, - }) - } - } + // Add remote addr as peer info so that the remote address can be logged + // inside Server.Login + ctx = x.AttachRemoteIP(ctx, r) var res *schema.Response gqlReq, err := getRequest(ctx, r) diff --git a/testutil/graphql.go b/testutil/graphql.go index 2a45180ebe8..8b720a4e252 100644 --- a/testutil/graphql.go +++ b/testutil/graphql.go @@ -78,6 +78,38 @@ func RequireNoGraphQLErrors(t *testing.T, resp *http.Response) { require.Nil(t, result.Errors) } +func MakeGQLRequest(t *testing.T, params *GraphQLParams) *GraphQLResponse { + return MakeGQLRequestWithAccessJwt(t, params, "") +} + +func MakeGQLRequestWithAccessJwt(t *testing.T, params *GraphQLParams, + accessToken string) *GraphQLResponse { + adminUrl := "http://" + SockAddrHttp + "/admin" + + b, err := json.Marshal(params) + require.NoError(t, err) + + req, err := http.NewRequest(http.MethodPost, adminUrl, bytes.NewBuffer(b)) + require.NoError(t, err) + req.Header.Set("Content-Type", "application/json") + if accessToken != "" { + req.Header.Set("X-Dgraph-AccessToken", accessToken) + } + client := &http.Client{} + resp, err := client.Do(req) + require.NoError(t, err) + + defer resp.Body.Close() + b, err = ioutil.ReadAll(resp.Body) + require.NoError(t, err) + + var gqlResp GraphQLResponse + err = json.Unmarshal(b, &gqlResp) + require.NoError(t, err) + + return &gqlResp +} + type clientCustomClaims struct { Namespace string AuthVariables map[string]interface{} diff --git a/x/x.go b/x/x.go index 95c1c04474e..14d5cb14b74 100644 --- a/x/x.go +++ b/x/x.go @@ -37,6 +37,8 @@ import ( "syscall" "time" + "google.golang.org/grpc/peer" + "github.com/dgraph-io/badger/v2" "github.com/dgraph-io/badger/v2/y" "github.com/dgraph-io/dgo/v200" @@ -369,6 +371,41 @@ func AttachAccessJwt(ctx context.Context, r *http.Request) context.Context { return ctx } +// AttachRemoteIP adds any incoming IP data into the grpc context metadata +func AttachRemoteIP(ctx context.Context, r *http.Request) context.Context { + if ip, port, err := net.SplitHostPort(r.RemoteAddr); err == nil { + if intPort, convErr := strconv.Atoi(port); convErr == nil { + ctx = peer.NewContext(ctx, &peer.Peer{ + Addr: &net.TCPAddr{ + IP: net.ParseIP(ip), + Port: intPort, + }, + }) + } + } + return ctx +} + +// IsIpWhitelisted checks if the given ipString is within the whitelisted ip range +func IsIpWhitelisted(ipString string) bool { + ip := net.ParseIP(ipString) + + if ip == nil { + return false + } + + if ip.IsLoopback() { + return true + } + + for _, ipRange := range WorkerConfig.WhiteListedIPRanges { + if bytes.Compare(ip, ipRange.Lower) >= 0 && bytes.Compare(ip, ipRange.Upper) <= 0 { + return true + } + } + return false +} + // Write response body, transparently compressing if necessary. func WriteResponse(w http.ResponseWriter, r *http.Request, b []byte) (int, error) { var out io.Writer = w