diff --git a/apm-lambda-extension/extension/apm_server.go b/apm-lambda-extension/extension/apm_server.go index 2fb68d5c..1ca2822c 100644 --- a/apm-lambda-extension/extension/apm_server.go +++ b/apm-lambda-extension/extension/apm_server.go @@ -21,6 +21,7 @@ import ( "bytes" "compress/gzip" "fmt" + "io" "io/ioutil" "log" "net/http" @@ -28,28 +29,45 @@ import ( // todo: can this be a streaming or streaming style call that keeps the // connection open across invocations? -func PostToApmServer(postBody []byte, config *extensionConfig) error { +func PostToApmServer(agentData AgentData, config *extensionConfig) error { endpointUri := "intake/v2/events" - var compressedBytes bytes.Buffer - w := gzip.NewWriter(&compressedBytes) - w.Write(postBody) - w.Write([]byte{10}) - w.Close() + var req *http.Request + var err error - client := &http.Client{} + if agentData.ContentEncoding == "" { + pr, pw := io.Pipe() + gw, _ := gzip.NewWriterLevel(pw, gzip.BestSpeed) - req, err := http.NewRequest("POST", config.apmServerUrl+endpointUri, bytes.NewReader(compressedBytes.Bytes())) - if err != nil { - return fmt.Errorf("failed to create a new request when posting to APM server: %v", err) + go func() { + _, err = io.Copy(gw, bytes.NewReader(agentData.Data)) + gw.Close() + pw.Close() + if err != nil { + log.Printf("Failed to compress data: %v", err) + } + }() + + req, err = http.NewRequest("POST", config.apmServerUrl+endpointUri, pr) + if err != nil { + return fmt.Errorf("failed to create a new request when posting to APM server: %v", err) + } + req.Header.Add("Content-Encoding", "gzip") + } else { + req, err = http.NewRequest("POST", config.apmServerUrl+endpointUri, bytes.NewReader(agentData.Data)) + if err != nil { + return fmt.Errorf("failed to create a new request when posting to APM server: %v", err) + } + req.Header.Add("Content-Encoding", agentData.ContentEncoding) } - req.Header.Add("Content-Type", "application/x-ndjson") - req.Header.Add("Content-Encoding", "gzip") + req.Header.Add("Content-Type", "application/x-ndjson") if config.apmServerApiKey != "" { req.Header.Add("Authorization", "ApiKey "+config.apmServerApiKey) } else if config.apmServerSecretToken != "" { req.Header.Add("Authorization", "Bearer "+config.apmServerSecretToken) } + + client := &http.Client{} resp, err := client.Do(req) if err != nil { return fmt.Errorf("failed to post to APM server: %v", err) diff --git a/apm-lambda-extension/extension/apm_server_test.go b/apm-lambda-extension/extension/apm_server_test.go new file mode 100644 index 00000000..4cfd19d5 --- /dev/null +++ b/apm-lambda-extension/extension/apm_server_test.go @@ -0,0 +1,95 @@ +// Licensed to Elasticsearch B.V. under one or more contributor +// license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright +// ownership. Elasticsearch B.V. licenses this file to you under +// the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package extension + +import ( + "compress/gzip" + "io" + "io/ioutil" + "net/http" + "net/http/httptest" + "testing" + + "gotest.tools/assert" +) + +func TestPostToApmServerDataCompressed(t *testing.T) { + s := "A long time ago in a galaxy far, far away..." + + // Compress the data + pr, pw := io.Pipe() + gw, _ := gzip.NewWriterLevel(pw, gzip.BestSpeed) + go func() { + gw.Write([]byte(s)) + gw.Close() + pw.Close() + }() + + // Create AgentData struct with compressed data + data, _ := ioutil.ReadAll(pr) + agentData := AgentData{Data: data, ContentEncoding: "gzip"} + + // Create apm server and handler + apmServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + bytes, _ := ioutil.ReadAll(r.Body) + assert.Equal(t, string(data), string(bytes)) + assert.Equal(t, "gzip", r.Header.Get("Content-Encoding")) + w.Write([]byte(`{"foo": "bar"}`)) + })) + defer apmServer.Close() + + config := extensionConfig{ + apmServerUrl: apmServer.URL + "/", + } + + err := PostToApmServer(agentData, &config) + assert.Equal(t, nil, err) +} + +func TestPostToApmServerDataNotCompressed(t *testing.T) { + s := "A long time ago in a galaxy far, far away..." + body := []byte(s) + agentData := AgentData{Data: body, ContentEncoding: ""} + + // Compress the data, so it can be compared with what + // the apm server receives + pr, pw := io.Pipe() + gw, _ := gzip.NewWriterLevel(pw, gzip.BestSpeed) + go func() { + gw.Write(body) + gw.Close() + pw.Close() + }() + + // Create apm server and handler + apmServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + request_bytes, _ := ioutil.ReadAll(r.Body) + compressed_bytes, _ := ioutil.ReadAll(pr) + assert.Equal(t, string(compressed_bytes), string(request_bytes)) + assert.Equal(t, "gzip", r.Header.Get("Content-Encoding")) + w.Write([]byte(`{"foo": "bar"}`)) + })) + defer apmServer.Close() + + config := extensionConfig{ + apmServerUrl: apmServer.URL + "/", + } + + err := PostToApmServer(agentData, &config) + assert.Equal(t, nil, err) +} diff --git a/apm-lambda-extension/extension/http_server.go b/apm-lambda-extension/extension/http_server.go index 1bd6395a..30bd45a8 100644 --- a/apm-lambda-extension/extension/http_server.go +++ b/apm-lambda-extension/extension/http_server.go @@ -18,18 +18,13 @@ package extension import ( - "bytes" - "compress/gzip" - "compress/zlib" - "fmt" - "io/ioutil" "net" "net/http" "time" ) type serverHandler struct { - data chan []byte + data chan AgentData config *extensionConfig } @@ -50,7 +45,7 @@ func (handler *serverHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) } -func NewHttpServer(dataChannel chan []byte, config *extensionConfig) *http.Server { +func NewHttpServer(dataChannel chan AgentData, config *extensionConfig) *http.Server { var handler = serverHandler{data: dataChannel, config: config} timeout := time.Duration(config.dataReceiverTimeoutSeconds) * time.Second s := &http.Server{ @@ -70,37 +65,3 @@ func NewHttpServer(dataChannel chan []byte, config *extensionConfig) *http.Serve return s } - -func getDecompressedBytesFromRequest(req *http.Request) ([]byte, error) { - var rawBytes []byte - if req.Body != nil { - rawBytes, _ = ioutil.ReadAll(req.Body) - } - - switch req.Header.Get("Content-Encoding") { - case "deflate": - reader := bytes.NewReader([]byte(rawBytes)) - zlibreader, err := zlib.NewReader(reader) - if err != nil { - return nil, fmt.Errorf("could not create zlib.NewReader: %v", err) - } - bodyBytes, err := ioutil.ReadAll(zlibreader) - if err != nil { - return nil, fmt.Errorf("could not read from zlib reader using ioutil.ReadAll: %v", err) - } - return bodyBytes, nil - case "gzip": - reader := bytes.NewReader([]byte(rawBytes)) - zlibreader, err := gzip.NewReader(reader) - if err != nil { - return nil, fmt.Errorf("could not create gzip.NewReader: %v", err) - } - bodyBytes, err := ioutil.ReadAll(zlibreader) - if err != nil { - return nil, fmt.Errorf("could not read from gzip reader using ioutil.ReadAll: %v", err) - } - return bodyBytes, nil - default: - return rawBytes, nil - } -} diff --git a/apm-lambda-extension/extension/http_server_test.go b/apm-lambda-extension/extension/http_server_test.go index dd36c24d..f765d875 100644 --- a/apm-lambda-extension/extension/http_server_test.go +++ b/apm-lambda-extension/extension/http_server_test.go @@ -18,130 +18,14 @@ package extension import ( - "bytes" - "compress/gzip" - "compress/zlib" "io/ioutil" "net/http" "net/http/httptest" - "strings" "testing" "gotest.tools/assert" ) -func Test_getDecompressedBytesFromRequestUncompressed(t *testing.T) { - s := "A long time ago in a galaxy far, far away..." - body := strings.NewReader(s) - - // Create the request - req, err := http.NewRequest(http.MethodPost, "example.com", body) - if err != nil { - t.Errorf("Error creating new request: %v", err) - t.Fail() - } - - // Decompress the request's body - got, err1 := getDecompressedBytesFromRequest(req) - if err1 != nil { - t.Errorf("Error decompressing request body: %v", err1) - t.Fail() - } - - if s != string(got) { - t.Errorf("Original string and decompressed data do not match") - t.Fail() - } -} - -func Test_getDecompressedBytesFromRequestGzip(t *testing.T) { - s := "A long time ago in a galaxy far, far away..." - var b bytes.Buffer - - // Compress the data - w := gzip.NewWriter(&b) - w.Write([]byte(s)) - w.Close() - - // Create a reader reading from the bytes on the buffer - body := bytes.NewReader(b.Bytes()) - - // Create the request - req, err := http.NewRequest(http.MethodPost, "example.com", body) - if err != nil { - t.Errorf("Error creating new request: %v", err) - t.Fail() - } - - // Set the encoding to gzip - req.Header.Set("Content-Encoding", "gzip") - - // Decompress the request's body - got, err1 := getDecompressedBytesFromRequest(req) - if err1 != nil { - t.Errorf("Error decompressing request body: %v", err1) - t.Fail() - } - - if s != string(got) { - t.Errorf("Original string and decompressed data do not match") - t.Fail() - } -} - -func Test_getDecompressedBytesFromRequestDeflate(t *testing.T) { - s := "A long time ago in a galaxy far, far away..." - var b bytes.Buffer - - // Compress the data - w := zlib.NewWriter(&b) - w.Write([]byte(s)) - w.Close() - - // Create a reader reading from the bytes on the buffer - body := bytes.NewReader(b.Bytes()) - - // Create the request - req, err := http.NewRequest(http.MethodPost, "example.com", body) - if err != nil { - t.Errorf("Error creating new request: %v", err) - t.Fail() - } - - // Set the encoding to deflate - req.Header.Set("Content-Encoding", "deflate") - - // Decompress the request's body - got, err1 := getDecompressedBytesFromRequest(req) - if err1 != nil { - t.Errorf("Error decompressing request body: %v", err1) - t.Fail() - } - - if s != string(got) { - t.Errorf("Original string and decompressed data do not match") - t.Fail() - } -} - -func Test_getDecompressedBytesFromRequestEmptyBody(t *testing.T) { - // Create the request - req, err := http.NewRequest(http.MethodPost, "example.com", nil) - if err != nil { - t.Errorf("Error creating new request: %v", err) - } - - got, err := getDecompressedBytesFromRequest(req) - if err != nil { - t.Errorf("Error decompressing request body: %v", err) - } - - if len(got) != 0 { - t.Errorf("A non-empty byte slice was returned") - t.Fail() - } -} - func TestInfoProxy(t *testing.T) { headers := map[string]string{"Authorization": "test-value"} wantResp := "{\"foo\": \"bar\"}" @@ -157,7 +41,7 @@ func TestInfoProxy(t *testing.T) { defer apmServer.Close() // Create extension config and start the server - dataChannel := make(chan []byte, 100) + dataChannel := make(chan AgentData, 100) config := extensionConfig{ apmServerUrl: apmServer.URL, apmServerSecretToken: "foo", diff --git a/apm-lambda-extension/extension/process_events.go b/apm-lambda-extension/extension/process_events.go index 4ec5a53e..54f479e3 100644 --- a/apm-lambda-extension/extension/process_events.go +++ b/apm-lambda-extension/extension/process_events.go @@ -27,7 +27,7 @@ func ProcessShutdown() { log.Println("Exiting") } -func FlushAPMData(dataChannel chan []byte, config *extensionConfig) { +func FlushAPMData(dataChannel chan AgentData, config *extensionConfig) { log.Println("Checking for agent data") for { select { diff --git a/apm-lambda-extension/extension/route_handlers.go b/apm-lambda-extension/extension/route_handlers.go index a6471f5e..60d4373c 100644 --- a/apm-lambda-extension/extension/route_handlers.go +++ b/apm-lambda-extension/extension/route_handlers.go @@ -23,6 +23,11 @@ import ( "net/http" ) +type AgentData struct { + Data []byte + ContentEncoding string +} + // URL: http://server/ func handleInfoRequest(handler *serverHandler, w http.ResponseWriter, r *http.Request) { client := &http.Client{} @@ -68,13 +73,24 @@ func handleInfoRequest(handler *serverHandler, w http.ResponseWriter, r *http.Re // URL: http://server/intake/v2/events func handleIntakeV2Events(handler *serverHandler, w http.ResponseWriter, r *http.Request) { - bodyBytes, err := getDecompressedBytesFromRequest(r) - if nil != err { - log.Printf("could not get decompressed bytes from request body: %v", err) - } else { - log.Println("Adding agent data to buffer to be sent to apm server") - handler.data <- bodyBytes - } w.WriteHeader(http.StatusAccepted) w.Write([]byte("ok")) + + if r.Body == nil { + log.Println("Could not get bytes from agent request body") + return + } + + rawBytes, err := ioutil.ReadAll(r.Body) + if err != nil { + log.Println("Could not read bytes from agent request body") + return + } + + agentData := AgentData{ + Data: rawBytes, + ContentEncoding: r.Header.Get("Content-Encoding"), + } + log.Println("Adding agent data to buffer to be sent to apm server") + handler.data <- agentData } diff --git a/apm-lambda-extension/main.go b/apm-lambda-extension/main.go index b9ec5eb0..c0bbf5f5 100644 --- a/apm-lambda-extension/main.go +++ b/apm-lambda-extension/main.go @@ -61,9 +61,9 @@ func main() { // setup http server to receive data from agent // and get a channel to listen for that data - dataChannel := make(chan []byte, 100) + agentDataChannel := make(chan extension.AgentData, 100) - extension.NewHttpServer(dataChannel, config) + extension.NewHttpServer(agentDataChannel, config) // Make channel for collecting logs and create a HTTP server to listen for them logsChannel := make(chan logsapi.LogEvent) @@ -112,7 +112,7 @@ func main() { // Flush any APM data, in case waiting for the runtimeDone event timed out, // the agent data wasn't available yet, and we got to the next event - extension.FlushAPMData(dataChannel, config) + extension.FlushAPMData(agentDataChannel, config) // Make a channel for signaling that a runtimeDone event has been received runtimeDone := make(chan struct{}) @@ -129,7 +129,7 @@ func main() { case <-funcInvocDone: log.Println("Function invocation is complete, not receiving any more agent data") return - case agentData := <-dataChannel: + case agentData := <-agentDataChannel: err := extension.PostToApmServer(agentData, config) if err != nil { log.Printf("Error sending to APM server, skipping: %v", err) @@ -179,7 +179,7 @@ func main() { } // Flush APM data now that the function invocation has completed - extension.FlushAPMData(dataChannel, config) + extension.FlushAPMData(agentDataChannel, config) // Signal that the function invocation has completed close(funcInvocDone)