Skip to content

Commit

Permalink
Merge pull request #2193 from cockroachdb/marc/restrict_kv_endpoint
Browse files Browse the repository at this point in the history
Restrict KV endpoint to root and node users.
  • Loading branch information
mberhault committed Aug 22, 2015
2 parents f2db92f + 83dfb1b commit 54f54b4
Show file tree
Hide file tree
Showing 14 changed files with 165 additions and 195 deletions.
109 changes: 0 additions & 109 deletions acceptance/multiuser_test.go

This file was deleted.

10 changes: 6 additions & 4 deletions client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ func (ss *notifyingSender) Send(ctx context.Context, call proto.Call) {
}

func createTestClient(addr string) *client.DB {
return createTestClientFor(addr, testUser)
return createTestClientFor(addr, security.RootUser)
}

func createTestClientFor(addr, user string) *client.DB {
Expand Down Expand Up @@ -554,6 +554,8 @@ func TestConcurrentIncrements(t *testing.T) {
}

// TestClientPermissions verifies permission enforcement.
// Only root and node users are now allowed to issue kv commands.
// We still enforce the permissions config through.
// This relies on:
// - r/w permissions config for 'testUser' on the 'testUser' prefix.
// - permissive checks for 'root' on all paths
Expand All @@ -576,13 +578,13 @@ func TestClientPermissions(t *testing.T) {
{"foo", test, false},
{"foo", root, true},

{testUser + "/foo", test, true},
{testUser + "/foo", test, false},
{testUser + "/foo", root, true},

{testUser + "foo", test, true},
{testUser + "foo", test, false},
{testUser + "foo", root, true},

{testUser, test, true},
{testUser, test, false},
{testUser, root, true},

{"unknown/foo", test, false},
Expand Down
2 changes: 1 addition & 1 deletion kv/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ func (s *DBServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}

// Check request user against client certificate user.
if err := authenticationHook(args); err != nil {
if err := authenticationHook(args, false /*not public*/); err != nil {
http.Error(w, err.Error(), http.StatusUnauthorized)
return
}
Expand Down
2 changes: 1 addition & 1 deletion multiraft/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ func (lt *localRPCTransport) Listen(id proto.RaftNodeID, server ServerInterface)
Stopper: lt.stopper,
DisableCache: true,
})
err := rpcServer.RegisterAsync(raftMessageName,
err := rpcServer.RegisterAsync(raftMessageName, false, /*not public*/
func(argsI gogoproto.Message, callback func(gogoproto.Message, error)) {
protoArgs := argsI.(*proto.RaftMessageRequest)
args := RaftMessageRequest{
Expand Down
4 changes: 2 additions & 2 deletions rpc/codec/rpc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ func listenAndServeArithAndEchoService(network, addr string) (net.Addr, error) {
log.Infof("clients.Accept(): %v\n", err)
continue
}
go srv.ServeCodec(NewServerCodec(conn, nil /* no request hook */))
go srv.ServeCodec(NewServerCodec(conn))
}
}()
return clients.Addr(), nil
Expand Down Expand Up @@ -379,7 +379,7 @@ func benchmarkEchoProtoRPC(b *testing.B, size int) {
if *startEchoServer {
l, err := listenAndServeEchoService("tcp", *echoAddr,
func(srv *rpc.Server, conn io.ReadWriteCloser) {
go srv.ServeCodec(NewServerCodec(conn, nil /* no request hook */))
go srv.ServeCodec(NewServerCodec(conn))
})
if err != nil {
b.Fatal("could not start server")
Expand Down
14 changes: 2 additions & 12 deletions rpc/codec/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,6 @@ type serverCodec struct {

methods []string

// Post body-decoding hook. May be nil in tests.
requestBodyHook func(proto.Message) error

// temporary work space
respBodyBuf bytes.Buffer
respHeaderBuf bytes.Buffer
Expand All @@ -49,14 +46,13 @@ type serverCodec struct {

// NewServerCodec returns a serverCodec that communicates with the ClientCodec
// on the other end of the given conn.
func NewServerCodec(conn io.ReadWriteCloser, requestBodyHook func(proto.Message) error) rpc.ServerCodec {
func NewServerCodec(conn io.ReadWriteCloser) rpc.ServerCodec {
return &serverCodec{
baseConn: baseConn{
r: bufio.NewReader(conn),
w: bufio.NewWriter(conn),
c: conn,
},
requestBodyHook: requestBodyHook,
}
}

Expand Down Expand Up @@ -101,13 +97,7 @@ func (c *serverCodec) ReadRequestBody(x interface{}) error {
return err
}
c.reqHeader.Reset()

if c.requestBodyHook == nil || x == nil {
// Only call requestBodyHook if we are actually decoding a frame
// instead of discarding it.
return nil
}
return c.requestBodyHook(request)
return nil
}

func (c *serverCodec) WriteResponse(r *rpc.Response, x interface{}) error {
Expand Down
2 changes: 1 addition & 1 deletion rpc/send_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ func TestClientNotReady(t *testing.T) {
addr: addr,
methods: map[string]method{},
}
if err := s.Register("Heartbeat.Ping", (&Heartbeat{}).Ping, &proto.PingRequest{}); err != nil {
if err := s.RegisterPublic("Heartbeat.Ping", (&Heartbeat{}).Ping, &proto.PingRequest{}); err != nil {
t.Fatal(err)
}
if err := s.Start(); err != nil {
Expand Down
40 changes: 30 additions & 10 deletions rpc/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ import (
type method struct {
handler func(proto.Message, func(proto.Message, error))
reqType reflect.Type
public bool
}

type serverResponse struct {
Expand Down Expand Up @@ -94,9 +95,18 @@ func NewServer(addr net.Addr, context *Context) *Server {
// argument of the same type as `reqPrototype`. Both the argument and
// return value of 'handler' should be a pointer to a protocol message
// type. The handler function will be executed in a new goroutine.
func (s *Server) Register(name string, handler func(proto.Message) (proto.Message, error),
// Only system users (root and node) are allowed to use these endpoints.
func (s *Server) Register(name string,
handler func(proto.Message) (proto.Message, error),
reqPrototype proto.Message) error {
return s.RegisterAsync(name, syncAdapter(handler).exec, reqPrototype)
return s.RegisterAsync(name, false /*not public*/, syncAdapter(handler).exec, reqPrototype)
}

// RegisterPublic is similar to Register, but allows non-system users.
func (s *Server) RegisterPublic(name string,
handler func(proto.Message) (proto.Message, error),
reqPrototype proto.Message) error {
return s.RegisterAsync(name, true /*public*/, syncAdapter(handler).exec, reqPrototype)
}

// RegisterAsync registers an asynchronous method handler. Instead of
Expand All @@ -107,7 +117,10 @@ func (s *Server) Register(name string, handler func(proto.Message) (proto.Messag
// channel promptly). However, the fact that they are started in the
// RPC server's goroutine guarantees that the order of requests as
// they were read from the connection is preserved.
func (s *Server) RegisterAsync(name string, handler func(proto.Message, func(proto.Message, error)),
// If 'public' is true, all users may call this method, otherwise
// system users only (root and node).
func (s *Server) RegisterAsync(name string, public bool,
handler func(proto.Message, func(proto.Message, error)),
reqPrototype proto.Message) error {
s.mu.Lock()
defer s.mu.Unlock()
Expand All @@ -124,6 +137,7 @@ func (s *Server) RegisterAsync(name string, handler func(proto.Message, func(pro
s.methods[name] = method{
handler: handler,
reqType: reqType,
public: public,
}
return nil
}
Expand Down Expand Up @@ -171,15 +185,15 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
security.LogTLSState("RPC", r.TLS)
io.WriteString(conn, "HTTP/1.0 "+connected+"\n\n")

codec := codec.NewServerCodec(conn, authHook)
codec := codec.NewServerCodec(conn)
responses := make(chan serverResponse)
var wg sync.WaitGroup
wg.Add(1)
go func() {
s.sendResponses(codec, responses)
wg.Done()
}()
s.readRequests(codec, responses)
s.readRequests(codec, authHook, responses)
wg.Wait()

codec.Close()
Expand Down Expand Up @@ -334,15 +348,15 @@ func (s *Server) Close() {
// when the handler finishes the response is written to the responses
// channel. When the connection is closed (and any pending requests
// have finished), we close the responses channel.
func (s *Server) readRequests(codec rpc.ServerCodec, responses chan<- serverResponse) {
func (s *Server) readRequests(codec rpc.ServerCodec, authHook func(proto.Message, bool) error, responses chan<- serverResponse) {
var wg sync.WaitGroup
defer func() {
wg.Wait()
close(responses)
}()

for {
req, meth, args, err := s.readRequest(codec)
req, meth, args, err := s.readRequest(codec, authHook)
if err != nil {
if err == io.EOF || err == io.ErrUnexpectedEOF || isClosedConnection(err) {
return
Expand Down Expand Up @@ -372,8 +386,8 @@ func (s *Server) readRequests(codec rpc.ServerCodec, responses chan<- serverResp
}

// readRequest reads a single request from a connection.
func (s *Server) readRequest(codec rpc.ServerCodec) (req rpc.Request, m method,
args proto.Message, err error) {
func (s *Server) readRequest(codec rpc.ServerCodec, authHook func(proto.Message, bool) error) (
req rpc.Request, m method, args proto.Message, err error) {
if err = codec.ReadRequestHeader(&req); err != nil {
return
}
Expand All @@ -388,7 +402,13 @@ func (s *Server) readRequest(codec rpc.ServerCodec) (req rpc.Request, m method,
if ok {
args = reflect.New(m.reqType.Elem()).Interface().(proto.Message)
}
err = codec.ReadRequestBody(args)
if err = codec.ReadRequestBody(args); err != nil {
return
}
if args == nil {
return
}
err = authHook(args, m.public)
return
}

Expand Down
4 changes: 2 additions & 2 deletions rpc/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,10 @@ func TestDuplicateRegistration(t *testing.T) {

s := NewServer(util.CreateTestAddr("tcp"), NewNodeTestContext(nil, stopper))
heartbeat := &Heartbeat{}
if err := s.Register("Foo.Bar", heartbeat.Ping, &proto.PingRequest{}); err != nil {
if err := s.RegisterPublic("Foo.Bar", heartbeat.Ping, &proto.PingRequest{}); err != nil {
t.Fatalf("unexpected failure on first registration: %s", err)
}
if err := s.Register("Foo.Bar", heartbeat.Ping, &proto.PingRequest{}); err == nil {
if err := s.RegisterPublic("Foo.Bar", heartbeat.Ping, &proto.PingRequest{}); err == nil {
t.Fatalf("unexpected success on second registration")
}
}
Expand Down
10 changes: 7 additions & 3 deletions security/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,10 @@ func GetCertificateUser(tlsState *tls.ConnectionState) (string, error) {
// AuthenticationHook builds an authentication hook based on the
// security mode and client certificate.
// Must be called at connection time and passed the TLS state.
// Returns a func(proto.Message) error. The passed-in proto must implement
// Returns a func(proto.Message,bool) error. The passed-in proto must implement
// the GetUser interface.
func AuthenticationHook(insecureMode bool, tlsState *tls.ConnectionState) (
func(request proto.Message) error, error) {
func(request proto.Message, public bool) error, error) {
var certUser string
var err error

Expand All @@ -99,7 +99,7 @@ func AuthenticationHook(insecureMode bool, tlsState *tls.ConnectionState) (
}
}

return func(request proto.Message) error {
return func(request proto.Message, public bool) error {
// userRequest is an interface for RPC requests that have a "requested user".
type userRequest interface {
// GetUser returns the user from the request.
Expand All @@ -119,6 +119,10 @@ func AuthenticationHook(insecureMode bool, tlsState *tls.ConnectionState) (
return util.Errorf("missing User in request: %+v", request)
}

if !public && requestedUser != RootUser && requestedUser != NodeUser {
return util.Errorf("user %s is not allowed", requestedUser)
}

// If running in insecure mode, we have nothing to verify it against.
if insecureMode {
return nil
Expand Down
Loading

0 comments on commit 54f54b4

Please sign in to comment.