diff --git a/cmd/httpcache/main.go b/cmd/httpcache/main.go index efa9de4..fa153af 100644 --- a/cmd/httpcache/main.go +++ b/cmd/httpcache/main.go @@ -52,16 +52,17 @@ func main() { } } - proxy := handler.NewProxy(c, logger.Println, *responseBodyContentLenghtLimit) stats := handler.NewStats(c, logger.Println) ping := handler.NewPing(logger.Println) + proxy := handler.NewProxy( + c, + logger.Println, + *responseBodyContentLenghtLimit, + ping, + stats, + ) - mux := http.NewServeMux() - mux.Handle("/stats", stats) - mux.Handle("/", proxy) - mux.Handle("/ping", ping) - - stack := middleware.NewPanic(mux, logger.Println) + stack := middleware.NewPanic(proxy, logger.Println) if *httpAddr != "" { listener, err := net.Listen("tcp", *httpAddr) diff --git a/internal/handler/ping.go b/internal/handler/ping.go index 6724c44..3f558ff 100644 --- a/internal/handler/ping.go +++ b/internal/handler/ping.go @@ -19,4 +19,3 @@ func (s *Ping) ServeHTTP(resp http.ResponseWriter, req *http.Request) { resp.WriteHeader(http.StatusOK) resp.Write([]byte("ok")) } - diff --git a/internal/handler/proxycache.go b/internal/handler/proxycache.go index 6f05d7d..ac295cc 100644 --- a/internal/handler/proxycache.go +++ b/internal/handler/proxycache.go @@ -4,13 +4,15 @@ import ( "fmt" "github.com/donutloop/httpcache/internal/cache" "github.com/donutloop/httpcache/internal/roundtripper" + "io" "io/ioutil" + "net" "net/http" "net/http/httputil" "strings" ) -func NewProxy(cache *cache.LRUCache, logger func(v ...interface{}), contentLength int64) *Proxy { +func NewProxy(cache *cache.LRUCache, logger func(v ...interface{}), contentLength int64, ping *Ping, stats *Stats) *Proxy { return &Proxy{ client: &http.Client{ Transport: &roundtripper.LoggedTransport{ @@ -24,17 +26,36 @@ func NewProxy(cache *cache.LRUCache, logger func(v ...interface{}), contentLengt Logger: logger, }}, logger: logger, + ping: ping, + stats: stats, } } type Proxy struct { client *http.Client logger func(v ...interface{}) + ping *Ping + stats *Stats } func (p *Proxy) ServeHTTP(resp http.ResponseWriter, req *http.Request) { + if req.URL.Path == "/ping" { + p.ping.ServeHTTP(resp, req) + return + } + + if req.URL.Path == "/stats" { + p.ping.ServeHTTP(resp, req) + return + } + req.RequestURI = "" + if req.Method == http.MethodConnect { + p.ProxyHTTPS(resp, req) + return + } + proxyResponse, err := p.client.Do(req) if err != nil { if strings.Contains(err.Error(), roundtripper.ResponseIsToLarge.Error()) { @@ -66,6 +87,42 @@ func (p *Proxy) ServeHTTP(resp http.ResponseWriter, req *http.Request) { resp.Write(body) } +func (p *Proxy) ProxyHTTPS(rw http.ResponseWriter, req *http.Request) { + hij, ok := rw.(http.Hijacker) + if !ok { + p.logger("proxy https error: http server does not support hijacker") + return + } + + clientConn, _, err := hij.Hijack() + if err != nil { + p.logger("proxy https error: %v", err) + return + } + + proxyConn, err := net.Dial("tcp", req.URL.Host) + if err != nil { + p.logger("proxy https error: %v", err) + return + } + + _, err = clientConn.Write([]byte("HTTP/1.0 200 OK\r\n\r\n")) + if err != nil { + p.logger("proxy https error: %v", err) + return + } + + go func() { + io.Copy(clientConn, proxyConn) + clientConn.Close() + proxyConn.Close() + }() + + io.Copy(proxyConn, clientConn) + proxyConn.Close() + clientConn.Close() +} + type requestDump []byte type responseDump []byte diff --git a/tests/api_test.go b/tests/api_test.go index 072c2bb..f393dae 100644 --- a/tests/api_test.go +++ b/tests/api_test.go @@ -1,6 +1,7 @@ package tests import ( + "crypto/tls" "encoding/json" "fmt" "github.com/donutloop/httpcache/internal/cache" @@ -21,20 +22,30 @@ import ( ) var client *http.Client +var clientTls *http.Client var c *cache.LRUCache func TestMain(m *testing.M) { c = cache.NewLRUCache(100, 0) - proxy := handler.NewProxy(c, log.Println, 500*size.MB) stats := handler.NewStats(c, log.Println) + ping := handler.NewPing(log.Println) + proxy := handler.NewProxy( + c, + log.Println, + 500*size.MB, + ping, + stats, + ) mux := http.NewServeMux() mux.Handle("/stats", stats) mux.Handle("/", proxy) + mux.Handle("/ping", ping) stack := middleware.NewPanic(mux, log.Println) proxyServer := httptest.NewServer(stack) + proxyServerTLS := httptest.NewTLSServer(proxy) transport := &http.Transport{ Proxy: SetProxyURL(proxyServer.URL), @@ -53,6 +64,61 @@ func TestMain(m *testing.M) { Transport: transport, } + transportTls := &http.Transport{ + Proxy: SetProxyURL(proxyServerTLS.URL + "/"), + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, + DialContext: (&net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + DualStack: true, + }).DialContext, + MaxIdleConns: 100, + IdleConnTimeout: 90 * time.Second, + TLSHandshakeTimeout: 10 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + } + + log.Println("tls proxy " + proxyServerTLS.URL) + + clientTls = &http.Client{ + Transport: transportTls, + } + + testtransport := &http.Transport{ + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, + DialContext: (&net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + DualStack: true, + }).DialContext, + MaxIdleConns: 100, + IdleConnTimeout: 90 * time.Second, + TLSHandshakeTimeout: 10 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + } + + testclient := http.Client{ + Transport: testtransport, + } + + resp, err := testclient.Get(fmt.Sprintf("%v/ping", proxyServerTLS.URL)) + if err != nil { + log.Fatalln(fmt.Sprintf("tls proxy (%v)", err)) + } + + if resp.StatusCode != http.StatusOK { + log.Fatalln(fmt.Sprintf("status code is bad (%v)", resp.StatusCode)) + } + + resp, err = testclient.Get(fmt.Sprintf("%v/ping", proxyServer.URL)) + if err != nil { + log.Fatalln(fmt.Sprintf("proxy (%v)", err)) + } + + if resp.StatusCode != http.StatusOK { + log.Fatalln(fmt.Sprintf("status code is bad (%v)", resp.StatusCode)) + } + // call flag.Parse() here if TestMain uses flags os.Exit(m.Run()) } @@ -67,6 +133,36 @@ func SetProxyURL(proxy string) func(req *http.Request) (*url.URL, error) { } } +func TestProxyHTTPSHandler(t *testing.T) { + defer c.Reset() + + handler := func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"count": 10}`)) + return + } + + testHandler := httptest.NewTLSServer(http.HandlerFunc(handler)) + + req, err := http.NewRequest(http.MethodGet, testHandler.URL, nil) + if err != nil { + t.Fatal(err) + } + + resp, err := clientTls.Do(req) + if err != nil { + t.Fatal(err) + } + + if resp.StatusCode != http.StatusOK { + t.Fatalf("status code is bad (%v)", resp.StatusCode) + } + + if c.Length() != 0 { + t.Fatalf("cache length is bad, got=%d", c.Length()) + } +} + func TestProxyHandler(t *testing.T) { defer c.Reset() @@ -170,7 +266,16 @@ func TestProxyHandler_ResponseBodyContentLengthLimit(t *testing.T) { go func() { logger := log.New(os.Stderr, "", log.LstdFlags) - proxy := handler.NewProxy(c1, logger.Println, cl) + stats := handler.NewStats(c, logger.Println) + ping := handler.NewPing(logger.Println) + proxy := handler.NewProxy( + c, + logger.Println, + cl, + ping, + stats, + ) + mux := http.NewServeMux() mux.Handle("/", proxy) @@ -249,8 +354,16 @@ func TestProxyHandler_GC(t *testing.T) { go func() { logger := log.New(os.Stderr, "", log.LstdFlags) + stats := handler.NewStats(c, logger.Println) + ping := handler.NewPing(logger.Println) + proxy := handler.NewProxy( + c, + logger.Println, + 3*size.MB, + ping, + stats, + ) - proxy := handler.NewProxy(c1, logger.Println, 3*size.MB) mux := http.NewServeMux() mux.Handle("/", proxy) @@ -324,7 +437,15 @@ func TestProxyHttpServer(t *testing.T) { go func() { logger := log.New(os.Stderr, "", log.LstdFlags) - proxy := handler.NewProxy(c1, logger.Println, 5*size.MB) + stats := handler.NewStats(c, logger.Println) + ping := handler.NewPing(logger.Println) + proxy := handler.NewProxy( + c1, + logger.Println, + 5*size.MB, + ping, + stats, + ) mux := http.NewServeMux() mux.Handle("/", proxy)