/
dial.go
37 lines (33 loc) · 1.02 KB
/
dial.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
package client
import (
"context"
"crypto/tls"
"net"
"github.com/canonical/go-dqlite/internal/protocol"
)
// DefaultDialFunc is the default dial function, which can handle plain TCP and
// Unix socket endpoints. You can customize it with WithDialFunc()
func DefaultDialFunc(ctx context.Context, address string) (net.Conn, error) {
return protocol.Dial(ctx, address)
}
// DialFuncWithTLS returns a dial function that uses TLS encryption.
//
// The given dial function will be used to establish the network connection,
// and the given TLS config will be used for encryption.
func DialFuncWithTLS(dial DialFunc, config *tls.Config) DialFunc {
return func(ctx context.Context, addr string) (net.Conn, error) {
clonedConfig := config.Clone()
if len(clonedConfig.ServerName) == 0 {
remoteIP, _, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
}
clonedConfig.ServerName = remoteIP
}
conn, err := dial(ctx, addr)
if err != nil {
return nil, err
}
return tls.Client(conn, clonedConfig), nil
}
}