This repository has been archived by the owner on Oct 29, 2022. It is now read-only.
forked from argoproj/argo-workflows
-
Notifications
You must be signed in to change notification settings - Fork 0
/
grpc.go
151 lines (140 loc) · 4.79 KB
/
grpc.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
package grpc
import (
"crypto/tls"
"net"
"runtime/debug"
"strings"
"time"
"github.com/sirupsen/logrus"
"golang.org/x/net/context"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/status"
)
// PanicLoggerUnaryServerInterceptor returns a new unary server interceptor for recovering from panics and returning error
func PanicLoggerUnaryServerInterceptor(log *logrus.Entry) grpc.UnaryServerInterceptor {
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (_ interface{}, err error) {
defer func() {
if r := recover(); r != nil {
log.Errorf("Recovered from panic: %+v\n%s", r, debug.Stack())
err = status.Errorf(codes.Internal, "%s", r)
}
}()
return handler(ctx, req)
}
}
// PanicLoggerStreamServerInterceptor returns a new streaming server interceptor for recovering from panics and returning error
func PanicLoggerStreamServerInterceptor(log *logrus.Entry) grpc.StreamServerInterceptor {
return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) (err error) {
defer func() {
if r := recover(); r != nil {
log.Errorf("Recovered from panic: %+v\n%s", r, debug.Stack())
err = status.Errorf(codes.Internal, "%s", r)
}
}()
return handler(srv, stream)
}
}
// BlockingDial is a helper method to dial the given address, using optional TLS credentials,
// and blocking until the returned connection is ready. If the given credentials are nil, the
// connection will be insecure (plain-text).
// Lifted from: https://github.com/fullstorydev/grpcurl/blob/master/grpcurl.go
func BlockingDial(ctx context.Context, network, address string, creds credentials.TransportCredentials, opts ...grpc.DialOption) (*grpc.ClientConn, error) {
// grpc.Dial doesn't provide any information on permanent connection errors (like
// TLS handshake failures). So in order to provide good error messages, we need a
// custom dialer that can provide that info. That means we manage the TLS handshake.
result := make(chan interface{}, 1)
writeResult := func(res interface{}) {
// non-blocking write: we only need the first result
select {
case result <- res:
default:
}
}
dialer := func(address string, timeout time.Duration) (net.Conn, error) {
ctx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
conn, err := (&net.Dialer{Cancel: ctx.Done()}).Dial(network, address)
if err != nil {
writeResult(err)
return nil, err
}
if creds != nil {
conn, _, err = creds.ClientHandshake(ctx, address, conn)
if err != nil {
writeResult(err)
return nil, err
}
}
return conn, nil
}
// Even with grpc.FailOnNonTempDialError, this call will usually timeout in
// the face of TLS handshake errors. So we can't rely on grpc.WithBlock() to
// know when we're done. So we run it in a goroutine and then use result
// channel to either get the channel or fail-fast.
go func() {
opts = append(opts,
grpc.WithBlock(),
grpc.FailOnNonTempDialError(true),
grpc.WithDialer(dialer),
grpc.WithInsecure(), // we are handling TLS, so tell grpc not to
)
conn, err := grpc.DialContext(ctx, address, opts...)
var res interface{}
if err != nil {
res = err
} else {
res = conn
}
writeResult(res)
}()
select {
case res := <-result:
if conn, ok := res.(*grpc.ClientConn); ok {
return conn, nil
}
return nil, res.(error)
case <-ctx.Done():
return nil, ctx.Err()
}
}
type TLSTestResult struct {
TLS bool
InsecureErr error
}
func TestTLS(address string) (*TLSTestResult, error) {
if parts := strings.Split(address, ":"); len(parts) == 1 {
// If port is unspecified, assume the most likely port
address += ":443"
}
var testResult TLSTestResult
var tlsConfig tls.Config
tlsConfig.InsecureSkipVerify = true
creds := credentials.NewTLS(&tlsConfig)
conn, err := BlockingDial(context.Background(), "tcp", address, creds)
if err == nil {
_ = conn.Close()
testResult.TLS = true
creds := credentials.NewTLS(&tls.Config{})
conn, err := BlockingDial(context.Background(), "tcp", address, creds)
if err == nil {
_ = conn.Close()
} else {
// if connection was successful with InsecureSkipVerify true, but unsuccessful with
// InsecureSkipVerify false, it means server is not configured securely
testResult.InsecureErr = err
}
return &testResult, nil
}
// If we get here, we were unable to connect via TLS (even with InsecureSkipVerify: true)
// It may be because server is running without TLS, or because of real issues (e.g. connection
// refused). Test if server accepts plain-text connections
conn, err = BlockingDial(context.Background(), "tcp", address, nil)
if err == nil {
_ = conn.Close()
testResult.TLS = false
return &testResult, nil
}
return nil, err
}