From c405aff310ebcdce6a3a8b4c0f5245099c8601c5 Mon Sep 17 00:00:00 2001 From: Michal Budzyn Date: Sun, 14 Oct 2018 01:18:19 +0200 Subject: [PATCH] Local auth OAUTHBEARER --- proxy/client.go | 8 +-- proxy/processor_default.go | 9 +-- proxy/sasl.go | 3 +- proxy/sasl_local.go | 112 ++++++++++++++++----------------- proxy/sasl_local_auth.go | 68 ++++++++++++++++++++ proxy/sasl_oauthbearer.go | 75 ++++++++++++++++++++++ proxy/sasl_oauthbearer_test.go | 19 ++++++ 7 files changed, 229 insertions(+), 65 deletions(-) create mode 100644 proxy/sasl_local_auth.go create mode 100644 proxy/sasl_oauthbearer.go create mode 100644 proxy/sasl_oauthbearer_test.go diff --git a/proxy/client.go b/proxy/client.go index cc676ec1..312d7689 100644 --- a/proxy/client.go +++ b/proxy/client.go @@ -93,10 +93,10 @@ func NewClient(conns *ConnSet, c *config.Config, netAddressMappingFunc config.Ne ResponseBufferSize: c.Proxy.ResponseBufferSize, ReadTimeout: c.Kafka.ReadTimeout, WriteTimeout: c.Kafka.WriteTimeout, - LocalSasl: &LocalSasl{ - enabled: c.Auth.Local.Enable, - timeout: c.Auth.Local.Timeout, - localAuthenticator: passwordAuthenticator}, + LocalSasl: NewLocalSasl( + c.Auth.Local.Enable, + c.Auth.Local.Timeout, + passwordAuthenticator), AuthServer: &AuthServer{ enabled: c.Auth.Gateway.Server.Enable, magic: c.Auth.Gateway.Server.Magic, diff --git a/proxy/processor_default.go b/proxy/processor_default.go index 756b4600..5bfc4838 100644 --- a/proxy/processor_default.go +++ b/proxy/processor_default.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "github.com/grepplabs/kafka-proxy/proxy/protocol" + "github.com/sirupsen/logrus" "io" "strconv" "time" @@ -32,7 +33,7 @@ func (handler *DefaultRequestHandler) handleRequest(dst DeadlineWriter, src Dead if err = protocol.Decode(keyVersionBuf, requestKeyVersion); err != nil { return true, err } - //logrus.Printf("Kafka request length %v, key %v, version %v", requestKeyVersion.Length, requestKeyVersion.ApiKey, requestKeyVersion.ApiVersion) + logrus.Debugf("Kafka request key %v, version %v, length %v", requestKeyVersion.ApiKey, requestKeyVersion.ApiVersion, requestKeyVersion.Length) if requestKeyVersion.ApiKey < minRequestApiKey || requestKeyVersion.ApiKey > maxRequestApiKey { return true, fmt.Errorf("api key %d is invalid", requestKeyVersion.ApiKey) @@ -55,11 +56,11 @@ func (handler *DefaultRequestHandler) handleRequest(dst DeadlineWriter, src Dead case apiKeySaslHandshake: switch requestKeyVersion.ApiVersion { case 0: - if err = ctx.localSasl.receiveAndSendSASLPlainAuthV0(src, keyVersionBuf); err != nil { + if err = ctx.localSasl.receiveAndSendSASLAuthV0(src, keyVersionBuf); err != nil { return true, err } case 1: - if err = ctx.localSasl.receiveAndSendSASLPlainAuthV1(src, keyVersionBuf); err != nil { + if err = ctx.localSasl.receiveAndSendSASLAuthV1(src, keyVersionBuf); err != nil { return true, err } default: @@ -132,7 +133,7 @@ func (handler *DefaultResponseHandler) handleResponse(dst DeadlineWriter, src De return true, err } proxyResponsesBytes.WithLabelValues(ctx.brokerAddress).Add(float64(responseHeader.Length + 4)) - //logrus.Printf("Kafka response lenght %v for key %v, version %v", responseHeader.Length, requestKeyVersion.ApiKey, requestKeyVersion.ApiVersion) + logrus.Debugf("Kafka response key %v, version %v, length %v", requestKeyVersion.ApiKey, requestKeyVersion.ApiVersion, responseHeader.Length) responseDeadline := time.Now().Add(ctx.timeout) err = dst.SetWriteDeadline(responseDeadline) diff --git a/proxy/sasl.go b/proxy/sasl.go index 6b5768b1..17b210e5 100644 --- a/proxy/sasl.go +++ b/proxy/sasl.go @@ -11,7 +11,8 @@ import ( ) const ( - SASLPlain = "PLAIN" + SASLPlain = "PLAIN" + SASLOAuthBearer = "OAUTHBEARER" ) type SASLPlainAuth struct { diff --git a/proxy/sasl_local.go b/proxy/sasl_local.go index e10c2445..e0300ae7 100644 --- a/proxy/sasl_local.go +++ b/proxy/sasl_local.go @@ -8,98 +8,116 @@ import ( "github.com/grepplabs/kafka-proxy/pkg/apis" "github.com/grepplabs/kafka-proxy/proxy/protocol" "io" - "strconv" - "strings" "time" ) type LocalSasl struct { - enabled bool - timeout time.Duration - localAuthenticator apis.PasswordAuthenticator + enabled bool + timeout time.Duration + localAuthenticators map[string]LocalSaslAuth } -func (p *LocalSasl) receiveAndSendSASLPlainAuthV1(conn DeadlineReaderWriter, readKeyVersionBuf []byte) (err error) { - if err = p.receiveAndSendSaslV0orV1(conn, readKeyVersionBuf, 1); err != nil { +func NewLocalSasl(enabled bool, timeout time.Duration, passwordAuthenticator apis.PasswordAuthenticator) *LocalSasl { + localAuthenticators := make(map[string]LocalSaslAuth) + if passwordAuthenticator != nil { + localAuthenticators[SASLPlain] = NewLocalSaslPlain(passwordAuthenticator) + } + localAuthenticators[SASLOAuthBearer] = NewLocalSaslOauth() + return &LocalSasl{ + enabled: enabled, + timeout: timeout, + localAuthenticators: localAuthenticators, + } +} + +func (p *LocalSasl) receiveAndSendSASLAuthV1(conn DeadlineReaderWriter, readKeyVersionBuf []byte) (err error) { + var localSaslAuth LocalSaslAuth + if localSaslAuth, err = p.receiveAndSendSaslV0orV1(conn, readKeyVersionBuf, 1); err != nil { return err } - if err = p.receiveAndSendAuthV1(conn); err != nil { + if err = p.receiveAndSendAuthV1(conn, localSaslAuth); err != nil { return err } return nil } -func (p *LocalSasl) receiveAndSendSASLPlainAuthV0(conn DeadlineReaderWriter, readKeyVersionBuf []byte) (err error) { - if err = p.receiveAndSendSaslV0orV1(conn, readKeyVersionBuf, 0); err != nil { +func (p *LocalSasl) receiveAndSendSASLAuthV0(conn DeadlineReaderWriter, readKeyVersionBuf []byte) (err error) { + var localSaslAuth LocalSaslAuth + if localSaslAuth, err = p.receiveAndSendSaslV0orV1(conn, readKeyVersionBuf, 0); err != nil { return err } - if err = p.receiveAndSendAuthV0(conn); err != nil { + if err = p.receiveAndSendAuthV0(conn, localSaslAuth); err != nil { return err } return nil } -func (p *LocalSasl) receiveAndSendSaslV0orV1(conn DeadlineReaderWriter, keyVersionBuf []byte, version int16) (err error) { +func (p *LocalSasl) receiveAndSendSaslV0orV1(conn DeadlineReaderWriter, keyVersionBuf []byte, version int16) (localSaslAuth LocalSaslAuth, err error) { requestDeadline := time.Now().Add(p.timeout) err = conn.SetDeadline(requestDeadline) if err != nil { - return err + return nil, err } if len(keyVersionBuf) != 8 { - return errors.New("length of keyVersionBuf should be 8") + return nil, errors.New("length of keyVersionBuf should be 8") } // keyVersionBuf has already been read from connection requestKeyVersion := &protocol.RequestKeyVersion{} if err = protocol.Decode(keyVersionBuf, requestKeyVersion); err != nil { - return err + return nil, err } if !(requestKeyVersion.ApiKey == 17 && requestKeyVersion.ApiVersion == version) { - return fmt.Errorf("SaslHandshake version %d is expected, but got %d", version, requestKeyVersion.ApiVersion) + return nil, fmt.Errorf("SaslHandshake version %d is expected, but got %d", version, requestKeyVersion.ApiVersion) } if int32(requestKeyVersion.Length) > protocol.MaxRequestSize { - return protocol.PacketDecodingError{Info: fmt.Sprintf("sasl handshake message of length %d too large", requestKeyVersion.Length)} + return nil, protocol.PacketDecodingError{Info: fmt.Sprintf("sasl handshake message of length %d too large", requestKeyVersion.Length)} } resp := make([]byte, int(requestKeyVersion.Length-4)) if _, err = io.ReadFull(conn, resp); err != nil { - return err + return nil, err } payload := bytes.Join([][]byte{keyVersionBuf[4:], resp}, nil) saslReqV0orV1 := &protocol.SaslHandshakeRequestV0orV1{Version: version} req := &protocol.Request{Body: saslReqV0orV1} if err = protocol.Decode(payload, req); err != nil { - return err + return nil, err } var saslResult error saslErr := protocol.ErrNoError - if saslReqV0orV1.Mechanism != SASLPlain { - saslResult = fmt.Errorf("PLAIN mechanism expected, but got %s", saslReqV0orV1.Mechanism) + localSaslAuth = p.localAuthenticators[saslReqV0orV1.Mechanism] + if localSaslAuth == nil { + mechanisms := make([]string, 0) + for mechanism := range p.localAuthenticators { + mechanisms = append(mechanisms, mechanism) + } + saslResult = fmt.Errorf("PLAIN or OAUTHBEARER mechanism expected, %v are configured, but got %s", mechanisms, saslReqV0orV1.Mechanism) saslErr = protocol.ErrUnsupportedSASLMechanism } - saslResV0 := &protocol.SaslHandshakeResponseV0orV1{Err: saslErr, EnabledMechanisms: []string{SASLPlain}} + saslResV0 := &protocol.SaslHandshakeResponseV0orV1{Err: saslErr, EnabledMechanisms: []string{saslReqV0orV1.Mechanism}} newResponseBuf, err := protocol.Encode(saslResV0) if err != nil { - return err + return nil, err } newHeaderBuf, err := protocol.Encode(&protocol.ResponseHeader{Length: int32(len(newResponseBuf) + 4), CorrelationID: req.CorrelationID}) if err != nil { - return err + return nil, err } if _, err := conn.Write(newHeaderBuf); err != nil { - return err + return nil, err } if _, err := conn.Write(newResponseBuf); err != nil { - return err + return nil, err } - return saslResult + return localSaslAuth, saslResult } -func (p *LocalSasl) receiveAndSendAuthV1(conn DeadlineReaderWriter) (err error) { +func (p *LocalSasl) receiveAndSendAuthV1(conn DeadlineReaderWriter, localSaslAuth LocalSaslAuth) (err error) { requestDeadline := time.Now().Add(p.timeout) err = conn.SetDeadline(requestDeadline) if err != nil { @@ -134,14 +152,15 @@ func (p *LocalSasl) receiveAndSendAuthV1(conn DeadlineReaderWriter) (err error) return err } - authErr := p.doLocalAuth(saslAuthReqV0.SaslAuthBytes) + authErr := localSaslAuth.doLocalAuth(saslAuthReqV0.SaslAuthBytes) var saslAuthResV0 *protocol.SaslAuthenticateResponseV0 if authErr == nil { - saslAuthResV0 = &protocol.SaslAuthenticateResponseV0{Err: protocol.ErrNoError, SaslAuthBytes: make([]byte, 4)} + // Length of SaslAuthBytes !=0 for OAUTHBEARER causes that java SaslClientAuthenticator in INTERMEDIATE state will sent SaslAuthenticate(36) second time + saslAuthResV0 = &protocol.SaslAuthenticateResponseV0{Err: protocol.ErrNoError, SaslAuthBytes: make([]byte, 0)} } else { errMsg := authErr.Error() - saslAuthResV0 = &protocol.SaslAuthenticateResponseV0{Err: protocol.ErrSASLAuthenticationFailed, ErrMsg: &errMsg, SaslAuthBytes: make([]byte, 4)} + saslAuthResV0 = &protocol.SaslAuthenticateResponseV0{Err: protocol.ErrSASLAuthenticationFailed, ErrMsg: &errMsg, SaslAuthBytes: make([]byte, 0)} } newResponseBuf, err := protocol.Encode(saslAuthResV0) @@ -162,7 +181,7 @@ func (p *LocalSasl) receiveAndSendAuthV1(conn DeadlineReaderWriter) (err error) } -func (p *LocalSasl) receiveAndSendAuthV0(conn DeadlineReaderWriter) (err error) { +func (p *LocalSasl) receiveAndSendAuthV0(conn DeadlineReaderWriter, localSaslAuth LocalSaslAuth) (err error) { requestDeadline := time.Now().Add(p.timeout) err = conn.SetDeadline(requestDeadline) if err != nil { @@ -185,7 +204,11 @@ func (p *LocalSasl) receiveAndSendAuthV0(conn DeadlineReaderWriter) (err error) return err } - if err = p.doLocalAuth(saslAuthBytes); err != nil { + if localSaslAuth == nil { + return errors.New("localSaslAuth is nil") + } + + if err = localSaslAuth.doLocalAuth(saslAuthBytes); err != nil { return err } // If the credentials are valid, we would write a 4 byte response filled with null characters. @@ -196,26 +219,3 @@ func (p *LocalSasl) receiveAndSendAuthV0(conn DeadlineReaderWriter) (err error) } return nil } - -func (p *LocalSasl) doLocalAuth(saslAuthBytes []byte) (err error) { - tokens := strings.Split(string(saslAuthBytes), "\x00") - if len(tokens) != 3 { - return fmt.Errorf("invalid SASL/PLAIN request: expected 3 tokens, got %d", len(tokens)) - } - if p.localAuthenticator == nil { - return protocol.PacketDecodingError{Info: "Listener authenticator is not set"} - } - - // logrus.Infof("user: %s , password: %s", tokens[1], tokens[2]) - ok, status, err := p.localAuthenticator.Authenticate(tokens[1], tokens[2]) - if err != nil { - proxyLocalAuthTotal.WithLabelValues("error", "1").Inc() - return err - } - proxyLocalAuthTotal.WithLabelValues(strconv.FormatBool(ok), strconv.Itoa(int(status))).Inc() - - if !ok { - return fmt.Errorf("user %s authentication failed", tokens[1]) - } - return nil -} diff --git a/proxy/sasl_local_auth.go b/proxy/sasl_local_auth.go new file mode 100644 index 00000000..c31ae531 --- /dev/null +++ b/proxy/sasl_local_auth.go @@ -0,0 +1,68 @@ +package proxy + +import ( + "fmt" + "github.com/grepplabs/kafka-proxy/pkg/apis" + "github.com/grepplabs/kafka-proxy/proxy/protocol" + "strconv" + "strings" +) + +type LocalSaslAuth interface { + doLocalAuth(saslAuthBytes []byte) (err error) +} + +type LocalSaslPlain struct { + localAuthenticator apis.PasswordAuthenticator +} + +func NewLocalSaslPlain(localAuthenticator apis.PasswordAuthenticator) *LocalSaslPlain { + return &LocalSaslPlain{ + localAuthenticator: localAuthenticator, + } +} + +// implements LocalSaslAuth +func (p *LocalSaslPlain) doLocalAuth(saslAuthBytes []byte) (err error) { + tokens := strings.Split(string(saslAuthBytes), "\x00") + if len(tokens) != 3 { + return fmt.Errorf("invalid SASL/PLAIN request: expected 3 tokens, got %d", len(tokens)) + } + if p.localAuthenticator == nil { + return protocol.PacketDecodingError{Info: "Listener authenticator is not set"} + } + + // logrus.Infof("user: %s , password: %s", tokens[1], tokens[2]) + ok, status, err := p.localAuthenticator.Authenticate(tokens[1], tokens[2]) + if err != nil { + proxyLocalAuthTotal.WithLabelValues("error", "1").Inc() + return err + } + proxyLocalAuthTotal.WithLabelValues(strconv.FormatBool(ok), strconv.Itoa(int(status))).Inc() + + if !ok { + return fmt.Errorf("user %s authentication failed", tokens[1]) + } + return nil +} + +type LocalSaslOauth struct { + saslOAuthBearer SaslOAuthBearer +} + +func NewLocalSaslOauth() *LocalSaslOauth { + return &LocalSaslOauth{ + saslOAuthBearer: SaslOAuthBearer{}, + } +} + +// implements LocalSaslAuth +func (p *LocalSaslOauth) doLocalAuth(saslAuthBytes []byte) (err error) { + token, err := p.saslOAuthBearer.GetToken(saslAuthBytes) + if err != nil { + return err + } + //TODO: implement TokenAuthenticator + _ = token + return nil +} diff --git a/proxy/sasl_oauthbearer.go b/proxy/sasl_oauthbearer.go new file mode 100644 index 00000000..327edf02 --- /dev/null +++ b/proxy/sasl_oauthbearer.go @@ -0,0 +1,75 @@ +package proxy + +import ( + "fmt" + "github.com/pkg/errors" + "regexp" + "strings" +) + +// https://tools.ietf.org/html/rfc7628#section-3.1 +// https://tools.ietf.org/html/rfc5801#section-4 +const ( + saslOauthSeparator = "\u0001" + saslOauthSaslName = "(?:[\\x01-\\x7F&&[^=,]]|=2C|=3D)+" + saslOauthKey = "[A-Za-z]+" + saslOauthValue = "[\\x21-\\x7E \t\r\n]+" + saslOauthAuthKey = "auth" +) + +var ( + saslOauthKVPairs = fmt.Sprintf("(%s=%s%s)*", saslOauthKey, saslOauthValue, saslOauthSeparator) + saslOauthAuthPattern = regexp.MustCompile("(?P[\\w]+)[ ]+(?P[-_.a-zA-Z0-9]+)") + saslOauthClientInitialResponsePattern = regexp.MustCompile(fmt.Sprintf("n,(a=(?P%s))?,%s(?P%s)%s", saslOauthSaslName, saslOauthSeparator, saslOauthKVPairs, saslOauthSeparator)) +) + +type SaslOAuthBearer struct{} + +func (p SaslOAuthBearer) GetToken(saslAuthBytes []byte) (string, error) { + match := saslOauthClientInitialResponsePattern.FindSubmatch(saslAuthBytes) + + result := make(map[string][]byte) + for i, name := range saslOauthClientInitialResponsePattern.SubexpNames() { + if i != 0 && name != "" { + result[name] = match[i] + } + } + kvpairs := result["kvpairs"] + properties := p.parseMap(string(kvpairs), "=", saslOauthSeparator) + return p.parseToken(properties[saslOauthAuthKey]) +} + +func (SaslOAuthBearer) parseToken(auth string) (string, error) { + if auth == "" { + return "", errors.New("invalid OAUTHBEARER initial client response: 'auth' not specified") + } + match := saslOauthAuthPattern.FindStringSubmatch(auth) + result := make(map[string]string) + for i, name := range saslOauthAuthPattern.SubexpNames() { + if i != 0 && name != "" { + result[name] = match[i] + } + } + if !strings.EqualFold(result["scheme"], "bearer") { + return "", fmt.Errorf("invalid scheme in OAUTHBEARER initial client response: %s", result["scheme"]) + } + token := result["token"] + if token == "" { + return "", errors.New("invalid OAUTHBEARER initial client response: 'token' is missing") + } + return token, nil +} + +func (SaslOAuthBearer) parseMap(mapStr string, keyValueSeparator string, elementSeparator string) map[string]string { + result := make(map[string]string) + if mapStr == "" { + return result + } + for _, attrval := range strings.Split(mapStr, elementSeparator) { + kv := strings.SplitN(attrval, keyValueSeparator, 2) + if len(kv) == 2 { + result[kv[0]] = kv[1] + } + } + return result +} diff --git a/proxy/sasl_oauthbearer_test.go b/proxy/sasl_oauthbearer_test.go new file mode 100644 index 00000000..52710c4b --- /dev/null +++ b/proxy/sasl_oauthbearer_test.go @@ -0,0 +1,19 @@ +package proxy + +import ( + "encoding/hex" + "github.com/stretchr/testify/assert" + "testing" +) + +func TestSaslOauthParseToken(t *testing.T) { + a := assert.New(t) + + saslBytes := "6e2c2c01617574683d4265617265722065794a68624763694f694a756232356c496e302e65794a6c654841694f6a45754e544d354e5445324e6a6b304e44453452546b73496d6c68644349364d5334314d7a6b314d544d774f5451304d5468464f53776963335669496a6f695957787059325579496e302e0101" + saslAuthBytes, err := hex.DecodeString(saslBytes) + a.Nil(err) + + token, err := SaslOAuthBearer{}.GetToken(saslAuthBytes) + a.Nil(err) + a.Equal("eyJhbGciOiJub25lIn0.eyJleHAiOjEuNTM5NTE2Njk0NDE4RTksImlhdCI6MS41Mzk1MTMwOTQ0MThFOSwic3ViIjoiYWxpY2UyIn0.", token) +}