Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Re-wrote the middleware compress library. #401

Merged
merged 1 commit into from Feb 27, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
277 changes: 194 additions & 83 deletions middleware/compress.go
Expand Up @@ -6,25 +6,76 @@ import (
"compress/gzip"
"errors"
"io"
"io/ioutil"
"net"
"net/http"
"strings"
"sync"
)

var encoders = map[string]EncoderFunc{}
var defaultCompressibleContentTypes = []string{
"text/html",
"text/css",
"text/plain",
"text/javascript",
"application/javascript",
"application/x-javascript",
"application/json",
"application/atom+xml",
"application/rss+xml",
"image/svg+xml",
}

// A default compressor that allows for the old API to use the new code.
// DEPRECATED
var defaultCompressor *Compressor

// Compressor represents a set of encoding configurations.
type Compressor struct {
level int // The compression level.
// The mapping of encoder names to encoder functions.
encoders map[string]EncoderFunc
// The mapping of pooled encoders to pools.
pooledEncoders map[string]*sync.Pool
// The set of content types allowed to be compressed.
allowedTypes map[string]bool
// The list of encoders in order of decreasing precedence.
encodingPrecedence []string
}

var encodingPrecedence = []string{"br", "gzip", "deflate"}
// NewCompressor creates a new Compressor that will handle encoding responses.
//
// The level should be one of the ones defined in the flate package.
// The types are the content types that are allowed to be compressed.
func NewCompressor(level int, types ...string) *Compressor {
// If types are provided, set those as the allowed types. If none are
// provided, use the default list.
allowedTypes := make(map[string]bool)
if len(types) > 0 {
for _, t := range types {
allowedTypes[t] = true
}
} else {
for _, t := range defaultCompressibleContentTypes {
allowedTypes[t] = true
}
}

func init() {
c := &Compressor{
level: level,
encoders: make(map[string]EncoderFunc),
pooledEncoders: make(map[string]*sync.Pool),
allowedTypes: allowedTypes,
}
// Set the default encoders. The precedence order uses the reverse
// ordering that the encoders were added. This means adding new encoders
// will move them to the front of the order.
//
// TODO:
// lzma: Opera.
// sdch: Chrome, Android. Gzip output + dictionary header.
// br: Brotli, see https://github.com/go-chi/chi/pull/326

// TODO: Exception for old MSIE browsers that can't handle non-HTML?
// https://zoompf.com/blog/2012/02/lose-the-wait-http-compression
SetEncoder("gzip", encoderGzip)

// HTTP 1.1 "deflate" (RFC 2616) stands for DEFLATE data (RFC 1951)
// wrapped with zlib (RFC 1950). The zlib wrapper uses Adler-32
// checksum compared to CRC-32 used in "gzip" and thus is faster.
Expand All @@ -41,21 +92,20 @@ func init() {
//
// That's why we prefer gzip over deflate. It's just more reliable
// and not significantly slower than gzip.
SetEncoder("deflate", encoderDeflate)
c.SetEncoder("deflate", encoderDeflate)

// TODO: Exception for old MSIE browsers that can't handle non-HTML?
// https://zoompf.com/blog/2012/02/lose-the-wait-http-compression
c.SetEncoder("gzip", encoderGzip)

// NOTE: Not implemented, intentionally:
// case "compress": // LZW. Deprecated.
// case "bzip2": // Too slow on-the-fly.
// case "zopfli": // Too slow on-the-fly.
// case "xz": // Too slow on-the-fly.
return c
}

// An EncoderFunc is a function that wraps the provided ResponseWriter with a
// streaming compression algorithm and returns it.
//
// In case of failure, the function should return nil.
type EncoderFunc func(w http.ResponseWriter, level int) io.Writer

// SetEncoder can be used to set the implementation of a compression algorithm.
//
// The encoding should be a standardised identifier. See:
Expand All @@ -65,126 +115,190 @@ type EncoderFunc func(w http.ResponseWriter, level int) io.Writer
//
// import brotli_enc "gopkg.in/kothar/brotli-go.v0/enc"
//
// middleware.SetEncoder("br", func(w http.ResponseWriter, level int) io.Writer {
// compressor := middleware.NewCompressor(5, "text/html")
// compressor.SetEncoder("br", func(w http.ResponseWriter, level int) io.Writer {
// params := brotli_enc.NewBrotliParams()
// params.SetQuality(level)
// return brotli_enc.NewBrotliWriter(params, w)
// })
func SetEncoder(encoding string, fn EncoderFunc) {
func (c *Compressor) SetEncoder(encoding string, fn EncoderFunc) {
encoding = strings.ToLower(encoding)
if encoding == "" {
panic("the encoding can not be empty")
}
if fn == nil {
panic("attempted to set a nil encoder function")
}
encoders[encoding] = fn

var e string
for _, v := range encodingPrecedence {
if v == encoding {
e = v
}
// If we are adding a new encoder that is already registered, we have to
// clear that one out first.
if _, ok := c.pooledEncoders[encoding]; ok {
delete(c.pooledEncoders, encoding)
}

if e == "" {
encodingPrecedence = append([]string{e}, encodingPrecedence...)
if _, ok := c.encoders[encoding]; ok {
delete(c.encoders, encoding)
}
}

var defaultContentTypes = map[string]struct{}{
"text/html": {},
"text/css": {},
"text/plain": {},
"text/javascript": {},
"application/javascript": {},
"application/x-javascript": {},
"application/json": {},
"application/atom+xml": {},
"application/rss+xml": {},
"image/svg+xml": {},
}

// DefaultCompress is a middleware that compresses response
// body of predefined content types to a data format based
// on Accept-Encoding request header. It uses a default
// compression level.
func DefaultCompress(next http.Handler) http.Handler {
return Compress(flate.DefaultCompression)(next)
}
// If the encoder supports Resetting (IoReseterWriter), then it can be pooled.
encoder := fn(ioutil.Discard, c.level)
if encoder != nil {
if _, ok := encoder.(ioResetterWriter); ok {
pool := &sync.Pool{
New: func() interface{} {
return fn(ioutil.Discard, c.level)
},
}
c.pooledEncoders[encoding] = pool
}
}
// If the encoder is not in the pooledEncoders, add it to the normal encoders.
if _, ok := c.pooledEncoders[encoding]; !ok {
c.encoders[encoding] = fn
}

// Compress is a middleware that compresses response
// body of a given content types to a data format based
// on Accept-Encoding request header. It uses a given
// compression level.
//
// NOTE: make sure to set the Content-Type header on your response
// otherwise this middleware will not compress the response body. For ex, in
// your handler you should set w.Header().Set("Content-Type", http.DetectContentType(yourBody))
// or set it manually.
func Compress(level int, types ...string) func(next http.Handler) http.Handler {
contentTypes := defaultContentTypes
if len(types) > 0 {
contentTypes = make(map[string]struct{}, len(types))
for _, t := range types {
contentTypes[t] = struct{}{}
for i, v := range c.encodingPrecedence {
if v == encoding {
c.encodingPrecedence = append(c.encodingPrecedence[:i], c.encodingPrecedence[i+1:]...)
}
}

c.encodingPrecedence = append([]string{encoding}, c.encodingPrecedence...)
}

// Handler returns a new middleware that will compress the response based on the
// current Compressor.
func (c *Compressor) Handler() func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
encoder, encoding := selectEncoder(r.Header)
encoder, encoding, cleanup := c.selectEncoder(r.Header, w)

cw := &compressResponseWriter{
ResponseWriter: w,
w: w,
contentTypes: contentTypes,
encoder: encoder,
contentTypes: c.allowedTypes,
encoding: encoding,
level: level,
}
if encoder != nil {
cw.w = encoder
}
// Re-add the encoder to the pool if applicable.
defer cleanup()
defer cw.Close()

next.ServeHTTP(cw, r)
}

return http.HandlerFunc(fn)
}

}

func selectEncoder(h http.Header) (EncoderFunc, string) {
// selectEncoder returns the encoder, the name of the encoder, and a closer function.
func (c *Compressor) selectEncoder(h http.Header, w io.Writer) (io.Writer, string, func()) {
header := h.Get("Accept-Encoding")

// Parse the names of all accepted algorithms from the header.
accepted := strings.Split(strings.ToLower(header), ",")

// Find supported encoder by accepted list by precedence
for _, name := range encodingPrecedence {
if fn, ok := encoders[name]; ok && matchAcceptEncoding(accepted, name) {
return fn, name
for _, name := range c.encodingPrecedence {
if matchAcceptEncoding(accepted, name) {
if pool, ok := c.pooledEncoders[name]; ok {
encoder := pool.Get().(ioResetterWriter)
cleanup := func() {
pool.Put(encoder)
}
encoder.Reset(w)
return encoder, name, cleanup

}
if fn, ok := c.encoders[name]; ok {
return fn(w, c.level), name, func() {}
}
}

}

// No encoder found to match the accepted encoding
return nil, ""
return nil, "", func() {}
}

func matchAcceptEncoding(accepted []string, encoding string) bool {
for _, v := range accepted {
if strings.Index(v, encoding) >= 0 {
if strings.Contains(v, encoding) {
return true
}
}
return false
}

// An EncoderFunc is a function that wraps the provided io.Writer with a
// streaming compression algorithm and returns it.
//
// In case of failure, the function should return nil.
type EncoderFunc func(w io.Writer, level int) io.Writer

// Interface for types that allow resetting io.Writers.
type ioResetterWriter interface {
io.Writer
Reset(w io.Writer)
}

// SetEncoder can be used to set the implementation of a compression algorithm.
//
// The encoding should be a standardised identifier. See:
// https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Accept-Encoding
//
// For example, add the Brotli algortithm:
//
// import brotli_enc "gopkg.in/kothar/brotli-go.v0/enc"
//
// middleware.SetEncoder("br", func(w http.ResponseWriter, level int) io.Writer {
// params := brotli_enc.NewBrotliParams()
// params.SetQuality(level)
// return brotli_enc.NewBrotliWriter(params, w)
// })
//
// DEPRECATED
func SetEncoder(encoding string, fn EncoderFunc) {
if defaultCompressor == nil {
panic("no compressor to set encoders on. Call Compress() first")
}
defaultCompressor.SetEncoder(encoding, fn)
}

// DefaultCompress is a middleware that compresses response
// body of predefined content types to a data format based
// on Accept-Encoding request header. It uses a default
// compression level.
// DEPRECATED
func DefaultCompress(next http.Handler) http.Handler {
return Compress(flate.DefaultCompression)(next)
}

// Compress is a middleware that compresses response
// body of a given content types to a data format based
// on Accept-Encoding request header. It uses a given
// compression level.
//
// NOTE: make sure to set the Content-Type header on your response
// otherwise this middleware will not compress the response body. For ex, in
// your handler you should set w.Header().Set("Content-Type", http.DetectContentType(yourBody))
// or set it manually.
//
// DEPRECATED
func Compress(level int, types ...string) func(next http.Handler) http.Handler {
defaultCompressor = NewCompressor(level, types...)
return defaultCompressor.Handler()
}

type compressResponseWriter struct {
http.ResponseWriter
// The streaming encoder writer to be used if there is one. Otherwise,
// this is just the normal writer.
w io.Writer
encoder EncoderFunc
encoding string
contentTypes map[string]struct{}
level int
contentTypes map[string]bool
wroteHeader bool
}

Expand Down Expand Up @@ -212,14 +326,11 @@ func (cw *compressResponseWriter) WriteHeader(code int) {
return
}

if cw.encoder != nil && cw.encoding != "" {
if wr := cw.encoder(cw.ResponseWriter, cw.level); wr != nil {
cw.w = wr
cw.Header().Set("Content-Encoding", cw.encoding)
if cw.encoding != "" {
cw.Header().Set("Content-Encoding", cw.encoding)

// The content-length after compression is unknown
cw.Header().Del("Content-Length")
}
// The content-length after compression is unknown
cw.Header().Del("Content-Length")
}
}

Expand Down Expand Up @@ -258,15 +369,15 @@ func (cw *compressResponseWriter) Close() error {
return errors.New("chi/middleware: io.WriteCloser is unavailable on the writer")
}

func encoderGzip(w http.ResponseWriter, level int) io.Writer {
func encoderGzip(w io.Writer, level int) io.Writer {
gw, err := gzip.NewWriterLevel(w, level)
if err != nil {
return nil
}
return gw
}

func encoderDeflate(w http.ResponseWriter, level int) io.Writer {
func encoderDeflate(w io.Writer, level int) io.Writer {
dw, err := flate.NewWriter(w, level)
if err != nil {
return nil
Expand Down