Skip to content

Commit

Permalink
Merge #96451
Browse files Browse the repository at this point in the history
96451: server: only forward the SQL identity in gRPC metadata r=andreimatei a=knz

Requested by `@andreimatei` .
Informs #96427.
Informs #45018.

Prior to this patch, we were forwarding any and all gRPC metdata during a RPC fanout. This was creating doubt and confusion, about how much data is really important/useful to forward.

Analysis suggests we only care about the SQL user identity resulting from HTTP authentication. So this patch limits the forwarding to just that information.

This specialization makes the forwarding logic easier to understand.


This patch additionally renames functions as follows:

| Old name                        | New name                              |
|---------------------------------|---------------------------------------|
| `userFromContext`               | `userFromIncomingRPCContext`          |
| `getSQLUsername`                | `userFromHTTPAuthInfoContext`         |
| `apiToOutgoingGatewayCtx`       | `forwardHTTPAuthInfoToRPCCalls`       |
| `forwardAuthenticationMetadata` | `translateHTTPAuthInfoToGRPCMetadata` |
| `propagateGatewayMetadata`      | `forwardSQLIdentityThroughRPCCalls`   |


Release note: None
Epic: None

Co-authored-by: Raphael 'kena' Poss <knz@thaumogen.net>
  • Loading branch information
craig[bot] and knz committed Feb 3, 2023
2 parents cebffa1 + b8518eb commit 682ad1b
Show file tree
Hide file tree
Showing 18 changed files with 192 additions and 153 deletions.
34 changes: 17 additions & 17 deletions pkg/server/admin.go
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ func (s *adminServer) Databases(
) (_ *serverpb.DatabasesResponse, retErr error) {
ctx = s.AnnotateCtx(ctx)

sessionUser, err := userFromContext(ctx)
sessionUser, err := userFromIncomingRPCContext(ctx)
if err != nil {
return nil, serverError(ctx, err)
}
Expand Down Expand Up @@ -431,7 +431,7 @@ func (s *adminServer) DatabaseDetails(
ctx context.Context, req *serverpb.DatabaseDetailsRequest,
) (_ *serverpb.DatabaseDetailsResponse, retErr error) {
ctx = s.AnnotateCtx(ctx)
userName, err := userFromContext(ctx)
userName, err := userFromIncomingRPCContext(ctx)
if err != nil {
return nil, serverError(ctx, err)
}
Expand Down Expand Up @@ -799,7 +799,7 @@ func (s *adminServer) TableDetails(
ctx context.Context, req *serverpb.TableDetailsRequest,
) (_ *serverpb.TableDetailsResponse, retErr error) {
ctx = s.AnnotateCtx(ctx)
userName, err := userFromContext(ctx)
userName, err := userFromIncomingRPCContext(ctx)
if err != nil {
return nil, serverError(ctx, err)
}
Expand Down Expand Up @@ -1202,7 +1202,7 @@ func (s *adminServer) TableStats(
) (*serverpb.TableStatsResponse, error) {
ctx = s.AnnotateCtx(ctx)

userName, err := userFromContext(ctx)
userName, err := userFromIncomingRPCContext(ctx)
if err != nil {
return nil, serverError(ctx, err)
}
Expand Down Expand Up @@ -1443,7 +1443,7 @@ func (s *adminServer) Users(
ctx context.Context, req *serverpb.UsersRequest,
) (_ *serverpb.UsersResponse, retErr error) {
ctx = s.AnnotateCtx(ctx)
userName, err := userFromContext(ctx)
userName, err := userFromIncomingRPCContext(ctx)
if err != nil {
return nil, serverError(ctx, err)
}
Expand Down Expand Up @@ -1647,7 +1647,7 @@ func (s *adminServer) RangeLog(
ctx = s.AnnotateCtx(ctx)

// Range keys, even when pretty-printed, contain PII.
user, err := userFromContext(ctx)
user, err := userFromIncomingRPCContext(ctx)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -1881,7 +1881,7 @@ func (s *adminServer) SetUIData(
) (*serverpb.SetUIDataResponse, error) {
ctx = s.AnnotateCtx(ctx)

userName, err := userFromContext(ctx)
userName, err := userFromIncomingRPCContext(ctx)
if err != nil {
return nil, serverError(ctx, err)
}
Expand Down Expand Up @@ -1920,7 +1920,7 @@ func (s *adminServer) GetUIData(
) (*serverpb.GetUIDataResponse, error) {
ctx = s.AnnotateCtx(ctx)

userName, err := userFromContext(ctx)
userName, err := userFromIncomingRPCContext(ctx)
if err != nil {
return nil, serverError(ctx, err)
}
Expand Down Expand Up @@ -2188,7 +2188,7 @@ func getLivenessResponse(
func (s *adminServer) Liveness(
ctx context.Context, req *serverpb.LivenessRequest,
) (*serverpb.LivenessResponse, error) {
ctx = propagateGatewayMetadata(ctx)
ctx = forwardSQLIdentityThroughRPCCalls(ctx)
ctx = s.AnnotateCtx(ctx)

return s.sqlServer.tenantConnect.Liveness(ctx, req)
Expand All @@ -2210,7 +2210,7 @@ func (s *adminServer) Jobs(
) (_ *serverpb.JobsResponse, retErr error) {
ctx = s.AnnotateCtx(ctx)

userName, err := userFromContext(ctx)
userName, err := userFromIncomingRPCContext(ctx)
if err != nil {
return nil, serverError(ctx, err)
}
Expand Down Expand Up @@ -2405,7 +2405,7 @@ func (s *adminServer) Job(
) (_ *serverpb.JobResponse, retErr error) {
ctx = s.AnnotateCtx(ctx)

userName, err := userFromContext(ctx)
userName, err := userFromIncomingRPCContext(ctx)
if err != nil {
return nil, serverError(ctx, err)
}
Expand Down Expand Up @@ -2466,7 +2466,7 @@ func (s *adminServer) Locations(
ctx = s.AnnotateCtx(ctx)

// Require authentication.
_, err := userFromContext(ctx)
_, err := userFromIncomingRPCContext(ctx)
if err != nil {
return nil, serverError(ctx, err)
}
Expand Down Expand Up @@ -2536,7 +2536,7 @@ func (s *adminServer) QueryPlan(
) (*serverpb.QueryPlanResponse, error) {
ctx = s.AnnotateCtx(ctx)

userName, err := userFromContext(ctx)
userName, err := userFromIncomingRPCContext(ctx)
if err != nil {
return nil, serverError(ctx, err)
}
Expand Down Expand Up @@ -2579,7 +2579,7 @@ func (s *adminServer) QueryPlan(
// getStatementBundle retrieves the statement bundle with the given id and
// writes it out as an attachment.
func (s *adminServer) getStatementBundle(ctx context.Context, id int64, w http.ResponseWriter) {
sessionUser, err := userFromContext(ctx)
sessionUser, err := userFromIncomingRPCContext(ctx)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
Expand Down Expand Up @@ -2904,7 +2904,7 @@ func (s *adminServer) DataDistribution(
return nil, err
}

userName, err := userFromContext(ctx)
userName, err := userFromIncomingRPCContext(ctx)
if err != nil {
return nil, serverError(ctx, err)
}
Expand Down Expand Up @@ -3109,7 +3109,7 @@ func (s *adminServer) dataDistributionHelper(
func (s *systemAdminServer) EnqueueRange(
ctx context.Context, req *serverpb.EnqueueRangeRequest,
) (*serverpb.EnqueueRangeResponse, error) {
ctx = propagateGatewayMetadata(ctx)
ctx = forwardSQLIdentityThroughRPCCalls(ctx)
ctx = s.AnnotateCtx(ctx)

if _, err := s.requireAdminUser(ctx); err != nil {
Expand Down Expand Up @@ -3981,7 +3981,7 @@ func (c *adminPrivilegeChecker) requireViewDebugPermission(ctx context.Context)
func (c *adminPrivilegeChecker) getUserAndRole(
ctx context.Context,
) (userName username.SQLUsername, isAdmin bool, err error) {
userName, err = userFromContext(ctx)
userName, err = userFromIncomingRPCContext(ctx)
if err != nil {
return userName, false, err
}
Expand Down
10 changes: 1 addition & 9 deletions pkg/server/api_v2.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ import (
"strconv"

"github.com/cockroachdb/cockroach/pkg/kv"
"github.com/cockroachdb/cockroach/pkg/security/username"
"github.com/cockroachdb/cockroach/pkg/server/serverpb"
"github.com/cockroachdb/cockroach/pkg/server/telemetry"
"github.com/cockroachdb/cockroach/pkg/sql/roleoption"
Expand All @@ -70,13 +69,6 @@ func writeJSONResponse(ctx context.Context, w http.ResponseWriter, code int, pay
_, _ = w.Write(res)
}

// Returns a SQL username from the request context of a route requiring login.
// Only use in routes that require login (requiresAuth = true in its route
// definition).
func getSQLUsername(ctx context.Context) username.SQLUsername {
return username.MakeSQLUsernameFromPreNormalizedString(ctx.Value(webSessionUserKey{}).(string))
}

type ApiV2System interface {
health(w http.ResponseWriter, r *http.Request)
listNodes(w http.ResponseWriter, r *http.Request)
Expand Down Expand Up @@ -325,7 +317,7 @@ func (a *apiV2Server) listSessions(w http.ResponseWriter, r *http.Request) {
reqExcludeClosed := r.URL.Query().Get("exclude_closed_sessions") == "true"
req := &serverpb.ListSessionsRequest{Username: reqUsername, ExcludeClosedSessions: reqExcludeClosed}
response := &listSessionsResponse{}
outgoingCtx := apiToOutgoingGatewayCtx(ctx, r)
outgoingCtx := forwardHTTPAuthInfoToRPCCalls(ctx, r)

responseProto, pagState, err := a.status.listSessionsHelper(outgoingCtx, req, limit, start)
if err != nil {
Expand Down
18 changes: 4 additions & 14 deletions pkg/server/api_v2_auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ import (
"github.com/cockroachdb/cockroach/pkg/util/protoutil"
"github.com/cockroachdb/errors"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
)

Expand Down Expand Up @@ -353,13 +352,11 @@ func (a *authenticationV2Mux) ServeHTTP(w http.ResponseWriter, req *http.Request
}
// Valid session found, or insecure. Set the username in the request context,
// so child http.Handlers can access it.
ctx := req.Context()
ctx = context.WithValue(ctx, webSessionUserKey{}, u)
var sessionID int64
if cookie != nil {
ctx = context.WithValue(ctx, webSessionIDKey{}, cookie.ID)
sessionID = cookie.ID
}
req = req.WithContext(ctx)

req = req.WithContext(contextWithHTTPAuthInfo(req.Context(), u, sessionID))
a.inner.ServeHTTP(w, req)
}

Expand Down Expand Up @@ -443,8 +440,7 @@ func (r *roleAuthorizationMux) hasRoleOption(
func (r *roleAuthorizationMux) ServeHTTP(w http.ResponseWriter, req *http.Request) {
// The username is set in authenticationV2Mux, and must correspond with a
// logged-in user.
username := username.MakeSQLUsernameFromPreNormalizedString(
req.Context().Value(webSessionUserKey{}).(string))
username := userFromHTTPAuthInfoContext(req.Context())
if role, err := r.getRoleForUser(req.Context(), username); err != nil || role < r.role {
if err != nil {
apiV2InternalError(req.Context(), err, w)
Expand All @@ -465,9 +461,3 @@ func (r *roleAuthorizationMux) ServeHTTP(w http.ResponseWriter, req *http.Reques
}
r.inner.ServeHTTP(w, req)
}

// apiToOutgoingGatewayCtx converts an HTTP API (v1 or v2) context, to one that
// can issue outgoing RPC requests under the same logged-in user.
func apiToOutgoingGatewayCtx(ctx context.Context, r *http.Request) context.Context {
return metadata.NewOutgoingContext(ctx, forwardAuthenticationMetadata(ctx, r))
}
8 changes: 4 additions & 4 deletions pkg/server/api_v2_ranges.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ type nodesResponse struct {
func (a *apiV2SystemServer) listNodes(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
limit, offset := getSimplePaginationValues(r)
ctx = apiToOutgoingGatewayCtx(ctx, r)
ctx = forwardHTTPAuthInfoToRPCCalls(ctx, r)

nodes, next, err := a.systemStatus.nodesHelper(ctx, limit, offset)
if err != nil {
Expand Down Expand Up @@ -195,7 +195,7 @@ type rangeResponse struct {
// "$ref": "#/definitions/rangeResponse"
func (a *apiV2Server) listRange(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
ctx = apiToOutgoingGatewayCtx(ctx, r)
ctx = forwardHTTPAuthInfoToRPCCalls(ctx, r)
vars := mux.Vars(r)
rangeID, err := strconv.ParseInt(vars["range_id"], 10, 64)
if err != nil {
Expand Down Expand Up @@ -378,7 +378,7 @@ type nodeRangesResponse struct {
// "$ref": "#/definitions/nodeRangesResponse"
func (a *apiV2SystemServer) listNodeRanges(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
ctx = apiToOutgoingGatewayCtx(ctx, r)
ctx = forwardHTTPAuthInfoToRPCCalls(ctx, r)
vars := mux.Vars(r)
nodeIDStr := vars["node_id"]
if nodeIDStr != "local" {
Expand Down Expand Up @@ -497,7 +497,7 @@ type hotRangeInfo struct {
// "$ref": "#/definitions/hotRangesResponse"
func (a *apiV2Server) listHotRanges(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
ctx = apiToOutgoingGatewayCtx(ctx, r)
ctx = forwardHTTPAuthInfoToRPCCalls(ctx, r)
nodeIDStr := r.URL.Query().Get("node_id")
limit, start := getRPCPaginationValues(r)

Expand Down
2 changes: 1 addition & 1 deletion pkg/server/api_v2_sql.go
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ func (a *apiV2Server) execSQL(w http.ResponseWriter, r *http.Request) {
}

// The SQL username that owns this session.
username := getSQLUsername(ctx)
username := userFromHTTPAuthInfoContext(ctx)

options := []isql.TxnOption{
isql.WithPriority(admissionpb.NormalPri),
Expand Down
14 changes: 7 additions & 7 deletions pkg/server/api_v2_sql_schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ type usersResponse struct {
func (a *apiV2Server) listUsers(w http.ResponseWriter, r *http.Request) {
limit, offset := getSimplePaginationValues(r)
ctx := r.Context()
username := getSQLUsername(ctx)
username := userFromHTTPAuthInfoContext(ctx)
ctx = a.sqlServer.AnnotateCtx(ctx)

query := `SELECT username FROM system.users WHERE "isRole" = false ORDER BY username`
Expand Down Expand Up @@ -149,7 +149,7 @@ type eventsResponse struct {
func (a *apiV2Server) listEvents(w http.ResponseWriter, r *http.Request) {
limit, offset := getSimplePaginationValues(r)
ctx := r.Context()
username := getSQLUsername(ctx)
username := userFromHTTPAuthInfoContext(ctx)
ctx = a.sqlServer.AnnotateCtx(ctx)
queryValues := r.URL.Query()

Expand Down Expand Up @@ -213,7 +213,7 @@ type databasesResponse struct {
func (a *apiV2Server) listDatabases(w http.ResponseWriter, r *http.Request) {
limit, offset := getSimplePaginationValues(r)
ctx := r.Context()
username := getSQLUsername(ctx)
username := userFromHTTPAuthInfoContext(ctx)
ctx = a.sqlServer.AnnotateCtx(ctx)

var resp databasesResponse
Expand Down Expand Up @@ -263,7 +263,7 @@ type databaseDetailsResponse struct {
// description: Database not found
func (a *apiV2Server) databaseDetails(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
username := getSQLUsername(ctx)
username := userFromHTTPAuthInfoContext(ctx)
ctx = a.sqlServer.AnnotateCtx(ctx)
pathVars := mux.Vars(r)
req := &serverpb.DatabaseDetailsRequest{
Expand Down Expand Up @@ -337,7 +337,7 @@ type databaseGrantsResponse struct {
func (a *apiV2Server) databaseGrants(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
limit, offset := getSimplePaginationValues(r)
username := getSQLUsername(ctx)
username := userFromHTTPAuthInfoContext(ctx)
ctx = a.sqlServer.AnnotateCtx(ctx)
pathVars := mux.Vars(r)
req := &serverpb.DatabaseDetailsRequest{
Expand Down Expand Up @@ -412,7 +412,7 @@ type databaseTablesResponse struct {
func (a *apiV2Server) databaseTables(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
limit, offset := getSimplePaginationValues(r)
username := getSQLUsername(ctx)
username := userFromHTTPAuthInfoContext(ctx)
ctx = a.sqlServer.AnnotateCtx(ctx)
pathVars := mux.Vars(r)
req := &serverpb.DatabaseDetailsRequest{
Expand Down Expand Up @@ -473,7 +473,7 @@ type tableDetailsResponse serverpb.TableDetailsResponse
// description: Database or table not found
func (a *apiV2Server) tableDetails(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
username := getSQLUsername(ctx)
username := userFromHTTPAuthInfoContext(ctx)
ctx = a.sqlServer.AnnotateCtx(ctx)
pathVars := mux.Vars(r)
req := &serverpb.TableDetailsRequest{
Expand Down

0 comments on commit 682ad1b

Please sign in to comment.