diff --git a/server/caching_sha2_cache_test.go b/server/caching_sha2_cache_test.go index e15f1e4a2..c11e5e76d 100644 --- a/server/caching_sha2_cache_test.go +++ b/server/caching_sha2_cache_test.go @@ -6,6 +6,7 @@ import ( "net" "strings" "sync" + "sync/atomic" "testing" "time" @@ -18,18 +19,17 @@ import ( "github.com/go-mysql-org/go-mysql/mysql" "github.com/go-mysql-org/go-mysql/test_util" "github.com/go-mysql-org/go-mysql/test_util/test_keys" - "github.com/go-mysql-org/go-mysql/utils" ) -var delay = 50 - // test caching for 'caching_sha2_password' // NOTE the idea here is to plugin a throttled credential provider so that the first connection (cache miss) will take longer time // than the second connection (cache hit). Remember to set the password for MySQL user otherwise it won't cache empty password. func TestCachingSha2Cache(t *testing.T) { log.SetLevel(log.LevelDebug) - remoteProvider := &RemoteThrottleProvider{NewInMemoryProvider(), delay + 50} + remoteProvider := &RemoteThrottleProvider{ + InMemoryProvider: NewInMemoryProvider(), + } remoteProvider.AddUser(*testUser, *testPassword) cacheServer := NewServer("8.0.12", mysql.DEFAULT_COLLATION_ID, mysql.AUTH_CACHING_SHA2_PASSWORD, test_keys.PubPem, tlsConf) @@ -44,7 +44,9 @@ func TestCachingSha2Cache(t *testing.T) { func TestCachingSha2CacheTLS(t *testing.T) { log.SetLevel(log.LevelDebug) - remoteProvider := &RemoteThrottleProvider{NewInMemoryProvider(), delay + 50} + remoteProvider := &RemoteThrottleProvider{ + InMemoryProvider: NewInMemoryProvider(), + } remoteProvider.AddUser(*testUser, *testPassword) cacheServer := NewServer("8.0.12", mysql.DEFAULT_COLLATION_ID, mysql.AUTH_CACHING_SHA2_PASSWORD, test_keys.PubPem, tlsConf) @@ -58,11 +60,11 @@ func TestCachingSha2CacheTLS(t *testing.T) { type RemoteThrottleProvider struct { *InMemoryProvider - delay int // in milliseconds + getCredCallCount atomic.Int64 } func (m *RemoteThrottleProvider) GetCredential(username string) (password string, found bool, err error) { - time.Sleep(time.Millisecond * time.Duration(m.delay)) + m.getCredCallCount.Add(1) return m.InMemoryProvider.GetCredential(username) } @@ -132,35 +134,26 @@ func (s *cacheTestSuite) runSelect() { func (s *cacheTestSuite) TestCache() { // first connection - t1 := utils.Now() var err error s.db, err = sql.Open("mysql", fmt.Sprintf("%s:%s@tcp(%s)/%s?tls=%s", *testUser, *testPassword, s.serverAddr, *testDB, s.tlsPara)) require.NoError(s.T(), err) s.db.SetMaxIdleConns(4) s.runSelect() - t2 := utils.Now() - - d1 := int(t2.Sub(t1).Nanoseconds() / 1e6) - // log.Debugf("first connection took %d milliseconds", d1) - - require.GreaterOrEqual(s.T(), d1, delay) + got := s.credProvider.(*RemoteThrottleProvider).getCredCallCount.Load() + require.Equal(s.T(), int64(1), got) if s.db != nil { s.db.Close() } // second connection - t3 := utils.Now() s.db, err = sql.Open("mysql", fmt.Sprintf("%s:%s@tcp(%s)/%s?tls=%s", *testUser, *testPassword, s.serverAddr, *testDB, s.tlsPara)) require.NoError(s.T(), err) s.db.SetMaxIdleConns(4) s.runSelect() - t4 := utils.Now() - - d2 := int(t4.Sub(t3).Nanoseconds() / 1e6) - // log.Debugf("second connection took %d milliseconds", d2) + got = s.credProvider.(*RemoteThrottleProvider).getCredCallCount.Load() + require.Equal(s.T(), int64(1), got) - require.Less(s.T(), d2, delay) if s.db != nil { s.db.Close() }