From f0cc2fb61998c3d1972163e8c92f67d772121e7b Mon Sep 17 00:00:00 2001 From: Marcel Edmund Franke Date: Sat, 24 Nov 2018 16:11:22 +0100 Subject: [PATCH] Check if response is to large for cache --- cmd/httpcache/main.go | 28 ++--- internal/cache/cache.go | 7 +- internal/handler/proxycache.go | 17 ++- internal/handler/stats.go | 19 ++-- internal/middleware/panic.go | 6 +- internal/middleware/panic_test.go | 2 +- internal/roundtripper/cacher.go | 4 +- internal/roundtripper/reponse_body_limit.go | 26 +++++ internal/size/size.go | 8 ++ internal/xhttp/server.go | 3 +- tests/api_test.go | 114 +++++++++++++++++--- 11 files changed, 177 insertions(+), 57 deletions(-) create mode 100644 internal/roundtripper/reponse_body_limit.go create mode 100644 internal/size/size.go diff --git a/cmd/httpcache/main.go b/cmd/httpcache/main.go index 17e191d..a58a20e 100644 --- a/cmd/httpcache/main.go +++ b/cmd/httpcache/main.go @@ -6,6 +6,7 @@ import ( "github.com/donutloop/httpcache/internal/cache" "github.com/donutloop/httpcache/internal/handler" "github.com/donutloop/httpcache/internal/middleware" + "github.com/donutloop/httpcache/internal/size" "github.com/donutloop/httpcache/internal/xhttp" "log" "net" @@ -20,12 +21,13 @@ func main() { fs := flag.NewFlagSet("http-proxy", flag.ExitOnError) var ( - httpAddr = fs.String("http", ":80", "serve HTTP on this address (optional)") - tlsAddr = fs.String("tls", "", "serve TLS on this address (optional)") - cert = fs.String("cert", "server.crt", "TLS certificate") - key = fs.String("key", "server.key", "TLS key") - cap = fs.Int64("cap", 100, "capacity of cache") - expire = fs.Int64("expire", 5, "the items in the cache expire after or expire never") + httpAddr = fs.String("http", ":80", "serve HTTP on this address (optional)") + tlsAddr = fs.String("tls", "", "serve TLS on this address (optional)") + cert = fs.String("cert", "server.crt", "TLS certificate") + key = fs.String("key", "server.key", "TLS key") + cap = fs.Int64("cap", 100, "capacity of cache") + responseBodyContentLenghtLimit = fs.Int64("rbcl", 500*size.MB, "response size limit") + expire = fs.Int64("expire", 5, "the items in the cache expire after or expire never") ) fs.Usage = usageFor(fs, "httpcache [flags]") fs.Parse(os.Args[1:]) @@ -41,7 +43,7 @@ func main() { } } - proxy := handler.NewProxy(c, logger.Println) + proxy := handler.NewProxy(c, logger.Println, *responseBodyContentLenghtLimit) stats := handler.NewStats(c, logger.Println) mux := http.NewServeMux() @@ -57,9 +59,9 @@ func main() { } xserver := xhttp.Server{ - Server: &http.Server{Addr: *httpAddr, Handler: stack}, - Logger: logger, - Listener: listener, + Server: &http.Server{Addr: *httpAddr, Handler: stack}, + Logger: logger, + Listener: listener, ShutdownTimeout: 3 * time.Second, } if err := xserver.Start(); err != nil { @@ -77,9 +79,9 @@ func main() { } xserver := xhttp.Server{ - Server: &http.Server{Addr: *tlsAddr, Handler: stack}, - Logger: logger, - Listener: listener, + Server: &http.Server{Addr: *tlsAddr, Handler: stack}, + Logger: logger, + Listener: listener, ShutdownTimeout: 3 * time.Second, } if err := xserver.StartTLS(*cert, *key); err != nil { diff --git a/internal/cache/cache.go b/internal/cache/cache.go index 8710427..bd381a6 100644 --- a/internal/cache/cache.go +++ b/internal/cache/cache.go @@ -57,14 +57,13 @@ type entry struct { // NewLRUCache creates a new empty cache with the given capacity. func NewLRUCache(capacity int64, expiry time.Duration) *LRUCache { - cache := &LRUCache{ + cache := &LRUCache{ list: list.New(), table: make(map[string]*list.Element), capacity: capacity, - expiry: expiry, + expiry: expiry, } - // We have expiry start the janitor routine. if expiry > 0 { // Initialize a new stop GC channel. @@ -246,4 +245,4 @@ func (c *LRUCache) StartGC() { } } }() -} \ No newline at end of file +} diff --git a/internal/handler/proxycache.go b/internal/handler/proxycache.go index 9b640ea..6f05d7d 100644 --- a/internal/handler/proxycache.go +++ b/internal/handler/proxycache.go @@ -7,15 +7,19 @@ import ( "io/ioutil" "net/http" "net/http/httputil" + "strings" ) -func NewProxy(cache *cache.LRUCache, logger func(v ...interface{})) *Proxy { +func NewProxy(cache *cache.LRUCache, logger func(v ...interface{}), contentLength int64) *Proxy { return &Proxy{ client: &http.Client{ Transport: &roundtripper.LoggedTransport{ Transport: &roundtripper.CacheTransport{ - Transport: http.DefaultTransport, - Cache: cache, + Transport: &roundtripper.ResponseBodyLimitRoundTripper{ + Transport: http.DefaultTransport, + Limit: contentLength, + }, + Cache: cache, }, Logger: logger, }}, @@ -33,7 +37,10 @@ func (p *Proxy) ServeHTTP(resp http.ResponseWriter, req *http.Request) { req.RequestURI = "" proxyResponse, err := p.client.Do(req) if err != nil { - p.logger(err.Error()) + if strings.Contains(err.Error(), roundtripper.ResponseIsToLarge.Error()) { + resp.WriteHeader(http.StatusRequestEntityTooLarge) + return + } resp.WriteHeader(http.StatusInternalServerError) return } @@ -73,4 +80,4 @@ func dump(request *http.Request, response *http.Response) (requestDump, response return nil, nil, err } return dumpedRequest, dumpedResponse, nil -} \ No newline at end of file +} diff --git a/internal/handler/stats.go b/internal/handler/stats.go index 85d1389..5dbe021 100644 --- a/internal/handler/stats.go +++ b/internal/handler/stats.go @@ -10,20 +10,20 @@ import ( func NewStats(c *cache.LRUCache, logger func(v ...interface{})) *Stats { return &Stats{ - c: c, + c: c, logger: logger, } } type StatsResponse struct { - Length int64 `json:"length"` - Size int64 `json:"size"` - Capacity int64 `json:"capacity"` - Oldest time.Time `json:"oldest"` + Length int64 `json:"length"` + Size int64 `json:"size"` + Capacity int64 `json:"capacity"` + Oldest time.Time `json:"oldest"` } type Stats struct { - c *cache.LRUCache + c *cache.LRUCache logger func(v ...interface{}) } @@ -47,12 +47,11 @@ func (s *Stats) Endpoint() *StatsResponse { length, size, capacity, oldest := s.c.Stats() resp := &StatsResponse{ - Length: length, - Size: size, + Length: length, + Size: size, Capacity: capacity, - Oldest: oldest, + Oldest: oldest, } return resp } - diff --git a/internal/middleware/panic.go b/internal/middleware/panic.go index f51d406..3ff3d7a 100644 --- a/internal/middleware/panic.go +++ b/internal/middleware/panic.go @@ -7,15 +7,15 @@ import ( ) func NewPanic(next http.Handler, loggerFunc func(v ...interface{})) *Panic { - return &Panic{ - Next: next, + return &Panic{ + Next: next, loggerFunc: loggerFunc, } } // Panic recovers from API panics and logs encountered panics type Panic struct { - Next http.Handler + Next http.Handler loggerFunc func(v ...interface{}) } diff --git a/internal/middleware/panic_test.go b/internal/middleware/panic_test.go index 58648bb..a1d69ed 100644 --- a/internal/middleware/panic_test.go +++ b/internal/middleware/panic_test.go @@ -8,7 +8,7 @@ import ( func TestPanic(t *testing.T) { crashedHandler := http.HandlerFunc(func(http.ResponseWriter, *http.Request) { - panic("hello world") + panic("hello world") }) middleware := NewPanic(crashedHandler, t.Log) diff --git a/internal/roundtripper/cacher.go b/internal/roundtripper/cacher.go index 0ea82c6..46b605d 100644 --- a/internal/roundtripper/cacher.go +++ b/internal/roundtripper/cacher.go @@ -3,8 +3,6 @@ package roundtripper import ( "crypto/md5" "encoding/hex" - "errors" - "fmt" "github.com/donutloop/httpcache/internal/cache" "net/http" "net/http/httputil" @@ -24,7 +22,7 @@ func (t *CacheTransport) RoundTrip(req *http.Request) (*http.Response, error) { if !ok { proxyResponse, err := t.Transport.RoundTrip(req) if err != nil { - return nil, errors.New(fmt.Sprintf("proxy couldn't forward request to destination server (%v)", err)) + return nil, err } cachedResponse = &cache.CachedResponse{Resp: proxyResponse} t.Cache.Set(clonedRequest, cachedResponse) diff --git a/internal/roundtripper/reponse_body_limit.go b/internal/roundtripper/reponse_body_limit.go new file mode 100644 index 0000000..425f6ba --- /dev/null +++ b/internal/roundtripper/reponse_body_limit.go @@ -0,0 +1,26 @@ +package roundtripper + +import ( + "errors" + "net/http" +) + +var ResponseIsToLarge = errors.New("response body is to large for the cache") + +type ResponseBodyLimitRoundTripper struct { + Limit int64 + Transport http.RoundTripper // underlying transport (or default if nil) +} + +func (t *ResponseBodyLimitRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + response, err := t.Transport.RoundTrip(req) + if err != nil { + return nil, err + } + + if response.ContentLength > t.Limit { + return nil, ResponseIsToLarge + } + + return response, nil +} diff --git a/internal/size/size.go b/internal/size/size.go new file mode 100644 index 0000000..babf836 --- /dev/null +++ b/internal/size/size.go @@ -0,0 +1,8 @@ +package size + +const ( + _ = iota // ignore first value by assigning to blank identifier + KB int64 = 1 << (10 * iota) + MB + GB +) diff --git a/internal/xhttp/server.go b/internal/xhttp/server.go index 85f1d65..c80d53f 100644 --- a/internal/xhttp/server.go +++ b/internal/xhttp/server.go @@ -15,7 +15,7 @@ type Server struct { ShutdownTimeout time.Duration - Logger *log.Logger + Logger *log.Logger } // Start starts the server and waits for it to return. @@ -50,4 +50,3 @@ func (s *Server) Stop() { s.Server.Close() } - diff --git a/tests/api_test.go b/tests/api_test.go index 386f60a..072c2bb 100644 --- a/tests/api_test.go +++ b/tests/api_test.go @@ -6,6 +6,7 @@ import ( "github.com/donutloop/httpcache/internal/cache" "github.com/donutloop/httpcache/internal/handler" "github.com/donutloop/httpcache/internal/middleware" + "github.com/donutloop/httpcache/internal/size" "github.com/donutloop/httpcache/internal/xhttp" "log" "math/rand" @@ -24,7 +25,7 @@ var c *cache.LRUCache func TestMain(m *testing.M) { c = cache.NewLRUCache(100, 0) - proxy := handler.NewProxy(c, log.Println) + proxy := handler.NewProxy(c, log.Println, 500*size.MB) stats := handler.NewStats(c, log.Println) mux := http.NewServeMux() @@ -32,7 +33,7 @@ func TestMain(m *testing.M) { mux.Handle("/", proxy) stack := middleware.NewPanic(mux, log.Println) - + proxyServer := httptest.NewServer(stack) transport := &http.Transport{ @@ -125,7 +126,7 @@ func TestStatsHandler(t *testing.T) { t.Fatalf("cache length is bad, got=%d", c.Length()) } - req, err = http.NewRequest(http.MethodGet, server.URL + "/stats", nil) + req, err = http.NewRequest(http.MethodGet, server.URL+"/stats", nil) if err != nil { t.Fatal(err) } @@ -142,7 +143,7 @@ func TestStatsHandler(t *testing.T) { statsResponse := &handler.StatsResponse{} if err := json.NewDecoder(resp.Body).Decode(statsResponse); err != nil { - b , err := httputil.DumpResponse(resp, true) + b, err := httputil.DumpResponse(resp, true) if err == nil { t.Log(string(b)) } @@ -156,8 +157,90 @@ func TestStatsHandler(t *testing.T) { t.Log(fmt.Sprintf("%#v", statsResponse)) } +func TestProxyHandler_ResponseBodyContentLengthLimit(t *testing.T) { + c1 := cache.NewLRUCache(100, 1*time.Second) + { + c1.OnEviction = func(key string) { + c1.Delete(key) + } + } + cl := 1 * size.KB + t.Log("size: ", cl) + + go func() { + logger := log.New(os.Stderr, "", log.LstdFlags) + + proxy := handler.NewProxy(c1, logger.Println, cl) + mux := http.NewServeMux() + mux.Handle("/", proxy) + + listener, err := net.Listen("tcp", "localhost:4528") + if err != nil { + logger.Fatal(err) + } + + xserver := xhttp.Server{ + Server: &http.Server{Addr: "localhost:4528", Handler: proxy}, + Logger: logger, + Listener: listener, + } + if err := xserver.Start(); err != nil { + xserver.Stop() + } + }() + + <-time.After(1 * time.Second) + + transport := &http.Transport{ + Proxy: SetProxyURL("http://localhost:4528"), + DialContext: (&net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + DualStack: true, + }).DialContext, + MaxIdleConns: 100, + IdleConnTimeout: 90 * time.Second, + TLSHandshakeTimeout: 10 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + } + + client := &http.Client{ + Transport: transport, + } + + testHandler := func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + data := make([]byte, 2*size.KB, 2*size.KB) + w.Write(data) + return + } + + server := httptest.NewServer(http.HandlerFunc(testHandler)) + + req, err := http.NewRequest(http.MethodGet, server.URL, nil) + if err != nil { + t.Fatal(err) + } + + t.Log(req.URL) + resp, err := client.Do(req) + if err != nil { + t.Fatal(err) + } + + if resp.StatusCode != http.StatusRequestEntityTooLarge { + t.Fatalf("status code is bad (%v)", resp.StatusCode) + } + + <-time.After(2 * time.Second) + + if c1.Length() != 0 { + t.Fatalf("cache length is bad, got=%d", c1.Length()) + } +} + func TestProxyHandler_GC(t *testing.T) { - c1 := cache.NewLRUCache(100, 1 * time.Second ) + c1 := cache.NewLRUCache(100, 1*time.Second) { c1.OnEviction = func(key string) { c1.Delete(key) @@ -167,7 +250,7 @@ func TestProxyHandler_GC(t *testing.T) { go func() { logger := log.New(os.Stderr, "", log.LstdFlags) - proxy := handler.NewProxy(c1, logger.Println) + proxy := handler.NewProxy(c1, logger.Println, 3*size.MB) mux := http.NewServeMux() mux.Handle("/", proxy) @@ -177,8 +260,8 @@ func TestProxyHandler_GC(t *testing.T) { } xserver := xhttp.Server{ - Server: &http.Server{Addr: "localhost:4568", Handler: proxy}, - Logger: logger, + Server: &http.Server{Addr: "localhost:4568", Handler: proxy}, + Logger: logger, Listener: listener, } if err := xserver.Start(); err != nil { @@ -186,7 +269,7 @@ func TestProxyHandler_GC(t *testing.T) { } }() - <-time.After(1 *time.Second) + <-time.After(1 * time.Second) transport := &http.Transport{ Proxy: SetProxyURL("http://localhost:4568"), @@ -205,7 +288,6 @@ func TestProxyHandler_GC(t *testing.T) { Transport: transport, } - testHandler := func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) w.Write([]byte(`{"count": 10}`)) @@ -229,7 +311,7 @@ func TestProxyHandler_GC(t *testing.T) { t.Fatalf("status code is bad (%v)", resp.StatusCode) } - <-time.After(2 *time.Second) + <-time.After(2 * time.Second) if c1.Length() != 0 { t.Fatalf("cache length is bad, got=%d", c1.Length()) @@ -238,11 +320,11 @@ func TestProxyHandler_GC(t *testing.T) { func TestProxyHttpServer(t *testing.T) { - c1 := cache.NewLRUCache(100, 0 ) + c1 := cache.NewLRUCache(100, 0) go func() { logger := log.New(os.Stderr, "", log.LstdFlags) - proxy := handler.NewProxy(c1, logger.Println) + proxy := handler.NewProxy(c1, logger.Println, 5*size.MB) mux := http.NewServeMux() mux.Handle("/", proxy) @@ -252,8 +334,8 @@ func TestProxyHttpServer(t *testing.T) { } xserver := xhttp.Server{ - Server: &http.Server{Addr: "localhost:4567", Handler: proxy}, - Logger: logger, + Server: &http.Server{Addr: "localhost:4567", Handler: proxy}, + Logger: logger, Listener: listener, } if err := xserver.Start(); err != nil { @@ -261,7 +343,7 @@ func TestProxyHttpServer(t *testing.T) { } }() - <-time.After(1 *time.Second) + <-time.After(1 * time.Second) transport := &http.Transport{ Proxy: SetProxyURL("http://localhost:4567"),