From adc17760e1a515962e02a2e3a47affdbdf9757a7 Mon Sep 17 00:00:00 2001 From: Matt Reiferson Date: Sat, 6 Apr 2013 23:50:21 -0400 Subject: [PATCH] rewrite for go 1.1: * cleanup API, no longer need FinishRequest * implement priority queue that uses the new CancelRequest API * improve travis config * update docs --- .travis.yml | 10 ++ README.md | 112 +++--------- httpclient.go | 400 ++++++++++++++++-------------------------- httpclient_test.go | 221 ++++------------------- pqueue/pqueue.go | 75 ++++++++ pqueue/pqueue_test.go | 68 +++++++ 6 files changed, 369 insertions(+), 517 deletions(-) create mode 100644 pqueue/pqueue.go create mode 100644 pqueue/pqueue_test.go diff --git a/.travis.yml b/.travis.yml index 4f2ee4d..ba1b6b7 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1 +1,11 @@ language: go +go: + - 1.1 +install: + - go get github.com/bmizerany/assert +script: + - pushd $TRAVIS_BUILD_DIR + - go test + - popd +notifications: + email: false diff --git a/README.md b/README.md index a63354a..f719067 100644 --- a/README.md +++ b/README.md @@ -1,96 +1,40 @@ -## HttpClient +## go-httpclient -HttpClient wraps Go's built in HTTP client providing an API to: +NOTE: **Requires Go 1.1+** - as of `0.4.0` the API has been completely re-written for Go 1.1 (for Go 1.0.x compatible release see the +[go1](https://github.com/mreiferson/go-httpclient/tree/go1) tag) - * set timeouts - * separate connect timeout - * *request* based timeout (*not* just read/write deadline) - * easy access to the connection object for a given request +[![Build +Status](https://secure.travis-ci.org/mreiferson/go-httpclient.png)](http://travis-ci.org/mreiferson/go-httpclient) -```go -package httpclient - -type HttpClient struct { - ConnectTimeout time.Duration - ReadWriteTimeout time.Duration - MaxConnsPerHost int - RedirectPolicy func(*http.Request, []*http.Request) error - TLSClientConfig *tls.Config -} - -func New() *HttpClient - create a new HttpClient all options should be set on the instance - returned - -func (h *HttpClient) Do(req *http.Request) (*http.Response, error) - perform the specified request +Provides an HTTP Transport that implements the `RoundTripper` interface and +can be used as a built in replacement for the standard library's, providing: -func (h *HttpClient) FinishRequest(req *http.Request) error - perform final cleanup for the specified request *must* be called for - every request performed after processing is finished and after which - GetConn will no longer return successfully + * connection timeouts + * request timeouts -func (h *HttpClient) GetConn(req *http.Request) (net.Conn, error) - returns the connection associated with the specified request cannot be - called after FinishRequest +Internally, it uses a priority queue maintained in a single goroutine +(per *client* instance), leveraging the Go 1.1+ `CancelRequest()` API. -func (h *HttpClient) RoundTrip(req *http.Request) (*http.Response, error) - satisfies the RoundTripper interface and handles checking the connection - cache or dialing (with ConnectTimeout) -func DefaultRedirectPolicy(req *http.Request, via []*http.Request) error - default redirect policy which fails after 3 redirects. - -func Version() string - returns the current version -``` - -#### Example +### Example ```go -package main - -import ( - "httpclient" - "io/ioutil" - "log" - "net/http" - "time" -) - -func main() { - httpClient := httpclient.New() - httpClient.ConnectTimeout = time.Second - httpClient.ReadWriteTimeout = time.Second - - // Allow insecure HTTPS connections. Note: the TLSClientConfig pointer can't change - // places, so you can only modify the existing tls.Config object - httpClient.TLSClientConfig.InsecureSkipVerify = true - - // Make a custom redirect policy to keep track of the number of redirects we've followed - var numRedirects int - httpClient.RedirectPolicy = func(r *http.Request, v []*http.Request) error { - numRedirects += 1 - return DefaultRedirectPolicy(r, v) - } - - req, _ := http.NewRequest("GET", "http://127.0.0.1/test", nil) - - resp, err := httpClient.Do(req) - if err != nil { - log.Fatalf("request failed - %s", err.Error()) - } - defer resp.Body.Close() - - conn, err := httpClient.GetConn(req) - if err != nil { - log.Fatalf("failed to get conn for req") - } - // do something with conn - - body, err := ioutil.ReadAll(resp.Body) - log.Printf("%s", body) +transport := &httpclient.Transport{ + ConnectTimeout: 1*time.Second, + RequestTimeout: 10*time.Second, + ResponseHeaderTimeout: 5*time.Second, +} +defer transport.Close() - httpClient.FinishRequest(req) +client := &http.Client{Transport: transport} +req, _ := http.NewRequest("GET", "http://127.0.0.1/test", nil) +resp, err := client.Do(req) +if err != nil { + return err } +defer resp.Body.Close() ``` + +### Reference Docs + +For API docs see [godoc](http://godoc.org/github.com/mreiferson/go-httpclient). diff --git a/httpclient.go b/httpclient.go index a0a8686..2f475e1 100644 --- a/httpclient.go +++ b/httpclient.go @@ -1,289 +1,193 @@ +/* +Provides an HTTP Transport that implements the `RoundTripper` interface and +can be used as a built in replacement for the standard library's, providing: + + * connection timeouts + * request timeouts + +Internally, it uses a priority queue maintained in a single goroutine +(per *client* instance), leveraging the Go 1.1+ `CancelRequest()` API. +*/ package httpclient import ( - "bufio" - "container/list" + "container/heap" "crypto/tls" - "errors" - "log" + "github.com/mreiferson/go-httpclient/pqueue" + "io" "net" "net/http" - "strings" + "net/url" "sync" "time" ) -// returns the current version +// returns the current version of the package func Version() string { - return "0.3.9" + return "0.4.0" } -type cachedConn struct { - net.Conn - shouldClose bool -} - -type connCache struct { - dl *list.List - outstanding int -} - -// HttpClient wraps Go's built in HTTP client providing an API to: -// * set connect timeout -// * set read/write timeout -// * easy access to the connection object for a given request +// Transport implements the RoundTripper interface and can be used as a replacement +// for Go's built in http.Transport implementing end-to-end request timeouts. // -type HttpClient struct { - sync.RWMutex - client *http.Client - cachedConns map[string]*connCache - connMap map[*http.Request]*cachedConn - ConnectTimeout time.Duration - ReadWriteTimeout time.Duration - MaxConnsPerHost int - RedirectPolicy func(*http.Request, []*http.Request) error - TLSClientConfig *tls.Config - Verbose bool -} - -// create a new HttpClient -// all options should be set on the instance returned -func New() *HttpClient { - client := &http.Client{} - h := &HttpClient{ - client: client, - cachedConns: make(map[string]*connCache), - connMap: make(map[*http.Request]*cachedConn), - ConnectTimeout: 5 * time.Second, - ReadWriteTimeout: 5 * time.Second, - MaxConnsPerHost: 5, - RedirectPolicy: DefaultRedirectPolicy, - TLSClientConfig: &tls.Config{}, - } - - redirFunc := func(r *http.Request, v []*http.Request) error { - lastRequest := v[len(v)-1] - if strings.HasPrefix(lastRequest.URL.Scheme, "hc_") { - lastRequest.URL.Scheme = lastRequest.URL.Scheme[3:] - } - if strings.HasPrefix(r.URL.Scheme, "hc_") { - r.URL.Scheme = r.URL.Scheme[3:] - } - resp := h.RedirectPolicy(r, v) - r.URL.Scheme = "hc_" + r.URL.Scheme - return resp - } - - transport := &http.Transport{ - TLSClientConfig: h.TLSClientConfig, - } - transport.RegisterProtocol("hc_http", h) - transport.RegisterProtocol("hc_https", h) - - client.CheckRedirect = redirFunc - client.Transport = transport - - return h +// transport := &httpclient.Transport{ +// ConnectTimeout: 1*time.Second, +// ResponseHeaderTimeout: 5*time.Second, +// RequestTimeout: 10*time.Second, +// } +// defer transport.Close() +// +// client := &http.Client{Transport: transport} +// req, _ := http.NewRequest("GET", "http://127.0.0.1/test", nil) +// resp, err := client.Do(req) +// if err != nil { +// return err +// } +// defer resp.Body.Close() +// +type Transport struct { + sync.Mutex + + // Proxy specifies a function to return a proxy for a given + // *http.Request. If the function returns a non-nil error, the + // request is aborted with the provided error. + // If Proxy is nil or returns a nil *url.URL, no proxy is used. + Proxy func(*http.Request) (*url.URL, error) + + // TLSClientConfig specifies the TLS configuration to use with + // tls.Client. If nil, the default configuration is used. + TLSClientConfig *tls.Config + + // DisableKeepAlives, if true, prevents re-use of TCP connections + // between different HTTP requests. + DisableKeepAlives bool + + // DisableCompression, if true, prevents the Transport from + // requesting compression with an "Accept-Encoding: gzip" + // request header when the Request contains no existing + // Accept-Encoding value. If the Transport requests gzip on + // its own and gets a gzipped response, it's transparently + // decoded in the Response.Body. However, if the user + // explicitly requested gzip it is not automatically + // uncompressed. + DisableCompression bool + + // MaxIdleConnsPerHost, if non-zero, controls the maximum idle + // (keep-alive) to keep per-host. If zero, + // http.DefaultMaxIdleConnsPerHost is used. + MaxIdleConnsPerHost int + + // ConnectTimeout, if non-zero, is the maximum amount of time a dial will wait for + // a connect to complete. + ConnectTimeout time.Duration + + // ResponseHeaderTimeout, if non-zero, specifies the amount of + // time to wait for a server's response headers after fully + // writing the request (including its body, if any). This + // time does not include the time to read the response body. + ResponseHeaderTimeout time.Duration + + // RequestTimeout, if non-zero, specifies the amount of time for the entire + // request to complete (including all of the above timeouts + entire response body). + // This should never be less than the sum total of the above two timeouts. + RequestTimeout time.Duration + + starter sync.Once + transport *http.Transport + requests pqueue.PriorityQueue + exitChan chan int } -func DefaultRedirectPolicy(req *http.Request, via []*http.Request) error { - if len(via) > 3 { - return errors.New("Stopped after 3 redirects") +// Close cleans up the Transport, making sure its goroutine has exited +func (t *Transport) Close() error { + if t.exitChan != nil { + t.exitChan <- 1 + <-t.exitChan } return nil } -// satisfies the RoundTripper interface and handles checking -// the connection cache or dialing (with ConnectTimeout) -func (h *HttpClient) RoundTrip(req *http.Request) (*http.Response, error) { - var c net.Conn - var err error - - addr := canonicalAddr(req.URL.Host, req.URL.Scheme) - - if h.Verbose { - log.Printf("DEBUG: checking cache for addr %s", addr) - } - c, err = h.checkConnCache(addr) - if err != nil { - return nil, err - } - - if c == nil { - if h.Verbose { - log.Printf("DEBUG: addr not in cache, connecting...") - } - c, err = net.DialTimeout("tcp", addr, h.ConnectTimeout) - if err != nil { - return nil, err - } - - if req.URL.Scheme == "hc_https" { - // Initiate TLS and check remote host name against certificate. - c = tls.Client(c, h.TLSClientConfig) - if err = c.(*tls.Conn).Handshake(); err != nil { - return nil, err - } - if h.TLSClientConfig == nil || !h.TLSClientConfig.InsecureSkipVerify { - hostname, _, _ := net.SplitHostPort(req.URL.Host) // Remove port from host - if err = c.(*tls.Conn).VerifyHostname(hostname); err != nil { - return nil, err - } - } - } - } - - h.Lock() - h.connMap[req] = &cachedConn{Conn: c} - h.Unlock() - - return h.exec(c, req) -} - -func (h *HttpClient) checkConnCache(addr string) (net.Conn, error) { - var c net.Conn - - h.Lock() - defer h.Unlock() - - cc, ok := h.cachedConns[addr] - if ok { - // address is in map, check the connection list - e := cc.dl.Front() - if e != nil { - cc.dl.Remove(e) - c = e.Value.(net.Conn) - } - } else { - // this client hasnt seen this address before - cc = &connCache{ - dl: list.New(), - } - h.cachedConns[addr] = cc +func (t *Transport) lazyStart() { + dialer := &net.Dialer{Timeout: t.ConnectTimeout} + t.transport = &http.Transport{ + Dial: dialer.Dial, + Proxy: t.Proxy, + TLSClientConfig: t.TLSClientConfig, + DisableKeepAlives: t.DisableKeepAlives, + DisableCompression: t.DisableCompression, + MaxIdleConnsPerHost: t.MaxIdleConnsPerHost, + } + t.requests = pqueue.New(16) + if t.RequestTimeout > 0 { + t.exitChan = make(chan int) + go t.worker() } - - // TODO: implement accounting for outstanding connections - if cc.outstanding > h.MaxConnsPerHost { - return nil, errors.New("too many outstanding conns on this addr") - } - - return c, nil } -func (h *HttpClient) cacheConn(addr string, conn net.Conn) error { - h.Lock() - defer h.Unlock() - - cc, ok := h.cachedConns[addr] - if !ok { - return errors.New("addr %s not in cache map") - } - cc.dl.PushBack(conn) +func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) { + var item *pqueue.Item - return nil -} + t.starter.Do(t.lazyStart) -func (h *HttpClient) exec(conn net.Conn, req *http.Request) (*http.Response, error) { - deadline := time.Now().Add(h.ReadWriteTimeout) - conn.SetDeadline(deadline) + absTs := time.Now().Add(t.RequestTimeout).UnixNano() + item = &pqueue.Item{Value: req, Priority: absTs} + t.Lock() + heap.Push(&t.requests, item) + t.Unlock() - bw := bufio.NewWriter(conn) - br := bufio.NewReader(conn) - - err := req.Write(bw) + resp, err := t.transport.RoundTrip(req) if err != nil { + t.Lock() + if item.Index != -1 { + heap.Remove(&t.requests, item.Index) + } + t.Unlock() return nil, err } - bw.Flush() - - resp, err := http.ReadResponse(br, req) - if err != nil { - h.Lock() - delete(h.connMap, req) - h.Unlock() - conn.Close() - } - return resp, err -} - -// returns the connection associated with the specified request -// cannot be called after FinishRequest -func (h *HttpClient) GetConn(req *http.Request) (net.Conn, error) { - h.RLock() - defer h.RUnlock() - - conn, ok := h.connMap[req] - if !ok { - return nil, errors.New("connection not in map") - } + resp.Body = &bodyCloseInterceptor{ReadCloser: resp.Body, item: item, t: t} - return conn, nil + return resp, nil } -// perform the specified request -func (h *HttpClient) Do(req *http.Request) (*http.Response, error) { - // h@x0r Go's http client to use our RoundTripper - if !strings.HasPrefix(req.URL.Scheme, "hc_") { - req.URL.Scheme = "hc_" + req.URL.Scheme - } - - resp, err := h.client.Do(req) - if err == nil && (resp.Close || req.Close) { - conn, _ := h.GetConn(req) - conn.(*cachedConn).shouldClose = true - if h.Verbose { - log.Printf("DEBUG: setting close on %s, err: %s, resp.Close: %v, req.Close: %v", conn.RemoteAddr(), err, resp.Close, req.Close) +func (t *Transport) worker() { + ticker := time.NewTicker(25 * time.Millisecond) + for { + select { + case <-ticker.C: + case <-t.exitChan: + goto exit } - } - if resp != nil { - if strings.HasPrefix(resp.Request.URL.Scheme, "hc_") { - resp.Request.URL.Scheme = resp.Request.URL.Scheme[3:] - } - } - return resp, err -} - -// perform final cleanup for the specified request -// *must* be called for every request performed after processing -// is finished and after which GetConn will no longer return -// successfully -func (h *HttpClient) FinishRequest(req *http.Request) error { - conn, err := h.GetConn(req) - if err != nil { - return err - } - - h.Lock() - delete(h.connMap, req) - h.Unlock() + now := time.Now().UnixNano() + for { + t.Lock() + item, _ := t.requests.PeekAndShift(now) + t.Unlock() + + if item == nil { + break + } - if conn.(*cachedConn).shouldClose { - if h.Verbose { - log.Printf("DEBUG: conn %s shouldClose, closing...", conn.RemoteAddr()) + req := item.Value.(*http.Request) + t.transport.CancelRequest(req) } - conn.Close() - return nil } +exit: + ticker.Stop() + close(t.exitChan) +} - addr := canonicalAddr(req.URL.Host, req.URL.Scheme) - if h.Verbose { - log.Printf("DEBUG: caching conn %s as %s", conn.RemoteAddr(), addr) - } - return h.cacheConn(addr, conn.(*cachedConn).Conn) +type bodyCloseInterceptor struct { + io.ReadCloser + item *pqueue.Item + t *Transport } -func canonicalAddr(s string, scheme string) string { - if !hasPort(s) { - switch scheme { - case "http", "hc_http": - s = s + ":80" - case "https", "hc_https": - s = s + ":443" - } +func (bci *bodyCloseInterceptor) Close() error { + err := bci.ReadCloser.Close() + bci.t.Lock() + if bci.item.Index != -1 { + heap.Remove(&bci.t.requests, bci.item.Index) } - return s + bci.t.Unlock() + return err } - -// Given a string of the form "host", "host:port", or "[ipv6::address]:port", -// return true if the string includes a port. -func hasPort(s string) bool { return strings.LastIndex(s, ":") > strings.LastIndex(s, "]") } diff --git a/httpclient_test.go b/httpclient_test.go index 374e44c..998dac0 100644 --- a/httpclient_test.go +++ b/httpclient_test.go @@ -1,12 +1,11 @@ package httpclient import ( - "bytes" + "crypto/tls" "io" "io/ioutil" "net" "net/http" - "strings" "sync" "testing" "time" @@ -64,91 +63,45 @@ func setupMockServer(t *testing.T) { } func TestHttpsConnection(t *testing.T) { - httpClient := New() - httpClient.TLSClientConfig.InsecureSkipVerify = true + transport := &Transport{ + ConnectTimeout: 1 * time.Second, + RequestTimeout: 2 * time.Second, + TLSClientConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + } + defer transport.Close() + client := &http.Client{Transport: transport} req, _ := http.NewRequest("GET", "https://httpbin.org/ip", nil) - - resp, err := httpClient.Do(req) + resp, err := client.Do(req) if err != nil { t.Fatalf("1st request failed - %s", err.Error()) } - defer resp.Body.Close() _, err = ioutil.ReadAll(resp.Body) if err != nil { t.Fatalf("1st failed to read body - %s", err.Error()) } - httpClient.FinishRequest(req) + resp.Body.Close() - httpClient.ReadWriteTimeout = 20 * time.Millisecond req2, _ := http.NewRequest("GET", "https://httpbin.org/delay/5", nil) - - _, err = httpClient.Do(req) + _, err = client.Do(req2) if err == nil { t.Fatalf("HTTPS request should have timed out") } - httpClient.FinishRequest(req2) } -func TestCustomRedirectPolicy(t *testing.T) { +func TestHttpClient(t *testing.T) { starter.Do(func() { setupMockServer(t) }) - httpClient := New() - redirects := make(chan string, 3) - httpClient.RedirectPolicy = func(r *http.Request, v []*http.Request) error { - if strings.HasPrefix(r.URL.Scheme, "hc_") { - t.Errorf("Stray hc_ in URL") - } - for _, i := range v { - if strings.HasPrefix(i.URL.Scheme, "hc_") { - t.Errorf("Stray hc_ in URL") - } - } - redirects <- v[len(v)-1].URL.String() - return DefaultRedirectPolicy(r, v) - } - - req, _ := http.NewRequest("GET", "http://"+addr.String()+"/redirect2", nil) - - resp, err := httpClient.Do(req) - if err != nil { - t.Fatalf("1st request failed - %s", err.Error()) + transport := &Transport{ + ConnectTimeout: 1 * time.Second, + RequestTimeout: 5 * time.Second, } + client := &http.Client{Transport: transport} - urls := make([]string, 0, 3) - close(redirects) - for url := range redirects { - urls = append(urls, url) - } - urls = append(urls, resp.Request.URL.String()) - t.Logf("%s", urls) - for _, url := range urls { - if strings.HasPrefix(url, "hc_") { - t.Errorf("Stray hc_ in URL") - } - } - - defer resp.Body.Close() - body, err := ioutil.ReadAll(resp.Body) - if err != nil { - t.Fatalf("1st failed to read body - %s", err.Error()) - } - httpClient.FinishRequest(req) - - if len(urls) != 3 { - t.Fatalf("Did not correctly redirect with custom redirect policy", err.Error()) - } - - t.Logf("%s", body) -} - -func TestClose(t *testing.T) { - starter.Do(func() { setupMockServer(t) }) - - httpClient := New() - req, _ := http.NewRequest("GET", "http://"+addr.String()+"/close", nil) - - resp, err := httpClient.Do(req) + req, _ := http.NewRequest("GET", "http://"+addr.String()+"/test", nil) + resp, err := client.Do(req) if err != nil { t.Fatalf("1st request failed - %s", err.Error()) } @@ -157,134 +110,32 @@ func TestClose(t *testing.T) { t.Fatalf("1st failed to read body - %s", err.Error()) } resp.Body.Close() - httpClient.FinishRequest(req) - - req, _ = http.NewRequest("GET", "http://"+addr.String()+"/close", nil) - - resp, err = httpClient.Do(req) - if err != nil { - t.Fatalf("2nd request failed - %s", err.Error()) - } - _, err = ioutil.ReadAll(resp.Body) - if err != nil { - t.Fatalf("2nd failed to read body - %s", err.Error()) - } - resp.Body.Close() - httpClient.FinishRequest(req) -} - -func TestHttpClient(t *testing.T) { - starter.Do(func() { setupMockServer(t) }) - - httpClient := New() - if httpClient == nil { - t.Fatalf("failed to instantiate HttpClient") - } - - req, _ := http.NewRequest("GET", "http://"+addr.String()+"/test", nil) - - resp, err := httpClient.Do(req) - if err != nil { - t.Fatalf("1st request failed - %s", err.Error()) - } - defer resp.Body.Close() - - if strings.HasPrefix(resp.Request.URL.Scheme, "hc_") { - t.Errorf("Stray hc_ in response") - } + transport.Close() - body, err := ioutil.ReadAll(resp.Body) - if err != nil { - t.Fatalf("1st failed to read body - %s", err.Error()) + transport = &Transport{ + ConnectTimeout: 25 * time.Millisecond, + RequestTimeout: 50 * time.Millisecond, } - t.Logf("%s", body) - httpClient.FinishRequest(req) + client = &http.Client{Transport: transport} - httpClient.ReadWriteTimeout = 50 * time.Millisecond - resp, err = httpClient.Do(req) + req2, _ := http.NewRequest("GET", "http://"+addr.String()+"/test", nil) + resp, err = client.Do(req2) if err == nil { t.Fatalf("2nd request should have timed out") } - httpClient.FinishRequest(req) - - httpClient.ReadWriteTimeout = 250 * time.Millisecond - resp, err = httpClient.Do(req) - if err != nil { - t.Fatalf("3nd request should not have timed out") - } - httpClient.FinishRequest(req) -} - -func TestManyPosts(t *testing.T) { - starter.Do(func() { setupMockServer(t) }) - - httpClient := New() - if httpClient == nil { - t.Fatalf("failed to instantiate HttpClient") - } - - data := "" - for i := 0; i < 100; i++ { - data = data + "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" - } - data = data + "\n" - - for i := 0; i < 10000; i++ { - buffer := bytes.NewBuffer([]byte(data)) - req, _ := http.NewRequest("POST", "http://"+addr.String()+"/post", buffer) - - resp, err := httpClient.Do(req) - if err != nil { - t.Fatalf("%d post request failed - %s", i, err.Error()) - } - _, err = ioutil.ReadAll(resp.Body) - if err != nil { - t.Fatalf("%d failed to read body - %s", i, err.Error()) - } - resp.Body.Close() - httpClient.FinishRequest(req) - } -} - -func TestConnectionCache(t *testing.T) { - starter.Do(func() { setupMockServer(t) }) + transport.Close() - httpClient := New() - if httpClient == nil { - t.Fatalf("failed to instantiate HttpClient") + transport = &Transport{ + ConnectTimeout: 25 * time.Millisecond, + RequestTimeout: 250 * time.Millisecond, } + client = &http.Client{Transport: transport} - req, _ := http.NewRequest("GET", "http://"+addr.String()+"/test", nil) - - resp, err := httpClient.Do(req) + req3, _ := http.NewRequest("GET", "http://"+addr.String()+"/test", nil) + resp, err = client.Do(req3) if err != nil { - t.Fatalf("1st request failed - %s", err.Error()) - } - resp.Body.Close() - - if len(httpClient.connMap) != 1 { - t.Fatalf("connMap != 1") - } - - httpClient.FinishRequest(req) - - if len(httpClient.connMap) != 0 { - t.Fatalf("connMap != 0") - } - - if httpClient.cachedConns[addr.String()].dl.Len() != 1 { - t.Fatalf("cachedConns != 1") - } - - resp, err = httpClient.Do(req) - if err != nil { - t.Fatalf("1st request failed - %s", err.Error()) + t.Fatalf("3nd request should not have timed out") } resp.Body.Close() - - if httpClient.cachedConns[addr.String()].dl.Len() != 0 { - t.Fatalf("cachedConns != 0") - } - - httpClient.FinishRequest(req) + transport.Close() } diff --git a/pqueue/pqueue.go b/pqueue/pqueue.go new file mode 100644 index 0000000..1e28ed8 --- /dev/null +++ b/pqueue/pqueue.go @@ -0,0 +1,75 @@ +package pqueue + +import ( + "container/heap" +) + +type Item struct { + Value interface{} + Priority int64 + Index int +} + +// this is a priority queue as implemented by a min heap +// ie. the 0th element is the *lowest* value +type PriorityQueue []*Item + +func New(capacity int) PriorityQueue { + return make(PriorityQueue, 0, capacity) +} + +func (pq PriorityQueue) Len() int { + return len(pq) +} + +func (pq PriorityQueue) Less(i, j int) bool { + return pq[i].Priority < pq[j].Priority +} + +func (pq PriorityQueue) Swap(i, j int) { + pq[i], pq[j] = pq[j], pq[i] + pq[i].Index = i + pq[j].Index = j +} + +func (pq *PriorityQueue) Push(x interface{}) { + n := len(*pq) + c := cap(*pq) + if n+1 > c { + npq := make(PriorityQueue, n, c*2) + copy(npq, *pq) + *pq = npq + } + *pq = (*pq)[0 : n+1] + item := x.(*Item) + item.Index = n + (*pq)[n] = item +} + +func (pq *PriorityQueue) Pop() interface{} { + n := len(*pq) + c := cap(*pq) + if n < (c/2) && c > 25 { + npq := make(PriorityQueue, n, c/2) + copy(npq, *pq) + *pq = npq + } + item := (*pq)[n-1] + item.Index = -1 + *pq = (*pq)[0 : n-1] + return item +} + +func (pq *PriorityQueue) PeekAndShift(max int64) (*Item, int64) { + if pq.Len() == 0 { + return nil, 0 + } + + item := (*pq)[0] + if item.Priority > max { + return nil, item.Priority - max + } + heap.Remove(pq, 0) + + return item, 0 +} diff --git a/pqueue/pqueue_test.go b/pqueue/pqueue_test.go new file mode 100644 index 0000000..bbc35d6 --- /dev/null +++ b/pqueue/pqueue_test.go @@ -0,0 +1,68 @@ +package pqueue + +import ( + "container/heap" + "github.com/bmizerany/assert" + "math/rand" + "sort" + "testing" +) + +func TestPriorityQueue(t *testing.T) { + c := 100 + pq := New(c) + + for i := 0; i < c+1; i++ { + heap.Push(&pq, &Item{Value: i, Priority: int64(i)}) + } + assert.Equal(t, pq.Len(), c+1) + assert.Equal(t, cap(pq), c*2) + + for i := 0; i < c+1; i++ { + item := heap.Pop(&pq) + assert.Equal(t, item.(*Item).Value.(int), i) + } + assert.Equal(t, cap(pq), c/4) +} + +func TestUnsortedInsert(t *testing.T) { + c := 100 + pq := New(c) + ints := make([]int, 0, c) + + for i := 0; i < c; i++ { + v := rand.Int() + ints = append(ints, v) + heap.Push(&pq, &Item{Value: i, Priority: int64(v)}) + } + assert.Equal(t, pq.Len(), c) + assert.Equal(t, cap(pq), c) + + sort.Sort(sort.IntSlice(ints)) + + for i := 0; i < c; i++ { + item, _ := pq.PeekAndShift(int64(ints[len(ints)-1])) + assert.Equal(t, item.Priority, int64(ints[i])) + } +} + +func TestRemove(t *testing.T) { + c := 100 + pq := New(c) + + for i := 0; i < c; i++ { + v := rand.Int() + heap.Push(&pq, &Item{Value: "test", Priority: int64(v)}) + } + + for i := 0; i < 10; i++ { + heap.Remove(&pq, rand.Intn((c-1)-i)) + } + + lastPriority := heap.Pop(&pq).(*Item).Priority + for i := 0; i < (c - 10 - 1); i++ { + item := heap.Pop(&pq) + assert.Equal(t, lastPriority < item.(*Item).Priority, true) + lastPriority = item.(*Item).Priority + } +}