-
Notifications
You must be signed in to change notification settings - Fork 22
/
my.go
209 lines (195 loc) · 6.15 KB
/
my.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
// Copyright 2023 The Cockroach Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0
// Package stdpool creates standardized database connection pools.
package stdpool
import (
"crypto/tls"
"database/sql"
sqldriver "database/sql/driver"
"fmt"
"net/url"
"strconv"
"strings"
"sync/atomic"
"time"
"github.com/cockroachdb/cdc-sink/internal/types"
"github.com/cockroachdb/cdc-sink/internal/util/secure"
"github.com/cockroachdb/cdc-sink/internal/util/stopper"
"github.com/go-sql-driver/mysql"
"github.com/pkg/errors"
log "github.com/sirupsen/logrus"
)
// See also:
// https://dev.mysql.com/doc/mysql-errors/8.0/en/server-error-reference.html
func myErrCode(err error) (string, bool) {
if myErr := (*mysql.MySQLError)(nil); errors.As(err, &myErr) {
return strconv.Itoa(int(myErr.Number)), true
}
return "", false
}
func myErrDeferrable(err error) bool {
code, ok := myErrCode(err)
if !ok {
return false
}
// Cannot add or update a child row: a foreign key constraint fails
return code == "1452"
}
func myErrRetryable(err error) bool {
code, ok := myErrCode(err)
if !ok {
return false
}
// Deadlock detected due to concurrent modification.
return code == "40001"
}
// tlsConfigNames assign new names to TLS configuration objects used by the driver.
var tlsConfigNames = onomastic{}
// onomastic is used to create unique names.
type onomastic struct {
counter atomic.Uint64
}
// newName assign a unique name.
func (o *onomastic) newName(prefix string) string {
return fmt.Sprintf("%s_%d", prefix, o.counter.Add(1))
}
// OpenMySQLAsTarget opens a database connection, returning it as
// a single connection.
func OpenMySQLAsTarget(
ctx *stopper.Context, connectString string, url *url.URL, options ...Option,
) (*types.TargetPool, error) {
var tc TestControls
if err := attachOptions(ctx, &tc, options); err != nil {
return nil, err
}
// Use a unique name for each call of OpenMySQLAsTarget.
tlsConfigName := tlsConfigNames.newName("mysql_driver")
tlsConfigs, err := secure.ParseTLSOptions(url)
if err != nil {
return nil, err
}
var ret *types.TargetPool
var transportError error
// Try all possible transport options.
// The first one that works is the one we will use.
for _, tlsConfig := range tlsConfigs {
mysql.DeregisterTLSConfig(tlsConfigName)
mySQLString, err := getConnString(url, tlsConfigName, tlsConfig)
if err != nil {
return nil, errors.WithStack(err)
}
err = mysql.RegisterTLSConfig(tlsConfigName, tlsConfig)
if err != nil {
return nil, errors.WithStack(err)
}
driver := mysql.MySQLDriver{}
connector, err := driver.OpenConnector(mySQLString)
if err != nil {
log.WithError(err).Trace("failed to connect to database server")
transportError = err
// Try a different option.
continue
}
ret = &types.TargetPool{
DB: sql.OpenDB(connector),
PoolInfo: types.PoolInfo{
ConnectionString: connectString,
Product: types.ProductMySQL,
ErrCode: myErrCode,
IsDeferrable: myErrDeferrable,
ShouldRetry: myErrRetryable,
},
}
ctx.Defer(func() { _ = ret.Close() })
ping:
if err := ret.Ping(); err != nil {
// For some errors, we retry.
if tc.WaitForStartup && isMySQLStartupError(err) {
log.WithError(err).Info("waiting for database to become ready")
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-time.After(10 * time.Second):
goto ping
}
}
transportError = err
_ = ret.Close()
// Try a different option.
continue
}
// Testing that connection is usable.
if err := ret.QueryRow("SELECT VERSION();").Scan(&ret.Version); err != nil {
return nil, errors.Wrap(err, "could not query version")
}
log.Infof("Version %s.", ret.Version)
if strings.Contains(ret.Version, "MariaDB") {
ret.PoolInfo.Product = types.ProductMariaDB
}
if err := setTableHint(ret.Info()); err != nil {
return nil, err
}
// If debug is enabled we print sql mode and ssl info.
if log.IsLevelEnabled(log.DebugLevel) {
var mode string
if err := ret.QueryRow("SELECT @@sql_mode").Scan(&mode); err != nil {
log.Errorf("could not query sql mode %s", err.Error())
}
var varName, cipher string
if err := ret.QueryRow("SHOW STATUS LIKE 'Ssl_cipher';").Scan(&varName, &cipher); err != nil {
log.Errorf("could not query ssl info %s", err.Error())
}
log.Debugf("Mode %s. %s %s", mode, varName, cipher)
ret.Version = fmt.Sprintf("%s cipher[%s]", ret.Version, cipher)
}
if err := attachOptions(ctx, ret.DB, options); err != nil {
return nil, err
}
if err := attachOptions(ctx, &ret.PoolInfo, options); err != nil {
return nil, err
}
// The connection meets the client/server requirements,
// no need to try other transport options.
return ret, nil
}
// All the options have been exhausted, returning the last error.
return nil, transportError
}
// TODO (silvano): verify error codes.
func isMySQLStartupError(err error) bool {
switch err {
case sqldriver.ErrBadConn:
return true
default:
return false
}
}
// getConnString returns a driver specific connection strings
// The TLS configuration must be already extracted from the URL parameters
// to determine the list of possible transport connections that the client wants to try.
// This function is only concerned about user, host and path section of the URL.
func getConnString(url *url.URL, tlsConfigName string, config *tls.Config) (string, error) {
path := "/"
if url.Path != "" {
path = url.Path
}
baseSQLString := fmt.Sprintf("%s@tcp(%s)%s?%s", url.User.String(), url.Host,
path, "sql_mode=ansi")
if config == nil {
return baseSQLString, nil
}
return fmt.Sprintf("%s&tls=%s", baseSQLString, tlsConfigName), nil
}