Skip to content

Commit

Permalink
Do not merge: restrict KV endpoint to root and node users.
Browse files Browse the repository at this point in the history
Work towards #2089

This is a proposed implementation. The reason it's convoluted is that
we need per method restrictions (we'll eventually be doing sql over
rpc).

This rearranges the authentication hook usage (moves from rpc/codec/
to rpc/server.go) so that it can access a 'restricted' bool passed at
method registration time. For now, all RPC handlers are passing true.

The http handlers already call the authentication hook manually, so they
just specify restricted (kv) or not (sql) at the proper time.

If there are no objections to this, I'll go ahead and remove all the
multiuser tests and add plain unauthorized tests.
  • Loading branch information
marc committed Aug 20, 2015
1 parent b12285c commit 987dd69
Show file tree
Hide file tree
Showing 12 changed files with 47 additions and 42 deletions.
2 changes: 1 addition & 1 deletion gossip/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ func (s *server) jitteredGossipInterval() time.Duration {
func (s *server) start(rpcServer *rpc.Server, stopper *stop.Stopper) {
addr := rpcServer.Addr()
s.is.NodeAddr = util.MakeUnresolvedAddr(addr.Network(), addr.String())
if err := rpcServer.Register("Gossip.Gossip", s.Gossip, &Request{}); err != nil {
if err := rpcServer.Register("Gossip.Gossip", true /*restricted*/, s.Gossip, &Request{}); err != nil {
log.Fatalf("unable to register gossip service with RPC server: %s", err)
}
rpcServer.AddCloseCallback(s.onClose)
Expand Down
4 changes: 2 additions & 2 deletions kv/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ func (s *DBServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}

// Check request user against client certificate user.
if err := authenticationHook(args); err != nil {
if err := authenticationHook(args, true /*restricted*/); err != nil {
http.Error(w, err.Error(), http.StatusUnauthorized)
return
}
Expand Down Expand Up @@ -206,7 +206,7 @@ func (s *DBServer) RegisterRPC(rpcServer *rpc.Server) error {
&proto.AdminMergeRequest{},
}
for _, r := range requests {
if err := rpcServer.Register("Server."+r.Method().String(),
if err := rpcServer.Register("Server."+r.Method().String(), true, /*restricted*/
s.executeCmd, r); err != nil {
return err
}
Expand Down
2 changes: 1 addition & 1 deletion multiraft/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ func (lt *localRPCTransport) Listen(id proto.RaftNodeID, server ServerInterface)
Stopper: lt.stopper,
DisableCache: true,
})
err := rpcServer.RegisterAsync(raftMessageName,
err := rpcServer.RegisterAsync(raftMessageName, true, /*restricted*/
func(argsI gogoproto.Message, callback func(gogoproto.Message, error)) {
protoArgs := argsI.(*proto.RaftMessageRequest)
args := RaftMessageRequest{
Expand Down
14 changes: 2 additions & 12 deletions rpc/codec/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,6 @@ type serverCodec struct {

methods []string

// Post body-decoding hook. May be nil in tests.
requestBodyHook func(proto.Message) error

// temporary work space
respBodyBuf bytes.Buffer
respHeaderBuf bytes.Buffer
Expand All @@ -49,14 +46,13 @@ type serverCodec struct {

// NewServerCodec returns a serverCodec that communicates with the ClientCodec
// on the other end of the given conn.
func NewServerCodec(conn io.ReadWriteCloser, requestBodyHook func(proto.Message) error) rpc.ServerCodec {
func NewServerCodec(conn io.ReadWriteCloser) rpc.ServerCodec {
return &serverCodec{
baseConn: baseConn{
r: bufio.NewReader(conn),
w: bufio.NewWriter(conn),
c: conn,
},
requestBodyHook: requestBodyHook,
}
}

Expand Down Expand Up @@ -101,13 +97,7 @@ func (c *serverCodec) ReadRequestBody(x interface{}) error {
return err
}
c.reqHeader.Reset()

if c.requestBodyHook == nil || x == nil {
// Only call requestBodyHook if we are actually decoding a frame
// instead of discarding it.
return nil
}
return c.requestBodyHook(request)
return nil
}

func (c *serverCodec) WriteResponse(r *rpc.Response, x interface{}) error {
Expand Down
4 changes: 2 additions & 2 deletions rpc/heartbeat.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ type HeartbeatService struct {

// Register this service on the given RPC server.
func (hs *HeartbeatService) Register(server *Server) error {
if err := server.Register("Heartbeat.Ping", hs.Ping,
if err := server.Register("Heartbeat.Ping", true /*restricted*/, hs.Ping,
&proto.PingRequest{}); err != nil {
return err
}
Expand Down Expand Up @@ -73,7 +73,7 @@ type ManualHeartbeatService struct {

// Register this service on the given RPC server.
func (mhs *ManualHeartbeatService) Register(server *Server) error {
if err := server.Register("Heartbeat.Ping", mhs.Ping,
if err := server.Register("Heartbeat.Ping", true /*restricted*/, mhs.Ping,
&proto.PingRequest{}); err != nil {
return err
}
Expand Down
2 changes: 1 addition & 1 deletion rpc/send_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ func TestClientNotReady(t *testing.T) {
addr: addr,
methods: map[string]method{},
}
if err := s.Register("Heartbeat.Ping", (&Heartbeat{}).Ping, &proto.PingRequest{}); err != nil {
if err := s.Register("Heartbeat.Ping", false, (&Heartbeat{}).Ping, &proto.PingRequest{}); err != nil {
t.Fatal(err)
}
if err := s.Start(); err != nil {
Expand Down
39 changes: 25 additions & 14 deletions rpc/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,9 @@ import (
)

type method struct {
handler func(proto.Message, func(proto.Message, error))
reqType reflect.Type
handler func(proto.Message, func(proto.Message, error))
reqType reflect.Type
restricted bool
}

type serverResponse struct {
Expand Down Expand Up @@ -94,9 +95,12 @@ func NewServer(addr net.Addr, context *Context) *Server {
// argument of the same type as `reqPrototype`. Both the argument and
// return value of 'handler' should be a pointer to a protocol message
// type. The handler function will be executed in a new goroutine.
func (s *Server) Register(name string, handler func(proto.Message) (proto.Message, error),
// If 'restricted' is true, only system users (RootUser and NodeUser)
// may call this method.
func (s *Server) Register(name string, restricted bool,
handler func(proto.Message) (proto.Message, error),
reqPrototype proto.Message) error {
return s.RegisterAsync(name, syncAdapter(handler).exec, reqPrototype)
return s.RegisterAsync(name, restricted, syncAdapter(handler).exec, reqPrototype)
}

// RegisterAsync registers an asynchronous method handler. Instead of
Expand All @@ -107,7 +111,10 @@ func (s *Server) Register(name string, handler func(proto.Message) (proto.Messag
// channel promptly). However, the fact that they are started in the
// RPC server's goroutine guarantees that the order of requests as
// they were read from the connection is preserved.
func (s *Server) RegisterAsync(name string, handler func(proto.Message, func(proto.Message, error)),
// If 'restricted' is true, only system users (RootUser and NodeUser)
// may call this method.
func (s *Server) RegisterAsync(name string, restricted bool,
handler func(proto.Message, func(proto.Message, error)),
reqPrototype proto.Message) error {
s.mu.Lock()
defer s.mu.Unlock()
Expand All @@ -122,8 +129,9 @@ func (s *Server) RegisterAsync(name string, handler func(proto.Message, func(pro
return util.Errorf("request type not a pointer")
}
s.methods[name] = method{
handler: handler,
reqType: reqType,
handler: handler,
reqType: reqType,
restricted: restricted,
}
return nil
}
Expand Down Expand Up @@ -171,15 +179,15 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
security.LogTLSState("RPC", r.TLS)
io.WriteString(conn, "HTTP/1.0 "+connected+"\n\n")

codec := codec.NewServerCodec(conn, authHook)
codec := codec.NewServerCodec(conn)
responses := make(chan serverResponse)
var wg sync.WaitGroup
wg.Add(1)
go func() {
s.sendResponses(codec, responses)
wg.Done()
}()
s.readRequests(codec, responses)
s.readRequests(codec, authHook, responses)
wg.Wait()

codec.Close()
Expand Down Expand Up @@ -334,15 +342,15 @@ func (s *Server) Close() {
// when the handler finishes the response is written to the responses
// channel. When the connection is closed (and any pending requests
// have finished), we close the responses channel.
func (s *Server) readRequests(codec rpc.ServerCodec, responses chan<- serverResponse) {
func (s *Server) readRequests(codec rpc.ServerCodec, authHook func(proto.Message, bool) error, responses chan<- serverResponse) {
var wg sync.WaitGroup
defer func() {
wg.Wait()
close(responses)
}()

for {
req, meth, args, err := s.readRequest(codec)
req, meth, args, err := s.readRequest(codec, authHook)
if err != nil {
if err == io.EOF || err == io.ErrUnexpectedEOF || isClosedConnection(err) {
return
Expand Down Expand Up @@ -372,8 +380,8 @@ func (s *Server) readRequests(codec rpc.ServerCodec, responses chan<- serverResp
}

// readRequest reads a single request from a connection.
func (s *Server) readRequest(codec rpc.ServerCodec) (req rpc.Request, m method,
args proto.Message, err error) {
func (s *Server) readRequest(codec rpc.ServerCodec, authHook func(proto.Message, bool) error) (
req rpc.Request, m method, args proto.Message, err error) {
if err = codec.ReadRequestHeader(&req); err != nil {
return
}
Expand All @@ -388,7 +396,10 @@ func (s *Server) readRequest(codec rpc.ServerCodec) (req rpc.Request, m method,
if ok {
args = reflect.New(m.reqType.Elem()).Interface().(proto.Message)
}
err = codec.ReadRequestBody(args)
if err = codec.ReadRequestBody(args); err != nil || args == nil {
return
}
err = authHook(args, m.restricted)
return
}

Expand Down
4 changes: 2 additions & 2 deletions rpc/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,10 @@ func TestDuplicateRegistration(t *testing.T) {

s := NewServer(util.CreateTestAddr("tcp"), NewNodeTestContext(nil, stopper))
heartbeat := &Heartbeat{}
if err := s.Register("Foo.Bar", heartbeat.Ping, &proto.PingRequest{}); err != nil {
if err := s.Register("Foo.Bar", false, heartbeat.Ping, &proto.PingRequest{}); err != nil {
t.Fatalf("unexpected failure on first registration: %s", err)
}
if err := s.Register("Foo.Bar", heartbeat.Ping, &proto.PingRequest{}); err == nil {
if err := s.Register("Foo.Bar", false, heartbeat.Ping, &proto.PingRequest{}); err == nil {
t.Fatalf("unexpected success on second registration")
}
}
Expand Down
10 changes: 7 additions & 3 deletions security/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,10 @@ func GetCertificateUser(tlsState *tls.ConnectionState) (string, error) {
// AuthenticationHook builds an authentication hook based on the
// security mode and client certificate.
// Must be called at connection time and passed the TLS state.
// Returns a func(proto.Message) error. The passed-in proto must implement
// Returns a func(proto.Message,bool) error. The passed-in proto must implement
// the GetUser interface.
func AuthenticationHook(insecureMode bool, tlsState *tls.ConnectionState) (
func(request proto.Message) error, error) {
func(request proto.Message, restricted bool) error, error) {
var certUser string
var err error

Expand All @@ -99,7 +99,7 @@ func AuthenticationHook(insecureMode bool, tlsState *tls.ConnectionState) (
}
}

return func(request proto.Message) error {
return func(request proto.Message, restricted bool) error {
// userRequest is an interface for RPC requests that have a "requested user".
type userRequest interface {
// GetUser returns the user from the request.
Expand All @@ -119,6 +119,10 @@ func AuthenticationHook(insecureMode bool, tlsState *tls.ConnectionState) (
return util.Errorf("missing User in request: %+v", request)
}

if restricted && requestedUser != RootUser && requestedUser != NodeUser {
return util.Errorf("user %s is not allowed", requestedUser)
}

// If running in insecure mode, we have nothing to verify it against.
if insecureMode {
return nil
Expand Down
2 changes: 1 addition & 1 deletion server/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ func (n *Node) start(rpcServer *rpc.Server, engines []engine.Engine,
&proto.LeaderLeaseRequest{},
}
for _, r := range requests {
if err := rpcServer.Register("Node."+r.Method().String(), n.executeCmd, r); err != nil {
if err := rpcServer.Register("Node."+r.Method().String(), true /*restricted*/, n.executeCmd, r); err != nil {
log.Fatalf("unable to register node service with RPC server: %s", err)
}
}
Expand Down
4 changes: 2 additions & 2 deletions server/raft_transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ func newRPCTransport(gossip *gossip.Gossip, rpcServer *rpc.Server, rpcContext *r
}

if t.rpcServer != nil {
if err := t.rpcServer.RegisterAsync(raftMessageName, t.RaftMessage,
&proto.RaftMessageRequest{}); err != nil {
if err := t.rpcServer.RegisterAsync(raftMessageName, true, /*restricted*/
t.RaftMessage, &proto.RaftMessageRequest{}); err != nil {
return nil, err
}
}
Expand Down
2 changes: 1 addition & 1 deletion sql/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}

// Check request user against client certificate user.
if err := authenticationHook(&args); err != nil {
if err := authenticationHook(&args, false); err != nil {
http.Error(w, err.Error(), http.StatusUnauthorized)
return
}
Expand Down

0 comments on commit 987dd69

Please sign in to comment.