/
main.go
249 lines (208 loc) · 6.72 KB
/
main.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
package main
import (
"context"
"encoding/json"
"fmt"
"log"
"math/rand"
"net/http"
"os"
"os/exec"
"os/signal"
"strconv"
"strings"
"syscall"
"time"
"github.com/aws/aws-lambda-go/events"
"github.com/aws/aws-lambda-go/lambda"
"github.com/oxplot/starenv"
)
const lambdafyEnvPrefix = "LAMBDAFY_"
var (
// These will be populated by go generate.
version = "dev"
port int // base port for various endpoints
appEndpoint string // end point that proxy will proxy requests to
listen string // listen address for our own HTTP server used for proxying to AWS services.
functionName = os.Getenv("AWS_LAMBDA_FUNCTION_NAME")
functionVersion = os.Getenv("AWS_LAMBDA_FUNCTION_VERSION")
inLambda = functionName != "" && functionVersion != "" && os.Getenv("AWS_LAMBDA_RUNTIME_API") != ""
client = &http.Client{
Transport: &http.Transport{
DisableKeepAlives: true,
},
CheckRedirect: func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
},
}
)
func init() {
rand.Seed(time.Now().UnixNano())
// Generate a random port number between 19000 and 19999.
// This is to ensure the user program can't depend on hardcoded port numbers.
port = 19000 + int(time.Now().UnixNano()%1000)
appEndpoint = "127.0.0.1:" + strconv.Itoa(port)
listen = "127.0.0.1:" + strconv.Itoa(port+1)
}
// handle is a generic handler for all Lambda events supported by this function.
func handle(ctx context.Context, e map[string]json.RawMessage) (any, error) {
// Flush stdout and stderr before returning to ensure the logs are captured by
// AWS.
defer func() {
os.Stdout.Sync()
os.Stderr.Sync()
}()
b, _ := json.Marshal(e)
if _, ok := e["Records"]; ok { // SQS event
var sqsEvent events.SQSEvent
if err := json.Unmarshal(b, &sqsEvent); err != nil {
log.Printf("failed to unmarshal the SQS event: %v", err)
return nil, err
}
return handleSQS(ctx, sqsEvent)
} else if _, ok := e["rawQueryString"]; ok {
var httpEvent events.APIGatewayV2HTTPRequest
if err := json.Unmarshal(b, &httpEvent); err != nil {
log.Printf("failed to unmarshal the APIGatewayV2 event: %v", err)
return nil, err
}
return handleHTTP(ctx, httpEvent)
} else if _, ok := e["cron"]; ok {
var cronEvent struct {
Cron string `json:"cron"`
}
if err := json.Unmarshal(b, &cronEvent); err != nil {
log.Printf("failed to unmarshal the cron event: %v", err)
}
return nil, handleCron(ctx, cronEvent.Cron)
}
return nil, fmt.Errorf("event type %v not supported by this lambda function", e)
}
// run is the main entry point for the proxy.
func run() (exitCode int, err error) {
if len(os.Args) < 2 {
return 127, fmt.Errorf("usage: %s command [arg [arg [...]]]", os.Args[0])
}
cmdName := os.Args[1]
// Remove all env vars with lambdafy prefix to prevent child process from
// depending on them.
// IMPORTANT: This must come before startenv loading since none of the values
// in the lambdafy prefixed env vars are meant be dereferenced.
for _, e := range os.Environ() {
if strings.HasPrefix(e, lambdafyEnvPrefix) {
os.Unsetenv(strings.SplitN(e, "=", 2)[0])
}
}
// Load env vars/derefence them from various sources
envLoader := starenv.NewLoader()
for t, n := range starenv.DefaultDerefers {
envLoader.Register(t, &starenv.LazyDerefer{New: n})
}
envLoader.Register(sendSQSStarenvTag, sqsIDToQueueURL)
if err := envLoader.Load(); len(err) > 0 {
return 1, fmt.Errorf("error loading env vars: %s", err)
}
if !inLambda {
path, err := exec.LookPath(cmdName)
if err != nil {
return 1, fmt.Errorf("cannot find command '%s': %w", cmdName, err)
}
// syscall.Exec requires the first argument to be the command name.
args := os.Args[1:]
err = syscall.Exec(path, args, os.Environ())
// If Exec succeeds, we'll never get here.
return 1, err
}
log.Printf("running in lambda, starting proxying traffic to %s", appEndpoint)
args := os.Args[2:]
// Start own AWS proxy endpoint (used for sending on SQS and other services)
http.HandleFunc("/sqs", handleSQSSend)
go http.ListenAndServe(listen, nil)
// Set/override the PORT env var
os.Setenv("PORT", strconv.Itoa(port))
// Run the command
cmd := exec.Command(cmdName, args...)
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
if err := cmd.Start(); err != nil {
return 127, fmt.Errorf("failed to run command: %s", err)
}
// Pass through all signals to the child process
sigs := make(chan os.Signal)
go func() {
for s := range sigs {
_ = cmd.Process.Signal(s)
}
}()
signal.Notify(sigs)
// Monitor child process for when it exits.
processStopped := make(chan struct{})
go func() {
defer close(processStopped)
if err := cmd.Wait(); err != nil {
if err, ok := err.(*exec.ExitError); ok {
log.Printf("command exited with code: %d", err.ExitCode())
} else {
log.Printf("error: waiting for command: %s", err)
}
}
os.Stdout.Sync()
os.Stderr.Sync()
}()
// Wait until the upstream is up and running
waitClient := &http.Client{
Transport: &http.Transport{
DisableKeepAlives: true,
},
CheckRedirect: func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
},
}
log.Printf("waiting for startup request to succeed")
StartupRequest:
for {
u := "http://" + appEndpoint + "/"
req, err := http.NewRequest(http.MethodGet, u, nil)
if err != nil {
return 1, fmt.Errorf("failed to create startup request: %s", err)
}
if resp, err := waitClient.Do(req); err == nil {
resp.Body.Close()
log.Printf("startup request passed - proxying requests from now on")
// We will only start accepting requests once the startup request to the
// upstream has succeeded. This is to ensure that the upstream is up and
// running before we take requests out of the queue and start sending them
// to the upstream. Startup phase is limited to 10 seconds and it is in
// addition to whatever timeout is set for the lambda.
// If start fails, it rudely kills the process so no need to do anything
// here. Inside a container, once we are killed, so will every other
// process, so no need to do anything here to catch it.
go lambda.StartWithOptions(handle, lambda.WithEnableSIGTERM())
break
}
select {
case <-processStopped:
break StartupRequest
default:
time.Sleep(100 * time.Millisecond)
}
// The reason we don't have our own timeout for this stage is that it'll be
// redundant in presence of Lambda's own timeout.
}
// Wait for process/lambda to stop.
<-processStopped
if cmd.ProcessState.ExitCode() == -1 {
return 127, nil
}
return cmd.ProcessState.ExitCode(), nil
}
func main() {
log.SetFlags(0)
log.SetPrefix("lambdafy-proxy: ")
exitCode, err := run()
if err != nil {
log.Fatalf("error: %s", err)
}
os.Exit(exitCode)
}