-
Notifications
You must be signed in to change notification settings - Fork 9
/
protocol.go
177 lines (143 loc) · 4.46 KB
/
protocol.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
// Copyright 2022 Namespace Labs Inc; All rights reserved.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
package provisioning
import (
"context"
"flag"
"io"
"log"
"os"
"strings"
"time"
"github.com/klauspost/compress/zstd"
"google.golang.org/grpc"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/proto"
"namespacelabs.dev/foundation/internal/compression"
"namespacelabs.dev/foundation/internal/fnerrors"
)
const (
maximumWallclockTime = 2 * time.Minute
)
var (
debug = flag.Bool("debug", false, "If set to true, emits debugging information.")
inlineInvocation = flag.String("inline_invocation", "", "If set, calls the method inline on stdin/stdout.")
inlineInvocationCompressed = flag.Bool("inline_invocation_compressed", false, "If set, the input and output are compressed.")
inlineInvocationJson = flag.Bool("inline_invocation_json", false, "If set, the input and output are json.")
inlineInvocationInput = flag.String("inline_invocation_input", "", "If set, reads the request from the path instead of os.Stdin.")
inlineInvocationOutput = flag.String("inline_invocation_output", "", "If set, writes the result to the path instead of os.Stdout.")
)
func HandleInvocation(ctx context.Context, register func(grpc.ServiceRegistrar)) error {
go func() {
if *debug {
log.Printf("Setup kill-switch: %v", maximumWallclockTime)
}
time.Sleep(maximumWallclockTime)
log.Fatalf("aborting tool after %v", maximumWallclockTime)
}()
flag.Parse()
if *inlineInvocation == "" {
log.Fatal("--inline_invocation is missing")
}
var reg inlineRegistrar
register(®)
m := strings.Split(*inlineInvocation, "/")
if len(m) != 2 {
log.Fatal("bad invocation specification")
}
for _, sr := range reg.registrations {
if sr.ServiceDesc.ServiceName != m[0] {
continue
}
for _, x := range sr.ServiceDesc.Methods {
if x.MethodName == m[1] {
result, err := x.Handler(sr.Impl, context.Background(), decodeInput, nil)
if err != nil {
log.Fatal(err)
}
if err := marshalOutput(result); err != nil {
log.Fatal(err)
}
os.Exit(0)
}
}
}
log.Fatalf("%s: don't know how to handle this method", *inlineInvocation)
return nil
}
func decodeInput(target interface{}) error {
var err error
var bytes []byte
if *inlineInvocationInput != "" {
bytes, err = os.ReadFile(*inlineInvocationInput)
} else {
bytes, err = io.ReadAll(os.Stdin)
}
if err != nil {
return fnerrors.InternalError("failed to read input: %w", err)
}
if *inlineInvocationCompressed {
payload, err := compression.DecompressZstd(bytes)
if err != nil {
return fnerrors.InternalError("failed to decompress payload: %w", err)
}
return proto.Unmarshal(payload, target.(proto.Message))
}
if *inlineInvocationJson {
return protojson.Unmarshal(bytes, target.(proto.Message))
}
return proto.Unmarshal(bytes, target.(proto.Message))
}
func marshalOutput(out interface{}) error {
w := os.Stdout
if *inlineInvocationOutput != "" {
f, err := os.Create(*inlineInvocationOutput)
if err != nil {
return fnerrors.InternalError("failed to create output: %w", err)
}
w = f
}
var bytes []byte
var err error
if *inlineInvocationJson {
bytes, err = protojson.Marshal(out.(proto.Message))
} else {
bytes, err = proto.Marshal(out.(proto.Message))
}
if err != nil {
return fnerrors.InternalError("failed to serialize output: %w", err)
}
if *inlineInvocationCompressed {
w, err := zstd.NewWriter(w)
if err != nil {
return fnerrors.InternalError("failed to prepare output: %w", err)
}
if _, err := w.Write(bytes); err != nil {
return fnerrors.InternalError("failed to compress output: %w", err)
}
if err := w.Close(); err != nil {
return fnerrors.InternalError("failed to finalize compression: %w", err)
}
return nil
}
if _, err := w.Write(bytes); err != nil {
return fnerrors.InternalError("failed to write output: %w", err)
}
if *inlineInvocationOutput != "" {
if err := w.Close(); err != nil {
return fnerrors.InternalError("failed to close output: %w", err)
}
}
return nil
}
type inlineRegistrar struct {
registrations []inlineRegistration
}
type inlineRegistration struct {
ServiceDesc *grpc.ServiceDesc
Impl interface{}
}
func (reg *inlineRegistrar) RegisterService(desc *grpc.ServiceDesc, impl interface{}) {
reg.registrations = append(reg.registrations, inlineRegistration{desc, impl})
}