diff --git a/internal/rpcserver/rpcserver.go b/internal/rpcserver/rpcserver.go index 88899ac339..c56b972b5c 100644 --- a/internal/rpcserver/rpcserver.go +++ b/internal/rpcserver/rpcserver.go @@ -36,6 +36,7 @@ import ( "unicode/utf8" "github.com/gorilla/websocket" + "github.com/jrick/bitset" "github.com/decred/dcrd/blockchain/stake/v4" "github.com/decred/dcrd/blockchain/standalone/v2" @@ -54,7 +55,6 @@ import ( "github.com/decred/dcrd/txscript/v4/stdaddr" "github.com/decred/dcrd/txscript/v4/stdscript" "github.com/decred/dcrd/wire" - "github.com/jrick/bitset" ) // API version constants @@ -5849,6 +5849,39 @@ func (s *Server) authMAC(dst, auth []byte) []byte { return dst } +// checkAuthMAC checks the HTTP Basic authentication string by comparing +// it with the already generated hash. +// +// The first bool return value signifies auth success (true if successful) and +// the second bool return value specifies whether the user can change the state +// of the server (true) or whether the user is limited (false). +func (s *Server) checkAuthMAC(auth, remoteAddr string) (bool, bool) { + mac := make([]byte, 0, sha256.Size) + mac = s.authMAC(mac, []byte(auth)) + + cmp := subtle.ConstantTimeCompare(mac, s.authsha[:]) + limitcmp := subtle.ConstantTimeCompare(mac, s.limitauthsha[:]) + if cmp|limitcmp == 0 { + // Request's auth doesn't match either user + log.Warnf("RPC authentication failure from %s", remoteAddr) + return false, false + } + return true, cmp == 1 +} + +// checkAuthUserPass checks the correctness of username and password by +// generating the corresponding HTTP Basic authentication string then +// compare the string with the already generated hash. +// +// The first bool return value signifies auth success (true if successful) and +// the second bool return value specifies whether the user can change the state +// of the server (true) or whether the user is limited (false). +func (s *Server) checkAuthUserPass(user, pass, remoteAddr string) (bool, bool) { + login := user + ":" + pass + auth := "Basic " + base64.StdEncoding.EncodeToString([]byte(login)) + return s.checkAuthMAC(auth, remoteAddr) +} + // checkAuth checks the HTTP Basic authentication supplied by a wallet or RPC // client in the HTTP request r. If the supplied authentication does not match // the username and password expected, a non-nil error is returned. @@ -5878,19 +5911,11 @@ func (s *Server) checkAuth(r *http.Request, require bool) (bool, bool, error) { return false, false, nil } - mac := make([]byte, 0, sha256.Size) - mac = s.authMAC(mac, []byte(authhdr[0])) - - cmp := subtle.ConstantTimeCompare(mac, s.authsha[:]) - limitcmp := subtle.ConstantTimeCompare(mac, s.limitauthsha[:]) - if cmp|limitcmp == 0 { - // Request's auth doesn't match either user - log.Warnf("RPC authentication failure from %s", r.RemoteAddr) + authed, isAdmin := s.checkAuthMAC(authhdr[0], r.RemoteAddr) + if !authed { return false, false, errors.New("auth failure") } - - isAdmin := cmp == 1 - return true, isAdmin, nil + return authed, isAdmin, nil } // parsedRPCCmd represents a JSON-RPC request object that has been parsed into diff --git a/internal/rpcserver/rpcserver_test.go b/internal/rpcserver/rpcserver_test.go index 588af08b99..7f5f350966 100644 --- a/internal/rpcserver/rpcserver_test.go +++ b/internal/rpcserver/rpcserver_test.go @@ -14,6 +14,7 @@ import ( "bytes" "context" "fmt" + "net/http" "runtime/debug" "testing" @@ -170,3 +171,129 @@ func TestRpcServer(t *testing.T) { currentTestNum++ } } + +func TestCheckAuthUserPass(t *testing.T) { + s, err := New(&Config{ + RPCUser: "user", + RPCPass: "pass", + RPCLimitUser: "limit", + RPCLimitPass: "limit", + }) + if err != nil { + t.Fatalf("unable to create RPC server: %v", err) + } + tests := []struct { + name string + user string + pass string + wantAuthed bool + wantAdmin bool + }{ + { + name: "correct admin", + user: "user", + pass: "pass", + wantAuthed: true, + wantAdmin: true, + }, + { + name: "correct limited user", + user: "limit", + pass: "limit", + wantAuthed: true, + wantAdmin: false, + }, + { + name: "invalid admin", + user: "user", + pass: "p", + wantAuthed: false, + wantAdmin: false, + }, + { + name: "invalid limited user", + user: "limit", + pass: "", + wantAuthed: false, + wantAdmin: false, + }, + { + name: "invalid empty user", + user: "", + pass: "", + wantAuthed: false, + wantAdmin: false, + }, + } + for _, test := range tests { + authed, isAdmin := s.checkAuthUserPass(test.user, test.pass, "addr") + if authed != test.wantAuthed { + t.Errorf("%q: unexpected authed -- got %v, want %v", test.name, authed, + test.wantAuthed) + } + if isAdmin != test.wantAdmin { + t.Errorf("%q: unexpected isAdmin -- got %v, want %v", test.name, isAdmin, + test.wantAdmin) + } + } +} + +func TestCheckAuth(t *testing.T) { + { + s, err := New(&Config{}) + if err != nil { + t.Fatalf("unable to create RPC server: %v", err) + } + for i := 0; i <= 1; i++ { + authed, isAdmin, err := s.checkAuth(&http.Request{}, i == 0) + if !authed { + t.Errorf(" unexpected authed -- got %v, want %v", authed, true) + } + if !isAdmin { + t.Errorf("unexpected isAdmin -- got %v, want %v", isAdmin, true) + } + if err != nil { + t.Errorf("unexpected err -- got %v, want %v", err, nil) + } + } + } + { + s, err := New(&Config{ + RPCUser: "user", + RPCPass: "pass", + RPCLimitUser: "limit", + RPCLimitPass: "limit", + }) + if err != nil { + t.Fatalf("unable to create RPC server: %v", err) + } + for i := 0; i <= 1; i++ { + authed, isAdmin, err := s.checkAuth(&http.Request{}, i == 0) + if authed { + t.Errorf(" unexpected authed -- got %v, want %v", authed, false) + } + if isAdmin { + t.Errorf("unexpected isAdmin -- got %v, want %v", isAdmin, false) + } + if i == 0 && err == nil { + t.Errorf("unexpected err -- got %v, want auth failure", err) + } else if i != 0 && err != nil { + t.Errorf("unexpected err -- got %v, want ", err) + } + } + for i := 0; i <= 1; i++ { + r := &http.Request{Header: make(map[string][]string, 1)} + r.Header["Authorization"] = []string{"Basic Nothing"} + authed, isAdmin, err := s.checkAuth(r, i == 0) + if authed { + t.Errorf(" unexpected authed -- got %v, want %v", authed, false) + } + if isAdmin { + t.Errorf("unexpected isAdmin -- got %v, want %v", isAdmin, false) + } + if err == nil { + t.Errorf("unexpected err -- got %v, want auth failure", err) + } + } + } +} diff --git a/internal/rpcserver/rpcwebsocket.go b/internal/rpcserver/rpcwebsocket.go index e89ab88dfc..121e34eacc 100644 --- a/internal/rpcserver/rpcwebsocket.go +++ b/internal/rpcserver/rpcwebsocket.go @@ -8,9 +8,6 @@ package rpcserver import ( "bytes" "context" - "crypto/sha256" - "crypto/subtle" - "encoding/base64" "encoding/hex" "encoding/json" "errors" @@ -1466,18 +1463,11 @@ out: break out case !c.authenticated: // Check credentials. - login := authCmd.Username + ":" + authCmd.Passphrase - auth := "Basic " + base64.StdEncoding.EncodeToString([]byte(login)) - mac := make([]byte, 0, sha256.Size) - mac = c.rpcServer.authMAC(mac, []byte(auth)) - cmp := subtle.ConstantTimeCompare(mac, c.rpcServer.authsha[:]) - limitcmp := subtle.ConstantTimeCompare(mac, c.rpcServer.limitauthsha[:]) - if cmp|limitcmp != 0 { - log.Warnf("Auth failure.") + c.authenticated, c.isAdmin = c.rpcServer.checkAuthUserPass( + authCmd.Username, authCmd.Passphrase, c.addr) + if !c.authenticated { break out } - c.authenticated = true - c.isAdmin = cmp == 1 // Increase the read limits for authenticated connections. c.conn.SetReadLimit(websocketReadLimitAuthenticated) @@ -1682,19 +1672,12 @@ out: break out case !c.authenticated: // Check credentials. - login := authCmd.Username + ":" + authCmd.Passphrase - auth := "Basic " + base64.StdEncoding.EncodeToString([]byte(login)) - authSha := sha256.Sum256([]byte(auth)) - cmp := subtle.ConstantTimeCompare(authSha[:], c.rpcServer.authsha[:]) - limitcmp := subtle.ConstantTimeCompare(authSha[:], c.rpcServer.limitauthsha[:]) - if cmp != 1 && limitcmp != 1 { - log.Warnf("Auth failure.") + c.authenticated, c.isAdmin = c.rpcServer.checkAuthUserPass( + authCmd.Username, authCmd.Passphrase, c.addr) + if !c.authenticated { break out } - c.authenticated = true - c.isAdmin = cmp == 1 - // Marshal and send response. reply, err = createMarshalledReply(cmd.jsonrpc, cmd.id, nil, nil) if err != nil {