/
mutual_tls_factory.go
136 lines (116 loc) · 3.53 KB
/
mutual_tls_factory.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
package grpcpool
import (
"context"
"crypto/tls"
"crypto/x509"
"fmt"
"github.com/davidwartell/go-logger-facade/logger"
"github.com/pkg/errors"
"go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/security/advancedtls"
)
type MutualTLSFactory struct {
credentials credentials.TransportCredentials
dialAddr string
options *Options
}
//goland:noinspection GoUnusedExportedFunction
func NewMutualTLSFactory(
caCertPEM []byte,
clientCertPEM []byte,
clientKeyPEM []byte,
serverAddress string,
opts ...Option,
) (MutualTLSFactory, error) {
var err error
factory := MutualTLSFactory{
options: new(Options),
}
for _, opt := range opts {
opt(factory.options)
}
factory.credentials, err = LoadTLSCredentials(caCertPEM, clientCertPEM, clientKeyPEM)
if err != nil {
logger.Instance().Error("error loading TLS credentials", logger.Error(err))
return factory, err
}
factory.dialAddr = serverAddress
return factory, nil
}
func (f MutualTLSFactory) NewConnection(ctx context.Context) (*grpc.ClientConn, error) {
return f.NewConnectionWithDialOpts(ctx)
}
func (f MutualTLSFactory) NewConnectionWithDialOpts(ctx context.Context, opts ...grpc.DialOption) (*grpc.ClientConn, error) {
allOpts := []grpc.DialOption{
grpc.WithTransportCredentials(f.credentials),
}
if f.options.keepalive != nil {
allOpts = append(allOpts, grpc.WithKeepaliveParams(*f.options.keepalive))
}
if f.options.withSnappyCompression {
allOpts = append(allOpts, grpc.WithDefaultCallOptions(grpc.UseCompressor(SnappyCompressor())))
}
if f.options.withOtelTracing {
allOpts = append(allOpts, grpc.WithStatsHandler(otelgrpc.NewClientHandler()))
}
if len(opts) > 0 {
allOpts = append(allOpts, opts...)
}
conn, err := grpc.DialContext(ctx, f.dialAddr, allOpts...)
if err != nil {
if conn != nil {
_ = conn.Close()
}
logger.Instance().Info("failed to dial", logger.String("dialAddr", f.dialAddr), logger.Error(err))
return nil, err
}
err = f.ConnectionOk(ctx, conn)
if err != nil {
if conn != nil {
_ = conn.Close()
}
logger.Instance().Info("failed to ping", logger.String("dialAddr", f.dialAddr), logger.Error(err))
err = errors.Wrapf(err, "failed to ping %s", f.dialAddr)
return nil, err
}
return conn, nil
}
func (f MutualTLSFactory) ConnectionOk(ctx context.Context, conn *grpc.ClientConn) error {
if f.options.pingFunc == nil {
return nil
}
// we have to send a request to the server to see if we can actually write to the socket
// implementing a ping/pong rpc is useful for this
_, err := f.options.pingFunc(ctx, conn)
if err != nil {
if conn != nil {
_ = conn.Close()
}
return err
}
return nil
}
func LoadTLSCredentials(caCertPEM []byte, clientCertPEM []byte, clientKeyPEM []byte) (credentials.TransportCredentials, error) {
// Load certificate of the CA who signed server's certificate
certPool := x509.NewCertPool()
if !certPool.AppendCertsFromPEM(caCertPEM) {
return nil, fmt.Errorf("failed to add server CA's certificate")
}
// Load client's certificate and private key
clientCert, err := tls.X509KeyPair(clientCertPEM, clientKeyPEM)
if err != nil {
return nil, err
}
clientOptions := &advancedtls.ClientOptions{
IdentityOptions: advancedtls.IdentityCertificateOptions{
Certificates: []tls.Certificate{clientCert},
},
RootOptions: advancedtls.RootCertificateOptions{
RootCACerts: certPool,
},
VType: advancedtls.CertVerification,
}
return advancedtls.NewClientCreds(clientOptions)
}