Skip to content

Commit

Permalink
decoder: sniff Content-Encoding
Browse files Browse the repository at this point in the history
  • Loading branch information
axw committed Mar 29, 2022
1 parent 826bc5a commit 6e6cc52
Show file tree
Hide file tree
Showing 2 changed files with 220 additions and 22 deletions.
97 changes: 75 additions & 22 deletions decoder/req_decoder.go → decoder/requestreader.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@ import (
"io"
"net/http"

"github.com/pkg/errors"

"github.com/elastic/beats/v7/libbeat/monitoring"
)

Expand All @@ -40,42 +38,80 @@ var (
readerCounter = monitoring.NewInt(decoderMetrics, "reader.count")
)

// CompressedRequestReader returns a reader that will decompress
// the body according to the supplied Content-Encoding header in the request
func CompressedRequestReader(req *http.Request) (io.ReadCloser, error) {
reader := req.Body
if reader == nil {
return nil, errors.New("no content")
}
const (
unspecifiedContentEncoding = iota
deflateContentEncoding
gzipContentEncoding
uncompressedContentEncoding
)

// CompressedRequestReader returns a reader that will decompress the body
// according to the supplied Content-Encoding request header, or by sniffing
// the body contents if no header is supplied by looking for magic byte
// headers.
//
// Content-Encoding sniffing is implemented to support the RUM agent sending
// compressed payloads using the Beacon API (https://w3c.github.io/beacon/),
// which does not support specifying request headers.
func CompressedRequestReader(req *http.Request) (io.ReadCloser, error) {
cLen := req.ContentLength
knownCLen := cLen > -1
if !knownCLen {
missingContentLengthCounter.Inc()
}

var reader io.ReadCloser
var err error
contentEncoding := unspecifiedContentEncoding
switch req.Header.Get("Content-Encoding") {
case "deflate":
contentEncoding = deflateContentEncoding
reader, err = zlib.NewReader(req.Body)
case "gzip":
contentEncoding = gzipContentEncoding
reader, err = gzip.NewReader(req.Body)
default:
// Sniff encoding from payload by looking at the first two bytes.
// This produces much less garbage than opportunistically calling
// gzip.NewReader, zlib.NewReader, etc.
//
// Portions of code based on compress/zlib and compress/gzip.
const (
zlibDeflate = 8
gzipID1 = 0x1f
gzipID2 = 0x8b
)
rc := &compressedRequestReadCloser{reader: req.Body, Closer: req.Body}
if _, err := io.ReadFull(req.Body, rc.magic[:]); err != nil {
return nil, err
}
if rc.magic[0] == gzipID1 && rc.magic[1] == gzipID2 {
contentEncoding = gzipContentEncoding
reader, err = gzip.NewReader(rc)
} else if rc.magic[0]&0x0f == zlibDeflate {
contentEncoding = deflateContentEncoding
reader, err = zlib.NewReader(rc)
} else {
contentEncoding = uncompressedContentEncoding
reader = rc
}
}
if err != nil {
return nil, err
}

switch contentEncoding {
case deflateContentEncoding:
if knownCLen {
deflateLengthAccumulator.Add(cLen)
deflateCounter.Inc()
}
var err error
reader, err = zlib.NewReader(reader)
if err != nil {
return nil, err
}

case "gzip":
case gzipContentEncoding:
if knownCLen {
gzipLengthAccumulator.Add(cLen)
gzipCounter.Inc()
}
var err error
reader, err = gzip.NewReader(reader)
if err != nil {
return nil, err
}
default:
case uncompressedContentEncoding:
if knownCLen {
uncompressedLengthAccumulator.Add(cLen)
uncompressedCounter.Inc()
Expand All @@ -84,3 +120,20 @@ func CompressedRequestReader(req *http.Request) (io.ReadCloser, error) {
readerCounter.Inc()
return reader, nil
}

type compressedRequestReadCloser struct {
magic [2]byte
magicRead int
reader io.Reader
io.Closer
}

func (r *compressedRequestReadCloser) Read(p []byte) (int, error) {
var nmagic int
if r.magicRead < 2 {
nmagic = copy(p[:], r.magic[r.magicRead:])
r.magicRead += nmagic
}
n, err := r.reader.Read(p[nmagic:])
return n + nmagic, err
}
145 changes: 145 additions & 0 deletions decoder/requestreader_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
// Licensed to Elasticsearch B.V. under one or more contributor
// license agreements. See the NOTICE file distributed with
// this work for additional information regarding copyright
// ownership. Elasticsearch B.V. licenses this file to you under
// the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

package decoder_test

import (
"bytes"
"compress/gzip"
"compress/zlib"
"io"
"io/ioutil"
"net/http/httptest"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/elastic/apm-server/decoder"
)

func TestCompressedRequestReader(t *testing.T) {
uncompressed := "uncompressed input"
zlibCompressed := zlibCompressString(uncompressed)
gzipCompressed := gzipCompressString(uncompressed)

requestBodyReader := func(contentEncoding string, body []byte) (io.ReadCloser, error) {
req := httptest.NewRequest("GET", "/", bytes.NewReader(body))
if contentEncoding != "" {
req.Header.Set("Content-Encoding", contentEncoding)
}
return decoder.CompressedRequestReader(req)
}

type test struct {
input []byte
contentEncoding string
}
for _, test := range []test{{
input: []byte(uncompressed),
contentEncoding: "", // sniff
}, {
input: zlibCompressed,
contentEncoding: "", // sniff
}, {
input: gzipCompressed,
contentEncoding: "", // sniff
}, {
input: zlibCompressed,
contentEncoding: "deflate",
}, {
input: gzipCompressed,
contentEncoding: "gzip",
}} {
reader, err := requestBodyReader(test.contentEncoding, test.input)
require.NoError(t, err)
assertReaderContents(t, uncompressed, reader)
}

_, err := requestBodyReader("deflate", gzipCompressed)
assert.Equal(t, zlib.ErrHeader, err)

_, err = requestBodyReader("gzip", zlibCompressed)
assert.Equal(t, gzip.ErrHeader, err)
}

func BenchmarkCompressedRequestReader(b *testing.B) {
benchmark := func(b *testing.B, input []byte, contentEncoding string) {
req := httptest.NewRequest("GET", "/", bytes.NewReader(input))
if contentEncoding != "" {
req.Header.Set("Content-Encoding", contentEncoding)
}
for i := 0; i < b.N; i++ {
req.Body = ioutil.NopCloser(bytes.NewReader(input))
if _, err := decoder.CompressedRequestReader(req); err != nil {
b.Fatal(err)
}
}
}

b.Run("uncompressed", func(b *testing.B) {
benchmark(b, []byte("uncompressed"), "")
})
b.Run("gzip_content_encoding", func(b *testing.B) {
benchmark(b, gzipCompressString("uncompressed"), "gzip")
})
b.Run("gzip_sniff", func(b *testing.B) {
benchmark(b, gzipCompressString("uncompressed"), "")
})
b.Run("deflate_content_encoding", func(b *testing.B) {
benchmark(b, zlibCompressString("uncompressed"), "deflate")
})
b.Run("deflate_sniff", func(b *testing.B) {
benchmark(b, zlibCompressString("uncompressed"), "")
})
}

func assertReaderContents(t *testing.T, expected string, r io.Reader) {
t.Helper()
contents, err := ioutil.ReadAll(r)
require.NoError(t, err)
assert.Equal(t, expected, string(contents))
}

func zlibCompressString(s string) []byte {
var buf bytes.Buffer
w, err := zlib.NewWriterLevel(&buf, zlib.BestSpeed)
if err != nil {
panic(err)
}
compressString(s, w)
return buf.Bytes()
}

func gzipCompressString(s string) []byte {
var buf bytes.Buffer
w, err := gzip.NewWriterLevel(&buf, gzip.BestSpeed)
if err != nil {
panic(err)
}
compressString(s, w)
return buf.Bytes()
}

func compressString(s string, w io.WriteCloser) {
if _, err := w.Write([]byte(s)); err != nil {
panic(err)
}
if err := w.Close(); err != nil {
panic(err)
}
}

0 comments on commit 6e6cc52

Please sign in to comment.