diff --git a/.travis.yml b/.travis.yml index 21baac83..f4da0d1e 100644 --- a/.travis.yml +++ b/.travis.yml @@ -3,7 +3,7 @@ sudo: false language: go go: - - "1.10.x" + - "1.11.x" env: global: diff --git a/Dockerfile.build b/Dockerfile.build index 64472c95..4e61dbe2 100644 --- a/Dockerfile.build +++ b/Dockerfile.build @@ -1,4 +1,4 @@ -FROM golang:1.10 as builder +FROM golang:1.11 as builder ARG GOOS=linux ARG GOARCH=amd64 diff --git a/Makefile b/Makefile index a1cdedba..0e9fb393 100644 --- a/Makefile +++ b/Makefile @@ -8,7 +8,7 @@ VERSION ?= $(shell git describe --tags --always --dirty) GOPKGS = $(shell go list ./... | grep -v /vendor/) BUILD_FLAGS ?= LDFLAGS ?= -X github.com/grepplabs/kafka-proxy/config.Version=$(VERSION) -w -s -TAG ?= "v0.0.8" +TAG ?= "v0.1.0" GOARCH ?= amd64 GOOS ?= linux @@ -46,11 +46,11 @@ release: clean protoc.local-auth: protoc -I plugin/local-auth/proto/ plugin/local-auth/proto/auth.proto --go_out=plugins=grpc:plugin/local-auth/proto/ -protoc.gateway-client: - protoc -I plugin/gateway-client/proto/ plugin/gateway-client/proto/token-provider.proto --go_out=plugins=grpc:plugin/gateway-client/proto/ +protoc.token-provider: + protoc -I plugin/token-provider/proto/ plugin/token-provider/proto/token-provider.proto --go_out=plugins=grpc:plugin/token-provider/proto/ -protoc.gateway-server: - protoc -I plugin/gateway-server/proto/ plugin/gateway-server/proto/token-info.proto --go_out=plugins=grpc:plugin/gateway-server/proto/ +protoc.token-info: + protoc -I plugin/token-info/proto/ plugin/token-info/proto/token-info.proto --go_out=plugins=grpc:plugin/token-info/proto/ plugin.auth-user: CGO_ENABLED=0 go build -o build/auth-user $(BUILD_FLAGS) -ldflags "$(LDFLAGS)" cmd/plugin-auth-user/main.go @@ -64,7 +64,14 @@ plugin.google-id-provider: plugin.google-id-info: CGO_ENABLED=0 go build -o build/google-id-info $(BUILD_FLAGS) -ldflags "$(LDFLAGS)" cmd/plugin-googleid-info/main.go -all: build plugin.auth-user plugin.auth-ldap plugin.google-id-provider plugin.google-id-info +plugin.unsecured-jwt-info: + CGO_ENABLED=0 go build -o build/unsecured-jwt-info $(BUILD_FLAGS) -ldflags "$(LDFLAGS)" cmd/plugin-unsecured-jwt-info/main.go + +plugin.unsecured-jwt-provider: + CGO_ENABLED=0 go build -o build/unsecured-jwt-provider $(BUILD_FLAGS) -ldflags "$(LDFLAGS)" cmd/plugin-unsecured-jwt-provider/main.go + + +all: build plugin.auth-user plugin.auth-ldap plugin.google-id-provider plugin.google-id-info plugin.unsecured-jwt-info plugin.unsecured-jwt-provider clean: @rm -rf build diff --git a/README.md b/README.md index fd886346..bd3b5cbd 100644 --- a/README.md +++ b/README.md @@ -34,11 +34,11 @@ See: Linux - curl -Ls https://github.com/grepplabs/kafka-proxy/releases/download/v0.0.8/kafka-proxy_0.0.8_linux_amd64.tar.gz | tar xz + curl -Ls https://github.com/grepplabs/kafka-proxy/releases/download/v0.1.0/kafka-proxy_0.1.0_linux_amd64.tar.gz | tar xz macOS - curl -Ls https://github.com/grepplabs/kafka-proxy/releases/download/v0.0.8/kafka-proxy_0.0.8_darwin_amd64.tar.gz | tar xz + curl -Ls https://github.com/grepplabs/kafka-proxy/releases/download/v0.1.0/kafka-proxy_0.1.0_darwin_amd64.tar.gz | tar xz 2. Move the binary in to your PATH. @@ -75,6 +75,7 @@ See: --auth-local-command string Path to authentication plugin binary --auth-local-enable Enable local SASL/PLAIN authentication performed by listener - SASL handshake will not be passed to kafka brokers --auth-local-log-level string Log level of the auth plugin (default "trace") + --auth-local-mechanism string SASL mechanism used for local authentication: PLAIN or OAUTHBEARER (default "PLAIN") --auth-local-param stringArray Authentication plugin parameter --auth-local-timeout duration Authentication timeout (default 10s) --bootstrap-server-mapping stringArray Mapping of Kafka bootstrap server address to local address (host:port,host:port(,advhost:advport)) @@ -84,7 +85,7 @@ See: --dynamic-listeners-disable Disable dynamic listeners. --external-server-mapping stringArray Mapping of Kafka server address to external address (host:port,host:port). A listener for the external address is not started --forbidden-api-keys intSlice Forbidden Kafka request types. The restriction should prevent some Kafka operations e.g. 20 - DeleteTopics - --forward-proxy string URL of the forward proxy. Supported schemas are http and socks5 + --forward-proxy string URL of the forward proxy. Supported schemas are socks5 and http -h, --help help for server --http-disable Disable HTTP endpoints --http-health-path string Path on which to health endpoint (default "/health") @@ -112,9 +113,15 @@ See: --proxy-listener-write-buffer-size int Sets the size of the operating system's transmit buffer associated with the connection. If zero, system default is used --proxy-request-buffer-size int Request buffer size pro tcp connection (default 4096) --proxy-response-buffer-size int Response buffer size pro tcp connection (default 4096) - --sasl-enable Connect using SASL/PLAIN + --sasl-enable Connect using SASL --sasl-jaas-config-file string Location of JAAS config file with SASL username and password --sasl-password string SASL user password + --sasl-plugin-command string Path to authentication plugin binary + --sasl-plugin-enable Use plugin for SASL authentication + --sasl-plugin-log-level string Log level of the auth plugin (default "trace") + --sasl-plugin-mechanism string SASL mechanism used for proxy authentication: PLAIN or OAUTHBEARER (default "OAUTHBEARER") + --sasl-plugin-param stringArray Authentication plugin parameter + --sasl-plugin-timeout duration Authentication timeout (default 10s) --sasl-username string SASL user name --tls-ca-chain-cert-file string PEM encoded CA's certificate file --tls-client-cert-file string PEM encoded file with client certificate @@ -123,8 +130,6 @@ See: --tls-enable Whether or not to use TLS when connecting to the broker --tls-insecure-skip-verify It controls whether a client verifies the server's certificate chain and host name - - ### Usage example kafka-proxy server --bootstrap-server-mapping "192.168.99.100:32400,0.0.0.0:32399" @@ -144,14 +149,29 @@ See: --external-server-mapping "192.168.99.100:32402,127.0.0.1:32403" \ --forbidden-api-keys 20 + + export BOOTSTRAP_SERVER_MAPPING="192.168.99.100:32401,0.0.0.0:32402 192.168.99.100:32402,0.0.0.0:32403" && kafka-proxy server + +### SASL authentication initiated by proxy example + +SASL authentication is initiated by the proxy. SASL authentication is disabled on the clients and enabled on the Kafka brokers. + kafka-proxy server --bootstrap-server-mapping "kafka-0.grepplabs.com:9093,0.0.0.0:32399" \ --tls-enable --tls-insecure-skip-verify \ --sasl-enable --sasl-username myuser --sasl-password mysecret - export BOOTSTRAP_SERVER_MAPPING="192.168.99.100:32401,0.0.0.0:32402 192.168.99.100:32402,0.0.0.0:32403" && kafka-proxy server + make clean build plugin.unsecured-jwt-provider && build/kafka-proxy server \ + --sasl-enable \ + --sasl-plugin-enable \ + --sasl-plugin-mechanism "OAUTHBEARER" \ + --sasl-plugin-command build/unsecured-jwt-provider \ + --sasl-plugin-param "--claim-sub=alice" \ + --bootstrap-server-mapping "192.168.99.100:32400,127.0.0.1:32400" ### Proxy authentication example +SASL authentication is performed by the proxy. SASL authentication is enabled on the clients and disabled on the Kafka brokers. + make clean build plugin.auth-user && build/kafka-proxy server --proxy-listener-key-file "server-key.pem" \ --proxy-listener-cert-file "server-cert.pem" \ --proxy-listener-ca-chain-cert-file "ca.pem" \ @@ -169,6 +189,14 @@ See: --auth-local-param "--user-attr=uid" \ --bootstrap-server-mapping "192.168.99.100:32400,127.0.0.1:32400" + make clean build plugin.unsecured-jwt-info && build/kafka-proxy server \ + --auth-local-enable \ + --auth-local-command build/unsecured-jwt-info \ + --auth-local-mechanism "OAUTHBEARER" \ + --auth-local-param "--claim-sub=alice" \ + --auth-local-param "--claim-sub=bob" \ + --bootstrap-server-mapping "192.168.99.100:32400,127.0.0.1:32400" + ### Kafka Gateway example Authentication between Kafka Proxy Client and Kafka Proxy Server with Google-ID (service account JWT) diff --git a/cmd/kafka-proxy/server.go b/cmd/kafka-proxy/server.go index 3bb08233..f7a18dbc 100644 --- a/cmd/kafka-proxy/server.go +++ b/cmd/kafka-proxy/server.go @@ -21,9 +21,9 @@ import ( "errors" "github.com/grepplabs/kafka-proxy/pkg/apis" - gatewayclient "github.com/grepplabs/kafka-proxy/plugin/gateway-client/shared" - gatewayserver "github.com/grepplabs/kafka-proxy/plugin/gateway-server/shared" localauth "github.com/grepplabs/kafka-proxy/plugin/local-auth/shared" + tokeninfo "github.com/grepplabs/kafka-proxy/plugin/token-info/shared" + tokenprovider "github.com/grepplabs/kafka-proxy/plugin/token-provider/shared" "github.com/hashicorp/go-hclog" "github.com/hashicorp/go-plugin" "strings" @@ -101,6 +101,7 @@ func initFlags() { // local authentication plugin Server.Flags().BoolVar(&c.Auth.Local.Enable, "auth-local-enable", false, "Enable local SASL/PLAIN authentication performed by listener - SASL handshake will not be passed to kafka brokers") Server.Flags().StringVar(&c.Auth.Local.Command, "auth-local-command", "", "Path to authentication plugin binary") + Server.Flags().StringVar(&c.Auth.Local.Mechanism, "auth-local-mechanism", "PLAIN", "SASL mechanism used for local authentication: PLAIN or OAUTHBEARER") Server.Flags().StringArrayVar(&c.Auth.Local.Parameters, "auth-local-param", []string{}, "Authentication plugin parameter") Server.Flags().StringVar(&c.Auth.Local.LogLevel, "auth-local-log-level", "trace", "Log level of the auth plugin") Server.Flags().DurationVar(&c.Auth.Local.Timeout, "auth-local-timeout", 10*time.Second, "Authentication timeout") @@ -142,12 +143,20 @@ func initFlags() { Server.Flags().StringVar(&c.Kafka.TLS.ClientKeyPassword, "tls-client-key-password", "", "Password to decrypt rsa private key") Server.Flags().StringVar(&c.Kafka.TLS.CAChainCertFile, "tls-ca-chain-cert-file", "", "PEM encoded CA's certificate file") - // SASL - Server.Flags().BoolVar(&c.Kafka.SASL.Enable, "sasl-enable", false, "Connect using SASL/PLAIN") + // SASL by Proxy + Server.Flags().BoolVar(&c.Kafka.SASL.Enable, "sasl-enable", false, "Connect using SASL") Server.Flags().StringVar(&c.Kafka.SASL.Username, "sasl-username", "", "SASL user name") Server.Flags().StringVar(&c.Kafka.SASL.Password, "sasl-password", "", "SASL user password") Server.Flags().StringVar(&c.Kafka.SASL.JaasConfigFile, "sasl-jaas-config-file", "", "Location of JAAS config file with SASL username and password") + // SASL by Proxy plugin + Server.Flags().BoolVar(&c.Kafka.SASL.Plugin.Enable, "sasl-plugin-enable", false, "Use plugin for SASL authentication") + Server.Flags().StringVar(&c.Kafka.SASL.Plugin.Command, "sasl-plugin-command", "", "Path to authentication plugin binary") + Server.Flags().StringVar(&c.Kafka.SASL.Plugin.Mechanism, "sasl-plugin-mechanism", "OAUTHBEARER", "SASL mechanism used for proxy authentication: PLAIN or OAUTHBEARER") + Server.Flags().StringArrayVar(&c.Kafka.SASL.Plugin.Parameters, "sasl-plugin-param", []string{}, "Authentication plugin parameter") + Server.Flags().StringVar(&c.Kafka.SASL.Plugin.LogLevel, "sasl-plugin-log-level", "trace", "Log level of the auth plugin") + Server.Flags().DurationVar(&c.Kafka.SASL.Plugin.Timeout, "sasl-plugin-timeout", 10*time.Second, "Authentication timeout") + // Web Server.Flags().BoolVar(&c.Http.Disable, "http-disable", false, "Disable HTTP endpoints") Server.Flags().StringVar(&c.Http.ListenAddress, "http-listen-address", "0.0.0.0:9080", "Address that kafka-proxy is listening on") @@ -172,47 +181,115 @@ func initFlags() { func Run(_ *cobra.Command, _ []string) { logrus.Infof("Starting kafka-proxy version %s", config.Version) - var passwordAuthenticator apis.PasswordAuthenticator + var localPasswordAuthenticator apis.PasswordAuthenticator + var localTokenAuthenticator apis.TokenInfo if c.Auth.Local.Enable { - var err error - factory, ok := registry.GetComponent(new(apis.PasswordAuthenticatorFactory), c.Auth.Local.Command).(apis.PasswordAuthenticatorFactory) - if ok { - logrus.Infof("Using built-in PasswordAuthenticator") - passwordAuthenticator, err = factory.New(c.Auth.Local.Parameters) - if err != nil { - logrus.Fatal(err) - } - } else { - client := NewPluginClient(localauth.Handshake, localauth.PluginMap, c.Auth.Local.LogLevel, c.Auth.Local.Command, c.Auth.Local.Parameters) - defer client.Kill() - - rpcClient, err := client.Client() - if err != nil { - logrus.Fatal(err) + switch c.Auth.Local.Mechanism { + case "PLAIN": + var err error + factory, ok := registry.GetComponent(new(apis.PasswordAuthenticatorFactory), c.Auth.Local.Command).(apis.PasswordAuthenticatorFactory) + if ok { + logrus.Infof("Using built-in '%s' PasswordAuthenticator for local PasswordAuthenticator", c.Auth.Local.Command) + localPasswordAuthenticator, err = factory.New(c.Auth.Local.Parameters) + if err != nil { + logrus.Fatal(err) + } + } else { + client := NewPluginClient(localauth.Handshake, localauth.PluginMap, c.Auth.Local.LogLevel, c.Auth.Local.Command, c.Auth.Local.Parameters) + defer client.Kill() + + rpcClient, err := client.Client() + if err != nil { + logrus.Fatal(err) + } + raw, err := rpcClient.Dispense("passwordAuthenticator") + if err != nil { + logrus.Fatal(err) + } + localPasswordAuthenticator, ok = raw.(apis.PasswordAuthenticator) + if !ok { + logrus.Fatal(errors.New("unsupported PasswordAuthenticator plugin type")) + } } - raw, err := rpcClient.Dispense("passwordAuthenticator") - if err != nil { - logrus.Fatal(err) + case "OAUTHBEARER": + var err error + factory, ok := registry.GetComponent(new(apis.TokenInfoFactory), c.Auth.Local.Command).(apis.TokenInfoFactory) + if ok { + logrus.Infof("Using built-in '%s' TokenInfo for local TokenAuthenticator", c.Auth.Local.Command) + + localTokenAuthenticator, err = factory.New(c.Auth.Local.Parameters) + if err != nil { + logrus.Fatal(err) + } + } else { + client := NewPluginClient(tokeninfo.Handshake, tokeninfo.PluginMap, c.Auth.Local.LogLevel, c.Auth.Local.Command, c.Auth.Local.Parameters) + defer client.Kill() + + rpcClient, err := client.Client() + if err != nil { + logrus.Fatal(err) + } + raw, err := rpcClient.Dispense("tokenInfo") + if err != nil { + logrus.Fatal(err) + } + localTokenAuthenticator, ok = raw.(apis.TokenInfo) + if !ok { + logrus.Fatal(errors.New("unsupported TokenInfo plugin type")) + } } - passwordAuthenticator, ok = raw.(apis.PasswordAuthenticator) - if !ok { - logrus.Fatal(errors.New("unsupported PasswordAuthenticator plugin type")) + default: + logrus.Fatal(errors.New("unsupported local auth mechanism")) + } + } + + var saslTokenProvider apis.TokenProvider + if c.Kafka.SASL.Plugin.Enable { + switch c.Kafka.SASL.Plugin.Mechanism { + case "OAUTHBEARER": + var err error + factory, ok := registry.GetComponent(new(apis.TokenProviderFactory), c.Kafka.SASL.Plugin.Command).(apis.TokenProviderFactory) + if ok { + logrus.Infof("Using built-in '%s' TokenProvider for sasl authentication", c.Kafka.SASL.Plugin.Command) + + saslTokenProvider, err = factory.New(c.Kafka.SASL.Plugin.Parameters) + if err != nil { + logrus.Fatal(err) + } + } else { + client := NewPluginClient(tokenprovider.Handshake, tokenprovider.PluginMap, c.Kafka.SASL.Plugin.LogLevel, c.Kafka.SASL.Plugin.Command, c.Kafka.SASL.Plugin.Parameters) + defer client.Kill() + + rpcClient, err := client.Client() + if err != nil { + logrus.Fatal(err) + } + raw, err := rpcClient.Dispense("tokenProvider") + if err != nil { + logrus.Fatal(err) + } + saslTokenProvider, ok = raw.(apis.TokenProvider) + if !ok { + logrus.Fatal(errors.New("unsupported TokenProvider plugin type")) + } } + default: + logrus.Fatal(errors.New("unsupported sasl auth mechanism")) } } - var tokenProvider apis.TokenProvider + var gatewayTokenProvider apis.TokenProvider if c.Auth.Gateway.Client.Enable { var err error factory, ok := registry.GetComponent(new(apis.TokenProviderFactory), c.Auth.Gateway.Client.Command).(apis.TokenProviderFactory) if ok { - logrus.Infof("Using built-in TokenProvider") - tokenProvider, err = factory.New(c.Auth.Gateway.Client.Parameters) + logrus.Infof("Using built-in '%s' TokenProvider for Gateway Client", c.Auth.Gateway.Client.Command) + gatewayTokenProvider, err = factory.New(c.Auth.Gateway.Client.Parameters) if err != nil { logrus.Fatal(err) } } else { - client := NewPluginClient(gatewayclient.Handshake, gatewayclient.PluginMap, c.Auth.Gateway.Client.LogLevel, c.Auth.Gateway.Client.Command, c.Auth.Gateway.Client.Parameters) + client := NewPluginClient(tokenprovider.Handshake, tokenprovider.PluginMap, c.Auth.Gateway.Client.LogLevel, c.Auth.Gateway.Client.Command, c.Auth.Gateway.Client.Parameters) defer client.Kill() rpcClient, err := client.Client() @@ -223,26 +300,26 @@ func Run(_ *cobra.Command, _ []string) { if err != nil { logrus.Fatal(err) } - tokenProvider, ok = raw.(apis.TokenProvider) + gatewayTokenProvider, ok = raw.(apis.TokenProvider) if !ok { logrus.Fatal(errors.New("unsupported TokenProvider plugin type")) } } } - var tokenInfo apis.TokenInfo + var gatewayTokenInfo apis.TokenInfo if c.Auth.Gateway.Server.Enable { var err error factory, ok := registry.GetComponent(new(apis.TokenInfoFactory), c.Auth.Gateway.Server.Command).(apis.TokenInfoFactory) if ok { - logrus.Infof("Using built-in TokenInfo") + logrus.Infof("Using built-in '%s' TokenInfo for Gateway Server", c.Auth.Gateway.Server.Command) - tokenInfo, err = factory.New(c.Auth.Gateway.Server.Parameters) + gatewayTokenInfo, err = factory.New(c.Auth.Gateway.Server.Parameters) if err != nil { logrus.Fatal(err) } } else { - client := NewPluginClient(gatewayserver.Handshake, gatewayserver.PluginMap, c.Auth.Gateway.Server.LogLevel, c.Auth.Gateway.Server.Command, c.Auth.Gateway.Server.Parameters) + client := NewPluginClient(tokeninfo.Handshake, tokeninfo.PluginMap, c.Auth.Gateway.Server.LogLevel, c.Auth.Gateway.Server.Command, c.Auth.Gateway.Server.Parameters) defer client.Kill() rpcClient, err := client.Client() @@ -253,7 +330,7 @@ func Run(_ *cobra.Command, _ []string) { if err != nil { logrus.Fatal(err) } - tokenInfo, ok = raw.(apis.TokenInfo) + gatewayTokenInfo, ok = raw.(apis.TokenInfo) if !ok { logrus.Fatal(errors.New("unsupported TokenInfo plugin type")) } @@ -273,7 +350,7 @@ func Run(_ *cobra.Command, _ []string) { if err != nil { logrus.Fatal(err) } - proxyClient, err := proxy.NewClient(connset, c, listeners.GetNetAddressMapping, passwordAuthenticator, tokenProvider, tokenInfo) + proxyClient, err := proxy.NewClient(connset, c, listeners.GetNetAddressMapping, localPasswordAuthenticator, localTokenAuthenticator, saslTokenProvider, gatewayTokenProvider, gatewayTokenInfo) if err != nil { logrus.Fatal(err) } diff --git a/cmd/plugin-googleid-info/main.go b/cmd/plugin-googleid-info/main.go index 1282ec01..d57b2fa3 100644 --- a/cmd/plugin-googleid-info/main.go +++ b/cmd/plugin-googleid-info/main.go @@ -2,7 +2,7 @@ package main import ( "github.com/grepplabs/kafka-proxy/pkg/libs/googleid-info" - "github.com/grepplabs/kafka-proxy/plugin/gateway-server/shared" + "github.com/grepplabs/kafka-proxy/plugin/token-info/shared" "github.com/hashicorp/go-plugin" "github.com/sirupsen/logrus" "os" diff --git a/cmd/plugin-googleid-provider/main.go b/cmd/plugin-googleid-provider/main.go index c92a33cc..32941de6 100644 --- a/cmd/plugin-googleid-provider/main.go +++ b/cmd/plugin-googleid-provider/main.go @@ -2,7 +2,7 @@ package main import ( "github.com/grepplabs/kafka-proxy/pkg/libs/googleid-provider" - "github.com/grepplabs/kafka-proxy/plugin/gateway-client/shared" + "github.com/grepplabs/kafka-proxy/plugin/token-provider/shared" "github.com/hashicorp/go-plugin" "github.com/sirupsen/logrus" "os" diff --git a/cmd/plugin-unsecured-jwt-info/main.go b/cmd/plugin-unsecured-jwt-info/main.go new file mode 100644 index 00000000..adf53da1 --- /dev/null +++ b/cmd/plugin-unsecured-jwt-info/main.go @@ -0,0 +1,154 @@ +package main + +import ( + "bytes" + "context" + "encoding/base64" + "encoding/json" + "errors" + "flag" + "github.com/grepplabs/kafka-proxy/pkg/apis" + "github.com/grepplabs/kafka-proxy/pkg/libs/util" + "github.com/grepplabs/kafka-proxy/plugin/token-info/shared" + "github.com/hashicorp/go-plugin" + "github.com/sirupsen/logrus" + "os" + "strings" + "time" +) + +const ( + StatusOK = 0 + StatusEmptyToken = 1 + StatusParseJWTFailed = 2 + StatusWrongAlgorithm = 3 + StatusUnauthorized = 4 + StatusNoIssueTimeInToken = 5 + StatusNoExpirationTimeInToken = 6 + StatusTokenTooEarly = 7 + StatusTokenExpired = 8 + + AlgorithmNone = "none" +) + +var ( + clockSkew = 1 * time.Minute +) + +type UnsecuredJWTVerifier struct { + claimSub map[string]struct{} +} + +type pluginMeta struct { + claimSub util.ArrayFlags +} + +func (f *pluginMeta) flagSet() *flag.FlagSet { + fs := flag.NewFlagSet("unsecured-jwt-info info settings", flag.ContinueOnError) + fs.Var(&f.claimSub, "claim-sub", "Allowed subject claim (user name)") + return fs +} + +// Implements apis.TokenInfo +func (v UnsecuredJWTVerifier) VerifyToken(ctx context.Context, request apis.VerifyRequest) (apis.VerifyResponse, error) { + if request.Token == "" { + return getVerifyResponseResponse(StatusEmptyToken) + } + + header, claimSet, err := Decode(request.Token) + if err != nil { + return getVerifyResponseResponse(StatusParseJWTFailed) + } + if header.Algorithm != AlgorithmNone { + return getVerifyResponseResponse(StatusWrongAlgorithm) + } + + if len(v.claimSub) != 0 { + if _, ok := v.claimSub[claimSet.Sub]; !ok { + return getVerifyResponseResponse(StatusUnauthorized) + } + } + if claimSet.Iat < 1 { + return getVerifyResponseResponse(StatusNoIssueTimeInToken) + } + if claimSet.Exp < 1 { + return getVerifyResponseResponse(StatusNoExpirationTimeInToken) + } + + earliest := int64(claimSet.Iat) - int64(clockSkew.Seconds()) + latest := int64(claimSet.Exp) + int64(clockSkew.Seconds()) + unix := time.Now().Unix() + + if unix < earliest { + return getVerifyResponseResponse(StatusTokenTooEarly) + } + if unix > latest { + return getVerifyResponseResponse(StatusTokenExpired) + } + return getVerifyResponseResponse(StatusOK) +} + +type Header struct { + Algorithm string `json:"alg"` +} + +// kafka client sends float instead of int +type ClaimSet struct { + Sub string `json:"sub,omitempty"` + Exp float64 `json:"exp"` + Iat float64 `json:"iat"` + OtherClaims map[string]interface{} `json:"-"` +} + +func Decode(token string) (*Header, *ClaimSet, error) { + args := strings.Split(token, ".") + if len(args) < 2 { + return nil, nil, errors.New("jws: invalid token received") + } + decodedHeader, err := base64.RawURLEncoding.DecodeString(args[0]) + if err != nil { + return nil, nil, err + } + decodedPayload, err := base64.RawURLEncoding.DecodeString(args[1]) + if err != nil { + return nil, nil, err + } + + header := &Header{} + err = json.NewDecoder(bytes.NewBuffer(decodedHeader)).Decode(header) + if err != nil { + return nil, nil, err + } + claimSet := &ClaimSet{} + err = json.NewDecoder(bytes.NewBuffer(decodedPayload)).Decode(claimSet) + if err != nil { + return nil, nil, err + } + return header, claimSet, nil +} + +func getVerifyResponseResponse(status int) (apis.VerifyResponse, error) { + success := status == StatusOK + return apis.VerifyResponse{Success: success, Status: int32(status)}, nil +} + +func main() { + pluginMeta := &pluginMeta{} + fs := pluginMeta.flagSet() + _ = fs.Parse(os.Args[1:]) + + logrus.Infof("Unsecured JWT sub claims: %v", pluginMeta.claimSub) + + unsecuredJWTVerifier := &UnsecuredJWTVerifier{ + claimSub: pluginMeta.claimSub.AsMap(), + } + + plugin.Serve(&plugin.ServeConfig{ + HandshakeConfig: shared.Handshake, + Plugins: map[string]plugin.Plugin{ + "unsecuredJWTInfo": &shared.TokenInfoPlugin{Impl: unsecuredJWTVerifier}, + }, + // A non-nil value here enables gRPC serving for this plugin... + GRPCServer: plugin.DefaultGRPCServer, + }) +} diff --git a/cmd/plugin-unsecured-jwt-provider/main.go b/cmd/plugin-unsecured-jwt-provider/main.go new file mode 100644 index 00000000..b91276e4 --- /dev/null +++ b/cmd/plugin-unsecured-jwt-provider/main.go @@ -0,0 +1,83 @@ +package main + +import ( + "context" + "flag" + "github.com/grepplabs/kafka-proxy/pkg/apis" + "github.com/grepplabs/kafka-proxy/plugin/token-provider/shared" + "github.com/hashicorp/go-plugin" + "github.com/sirupsen/logrus" + "golang.org/x/oauth2/jws" + "os" +) + +const ( + StatusOK = 0 + StatusEncodeError = 1 + AlgorithmNone = "none" +) + +type UnsecuredJWTProvider struct { + claimSub string +} + +func (v UnsecuredJWTProvider) GetToken(ctx context.Context, request apis.TokenRequest) (apis.TokenResponse, error) { + token, err := v.encodeToken() + if err != nil { + return getGetTokenResponse(StatusEncodeError, "") + } + + return getGetTokenResponse(StatusOK, token) +} + +func getGetTokenResponse(status int, token string) (apis.TokenResponse, error) { + success := status == StatusOK + return apis.TokenResponse{Success: success, Status: int32(status), Token: token}, nil +} + +func (v UnsecuredJWTProvider) encodeToken() (string, error) { + header := &jws.Header{ + Algorithm: AlgorithmNone, + } + claims := &jws.ClaimSet{ + Sub: v.claimSub, + } + signer := func(data []byte) (sig []byte, err error) { + return []byte{}, nil + } + return jws.EncodeWithSigner(header, claims, signer) +} + +type pluginMeta struct { + claimSub string +} + +func (f *pluginMeta) flagSet() *flag.FlagSet { + fs := flag.NewFlagSet("unsecured-jwt-info info settings", flag.ContinueOnError) + fs.StringVar(&f.claimSub, "claim-sub", "", "subject claim") + return fs +} + +func main() { + pluginMeta := &pluginMeta{} + flags := pluginMeta.flagSet() + _ = flags.Parse(os.Args[1:]) + + if pluginMeta.claimSub == "" { + logrus.Errorf("parameter claim-sub is required") + os.Exit(1) + } + + unsecuredJWTProvider := &UnsecuredJWTProvider{ + claimSub: pluginMeta.claimSub, + } + + plugin.Serve(&plugin.ServeConfig{ + HandshakeConfig: shared.Handshake, + Plugins: map[string]plugin.Plugin{ + "unsecuredJWTProvider": &shared.TokenProviderPlugin{Impl: unsecuredJWTProvider}, + }, + // A non-nil value here enables gRPC serving for this plugin... + GRPCServer: plugin.DefaultGRPCServer, + }) +} diff --git a/config/config.go b/config/config.go index 51cb1b47..332020f3 100644 --- a/config/config.go +++ b/config/config.go @@ -66,6 +66,7 @@ type Config struct { Local struct { Enable bool Command string + Mechanism string Parameters []string LogLevel string Timeout time.Duration @@ -119,6 +120,14 @@ type Config struct { Username string Password string JaasConfigFile string + Plugin struct { + Enable bool + Command string + Mechanism string + Parameters []string + LogLevel string + Timeout time.Duration + } } } ForwardProxy struct { @@ -210,8 +219,26 @@ func NewConfig() *Config { } func (c *Config) Validate() error { - if c.Kafka.SASL.Enable && (c.Kafka.SASL.Username == "" || c.Kafka.SASL.Password == "") { - return errors.New("SASL.Username and SASL.Password are required when SASL is enabled") + if c.Kafka.SASL.Enable { + if c.Kafka.SASL.Plugin.Enable { + if c.Kafka.SASL.Plugin.Command == "" { + return errors.New("Command is required when Kafka.SASL.Plugin.Enable is enabled") + } + if c.Kafka.SASL.Plugin.Timeout <= 0 { + return errors.New("Kafka.SASL.Plugin.Timeout must be greater than 0") + } + if c.Kafka.SASL.Plugin.Mechanism != "OAUTHBEARER" { + return errors.New("Mechanism OAUTHBEARER is required when Kafka.SASL.Plugin.Enable is enabled") + } + } else { + if c.Kafka.SASL.Username == "" || c.Kafka.SASL.Password == "" { + return errors.New("SASL.Username and SASL.Password are required when SASL is enabled and plugin is not used") + } + } + } else { + if c.Kafka.SASL.Plugin.Enable { + return errors.New("Kafka.SASL.Plugin.Enable must be disabled, when SASL is disabled") + } } if c.Kafka.KeepAlive < 0 { return errors.New("KeepAlive must be greater or equal 0") @@ -254,6 +281,9 @@ func (c *Config) Validate() error { if c.Auth.Local.Enable && c.Auth.Local.Command == "" { return errors.New("Command is required when Auth.Local.Enable is enabled") } + if c.Auth.Local.Enable && (c.Auth.Local.Mechanism != "PLAIN" && c.Auth.Local.Mechanism != "OAUTHBEARER") { + return errors.New("Mechanism PLAIN or OAUTHBEARER is required when Auth.Local.Enable is enabled") + } if c.Auth.Local.Enable && c.Auth.Local.Timeout <= 0 { return errors.New("Auth.Local.Timeout must be greater than 0") } diff --git a/pkg/apis/gateway.go b/pkg/apis/token.go similarity index 100% rename from pkg/apis/gateway.go rename to pkg/apis/token.go diff --git a/pkg/libs/googleid-info/factory.go b/pkg/libs/googleid-info/factory.go index 93a8b13b..4890bdaa 100644 --- a/pkg/libs/googleid-info/factory.go +++ b/pkg/libs/googleid-info/factory.go @@ -2,8 +2,8 @@ package googleidinfo import ( "flag" - "fmt" "github.com/grepplabs/kafka-proxy/pkg/apis" + "github.com/grepplabs/kafka-proxy/pkg/libs/util" "github.com/grepplabs/kafka-proxy/pkg/registry" ) @@ -20,27 +20,8 @@ func (f *pluginMeta) flagSet() *flag.FlagSet { type pluginMeta struct { timeout int certsRefreshInterval int - audience arrayFlags - emailsRegex arrayFlags -} - -type arrayFlags []string - -func (i *arrayFlags) String() string { - return fmt.Sprintf("%v", *i) -} - -func (i *arrayFlags) Set(value string) error { - *i = append(*i, value) - return nil -} - -func (i *arrayFlags) asMap() map[string]struct{} { - result := make(map[string]struct{}) - for _, elem := range *i { - result[elem] = struct{}{} - } - return result + audience util.ArrayFlags + emailsRegex util.ArrayFlags } type Factory struct { diff --git a/pkg/libs/util/flags.go b/pkg/libs/util/flags.go new file mode 100644 index 00000000..91def714 --- /dev/null +++ b/pkg/libs/util/flags.go @@ -0,0 +1,22 @@ +package util + +import "fmt" + +type ArrayFlags []string + +func (i *ArrayFlags) String() string { + return fmt.Sprintf("%v", *i) +} + +func (i *ArrayFlags) Set(value string) error { + *i = append(*i, value) + return nil +} + +func (i *ArrayFlags) AsMap() map[string]struct{} { + result := make(map[string]struct{}) + for _, elem := range *i { + result[elem] = struct{}{} + } + return result +} diff --git a/plugin/gateway-server/proto/token-info.pb.go b/plugin/token-info/proto/token-info.pb.go similarity index 100% rename from plugin/gateway-server/proto/token-info.pb.go rename to plugin/token-info/proto/token-info.pb.go diff --git a/plugin/gateway-server/proto/token-info.proto b/plugin/token-info/proto/token-info.proto similarity index 100% rename from plugin/gateway-server/proto/token-info.proto rename to plugin/token-info/proto/token-info.proto diff --git a/plugin/gateway-server/shared/grpc.go b/plugin/token-info/shared/grpc.go similarity index 94% rename from plugin/gateway-server/shared/grpc.go rename to plugin/token-info/shared/grpc.go index 449b716e..98368b3f 100644 --- a/plugin/gateway-server/shared/grpc.go +++ b/plugin/token-info/shared/grpc.go @@ -2,7 +2,7 @@ package shared import ( "github.com/grepplabs/kafka-proxy/pkg/apis" - "github.com/grepplabs/kafka-proxy/plugin/gateway-server/proto" + "github.com/grepplabs/kafka-proxy/plugin/token-info/proto" "github.com/hashicorp/go-plugin" "golang.org/x/net/context" ) diff --git a/plugin/gateway-server/shared/interface.go b/plugin/token-info/shared/interface.go similarity index 95% rename from plugin/gateway-server/shared/interface.go rename to plugin/token-info/shared/interface.go index f0cb5681..0e09abfe 100644 --- a/plugin/gateway-server/shared/interface.go +++ b/plugin/token-info/shared/interface.go @@ -6,7 +6,7 @@ import ( "google.golang.org/grpc" "github.com/grepplabs/kafka-proxy/pkg/apis" - "github.com/grepplabs/kafka-proxy/plugin/gateway-server/proto" + "github.com/grepplabs/kafka-proxy/plugin/token-info/proto" "github.com/hashicorp/go-plugin" "net/rpc" ) diff --git a/plugin/gateway-server/shared/rpc.go b/plugin/token-info/shared/rpc.go similarity index 100% rename from plugin/gateway-server/shared/rpc.go rename to plugin/token-info/shared/rpc.go diff --git a/plugin/gateway-client/proto/token-provider.pb.go b/plugin/token-provider/proto/token-provider.pb.go similarity index 100% rename from plugin/gateway-client/proto/token-provider.pb.go rename to plugin/token-provider/proto/token-provider.pb.go diff --git a/plugin/gateway-client/proto/token-provider.proto b/plugin/token-provider/proto/token-provider.proto similarity index 100% rename from plugin/gateway-client/proto/token-provider.proto rename to plugin/token-provider/proto/token-provider.proto diff --git a/plugin/gateway-client/shared/grpc.go b/plugin/token-provider/shared/grpc.go similarity index 94% rename from plugin/gateway-client/shared/grpc.go rename to plugin/token-provider/shared/grpc.go index 45e5be74..188d3573 100644 --- a/plugin/gateway-client/shared/grpc.go +++ b/plugin/token-provider/shared/grpc.go @@ -2,7 +2,7 @@ package shared import ( "github.com/grepplabs/kafka-proxy/pkg/apis" - "github.com/grepplabs/kafka-proxy/plugin/gateway-client/proto" + "github.com/grepplabs/kafka-proxy/plugin/token-provider/proto" "github.com/hashicorp/go-plugin" "golang.org/x/net/context" ) diff --git a/plugin/gateway-client/shared/interface.go b/plugin/token-provider/shared/interface.go similarity index 95% rename from plugin/gateway-client/shared/interface.go rename to plugin/token-provider/shared/interface.go index 1b21a796..4154a2b2 100644 --- a/plugin/gateway-client/shared/interface.go +++ b/plugin/token-provider/shared/interface.go @@ -6,7 +6,7 @@ import ( "google.golang.org/grpc" "github.com/grepplabs/kafka-proxy/pkg/apis" - "github.com/grepplabs/kafka-proxy/plugin/gateway-client/proto" + "github.com/grepplabs/kafka-proxy/plugin/token-provider/proto" "github.com/hashicorp/go-plugin" "net/rpc" ) diff --git a/plugin/gateway-client/shared/rpc.go b/plugin/token-provider/shared/rpc.go similarity index 100% rename from plugin/gateway-client/shared/rpc.go rename to plugin/token-provider/shared/rpc.go diff --git a/proxy/auth.go b/proxy/auth.go index b4f3055f..1a26a88d 100644 --- a/proxy/auth.go +++ b/proxy/auth.go @@ -118,7 +118,7 @@ func (b *AuthServer) receiveAndSendGatewayAuth(conn DeadlineReaderWriter) error return err } if !resp.Success { - return fmt.Errorf("verify token failed with status: %d", resp.Status) + return fmt.Errorf("gateway server verify token failed with status: %d", resp.Status) } logrus.Debugf("gateway handshake payload: %s", data) diff --git a/proxy/client.go b/proxy/client.go index cc676ec1..c3594849 100644 --- a/proxy/client.go +++ b/proxy/client.go @@ -34,11 +34,11 @@ type Client struct { stopRun chan struct{} stopOnce sync.Once - saslPlainAuth *SASLPlainAuth - authClient *AuthClient + saslAuthByProxy SASLAuthByProxy + authClient *AuthClient } -func NewClient(conns *ConnSet, c *config.Config, netAddressMappingFunc config.NetAddressMappingFunc, passwordAuthenticator apis.PasswordAuthenticator, tokenProvider apis.TokenProvider, tokenInfo apis.TokenInfo) (*Client, error) { +func NewClient(conns *ConnSet, c *config.Config, netAddressMappingFunc config.NetAddressMappingFunc, localPasswordAuthenticator apis.PasswordAuthenticator, localTokenAuthenticator apis.TokenInfo, saslTokenProvider apis.TokenProvider, gatewayTokenProvider apis.TokenProvider, gatewayTokenInfo apis.TokenInfo) (*Client, error) { tlsConfig, err := newTLSClientConfig(c) if err != nil { return nil, err @@ -60,31 +60,47 @@ func NewClient(conns *ConnSet, c *config.Config, netAddressMappingFunc config.Ne forbiddenApiKeys[int16(apiKey)] = struct{}{} } } - if c.Auth.Local.Enable && passwordAuthenticator == nil { - return nil, errors.New("Auth.Local.Enable is enabled but passwordAuthenticator is nil") + if c.Auth.Local.Enable && (localPasswordAuthenticator == nil && localTokenAuthenticator == nil) { + return nil, errors.New("Auth.Local.Enable is enabled but passwordAuthenticator and localTokenAuthenticator are nil") } - if c.Auth.Gateway.Client.Enable && tokenProvider == nil { + if c.Auth.Gateway.Client.Enable && gatewayTokenProvider == nil { return nil, errors.New("Auth.Gateway.Client.Enable is enabled but tokenProvider is nil") } - if c.Auth.Gateway.Server.Enable && tokenInfo == nil { + if c.Auth.Gateway.Server.Enable && gatewayTokenInfo == nil { return nil, errors.New("Auth.Gateway.Server.Enable is enabled but tokenInfo is nil") } + var saslAuthByProxy SASLAuthByProxy + if c.Kafka.SASL.Plugin.Enable { + if c.Kafka.SASL.Plugin.Mechanism == SASLOAuthBearer && saslTokenProvider != nil { + saslAuthByProxy = &SASLOAuthBearerAuth{ + clientID: c.Kafka.ClientID, + writeTimeout: c.Kafka.WriteTimeout, + readTimeout: c.Kafka.ReadTimeout, + tokenProvider: saslTokenProvider, + } + } else { + return nil, errors.Errorf("SASLAuthByProxy plugin unsupported or plugin misconfiguration for mechanism '%s' ", c.Kafka.SASL.Plugin.Mechanism) + } - return &Client{conns: conns, config: c, dialer: dialer, tcpConnOptions: tcpConnOptions, stopRun: make(chan struct{}, 1), - saslPlainAuth: &SASLPlainAuth{ + } else { + saslAuthByProxy = &SASLPlainAuth{ clientID: c.Kafka.ClientID, writeTimeout: c.Kafka.WriteTimeout, readTimeout: c.Kafka.ReadTimeout, username: c.Kafka.SASL.Username, password: c.Kafka.SASL.Password, - }, + } + } + + return &Client{conns: conns, config: c, dialer: dialer, tcpConnOptions: tcpConnOptions, stopRun: make(chan struct{}, 1), + saslAuthByProxy: saslAuthByProxy, authClient: &AuthClient{ enabled: c.Auth.Gateway.Client.Enable, magic: c.Auth.Gateway.Client.Magic, method: c.Auth.Gateway.Client.Method, timeout: c.Auth.Gateway.Client.Timeout, - tokenProvider: tokenProvider, + tokenProvider: gatewayTokenProvider, }, processorConfig: ProcessorConfig{ MaxOpenRequests: c.Kafka.MaxOpenRequests, @@ -93,16 +109,18 @@ func NewClient(conns *ConnSet, c *config.Config, netAddressMappingFunc config.Ne ResponseBufferSize: c.Proxy.ResponseBufferSize, ReadTimeout: c.Kafka.ReadTimeout, WriteTimeout: c.Kafka.WriteTimeout, - LocalSasl: &LocalSasl{ - enabled: c.Auth.Local.Enable, - timeout: c.Auth.Local.Timeout, - localAuthenticator: passwordAuthenticator}, + LocalSasl: NewLocalSasl(LocalSaslParams{ + enabled: c.Auth.Local.Enable, + timeout: c.Auth.Local.Timeout, + passwordAuthenticator: localPasswordAuthenticator, + tokenAuthenticator: localTokenAuthenticator, + }), AuthServer: &AuthServer{ enabled: c.Auth.Gateway.Server.Enable, magic: c.Auth.Gateway.Server.Magic, method: c.Auth.Gateway.Server.Method, timeout: c.Auth.Gateway.Server.Timeout, - tokenInfo: tokenInfo, + tokenInfo: gatewayTokenInfo, }, ForbiddenApiKeys: forbiddenApiKeys, }}, nil @@ -191,7 +209,7 @@ func (c *Client) handleConn(conn Conn) { server, err := c.DialAndAuth(conn.BrokerAddress) if err != nil { logrus.Infof("couldn't connect to %s: %v", conn.BrokerAddress, err) - conn.LocalConnection.Close() + _ = conn.LocalConnection.Close() return } if tcpConn, ok := server.(*net.TCPConn); ok { @@ -213,7 +231,7 @@ func (c *Client) DialAndAuth(brokerAddress string) (net.Conn, error) { return nil, err } if err := conn.SetDeadline(time.Time{}); err != nil { - conn.Close() + _ = conn.Close() return nil, err } err = c.auth(conn) @@ -226,22 +244,22 @@ func (c *Client) DialAndAuth(brokerAddress string) (net.Conn, error) { func (c *Client) auth(conn net.Conn) error { if c.config.Auth.Gateway.Client.Enable { if err := c.authClient.sendAndReceiveGatewayAuth(conn); err != nil { - conn.Close() + _ = conn.Close() return err } if err := conn.SetDeadline(time.Time{}); err != nil { - conn.Close() + _ = conn.Close() return err } } if c.config.Kafka.SASL.Enable { - err := c.saslPlainAuth.sendAndReceiveSASLPlainAuth(conn) + err := c.saslAuthByProxy.sendAndReceiveSASLAuth(conn) if err != nil { - conn.Close() + _ = conn.Close() return err } if err := conn.SetDeadline(time.Time{}); err != nil { - conn.Close() + _ = conn.Close() return err } } diff --git a/proxy/processor_default.go b/proxy/processor_default.go index 756b4600..5bfc4838 100644 --- a/proxy/processor_default.go +++ b/proxy/processor_default.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "github.com/grepplabs/kafka-proxy/proxy/protocol" + "github.com/sirupsen/logrus" "io" "strconv" "time" @@ -32,7 +33,7 @@ func (handler *DefaultRequestHandler) handleRequest(dst DeadlineWriter, src Dead if err = protocol.Decode(keyVersionBuf, requestKeyVersion); err != nil { return true, err } - //logrus.Printf("Kafka request length %v, key %v, version %v", requestKeyVersion.Length, requestKeyVersion.ApiKey, requestKeyVersion.ApiVersion) + logrus.Debugf("Kafka request key %v, version %v, length %v", requestKeyVersion.ApiKey, requestKeyVersion.ApiVersion, requestKeyVersion.Length) if requestKeyVersion.ApiKey < minRequestApiKey || requestKeyVersion.ApiKey > maxRequestApiKey { return true, fmt.Errorf("api key %d is invalid", requestKeyVersion.ApiKey) @@ -55,11 +56,11 @@ func (handler *DefaultRequestHandler) handleRequest(dst DeadlineWriter, src Dead case apiKeySaslHandshake: switch requestKeyVersion.ApiVersion { case 0: - if err = ctx.localSasl.receiveAndSendSASLPlainAuthV0(src, keyVersionBuf); err != nil { + if err = ctx.localSasl.receiveAndSendSASLAuthV0(src, keyVersionBuf); err != nil { return true, err } case 1: - if err = ctx.localSasl.receiveAndSendSASLPlainAuthV1(src, keyVersionBuf); err != nil { + if err = ctx.localSasl.receiveAndSendSASLAuthV1(src, keyVersionBuf); err != nil { return true, err } default: @@ -132,7 +133,7 @@ func (handler *DefaultResponseHandler) handleResponse(dst DeadlineWriter, src De return true, err } proxyResponsesBytes.WithLabelValues(ctx.brokerAddress).Add(float64(responseHeader.Length + 4)) - //logrus.Printf("Kafka response lenght %v for key %v, version %v", responseHeader.Length, requestKeyVersion.ApiKey, requestKeyVersion.ApiVersion) + logrus.Debugf("Kafka response key %v, version %v, length %v", requestKeyVersion.ApiKey, requestKeyVersion.ApiVersion, responseHeader.Length) responseDeadline := time.Now().Add(ctx.timeout) err = dst.SetWriteDeadline(responseDeadline) diff --git a/proxy/sasl.go b/proxy/sasl.go deleted file mode 100644 index 6b5768b1..00000000 --- a/proxy/sasl.go +++ /dev/null @@ -1,129 +0,0 @@ -package proxy - -import ( - "bytes" - "encoding/binary" - "fmt" - "github.com/grepplabs/kafka-proxy/proxy/protocol" - "github.com/pkg/errors" - "io" - "time" -) - -const ( - SASLPlain = "PLAIN" -) - -type SASLPlainAuth struct { - clientID string - - writeTimeout time.Duration - readTimeout time.Duration - - username string - password string -} - -// In SASL Plain, Kafka expects the auth header to be in the following format -// Message format (from https://tools.ietf.org/html/rfc4616): -// -// message = [authzid] UTF8NUL authcid UTF8NUL passwd -// authcid = 1*SAFE ; MUST accept up to 255 octets -// authzid = 1*SAFE ; MUST accept up to 255 octets -// passwd = 1*SAFE ; MUST accept up to 255 octets -// UTF8NUL = %x00 ; UTF-8 encoded NUL character -// -// SAFE = UTF1 / UTF2 / UTF3 / UTF4 -// ;; any UTF-8 encoded Unicode character except NUL -// -// When credentials are valid, Kafka returns a 4 byte array of null characters. -// When credentials are invalid, Kafka closes the connection. This does not seem to be the ideal way -// of responding to bad credentials but thats how its being done today. -func (b *SASLPlainAuth) sendAndReceiveSASLPlainAuth(conn DeadlineReaderWriter) error { - - handshakeErr := b.sendAndReceiveSASLPlainHandshake(conn) - if handshakeErr != nil { - return handshakeErr - } - length := 1 + len(b.username) + 1 + len(b.password) - authBytes := make([]byte, length+4) //4 byte length header + auth data - binary.BigEndian.PutUint32(authBytes, uint32(length)) - copy(authBytes[4:], []byte("\x00"+b.username+"\x00"+b.password)) - - err := conn.SetWriteDeadline(time.Now().Add(b.writeTimeout)) - if err != nil { - return err - } - _, err = conn.Write(authBytes) - if err != nil { - return errors.Wrap(err, "Failed to write SASL auth header") - } - - err = conn.SetReadDeadline(time.Now().Add(b.readTimeout)) - if err != nil { - return err - } - - header := make([]byte, 4) - _, err = io.ReadFull(conn, header) - // If the credentials are valid, we would get a 4 byte response filled with null characters. - // Otherwise, the broker closes the connection and we get an EOF - if err != nil { - if err == io.EOF { - return fmt.Errorf("SASL/PLAIN auth for user %s failed", b.username) - } - return errors.Wrap(err, "Failed to read response while authenticating with SASL") - } - return nil -} - -func (b *SASLPlainAuth) sendAndReceiveSASLPlainHandshake(conn DeadlineReaderWriter) error { - - req := &protocol.Request{ - ClientID: b.clientID, - Body: &protocol.SaslHandshakeRequestV0orV1{Version: 0, Mechanism: SASLPlain}, - } - reqBuf, err := protocol.Encode(req) - if err != nil { - return err - } - sizeBuf := make([]byte, 4) - binary.BigEndian.PutUint32(sizeBuf, uint32(len(reqBuf))) - - err = conn.SetWriteDeadline(time.Now().Add(b.writeTimeout)) - if err != nil { - return err - } - - _, err = conn.Write(bytes.Join([][]byte{sizeBuf, reqBuf}, nil)) - if err != nil { - return errors.Wrap(err, "Failed to send SASL handshake") - } - - err = conn.SetReadDeadline(time.Now().Add(b.readTimeout)) - if err != nil { - return err - } - - //wait for the response - header := make([]byte, 8) // response header - _, err = io.ReadFull(conn, header) - if err != nil { - return errors.Wrap(err, "Failed to read SASL handshake header") - } - length := binary.BigEndian.Uint32(header[:4]) - payload := make([]byte, length-4) - _, err = io.ReadFull(conn, payload) - if err != nil { - return errors.Wrap(err, "Failed to read SASL handshake payload") - } - res := &protocol.SaslHandshakeResponseV0orV1{} - err = protocol.Decode(payload, res) - if err != nil { - return errors.Wrap(err, "Failed to parse SASL handshake") - } - if res.Err != protocol.ErrNoError { - return errors.Wrap(res.Err, "Invalid SASL Mechanism") - } - return nil -} diff --git a/proxy/sasl_by_proxy.go b/proxy/sasl_by_proxy.go new file mode 100644 index 00000000..8516215f --- /dev/null +++ b/proxy/sasl_by_proxy.go @@ -0,0 +1,258 @@ +package proxy + +import ( + "bytes" + "context" + "encoding/binary" + "fmt" + "github.com/grepplabs/kafka-proxy/pkg/apis" + "github.com/grepplabs/kafka-proxy/proxy/protocol" + "github.com/pkg/errors" + "github.com/sirupsen/logrus" + "io" + "time" +) + +const ( + SASLPlain = "PLAIN" + SASLOAuthBearer = "OAUTHBEARER" +) + +type SASLHandshake struct { + clientID string + version int16 + mechanism string + + writeTimeout time.Duration + readTimeout time.Duration +} + +type SASLOAuthBearerAuth struct { + clientID string + + writeTimeout time.Duration + readTimeout time.Duration + + tokenProvider apis.TokenProvider +} + +type SASLPlainAuth struct { + clientID string + + writeTimeout time.Duration + readTimeout time.Duration + + username string + password string +} + +type SASLAuthByProxy interface { + sendAndReceiveSASLAuth(conn DeadlineReaderWriter) error +} + +// In SASL Plain, Kafka expects the auth header to be in the following format +// Message format (from https://tools.ietf.org/html/rfc4616): +// +// message = [authzid] UTF8NUL authcid UTF8NUL passwd +// authcid = 1*SAFE ; MUST accept up to 255 octets +// authzid = 1*SAFE ; MUST accept up to 255 octets +// passwd = 1*SAFE ; MUST accept up to 255 octets +// UTF8NUL = %x00 ; UTF-8 encoded NUL character +// +// SAFE = UTF1 / UTF2 / UTF3 / UTF4 +// ;; any UTF-8 encoded Unicode character except NUL +// +// When credentials are valid, Kafka returns a 4 byte array of null characters. +// When credentials are invalid, Kafka closes the connection. This does not seem to be the ideal way +// of responding to bad credentials but thats how its being done today. +func (b *SASLPlainAuth) sendAndReceiveSASLAuth(conn DeadlineReaderWriter) error { + + saslHandshake := &SASLHandshake{ + clientID: b.clientID, + version: 0, + mechanism: SASLPlain, + writeTimeout: b.writeTimeout, + readTimeout: b.readTimeout, + } + handshakeErr := saslHandshake.sendAndReceiveHandshake(conn) + if handshakeErr != nil { + return handshakeErr + } + return b.sendSaslAuthenticateRequest(conn) +} + +func (b *SASLPlainAuth) sendSaslAuthenticateRequest(conn DeadlineReaderWriter) error { + logrus.Debugf("Sending authentication opaque packets, mechanism PLAIN") + + length := 1 + len(b.username) + 1 + len(b.password) + authBytes := make([]byte, length+4) //4 byte length header + auth data + binary.BigEndian.PutUint32(authBytes, uint32(length)) + copy(authBytes[4:], []byte("\x00"+b.username+"\x00"+b.password)) + + err := conn.SetWriteDeadline(time.Now().Add(b.writeTimeout)) + if err != nil { + return err + } + _, err = conn.Write(authBytes) + if err != nil { + return errors.Wrap(err, "Failed to write SASL auth header") + } + + err = conn.SetReadDeadline(time.Now().Add(b.readTimeout)) + if err != nil { + return err + } + + header := make([]byte, 4) + _, err = io.ReadFull(conn, header) + // If the credentials are valid, we would get a 4 byte response filled with null characters. + // Otherwise, the broker closes the connection and we get an EOF + if err != nil { + if err == io.EOF { + return fmt.Errorf("SASL/PLAIN auth for user %s failed", b.username) + } + return errors.Wrap(err, "Failed to read response while authenticating with SASL") + } + return nil +} + +func (b *SASLHandshake) sendAndReceiveHandshake(conn DeadlineReaderWriter) error { + logrus.Debugf("Sending SaslHandshakeRequest") + + req := &protocol.Request{ + ClientID: b.clientID, + Body: &protocol.SaslHandshakeRequestV0orV1{Version: b.version, Mechanism: b.mechanism}, + } + reqBuf, err := protocol.Encode(req) + if err != nil { + return err + } + sizeBuf := make([]byte, 4) + binary.BigEndian.PutUint32(sizeBuf, uint32(len(reqBuf))) + + err = conn.SetWriteDeadline(time.Now().Add(b.writeTimeout)) + if err != nil { + return err + } + + _, err = conn.Write(bytes.Join([][]byte{sizeBuf, reqBuf}, nil)) + if err != nil { + return errors.Wrap(err, "Failed to send SASL handshake") + } + + err = conn.SetReadDeadline(time.Now().Add(b.readTimeout)) + if err != nil { + return err + } + + //wait for the response + header := make([]byte, 8) // response header + _, err = io.ReadFull(conn, header) + if err != nil { + return errors.Wrap(err, "Failed to read SASL handshake header") + } + length := binary.BigEndian.Uint32(header[:4]) + payload := make([]byte, length-4) + _, err = io.ReadFull(conn, payload) + if err != nil { + return errors.Wrap(err, "Failed to read SASL handshake payload") + } + res := &protocol.SaslHandshakeResponseV0orV1{} + err = protocol.Decode(payload, res) + if err != nil { + return errors.Wrap(err, "Failed to parse SASL handshake") + } + if res.Err != protocol.ErrNoError { + return errors.Wrap(res.Err, "Invalid SASL Mechanism") + } + return nil +} + +func (b *SASLOAuthBearerAuth) getOAuthBearerToken() (string, error) { + resp, err := b.tokenProvider.GetToken(context.Background(), apis.TokenRequest{}) + if err != nil { + return "", err + } + if !resp.Success { + return "", fmt.Errorf("get sasl token failed with status: %d", resp.Status) + } + if resp.Token == "" { + return "", errors.New("get sasl token returned empty token") + } + return resp.Token, nil +} + +func (b *SASLOAuthBearerAuth) sendAndReceiveSASLAuth(conn DeadlineReaderWriter) error { + + token, err := b.getOAuthBearerToken() + if err != nil { + return err + } + saslHandshake := &SASLHandshake{ + clientID: b.clientID, + version: 1, + mechanism: SASLOAuthBearer, + writeTimeout: b.writeTimeout, + readTimeout: b.readTimeout, + } + handshakeErr := saslHandshake.sendAndReceiveHandshake(conn) + if handshakeErr != nil { + return handshakeErr + } + return b.sendSaslAuthenticateRequest(token, conn) +} + +func (b *SASLOAuthBearerAuth) sendSaslAuthenticateRequest(token string, conn DeadlineReaderWriter) error { + logrus.Debugf("Sending SaslAuthenticateRequest, mechanism OAUTHBEARER") + + saslAuthReqV0 := protocol.SaslAuthenticateRequestV0{SaslAuthBytes: SaslOAuthBearer{}.ToBytes(token, "", make(map[string]string, 0))} + + req := &protocol.Request{ + ClientID: b.clientID, + Body: &saslAuthReqV0, + } + reqBuf, err := protocol.Encode(req) + if err != nil { + return err + } + sizeBuf := make([]byte, 4) + binary.BigEndian.PutUint32(sizeBuf, uint32(len(reqBuf))) + + err = conn.SetWriteDeadline(time.Now().Add(b.writeTimeout)) + if err != nil { + return err + } + + _, err = conn.Write(bytes.Join([][]byte{sizeBuf, reqBuf}, nil)) + if err != nil { + return errors.Wrap(err, "Failed to send SASL auth request") + } + + err = conn.SetReadDeadline(time.Now().Add(b.readTimeout)) + if err != nil { + return err + } + + //wait for the response + header := make([]byte, 8) // response header + _, err = io.ReadFull(conn, header) + if err != nil { + return errors.Wrap(err, "Failed to read SASL auth header") + } + length := binary.BigEndian.Uint32(header[:4]) + payload := make([]byte, length-4) + _, err = io.ReadFull(conn, payload) + if err != nil { + return errors.Wrap(err, "Failed to read SASL auth payload") + } + + res := &protocol.SaslAuthenticateResponseV0{} + err = protocol.Decode(payload, res) + if err != nil { + return errors.Wrap(err, "Failed to parse SASL auth response") + } + if res.Err != protocol.ErrNoError { + return errors.Wrapf(res.Err, "SASL authentication failed, error message is '%v'", res.ErrMsg) + } + return nil +} diff --git a/proxy/sasl_local.go b/proxy/sasl_local.go index e10c2445..6387972c 100644 --- a/proxy/sasl_local.go +++ b/proxy/sasl_local.go @@ -8,98 +8,126 @@ import ( "github.com/grepplabs/kafka-proxy/pkg/apis" "github.com/grepplabs/kafka-proxy/proxy/protocol" "io" - "strconv" - "strings" "time" ) type LocalSasl struct { - enabled bool - timeout time.Duration - localAuthenticator apis.PasswordAuthenticator + enabled bool + timeout time.Duration + localAuthenticators map[string]LocalSaslAuth } -func (p *LocalSasl) receiveAndSendSASLPlainAuthV1(conn DeadlineReaderWriter, readKeyVersionBuf []byte) (err error) { - if err = p.receiveAndSendSaslV0orV1(conn, readKeyVersionBuf, 1); err != nil { +type LocalSaslParams struct { + enabled bool + timeout time.Duration + passwordAuthenticator apis.PasswordAuthenticator + tokenAuthenticator apis.TokenInfo +} + +func NewLocalSasl(params LocalSaslParams) *LocalSasl { + localAuthenticators := make(map[string]LocalSaslAuth) + if params.passwordAuthenticator != nil { + localAuthenticators[SASLPlain] = NewLocalSaslPlain(params.passwordAuthenticator) + } + + if params.tokenAuthenticator != nil { + localAuthenticators[SASLOAuthBearer] = NewLocalSaslOauth(params.tokenAuthenticator) + } + return &LocalSasl{ + enabled: params.enabled, + timeout: params.timeout, + localAuthenticators: localAuthenticators, + } +} + +func (p *LocalSasl) receiveAndSendSASLAuthV1(conn DeadlineReaderWriter, readKeyVersionBuf []byte) (err error) { + var localSaslAuth LocalSaslAuth + if localSaslAuth, err = p.receiveAndSendSaslV0orV1(conn, readKeyVersionBuf, 1); err != nil { return err } - if err = p.receiveAndSendAuthV1(conn); err != nil { + if err = p.receiveAndSendAuthV1(conn, localSaslAuth); err != nil { return err } return nil } -func (p *LocalSasl) receiveAndSendSASLPlainAuthV0(conn DeadlineReaderWriter, readKeyVersionBuf []byte) (err error) { - if err = p.receiveAndSendSaslV0orV1(conn, readKeyVersionBuf, 0); err != nil { +func (p *LocalSasl) receiveAndSendSASLAuthV0(conn DeadlineReaderWriter, readKeyVersionBuf []byte) (err error) { + var localSaslAuth LocalSaslAuth + if localSaslAuth, err = p.receiveAndSendSaslV0orV1(conn, readKeyVersionBuf, 0); err != nil { return err } - if err = p.receiveAndSendAuthV0(conn); err != nil { + if err = p.receiveAndSendAuthV0(conn, localSaslAuth); err != nil { return err } return nil } -func (p *LocalSasl) receiveAndSendSaslV0orV1(conn DeadlineReaderWriter, keyVersionBuf []byte, version int16) (err error) { +func (p *LocalSasl) receiveAndSendSaslV0orV1(conn DeadlineReaderWriter, keyVersionBuf []byte, version int16) (localSaslAuth LocalSaslAuth, err error) { requestDeadline := time.Now().Add(p.timeout) err = conn.SetDeadline(requestDeadline) if err != nil { - return err + return nil, err } if len(keyVersionBuf) != 8 { - return errors.New("length of keyVersionBuf should be 8") + return nil, errors.New("length of keyVersionBuf should be 8") } // keyVersionBuf has already been read from connection requestKeyVersion := &protocol.RequestKeyVersion{} if err = protocol.Decode(keyVersionBuf, requestKeyVersion); err != nil { - return err + return nil, err } if !(requestKeyVersion.ApiKey == 17 && requestKeyVersion.ApiVersion == version) { - return fmt.Errorf("SaslHandshake version %d is expected, but got %d", version, requestKeyVersion.ApiVersion) + return nil, fmt.Errorf("SaslHandshake version %d is expected, but got %d", version, requestKeyVersion.ApiVersion) } if int32(requestKeyVersion.Length) > protocol.MaxRequestSize { - return protocol.PacketDecodingError{Info: fmt.Sprintf("sasl handshake message of length %d too large", requestKeyVersion.Length)} + return nil, protocol.PacketDecodingError{Info: fmt.Sprintf("sasl handshake message of length %d too large", requestKeyVersion.Length)} } resp := make([]byte, int(requestKeyVersion.Length-4)) if _, err = io.ReadFull(conn, resp); err != nil { - return err + return nil, err } payload := bytes.Join([][]byte{keyVersionBuf[4:], resp}, nil) saslReqV0orV1 := &protocol.SaslHandshakeRequestV0orV1{Version: version} req := &protocol.Request{Body: saslReqV0orV1} if err = protocol.Decode(payload, req); err != nil { - return err + return nil, err } var saslResult error saslErr := protocol.ErrNoError - if saslReqV0orV1.Mechanism != SASLPlain { - saslResult = fmt.Errorf("PLAIN mechanism expected, but got %s", saslReqV0orV1.Mechanism) + localSaslAuth = p.localAuthenticators[saslReqV0orV1.Mechanism] + if localSaslAuth == nil { + mechanisms := make([]string, 0) + for mechanism := range p.localAuthenticators { + mechanisms = append(mechanisms, mechanism) + } + saslResult = fmt.Errorf("PLAIN or OAUTHBEARER mechanism expected, %v are configured, but got %s", mechanisms, saslReqV0orV1.Mechanism) saslErr = protocol.ErrUnsupportedSASLMechanism } - saslResV0 := &protocol.SaslHandshakeResponseV0orV1{Err: saslErr, EnabledMechanisms: []string{SASLPlain}} + saslResV0 := &protocol.SaslHandshakeResponseV0orV1{Err: saslErr, EnabledMechanisms: []string{saslReqV0orV1.Mechanism}} newResponseBuf, err := protocol.Encode(saslResV0) if err != nil { - return err + return nil, err } newHeaderBuf, err := protocol.Encode(&protocol.ResponseHeader{Length: int32(len(newResponseBuf) + 4), CorrelationID: req.CorrelationID}) if err != nil { - return err + return nil, err } if _, err := conn.Write(newHeaderBuf); err != nil { - return err + return nil, err } if _, err := conn.Write(newResponseBuf); err != nil { - return err + return nil, err } - return saslResult + return localSaslAuth, saslResult } -func (p *LocalSasl) receiveAndSendAuthV1(conn DeadlineReaderWriter) (err error) { +func (p *LocalSasl) receiveAndSendAuthV1(conn DeadlineReaderWriter, localSaslAuth LocalSaslAuth) (err error) { requestDeadline := time.Now().Add(p.timeout) err = conn.SetDeadline(requestDeadline) if err != nil { @@ -134,14 +162,15 @@ func (p *LocalSasl) receiveAndSendAuthV1(conn DeadlineReaderWriter) (err error) return err } - authErr := p.doLocalAuth(saslAuthReqV0.SaslAuthBytes) + authErr := localSaslAuth.doLocalAuth(saslAuthReqV0.SaslAuthBytes) var saslAuthResV0 *protocol.SaslAuthenticateResponseV0 if authErr == nil { - saslAuthResV0 = &protocol.SaslAuthenticateResponseV0{Err: protocol.ErrNoError, SaslAuthBytes: make([]byte, 4)} + // Length of SaslAuthBytes !=0 for OAUTHBEARER causes that java SaslClientAuthenticator in INTERMEDIATE state will sent SaslAuthenticate(36) second time + saslAuthResV0 = &protocol.SaslAuthenticateResponseV0{Err: protocol.ErrNoError, SaslAuthBytes: make([]byte, 0)} } else { errMsg := authErr.Error() - saslAuthResV0 = &protocol.SaslAuthenticateResponseV0{Err: protocol.ErrSASLAuthenticationFailed, ErrMsg: &errMsg, SaslAuthBytes: make([]byte, 4)} + saslAuthResV0 = &protocol.SaslAuthenticateResponseV0{Err: protocol.ErrSASLAuthenticationFailed, ErrMsg: &errMsg, SaslAuthBytes: make([]byte, 0)} } newResponseBuf, err := protocol.Encode(saslAuthResV0) @@ -162,7 +191,7 @@ func (p *LocalSasl) receiveAndSendAuthV1(conn DeadlineReaderWriter) (err error) } -func (p *LocalSasl) receiveAndSendAuthV0(conn DeadlineReaderWriter) (err error) { +func (p *LocalSasl) receiveAndSendAuthV0(conn DeadlineReaderWriter, localSaslAuth LocalSaslAuth) (err error) { requestDeadline := time.Now().Add(p.timeout) err = conn.SetDeadline(requestDeadline) if err != nil { @@ -185,7 +214,11 @@ func (p *LocalSasl) receiveAndSendAuthV0(conn DeadlineReaderWriter) (err error) return err } - if err = p.doLocalAuth(saslAuthBytes); err != nil { + if localSaslAuth == nil { + return errors.New("localSaslAuth is nil") + } + + if err = localSaslAuth.doLocalAuth(saslAuthBytes); err != nil { return err } // If the credentials are valid, we would write a 4 byte response filled with null characters. @@ -196,26 +229,3 @@ func (p *LocalSasl) receiveAndSendAuthV0(conn DeadlineReaderWriter) (err error) } return nil } - -func (p *LocalSasl) doLocalAuth(saslAuthBytes []byte) (err error) { - tokens := strings.Split(string(saslAuthBytes), "\x00") - if len(tokens) != 3 { - return fmt.Errorf("invalid SASL/PLAIN request: expected 3 tokens, got %d", len(tokens)) - } - if p.localAuthenticator == nil { - return protocol.PacketDecodingError{Info: "Listener authenticator is not set"} - } - - // logrus.Infof("user: %s , password: %s", tokens[1], tokens[2]) - ok, status, err := p.localAuthenticator.Authenticate(tokens[1], tokens[2]) - if err != nil { - proxyLocalAuthTotal.WithLabelValues("error", "1").Inc() - return err - } - proxyLocalAuthTotal.WithLabelValues(strconv.FormatBool(ok), strconv.Itoa(int(status))).Inc() - - if !ok { - return fmt.Errorf("user %s authentication failed", tokens[1]) - } - return nil -} diff --git a/proxy/sasl_local_auth.go b/proxy/sasl_local_auth.go new file mode 100644 index 00000000..0b2eb49f --- /dev/null +++ b/proxy/sasl_local_auth.go @@ -0,0 +1,76 @@ +package proxy + +import ( + "context" + "fmt" + "github.com/grepplabs/kafka-proxy/pkg/apis" + "github.com/grepplabs/kafka-proxy/proxy/protocol" + "strconv" + "strings" +) + +type LocalSaslAuth interface { + doLocalAuth(saslAuthBytes []byte) (err error) +} + +type LocalSaslPlain struct { + localAuthenticator apis.PasswordAuthenticator +} + +func NewLocalSaslPlain(localAuthenticator apis.PasswordAuthenticator) *LocalSaslPlain { + return &LocalSaslPlain{ + localAuthenticator: localAuthenticator, + } +} + +// implements LocalSaslAuth +func (p *LocalSaslPlain) doLocalAuth(saslAuthBytes []byte) (err error) { + tokens := strings.Split(string(saslAuthBytes), "\x00") + if len(tokens) != 3 { + return fmt.Errorf("invalid SASL/PLAIN request: expected 3 tokens, got %d", len(tokens)) + } + if p.localAuthenticator == nil { + return protocol.PacketDecodingError{Info: "Listener authenticator is not set"} + } + + // logrus.Infof("user: %s , password: %s", tokens[1], tokens[2]) + ok, status, err := p.localAuthenticator.Authenticate(tokens[1], tokens[2]) + if err != nil { + proxyLocalAuthTotal.WithLabelValues("error", "1").Inc() + return err + } + proxyLocalAuthTotal.WithLabelValues(strconv.FormatBool(ok), strconv.Itoa(int(status))).Inc() + + if !ok { + return fmt.Errorf("user %s authentication failed", tokens[1]) + } + return nil +} + +type LocalSaslOauth struct { + saslOAuthBearer SaslOAuthBearer + tokenAuthenticator apis.TokenInfo +} + +func NewLocalSaslOauth(tokenAuthenticator apis.TokenInfo) *LocalSaslOauth { + return &LocalSaslOauth{ + saslOAuthBearer: SaslOAuthBearer{}, + tokenAuthenticator: tokenAuthenticator, + } +} + +// implements LocalSaslAuth +func (p *LocalSaslOauth) doLocalAuth(saslAuthBytes []byte) (err error) { + token, _, _, err := p.saslOAuthBearer.GetClientInitialResponse(saslAuthBytes) + if err != nil { + return err + } + resp, err := p.tokenAuthenticator.VerifyToken(context.Background(), apis.VerifyRequest{Token: token}) + if err != nil { + return err + } + if !resp.Success { + return fmt.Errorf("local oauth verify token failed with status: %d", resp.Status) + } + return nil +} diff --git a/proxy/sasl_oauthbearer.go b/proxy/sasl_oauthbearer.go new file mode 100644 index 00000000..26492b05 --- /dev/null +++ b/proxy/sasl_oauthbearer.go @@ -0,0 +1,115 @@ +package proxy + +import ( + "fmt" + "github.com/pkg/errors" + "regexp" + "strings" +) + +// https://tools.ietf.org/html/rfc7628#section-3.1 +// https://tools.ietf.org/html/rfc5801#section-4 +// https://tools.ietf.org/html/rfc5801 (UTF8-1-safe) +const ( + saslOauthSeparator = "\u0001" + saslOauthSaslName = "(?:[\x01-\x2b]|[\x2d-\x3c]|[\x3e-\x7F]|=2C|=3D)+" + saslOauthKey = "[A-Za-z]+" + saslOauthValue = "[\\x21-\\x7E \t\r\n]+" + saslOauthAuthKey = "auth" +) + +var ( + saslOauthKVPairs = fmt.Sprintf("(%s=%s%s)*", saslOauthKey, saslOauthValue, saslOauthSeparator) + saslOauthAuthPattern = regexp.MustCompile("(?P[\\w]+)[ ]+(?P[-_.a-zA-Z0-9]+)") + saslOauthClientInitialResponsePattern = regexp.MustCompile(fmt.Sprintf("n,(a=(?P%s))?,%s(?P%s)%s", saslOauthSaslName, saslOauthSeparator, saslOauthKVPairs, saslOauthSeparator)) +) + +type SaslOAuthBearer struct{} + +func (p SaslOAuthBearer) GetClientInitialResponse(saslAuthBytes []byte) (token string, authzid string, extensions map[string]string, err error) { + match := saslOauthClientInitialResponsePattern.FindSubmatch(saslAuthBytes) + if len(match) == 0 { + return "", "", nil, errors.New("invalid OAUTHBEARER initial client response: 'saslAuthBytes' parse error") + } + + result := make(map[string][]byte) + for i, name := range saslOauthClientInitialResponsePattern.SubexpNames() { + if i != 0 && name != "" { + if i >= len(match) { + return "", "", nil, errors.New("invalid OAUTHBEARER initial client response: 'SubexpNames' range error") + } + result[name] = match[i] + } + } + + authzid = string(result["authzid"]) + kvpairs := result["kvpairs"] + properties := p.parseMap(string(kvpairs), "=", saslOauthSeparator) + + token, err = p.parseToken(properties[saslOauthAuthKey]) + if err != nil { + return "", "", nil, err + } + delete(properties, saslOauthAuthKey) + return token, authzid, properties, nil +} + +func (SaslOAuthBearer) parseToken(auth string) (string, error) { + if auth == "" { + return "", errors.New("invalid OAUTHBEARER initial client response: 'auth' not specified") + } + match := saslOauthAuthPattern.FindStringSubmatch(auth) + result := make(map[string]string) + for i, name := range saslOauthAuthPattern.SubexpNames() { + if i != 0 && name != "" { + result[name] = match[i] + } + } + if !strings.EqualFold(result["scheme"], "bearer") { + return "", fmt.Errorf("invalid scheme in OAUTHBEARER initial client response: %s", result["scheme"]) + } + token := result["token"] + if token == "" { + return "", errors.New("invalid OAUTHBEARER initial client response: 'token' is missing") + } + return token, nil +} + +func (SaslOAuthBearer) parseMap(mapStr string, keyValueSeparator string, elementSeparator string) map[string]string { + result := make(map[string]string) + if mapStr == "" { + return result + } + for _, attrval := range strings.Split(mapStr, elementSeparator) { + kv := strings.SplitN(attrval, keyValueSeparator, 2) + if len(kv) == 2 { + result[kv[0]] = kv[1] + } + } + return result +} + +func (SaslOAuthBearer) mkString(mapValues map[string]string, keyValueSeparator string, elementSeparator string) string { + if len(mapValues) == 0 { + return "" + } + elements := make([]string, 0, len(mapValues)) + for k, v := range mapValues { + elements = append(elements, strings.Join([]string{k, v}, keyValueSeparator)) + } + return strings.Join(elements, elementSeparator) +} + +func (p SaslOAuthBearer) ToBytes(tokenValue string, authorizationId string, saslExtensions map[string]string) []byte { + authzid := authorizationId + if authzid != "" { + authzid = "a=" + authorizationId + } + extensions := p.mkString(saslExtensions, "=", saslOauthSeparator) + if extensions != "" { + extensions = saslOauthSeparator + extensions + } + message := fmt.Sprintf("n,%s,%sauth=Bearer %s%s%s%s", authzid, + saslOauthSeparator, tokenValue, extensions, saslOauthSeparator, saslOauthSeparator) + return []byte(message) +} diff --git a/proxy/sasl_oauthbearer_test.go b/proxy/sasl_oauthbearer_test.go new file mode 100644 index 00000000..9f09ee97 --- /dev/null +++ b/proxy/sasl_oauthbearer_test.go @@ -0,0 +1,72 @@ +package proxy + +import ( + "encoding/hex" + "github.com/stretchr/testify/assert" + "testing" +) + +func TestSaslOauthParseToken(t *testing.T) { + a := assert.New(t) + + saslBytes := "6e2c2c01617574683d4265617265722065794a68624763694f694a756232356c496e302e65794a6c654841694f6a45754e544d354e5445324e6a6b304e44453452546b73496d6c68644349364d5334314d7a6b314d544d774f5451304d5468464f53776963335669496a6f695957787059325579496e302e0101" + saslAuthBytes, err := hex.DecodeString(saslBytes) + a.Nil(err) + + token, authzid, extensions, err := SaslOAuthBearer{}.GetClientInitialResponse(saslAuthBytes) + a.Nil(err) + a.Empty(authzid) + a.Empty(extensions) + a.Equal("eyJhbGciOiJub25lIn0.eyJleHAiOjEuNTM5NTE2Njk0NDE4RTksImlhdCI6MS41Mzk1MTMwOTQ0MThFOSwic3ViIjoiYWxpY2UyIn0.", token) + + a.Equal(saslAuthBytes, SaslOAuthBearer{}.ToBytes(token, authzid, extensions)) + +} +func TestSaslOAuthBearerToBytes(t *testing.T) { + a := assert.New(t) + token, authzid, extensions, err := SaslOAuthBearer{}.GetClientInitialResponse([]byte("n,,\u0001auth=Bearer 123.345.567\u0001nineteen=42\u0001\u0001")) + a.Nil(err) + a.Equal("123.345.567", token) + a.Empty(authzid) + a.Equal(map[string]string{"nineteen": "42"}, extensions) +} + +func TestSaslOAuthBearerAuthorizationId(t *testing.T) { + a := assert.New(t) + token, authzid, extensions, err := SaslOAuthBearer{}.GetClientInitialResponse([]byte("n,a=myuser,\u0001auth=Bearer 345\u0001\u0001")) + a.Nil(err) + a.Equal("345", token) + a.Equal("myuser", authzid) + a.Empty(extensions) +} + +func TestSaslOAuthBearerExtensions(t *testing.T) { + a := assert.New(t) + token, authzid, extensions, err := SaslOAuthBearer{}.GetClientInitialResponse([]byte("n,,\u0001propA=valueA1, valueA2\u0001auth=Bearer 567\u0001propB=valueB\u0001\u0001")) + a.Nil(err) + a.Equal("567", token) + a.Empty(authzid) + a.Equal(map[string]string{"propA": "valueA1, valueA2", "propB": "valueB"}, extensions) +} + +func TestSaslOAuthBearerRfc7688Example(t *testing.T) { + a := assert.New(t) + message := "n,a=user@example.com,\u0001host=server.example.com\u0001port=143\u0001" + + "auth=Bearer vF9dft4qmTc2Nvb3RlckBhbHRhdmlzdGEuY29tCg\u0001\u0001" + token, authzid, extensions, err := SaslOAuthBearer{}.GetClientInitialResponse([]byte(message)) + a.Nil(err) + a.Equal("vF9dft4qmTc2Nvb3RlckBhbHRhdmlzdGEuY29tCg", token) + a.Equal("user@example.com", authzid) + a.Equal(map[string]string{"host": "server.example.com", "port": "143"}, extensions) +} + +func TestSaslOAuthBearerNoExtensionsFromByteArray(t *testing.T) { + a := assert.New(t) + message := "n,a=user@example.com,\u0001" + + "auth=Bearer vF9dft4qmTc2Nvb3RlckBhbHRhdmlzdGEuY29tCg\u0001\u0001" + token, authzid, extensions, err := SaslOAuthBearer{}.GetClientInitialResponse([]byte(message)) + a.Nil(err) + a.Equal("vF9dft4qmTc2Nvb3RlckBhbHRhdmlzdGEuY29tCg", token) + a.Equal("user@example.com", authzid) + a.Empty(extensions) +}