6 changes: 3 additions & 3 deletions access_log/access_log_record.go
Expand Up @@ -12,7 +12,7 @@ import (

type AccessLogRecord struct {
Request *http.Request
Response *http.Response
StatusCode int
RouteEndpoint *route.Endpoint
StartedAt time.Time
FirstByteAt time.Time
Expand Down Expand Up @@ -42,10 +42,10 @@ func (r *AccessLogRecord) makeRecord() *bytes.Buffer {
fmt.Fprintf(b, `[%s] `, r.FormatStartedAt())
fmt.Fprintf(b, `"%s %s %s" `, r.Request.Method, r.Request.URL.RequestURI(), r.Request.Proto)

if r.Response == nil {
if r.StatusCode == 0 {
fmt.Fprintf(b, "MissingResponseStatusCode ")
} else {
fmt.Fprintf(b, `%d `, r.Response.StatusCode)
fmt.Fprintf(b, `%d `, r.StatusCode)
}

fmt.Fprintf(b, `%d `, r.BodyBytesSent)
Expand Down
4 changes: 1 addition & 3 deletions access_log/access_log_record_test.go
Expand Up @@ -92,9 +92,7 @@ func CompleteAccessLogRecord() AccessLogRecord {
RemoteAddr: "FakeRemoteAddr",
},
BodyBytesSent: 23,
Response: &http.Response{
StatusCode: 200,
},
StatusCode: 200,
RouteEndpoint: &route.Endpoint{
ApplicationId: "FakeApplicationId",
},
Expand Down
14 changes: 3 additions & 11 deletions access_log/file_and_loggregator_access_logger_test.go
Expand Up @@ -62,11 +62,7 @@ var _ = Describe("AccessLog", func() {
testEmitter := NewMockEmitter()
accessLogger := NewFileAndLoggregatorAccessLogger(nil, testEmitter)

routeEndpoint := &route.Endpoint{
ApplicationId: "",
Host: "127.0.0.1",
Port: 4567,
}
routeEndpoint := route.NewEndpoint("", "127.0.0.1", 4567, "", nil)

accessLogRecord := CreateAccessLogRecord()
accessLogRecord.RouteEndpoint = routeEndpoint
Expand Down Expand Up @@ -165,15 +161,11 @@ func CreateAccessLogRecord() *AccessLogRecord {
StatusCode: http.StatusOK,
}

b := &route.Endpoint{
ApplicationId: "my_awesome_id",
Host: "127.0.0.1",
Port: 4567,
}
b := route.NewEndpoint("my_awesome_id", "127.0.0.1", 4567, "", nil)

r := AccessLogRecord{
Request: req,
Response: res,
StatusCode: res.StatusCode,
RouteEndpoint: b,
StartedAt: time.Unix(10, 100000000),
FirstByteAt: time.Unix(10, 200000000),
Expand Down
8 changes: 4 additions & 4 deletions common/component.go
Expand Up @@ -4,13 +4,13 @@ import (
"encoding/json"
"errors"
"fmt"
. "github.com/cloudfoundry/gorouter/common/http"
steno "github.com/cloudfoundry/gosteno"
"github.com/cloudfoundry/yagnats"
"net"
"net/http"
"runtime"
"time"
. "github.com/cloudfoundry/gorouter/common/http"
steno "github.com/cloudfoundry/gosteno"
"github.com/cloudfoundry/yagnats"
)

var procStat *ProcessStatus
Expand Down Expand Up @@ -81,7 +81,7 @@ func (c *VcapComponent) Start() error {
return err
}

c.Host = fmt.Sprintf("%s:%s", host, port)
c.Host = fmt.Sprintf("%s:%d", host, port)
}

if c.Credentials == nil || len(c.Credentials) != 2 {
Expand Down
4 changes: 2 additions & 2 deletions main.go
Expand Up @@ -63,7 +63,7 @@ func main() {
logger.Fatalf("Error connecting to NATS: %s\n", err)
}

registry := rregistry.NewCFRegistry(c, natsClient)
registry := rregistry.NewRouteRegistry(c, natsClient)

varz := rvarz.NewVarz(registry)

Expand Down Expand Up @@ -98,7 +98,7 @@ func main() {
select {
case err := <-errChan:
if err != nil {
logger.Errorf("Error occurred:", err.Error())
logger.Errorf("Error occurred: %s", err.Error())
os.Exit(1)
}
case sig := <-signals:
Expand Down
2 changes: 2 additions & 0 deletions main_test.go
Expand Up @@ -115,6 +115,7 @@ var _ = Describe("Router Integration", func() {

It("waits for all requests to finish", func() {
mbusClient, err := newMessageBus(config)
Ω(err).ShouldNot(HaveOccurred())

blocker := make(chan bool)
longApp := test.NewTestApp([]route.Uri{"longapp.vcap.me"}, proxyPort, mbusClient, nil)
Expand Down Expand Up @@ -148,6 +149,7 @@ var _ = Describe("Router Integration", func() {

It("will timeout if requests take too long", func() {
mbusClient, err := newMessageBus(config)
Ω(err).ShouldNot(HaveOccurred())

blocker := make(chan bool)
resultCh := make(chan error, 1)
Expand Down
7 changes: 2 additions & 5 deletions perf_test.go
Expand Up @@ -19,7 +19,7 @@ var _ = Describe("AccessLogRecord", func() {
Measure("Register", func(b Benchmarker) {
c := config.DefaultConfig()
mbus := fakeyagnats.New()
r := registry.NewCFRegistry(c, mbus)
r := registry.NewRouteRegistry(c, mbus)

accesslog, err := access_log.CreateRunningAccessLogger(c)
Ω(err).ToNot(HaveOccurred())
Expand All @@ -38,10 +38,7 @@ var _ = Describe("AccessLogRecord", func() {
str := strconv.Itoa(i)
r.Register(
route.Uri("bench.vcap.me."+str),
&route.Endpoint{
Host: "localhost",
Port: uint16(i),
},
route.NewEndpoint("", "localhost", uint16(i), "", nil),
)
}
})
Expand Down
210 changes: 118 additions & 92 deletions proxy/proxy.go
@@ -1,6 +1,7 @@
package proxy

import (
"errors"
"net"
"net/http"
"net/http/httputil"
Expand All @@ -18,13 +19,17 @@ import (
const (
VcapCookieId = "__VCAP_ID__"
StickyCookieKey = "JSESSIONID"
retries = 3
)

var noEndpointsAvailable = errors.New("No endpoints available")

type LookupRegistry interface {
Lookup(uri route.Uri) (*route.Endpoint, bool)
LookupByPrivateInstanceId(uri route.Uri, p string) (*route.Endpoint, bool)
Lookup(uri route.Uri) *route.Pool
}

type AfterRoundTrip func(rsp *http.Response, endpoint *route.Endpoint, err error)

type ProxyReporter interface {
CaptureBadRequest(req *http.Request)
CaptureBadGateway(req *http.Request)
Expand Down Expand Up @@ -100,32 +105,32 @@ func (p *proxy) Wait() {
p.waitgroup.Wait()
}

func (p *proxy) lookup(request *http.Request) (*route.Endpoint, bool) {
uri := route.Uri(hostWithoutPort(request))

func (p *proxy) getStickySession(request *http.Request) string {
// Try choosing a backend using sticky session
if _, err := request.Cookie(StickyCookieKey); err == nil {
if sticky, err := request.Cookie(VcapCookieId); err == nil {
routeEndpoint, ok := p.registry.LookupByPrivateInstanceId(uri, sticky.Value)
if ok {
return routeEndpoint, ok
}
return sticky.Value
}
}
return ""
}

func (p *proxy) lookup(request *http.Request) *route.Pool {
uri := route.Uri(hostWithoutPort(request))
// Choose backend using host alone
return p.registry.Lookup(uri)
}

func (p *proxy) ServeHTTP(responseWriter http.ResponseWriter, request *http.Request) {
startedAt := time.Now()
handler := NewRequestHandler(request, responseWriter)

accessLog := access_log.AccessLogRecord{
Request: request,
StartedAt: startedAt,
}

handler := NewRequestHandler(request, responseWriter, p.reporter, &accessLog)

p.waitgroup.Add(1)

defer func() {
Expand All @@ -143,167 +148,188 @@ func (p *proxy) ServeHTTP(responseWriter http.ResponseWriter, request *http.Requ
return
}

routeEndpoint, found := p.lookup(request)
if !found {
routePool := p.lookup(request)
if routePool == nil {
p.reporter.CaptureBadRequest(request)
handler.HandleMissingRoute()
return
}

handler.logger.Set("RouteEndpoint", routeEndpoint.ToLogData())
stickyEndpointId := p.getStickySession(request)
iter := &wrappedIterator{
nested: routePool.Endpoints(stickyEndpointId),

accessLog.RouteEndpoint = routeEndpoint

p.reporter.CaptureRoutingRequest(routeEndpoint, handler.request)
afterNext: func(endpoint *route.Endpoint) {
if endpoint != nil {
handler.logger.Set("RouteEndpoint", endpoint.ToLogData())
accessLog.RouteEndpoint = endpoint
p.reporter.CaptureRoutingRequest(endpoint, request)
}
},
}

if isTcpUpgrade(request) {
handler.HandleTcpRequest(routeEndpoint)
handler.HandleTcpRequest(iter)
return
}

if isWebSocketUpgrade(request) {
handler.HandleWebSocketRequest(routeEndpoint)
handler.HandleWebSocketRequest(iter)
return
}

proxyWriter := newProxyResponseWriter(responseWriter)
roundTripper := &proxyRoundTripper{
transport: p.transport,
after: func(rsp *http.Response, err error) {
iter: iter,
handler: &handler,

after: func(rsp *http.Response, endpoint *route.Endpoint, err error) {
accessLog.FirstByteAt = time.Now()
accessLog.Response = rsp
if rsp != nil {
accessLog.StatusCode = rsp.StatusCode
}

// disable keep-alives -- not needed with Go 1.3
responseWriter.Header().Set("Connection", "close")

if p.traceKey != "" && request.Header.Get(router_http.VcapTraceHeader) == p.traceKey {
setTraceHeaders(responseWriter, p.ip, routeEndpoint.CanonicalAddr())
setTraceHeaders(responseWriter, p.ip, endpoint.CanonicalAddr())
}

latency := time.Since(startedAt)

p.reporter.CaptureRoutingResponse(routeEndpoint, rsp, startedAt, latency)
p.reporter.CaptureRoutingResponse(endpoint, rsp, startedAt, latency)

if err != nil {
p.reporter.CaptureBadGateway(request)
handler.HandleBadGateway(err)
proxyWriter.Done()
return
}

if routeEndpoint.PrivateInstanceId != "" {
setupStickySession(responseWriter, rsp, routeEndpoint)
if endpoint.PrivateInstanceId != "" {
setupStickySession(responseWriter, rsp, endpoint)
}
},
}
proxyTransport := autowire.InstrumentedRoundTripper(roundTripper)

proxyWriter := newProxyResponseWriter(responseWriter)
p.newReverseProxy(proxyTransport, routeEndpoint, request).ServeHTTP(proxyWriter, request)
p.newReverseProxy(proxyTransport, request).ServeHTTP(proxyWriter, request)

accessLog.FinishedAt = time.Now()
accessLog.BodyBytesSent = int64(proxyWriter.Size())
}

func (p *proxy) newReverseProxy(proxyTransport http.RoundTripper, endpoint *route.Endpoint, req *http.Request) http.Handler {
func (p *proxy) newReverseProxy(proxyTransport http.RoundTripper, req *http.Request) http.Handler {
rproxy := &httputil.ReverseProxy{
Director: func(request *http.Request) {
request.URL.Scheme = "http"
request.URL.Host = endpoint.CanonicalAddr()
request.URL.Host = req.Host
request.URL.Opaque = req.URL.Opaque
request.URL.RawQuery = req.URL.RawQuery
request.Header.Set("X-CF-ApplicationID", endpoint.ApplicationId)
setRequestXCfInstanceId(req, endpoint)

setRequestXRequestStart(req)
setRequestXVcapRequestId(req, nil)
},
Transport: proxyTransport,
FlushInterval: 50 * time.Millisecond,
}

rproxy.Transport = proxyTransport
rproxy.FlushInterval = 50 * time.Millisecond

return rproxy
}

func setupStickySession(responseWriter http.ResponseWriter, response *http.Response, endpoint *route.Endpoint) {
for _, v := range response.Cookies() {
if v.Name == StickyCookieKey {
cookie := &http.Cookie{
Name: VcapCookieId,
Value: endpoint.PrivateInstanceId,
Path: "/",
}

http.SetCookie(responseWriter, cookie)
return
}
}
}

type proxyRoundTripper struct {
transport http.RoundTripper
after func(response *http.Response, err error)
response *http.Response
err error
after AfterRoundTrip
iter route.EndpointIterator
handler *RequestHandler

response *http.Response
err error
}

func (p *proxyRoundTripper) RoundTrip(request *http.Request) (*http.Response, error) {
p.response, p.err = p.transport.RoundTrip(request)
if p.after != nil {
p.after(p.response, p.err)
}
var err error
var res *http.Response
var endpoint *route.Endpoint
retry := 0
for {
endpoint = p.iter.Next()

if endpoint == nil {
p.handler.reporter.CaptureBadGateway(request)
err = noEndpointsAvailable
p.handler.HandleBadGateway(err)
return nil, err
}

return p.response, p.err
}
request.URL.Host = endpoint.CanonicalAddr()
request.Header.Set("X-CF-ApplicationID", endpoint.ApplicationId)
setRequestXCfInstanceId(request, endpoint)

type proxyResponseWriter struct {
w http.ResponseWriter
status int
size int
res, err = p.transport.RoundTrip(request)
if err == nil {
break
}

flusher http.Flusher
}
if ne, netErr := err.(*net.OpError); !netErr || ne.Op != "dial" {
break
}

func newProxyResponseWriter(w http.ResponseWriter) *proxyResponseWriter {
proxyWriter := &proxyResponseWriter{
w: w,
flusher: w.(http.Flusher),
}
p.iter.EndpointFailed()

return proxyWriter
}
p.handler.Logger().Set("Error", err.Error())
p.handler.Logger().Warnf("proxy.endpoint.failed")

func (p *proxyResponseWriter) Header() http.Header {
return p.w.Header()
}
retry++
if retry == retries {
break
}
}

func (p *proxyResponseWriter) Write(b []byte) (int, error) {
if p.status == 0 {
p.WriteHeader(http.StatusOK)
if p.after != nil {
p.after(res, endpoint, err)
}
size, err := p.w.Write(b)
p.size += size
return size, err
}

func (p *proxyResponseWriter) WriteHeader(s int) {
p.w.WriteHeader(s)
p.response = res
p.err = err

if p.status == 0 {
p.status = s
}
return res, err
}
func (p *proxyResponseWriter) Flush() {
if p.flusher != nil {
p.flusher.Flush()

type wrappedIterator struct {
nested route.EndpointIterator
afterNext func(*route.Endpoint)
}

func (i *wrappedIterator) Next() *route.Endpoint {
e := i.nested.Next()
if i.afterNext != nil {
i.afterNext(e)
}
return e
}

func (p *proxyResponseWriter) Status() int {
return p.status
func (i *wrappedIterator) EndpointFailed() {
i.nested.EndpointFailed()
}

func (p *proxyResponseWriter) Size() int {
return p.size
func setupStickySession(responseWriter http.ResponseWriter, response *http.Response, endpoint *route.Endpoint) {
for _, v := range response.Cookies() {
if v.Name == StickyCookieKey {
cookie := &http.Cookie{
Name: VcapCookieId,
Value: endpoint.PrivateInstanceId,
Path: "/",

HttpOnly: true,
}

http.SetCookie(responseWriter, cookie)
return
}
}
}

func isProtocolSupported(request *http.Request) bool {
Expand Down
58 changes: 39 additions & 19 deletions proxy/proxy_test.go
Expand Up @@ -33,14 +33,14 @@ type nullVarz struct{}

func (_ nullVarz) MarshalJSON() ([]byte, error) { return json.Marshal(nil) }
func (_ nullVarz) ActiveApps() *stats.ActiveApps { return stats.NewActiveApps() }
func (_ nullVarz) CaptureBadRequest(req *http.Request) {}
func (_ nullVarz) CaptureBadGateway(req *http.Request) {}
func (_ nullVarz) CaptureBadRequest(*http.Request) {}
func (_ nullVarz) CaptureBadGateway(*http.Request) {}
func (_ nullVarz) CaptureRoutingRequest(b *route.Endpoint, req *http.Request) {}
func (_ nullVarz) CaptureRoutingResponse(b *route.Endpoint, res *http.Response, t time.Time, d time.Duration) {
}

var _ = Describe("Proxy", func() {
var r *registry.CFRegistry
var r *registry.RouteRegistry
var p Proxy
var conf *config.Config
var proxyServer net.Listener
Expand All @@ -54,7 +54,7 @@ var _ = Describe("Proxy", func() {

mbus := fakeyagnats.New()

r = registry.NewCFRegistry(conf, mbus)
r = registry.NewRouteRegistry(conf, mbus)

accessLogFile = new(test_util.FakeFile)
accessLog = access_log.NewFileAndLoggregatorAccessLogger(accessLogFile, nil)
Expand Down Expand Up @@ -112,6 +112,7 @@ var _ = Describe("Proxy", func() {
"HTTP/1.1 200 OK",
"Content-Length: 0",
})

})
defer ln.Close()

Expand All @@ -125,9 +126,10 @@ var _ = Describe("Proxy", func() {
x.CheckLine("HTTP/1.0 200 OK")

var payload []byte
n, e := accessLogFile.Read(&payload)
Ω(e).ShouldNot(HaveOccurred())
Ω(n).ShouldNot(BeZero())
Eventually(func() int {
accessLogFile.Read(&payload)
return len(payload)
}).ShouldNot(BeZero())
Ω(string(payload)).To(MatchRegexp("^test.*\n"))
//make sure the record includes all the data
//since the building of the log record happens throughout the life of the request
Expand Down Expand Up @@ -878,6 +880,31 @@ var _ = Describe("Proxy", func() {
Ω(err).Should(HaveOccurred())
})

It("retries when failed endpoints exist", func() {
ln := registerHandler(r, "retries", func(x *test_util.HttpConn) {
x.CheckLine("GET / HTTP/1.1")
resp := test_util.NewResponse(http.StatusOK)
x.WriteResponse(resp)
x.Close()
})
defer ln.Close()

ip, err := net.ResolveTCPAddr("tcp", "localhost:81")
Ω(err).Should(BeNil())
registerAddr(r, "retries", ip, "instanceId")

for i := 0; i < 5; i++ {
x := dialProxy(proxyServer)

req := x.NewRequest("GET", "/", nil)
req.Host = "retries"
x.WriteRequest(req)
resp, _ := x.ReadResponse()

Ω(resp.StatusCode).To(Equal(http.StatusOK))
}
})

Context("Wait", func() {
It("waits for requests to finish", func() {
blocker := make(chan bool)
Expand Down Expand Up @@ -918,28 +945,21 @@ var _ = Describe("Proxy", func() {
})
})

func registerAddr(r *registry.CFRegistry, u string, a net.Addr, instanceId string) {
func registerAddr(r *registry.RouteRegistry, u string, a net.Addr, instanceId string) {
h, p, err := net.SplitHostPort(a.String())
Ω(err).NotTo(HaveOccurred())

x, err := strconv.Atoi(p)
Ω(err).NotTo(HaveOccurred())

r.Register(
route.Uri(u),
&route.Endpoint{
Host: h,
Port: uint16(x),
PrivateInstanceId: instanceId,
},
)
r.Register(route.Uri(u), route.NewEndpoint("", h, uint16(x), instanceId, nil))
}

func registerHandler(r *registry.CFRegistry, u string, h connHandler) net.Listener {
func registerHandler(r *registry.RouteRegistry, u string, h connHandler) net.Listener {
return registerHandlerWithInstanceId(r, u, h, "")
}

func registerHandlerWithInstanceId(r *registry.CFRegistry, u string, h connHandler, instanceId string) net.Listener {
func registerHandlerWithInstanceId(r *registry.RouteRegistry, u string, h connHandler, instanceId string) net.Listener {
ln, err := net.Listen("tcp", "127.0.0.1:0")
Ω(err).NotTo(HaveOccurred())

Expand All @@ -957,7 +977,7 @@ func registerHandlerWithInstanceId(r *registry.CFRegistry, u string, h connHandl
if max := 1 * time.Second; tempDelay > max {
tempDelay = max
}
println("http: Accept error: %v; retrying in %v", err, tempDelay)
fmt.Printf("http: Accept error: %v; retrying in %v\n", err, tempDelay)
time.Sleep(tempDelay)
continue
}
Expand Down
125 changes: 91 additions & 34 deletions proxy/request_handler.go
Expand Up @@ -11,22 +11,28 @@ import (
"strings"
"time"

"github.com/cloudfoundry/gorouter/access_log"
"github.com/cloudfoundry/gorouter/common"
router_http "github.com/cloudfoundry/gorouter/common/http"
"github.com/cloudfoundry/gorouter/route"
steno "github.com/cloudfoundry/gosteno"
)

type RequestHandler struct {
logger *steno.Logger
logger *steno.Logger
reporter ProxyReporter
logrecord *access_log.AccessLogRecord

request *http.Request
response http.ResponseWriter
}

func NewRequestHandler(request *http.Request, response http.ResponseWriter) RequestHandler {
func NewRequestHandler(request *http.Request, response http.ResponseWriter, r ProxyReporter,
alr *access_log.AccessLogRecord) RequestHandler {
return RequestHandler{
logger: createLogger(request),
logger: createLogger(request),
reporter: r,
logrecord: alr,

request: request,
response: response,
Expand All @@ -45,7 +51,12 @@ func createLogger(request *http.Request) *steno.Logger {
return logger
}

func (h *RequestHandler) Logger() *steno.Logger {
return h.logger
}

func (h *RequestHandler) HandleHeartbeat() {
h.logrecord.StatusCode = http.StatusOK
h.response.WriteHeader(http.StatusOK)
h.response.Write([]byte("ok\n"))
h.request.Close = true
Expand All @@ -59,6 +70,7 @@ func (h *RequestHandler) HandleUnsupportedProtocol() {
return
}

h.logrecord.StatusCode = http.StatusBadRequest
fmt.Fprintf(buf, "HTTP/1.0 400 Bad Request\r\n\r\n")
buf.Flush()
conn.Close()
Expand All @@ -80,28 +92,20 @@ func (h *RequestHandler) HandleBadGateway(err error) {
h.writeStatus(http.StatusBadGateway, "Registered endpoint failed to handle the request.")
}

func (h *RequestHandler) HandleTcpRequest(endpoint *route.Endpoint) {
func (h *RequestHandler) HandleTcpRequest(iter route.EndpointIterator) {
h.logger.Set("Upgrade", "tcp")

err := h.serveTcp(endpoint)
err := h.serveTcp(iter)
if err != nil {
h.logger.Set("Error", err.Error())
h.logger.Warn("proxy.tcp.failed")

h.writeStatus(http.StatusBadRequest, "TCP forwarding to endpoint failed.")
}
}

func (h *RequestHandler) HandleWebSocketRequest(endpoint *route.Endpoint) {
h.setupRequest(endpoint)

func (h *RequestHandler) HandleWebSocketRequest(iter route.EndpointIterator) {
h.logger.Set("Upgrade", "websocket")

err := h.serveWebSocket(endpoint)
err := h.serveWebSocket(iter)
if err != nil {
h.logger.Set("Error", err.Error())
h.logger.Warn("proxy.websocket.failed")

h.writeStatus(http.StatusBadRequest, "WebSocket request to endpoint failed.")
}
}
Expand All @@ -110,61 +114,114 @@ func (h *RequestHandler) writeStatus(code int, message string) {
body := fmt.Sprintf("%d %s: %s", code, http.StatusText(code), message)

h.logger.Warn(body)
h.logrecord.StatusCode = code

http.Error(h.response, body, code)
if code > 299 {
h.response.Header().Del("Connection")
}
}

func (h *RequestHandler) serveTcp(endpoint *route.Endpoint) error {
func (h *RequestHandler) serveTcp(iter route.EndpointIterator) error {
var err error
var connection net.Conn

client, _, err := h.hijack()
if err != nil {
return err
}

connection, err := net.DialTimeout("tcp", endpoint.CanonicalAddr(), 5*time.Second)
if err != nil {
return err
}

defer func() {
client.Close()
connection.Close()
if connection != nil {
connection.Close()
}
}()

forwardIO(client, connection)
retry := 0
for {
endpoint := iter.Next()
if endpoint == nil {
h.reporter.CaptureBadGateway(h.request)
err = noEndpointsAvailable
h.HandleBadGateway(err)
return err
}

connection, err = net.DialTimeout("tcp", endpoint.CanonicalAddr(), 5*time.Second)
if err == nil {
break
}

iter.EndpointFailed()

h.logger.Set("Error", err.Error())
h.logger.Warn("proxy.tcp.failed")

retry++
if retry == retries {
return err
}
}

if connection != nil {
forwardIO(client, connection)
}

return nil
}

func (h *RequestHandler) serveWebSocket(endpoint *route.Endpoint) error {
func (h *RequestHandler) serveWebSocket(iter route.EndpointIterator) error {
var err error
var connection net.Conn

client, _, err := h.hijack()
if err != nil {
return err
}

connection, err := net.DialTimeout("tcp", endpoint.CanonicalAddr(), 5*time.Second)
if err != nil {
return err
}

defer func() {
client.Close()
connection.Close()
if connection != nil {
connection.Close()
}
}()

err = h.request.Write(connection)
if err != nil {
return err
retry := 0
for {
endpoint := iter.Next()
if endpoint == nil {
h.reporter.CaptureBadGateway(h.request)
err = noEndpointsAvailable
h.HandleBadGateway(err)
return err
}

connection, err = net.DialTimeout("tcp", endpoint.CanonicalAddr(), 5*time.Second)
if err == nil {
h.setupRequest(endpoint)
break
}

iter.EndpointFailed()

h.logger.Set("Error", err.Error())
h.logger.Warn("proxy.websocket.failed")

retry++
if retry == retries {
return err
}
}

forwardIO(client, connection)
if connection != nil {
err = h.request.Write(connection)
if err != nil {
return err
}

forwardIO(client, connection)
}
return nil
}

Expand Down
70 changes: 70 additions & 0 deletions proxy/responsewriter.go
@@ -0,0 +1,70 @@
package proxy

import (
"net/http"
)

type proxyResponseWriter struct {
w http.ResponseWriter
status int
size int

flusher http.Flusher
done bool
}

func newProxyResponseWriter(w http.ResponseWriter) *proxyResponseWriter {
proxyWriter := &proxyResponseWriter{
w: w,
flusher: w.(http.Flusher),
}

return proxyWriter
}

func (p *proxyResponseWriter) Header() http.Header {
return p.w.Header()
}

func (p *proxyResponseWriter) Write(b []byte) (int, error) {
if p.done {
return 0, nil
}

if p.status == 0 {
p.WriteHeader(http.StatusOK)
}
size, err := p.w.Write(b)
p.size += size
return size, err
}

func (p *proxyResponseWriter) WriteHeader(s int) {
if p.done {
return
}

p.w.WriteHeader(s)

if p.status == 0 {
p.status = s
}
}

func (p *proxyResponseWriter) Done() {
p.done = true
}

func (p *proxyResponseWriter) Flush() {
if p.flusher != nil {
p.flusher.Flush()
}
}

func (p *proxyResponseWriter) Status() int {
return p.status
}

func (p *proxyResponseWriter) Size() int {
return p.size
}
232 changes: 86 additions & 146 deletions registry/registry.go
Expand Up @@ -12,42 +12,29 @@ import (
"github.com/cloudfoundry/gorouter/route"
)

type CFRegistry struct {
type RouteRegistry struct {
sync.RWMutex

logger *steno.Logger

byUri map[route.Uri]*route.Pool

table map[tableKey]*tableEntry

pruneStaleDropletsInterval time.Duration
dropletStaleThreshold time.Duration

messageBus yagnats.NATSClient

ticker *time.Ticker
timeOfLastUpdate time.Time
}

type tableKey struct {
addr string
uri route.Uri
}

type tableEntry struct {
endpoint *route.Endpoint
updatedAt time.Time
}

func NewCFRegistry(c *config.Config, mbus yagnats.NATSClient) *CFRegistry {
r := &CFRegistry{}
func NewRouteRegistry(c *config.Config, mbus yagnats.NATSClient) *RouteRegistry {
r := &RouteRegistry{}

r.logger = steno.NewLogger("router.registry")

r.byUri = make(map[route.Uri]*route.Pool)

r.table = make(map[tableKey]*tableEntry)

r.pruneStaleDropletsInterval = c.PruneStaleDropletsInterval
r.dropletStaleThreshold = c.DropletStaleThreshold

Expand All @@ -56,192 +43,145 @@ func NewCFRegistry(c *config.Config, mbus yagnats.NATSClient) *CFRegistry {
return r
}

func (registry *CFRegistry) Register(uri route.Uri, endpoint *route.Endpoint) {
registry.Lock()
defer registry.Unlock()
func (r *RouteRegistry) Register(uri route.Uri, endpoint *route.Endpoint) {
t := time.Now()
r.Lock()

uri = uri.ToLower()

key := tableKey{
addr: endpoint.CanonicalAddr(),
uri: uri,
}

var endpointToRegister *route.Endpoint

entry, found := registry.table[key]
if found {
endpointToRegister = entry.endpoint
} else {
endpointToRegister = endpoint
entry = &tableEntry{endpoint: endpoint}

registry.table[key] = entry
}

pool, found := registry.byUri[uri]
pool, found := r.byUri[uri]
if !found {
pool = route.NewPool()
registry.byUri[uri] = pool
pool = route.NewPool(r.dropletStaleThreshold / 4)
r.byUri[uri] = pool
}

pool.Add(endpointToRegister)

entry.updatedAt = time.Now()
pool.Put(endpoint)

registry.timeOfLastUpdate = time.Now()
r.timeOfLastUpdate = t
r.Unlock()
}

func (registry *CFRegistry) Unregister(uri route.Uri, endpoint *route.Endpoint) {
registry.Lock()
defer registry.Unlock()
func (r *RouteRegistry) Unregister(uri route.Uri, endpoint *route.Endpoint) {
r.Lock()

uri = uri.ToLower()

key := tableKey{
addr: endpoint.CanonicalAddr(),
uri: uri,
pool, found := r.byUri[uri]
if found {
pool.Remove(endpoint)

if pool.IsEmpty() {
delete(r.byUri, uri)
}
}

registry.unregisterUri(key)
r.Unlock()
}

func (r *CFRegistry) Lookup(uri route.Uri) (*route.Endpoint, bool) {
func (r *RouteRegistry) Lookup(uri route.Uri) *route.Pool {
r.RLock()
defer r.RUnlock()

pool, ok := r.lookupByUri(uri)
if !ok {
return nil, false
}
uri = uri.ToLower()
pool := r.byUri[uri]

return pool.Sample()
r.RUnlock()

return pool
}

func (r *CFRegistry) LookupByPrivateInstanceId(uri route.Uri, p string) (*route.Endpoint, bool) {
r.RLock()
defer r.RUnlock()
func (r *RouteRegistry) StartPruningCycle() {
if r.pruneStaleDropletsInterval > 0 {
r.Lock()
r.ticker = time.NewTicker(r.pruneStaleDropletsInterval)
r.Unlock()

pool, ok := r.lookupByUri(uri)
if !ok {
return nil, false
}
go func() {
for {
select {
case <-r.ticker.C:
r.logger.Debug("Start to check and prune stale droplets")
if r.isStateStale() {
r.logger.Info("State is stale; NOT pruning")
r.pauseStaleTracker()
break
}

return pool.FindByPrivateInstanceId(p)
}
r.pruneStaleDroplets()

func (r *CFRegistry) lookupByUri(uri route.Uri) (*route.Pool, bool) {
uri = uri.ToLower()
pool, ok := r.byUri[uri]
return pool, ok
}

func (r *CFRegistry) StartPruningCycle() {
go r.checkAndPrune()
}
}
}()
}
}

func (r *CFRegistry) PruneStaleDroplets() {
if r.isStateStale() {
r.logger.Info("State is stale; NOT pruning")
r.pauseStaleTracker()
return
func (r *RouteRegistry) StopPruningCycle() {
r.Lock()
if r.ticker != nil {
r.ticker.Stop()
}

r.pruneStaleDroplets()
r.Unlock()
}

func (registry *CFRegistry) NumUris() int {
func (registry *RouteRegistry) NumUris() int {
registry.RLock()
defer registry.RUnlock()
uriCount := len(registry.byUri)
registry.RUnlock()

return len(registry.byUri)
return uriCount
}

func (r *CFRegistry) TimeOfLastUpdate() time.Time {
func (r *RouteRegistry) TimeOfLastUpdate() time.Time {
r.RLock()
defer r.RUnlock()
return r.timeOfLastUpdate
t := r.timeOfLastUpdate
r.RUnlock()

return t
}

func (r *CFRegistry) NumEndpoints() int {
func (r *RouteRegistry) NumEndpoints() int {
r.RLock()
defer r.RUnlock()

mapForSize := make(map[string]bool)
for _, entry := range r.table {
mapForSize[entry.endpoint.CanonicalAddr()] = true
uris := make(map[string]struct{})
f := func(endpoint *route.Endpoint) {
uris[endpoint.CanonicalAddr()] = struct{}{}
}
for _, pool := range r.byUri {
pool.Each(f)
}
r.RUnlock()

return len(mapForSize)
return len(uris)
}

func (r *CFRegistry) MarshalJSON() ([]byte, error) {
func (r *RouteRegistry) MarshalJSON() ([]byte, error) {
r.RLock()
defer r.RUnlock()

return json.Marshal(r.byUri)
}

func (r *CFRegistry) isStateStale() bool {
func (r *RouteRegistry) isStateStale() bool {
return !r.messageBus.Ping()
}

func (r *CFRegistry) pruneStaleDroplets() {
func (r *RouteRegistry) pruneStaleDroplets() {
r.Lock()
defer r.Unlock()

for key, entry := range r.table {
if !r.isEntryStale(entry) {
continue
pruneTime := time.Now().Add(-r.dropletStaleThreshold)
for k, pool := range r.byUri {
pool.PruneBefore(pruneTime)
if pool.IsEmpty() {
delete(r.byUri, k)
}

r.logger.Infof("Pruning stale droplet: %v, uri: %s", entry, key.uri)
r.unregisterUri(key)
}
r.Unlock()
}

func (r *CFRegistry) isEntryStale(entry *tableEntry) bool {
return entry.updatedAt.Add(r.dropletStaleThreshold).Before(time.Now())
}

func (r *CFRegistry) pauseStaleTracker() {
func (r *RouteRegistry) pauseStaleTracker() {
r.Lock()
defer r.Unlock()

for _, entry := range r.table {
entry.updatedAt = time.Now()
}
}

func (r *CFRegistry) checkAndPrune() {
if r.pruneStaleDropletsInterval == 0 {
return
}

tick := time.Tick(r.pruneStaleDropletsInterval)
for {
select {
case <-tick:
r.logger.Debug("Start to check and prune stale droplets")
r.PruneStaleDroplets()
}
}
}

func (r *CFRegistry) unregisterUri(key tableKey) {
entry, found := r.table[key]
if !found {
return
}
t := time.Now()

endpoints, found := r.byUri[key.uri]
if found {
endpoints.Remove(entry.endpoint)

if endpoints.IsEmpty() {
delete(r.byUri, key.uri)
}
for _, pool := range r.byUri {
pool.MarkUpdated(t)
}

delete(r.table, key)
r.Unlock()
}
206 changes: 93 additions & 113 deletions registry/registry_test.go
Expand Up @@ -13,52 +13,39 @@ import (
"time"
)

var _ = Describe("Registry", func() {
var r *CFRegistry
var _ = Describe("RouteRegistry", func() {
var r *RouteRegistry
var messageBus *fakeyagnats.FakeYagnats

var fooEndpoint, barEndpoint, bar2Endpoint *route.Endpoint
var configObj *config.Config

BeforeEach(func() {
configObj = config.DefaultConfig()
configObj.PruneStaleDropletsInterval = 50 * time.Millisecond
configObj.DropletStaleThreshold = 10 * time.Millisecond

messageBus = fakeyagnats.New()
r = NewCFRegistry(configObj, messageBus)
fooEndpoint = &route.Endpoint{
Host: "192.168.1.1",
Port: 1234,

ApplicationId: "12345",
Tags: map[string]string{
r = NewRouteRegistry(configObj, messageBus)
fooEndpoint = route.NewEndpoint("12345", "192.168.1.1", 1234,
"id1", map[string]string{
"runtime": "ruby18",
"framework": "sinatra",
},
}

barEndpoint = &route.Endpoint{
Host: "192.168.1.2",
Port: 4321,
})

ApplicationId: "54321",
Tags: map[string]string{
barEndpoint = route.NewEndpoint("54321", "192.168.1.2", 4321,
"id2", map[string]string{
"runtime": "javascript",
"framework": "node",
},
}

bar2Endpoint = &route.Endpoint{
Host: "192.168.1.3",
Port: 1234,
})

ApplicationId: "54321",
Tags: map[string]string{
bar2Endpoint = route.NewEndpoint("54321", "192.168.1.3", 1234,
"id3", map[string]string{
"runtime": "javascript",
"framework": "node",
},
}
})
})

Context("Register", func() {
It("records and tracks time of last update", func() {
r.Register("foo", fooEndpoint)
Expand Down Expand Up @@ -89,15 +76,8 @@ var _ = Describe("Registry", func() {
})

It("ignores case", func() {
m1 := &route.Endpoint{
Host: "192.168.1.1",
Port: 1234,
}

m2 := &route.Endpoint{
Host: "192.168.1.1",
Port: 1235,
}
m1 := route.NewEndpoint("", "192.168.1.1", 1234, "", nil)
m2 := route.NewEndpoint("", "192.168.1.1", 1235, "", nil)

r.Register("foo", m1)
r.Register("FOO", m2)
Expand All @@ -106,15 +86,8 @@ var _ = Describe("Registry", func() {
})

It("allows multiple uris for the same endpoint", func() {
m1 := &route.Endpoint{
Host: "192.168.1.1",
Port: 1234,
}

m2 := &route.Endpoint{
Host: "192.168.1.1",
Port: 1234,
}
m1 := route.NewEndpoint("", "192.168.1.1", 1234, "", nil)
m2 := route.NewEndpoint("", "192.168.1.1", 1234, "", nil)

r.Register("foo", m1)
r.Register("bar", m2)
Expand Down Expand Up @@ -148,15 +121,8 @@ var _ = Describe("Registry", func() {
})

It("ignores uri case and matches endpoint", func() {
m1 := &route.Endpoint{
Host: "192.168.1.1",
Port: 1234,
}

m2 := &route.Endpoint{
Host: "192.168.1.1",
Port: 1234,
}
m1 := route.NewEndpoint("", "192.168.1.1", 1234, "", nil)
m2 := route.NewEndpoint("", "192.168.1.1", 1234, "", nil)

r.Register("foo", m1)
r.Unregister("FOO", m2)
Expand All @@ -165,15 +131,8 @@ var _ = Describe("Registry", func() {
})

It("removes the specific url/endpoint combo", func() {
m1 := &route.Endpoint{
Host: "192.168.1.1",
Port: 1234,
}

m2 := &route.Endpoint{
Host: "192.168.1.1",
Port: 1234,
}
m1 := route.NewEndpoint("", "192.168.1.1", 1234, "", nil)
m2 := route.NewEndpoint("", "192.168.1.1", 1234, "", nil)

r.Register("foo", m1)
r.Register("bar", m1)
Expand All @@ -186,32 +145,21 @@ var _ = Describe("Registry", func() {

Context("Lookup", func() {
It("case insensitive lookup", func() {
m := &route.Endpoint{
Host: "192.168.1.1",
Port: 1234,
}
m := route.NewEndpoint("", "192.168.1.1", 1234, "", nil)

r.Register("foo", m)

b, ok := r.Lookup("foo")
Ω(ok).To(BeTrue())
Ω(b.CanonicalAddr()).To(Equal("192.168.1.1:1234"))
p1 := r.Lookup("foo")
p2 := r.Lookup("FOO")
Ω(p1).To(Equal(p2))

b, ok = r.Lookup("FOO")
Ω(ok).To(BeTrue())
Ω(b.CanonicalAddr()).To(Equal("192.168.1.1:1234"))
iter := p1.Endpoints("")
Ω(iter.Next().CanonicalAddr()).To(Equal("192.168.1.1:1234"))
})

It("selects one of the routes", func() {
m1 := &route.Endpoint{
Host: "192.168.1.2",
Port: 1234,
}

m2 := &route.Endpoint{
Host: "192.168.1.2",
Port: 1235,
}
m1 := route.NewEndpoint("", "192.168.1.1", 1234, "", nil)
m2 := route.NewEndpoint("", "192.168.1.1", 1235, "", nil)

r.Register("bar", m1)
r.Register("barr", m1)
Expand All @@ -222,13 +170,19 @@ var _ = Describe("Registry", func() {
Ω(r.NumUris()).To(Equal(2))
Ω(r.NumEndpoints()).To(Equal(2))

b, ok := r.Lookup("bar")
Ω(ok).To(BeTrue())
Ω(b.Host).To(Equal("192.168.1.2"))
Ω(b.Port == m1.Port || b.Port == m2.Port).To(BeTrue())
p := r.Lookup("bar")
Ω(p).ShouldNot(BeNil())
e := p.Endpoints("").Next()
Ω(e).ShouldNot(BeNil())
Ω(e.CanonicalAddr()).To(MatchRegexp("192.168.1.1:123[4|5]"))
})
})
Context("PruneStaleDropelts", func() {
Context("Prunes Stale Droplets", func() {

AfterEach(func() {
r.StopPruningCycle()
})

It("removes stale droplets", func() {
r.Register("foo", fooEndpoint)
r.Register("fooo", fooEndpoint)
Expand All @@ -239,18 +193,15 @@ var _ = Describe("Registry", func() {
Ω(r.NumUris()).To(Equal(4))
Ω(r.NumEndpoints()).To(Equal(2))

time.Sleep(configObj.DropletStaleThreshold + 1*time.Millisecond)
r.PruneStaleDroplets()
r.StartPruningCycle()
time.Sleep(configObj.PruneStaleDropletsInterval + 10*time.Millisecond)

Ω(r.NumUris()).To(Equal(0))
Ω(r.NumEndpoints()).To(Equal(0))
})

It("skips fresh droplets", func() {
endpoint := &route.Endpoint{
Host: "192.168.1.1",
Port: 1234,
}
endpoint := route.NewEndpoint("", "192.168.1.1", 1234, "", nil)

r.Register("foo", endpoint)
r.Register("bar", endpoint)
Expand All @@ -260,21 +211,21 @@ var _ = Describe("Registry", func() {
Ω(r.NumUris()).To(Equal(2))
Ω(r.NumEndpoints()).To(Equal(1))

time.Sleep(configObj.DropletStaleThreshold + 1*time.Millisecond)
r.StartPruningCycle()
time.Sleep(configObj.PruneStaleDropletsInterval + 10*time.Millisecond)

r.Register("foo", endpoint)

r.PruneStaleDroplets()

r.StopPruningCycle()
Ω(r.NumUris()).To(Equal(1))
Ω(r.NumEndpoints()).To(Equal(1))

foundEndpoint, found := r.Lookup("foo")
Ω(found).To(BeTrue())
Ω(foundEndpoint).To(Equal(endpoint))
p := r.Lookup("foo")
Ω(p).ShouldNot(BeNil())
Ω(p.Endpoints("").Next()).To(Equal(endpoint))

_, found = r.Lookup("bar")
Ω(found).To(BeFalse())
p = r.Lookup("bar")
Ω(p).Should(BeNil())
})

It("disables pruning when NATS is unavailable", func() {
Expand All @@ -287,10 +238,9 @@ var _ = Describe("Registry", func() {
Ω(r.NumUris()).To(Equal(4))
Ω(r.NumEndpoints()).To(Equal(2))

time.Sleep(configObj.DropletStaleThreshold + 1*time.Millisecond)

messageBus.OnPing(func() bool { return false })
r.PruneStaleDroplets()
r.StartPruningCycle()
time.Sleep(configObj.PruneStaleDropletsInterval + 10*time.Millisecond)

Ω(r.NumUris()).To(Equal(4))
Ω(r.NumEndpoints()).To(Equal(2))
Expand All @@ -313,25 +263,55 @@ var _ = Describe("Registry", func() {
return false
})

go r.PruneStaleDroplets()
r.StartPruningCycle()
<-barrier

_, ok := r.Lookup("foo")
p := r.Lookup("foo")
barrier <- struct{}{}
Ω(ok).To(BeTrue())
Ω(p).ShouldNot(BeNil())
})
})

It("marshals", func() {
m := &route.Endpoint{
Host: "192.168.1.1",
Port: 1234,
}
Context("Varz data", func() {
It("NumUris", func() {
r.Register("bar", barEndpoint)
r.Register("baar", barEndpoint)

Ω(r.NumUris()).To(Equal(2))

r.Register("foo", fooEndpoint)

Ω(r.NumUris()).To(Equal(3))
})

It("NumEndpoints", func() {
r.Register("bar", barEndpoint)
r.Register("baar", barEndpoint)

Ω(r.NumEndpoints()).To(Equal(1))

r.Register("foo", fooEndpoint)

Ω(r.NumEndpoints()).To(Equal(2))
})

It("TimeOfLastUpdate", func() {
start := time.Now()
r.Register("bar", barEndpoint)
t := r.TimeOfLastUpdate()
end := time.Now()

Ω(start.Before(t)).Should(BeTrue())
Ω(end.After(t)).Should(BeTrue())
})
})

It("marshals", func() {
m := route.NewEndpoint("", "192.168.1.1", 1234, "", nil)
r.Register("foo", m)

marshalled, err := json.Marshal(r)
Ω(err).NotTo(HaveOccurred())

Ω(string(marshalled)).To(Equal(`{"foo":["192.168.1.1:1234"]}`))
})
})
26 changes: 15 additions & 11 deletions route/endpoint.go
Expand Up @@ -3,37 +3,41 @@ package route
import (
"encoding/json"
"fmt"
"sync"
)

type Endpoint struct {
sync.Mutex
func NewEndpoint(appId, host string, port uint16, privateInstanceId string,
tags map[string]string) *Endpoint {
return &Endpoint{
ApplicationId: appId,
addr: fmt.Sprintf("%s:%d", host, port),
Tags: tags,
PrivateInstanceId: privateInstanceId,
}
}

type Endpoint struct {
ApplicationId string
Host string
Port uint16
addr string
Tags map[string]string
PrivateInstanceId string
}

func (e *Endpoint) MarshalJSON() ([]byte, error) {
return json.Marshal(e.CanonicalAddr())
return json.Marshal(e.addr)
}

func (e *Endpoint) CanonicalAddr() string {
return fmt.Sprintf("%s:%d", e.Host, e.Port)
return e.addr
}

func (e *Endpoint) ToLogData() interface{} {
return struct {
ApplicationId string
Host string
Port uint16
Addr string
Tags map[string]string
}{
e.ApplicationId,
e.Host,
e.Port,
e.addr,
e.Tags,
}
}
198 changes: 198 additions & 0 deletions route/endpoint_iterator_test.go
@@ -0,0 +1,198 @@
package route_test

import (
"time"
. "github.com/cloudfoundry/gorouter/route"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)

var _ = Describe("EndpointIterator", func() {
var pool *Pool

BeforeEach(func() {
pool = NewPool(2 * time.Minute)
})

Describe("Next", func() {
It("performs round-robin through the endpoints", func() {
e1 := NewEndpoint("", "1.2.3.4", 5678, "", nil)
e2 := NewEndpoint("", "5.6.7.8", 1234, "", nil)
e3 := NewEndpoint("", "1.2.7.8", 1234, "", nil)
endpoints := []*Endpoint{e1, e2, e3}

for _, e := range endpoints {
pool.Put(e)
}

counts := make([]int, len(endpoints))

iter := pool.Endpoints("")

loops := 50
for i := 0; i < len(endpoints)*loops; i += 1 {
n := iter.Next()
for j, e := range endpoints {
if e == n {
counts[j]++
break
}
}
}

for i := 0; i < len(endpoints); i++ {
Ω(counts[i]).To(Equal(loops))
}
})

It("returns nil when no endpoints exist", func() {
iter := pool.Endpoints("")
e := iter.Next()
Ω(e).Should(BeNil())
})

It("finds the initial endpoint by private id", func() {
b := NewEndpoint("", "1.2.3.4", 1235, "b", nil)
pool.Put(NewEndpoint("", "1.2.3.4", 1234, "a", nil))
pool.Put(b)
pool.Put(NewEndpoint("", "1.2.3.4", 1236, "c", nil))
pool.Put(NewEndpoint("", "1.2.3.4", 1237, "d", nil))

for i := 0; i < 10; i++ {
iter := pool.Endpoints(b.PrivateInstanceId)
e := iter.Next()
Ω(e).ShouldNot(BeNil())
Ω(e.PrivateInstanceId).To(Equal(b.PrivateInstanceId))
}
})

It("finds the initial endpoint by canonical addr", func() {
b := NewEndpoint("", "1.2.3.4", 1235, "b", nil)
pool.Put(NewEndpoint("", "1.2.3.4", 1234, "a", nil))
pool.Put(b)
pool.Put(NewEndpoint("", "1.2.3.4", 1236, "c", nil))
pool.Put(NewEndpoint("", "1.2.3.4", 1237, "d", nil))

for i := 0; i < 10; i++ {
iter := pool.Endpoints(b.CanonicalAddr())
e := iter.Next()
Ω(e).ShouldNot(BeNil())
Ω(e.CanonicalAddr()).To(Equal(b.CanonicalAddr()))
}
})

It("finds when there are multiple private ids", func() {
endpointFoo := NewEndpoint("", "1.2.3.4", 1234, "foo", nil)
endpointBar := NewEndpoint("", "5.6.7.8", 5678, "bar", nil)

pool.Put(endpointFoo)
pool.Put(endpointBar)

iter := pool.Endpoints(endpointFoo.PrivateInstanceId)
foundEndpoint := iter.Next()
Ω(foundEndpoint).ToNot(BeNil())
Ω(foundEndpoint).To(Equal(endpointFoo))

iter = pool.Endpoints(endpointBar.PrivateInstanceId)
foundEndpoint = iter.Next()
Ω(foundEndpoint).ToNot(BeNil())
Ω(foundEndpoint).To(Equal(endpointBar))
})

It("returns the next available endpoint when the initial is not found", func() {
eFoo := NewEndpoint("", "1.2.3.4", 1234, "foo", nil)
pool.Put(eFoo)

iter := pool.Endpoints("bogus")
e := iter.Next()
Ω(e).ShouldNot(BeNil())
Ω(e).Should(Equal(eFoo))
})

It("finds the correct endpoint when private ids change", func() {
endpointFoo := NewEndpoint("", "1.2.3.4", 1234, "foo", nil)
pool.Put(endpointFoo)

iter := pool.Endpoints(endpointFoo.PrivateInstanceId)
foundEndpoint := iter.Next()
Ω(foundEndpoint).ShouldNot(BeNil())
Ω(foundEndpoint).Should(Equal(endpointFoo))

endpointBar := NewEndpoint("", "1.2.3.4", 1234, "bar", nil)
pool.Put(endpointBar)

iter = pool.Endpoints("foo")
foundEndpoint = iter.Next()
Ω(foundEndpoint).ShouldNot(Equal(endpointFoo))

iter = pool.Endpoints("bar")
Ω(foundEndpoint).Should(Equal(endpointBar))
})
})

Describe("Failed", func() {
It("skips failed endpoints", func() {
e1 := NewEndpoint("", "1.2.3.4", 5678, "", nil)
e2 := NewEndpoint("", "5.6.7.8", 1234, "", nil)
pool.Put(e1)
pool.Put(e2)

iter := pool.Endpoints("")
n := iter.Next()
Ω(n).ShouldNot(BeNil())

iter.EndpointFailed()

nn1 := iter.Next()
nn2 := iter.Next()
Ω(nn1).ShouldNot(BeNil())
Ω(nn2).ShouldNot(BeNil())
Ω(nn1).ShouldNot(Equal(n))
Ω(nn1).Should(Equal(nn2))
})

It("resets when all endpoints are failed", func() {
e1 := NewEndpoint("", "1.2.3.4", 5678, "", nil)
e2 := NewEndpoint("", "5.6.7.8", 1234, "", nil)
pool.Put(e1)
pool.Put(e2)

iter := pool.Endpoints("")
n1 := iter.Next()
iter.EndpointFailed()
n2 := iter.Next()
iter.EndpointFailed()
Ω(n1).ShouldNot(Equal(n2))

n1 = iter.Next()
n2 = iter.Next()
Ω(n1).ShouldNot(Equal(n2))
})

It("resets failed endpoints after exceeding failure duration", func() {
pool = NewPool(50 * time.Millisecond)

e1 := NewEndpoint("", "1.2.3.4", 5678, "", nil)
e2 := NewEndpoint("", "5.6.7.8", 1234, "", nil)
pool.Put(e1)
pool.Put(e2)

iter := pool.Endpoints("")
n1 := iter.Next()
n2 := iter.Next()
Ω(n1).ShouldNot(Equal(n2))

iter.EndpointFailed()

n1 = iter.Next()
n2 = iter.Next()
Ω(n1).Should(Equal(n2))

time.Sleep(50 * time.Millisecond)

n1 = iter.Next()
n2 = iter.Next()
Ω(n1).ShouldNot(Equal(n2))
})
})
})
257 changes: 229 additions & 28 deletions route/pool.go
Expand Up @@ -3,65 +3,266 @@ package route
import (
"encoding/json"
"math/rand"
"sync"
"time"
)

type EndpointIterator interface {
Next() *Endpoint
EndpointFailed()
}

type endpointIterator struct {
pool *Pool

initialEndpoint string
lastEndpoint *Endpoint
}

type endpointElem struct {
endpoint *Endpoint
index int
updated time.Time
failedAt *time.Time
}

type Pool struct {
endpoints map[string]*Endpoint
lock sync.Mutex
endpoints []*endpointElem
index map[string]*endpointElem

retryAfterFailure time.Duration
nextIdx int
}

func NewPool() *Pool {
func NewPool(retryAfterFailure time.Duration) *Pool {
return &Pool{
endpoints: make(map[string]*Endpoint),
endpoints: make([]*endpointElem, 0, 1),
index: make(map[string]*endpointElem),
retryAfterFailure: retryAfterFailure,
nextIdx: -1,
}
}

func (p *Pool) Add(endpoint *Endpoint) {
p.endpoints[endpoint.CanonicalAddr()] = endpoint
func (p *Pool) Put(endpoint *Endpoint) bool {
p.lock.Lock()
defer p.lock.Unlock()

e, found := p.index[endpoint.CanonicalAddr()]
if found {
if e.endpoint == endpoint {
return false
}

oldEndpoint := e.endpoint
e.endpoint = endpoint

if oldEndpoint.PrivateInstanceId != endpoint.PrivateInstanceId {
delete(p.index, oldEndpoint.PrivateInstanceId)
p.index[endpoint.PrivateInstanceId] = e
}
} else {
e = &endpointElem{
endpoint: endpoint,
index: len(p.endpoints),
}

p.endpoints = append(p.endpoints, e)

p.index[endpoint.CanonicalAddr()] = e
p.index[endpoint.PrivateInstanceId] = e
}

e.updated = time.Now()

return !found
}

func (p *Pool) Remove(endpoint *Endpoint) {
delete(p.endpoints, endpoint.CanonicalAddr())
func (p *Pool) Remove(endpoint *Endpoint) bool {
var e *endpointElem

p.lock.Lock()
l := len(p.endpoints)
if l > 0 {
e = p.index[endpoint.CanonicalAddr()]
if e != nil {
p.removeEndpoint(e)
}
}
p.lock.Unlock()

return e != nil
}

func (p *Pool) Sample() (*Endpoint, bool) {
if len(p.endpoints) == 0 {
return nil, false
func (p *Pool) removeEndpoint(e *endpointElem) {
i := e.index
es := p.endpoints
last := len(es)
// re-ordering delete
es[last-1], es[i], es = nil, es[last-1], es[:last-1]
if i < last-1 {
es[i].index = i
}
p.endpoints = es

delete(p.index, e.endpoint.CanonicalAddr())
delete(p.index, e.endpoint.PrivateInstanceId)
}

func (p *Pool) Endpoints(initial string) EndpointIterator {
return newEndpointIterator(p, initial)
}

index := rand.Intn(len(p.endpoints))
func (p *Pool) next() *Endpoint {
p.lock.Lock()
defer p.lock.Unlock()

ticker := 0
for _, endpoint := range p.endpoints {
if ticker == index {
return endpoint, true
last := len(p.endpoints)
if last == 0 {
return nil
}

if p.nextIdx == -1 {
p.nextIdx = rand.Intn(last)
} else if p.nextIdx >= last {
p.nextIdx = 0
}

startIdx := p.nextIdx
curIdx := startIdx
for {
e := p.endpoints[curIdx]

curIdx++
if curIdx == last {
curIdx = 0
}

if e.failedAt != nil {
curTime := time.Now()
if curTime.Sub(*e.failedAt) > p.retryAfterFailure {
// exipired failure window
e.failedAt = nil
}
}

ticker += 1
if e.failedAt == nil {
p.nextIdx = curIdx
return e.endpoint
}

if curIdx == startIdx {
// all endpoints are marked failed so reset everything to available
for _, e2 := range p.endpoints {
e2.failedAt = nil
}
}
}
}

func (p *Pool) findById(id string) *Endpoint {
var endpoint *Endpoint
p.lock.Lock()
e := p.index[id]
if e != nil {
endpoint = e.endpoint
}
p.lock.Unlock()

return endpoint
}

func (p *Pool) IsEmpty() bool {
p.lock.Lock()
l := len(p.endpoints)
p.lock.Unlock()

panic("unreachable")
return l == 0
}

func (p *Pool) FindByPrivateInstanceId(id string) (*Endpoint, bool) {
for _, endpoint := range p.endpoints {
if endpoint.PrivateInstanceId == id {
return endpoint, true
func (p *Pool) PruneBefore(t time.Time) {
p.lock.Lock()

last := len(p.endpoints)
for i := 0; i < last; {
e := p.endpoints[i]
if e.updated.Before(t) {
p.removeEndpoint(e)
last--
} else {
i++
}
}

return nil, false
p.lock.Unlock()
}

func (p *Pool) IsEmpty() bool {
return len(p.endpoints) == 0
func (p *Pool) MarkUpdated(t time.Time) {
p.lock.Lock()
for _, e := range p.endpoints {
e.updated = t
}
p.lock.Unlock()
}

func (p *Pool) MarshalJSON() ([]byte, error) {
addresses := []string{}
func (p *Pool) endpointFailed(endpoint *Endpoint) {
p.lock.Lock()
e := p.index[endpoint.CanonicalAddr()]
if e != nil {
e.failed()
}
p.lock.Unlock()
}

func (p *Pool) Each(f func(endpoint *Endpoint)) {
p.lock.Lock()
for _, e := range p.endpoints {
f(e.endpoint)
}
p.lock.Unlock()
}

for addr, _ := range p.endpoints {
addresses = append(addresses, addr)
func (p *Pool) MarshalJSON() ([]byte, error) {
p.lock.Lock()
addresses := make([]string, 0, len(p.endpoints))
for _, e := range p.endpoints {
addresses = append(addresses, e.endpoint.addr)
}
p.lock.Unlock()

return json.Marshal(addresses)
}

func newEndpointIterator(p *Pool, initial string) EndpointIterator {
return &endpointIterator{
pool: p,
initialEndpoint: initial,
}
}

func (i *endpointIterator) Next() *Endpoint {
var e *Endpoint
if i.initialEndpoint != "" {
e = i.pool.findById(i.initialEndpoint)
i.initialEndpoint = ""
}

if e == nil {
e = i.pool.next()
}

i.lastEndpoint = e

return e
}

func (i *endpointIterator) EndpointFailed() {
if i.lastEndpoint != nil {
i.pool.endpointFailed(i.lastEndpoint)
}
}

func (e *endpointElem) failed() {
t := time.Now()
e.failedAt = &t
}
181 changes: 91 additions & 90 deletions route/pool_test.go
Expand Up @@ -4,147 +4,148 @@ import (
. "github.com/cloudfoundry/gorouter/route"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
"time"
)

var _ = Describe("Route", func() {
Context("Add", func() {
var _ = Describe("Pool", func() {
var pool *Pool

BeforeEach(func() {
pool = NewPool(2 * time.Minute)
})

Context("Put", func() {
It("adds endpoints", func() {
pool := NewPool()
endpoint := &Endpoint{}

pool.Add(endpoint)
foundEndpoint, found := pool.Sample()
Ω(found).To(BeTrue())
Ω(foundEndpoint).To(Equal(endpoint))
b := pool.Put(endpoint)
Ω(b).Should(BeTrue())
})

It("handles duplicate endpoints", func() {
pool := NewPool()

endpoint := &Endpoint{}

pool.Add(endpoint)
pool.Add(endpoint)

foundEndpoint, found := pool.Sample()
Ω(found).To(BeTrue())
Ω(foundEndpoint).To(Equal(endpoint))

pool.Remove(endpoint)

_, found = pool.Sample()
Ω(found).To(BeFalse())
pool.Put(endpoint)
b := pool.Put(endpoint)
Ω(b).Should(BeFalse())
})

It("handles equivalent (duplicate) endpoints", func() {
pool := NewPool()

endpoint1 := &Endpoint{Host: "1.2.3.4", Port: 5678}
endpoint2 := &Endpoint{Host: "1.2.3.4", Port: 5678}
endpoint1 := NewEndpoint("", "1.2.3.4", 5678, "", nil)
endpoint2 := NewEndpoint("", "1.2.3.4", 5678, "", nil)

pool.Add(endpoint1)
pool.Add(endpoint2)

_, found := pool.Sample()
Ω(found).To(BeTrue())

pool.Remove(endpoint1)

_, found = pool.Sample()
Ω(found).To(BeFalse())
pool.Put(endpoint1)
Ω(pool.Put(endpoint2)).Should(BeFalse())
})
})

Context("Remove", func() {
It("removes endpoints", func() {
pool := NewPool()

endpoint := &Endpoint{}
pool.Put(endpoint)

pool.Add(endpoint)

foundEndpoint, found := pool.Sample()
Ω(found).To(BeTrue())
Ω(foundEndpoint).To(Equal(endpoint))
b := pool.Remove(endpoint)
Ω(b).Should(BeTrue())
Ω(pool.IsEmpty()).Should(BeTrue())
})

pool.Remove(endpoint)
It("fails to remove an endpoint that doesn't exist", func() {
endpoint := &Endpoint{}

_, found = pool.Sample()
Ω(found).To(BeFalse())
b := pool.Remove(endpoint)
Ω(b).Should(BeFalse())
})

})

Context("IsEmpty", func() {
It("starts empty", func() {
Ω(NewPool().IsEmpty()).To(BeTrue())
Ω(pool.IsEmpty()).To(BeTrue())
})

It("empty after removing everything", func() {
pool := NewPool()

It("not empty after adding an endpoint", func() {
endpoint := &Endpoint{}
pool.Put(endpoint)

pool.Add(endpoint)

Ω(pool.IsEmpty()).To(BeFalse())
Ω(pool.IsEmpty()).Should(BeFalse())
})

It("is empty after removing everything", func() {
endpoint := &Endpoint{}
pool.Put(endpoint)
pool.Remove(endpoint)

Ω(pool.IsEmpty()).To(BeTrue())
})
})

It("finds by private instance id", func() {
pool := NewPool()
Context("PruneBefore", func() {
It("prunes endpoints that haven't been updated", func() {
e1 := NewEndpoint("", "1.2.3.4", 5678, "", nil)
e2 := NewEndpoint("", "5.6.7.8", 1234, "", nil)
pool.Put(e1)
pool.Put(e2)

endpointFoo := &Endpoint{Host: "1.2.3.4", Port: 1234, PrivateInstanceId: "foo"}
endpointBar := &Endpoint{Host: "5.6.7.8", Port: 5678, PrivateInstanceId: "bar"}

pool.Add(endpointFoo)
pool.Add(endpointBar)
t := time.Now().Add(1 * time.Second)
pool.PruneBefore(t)
Ω(pool.IsEmpty()).Should(BeTrue())
})

foundEndpoint, found := pool.FindByPrivateInstanceId("foo")
Ω(found).To(BeTrue())
Ω(foundEndpoint).To(Equal(endpointFoo))
It("does not prune updated endpoints", func() {
e1 := NewEndpoint("", "1.2.3.4", 5678, "", nil)
e2 := NewEndpoint("", "5.6.7.8", 1234, "", nil)
pool.Put(e1)
pool.Put(e2)

foundEndpoint, found = pool.FindByPrivateInstanceId("bar")
Ω(found).To(BeTrue())
Ω(foundEndpoint).To(Equal(endpointBar))
t := time.Now().Add(-1 * time.Second)
pool.PruneBefore(t)
Ω(pool.IsEmpty()).Should(BeFalse())

_, found = pool.FindByPrivateInstanceId("quux")
Ω(found).To(BeFalse())
iter := pool.Endpoints("")
n1 := iter.Next()
n2 := iter.Next()
Ω(n1).ShouldNot(Equal(n2))
})
})

It("Sample is randomish", func() {
pool := NewPool()

endpoint1 := &Endpoint{Host: "1.2.3.4", Port: 5678}
endpoint2 := &Endpoint{Host: "5.6.7.8", Port: 1234}
Context("MarkUpdated", func() {
It("updates all endpoints", func() {
e1 := NewEndpoint("", "1.2.3.4", 5678, "", nil)

pool.Add(endpoint1)
pool.Add(endpoint2)
pool.Put(e1)

var occurrences1, occurrences2 int
t := time.Time{}.Add(1 * time.Second)
pool.PruneBefore(t)
Ω(pool.IsEmpty()).Should(BeFalse())

for i := 0; i < 200; i += 1 {
foundEndpoint, _ := pool.Sample()
if foundEndpoint == endpoint1 {
occurrences1 += 1
} else {
occurrences2 += 1
}
}
pool.MarkUpdated(t)
pool.PruneBefore(t)
Ω(pool.IsEmpty()).Should(BeFalse())

Ω(occurrences1).ToNot(BeZero())
Ω(occurrences2).ToNot(BeZero())
pool.PruneBefore(t.Add(1 * time.Microsecond))
Ω(pool.IsEmpty()).Should(BeTrue())
})
})

// they should be arbitrarily close
Ω(occurrences1 - occurrences2).To(BeNumerically("~", 0, 50))
Context("Each", func() {
It("applies a function to each endpoint", func() {
e1 := NewEndpoint("", "1.2.3.4", 5678, "", nil)
e2 := NewEndpoint("", "5.6.7.8", 1234, "", nil)
pool.Put(e1)
pool.Put(e2)

endpoints := make(map[string]*Endpoint)
pool.Each(func(e *Endpoint) {
endpoints[e.CanonicalAddr()] = e
})
Ω(endpoints).Should(HaveLen(2))
Ω(endpoints[e1.CanonicalAddr()]).Should(Equal(e1))
Ω(endpoints[e2.CanonicalAddr()]).Should(Equal(e2))
})
})

It("marshals json", func() {
pool := NewPool()

pool.Add(&Endpoint{Host: "1.2.3.4", Port: 5678})
e := NewEndpoint("", "1.2.3.4", 5678, "", nil)
pool.Put(e)

json, err := pool.MarshalJSON()
Ω(err).ToNot(HaveOccurred())
Expand Down
18 changes: 12 additions & 6 deletions router/helper_test.go
Expand Up @@ -8,33 +8,39 @@ import (
"time"
)

func waitMsgReceived(registry *registry.CFRegistry, app *test.TestApp, expectedToBeFound bool, timeout time.Duration) bool {
func waitMsgReceived(registry *registry.RouteRegistry, app *test.TestApp, expectedToBeFound bool, timeout time.Duration) bool {
interval := time.Millisecond * 50
repetitions := int(timeout / interval)

for j := 0; j < repetitions; j++ {
if j > 0 {
time.Sleep(interval)
}

received := true
for _, url := range app.Urls() {
_, ok := registry.Lookup(url)
if ok != expectedToBeFound {
pool := registry.Lookup(url)
if expectedToBeFound && pool == nil {
received = false
break
} else if !expectedToBeFound && pool != nil {
received = false
break
}
}
if received {
return true
}
time.Sleep(interval)
}

return false
}

func waitAppRegistered(registry *registry.CFRegistry, app *test.TestApp, timeout time.Duration) bool {
func waitAppRegistered(registry *registry.RouteRegistry, app *test.TestApp, timeout time.Duration) bool {
return waitMsgReceived(registry, app, true, timeout)
}

func waitAppUnregistered(registry *registry.CFRegistry, app *test.TestApp, timeout time.Duration) bool {
func waitAppUnregistered(registry *registry.RouteRegistry, app *test.TestApp, timeout time.Duration) bool {
return waitMsgReceived(registry, app, false, timeout)
}

Expand Down
10 changes: 2 additions & 8 deletions router/registry_message.go
Expand Up @@ -14,12 +14,6 @@ type registryMessage struct {
PrivateInstanceId string `json:"private_instance_id"`
}

func (registryMessage *registryMessage) makeEndpoint() *route.Endpoint {
return &route.Endpoint{
Host: registryMessage.Host,
Port: registryMessage.Port,
ApplicationId: registryMessage.App,
Tags: registryMessage.Tags,
PrivateInstanceId: registryMessage.PrivateInstanceId,
}
func (rm *registryMessage) makeEndpoint() *route.Endpoint {
return route.NewEndpoint(rm.App, rm.Host, rm.Port, rm.PrivateInstanceId, rm.Tags)
}
10 changes: 5 additions & 5 deletions router/router.go
Expand Up @@ -26,7 +26,7 @@ type Router struct {
config *config.Config
proxy proxy.Proxy
mbusClient *yagnats.Client
registry *registry.CFRegistry
registry *registry.RouteRegistry
varz varz.Varz
component *vcap.VcapComponent

Expand All @@ -35,7 +35,7 @@ type Router struct {
logger *steno.Logger
}

func NewRouter(cfg *config.Config, p proxy.Proxy, mbusClient *yagnats.Client, r *registry.CFRegistry, v varz.Varz,
func NewRouter(cfg *config.Config, p proxy.Proxy, mbusClient *yagnats.Client, r *registry.RouteRegistry, v varz.Varz,
logCounter *vcap.LogCounter) (*Router, error) {

var host string
Expand Down Expand Up @@ -257,9 +257,9 @@ func (r *Router) greetMessage() ([]byte, error) {
}

d := vcap.RouterStart{
uuid,
[]string{host},
r.config.StartResponseDelayIntervalInSeconds,
Id: uuid,
Hosts: []string{host},
MinimumRegisterIntervalInSeconds: r.config.StartResponseDelayIntervalInSeconds,
}

return json.Marshal(d)
Expand Down
4 changes: 2 additions & 2 deletions router/router_drain_test.go
Expand Up @@ -25,7 +25,7 @@ var _ = Describe("Router", func() {
var config *cfg.Config

var mbusClient *yagnats.Client
var registry *rregistry.CFRegistry
var registry *rregistry.RouteRegistry
var varz vvarz.Varz
var router *Router
var natsPort uint16
Expand All @@ -42,7 +42,7 @@ var _ = Describe("Router", func() {
config.EndpointTimeout = 5 * time.Second

mbusClient = natsRunner.MessageBus.(*yagnats.Client)
registry = rregistry.NewCFRegistry(config, mbusClient)
registry = rregistry.NewRouteRegistry(config, mbusClient)
varz = vvarz.NewVarz(registry)
logcounter := vcap.NewLogCounter()
proxy := proxy.NewProxy(proxy.ProxyArgs{
Expand Down
11 changes: 5 additions & 6 deletions router/router_test.go
Expand Up @@ -33,7 +33,7 @@ var _ = Describe("Router", func() {
var config *cfg.Config

var mbusClient *yagnats.Client
var registry *rregistry.CFRegistry
var registry *rregistry.RouteRegistry
var varz vvarz.Varz
var router *Router

Expand All @@ -48,7 +48,7 @@ var _ = Describe("Router", func() {
config = test_util.SpecConfig(natsPort, statusPort, proxyPort)

mbusClient = natsRunner.MessageBus.(*yagnats.Client)
registry = rregistry.NewCFRegistry(config, mbusClient)
registry = rregistry.NewRouteRegistry(config, mbusClient)
varz = vvarz.NewVarz(registry)
logcounter := vcap.NewLogCounter()
proxy := proxy.NewProxy(proxy.ProxyArgs{
Expand Down Expand Up @@ -157,15 +157,14 @@ var _ = Describe("Router", func() {
app1.Listen()
Ω(waitAppRegistered(registry, app1, time.Second*1)).To(BeTrue())

time.Sleep(2 * time.Second)
time.Sleep(100 * time.Millisecond)
initialUpdateTime := fetchRecursively(readVarz(varz), "ms_since_last_registry_update").(float64)
// initialUpdateTime should be roughly 2 seconds.

app2 := test.NewGreetApp([]route.Uri{"test2.vcap.me"}, config.Port, mbusClient, nil)
app2.Listen()
Ω(waitAppRegistered(registry, app2, time.Second*1)).To(BeTrue())

// updateTime should be roughly 0 seconds
// updateTime should be after initial update time
updateTime := fetchRecursively(readVarz(varz), "ms_since_last_registry_update").(float64)
Ω(updateTime).To(BeNumerically("<", initialUpdateTime))
})
Expand Down Expand Up @@ -473,7 +472,7 @@ func fetchRecursively(x interface{}, s ...string) interface{} {
return x
}

func verify_health_z(host string, r *rregistry.CFRegistry) {
func verify_health_z(host string, r *rregistry.RouteRegistry) {
var req *http.Request
path := "/healthz"

Expand Down
11 changes: 11 additions & 0 deletions scripts/test
Expand Up @@ -2,8 +2,19 @@

set -e -x -u


export DROPSONDE_ORIGIN="gorouter_test/0"

function printStatus {
if [ $? -eq 0 ]; then
echo -e "\nSWEET SUITE SUCCESS"
else
echo -e "\nSUITE FAILURE"
fi
}

trap printStatus EXIT

. $(dirname $0)/gorequired

#Download & Install gnatsd into GOPATH (or use pre-installed version)
Expand Down
20 changes: 10 additions & 10 deletions varz/varz.go
Expand Up @@ -170,13 +170,13 @@ type Varz interface {

type RealVarz struct {
sync.Mutex
r *registry.CFRegistry
r *registry.RouteRegistry
activeApps *stats.ActiveApps
topApps *stats.TopApps
varz
}

func NewVarz(r *registry.CFRegistry) Varz {
func NewVarz(r *registry.RouteRegistry) Varz {
x := &RealVarz{r: r}

x.activeApps = stats.NewActiveApps()
Expand Down Expand Up @@ -227,18 +227,16 @@ func (x *RealVarz) ActiveApps() *stats.ActiveApps {
return x.activeApps
}

func (x *RealVarz) CaptureBadRequest(req *http.Request) {
func (x *RealVarz) CaptureBadRequest(*http.Request) {
x.Lock()
defer x.Unlock()

x.BadRequests++
x.Unlock()
}

func (x *RealVarz) CaptureBadGateway(req *http.Request) {
func (x *RealVarz) CaptureBadGateway(*http.Request) {
x.Lock()
defer x.Unlock()

x.BadGateways++
x.Unlock()
}

func (x *RealVarz) CaptureAppStats(b *route.Endpoint, t time.Time) {
Expand All @@ -250,7 +248,6 @@ func (x *RealVarz) CaptureAppStats(b *route.Endpoint, t time.Time) {

func (x *RealVarz) CaptureRoutingRequest(b *route.Endpoint, req *http.Request) {
x.Lock()
defer x.Unlock()

var t string
var ok bool
Expand All @@ -261,11 +258,12 @@ func (x *RealVarz) CaptureRoutingRequest(b *route.Endpoint, req *http.Request) {
}

x.varz.All.CaptureRequest()

x.Unlock()
}

func (x *RealVarz) CaptureRoutingResponse(endpoint *route.Endpoint, response *http.Response, startedAt time.Time, duration time.Duration) {
x.Lock()
defer x.Unlock()

var tags string
var ok bool
Expand All @@ -277,6 +275,8 @@ func (x *RealVarz) CaptureRoutingResponse(endpoint *route.Endpoint, response *ht

x.CaptureAppStats(endpoint, startedAt)
x.varz.All.CaptureResponse(response, duration)

x.Unlock()
}

func transform(x interface{}, y map[string]interface{}) error {
Expand Down
12 changes: 3 additions & 9 deletions varz/varz_test.go
Expand Up @@ -17,10 +17,10 @@ import (

var _ = Describe("Varz", func() {
var Varz Varz
var Registry *registry.CFRegistry
var Registry *registry.RouteRegistry

BeforeEach(func() {
Registry = registry.NewCFRegistry(config.DefaultConfig(), fakeyagnats.New())
Registry = registry.NewRouteRegistry(config.DefaultConfig(), fakeyagnats.New())
Varz = NewVarz(Registry)
})

Expand Down Expand Up @@ -72,13 +72,7 @@ var _ = Describe("Varz", func() {
It("has urls", func() {
Ω(findValue(Varz, "urls")).To(Equal(float64(0)))

var fooReg = &route.Endpoint{
Host: "192.168.1.1",
Port: 1234,
Tags: map[string]string{},

ApplicationId: "12345",
}
var fooReg = route.NewEndpoint("12345", "192.168.1.1", 1234, "", map[string]string{})

// Add a route
Registry.Register("foo.vcap.me", fooReg)
Expand Down