From 78d77b66beba656bd0d7c0aaff4dc2537f4d101f Mon Sep 17 00:00:00 2001 From: Umputun Date: Sun, 4 Apr 2021 02:35:58 -0500 Subject: [PATCH] check gz headers and allow custom headers --- go.mod | 2 +- go.sum | 2 ++ gzip.go | 65 +++++++++++++++++++++++++++++++++++++++--------- gzip_test.go | 70 ++++++++++++++++++++++++++++++++++++++++++++++++++-- 4 files changed, 124 insertions(+), 15 deletions(-) diff --git a/go.mod b/go.mod index aab6b51..7898921 100644 --- a/go.mod +++ b/go.mod @@ -4,5 +4,5 @@ go 1.15 require ( github.com/pkg/errors v0.9.1 - github.com/stretchr/testify v1.6.1 + github.com/stretchr/testify v1.7.0 ) diff --git a/go.sum b/go.sum index 7034f5b..0caa82c 100644 --- a/go.sum +++ b/go.sum @@ -8,6 +8,8 @@ github.com/stretchr/objx v0.1.0 h1:4G4v2dO3VZwixGIRoQ5Lfboy6nUhCyYzaqnIAPPhYs4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= diff --git a/gzip.go b/gzip.go index 4e8e005..e1d3d1a 100644 --- a/gzip.go +++ b/gzip.go @@ -9,6 +9,17 @@ import ( "sync" ) +var gzDefaultContentTypes = []string{ + "text/css", + "text/javascript", + "text/xml", + "text/html", + "text/plain", + "application/javascript", + "application/x-javascript", + "application/json", +} + var gzPool = sync.Pool{ New: func() interface{} { return gzip.NewWriter(ioutil.Discard) }, } @@ -28,21 +39,51 @@ func (w *gzipResponseWriter) Write(b []byte) (int, error) { } // Gzip is a middleware compressing response -func Gzip(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if !strings.Contains(r.Header.Get("Accept-Encoding"), "gzip") { - next.ServeHTTP(w, r) - return +func Gzip(contentTypes ...string) func(http.Handler) http.Handler { + + gzCts := gzDefaultContentTypes + if len(contentTypes) > 0 { + gzCts = contentTypes + } + + contentType := func(r *http.Request) string { + result := r.Header.Get("Content-type") + if result == "" { + return "application/octet-stream" } + return result + } + + f := func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if !strings.Contains(r.Header.Get("Accept-Encoding"), "gzip") { + next.ServeHTTP(w, r) + return + } + + var gzOk bool + ctype := contentType(r) + for _, c := range gzCts { + if strings.EqualFold(ctype, c) { + gzOk = true + break + } + } - w.Header().Set("Content-Encoding", "gzip") + if !gzOk { + next.ServeHTTP(w, r) + return + } - gz := gzPool.Get().(*gzip.Writer) - defer gzPool.Put(gz) + w.Header().Set("Content-Encoding", "gzip") + gz := gzPool.Get().(*gzip.Writer) + defer gzPool.Put(gz) - gz.Reset(w) - defer gz.Close() + gz.Reset(w) + defer gz.Close() - next.ServeHTTP(&gzipResponseWriter{ResponseWriter: w, Writer: gz}, r) - }) + next.ServeHTTP(&gzipResponseWriter{ResponseWriter: w, Writer: gz}, r) + }) + } + return f } diff --git a/gzip_test.go b/gzip_test.go index 2d869ae..0792eae 100644 --- a/gzip_test.go +++ b/gzip_test.go @@ -13,7 +13,7 @@ import ( "github.com/stretchr/testify/require" ) -func TestGzip(t *testing.T) { +func TestGzipCustom(t *testing.T) { handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { _, err := w.Write([]byte("Lorem Ipsum is simply dummy text of the printing and typesetting industry. " + "Lorem Ipsum has been the industry’s standard dummy text ever since the 1500s, when an unknown printer took " + @@ -23,7 +23,7 @@ func TestGzip(t *testing.T) { "and more recently with desktop publishing software like Aldus PageMaker including versions of Lorem Ipsum.")) require.NoError(t, err) }) - ts := httptest.NewServer(Gzip(handler)) + ts := httptest.NewServer(Gzip("text/plain", "text/html")(handler)) defer ts.Close() client := http.Client{} @@ -32,6 +32,7 @@ func TestGzip(t *testing.T) { req, err := http.NewRequest("GET", ts.URL+"/something", nil) require.NoError(t, err) req.Header.Set("Accept-Encoding", "gzip") + req.Header.Set("Content-Type", "text/plain") resp, err := client.Do(req) require.NoError(t, err) assert.Equal(t, 200, resp.StatusCode) @@ -46,6 +47,58 @@ func TestGzip(t *testing.T) { require.NoError(t, err) assert.True(t, strings.HasPrefix(string(b), "Lorem Ipsum"), string(b)) } + + { + req, err := http.NewRequest("GET", ts.URL+"/something", nil) + require.NoError(t, err) + req.Header.Set("Accept-Encoding", "gzip") + req.Header.Set("Content-Type", "something") + resp, err := client.Do(req) + require.NoError(t, err) + assert.Equal(t, 200, resp.StatusCode) + defer resp.Body.Close() + b, err := ioutil.ReadAll(resp.Body) + assert.NoError(t, err) + assert.Equal(t, 576, len(b), "uncompressed size") + } + +} + +func TestGzipDefault(t *testing.T) { + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, err := w.Write([]byte("Lorem Ipsum is simply dummy text of the printing and typesetting industry. " + + "Lorem Ipsum has been the industry’s standard dummy text ever since the 1500s, when an unknown printer took " + + "a galley of type and scrambled it to make a type specimen book. It has survived not only five centuries," + + " but also the leap into electronic typesetting, remaining essentially unchanged. It was popularized" + + " in the 1960s with the release of Letraset sheets containing Lorem Ipsum passages, " + + "and more recently with desktop publishing software like Aldus PageMaker including versions of Lorem Ipsum.")) + require.NoError(t, err) + }) + ts := httptest.NewServer(Gzip()(handler)) + defer ts.Close() + + client := http.Client{} + + { + req, err := http.NewRequest("GET", ts.URL+"/something", nil) + require.NoError(t, err) + req.Header.Set("Accept-Encoding", "gzip") + req.Header.Set("Content-Type", "text/plain") + resp, err := client.Do(req) + require.NoError(t, err) + assert.Equal(t, 200, resp.StatusCode) + defer resp.Body.Close() + b, err := ioutil.ReadAll(resp.Body) + assert.NoError(t, err) + assert.Equal(t, 357, len(b), "compressed size") + + gzr, err := gzip.NewReader(bytes.NewBuffer(b)) + require.NoError(t, err) + b, err = ioutil.ReadAll(gzr) + require.NoError(t, err) + assert.True(t, strings.HasPrefix(string(b), "Lorem Ipsum"), string(b)) + } + { req, err := http.NewRequest("GET", ts.URL+"/something", nil) require.NoError(t, err) @@ -56,7 +109,20 @@ func TestGzip(t *testing.T) { b, err := ioutil.ReadAll(resp.Body) assert.NoError(t, err) assert.Equal(t, 576, len(b), "uncompressed size") + } + { + req, err := http.NewRequest("GET", ts.URL+"/something", nil) + require.NoError(t, err) + req.Header.Set("Accept-Encoding", "gzip") + req.Header.Set("Content-Type", "something") + resp, err := client.Do(req) + require.NoError(t, err) + assert.Equal(t, 200, resp.StatusCode) + defer resp.Body.Close() + b, err := ioutil.ReadAll(resp.Body) + assert.NoError(t, err) + assert.Equal(t, 576, len(b), "uncompressed size") } }