-
Notifications
You must be signed in to change notification settings - Fork 5
/
mysql_connection_string_builder.go
102 lines (82 loc) · 2.63 KB
/
mysql_connection_string_builder.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
package db
import (
"crypto/tls"
"crypto/x509"
"fmt"
"net/url"
"os"
"time"
"github.com/go-sql-driver/mysql"
)
//go:generate counterfeiter -o ../fakes/mysql_adapter.go --fake-name MySQLAdapter . mySQLAdapter
type mySQLAdapter interface {
ParseDSN(dsn string) (cfg *mysql.Config, err error)
RegisterTLSConfig(key string, config *tls.Config) error
}
type MySQLConnectionStringBuilder struct {
MySQLAdapter mySQLAdapter
}
func (m *MySQLConnectionStringBuilder) Build(config Config) (string, error) {
sqlMode := url.QueryEscape("(SELECT CONCAT(@@sql_mode,',ANSI_QUOTES'))")
connString := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?parseTime=true&sql_mode=%s", config.User, config.Password, config.Host, config.Port, config.DatabaseName, sqlMode)
dbConfig, err := m.MySQLAdapter.ParseDSN(connString)
if err != nil {
return "", fmt.Errorf("parsing db connection string: %s", err)
}
timeoutDuration := time.Duration(config.Timeout) * time.Second
dbConfig.Timeout = timeoutDuration
dbConfig.ReadTimeout = timeoutDuration
dbConfig.WriteTimeout = timeoutDuration
if config.RequireSSL {
dbConfig.TLSConfig = fmt.Sprintf("%s-tls", config.DatabaseName)
certBytes, err := os.ReadFile(config.CACert)
if err != nil {
return "", fmt.Errorf("reading db ca cert file: %s", err)
}
caCertPool := x509.NewCertPool()
if ok := caCertPool.AppendCertsFromPEM(certBytes); !ok {
return "", fmt.Errorf("appending cert to pool from pem - invalid cert bytes")
}
tlsConfig := &tls.Config{
InsecureSkipVerify: false,
RootCAs: caCertPool,
}
if config.SkipHostnameValidation {
tlsConfig.InsecureSkipVerify = true
tlsConfig.VerifyPeerCertificate = func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
return VerifyCertificatesIgnoreHostname(rawCerts, caCertPool)
}
}
err = m.MySQLAdapter.RegisterTLSConfig(dbConfig.TLSConfig, tlsConfig)
if err != nil {
return "", fmt.Errorf("registering mysql tls config: %s", err)
}
}
return dbConfig.FormatDSN(), nil
}
func VerifyCertificatesIgnoreHostname(rawCerts [][]byte, caCertPool *x509.CertPool) error {
certs := make([]*x509.Certificate, len(rawCerts))
for i, asn1Data := range rawCerts {
cert, err := x509.ParseCertificate(asn1Data)
if err != nil {
return fmt.Errorf("tls: failed to parse certificate from server: %s", err)
}
certs[i] = cert
}
opts := x509.VerifyOptions{
Roots: caCertPool,
CurrentTime: time.Now(),
Intermediates: x509.NewCertPool(),
}
for i, cert := range certs {
if i == 0 {
continue
}
opts.Intermediates.AddCert(cert)
}
_, err := certs[0].Verify(opts)
if err != nil {
return err
}
return nil
}