Skip to content

Commit

Permalink
Local auth OAUTHBEARER
Browse files Browse the repository at this point in the history
  • Loading branch information
everesio committed Oct 20, 2018
1 parent f3f0b40 commit c405aff
Show file tree
Hide file tree
Showing 7 changed files with 229 additions and 65 deletions.
8 changes: 4 additions & 4 deletions proxy/client.go
Expand Up @@ -93,10 +93,10 @@ func NewClient(conns *ConnSet, c *config.Config, netAddressMappingFunc config.Ne
ResponseBufferSize: c.Proxy.ResponseBufferSize,
ReadTimeout: c.Kafka.ReadTimeout,
WriteTimeout: c.Kafka.WriteTimeout,
LocalSasl: &LocalSasl{
enabled: c.Auth.Local.Enable,
timeout: c.Auth.Local.Timeout,
localAuthenticator: passwordAuthenticator},
LocalSasl: NewLocalSasl(
c.Auth.Local.Enable,
c.Auth.Local.Timeout,
passwordAuthenticator),
AuthServer: &AuthServer{
enabled: c.Auth.Gateway.Server.Enable,
magic: c.Auth.Gateway.Server.Magic,
Expand Down
9 changes: 5 additions & 4 deletions proxy/processor_default.go
Expand Up @@ -4,6 +4,7 @@ import (
"errors"
"fmt"
"github.com/grepplabs/kafka-proxy/proxy/protocol"
"github.com/sirupsen/logrus"
"io"
"strconv"
"time"
Expand Down Expand Up @@ -32,7 +33,7 @@ func (handler *DefaultRequestHandler) handleRequest(dst DeadlineWriter, src Dead
if err = protocol.Decode(keyVersionBuf, requestKeyVersion); err != nil {
return true, err
}
//logrus.Printf("Kafka request length %v, key %v, version %v", requestKeyVersion.Length, requestKeyVersion.ApiKey, requestKeyVersion.ApiVersion)
logrus.Debugf("Kafka request key %v, version %v, length %v", requestKeyVersion.ApiKey, requestKeyVersion.ApiVersion, requestKeyVersion.Length)

if requestKeyVersion.ApiKey < minRequestApiKey || requestKeyVersion.ApiKey > maxRequestApiKey {
return true, fmt.Errorf("api key %d is invalid", requestKeyVersion.ApiKey)
Expand All @@ -55,11 +56,11 @@ func (handler *DefaultRequestHandler) handleRequest(dst DeadlineWriter, src Dead
case apiKeySaslHandshake:
switch requestKeyVersion.ApiVersion {
case 0:
if err = ctx.localSasl.receiveAndSendSASLPlainAuthV0(src, keyVersionBuf); err != nil {
if err = ctx.localSasl.receiveAndSendSASLAuthV0(src, keyVersionBuf); err != nil {
return true, err
}
case 1:
if err = ctx.localSasl.receiveAndSendSASLPlainAuthV1(src, keyVersionBuf); err != nil {
if err = ctx.localSasl.receiveAndSendSASLAuthV1(src, keyVersionBuf); err != nil {
return true, err
}
default:
Expand Down Expand Up @@ -132,7 +133,7 @@ func (handler *DefaultResponseHandler) handleResponse(dst DeadlineWriter, src De
return true, err
}
proxyResponsesBytes.WithLabelValues(ctx.brokerAddress).Add(float64(responseHeader.Length + 4))
//logrus.Printf("Kafka response lenght %v for key %v, version %v", responseHeader.Length, requestKeyVersion.ApiKey, requestKeyVersion.ApiVersion)
logrus.Debugf("Kafka response key %v, version %v, length %v", requestKeyVersion.ApiKey, requestKeyVersion.ApiVersion, responseHeader.Length)

responseDeadline := time.Now().Add(ctx.timeout)
err = dst.SetWriteDeadline(responseDeadline)
Expand Down
3 changes: 2 additions & 1 deletion proxy/sasl.go
Expand Up @@ -11,7 +11,8 @@ import (
)

const (
SASLPlain = "PLAIN"
SASLPlain = "PLAIN"
SASLOAuthBearer = "OAUTHBEARER"
)

type SASLPlainAuth struct {
Expand Down
112 changes: 56 additions & 56 deletions proxy/sasl_local.go
Expand Up @@ -8,98 +8,116 @@ import (
"github.com/grepplabs/kafka-proxy/pkg/apis"
"github.com/grepplabs/kafka-proxy/proxy/protocol"
"io"
"strconv"
"strings"
"time"
)

type LocalSasl struct {
enabled bool
timeout time.Duration
localAuthenticator apis.PasswordAuthenticator
enabled bool
timeout time.Duration
localAuthenticators map[string]LocalSaslAuth
}

func (p *LocalSasl) receiveAndSendSASLPlainAuthV1(conn DeadlineReaderWriter, readKeyVersionBuf []byte) (err error) {
if err = p.receiveAndSendSaslV0orV1(conn, readKeyVersionBuf, 1); err != nil {
func NewLocalSasl(enabled bool, timeout time.Duration, passwordAuthenticator apis.PasswordAuthenticator) *LocalSasl {
localAuthenticators := make(map[string]LocalSaslAuth)
if passwordAuthenticator != nil {
localAuthenticators[SASLPlain] = NewLocalSaslPlain(passwordAuthenticator)
}
localAuthenticators[SASLOAuthBearer] = NewLocalSaslOauth()
return &LocalSasl{
enabled: enabled,
timeout: timeout,
localAuthenticators: localAuthenticators,
}
}

func (p *LocalSasl) receiveAndSendSASLAuthV1(conn DeadlineReaderWriter, readKeyVersionBuf []byte) (err error) {
var localSaslAuth LocalSaslAuth
if localSaslAuth, err = p.receiveAndSendSaslV0orV1(conn, readKeyVersionBuf, 1); err != nil {
return err
}
if err = p.receiveAndSendAuthV1(conn); err != nil {
if err = p.receiveAndSendAuthV1(conn, localSaslAuth); err != nil {
return err
}
return nil
}

func (p *LocalSasl) receiveAndSendSASLPlainAuthV0(conn DeadlineReaderWriter, readKeyVersionBuf []byte) (err error) {
if err = p.receiveAndSendSaslV0orV1(conn, readKeyVersionBuf, 0); err != nil {
func (p *LocalSasl) receiveAndSendSASLAuthV0(conn DeadlineReaderWriter, readKeyVersionBuf []byte) (err error) {
var localSaslAuth LocalSaslAuth
if localSaslAuth, err = p.receiveAndSendSaslV0orV1(conn, readKeyVersionBuf, 0); err != nil {
return err
}
if err = p.receiveAndSendAuthV0(conn); err != nil {
if err = p.receiveAndSendAuthV0(conn, localSaslAuth); err != nil {
return err
}
return nil
}

func (p *LocalSasl) receiveAndSendSaslV0orV1(conn DeadlineReaderWriter, keyVersionBuf []byte, version int16) (err error) {
func (p *LocalSasl) receiveAndSendSaslV0orV1(conn DeadlineReaderWriter, keyVersionBuf []byte, version int16) (localSaslAuth LocalSaslAuth, err error) {
requestDeadline := time.Now().Add(p.timeout)
err = conn.SetDeadline(requestDeadline)
if err != nil {
return err
return nil, err
}

if len(keyVersionBuf) != 8 {
return errors.New("length of keyVersionBuf should be 8")
return nil, errors.New("length of keyVersionBuf should be 8")
}
// keyVersionBuf has already been read from connection
requestKeyVersion := &protocol.RequestKeyVersion{}
if err = protocol.Decode(keyVersionBuf, requestKeyVersion); err != nil {
return err
return nil, err
}
if !(requestKeyVersion.ApiKey == 17 && requestKeyVersion.ApiVersion == version) {
return fmt.Errorf("SaslHandshake version %d is expected, but got %d", version, requestKeyVersion.ApiVersion)
return nil, fmt.Errorf("SaslHandshake version %d is expected, but got %d", version, requestKeyVersion.ApiVersion)
}

if int32(requestKeyVersion.Length) > protocol.MaxRequestSize {
return protocol.PacketDecodingError{Info: fmt.Sprintf("sasl handshake message of length %d too large", requestKeyVersion.Length)}
return nil, protocol.PacketDecodingError{Info: fmt.Sprintf("sasl handshake message of length %d too large", requestKeyVersion.Length)}
}

resp := make([]byte, int(requestKeyVersion.Length-4))
if _, err = io.ReadFull(conn, resp); err != nil {
return err
return nil, err
}
payload := bytes.Join([][]byte{keyVersionBuf[4:], resp}, nil)

saslReqV0orV1 := &protocol.SaslHandshakeRequestV0orV1{Version: version}
req := &protocol.Request{Body: saslReqV0orV1}
if err = protocol.Decode(payload, req); err != nil {
return err
return nil, err
}

var saslResult error
saslErr := protocol.ErrNoError
if saslReqV0orV1.Mechanism != SASLPlain {
saslResult = fmt.Errorf("PLAIN mechanism expected, but got %s", saslReqV0orV1.Mechanism)
localSaslAuth = p.localAuthenticators[saslReqV0orV1.Mechanism]
if localSaslAuth == nil {
mechanisms := make([]string, 0)
for mechanism := range p.localAuthenticators {
mechanisms = append(mechanisms, mechanism)
}
saslResult = fmt.Errorf("PLAIN or OAUTHBEARER mechanism expected, %v are configured, but got %s", mechanisms, saslReqV0orV1.Mechanism)
saslErr = protocol.ErrUnsupportedSASLMechanism
}

saslResV0 := &protocol.SaslHandshakeResponseV0orV1{Err: saslErr, EnabledMechanisms: []string{SASLPlain}}
saslResV0 := &protocol.SaslHandshakeResponseV0orV1{Err: saslErr, EnabledMechanisms: []string{saslReqV0orV1.Mechanism}}
newResponseBuf, err := protocol.Encode(saslResV0)
if err != nil {
return err
return nil, err
}
newHeaderBuf, err := protocol.Encode(&protocol.ResponseHeader{Length: int32(len(newResponseBuf) + 4), CorrelationID: req.CorrelationID})
if err != nil {
return err
return nil, err
}
if _, err := conn.Write(newHeaderBuf); err != nil {
return err
return nil, err
}
if _, err := conn.Write(newResponseBuf); err != nil {
return err
return nil, err
}
return saslResult
return localSaslAuth, saslResult
}

func (p *LocalSasl) receiveAndSendAuthV1(conn DeadlineReaderWriter) (err error) {
func (p *LocalSasl) receiveAndSendAuthV1(conn DeadlineReaderWriter, localSaslAuth LocalSaslAuth) (err error) {
requestDeadline := time.Now().Add(p.timeout)
err = conn.SetDeadline(requestDeadline)
if err != nil {
Expand Down Expand Up @@ -134,14 +152,15 @@ func (p *LocalSasl) receiveAndSendAuthV1(conn DeadlineReaderWriter) (err error)
return err
}

authErr := p.doLocalAuth(saslAuthReqV0.SaslAuthBytes)
authErr := localSaslAuth.doLocalAuth(saslAuthReqV0.SaslAuthBytes)

var saslAuthResV0 *protocol.SaslAuthenticateResponseV0
if authErr == nil {
saslAuthResV0 = &protocol.SaslAuthenticateResponseV0{Err: protocol.ErrNoError, SaslAuthBytes: make([]byte, 4)}
// Length of SaslAuthBytes !=0 for OAUTHBEARER causes that java SaslClientAuthenticator in INTERMEDIATE state will sent SaslAuthenticate(36) second time
saslAuthResV0 = &protocol.SaslAuthenticateResponseV0{Err: protocol.ErrNoError, SaslAuthBytes: make([]byte, 0)}
} else {
errMsg := authErr.Error()
saslAuthResV0 = &protocol.SaslAuthenticateResponseV0{Err: protocol.ErrSASLAuthenticationFailed, ErrMsg: &errMsg, SaslAuthBytes: make([]byte, 4)}
saslAuthResV0 = &protocol.SaslAuthenticateResponseV0{Err: protocol.ErrSASLAuthenticationFailed, ErrMsg: &errMsg, SaslAuthBytes: make([]byte, 0)}
}

newResponseBuf, err := protocol.Encode(saslAuthResV0)
Expand All @@ -162,7 +181,7 @@ func (p *LocalSasl) receiveAndSendAuthV1(conn DeadlineReaderWriter) (err error)

}

func (p *LocalSasl) receiveAndSendAuthV0(conn DeadlineReaderWriter) (err error) {
func (p *LocalSasl) receiveAndSendAuthV0(conn DeadlineReaderWriter, localSaslAuth LocalSaslAuth) (err error) {
requestDeadline := time.Now().Add(p.timeout)
err = conn.SetDeadline(requestDeadline)
if err != nil {
Expand All @@ -185,7 +204,11 @@ func (p *LocalSasl) receiveAndSendAuthV0(conn DeadlineReaderWriter) (err error)
return err
}

if err = p.doLocalAuth(saslAuthBytes); err != nil {
if localSaslAuth == nil {
return errors.New("localSaslAuth is nil")
}

if err = localSaslAuth.doLocalAuth(saslAuthBytes); err != nil {
return err
}
// If the credentials are valid, we would write a 4 byte response filled with null characters.
Expand All @@ -196,26 +219,3 @@ func (p *LocalSasl) receiveAndSendAuthV0(conn DeadlineReaderWriter) (err error)
}
return nil
}

func (p *LocalSasl) doLocalAuth(saslAuthBytes []byte) (err error) {
tokens := strings.Split(string(saslAuthBytes), "\x00")
if len(tokens) != 3 {
return fmt.Errorf("invalid SASL/PLAIN request: expected 3 tokens, got %d", len(tokens))
}
if p.localAuthenticator == nil {
return protocol.PacketDecodingError{Info: "Listener authenticator is not set"}
}

// logrus.Infof("user: %s , password: %s", tokens[1], tokens[2])
ok, status, err := p.localAuthenticator.Authenticate(tokens[1], tokens[2])
if err != nil {
proxyLocalAuthTotal.WithLabelValues("error", "1").Inc()
return err
}
proxyLocalAuthTotal.WithLabelValues(strconv.FormatBool(ok), strconv.Itoa(int(status))).Inc()

if !ok {
return fmt.Errorf("user %s authentication failed", tokens[1])
}
return nil
}
68 changes: 68 additions & 0 deletions proxy/sasl_local_auth.go
@@ -0,0 +1,68 @@
package proxy

import (
"fmt"
"github.com/grepplabs/kafka-proxy/pkg/apis"
"github.com/grepplabs/kafka-proxy/proxy/protocol"
"strconv"
"strings"
)

type LocalSaslAuth interface {
doLocalAuth(saslAuthBytes []byte) (err error)
}

type LocalSaslPlain struct {
localAuthenticator apis.PasswordAuthenticator
}

func NewLocalSaslPlain(localAuthenticator apis.PasswordAuthenticator) *LocalSaslPlain {
return &LocalSaslPlain{
localAuthenticator: localAuthenticator,
}
}

// implements LocalSaslAuth
func (p *LocalSaslPlain) doLocalAuth(saslAuthBytes []byte) (err error) {
tokens := strings.Split(string(saslAuthBytes), "\x00")
if len(tokens) != 3 {
return fmt.Errorf("invalid SASL/PLAIN request: expected 3 tokens, got %d", len(tokens))
}
if p.localAuthenticator == nil {
return protocol.PacketDecodingError{Info: "Listener authenticator is not set"}
}

// logrus.Infof("user: %s , password: %s", tokens[1], tokens[2])
ok, status, err := p.localAuthenticator.Authenticate(tokens[1], tokens[2])
if err != nil {
proxyLocalAuthTotal.WithLabelValues("error", "1").Inc()
return err
}
proxyLocalAuthTotal.WithLabelValues(strconv.FormatBool(ok), strconv.Itoa(int(status))).Inc()

if !ok {
return fmt.Errorf("user %s authentication failed", tokens[1])
}
return nil
}

type LocalSaslOauth struct {
saslOAuthBearer SaslOAuthBearer
}

func NewLocalSaslOauth() *LocalSaslOauth {
return &LocalSaslOauth{
saslOAuthBearer: SaslOAuthBearer{},
}
}

// implements LocalSaslAuth
func (p *LocalSaslOauth) doLocalAuth(saslAuthBytes []byte) (err error) {
token, err := p.saslOAuthBearer.GetToken(saslAuthBytes)
if err != nil {
return err
}
//TODO: implement TokenAuthenticator
_ = token
return nil
}

0 comments on commit c405aff

Please sign in to comment.