This repository has been archived by the owner on Jan 8, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 330
/
grpc_version.go
155 lines (132 loc) · 3.81 KB
/
grpc_version.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
package server
import (
"context"
"strings"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
"github.com/hashicorp/waypoint/internal/protocolversion"
pb "github.com/hashicorp/waypoint/internal/server/gen"
)
// versionUnaryInterceptor returns a gRPC unary interceptor that negotiates
// the protocol version to use and sets it in the context using
// protocolversion.WithContext.
func versionUnaryInterceptor(serverInfo *pb.VersionInfo) grpc.UnaryServerInterceptor {
return func(
ctx context.Context,
req interface{},
info *grpc.UnaryServerInfo,
handler grpc.UnaryHandler) (interface{}, error) {
typ, ok := versionType(info.FullMethod)
if !ok {
return handler(ctx, req)
}
ctx, err := versionContext(ctx, typ, serverInfo)
if err != nil {
return nil, err
}
return handler(ctx, req)
}
}
// versionStreamInterceptor returns a gRPC unary interceptor that negotiates
// the protocol version to use and sets it in the context using
// protocolversion.WithContext.
func versionStreamInterceptor(serverInfo *pb.VersionInfo) grpc.StreamServerInterceptor {
return func(
srv interface{},
ss grpc.ServerStream,
info *grpc.StreamServerInfo,
handler grpc.StreamHandler) error {
typ, ok := versionType(info.FullMethod)
if !ok {
return handler(srv, ss)
}
ctx := ss.Context()
ctx, err := versionContext(ctx, typ, serverInfo)
if err != nil {
return err
}
// Invoke the handler.
return handler(srv, &versionStream{
ServerStream: ss,
context: ctx,
})
}
}
// versionType returns the type of protocol version we should negotiate.
func versionType(fullMethod string) (protocolversion.Type, bool) {
// Only care about waypoint APIs and ignore the version info call.
if !strings.HasPrefix(fullMethod, "/hashicorp.waypoint.Waypoint/") {
return protocolversion.Invalid, false
}
// Get the method
idx := strings.LastIndex(fullMethod, "/")
if idx == -1 {
return protocolversion.Invalid, false
}
method := fullMethod[idx+1:]
// If it is a version method we don't negotiate versions at all.
if method == "GetVersionInfo" {
return protocolversion.Invalid, false
}
// Determine what API is being called
typ := protocolversion.Api
if strings.HasPrefix(method, "Entrypoint") {
typ = protocolversion.Entrypoint
}
return typ, true
}
// versionContext
func versionContext(
ctx context.Context,
typ protocolversion.Type,
info *pb.VersionInfo,
) (context.Context, error) {
var header string
var server *pb.VersionInfo_ProtocolVersion
switch typ {
case protocolversion.Api:
header = protocolversion.HeaderClientApiProtocol
server = info.Api
case protocolversion.Entrypoint:
header = protocolversion.HeaderClientEntrypointProtocol
server = info.Entrypoint
default:
return nil, status.Errorf(codes.Internal, "invalid protocol type")
}
// Get our metadata
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
return nil, status.Errorf(codes.InvalidArgument, "Retrieving metadata is failed")
}
// Get the client version information
vs := md[header]
if len(vs) != 1 {
return nil, status.Errorf(codes.InvalidArgument,
"required header %s is not set", header)
}
min, current, err := protocolversion.ParseHeader(vs[0])
if err != nil {
return nil, status.Errorf(codes.InvalidArgument,
"header %q: %s", header, err)
}
// Negotiate the version to use
version, err := protocolversion.Negotiate(&pb.VersionInfo_ProtocolVersion{
Current: current,
Minimum: min,
}, server)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument,
"header %q: %s", header, err)
}
// Invoke the handler.
return protocolversion.WithContext(ctx, version), nil
}
type versionStream struct {
grpc.ServerStream
context context.Context
}
func (s *versionStream) Context() context.Context {
return s.context
}