diff --git a/AUTHORS b/AUTHORS index ad5989800..c144c46ab 100644 --- a/AUTHORS +++ b/AUTHORS @@ -103,3 +103,4 @@ Multiplay Ltd. Percona LLC Pivotal Inc. Stripe Inc. +Zendesk Inc. diff --git a/README.md b/README.md index 2d15ffda3..52aeaf82d 100644 --- a/README.md +++ b/README.md @@ -209,6 +209,16 @@ SELECT u.id FROM users as u will return `u.id` instead of just `id` if `columnsWithAlias=true`. +#### `credentialProvider` + +``` +Type: string +Valid Values: +Default: "" +``` + +If set, this must refer to a credential provider name registerd with `RegisterCredentialProvider`. When this is set, the username and password in the DSN will be ignored; instead, each time a conneciton is to be opened, the named credential provider function will be called to obtain a username/password to connect with. This is useful when using, for example, IAM database auth in Amazon AWS, where "passwords" are actually temporary tokens that expire. + ##### `interpolateParams` ``` diff --git a/auth.go b/auth.go index fec7040d4..514da77e9 100644 --- a/auth.go +++ b/auth.go @@ -15,13 +15,16 @@ import ( "crypto/sha256" "crypto/x509" "encoding/pem" + "fmt" "sync" ) // server pub keys registry var ( - serverPubKeyLock sync.RWMutex - serverPubKeyRegistry map[string]*rsa.PublicKey + serverPubKeyLock sync.RWMutex + serverPubKeyRegistry map[string]*rsa.PublicKey + credentialProviderLock sync.RWMutex + credentialProviderRetistry map[string]CredentialProviderFunc ) // RegisterServerPubKey registers a server RSA public key which can be used to @@ -81,6 +84,44 @@ func getServerPubKey(name string) (pubKey *rsa.PublicKey) { return } +// CredentialProviderFunc is a function which can be used to fetch a username/password +// pair for use when opening a new MySQL connection. The first return value is the username +// and the second the password. +type CredentialProviderFunc func() (string, string, error) + +// RegisterCredentialProvider registers a function to be called on every connection open to +// get the username and password to call +func RegisterCredentialProvider(name string, providerFunc CredentialProviderFunc) { + credentialProviderLock.Lock() + if credentialProviderRetistry == nil { + credentialProviderRetistry = make(map[string]CredentialProviderFunc) + } + credentialProviderRetistry[name] = providerFunc + credentialProviderLock.Unlock() +} + +// DeregisterCredentialProvider removes a function registered with RegisterCredentialProvider +func DeregisterCredentialProvider(name string) { + credentialProviderLock.Lock() + if credentialProviderRetistry != nil { + delete(credentialProviderRetistry, name) + } + credentialProviderLock.Unlock() +} + +func getCredentialsFromConfig(cfg *Config) (string, string, error) { + if cfg.CredentialProvider != "" { + credentialProviderLock.RLock() + defer credentialProviderLock.RUnlock() + cpFunc, ok := credentialProviderRetistry[cfg.CredentialProvider] + if !ok { + return "", "", fmt.Errorf("credential provider %s not registered", cfg.CredentialProvider) + } + return cpFunc() + } + return cfg.User, cfg.Passwd, nil +} + // Hash password using pre 4.1 (old password) method // https://github.com/atcurtis/mariadb/blob/master/mysys/my_rnd.c type myRnd struct { @@ -237,10 +278,10 @@ func (mc *mysqlConn) sendEncryptedPassword(seed []byte, pub *rsa.PublicKey) erro return mc.writeAuthSwitchPacket(enc) } -func (mc *mysqlConn) auth(authData []byte, plugin string) ([]byte, error) { +func (mc *mysqlConn) auth(authData []byte, plugin string, password string) ([]byte, error) { switch plugin { case "caching_sha2_password": - authResp := scrambleSHA256Password(authData, mc.cfg.Passwd) + authResp := scrambleSHA256Password(authData, password) return authResp, nil case "mysql_old_password": @@ -250,7 +291,7 @@ func (mc *mysqlConn) auth(authData []byte, plugin string) ([]byte, error) { // Note: there are edge cases where this should work but doesn't; // this is currently "wontfix": // https://github.com/go-sql-driver/mysql/issues/184 - authResp := append(scrambleOldPassword(authData[:8], mc.cfg.Passwd), 0) + authResp := append(scrambleOldPassword(authData[:8], password), 0) return authResp, nil case "mysql_clear_password": @@ -259,7 +300,7 @@ func (mc *mysqlConn) auth(authData []byte, plugin string) ([]byte, error) { } // http://dev.mysql.com/doc/refman/5.7/en/cleartext-authentication-plugin.html // http://dev.mysql.com/doc/refman/5.7/en/pam-authentication-plugin.html - return append([]byte(mc.cfg.Passwd), 0), nil + return append([]byte(password), 0), nil case "mysql_native_password": if !mc.cfg.AllowNativePasswords { @@ -267,16 +308,16 @@ func (mc *mysqlConn) auth(authData []byte, plugin string) ([]byte, error) { } // https://dev.mysql.com/doc/internals/en/secure-password-authentication.html // Native password authentication only need and will need 20-byte challenge. - authResp := scramblePassword(authData[:20], mc.cfg.Passwd) + authResp := scramblePassword(authData[:20], password) return authResp, nil case "sha256_password": - if len(mc.cfg.Passwd) == 0 { + if len(password) == 0 { return []byte{0}, nil } if mc.cfg.tls != nil || mc.cfg.Net == "unix" { // write cleartext auth packet - return append([]byte(mc.cfg.Passwd), 0), nil + return append([]byte(password), 0), nil } pubKey := mc.cfg.pubKey @@ -286,7 +327,7 @@ func (mc *mysqlConn) auth(authData []byte, plugin string) ([]byte, error) { } // encrypted password - enc, err := encryptPassword(mc.cfg.Passwd, authData, pubKey) + enc, err := encryptPassword(password, authData, pubKey) return enc, err default: @@ -295,7 +336,7 @@ func (mc *mysqlConn) auth(authData []byte, plugin string) ([]byte, error) { } } -func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error { +func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string, password string) error { // Read Result Packet authData, newPlugin, err := mc.readAuthResult() if err != nil { @@ -315,7 +356,7 @@ func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error { plugin = newPlugin - authResp, err := mc.auth(authData, plugin) + authResp, err := mc.auth(authData, plugin, password) if err != nil { return err } @@ -352,7 +393,7 @@ func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error { case cachingSha2PasswordPerformFullAuthentication: if mc.cfg.tls != nil || mc.cfg.Net == "unix" { // write cleartext auth packet - err = mc.writeAuthSwitchPacket(append([]byte(mc.cfg.Passwd), 0)) + err = mc.writeAuthSwitchPacket(append([]byte(password), 0)) if err != nil { return err } diff --git a/auth_test.go b/auth_test.go index 1920ef39f..c1d08a454 100644 --- a/auth_test.go +++ b/auth_test.go @@ -85,11 +85,11 @@ func TestAuthFastCachingSHA256PasswordCached(t *testing.T) { plugin := "caching_sha2_password" // Send Client Authentication Packet - authResp, err := mc.auth(authData, plugin) + authResp, err := mc.auth(authData, plugin, mc.cfg.Passwd) if err != nil { t.Fatal(err) } - err = mc.writeHandshakeResponsePacket(authResp, plugin) + err = mc.writeHandshakeResponsePacket(authResp, plugin, mc.cfg.User) if err != nil { t.Fatal(err) } @@ -115,7 +115,7 @@ func TestAuthFastCachingSHA256PasswordCached(t *testing.T) { conn.maxReads = 1 // Handle response to auth packet - if err := mc.handleAuthResult(authData, plugin); err != nil { + if err := mc.handleAuthResult(authData, plugin, mc.cfg.Passwd); err != nil { t.Errorf("got error: %v", err) } } @@ -130,11 +130,11 @@ func TestAuthFastCachingSHA256PasswordEmpty(t *testing.T) { plugin := "caching_sha2_password" // Send Client Authentication Packet - authResp, err := mc.auth(authData, plugin) + authResp, err := mc.auth(authData, plugin, mc.cfg.Passwd) if err != nil { t.Fatal(err) } - err = mc.writeHandshakeResponsePacket(authResp, plugin) + err = mc.writeHandshakeResponsePacket(authResp, plugin, mc.cfg.User) if err != nil { t.Fatal(err) } @@ -157,7 +157,7 @@ func TestAuthFastCachingSHA256PasswordEmpty(t *testing.T) { conn.maxReads = 1 // Handle response to auth packet - if err := mc.handleAuthResult(authData, plugin); err != nil { + if err := mc.handleAuthResult(authData, plugin, mc.cfg.Passwd); err != nil { t.Errorf("got error: %v", err) } } @@ -172,11 +172,11 @@ func TestAuthFastCachingSHA256PasswordFullRSA(t *testing.T) { plugin := "caching_sha2_password" // Send Client Authentication Packet - authResp, err := mc.auth(authData, plugin) + authResp, err := mc.auth(authData, plugin, mc.cfg.Passwd) if err != nil { t.Fatal(err) } - err = mc.writeHandshakeResponsePacket(authResp, plugin) + err = mc.writeHandshakeResponsePacket(authResp, plugin, mc.cfg.User) if err != nil { t.Fatal(err) } @@ -208,7 +208,7 @@ func TestAuthFastCachingSHA256PasswordFullRSA(t *testing.T) { conn.maxReads = 3 // Handle response to auth packet - if err := mc.handleAuthResult(authData, plugin); err != nil { + if err := mc.handleAuthResult(authData, plugin, mc.cfg.Passwd); err != nil { t.Errorf("got error: %v", err) } @@ -228,11 +228,11 @@ func TestAuthFastCachingSHA256PasswordFullRSAWithKey(t *testing.T) { plugin := "caching_sha2_password" // Send Client Authentication Packet - authResp, err := mc.auth(authData, plugin) + authResp, err := mc.auth(authData, plugin, mc.cfg.Passwd) if err != nil { t.Fatal(err) } - err = mc.writeHandshakeResponsePacket(authResp, plugin) + err = mc.writeHandshakeResponsePacket(authResp, plugin, mc.cfg.User) if err != nil { t.Fatal(err) } @@ -261,7 +261,7 @@ func TestAuthFastCachingSHA256PasswordFullRSAWithKey(t *testing.T) { conn.maxReads = 2 // Handle response to auth packet - if err := mc.handleAuthResult(authData, plugin); err != nil { + if err := mc.handleAuthResult(authData, plugin, mc.cfg.Passwd); err != nil { t.Errorf("got error: %v", err) } @@ -280,11 +280,11 @@ func TestAuthFastCachingSHA256PasswordFullSecure(t *testing.T) { plugin := "caching_sha2_password" // Send Client Authentication Packet - authResp, err := mc.auth(authData, plugin) + authResp, err := mc.auth(authData, plugin, mc.cfg.Passwd) if err != nil { t.Fatal(err) } - err = mc.writeHandshakeResponsePacket(authResp, plugin) + err = mc.writeHandshakeResponsePacket(authResp, plugin, mc.cfg.User) if err != nil { t.Fatal(err) } @@ -317,7 +317,7 @@ func TestAuthFastCachingSHA256PasswordFullSecure(t *testing.T) { conn.maxReads = 3 // Handle response to auth packet - if err := mc.handleAuthResult(authData, plugin); err != nil { + if err := mc.handleAuthResult(authData, plugin, mc.cfg.Passwd); err != nil { t.Errorf("got error: %v", err) } @@ -336,7 +336,7 @@ func TestAuthFastCleartextPasswordNotAllowed(t *testing.T) { plugin := "mysql_clear_password" // Send Client Authentication Packet - _, err := mc.auth(authData, plugin) + _, err := mc.auth(authData, plugin, mc.cfg.Passwd) if err != ErrCleartextPassword { t.Errorf("expected ErrCleartextPassword, got %v", err) } @@ -353,11 +353,11 @@ func TestAuthFastCleartextPassword(t *testing.T) { plugin := "mysql_clear_password" // Send Client Authentication Packet - authResp, err := mc.auth(authData, plugin) + authResp, err := mc.auth(authData, plugin, mc.cfg.Passwd) if err != nil { t.Fatal(err) } - err = mc.writeHandshakeResponsePacket(authResp, plugin) + err = mc.writeHandshakeResponsePacket(authResp, plugin, mc.cfg.User) if err != nil { t.Fatal(err) } @@ -380,7 +380,7 @@ func TestAuthFastCleartextPassword(t *testing.T) { conn.maxReads = 1 // Handle response to auth packet - if err := mc.handleAuthResult(authData, plugin); err != nil { + if err := mc.handleAuthResult(authData, plugin, mc.cfg.Passwd); err != nil { t.Errorf("got error: %v", err) } } @@ -396,11 +396,11 @@ func TestAuthFastCleartextPasswordEmpty(t *testing.T) { plugin := "mysql_clear_password" // Send Client Authentication Packet - authResp, err := mc.auth(authData, plugin) + authResp, err := mc.auth(authData, plugin, mc.cfg.Passwd) if err != nil { t.Fatal(err) } - err = mc.writeHandshakeResponsePacket(authResp, plugin) + err = mc.writeHandshakeResponsePacket(authResp, plugin, mc.cfg.User) if err != nil { t.Fatal(err) } @@ -423,7 +423,7 @@ func TestAuthFastCleartextPasswordEmpty(t *testing.T) { conn.maxReads = 1 // Handle response to auth packet - if err := mc.handleAuthResult(authData, plugin); err != nil { + if err := mc.handleAuthResult(authData, plugin, mc.cfg.Passwd); err != nil { t.Errorf("got error: %v", err) } } @@ -439,7 +439,7 @@ func TestAuthFastNativePasswordNotAllowed(t *testing.T) { plugin := "mysql_native_password" // Send Client Authentication Packet - _, err := mc.auth(authData, plugin) + _, err := mc.auth(authData, plugin, mc.cfg.Passwd) if err != ErrNativePassword { t.Errorf("expected ErrNativePassword, got %v", err) } @@ -455,11 +455,11 @@ func TestAuthFastNativePassword(t *testing.T) { plugin := "mysql_native_password" // Send Client Authentication Packet - authResp, err := mc.auth(authData, plugin) + authResp, err := mc.auth(authData, plugin, mc.cfg.Passwd) if err != nil { t.Fatal(err) } - err = mc.writeHandshakeResponsePacket(authResp, plugin) + err = mc.writeHandshakeResponsePacket(authResp, plugin, mc.cfg.User) if err != nil { t.Fatal(err) } @@ -483,7 +483,7 @@ func TestAuthFastNativePassword(t *testing.T) { conn.maxReads = 1 // Handle response to auth packet - if err := mc.handleAuthResult(authData, plugin); err != nil { + if err := mc.handleAuthResult(authData, plugin, mc.cfg.Passwd); err != nil { t.Errorf("got error: %v", err) } } @@ -498,11 +498,11 @@ func TestAuthFastNativePasswordEmpty(t *testing.T) { plugin := "mysql_native_password" // Send Client Authentication Packet - authResp, err := mc.auth(authData, plugin) + authResp, err := mc.auth(authData, plugin, mc.cfg.Passwd) if err != nil { t.Fatal(err) } - err = mc.writeHandshakeResponsePacket(authResp, plugin) + err = mc.writeHandshakeResponsePacket(authResp, plugin, mc.cfg.User) if err != nil { t.Fatal(err) } @@ -525,7 +525,7 @@ func TestAuthFastNativePasswordEmpty(t *testing.T) { conn.maxReads = 1 // Handle response to auth packet - if err := mc.handleAuthResult(authData, plugin); err != nil { + if err := mc.handleAuthResult(authData, plugin, mc.cfg.Passwd); err != nil { t.Errorf("got error: %v", err) } } @@ -540,11 +540,11 @@ func TestAuthFastSHA256PasswordEmpty(t *testing.T) { plugin := "sha256_password" // Send Client Authentication Packet - authResp, err := mc.auth(authData, plugin) + authResp, err := mc.auth(authData, plugin, mc.cfg.Passwd) if err != nil { t.Fatal(err) } - err = mc.writeHandshakeResponsePacket(authResp, plugin) + err = mc.writeHandshakeResponsePacket(authResp, plugin, mc.cfg.User) if err != nil { t.Fatal(err) } @@ -569,7 +569,7 @@ func TestAuthFastSHA256PasswordEmpty(t *testing.T) { conn.maxReads = 2 // Handle response to auth packet - if err := mc.handleAuthResult(authData, plugin); err != nil { + if err := mc.handleAuthResult(authData, plugin, mc.cfg.Passwd); err != nil { t.Errorf("got error: %v", err) } @@ -588,11 +588,11 @@ func TestAuthFastSHA256PasswordRSA(t *testing.T) { plugin := "sha256_password" // Send Client Authentication Packet - authResp, err := mc.auth(authData, plugin) + authResp, err := mc.auth(authData, plugin, mc.cfg.Passwd) if err != nil { t.Fatal(err) } - err = mc.writeHandshakeResponsePacket(authResp, plugin) + err = mc.writeHandshakeResponsePacket(authResp, plugin, mc.cfg.User) if err != nil { t.Fatal(err) } @@ -617,7 +617,7 @@ func TestAuthFastSHA256PasswordRSA(t *testing.T) { conn.maxReads = 2 // Handle response to auth packet - if err := mc.handleAuthResult(authData, plugin); err != nil { + if err := mc.handleAuthResult(authData, plugin, mc.cfg.Passwd); err != nil { t.Errorf("got error: %v", err) } @@ -637,11 +637,11 @@ func TestAuthFastSHA256PasswordRSAWithKey(t *testing.T) { plugin := "sha256_password" // Send Client Authentication Packet - authResp, err := mc.auth(authData, plugin) + authResp, err := mc.auth(authData, plugin, mc.cfg.Passwd) if err != nil { t.Fatal(err) } - err = mc.writeHandshakeResponsePacket(authResp, plugin) + err = mc.writeHandshakeResponsePacket(authResp, plugin, mc.cfg.User) if err != nil { t.Fatal(err) } @@ -651,7 +651,7 @@ func TestAuthFastSHA256PasswordRSAWithKey(t *testing.T) { conn.maxReads = 1 // Handle response to auth packet - if err := mc.handleAuthResult(authData, plugin); err != nil { + if err := mc.handleAuthResult(authData, plugin, mc.cfg.Passwd); err != nil { t.Errorf("got error: %v", err) } } @@ -670,7 +670,7 @@ func TestAuthFastSHA256PasswordSecure(t *testing.T) { plugin := "sha256_password" // send Client Authentication Packet - authResp, err := mc.auth(authData, plugin) + authResp, err := mc.auth(authData, plugin, mc.cfg.Passwd) if err != nil { t.Fatal(err) } @@ -678,7 +678,7 @@ func TestAuthFastSHA256PasswordSecure(t *testing.T) { // unset TLS config to prevent the actual establishment of a TLS wrapper mc.cfg.tls = nil - err = mc.writeHandshakeResponsePacket(authResp, plugin) + err = mc.writeHandshakeResponsePacket(authResp, plugin, mc.cfg.User) if err != nil { t.Fatal(err) } @@ -699,7 +699,7 @@ func TestAuthFastSHA256PasswordSecure(t *testing.T) { conn.maxReads = 1 // Handle response to auth packet - if err := mc.handleAuthResult(authData, plugin); err != nil { + if err := mc.handleAuthResult(authData, plugin, mc.cfg.Passwd); err != nil { t.Errorf("got error: %v", err) } @@ -728,7 +728,7 @@ func TestAuthSwitchCachingSHA256PasswordCached(t *testing.T) { 47, 43, 9, 41, 112, 67, 110} plugin := "mysql_native_password" - if err := mc.handleAuthResult(authData, plugin); err != nil { + if err := mc.handleAuthResult(authData, plugin, mc.cfg.Passwd); err != nil { t.Errorf("got error: %v", err) } @@ -761,7 +761,7 @@ func TestAuthSwitchCachingSHA256PasswordEmpty(t *testing.T) { 47, 43, 9, 41, 112, 67, 110} plugin := "mysql_native_password" - if err := mc.handleAuthResult(authData, plugin); err != nil { + if err := mc.handleAuthResult(authData, plugin, mc.cfg.Passwd); err != nil { t.Errorf("got error: %v", err) } @@ -797,7 +797,7 @@ func TestAuthSwitchCachingSHA256PasswordFullRSA(t *testing.T) { 47, 43, 9, 41, 112, 67, 110} plugin := "mysql_native_password" - if err := mc.handleAuthResult(authData, plugin); err != nil { + if err := mc.handleAuthResult(authData, plugin, mc.cfg.Passwd); err != nil { t.Errorf("got error: %v", err) } @@ -842,7 +842,7 @@ func TestAuthSwitchCachingSHA256PasswordFullRSAWithKey(t *testing.T) { 47, 43, 9, 41, 112, 67, 110} plugin := "mysql_native_password" - if err := mc.handleAuthResult(authData, plugin); err != nil { + if err := mc.handleAuthResult(authData, plugin, mc.cfg.Passwd); err != nil { t.Errorf("got error: %v", err) } @@ -885,7 +885,7 @@ func TestAuthSwitchCachingSHA256PasswordFullSecure(t *testing.T) { 47, 43, 9, 41, 112, 67, 110} plugin := "mysql_native_password" - if err := mc.handleAuthResult(authData, plugin); err != nil { + if err := mc.handleAuthResult(authData, plugin, mc.cfg.Passwd); err != nil { t.Errorf("got error: %v", err) } @@ -912,7 +912,7 @@ func TestAuthSwitchCleartextPasswordNotAllowed(t *testing.T) { authData := []byte{123, 87, 15, 84, 20, 58, 37, 121, 91, 117, 51, 24, 19, 47, 43, 9, 41, 112, 67, 110} plugin := "mysql_native_password" - err := mc.handleAuthResult(authData, plugin) + err := mc.handleAuthResult(authData, plugin, mc.cfg.Passwd) if err != ErrCleartextPassword { t.Errorf("expected ErrCleartextPassword, got %v", err) } @@ -935,7 +935,7 @@ func TestAuthSwitchCleartextPassword(t *testing.T) { 47, 43, 9, 41, 112, 67, 110} plugin := "mysql_native_password" - if err := mc.handleAuthResult(authData, plugin); err != nil { + if err := mc.handleAuthResult(authData, plugin, mc.cfg.Passwd); err != nil { t.Errorf("got error: %v", err) } @@ -962,7 +962,7 @@ func TestAuthSwitchCleartextPasswordEmpty(t *testing.T) { 47, 43, 9, 41, 112, 67, 110} plugin := "mysql_native_password" - if err := mc.handleAuthResult(authData, plugin); err != nil { + if err := mc.handleAuthResult(authData, plugin, mc.cfg.Passwd); err != nil { t.Errorf("got error: %v", err) } @@ -984,7 +984,7 @@ func TestAuthSwitchNativePasswordNotAllowed(t *testing.T) { authData := []byte{96, 71, 63, 8, 1, 58, 75, 12, 69, 95, 66, 60, 117, 31, 48, 31, 89, 39, 55, 31} plugin := "caching_sha2_password" - err := mc.handleAuthResult(authData, plugin) + err := mc.handleAuthResult(authData, plugin, mc.cfg.Passwd) if err != ErrNativePassword { t.Errorf("expected ErrNativePassword, got %v", err) } @@ -1009,7 +1009,7 @@ func TestAuthSwitchNativePassword(t *testing.T) { 48, 31, 89, 39, 55, 31} plugin := "caching_sha2_password" - if err := mc.handleAuthResult(authData, plugin); err != nil { + if err := mc.handleAuthResult(authData, plugin, mc.cfg.Passwd); err != nil { t.Errorf("got error: %v", err) } @@ -1039,7 +1039,7 @@ func TestAuthSwitchNativePasswordEmpty(t *testing.T) { 48, 31, 89, 39, 55, 31} plugin := "caching_sha2_password" - if err := mc.handleAuthResult(authData, plugin); err != nil { + if err := mc.handleAuthResult(authData, plugin, mc.cfg.Passwd); err != nil { t.Errorf("got error: %v", err) } @@ -1059,7 +1059,7 @@ func TestAuthSwitchOldPasswordNotAllowed(t *testing.T) { authData := []byte{95, 84, 103, 43, 61, 49, 123, 61, 91, 50, 40, 113, 35, 84, 96, 101, 92, 123, 121, 107} plugin := "mysql_native_password" - err := mc.handleAuthResult(authData, plugin) + err := mc.handleAuthResult(authData, plugin, mc.cfg.Passwd) if err != ErrOldPassword { t.Errorf("expected ErrOldPassword, got %v", err) } @@ -1075,7 +1075,7 @@ func TestOldAuthSwitchNotAllowed(t *testing.T) { authData := []byte{95, 84, 103, 43, 61, 49, 123, 61, 91, 50, 40, 113, 35, 84, 96, 101, 92, 123, 121, 107} plugin := "mysql_native_password" - err := mc.handleAuthResult(authData, plugin) + err := mc.handleAuthResult(authData, plugin, mc.cfg.Passwd) if err != ErrOldPassword { t.Errorf("expected ErrOldPassword, got %v", err) } @@ -1099,7 +1099,7 @@ func TestAuthSwitchOldPassword(t *testing.T) { 84, 96, 101, 92, 123, 121, 107} plugin := "mysql_native_password" - if err := mc.handleAuthResult(authData, plugin); err != nil { + if err := mc.handleAuthResult(authData, plugin, mc.cfg.Passwd); err != nil { t.Errorf("got error: %v", err) } @@ -1126,7 +1126,7 @@ func TestOldAuthSwitch(t *testing.T) { 84, 96, 101, 92, 123, 121, 107} plugin := "mysql_native_password" - if err := mc.handleAuthResult(authData, plugin); err != nil { + if err := mc.handleAuthResult(authData, plugin, mc.cfg.Passwd); err != nil { t.Errorf("got error: %v", err) } @@ -1153,7 +1153,7 @@ func TestAuthSwitchOldPasswordEmpty(t *testing.T) { 84, 96, 101, 92, 123, 121, 107} plugin := "mysql_native_password" - if err := mc.handleAuthResult(authData, plugin); err != nil { + if err := mc.handleAuthResult(authData, plugin, mc.cfg.Passwd); err != nil { t.Errorf("got error: %v", err) } @@ -1180,7 +1180,7 @@ func TestOldAuthSwitchPasswordEmpty(t *testing.T) { 84, 96, 101, 92, 123, 121, 107} plugin := "mysql_native_password" - if err := mc.handleAuthResult(authData, plugin); err != nil { + if err := mc.handleAuthResult(authData, plugin, mc.cfg.Passwd); err != nil { t.Errorf("got error: %v", err) } @@ -1209,7 +1209,7 @@ func TestAuthSwitchSHA256PasswordEmpty(t *testing.T) { 47, 43, 9, 41, 112, 67, 110} plugin := "mysql_native_password" - if err := mc.handleAuthResult(authData, plugin); err != nil { + if err := mc.handleAuthResult(authData, plugin, mc.cfg.Passwd); err != nil { t.Errorf("got error: %v", err) } @@ -1244,7 +1244,7 @@ func TestAuthSwitchSHA256PasswordRSA(t *testing.T) { 47, 43, 9, 41, 112, 67, 110} plugin := "mysql_native_password" - if err := mc.handleAuthResult(authData, plugin); err != nil { + if err := mc.handleAuthResult(authData, plugin, mc.cfg.Passwd); err != nil { t.Errorf("got error: %v", err) } @@ -1280,7 +1280,7 @@ func TestAuthSwitchSHA256PasswordRSAWithKey(t *testing.T) { 47, 43, 9, 41, 112, 67, 110} plugin := "mysql_native_password" - if err := mc.handleAuthResult(authData, plugin); err != nil { + if err := mc.handleAuthResult(authData, plugin, mc.cfg.Passwd); err != nil { t.Errorf("got error: %v", err) } @@ -1316,7 +1316,7 @@ func TestAuthSwitchSHA256PasswordSecure(t *testing.T) { 47, 43, 9, 41, 112, 67, 110} plugin := "mysql_native_password" - if err := mc.handleAuthResult(authData, plugin); err != nil { + if err := mc.handleAuthResult(authData, plugin, mc.cfg.Passwd); err != nil { t.Errorf("got error: %v", err) } diff --git a/connector.go b/connector.go index d567b4e4f..38715e9a5 100644 --- a/connector.go +++ b/connector.go @@ -88,25 +88,32 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { plugin = defaultAuthPlugin } + user, password, err := getCredentialsFromConfig(c.cfg) + if err != nil { + mc.cleanup() + return nil, err + } + // Send Client Authentication Packet - authResp, err := mc.auth(authData, plugin) + authResp, err := mc.auth(authData, plugin, password) if err != nil { // try the default auth plugin, if using the requested plugin failed errLog.Print("could not use requested auth plugin '"+plugin+"': ", err.Error()) plugin = defaultAuthPlugin - authResp, err = mc.auth(authData, plugin) + authResp, err = mc.auth(authData, plugin, password) if err != nil { mc.cleanup() return nil, err } } - if err = mc.writeHandshakeResponsePacket(authResp, plugin); err != nil { + + if err = mc.writeHandshakeResponsePacket(authResp, plugin, user); err != nil { mc.cleanup() return nil, err } // Handle response to auth packet, switch methods if possible - if err = mc.handleAuthResult(authData, plugin); err != nil { + if err = mc.handleAuthResult(authData, plugin, password); err != nil { // Authentication failed and MySQL has already closed the connection // (https://dev.mysql.com/doc/internals/en/authentication-fails.html). // Do not send COM_QUIT, just cleanup and return the error. diff --git a/driver_test.go b/driver_test.go index ace083dfc..63614c04b 100644 --- a/driver_test.go +++ b/driver_test.go @@ -3163,3 +3163,75 @@ func TestConnectorTimeoutsWatchCancel(t *testing.T) { t.Errorf("connection not closed") } } + +func TestCredentialProviderFunc(t *testing.T) { + // Our test provider func should return a valid password, then an invalid one, then a valid one + // to test that it really is having an effect. + shouldFailCreds := false + shouldFailError := false + RegisterCredentialProvider("TestCredentialProviderFunc", func() (string, string, error) { + if shouldFailCreds { + return "fail", "fail", nil + } + if shouldFailError { + return "", "", fmt.Errorf("credential_error") + } + return user, pass, nil + }) + defer DeregisterCredentialProvider("TestCredentialProviderFunc") + dsn := fmt.Sprintf("%s/%s?timeout=30s&credentialProvider=TestCredentialProviderFunc", netAddr, dbname) + runTests(t, dsn, func(dbt *DBTest) { + ctx := context.Background() + c1, err := dbt.db.Conn(ctx) + if err != nil { + dbt.Fatalf("error opening conn: %s", err) + } + defer c1.Close() + + rows, err := c1.QueryContext(ctx, "SELECT USER()") + if err != nil { + dbt.Fatalf("error running SELECT USER(): %s", err) + } + connUserAndHost := "" + for rows.Next() { + err := rows.Scan(&connUserAndHost) + if err != nil { + dbt.Fatalf("error running query: %s", err) + } + } + parts := strings.Split(connUserAndHost, "@") + connUser := strings.Join(parts[:len(parts)-1], "@") + if connUser != user { + dbt.Errorf("USER() and credentials don't match: %s != %s", connUser, user) + } + + // open one that should fail (wrong creds) + shouldFailCreds = true + _, err = dbt.db.Conn(ctx) + shouldFailCreds = false + if err == nil { + dbt.Errorf("expected second open to fail") + } + + // open one that should fail (with an error) + shouldFailError = true + _, err = dbt.db.Conn(ctx) + shouldFailError = false + if err == nil { + dbt.Errorf("expected third open to fail") + } + if !strings.Contains(err.Error(), "credential_error") { + dbt.Errorf("expected third open to fail with credential_error") + } + + c4, err := dbt.db.Conn(ctx) + if err != nil { + dbt.Fatalf("error opening conn: %s", err) + } + defer c4.Close() + err = c4.PingContext(ctx) + if err != nil { + dbt.Errorf("error running PingContext: %s", err) + } + }) +} diff --git a/dsn.go b/dsn.go index 1d9b4ab0a..0c3a58fe3 100644 --- a/dsn.go +++ b/dsn.go @@ -34,22 +34,23 @@ var ( // If a new Config is created instead of being parsed from a DSN string, // the NewConfig function should be used, which sets default values. type Config struct { - User string // Username - Passwd string // Password (requires User) - Net string // Network type - Addr string // Network address (requires Net) - DBName string // Database name - Params map[string]string // Connection parameters - Collation string // Connection collation - Loc *time.Location // Location for time.Time values - MaxAllowedPacket int // Max packet size allowed - ServerPubKey string // Server public key name - pubKey *rsa.PublicKey // Server public key - TLSConfig string // TLS configuration name - tls *tls.Config // TLS configuration - Timeout time.Duration // Dial timeout - ReadTimeout time.Duration // I/O read timeout - WriteTimeout time.Duration // I/O write timeout + User string // Username + Passwd string // Password (requires User) + CredentialProvider string // Credential provider name registered with RegisterCredentialProvider + Net string // Network type + Addr string // Network address (requires Net) + DBName string // Database name + Params map[string]string // Connection parameters + Collation string // Connection collation + Loc *time.Location // Location for time.Time values + MaxAllowedPacket int // Max packet size allowed + ServerPubKey string // Server public key name + pubKey *rsa.PublicKey // Server public key + TLSConfig string // TLS configuration name + tls *tls.Config // TLS configuration + Timeout time.Duration // Dial timeout + ReadTimeout time.Duration // I/O read timeout + WriteTimeout time.Duration // I/O write timeout AllowAllFiles bool // Allow all files to be used with LOAD DATA LOCAL INFILE AllowCleartextPasswords bool // Allows the cleartext client side plugin @@ -347,6 +348,16 @@ func (cfg *Config) FormatDSN() string { } + if cfg.CredentialProvider != "" { + if hasParam { + buf.WriteString("&credentialProvider=") + } else { + hasParam = true + buf.WriteString("?credentialProvider=") + } + buf.WriteString(cfg.CredentialProvider) + } + // other params if cfg.Params != nil { var params []string @@ -613,6 +624,8 @@ func parseDSNParams(cfg *Config, params string) (err error) { if err != nil { return } + case "credentialProvider": + cfg.CredentialProvider = value default: // lazy init if cfg.Params == nil { diff --git a/dsn_test.go b/dsn_test.go index 50dc2932c..82194b52e 100644 --- a/dsn_test.go +++ b/dsn_test.go @@ -71,8 +71,10 @@ var testDSNs = []struct { }, { "tcp(de:ad:be:ef::ca:fe)/dbname", &Config{Net: "tcp", Addr: "[de:ad:be:ef::ca:fe]:3306", DBName: "dbname", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true}, -}, -} +}, { + "tcp(localhost)/dbname?credentialProvider=foobar", + &Config{Net: "tcp", Addr: "localhost:3306", DBName: "dbname", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, CredentialProvider: "foobar"}, +}} func TestDSNParser(t *testing.T) { for i, tst := range testDSNs { diff --git a/packets.go b/packets.go index 30b3352c2..18b0d3731 100644 --- a/packets.go +++ b/packets.go @@ -276,7 +276,7 @@ func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err erro // Client Authentication Packet // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse -func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string) error { +func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string, user string) error { // Adjust client flags based on server support clientFlags := clientProtocol41 | clientSecureConn | @@ -310,7 +310,7 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string clientFlags |= clientPluginAuthLenEncClientData } - pktLen := 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + len(authRespLEI) + len(authResp) + 21 + 1 + pktLen := 4 + 4 + 1 + 23 + len(user) + 1 + len(authRespLEI) + len(authResp) + 21 + 1 // To specify a db name if n := len(mc.cfg.DBName); n > 0 { @@ -373,8 +373,8 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string } // User [null terminated string] - if len(mc.cfg.User) > 0 { - pos += copy(data[pos:], mc.cfg.User) + if len(user) > 0 { + pos += copy(data[pos:], user) } data[pos] = 0x00 pos++