diff --git a/gossip/server.go b/gossip/server.go index d08203a225db..d149630d359c 100644 --- a/gossip/server.go +++ b/gossip/server.go @@ -150,7 +150,7 @@ func (s *server) jitteredGossipInterval() time.Duration { func (s *server) start(rpcServer *rpc.Server, stopper *stop.Stopper) { addr := rpcServer.Addr() s.is.NodeAddr = util.MakeUnresolvedAddr(addr.Network(), addr.String()) - if err := rpcServer.Register("Gossip.Gossip", s.Gossip, &Request{}); err != nil { + if err := rpcServer.Register("Gossip.Gossip", true /*restricted*/, s.Gossip, &Request{}); err != nil { log.Fatalf("unable to register gossip service with RPC server: %s", err) } rpcServer.AddCloseCallback(s.onClose) diff --git a/kv/db.go b/kv/db.go index 0a766dc86234..2486ca911592 100644 --- a/kv/db.go +++ b/kv/db.go @@ -171,7 +171,7 @@ func (s *DBServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { } // Check request user against client certificate user. - if err := authenticationHook(args); err != nil { + if err := authenticationHook(args, true /*restricted*/); err != nil { http.Error(w, err.Error(), http.StatusUnauthorized) return } @@ -206,7 +206,7 @@ func (s *DBServer) RegisterRPC(rpcServer *rpc.Server) error { &proto.AdminMergeRequest{}, } for _, r := range requests { - if err := rpcServer.Register("Server."+r.Method().String(), + if err := rpcServer.Register("Server."+r.Method().String(), true, /*restricted*/ s.executeCmd, r); err != nil { return err } diff --git a/multiraft/transport.go b/multiraft/transport.go index d99e76c3b382..a90365e76c72 100644 --- a/multiraft/transport.go +++ b/multiraft/transport.go @@ -111,7 +111,7 @@ func (lt *localRPCTransport) Listen(id proto.RaftNodeID, server ServerInterface) Stopper: lt.stopper, DisableCache: true, }) - err := rpcServer.RegisterAsync(raftMessageName, + err := rpcServer.RegisterAsync(raftMessageName, true, /*restricted*/ func(argsI gogoproto.Message, callback func(gogoproto.Message, error)) { protoArgs := argsI.(*proto.RaftMessageRequest) args := RaftMessageRequest{ diff --git a/rpc/codec/server.go b/rpc/codec/server.go index e3731404e1d4..05ac2617cb59 100644 --- a/rpc/codec/server.go +++ b/rpc/codec/server.go @@ -37,9 +37,6 @@ type serverCodec struct { methods []string - // Post body-decoding hook. May be nil in tests. - requestBodyHook func(proto.Message) error - // temporary work space respBodyBuf bytes.Buffer respHeaderBuf bytes.Buffer @@ -49,14 +46,13 @@ type serverCodec struct { // NewServerCodec returns a serverCodec that communicates with the ClientCodec // on the other end of the given conn. -func NewServerCodec(conn io.ReadWriteCloser, requestBodyHook func(proto.Message) error) rpc.ServerCodec { +func NewServerCodec(conn io.ReadWriteCloser) rpc.ServerCodec { return &serverCodec{ baseConn: baseConn{ r: bufio.NewReader(conn), w: bufio.NewWriter(conn), c: conn, }, - requestBodyHook: requestBodyHook, } } @@ -101,13 +97,7 @@ func (c *serverCodec) ReadRequestBody(x interface{}) error { return err } c.reqHeader.Reset() - - if c.requestBodyHook == nil || x == nil { - // Only call requestBodyHook if we are actually decoding a frame - // instead of discarding it. - return nil - } - return c.requestBodyHook(request) + return nil } func (c *serverCodec) WriteResponse(r *rpc.Response, x interface{}) error { diff --git a/rpc/heartbeat.go b/rpc/heartbeat.go index 52113be5f2c5..680685d9b933 100644 --- a/rpc/heartbeat.go +++ b/rpc/heartbeat.go @@ -38,7 +38,7 @@ type HeartbeatService struct { // Register this service on the given RPC server. func (hs *HeartbeatService) Register(server *Server) error { - if err := server.Register("Heartbeat.Ping", hs.Ping, + if err := server.Register("Heartbeat.Ping", true /*restricted*/, hs.Ping, &proto.PingRequest{}); err != nil { return err } @@ -73,7 +73,7 @@ type ManualHeartbeatService struct { // Register this service on the given RPC server. func (mhs *ManualHeartbeatService) Register(server *Server) error { - if err := server.Register("Heartbeat.Ping", mhs.Ping, + if err := server.Register("Heartbeat.Ping", true /*restricted*/, mhs.Ping, &proto.PingRequest{}); err != nil { return err } diff --git a/rpc/send_test.go b/rpc/send_test.go index c099cc583254..7480e66c1cc4 100644 --- a/rpc/send_test.go +++ b/rpc/send_test.go @@ -211,7 +211,7 @@ func TestClientNotReady(t *testing.T) { addr: addr, methods: map[string]method{}, } - if err := s.Register("Heartbeat.Ping", (&Heartbeat{}).Ping, &proto.PingRequest{}); err != nil { + if err := s.Register("Heartbeat.Ping", false, (&Heartbeat{}).Ping, &proto.PingRequest{}); err != nil { t.Fatal(err) } if err := s.Start(); err != nil { diff --git a/rpc/server.go b/rpc/server.go index b61d267c5327..3c5ea23c3b87 100644 --- a/rpc/server.go +++ b/rpc/server.go @@ -35,8 +35,9 @@ import ( ) type method struct { - handler func(proto.Message, func(proto.Message, error)) - reqType reflect.Type + handler func(proto.Message, func(proto.Message, error)) + reqType reflect.Type + restricted bool } type serverResponse struct { @@ -94,9 +95,12 @@ func NewServer(addr net.Addr, context *Context) *Server { // argument of the same type as `reqPrototype`. Both the argument and // return value of 'handler' should be a pointer to a protocol message // type. The handler function will be executed in a new goroutine. -func (s *Server) Register(name string, handler func(proto.Message) (proto.Message, error), +// If 'restricted' is true, only system users (RootUser and NodeUser) +// may call this method. +func (s *Server) Register(name string, restricted bool, + handler func(proto.Message) (proto.Message, error), reqPrototype proto.Message) error { - return s.RegisterAsync(name, syncAdapter(handler).exec, reqPrototype) + return s.RegisterAsync(name, restricted, syncAdapter(handler).exec, reqPrototype) } // RegisterAsync registers an asynchronous method handler. Instead of @@ -107,7 +111,10 @@ func (s *Server) Register(name string, handler func(proto.Message) (proto.Messag // channel promptly). However, the fact that they are started in the // RPC server's goroutine guarantees that the order of requests as // they were read from the connection is preserved. -func (s *Server) RegisterAsync(name string, handler func(proto.Message, func(proto.Message, error)), +// If 'restricted' is true, only system users (RootUser and NodeUser) +// may call this method. +func (s *Server) RegisterAsync(name string, restricted bool, + handler func(proto.Message, func(proto.Message, error)), reqPrototype proto.Message) error { s.mu.Lock() defer s.mu.Unlock() @@ -122,8 +129,9 @@ func (s *Server) RegisterAsync(name string, handler func(proto.Message, func(pro return util.Errorf("request type not a pointer") } s.methods[name] = method{ - handler: handler, - reqType: reqType, + handler: handler, + reqType: reqType, + restricted: restricted, } return nil } @@ -171,7 +179,7 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { security.LogTLSState("RPC", r.TLS) io.WriteString(conn, "HTTP/1.0 "+connected+"\n\n") - codec := codec.NewServerCodec(conn, authHook) + codec := codec.NewServerCodec(conn) responses := make(chan serverResponse) var wg sync.WaitGroup wg.Add(1) @@ -179,7 +187,7 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { s.sendResponses(codec, responses) wg.Done() }() - s.readRequests(codec, responses) + s.readRequests(codec, authHook, responses) wg.Wait() codec.Close() @@ -334,7 +342,7 @@ func (s *Server) Close() { // when the handler finishes the response is written to the responses // channel. When the connection is closed (and any pending requests // have finished), we close the responses channel. -func (s *Server) readRequests(codec rpc.ServerCodec, responses chan<- serverResponse) { +func (s *Server) readRequests(codec rpc.ServerCodec, authHook func(proto.Message, bool) error, responses chan<- serverResponse) { var wg sync.WaitGroup defer func() { wg.Wait() @@ -342,7 +350,7 @@ func (s *Server) readRequests(codec rpc.ServerCodec, responses chan<- serverResp }() for { - req, meth, args, err := s.readRequest(codec) + req, meth, args, err := s.readRequest(codec, authHook) if err != nil { if err == io.EOF || err == io.ErrUnexpectedEOF || isClosedConnection(err) { return @@ -372,8 +380,8 @@ func (s *Server) readRequests(codec rpc.ServerCodec, responses chan<- serverResp } // readRequest reads a single request from a connection. -func (s *Server) readRequest(codec rpc.ServerCodec) (req rpc.Request, m method, - args proto.Message, err error) { +func (s *Server) readRequest(codec rpc.ServerCodec, authHook func(proto.Message, bool) error) ( + req rpc.Request, m method, args proto.Message, err error) { if err = codec.ReadRequestHeader(&req); err != nil { return } @@ -388,7 +396,10 @@ func (s *Server) readRequest(codec rpc.ServerCodec) (req rpc.Request, m method, if ok { args = reflect.New(m.reqType.Elem()).Interface().(proto.Message) } - err = codec.ReadRequestBody(args) + if err = codec.ReadRequestBody(args); err != nil || args == nil { + return + } + err = authHook(args, m.restricted) return } diff --git a/rpc/server_test.go b/rpc/server_test.go index 0e45d85806fb..2f1916f3bd69 100644 --- a/rpc/server_test.go +++ b/rpc/server_test.go @@ -73,10 +73,10 @@ func TestDuplicateRegistration(t *testing.T) { s := NewServer(util.CreateTestAddr("tcp"), NewNodeTestContext(nil, stopper)) heartbeat := &Heartbeat{} - if err := s.Register("Foo.Bar", heartbeat.Ping, &proto.PingRequest{}); err != nil { + if err := s.Register("Foo.Bar", false, heartbeat.Ping, &proto.PingRequest{}); err != nil { t.Fatalf("unexpected failure on first registration: %s", err) } - if err := s.Register("Foo.Bar", heartbeat.Ping, &proto.PingRequest{}); err == nil { + if err := s.Register("Foo.Bar", false, heartbeat.Ping, &proto.PingRequest{}); err == nil { t.Fatalf("unexpected success on second registration") } } diff --git a/security/auth.go b/security/auth.go index 17df68160c9c..8835d5597693 100644 --- a/security/auth.go +++ b/security/auth.go @@ -85,10 +85,10 @@ func GetCertificateUser(tlsState *tls.ConnectionState) (string, error) { // AuthenticationHook builds an authentication hook based on the // security mode and client certificate. // Must be called at connection time and passed the TLS state. -// Returns a func(proto.Message) error. The passed-in proto must implement +// Returns a func(proto.Message,bool) error. The passed-in proto must implement // the GetUser interface. func AuthenticationHook(insecureMode bool, tlsState *tls.ConnectionState) ( - func(request proto.Message) error, error) { + func(request proto.Message, restricted bool) error, error) { var certUser string var err error @@ -99,7 +99,7 @@ func AuthenticationHook(insecureMode bool, tlsState *tls.ConnectionState) ( } } - return func(request proto.Message) error { + return func(request proto.Message, restricted bool) error { // userRequest is an interface for RPC requests that have a "requested user". type userRequest interface { // GetUser returns the user from the request. @@ -119,6 +119,10 @@ func AuthenticationHook(insecureMode bool, tlsState *tls.ConnectionState) ( return util.Errorf("missing User in request: %+v", request) } + if restricted && requestedUser != RootUser && requestedUser != NodeUser { + return util.Errorf("user %s is not allowed", requestedUser) + } + // If running in insecure mode, we have nothing to verify it against. if insecureMode { return nil diff --git a/server/node.go b/server/node.go index f0a0d9ed1dbc..7c27e54126ed 100644 --- a/server/node.go +++ b/server/node.go @@ -249,7 +249,7 @@ func (n *Node) start(rpcServer *rpc.Server, engines []engine.Engine, &proto.LeaderLeaseRequest{}, } for _, r := range requests { - if err := rpcServer.Register("Node."+r.Method().String(), n.executeCmd, r); err != nil { + if err := rpcServer.Register("Node."+r.Method().String(), true /*restricted*/, n.executeCmd, r); err != nil { log.Fatalf("unable to register node service with RPC server: %s", err) } } diff --git a/server/raft_transport.go b/server/raft_transport.go index 3527211900fd..ade5098eed7c 100644 --- a/server/raft_transport.go +++ b/server/raft_transport.go @@ -66,8 +66,8 @@ func newRPCTransport(gossip *gossip.Gossip, rpcServer *rpc.Server, rpcContext *r } if t.rpcServer != nil { - if err := t.rpcServer.RegisterAsync(raftMessageName, t.RaftMessage, - &proto.RaftMessageRequest{}); err != nil { + if err := t.rpcServer.RegisterAsync(raftMessageName, true, /*restricted*/ + t.RaftMessage, &proto.RaftMessageRequest{}); err != nil { return nil, err } } diff --git a/sql/server.go b/sql/server.go index 9636992d05df..fdca7a79c65b 100644 --- a/sql/server.go +++ b/sql/server.go @@ -101,7 +101,7 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { } // Check request user against client certificate user. - if err := authenticationHook(&args); err != nil { + if err := authenticationHook(&args, false); err != nil { http.Error(w, err.Error(), http.StatusUnauthorized) return }