diff --git a/docs/03-how-to-add-new-route-option.md b/docs/03-how-to-add-new-route-option.md index 39ce25447..308213fcc 100644 --- a/docs/03-how-to-add-new-route-option.md +++ b/docs/03-how-to-add-new-route-option.md @@ -22,6 +22,11 @@ applications: - route: example2.com options: loadbalancing: least-connection + - route: example3.com + options: + loadbalancing: hash + hash_header: tenant-id + hash_balance: 1.25 ``` **NOTE**: In the implementation, the `options` property of a route represents per-route features. diff --git a/src/code.cloudfoundry.org/gorouter/proxy/round_tripper/proxy_round_tripper.go b/src/code.cloudfoundry.org/gorouter/proxy/round_tripper/proxy_round_tripper.go index 88cfd20a5..84252262f 100644 --- a/src/code.cloudfoundry.org/gorouter/proxy/round_tripper/proxy_round_tripper.go +++ b/src/code.cloudfoundry.org/gorouter/proxy/round_tripper/proxy_round_tripper.go @@ -127,6 +127,20 @@ func (rt *roundTripper) RoundTrip(originalRequest *http.Request) (*http.Response stickyEndpointID, mustBeSticky := handlers.GetStickySession(request, rt.config.StickySessionCookieNames, rt.config.StickySessionsForAuthNegotiate) numberOfEndpoints := reqInfo.RoutePool.NumEndpoints() iter := reqInfo.RoutePool.Endpoints(rt.logger, stickyEndpointID, mustBeSticky, rt.config.LoadBalanceAZPreference, rt.config.Zone) + if reqInfo.RoutePool.LoadBalancingAlgorithm == config.LOAD_BALANCE_HB { + if reqInfo.RoutePool.HashRoutingProperties == nil { + rt.logger.Error("hash-routing-properties-nil", slog.String("host", reqInfo.RoutePool.Host())) + + } else { + headerName := reqInfo.RoutePool.HashRoutingProperties.Header + headerValue := request.Header.Get(headerName) + if headerValue != "" { + iter.(*route.HashBased).HeaderValue = headerValue + } else { + iter = reqInfo.RoutePool.FallBackToDefaultLoadBalancing(rt.config.LoadBalance, rt.logger, stickyEndpointID, mustBeSticky, rt.config.LoadBalanceAZPreference, rt.config.Zone) + } + } + } // The selectEndpointErr needs to be tracked separately. If we get an error // while selecting an endpoint we might just have run out of routes. In diff --git a/src/code.cloudfoundry.org/gorouter/proxy/round_tripper/proxy_round_tripper_test.go b/src/code.cloudfoundry.org/gorouter/proxy/round_tripper/proxy_round_tripper_test.go index 9d270867c..6abe4d218 100644 --- a/src/code.cloudfoundry.org/gorouter/proxy/round_tripper/proxy_round_tripper_test.go +++ b/src/code.cloudfoundry.org/gorouter/proxy/round_tripper/proxy_round_tripper_test.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "io" + "math/rand" "net" "net/http" "net/http/httptest" @@ -1700,6 +1701,167 @@ var _ = Describe("ProxyRoundTripper", func() { }) }) + Context("when load-balancing strategy is set to hash-based routing", func() { + JustBeforeEach(func() { + for i := 1; i <= 3; i++ { + endpoint = route.NewEndpoint(&route.EndpointOpts{ + AppId: fmt.Sprintf("appID%d", i), + Host: fmt.Sprintf("%d.%d.%d.%d", i, i, i, i), + Port: 9090, + PrivateInstanceId: fmt.Sprintf("instanceID%d", i), + PrivateInstanceIndex: fmt.Sprintf("%d", i), + AvailabilityZone: AZ, + LoadBalancingAlgorithm: config.LOAD_BALANCE_HB, + HashHeaderName: "X-Hash", + }) + + _ = routePool.Put(endpoint) + Expect(routePool.HashLookupTable).ToNot(BeNil()) + + } + }) + + It("routes requests with same hash header value to the same endpoint", func() { + req.Header.Set("X-Hash", "value") + reqInfo, err := handlers.ContextRequestInfo(req) + Expect(err).ToNot(HaveOccurred()) + reqInfo.RoutePool = routePool + + var selectedEndpoints []*route.Endpoint + + // Make multiple requests with the same hash value + for i := 0; i < 5; i++ { + _, err = proxyRoundTripper.RoundTrip(req) + Expect(err).NotTo(HaveOccurred()) + selectedEndpoints = append(selectedEndpoints, reqInfo.RouteEndpoint) + } + + // All requests should go to the same endpoint + firstEndpoint := selectedEndpoints[0] + for _, ep := range selectedEndpoints[1:] { + Expect(ep.PrivateInstanceId).To(Equal(firstEndpoint.PrivateInstanceId)) + } + }) + + It("routes requests with different hash header values to potentially different endpoints", func() { + reqInfo, err := handlers.ContextRequestInfo(req) + Expect(err).ToNot(HaveOccurred()) + reqInfo.RoutePool = routePool + + endpointDistribution := make(map[string]int) + + // Make requests with different hash values + for i := 0; i < 10; i++ { + req.Header.Set("X-Hash", fmt.Sprintf("value-%d", i)) + _, err = proxyRoundTripper.RoundTrip(req) + Expect(err).NotTo(HaveOccurred()) + endpointDistribution[reqInfo.RouteEndpoint.PrivateInstanceId]++ + } + + // Should distribute across multiple endpoints (not all to one) + Expect(len(endpointDistribution)).To(BeNumerically(">", 1)) + }) + + It("falls back to default load balancing algorithm when hash header is missing", func() { + reqInfo, err := handlers.ContextRequestInfo(req) + Expect(err).ToNot(HaveOccurred()) + + reqInfo.RoutePool = routePool + + _, err = proxyRoundTripper.RoundTrip(req) + Expect(err).NotTo(HaveOccurred()) + + infoLogs := logger.Lines(zap.InfoLevel) + count := 0 + for i := 0; i < len(infoLogs); i++ { + if strings.Contains(infoLogs[i], "hash-based-routing-header-not-found") { + count++ + } + } + Expect(count).To(Equal(1)) + // Verify it still selects an endpoint + Expect(reqInfo.RouteEndpoint).ToNot(BeNil()) + }) + + Context("when sticky session cookies (JSESSIONID and VCAP_ID) are on the request", func() { + var ( + sessionCookie *http.Cookie + cookies []*http.Cookie + ) + + JustBeforeEach(func() { + sessionCookie = &http.Cookie{ + Name: StickyCookieKey, //JSESSIONID + } + transport.RoundTripStub = func(req *http.Request) (*http.Response, error) { + resp := &http.Response{StatusCode: http.StatusTeapot, Header: make(map[string][]string)} + //Attach the same JSESSIONID on to the response if it exists on the request + + if len(req.Cookies()) > 0 { + for _, cookie := range req.Cookies() { + if cookie.Name == StickyCookieKey { + resp.Header.Add(round_tripper.CookieHeader, cookie.String()) + return resp, nil + } + } + } + + sessionCookie.Value, _ = uuid.GenerateUUID() + resp.Header.Add(round_tripper.CookieHeader, sessionCookie.String()) + return resp, nil + } + resp, err := proxyRoundTripper.RoundTrip(req) + Expect(err).ToNot(HaveOccurred()) + + cookies = resp.Cookies() + Expect(cookies).To(HaveLen(2)) + + }) + + Context("when there is a JSESSIONID and __VCAP_ID__ set on the request", func() { + It("will always route to the instance specified with the __VCAP_ID__ cookie", func() { + + // Generate 20 random values for the hash header, so chance that all go to instanceID1 + // by accident is 0.33^20 + for i := 0; i < 20; i++ { + randomStr := make([]byte, 8) + for j := range randomStr { + randomStr[j] = byte('a' + rand.Intn(26)) + } + + req.Header.Set("X-Hash", string(randomStr)) + reqInfo, err := handlers.ContextRequestInfo(req) + req.AddCookie(&http.Cookie{Name: round_tripper.VcapCookieId, Value: "instanceID1"}) + req.AddCookie(&http.Cookie{Name: StickyCookieKey, Value: "abc"}) + + Expect(err).ToNot(HaveOccurred()) + reqInfo.RoutePool = routePool + + resp, err := proxyRoundTripper.RoundTrip(req) + Expect(err).ToNot(HaveOccurred()) + + new_cookies := resp.Cookies() + Expect(new_cookies).To(HaveLen(2)) + + for _, cookie := range new_cookies { + Expect(cookie.Name).To(SatisfyAny( + Equal(StickyCookieKey), + Equal(round_tripper.VcapCookieId), + )) + if cookie.Name == StickyCookieKey { + Expect(cookie.Value).To(Equal("abc")) + } else { + Expect(cookie.Value).To(Equal("instanceID1")) + } + } + + } + + }) + }) + }) + }) + Context("when endpoint timeout is not 0", func() { var reqCh chan *http.Request BeforeEach(func() { diff --git a/src/code.cloudfoundry.org/gorouter/route/hash_based.go b/src/code.cloudfoundry.org/gorouter/route/hash_based.go new file mode 100644 index 000000000..db7aaac1a --- /dev/null +++ b/src/code.cloudfoundry.org/gorouter/route/hash_based.go @@ -0,0 +1,250 @@ +package route + +import ( + "context" + "errors" + "log/slog" + "sync" + + log "code.cloudfoundry.org/gorouter/logger" +) + +// HashBased load balancing algorithm distributes requests based on a hash of a specific header value. +// The sticky session cookie has precedence over hash-based routing and the request should be routed to the instance stored in the cookie. +// If requests do not contain the hash-related header set configured for the hash-based route option, use the default load-balancing algorithm. +type HashBased struct { + lock *sync.Mutex + + logger *slog.Logger + pool *EndpointPool + lastEndpoint *Endpoint + lastLookupTableIndex uint64 + + stickyEndpointID string + mustBeSticky bool + + HeaderValue string +} + +// NewHashBased initializes an endpoint iterator that selects endpoints based on a hash of a header value. +// The global properties locallyOptimistic and localAvailabilityZone will be ignored when using Hash-Based Routing. +func NewHashBased(logger *slog.Logger, p *EndpointPool, initial string, mustBeSticky bool, locallyOptimistic bool, localAvailabilityZone string) EndpointIterator { + return &HashBased{ + logger: logger, + pool: p, + lock: &sync.Mutex{}, + stickyEndpointID: initial, + mustBeSticky: mustBeSticky, + } +} + +// Next selects the next endpoint based on the hash of the header value. +// If a sticky session endpoint is available and not overloaded, it will be returned. +// If the request must be sticky and the sticky endpoint is unavailable or overloaded, nil will be returned. +// If no sticky session is present, the endpoint will be selected based on the hash of the header value. +// It returns the same endpoint for the same header value consistently. +// If the hash lookup fails or the endpoint is not found, nil will be returned. +func (h *HashBased) Next(attempt int) *Endpoint { + h.lock.Lock() + defer h.lock.Unlock() + + endpoint := h.findEndpointIfStickySession() + if endpoint == nil && h.mustBeSticky { + return nil + } + + if endpoint != nil { + h.lastEndpoint = endpoint + return endpoint + } + + if len(h.pool.endpoints) == 0 { + h.logger.Warn("hash-based-routing-pool-empty", slog.String("host", h.pool.host)) + return nil + } + + endpoint = h.getSingleEndpoint() + if endpoint != nil { + h.lastEndpoint = endpoint + return endpoint + } + + if h.pool.HashLookupTable == nil { + h.logger.Error("hash-based-routing-failed", slog.String("host", h.pool.host), log.ErrAttr(errors.New("Lookup table is empty"))) + return nil + } + + if attempt == 0 || h.lastLookupTableIndex == 0 { + initialLookupTableIndex, _, err := h.pool.HashLookupTable.GetInstanceForHashHeader(h.HeaderValue) + + if err != nil { + h.logger.Error( + "hash-based-routing-failed", + slog.String("host", h.pool.host), + log.ErrAttr(err), + ) + return nil + } + + endpoint = h.findEndpoint(initialLookupTableIndex, attempt) + } else { + // On retries, start looking from the next index in the lookup table + nextIndex := (h.lastLookupTableIndex + 1) % h.pool.HashLookupTable.GetLookupTableSize() + endpoint = h.findEndpoint(nextIndex, attempt) + } + + if endpoint != nil { + h.lastEndpoint = endpoint + } + return endpoint +} + +func (h *HashBased) findEndpoint(index uint64, attempt int) *Endpoint { + // Ensure we don't exceed the lookup table size + lookupTableSize := h.pool.HashLookupTable.GetLookupTableSize() + + // Normalize index + currentIndex := index % lookupTableSize + // Keep track of endpoints already visited, to avoid visiting them twice + visitedEndpoints := make(map[string]bool) + + numberOfEndpoints := len(h.pool.HashLookupTable.GetEndpointList()) + + lastEndpointPrivateId := "" + if attempt > 0 && h.lastEndpoint != nil { + lastEndpointPrivateId = h.lastEndpoint.PrivateInstanceId + } + + // abort when we have visited all available endpoints unsuccessfully + for len(visitedEndpoints) < numberOfEndpoints { + id := h.pool.HashLookupTable.GetEndpointId(currentIndex) + + if visitedEndpoints[id] || id == lastEndpointPrivateId { + currentIndex = (currentIndex + 1) % lookupTableSize + continue + } + visitedEndpoints[id] = true + + endpointElem := h.pool.findById(id) + if endpointElem == nil { + h.logger.Error("hash-based-routing-failed", slog.String("host", h.pool.host), log.ErrAttr(errors.New("Endpoint not found in pool")), slog.String("endpoint-id", id)) + currentIndex = (currentIndex + 1) % lookupTableSize + continue + } + + lastEndpointPrivateId = id + + if h.pool.HashRoutingProperties.BalanceFactor <= 0 || !h.isImbalancedOrOverloaded(endpointElem) { + h.lastLookupTableIndex = currentIndex + return endpointElem.endpoint + } + + currentIndex = (currentIndex + 1) % lookupTableSize + } + // All endpoints checked and overloaded or not found + h.logger.Error("hash-based-routing-failed", slog.String("host", h.pool.host), log.ErrAttr(errors.New("All endpoints are overloaded"))) + return nil +} + +func (h *HashBased) isImbalancedOrOverloaded(e *endpointElem) bool { + endpoint := e.endpoint + return h.IsImbalancedOrOverloaded(endpoint, e.isOverloaded()) +} + +func (h *HashBased) IsImbalancedOrOverloaded(endpoint *Endpoint, isEndpointOverloaded bool) bool { + avgNumberOfInFlightRequests := h.CalculateAverageLoad() + currentInFlightRequestCount := endpoint.Stats.NumberConnections.Count() + balanceFactor := h.pool.HashRoutingProperties.BalanceFactor + + if isEndpointOverloaded { + h.logger.Debug("hash-based-routing-endpoint-overloaded", slog.String("host", h.pool.host), slog.String("endpoint-id", endpoint.PrivateInstanceId), slog.Int64("endpoint-connections", currentInFlightRequestCount)) + return true + } + + // Check if avgNumberOfInFlightRequests is 0 to avoid division by 0 + if avgNumberOfInFlightRequests == 0 || float64(currentInFlightRequestCount)/avgNumberOfInFlightRequests > balanceFactor { + h.logger.Debug("hash-based-routing-endpoint-imbalanced", slog.String("host", h.pool.host), slog.String("endpoint-id", endpoint.PrivateInstanceId), slog.Int64("endpoint-connections", endpoint.Stats.NumberConnections.Count()), slog.Float64("average-load", avgNumberOfInFlightRequests)) + return true + } + return false +} + +// findEndpointIfStickySession checks if there is a sticky session endpoint and returns it if available. +// If the sticky session endpoint is overloaded, returns nil. +func (h *HashBased) findEndpointIfStickySession() *Endpoint { + var e *endpointElem + if h.stickyEndpointID != "" { + e = h.pool.findById(h.stickyEndpointID) + if e != nil && e.isOverloaded() { + if h.mustBeSticky { + if h.logger.Enabled(context.Background(), slog.LevelDebug) { + h.logger.Debug("endpoint-overloaded-but-request-must-be-sticky", e.endpoint.ToLogData()...) + } + return nil + } + e = nil + } + + if e == nil && h.mustBeSticky { + h.logger.Debug("endpoint-missing-but-request-must-be-sticky", slog.String("requested-endpoint", h.stickyEndpointID)) + return nil + } + + if !h.mustBeSticky { + h.logger.Debug("endpoint-missing-choosing-alternate", slog.String("requested-endpoint", h.stickyEndpointID)) + h.stickyEndpointID = "" + } + } + + if e != nil { + e.RLock() + defer e.RUnlock() + return e.endpoint + } + return nil +} + +// EndpointFailed notifies the endpoint pool that the last selected endpoint has failed. +func (h *HashBased) EndpointFailed(err error) { + if h.lastEndpoint != nil { + h.pool.EndpointFailed(h.lastEndpoint, err) + } +} + +// PreRequest increments the in-flight request count for the selected endpoint from current Gorouter. +func (h *HashBased) PreRequest(e *Endpoint) { + e.Stats.NumberConnections.Increment() +} + +// PostRequest decrements the in-flight request count for the selected endpoint from current Gorouter. +func (h *HashBased) PostRequest(e *Endpoint) { + e.Stats.NumberConnections.Decrement() +} + +// CalculateAverageLoad computes the average number of in-flight requests across all endpoints in the pool. +func (h *HashBased) CalculateAverageLoad() float64 { + if len(h.pool.endpoints) == 0 { + return 0 + } + + var currentInFlightRequestCount int64 + for _, endpointElem := range h.pool.endpoints { + endpointElem.RLock() + currentInFlightRequestCount += endpointElem.endpoint.Stats.NumberConnections.Count() + endpointElem.RUnlock() + } + + return float64(currentInFlightRequestCount) / float64(len(h.pool.endpoints)) +} + +func (h *HashBased) getSingleEndpoint() *Endpoint { + if len(h.pool.endpoints) == 1 { + e := h.pool.endpoints[0] + if e.isOverloaded() { + return nil + } + + return e.endpoint + } + return nil +} diff --git a/src/code.cloudfoundry.org/gorouter/route/hash_based_test.go b/src/code.cloudfoundry.org/gorouter/route/hash_based_test.go new file mode 100644 index 000000000..c7fcd79a3 --- /dev/null +++ b/src/code.cloudfoundry.org/gorouter/route/hash_based_test.go @@ -0,0 +1,475 @@ +package route_test + +import ( + _ "errors" + "hash/fnv" + "time" + + "code.cloudfoundry.org/gorouter/config" + "code.cloudfoundry.org/gorouter/route" + "code.cloudfoundry.org/gorouter/test_util" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + "github.com/onsi/gomega/gbytes" +) + +var _ = Describe("HashBased", func() { + var ( + pool *route.EndpointPool + logger *test_util.TestLogger + ) + + BeforeEach(func() { + logger = test_util.NewTestLogger("test") + pool = route.NewPool(&route.PoolOpts{ + Logger: logger.Logger, + RetryAfterFailure: 2 * time.Minute, + Host: "", + ContextPath: "", + MaxConnsPerBackend: 500, + LoadBalancingAlgorithm: config.LOAD_BALANCE_HB, + HashHeader: "tenant-id", + }) + }) + + Describe("Next", func() { + + Context("when pool is empty", func() { + It("does not select an endpoint", func() { + iter := route.NewHashBased(logger.Logger, pool, "", false, false, "") + Expect(iter.Next(0)).To(BeNil()) + }) + }) + + Context("when pool has endpoints", func() { + var ( + endpoints []*route.Endpoint + ) + BeforeEach(func() { + e1 := route.NewEndpoint(&route.EndpointOpts{Host: "1.2.3.4", Port: 5678, LoadBalancingAlgorithm: "hash", HashHeaderName: "tenant-id", PrivateInstanceId: "ID1"}) + e2 := route.NewEndpoint(&route.EndpointOpts{Host: "2.2.3.4", Port: 5678, LoadBalancingAlgorithm: "hash", HashHeaderName: "tenant-id", PrivateInstanceId: "ID2"}) + endpoints = []*route.Endpoint{e1, e2} + for _, e := range endpoints { + pool.Put(e) + } + + }) + It("It returns the same endpoint for the same header value", func() { + iter := route.NewHashBased(logger.Logger, pool, "", false, false, "") + iter.(*route.HashBased).HeaderValue = "tenant-1" + first := iter.Next(0) + second := iter.Next(0) + Expect(first).NotTo(BeNil()) + Expect(second).NotTo(BeNil()) + Expect(first).To(Equal(second)) + }) + }) + + Context("when endpoint overloaded", func() { + var ( + endpoints []*route.Endpoint + e1 *route.Endpoint + e2 *route.Endpoint + e3 *route.Endpoint + ) + It("It returns the next endpoint for the same header value when balancer factor set", func() { + e1 = route.NewEndpoint(&route.EndpointOpts{Host: "1.2.3.4", Port: 5678, LoadBalancingAlgorithm: "hash", HashHeaderName: "tenant-id", HashBalanceFactor: 1.2, PrivateInstanceId: "ID1"}) + e2 = route.NewEndpoint(&route.EndpointOpts{Host: "2.2.3.4", Port: 5678, LoadBalancingAlgorithm: "hash", HashHeaderName: "tenant-id", HashBalanceFactor: 1.2, PrivateInstanceId: "ID2"}) + e3 = route.NewEndpoint(&route.EndpointOpts{Host: "3.2.3.4", Port: 5678, LoadBalancingAlgorithm: "hash", HashHeaderName: "tenant-id", HashBalanceFactor: 1.2, PrivateInstanceId: "ID3"}) + endpoints = []*route.Endpoint{e1, e2, e3} + for _, e := range endpoints { + pool.Put(e) + } + iter := route.NewHashBased(logger.Logger, pool, "", false, false, "") + iter.(*route.HashBased).HeaderValue = "tenant-1" + first := iter.Next(0) + Expect(iter.Next(0)).To(Equal(first)) + for i := 0; i < 6; i++ { + iter.PreRequest(first) + } + second := iter.Next(0) + Expect(second).NotTo(Equal(first)) + }) + It("It returns the same overloaded endpoint for the same header value when balancer factor not set", func() { + e1 = route.NewEndpoint(&route.EndpointOpts{Host: "1.2.3.4", Port: 5678, LoadBalancingAlgorithm: "hash", HashHeaderName: "tenant-id", HashBalanceFactor: 0, PrivateInstanceId: "ID1"}) + e2 = route.NewEndpoint(&route.EndpointOpts{Host: "2.2.3.4", Port: 5678, LoadBalancingAlgorithm: "hash", HashHeaderName: "tenant-id", HashBalanceFactor: 0, PrivateInstanceId: "ID2"}) + e3 = route.NewEndpoint(&route.EndpointOpts{Host: "3.2.3.4", Port: 5678, LoadBalancingAlgorithm: "hash", HashHeaderName: "tenant-id", HashBalanceFactor: 0, PrivateInstanceId: "ID3"}) + endpoints = []*route.Endpoint{e1, e2, e3} + for _, e := range endpoints { + pool.Put(e) + } + iter := route.NewHashBased(logger.Logger, pool, "", false, false, "") + iter.(*route.HashBased).HeaderValue = "tenant-1" + first := iter.Next(0) + Expect(iter.Next(0)).To(Equal(first)) + for i := 0; i < 6; i++ { + iter.PreRequest(first) + } + second := iter.Next(0) + Expect(second).To(Equal(first)) + }) + + }) + + Context("with retries", func() { + var ( + endpoints []*route.Endpoint + e1 *route.Endpoint + e2 *route.Endpoint + e3 *route.Endpoint + e4 *route.Endpoint + MaglevLookupTable = []int{2, 2, 1, 0, 1, 0, 0, 0, 2, 0, 1, 3, 1, 0, 1, 0, 3, 0, 3, 0, 0, 0, 1, 0, 1, 2, 2, 0, 3, 2, 3, 0, 1, 0, 1, 0, 3, 3, 2, 0, 3, 1, 2, 0, 3, 0, 1, 0, 2, 3, 2, 3, 2, 0, 1, 2, 1, 0, 3, 2, 2, 1, 1, 2, 1, 3, 1, 2, 2, 0, 3, 2, 3, 1, 1, 3, 1, 3, 1, 0, 2, 1, 3, 1, 2, 2, 1, 3, 2, 2, 2, 3, 3, 1, 3, 0, 3, 2, 3, 3, 0} + ) + It("It returns next endpoint from maglev lookup table", func() { + e1 = route.NewEndpoint(&route.EndpointOpts{Host: "1.2.3.4", Port: 5678, LoadBalancingAlgorithm: "hash", HashHeaderName: "tenant-id", PrivateInstanceId: "ID1"}) + e2 = route.NewEndpoint(&route.EndpointOpts{Host: "2.2.3.4", Port: 5678, LoadBalancingAlgorithm: "hash", HashHeaderName: "tenant-id", PrivateInstanceId: "ID2"}) + e3 = route.NewEndpoint(&route.EndpointOpts{Host: "3.2.3.4", Port: 5678, LoadBalancingAlgorithm: "hash", HashHeaderName: "tenant-id", PrivateInstanceId: "ID3"}) + e4 = route.NewEndpoint(&route.EndpointOpts{Host: "4.2.3.4", Port: 5678, LoadBalancingAlgorithm: "hash", HashHeaderName: "tenant-id", PrivateInstanceId: "ID4"}) + + endpoints = []*route.Endpoint{e1, e2, e3, e4} + endpointIDList := make([]string, 0, 4) + for _, e := range endpoints { + pool.Put(e) + endpointIDList = append(endpointIDList, e.PrivateInstanceId) + } + maglevMock := NewMockHashLookupTable(MaglevLookupTable, endpointIDList) + pool.HashLookupTable = maglevMock + iter := route.NewHashBased(logger.Logger, pool, "", false, false, "") + iter.(*route.HashBased).HeaderValue = "tenant-1" + // The returned endpoint has always ID3 according to the Maglev lookup table + first := iter.Next(0) + Expect(first).To(Equal(e4)) + second := iter.Next(1) + Expect(second).To(Equal(e1)) + third := iter.Next(2) + Expect(third).To(Equal(e4)) + }) + It("It returns the next not overloaded endpoint for the second attempt", func() { + e1 = route.NewEndpoint(&route.EndpointOpts{Host: "1.2.3.4", Port: 5678, LoadBalancingAlgorithm: "hash", HashHeaderName: "tenant-id", HashBalanceFactor: 1.2, PrivateInstanceId: "ID1"}) + e2 = route.NewEndpoint(&route.EndpointOpts{Host: "2.2.3.4", Port: 5678, LoadBalancingAlgorithm: "hash", HashHeaderName: "tenant-id", HashBalanceFactor: 1.2, PrivateInstanceId: "ID2"}) + e3 = route.NewEndpoint(&route.EndpointOpts{Host: "3.2.3.4", Port: 5678, LoadBalancingAlgorithm: "hash", HashHeaderName: "tenant-id", HashBalanceFactor: 1.2, PrivateInstanceId: "ID3"}) + e4 = route.NewEndpoint(&route.EndpointOpts{Host: "4.2.3.4", Port: 5678, LoadBalancingAlgorithm: "hash", HashHeaderName: "tenant-id", PrivateInstanceId: "ID3"}) + + endpoints = []*route.Endpoint{e1, e2, e3, e4} + for _, e := range endpoints { + pool.Put(e) + } + iter := route.NewHashBased(logger.Logger, pool, "", false, false, "") + iter.(*route.HashBased).HeaderValue = "tenant-1" + firstAttemptResult := iter.Next(0) + Expect(iter.Next(0)).To(Equal(firstAttemptResult)) + for i := 0; i < 6; i++ { + // Simulate requests to overload the endpoints + iter.PreRequest(e1) + iter.PreRequest(e2) + } + secondAttemptResult := iter.Next(1) + Expect(secondAttemptResult).NotTo(Equal(firstAttemptResult)) + Expect(secondAttemptResult).NotTo(Equal(e1)) + Expect(secondAttemptResult).NotTo(Equal(e2)) + }) + }) + + Context("when using sticky sessions", func() { + var ( + endpoints []*route.Endpoint + iter route.EndpointIterator + ) + + BeforeEach(func() { + e1 := route.NewEndpoint(&route.EndpointOpts{Host: "1.2.3.4", Port: 5678, LoadBalancingAlgorithm: "hash", PrivateInstanceId: "ID1"}) + e2 := route.NewEndpoint(&route.EndpointOpts{Host: "2.2.3.4", Port: 5678, LoadBalancingAlgorithm: "hash", PrivateInstanceId: "ID2"}) + e3 := route.NewEndpoint(&route.EndpointOpts{Host: "3.2.3.4", Port: 5678, LoadBalancingAlgorithm: "hash", HashHeaderName: "tenant-id", PrivateInstanceId: "ID3"}) + endpoints = []*route.Endpoint{e1, e2, e3} + for _, e := range endpoints { + pool.Put(e) + } + }) + + Context("when mustBeSticky is true", func() { + BeforeEach(func() { + iter = route.NewHashBased(logger.Logger, pool, "ID1", true, false, "") + }) + + It("returns the sticky endpoint when it exists", func() { + endpoint := iter.Next(0) + Expect(endpoint).NotTo(BeNil()) + Expect(endpoint.PrivateInstanceId).To(Equal("ID1")) + }) + + It("returns nil when sticky endpoint doesn't exist", func() { + iter = route.NewHashBased(logger.Logger, pool, "nonexistent-id", true, false, "") + Expect(iter.Next(0)).To(BeNil()) + }) + It("returns nil when sticky endpoint is overloaded and mustBeSticky is true", func() { + iter = route.NewHashBased(logger.Logger, pool, "ID1", true, false, "") + for i := 0; i < 1000; i++ { + iter.PreRequest(endpoints[0]) + } + Expect(iter.Next(0)).To(BeNil()) + }) + }) + + Context("when mustBeSticky is false", func() { + BeforeEach(func() { + iter = route.NewHashBased(logger.Logger, pool, "ID1", false, false, "") + }) + + It("returns the sticky endpoint when it exists", func() { + endpoint := iter.Next(0) + Expect(endpoint).NotTo(BeNil()) + Expect(endpoint.PrivateInstanceId).To(Equal("ID1")) + }) + + It("falls back to hash-based routing when sticky endpoint doesn't exist", func() { + iter = route.NewHashBased(logger.Logger, pool, "nonexistent-id", false, false, "") + hashIter := iter.(*route.HashBased) + hashIter.HeaderValue = "some-value" + endpoint := iter.Next(0) + Expect(endpoint).NotTo(BeNil()) + }) + }) + }) + }) + + Context("when testing PreRequest and PostRequest", func() { + var ( + endpoint *route.Endpoint + iter route.EndpointIterator + ) + + BeforeEach(func() { + endpoint = route.NewEndpoint(&route.EndpointOpts{Host: "1.2.3.4", Port: 5678, LoadBalancingAlgorithm: "hash", PrivateInstanceId: "ID1"}) + pool.Put(endpoint) + iter = route.NewHashBased(logger.Logger, pool, "", false, false, "") + }) + + It("increments connection count on PreRequest", func() { + initialCount := endpoint.Stats.NumberConnections.Count() + iter.PreRequest(endpoint) + Expect(endpoint.Stats.NumberConnections.Count()).To(Equal(initialCount + 1)) + }) + + It("decrements connection count on PostRequest", func() { + iter.PreRequest(endpoint) + initialCount := endpoint.Stats.NumberConnections.Count() + iter.PostRequest(endpoint) + Expect(endpoint.Stats.NumberConnections.Count()).To(Equal(initialCount - 1)) + }) + }) + Describe("IsImbalancedOrOverloaded", func() { + var iter *route.HashBased + var endpoints []*route.Endpoint + + BeforeEach(func() { + iter = route.NewHashBased(logger.Logger, pool, "", false, false, "").(*route.HashBased) + }) + + Context("when endpoints have a lot of in-flight requests", func() { + var e1, e2, e3 *route.Endpoint + BeforeEach(func() { + e1 = route.NewEndpoint(&route.EndpointOpts{Host: "1.2.3.4", Port: 5678, LoadBalancingAlgorithm: "hash", HashHeaderName: "tenant-id", HashBalanceFactor: 1.2, PrivateInstanceId: "ID1"}) + e2 = route.NewEndpoint(&route.EndpointOpts{Host: "2.2.3.4", Port: 5678, LoadBalancingAlgorithm: "hash", HashHeaderName: "tenant-id", HashBalanceFactor: 1.2, PrivateInstanceId: "ID2"}) + e3 = route.NewEndpoint(&route.EndpointOpts{Host: "3.2.3.4", Port: 5678, LoadBalancingAlgorithm: "hash", HashHeaderName: "tenant-id", HashBalanceFactor: 1.2, PrivateInstanceId: "ID3"}) + endpoints = []*route.Endpoint{e1, e2, e3} + for _, e := range endpoints { + pool.Put(e) + } + + }) + It("mark the endpoint as overloaded", func() { + for i := 0; i < 500; i++ { + iter.PreRequest(e1) + } + // in general 500 in flight requests counted by e1 + Expect(iter.IsImbalancedOrOverloaded(e1, true)).To(BeTrue()) + }) + It("do not mark as imbalanced if every endpoint has 499 in-flight requests", func() { + for i := 0; i < 498; i++ { + iter.PreRequest(e1) + } + for i := 0; i < 498; i++ { + iter.PreRequest(e2) + } + for i := 0; i < 498; i++ { + iter.PreRequest(e3) + } + // in general 500 in flight requests counted by e1 + Expect(iter.IsImbalancedOrOverloaded(e1, false)).To(BeFalse()) + }) + + It("mark endpoint as overloaded if every endpoint has 500 in-flight requests", func() { + for i := 0; i < 499; i++ { + iter.PreRequest(e1) + } + for i := 0; i < 499; i++ { + iter.PreRequest(e2) + } + for i := 0; i < 499; i++ { + iter.PreRequest(e3) + } + // in general 500 in flight requests counted by e1 + Expect(iter.IsImbalancedOrOverloaded(e1, true)).To(BeTrue()) + Eventually(logger).Should(gbytes.Say("hash-based-routing-endpoint-overloaded")) + Expect(iter.IsImbalancedOrOverloaded(e2, true)).To(BeTrue()) + Expect(iter.IsImbalancedOrOverloaded(e3, true)).To(BeTrue()) + + }) + It("mark as imbalanced if it has more in-flight requests", func() { + for i := 0; i < 300; i++ { + iter.PreRequest(e1) + } + for i := 0; i < 200; i++ { + iter.PreRequest(e2) + } + for i := 0; i < 200; i++ { + iter.PreRequest(e3) + } + Expect(iter.IsImbalancedOrOverloaded(e1, false)).To(BeTrue()) + Eventually(logger).Should(gbytes.Say("hash-based-routing-endpoint-imbalanced")) + Expect(iter.IsImbalancedOrOverloaded(e2, false)).To(BeFalse()) + Expect(iter.IsImbalancedOrOverloaded(e3, false)).To(BeFalse()) + }) + }) + }) + + Describe("CalculateAverageNumberOfConnections", func() { + var iter *route.HashBased + var endpoints []*route.Endpoint + + BeforeEach(func() { + iter = route.NewHashBased(logger.Logger, pool, "", false, false, "").(*route.HashBased) + }) + + Context("when there are no endpoints", func() { + It("returns 0", func() { + Expect(iter.CalculateAverageLoad()).To(Equal(float64(0))) + }) + }) + + Context("when all endpoints have zero connections", func() { + BeforeEach(func() { + pool.Put(route.NewEndpoint(&route.EndpointOpts{Host: "1.2.3.4", Port: 5678, LoadBalancingAlgorithm: "hash", PrivateInstanceId: "ID1"})) + pool.Put(route.NewEndpoint(&route.EndpointOpts{Host: "2.2.3.4", Port: 5678, LoadBalancingAlgorithm: "hash", PrivateInstanceId: "ID2"})) + }) + It("returns 0", func() { + Expect(iter.CalculateAverageLoad()).To(Equal(float64(0))) + }) + }) + + Context("when endpoints have varying connection counts", func() { + var e1, e2, e3 *route.Endpoint + BeforeEach(func() { + e1 = route.NewEndpoint(&route.EndpointOpts{Host: "1.2.3.4", Port: 5678, LoadBalancingAlgorithm: "hash", PrivateInstanceId: "ID1"}) + e2 = route.NewEndpoint(&route.EndpointOpts{Host: "2.2.3.4", Port: 5678, LoadBalancingAlgorithm: "hash", PrivateInstanceId: "ID2"}) + e3 = route.NewEndpoint(&route.EndpointOpts{Host: "3.2.3.4", Port: 5678, LoadBalancingAlgorithm: "hash", PrivateInstanceId: "ID3"}) + endpoints = []*route.Endpoint{e1, e2, e3} + for _, e := range endpoints { + pool.Put(e) + } + for i := 0; i < 2; i++ { + iter.PreRequest(e1) + } + for i := 0; i < 4; i++ { + iter.PreRequest(e2) + } + for i := 0; i < 6; i++ { + iter.PreRequest(e3) + } + }) + It("returns the correct average", func() { + // in general 12 in flight requests + Expect(iter.CalculateAverageLoad()).To(Equal(float64(4))) + }) + }) + + Context("when one endpoint has many connections", func() { + var e1, e2 *route.Endpoint + BeforeEach(func() { + e1 = route.NewEndpoint(&route.EndpointOpts{Host: "1.2.3.4", Port: 5678, LoadBalancingAlgorithm: "hash", PrivateInstanceId: "ID1"}) + e2 = route.NewEndpoint(&route.EndpointOpts{Host: "2.2.3.4", Port: 5678, LoadBalancingAlgorithm: "hash", PrivateInstanceId: "ID2"}) + endpoints = []*route.Endpoint{e1, e2} + for _, e := range endpoints { + pool.Put(e) + } + for i := 0; i < 10; i++ { + iter.PreRequest(e1) + } + }) + It("returns the correct average", func() { + Expect(iter.CalculateAverageLoad()).To(Equal(float64(5))) + }) + }) + }) + +}) + +// MockHashLookupTable provides a simple mock implementation of MaglevLookup interface for testing. +type MockHashLookupTable struct { + lookupTable []int + endpointList []string +} + +// NewMockHashLookupTable creates a new mock lookup table with predefined mappings +func NewMockHashLookupTable(lookupTable []int, endpointList []string) *MockHashLookupTable { + return &MockHashLookupTable{ + lookupTable: lookupTable, + endpointList: endpointList, + } +} + +func (m *MockHashLookupTable) GetInstanceForHashHeader(hashHeaderValue string) (uint64, string, error) { + if len(m.endpointList) == 0 { + return 0, "", nil + } + h := fnv.New64a() + _, _ = h.Write([]byte(hashHeaderValue)) + key := h.Sum64() + index := key % m.GetLookupTableSize() + return index, m.endpointList[m.lookupTable[index]], nil + +} + +func (m *MockHashLookupTable) GetLookupTableSize() uint64 { + return uint64(len(m.lookupTable)) +} + +func (m *MockHashLookupTable) GetEndpointId(lookupTableIndex uint64) string { + return m.endpointList[m.lookupTable[lookupTableIndex]] +} + +func (m *MockHashLookupTable) Add(endpoint string) { + // Check if endpoint already exists + for _, existing := range m.endpointList { + if existing == endpoint { + return + } + } + m.endpointList = append(m.endpointList, endpoint) +} + +func (m *MockHashLookupTable) Remove(endpoint string) { + for i, existing := range m.endpointList { + if existing == endpoint { + m.endpointList = append(m.endpointList[:i], m.endpointList[i+1:]...) + return + } + } +} + +func (m *MockHashLookupTable) GetEndpointList() []string { + return append([]string(nil), m.endpointList...) // return a copy +} + +// GetLookupTable returns a copy of the current lookup table (for testing) +func (m *MockHashLookupTable) GetLookupTable() []int { + return m.lookupTable // return a copy +} + +// GetPermutationTable returns a copy of the current permutation table (for testing) +func (m *MockHashLookupTable) GetPermutationTable() [][]uint64 { + return nil // not implemented in mock +} + +// Compile-time check to ensure MockHashLookupTable implements MaglevLookup interface +var _ route.MaglevLookup = (*MockHashLookupTable)(nil) diff --git a/src/code.cloudfoundry.org/gorouter/route/maglev.go b/src/code.cloudfoundry.org/gorouter/route/maglev.go new file mode 100644 index 000000000..23d46210f --- /dev/null +++ b/src/code.cloudfoundry.org/gorouter/route/maglev.go @@ -0,0 +1,260 @@ +package route + +/****************************************************************************** + * Original github.com/kkdai/maglev/maglev.go + * + * Copyright (c) 2019 Evan Lin (github.com/kkdai) + * + * This program and the accompanying materials are made available under + * the terms of the Apache License, Version 2.0 which is available at + * http://www.apache.org/licenses/LICENSE-2.0. + ******************************************************************************/ + +import ( + "errors" + "fmt" + "hash/fnv" + "log/slog" + "sort" + "strconv" + "strings" + "sync" +) + +const ( + // lookupTableSize is prime number for the size of the maglev lookup table, which should be approximately 100x + // the number of expected endpoints + lookupTableSize uint64 = 1801 +) + +// MaglevLookup defines the interface for consistent hashing lookup table implementations. +// This interface allows for different implementations of the Maglev algorithm and +// enables easy testing with mock implementations. +type MaglevLookup interface { + // Add a new endpoint to the lookup table + Add(endpoint string) + + // Remove an endpoint from the lookup table + Remove(endpoint string) + + // GetInstanceForHashHeader endpoint by specified request header value + GetInstanceForHashHeader(hashHeaderValue string) (uint64, string, error) + + // GetEndpointId returns the endpoint ID by specified lookup table index + GetEndpointId(lookupTableIndex uint64) string + + // GetLookupTableSize returns the size of the lookup table + GetLookupTableSize() uint64 + + // GetEndpointList returns a copy of the current endpoint list (for testing) + GetEndpointList() []string + + // GetLookupTable returns a copy of the current lookup table (for testing) + GetLookupTable() []int + + // GetPermutationTable returns a copy of the current permutation table (for testing) + GetPermutationTable() [][]uint64 +} + +// Maglev implementation of consistent hashing algorithm described in "Maglev: A Fast and Reliable Software Network +// Load Balancer" (https://storage.googleapis.com/gweb-research2023-media/pubtools/2904.pdf) +type Maglev struct { + logger *slog.Logger + permutationTable [][]uint64 + lookupTable []int + endpointList []string + lock *sync.RWMutex +} + +// NewMaglev initializes an empty maglev lookupTable table +func NewMaglev(logger *slog.Logger) *Maglev { + return &Maglev{ + lock: &sync.RWMutex{}, + lookupTable: make([]int, lookupTableSize), + endpointList: make([]string, 0, 2), + permutationTable: make([][]uint64, 0, 2), + logger: logger, + } +} + +// Add a new endpoint to lookupTable if it's not already contained. +func (m *Maglev) Add(endpoint string) { + m.lock.Lock() + defer m.lock.Unlock() + + if lookupTableSize == uint64(len(m.endpointList)) { + m.logger.Warn("maglev-add-lookuptable-capacity-exceeded", slog.String("endpoint-id", endpoint)) + return + } + + index := sort.SearchStrings(m.endpointList, endpoint) + if index < len(m.endpointList) && m.endpointList[index] == endpoint { + m.logger.Debug("maglev-add-lookuptable-endpoint-exists", slog.String("endpoint-id", endpoint), slog.Int("current-endpoints", len(m.endpointList))) + return + } + + m.endpointList = append(m.endpointList, "") + copy(m.endpointList[index+1:], m.endpointList[index:]) + m.endpointList[index] = endpoint + + m.generatePermutation(endpoint) + m.fillLookupTable() +} + +// Remove an endpoint from lookupTable if it's contained. +func (m *Maglev) Remove(endpoint string) { + m.lock.Lock() + defer m.lock.Unlock() + + index := sort.SearchStrings(m.endpointList, endpoint) + if index >= len(m.endpointList) || m.endpointList[index] != endpoint { + m.logger.Debug("maglev-remove-endpoint-not-found", slog.String("endpoint-id", endpoint)) + return + } + + m.endpointList = append(m.endpointList[:index], m.endpointList[index+1:]...) + m.permutationTable = append(m.permutationTable[:index], m.permutationTable[index+1:]...) + + m.fillLookupTable() +} + +func (m *Maglev) hashKey(headerValue string) uint64 { + return m.calculateFNVHash64(headerValue) +} + +// GetInstanceForHashHeader lookup table index and private instance ID for the specified request header value +func (m *Maglev) GetInstanceForHashHeader(hashHeaderValue string) (uint64, string, error) { + m.lock.RLock() + defer m.lock.RUnlock() + + if len(m.endpointList) == 0 { + return 0, "", errors.New("no endpoint available") + } + key := m.hashKey(hashHeaderValue) + index := key % lookupTableSize + return index, m.endpointList[m.lookupTable[key%lookupTableSize]], nil +} + +// GetEndpointId by specified lookup table index +func (m *Maglev) GetEndpointId(lookupTableIndex uint64) string { + m.lock.RLock() + defer m.lock.RUnlock() + + return m.endpointList[m.lookupTable[lookupTableIndex]] +} + +// generatePermutation creates a permutationTable of the lookup table for each endpoint +func (m *Maglev) generatePermutation(endpoint string) { + pos := sort.SearchStrings(m.endpointList, endpoint) + if pos == len(m.endpointList) { + m.logger.Debug("maglev-permutation-no-endpoints") + return + } + + endpointHash := m.calculateFNVHash64(endpoint) + offset := endpointHash % lookupTableSize + skip := (endpointHash % (lookupTableSize - 1)) + 1 + + permutationForEndpoint := make([]uint64, lookupTableSize) + for j := uint64(0); j < lookupTableSize; j++ { + permutationForEndpoint[j] = (offset + j*skip) % lookupTableSize + } + + // insert permutationForEndpoint at position pos, shifting the rest to the right + m.permutationTable = append(m.permutationTable, nil) + copy(m.permutationTable[pos+1:], m.permutationTable[pos:]) + m.permutationTable[pos] = permutationForEndpoint + +} + +func (m *Maglev) fillLookupTable() { + if len(m.endpointList) == 0 { + return + } + + numberOfEndpoints := len(m.endpointList) + next := make([]int, numberOfEndpoints) + entry := make([]int, lookupTableSize) + for j := range entry { + entry[j] = -1 + } + + for n := uint64(0); n <= lookupTableSize; { + for i := 0; i < numberOfEndpoints; i++ { + candidate := m.findNextAvailableSlot(i, next, entry) + entry[candidate] = int(i) + next[i] = next[i] + 1 + n++ + + if n == lookupTableSize { + m.lookupTable = entry + return + } + } + } +} + +func (m *Maglev) findNextAvailableSlot(i int, next []int, entry []int) uint64 { + candidate := m.permutationTable[i][next[i]] + for entry[candidate] >= 0 { + next[i]++ + if next[i] >= len(m.permutationTable[i]) { + // This should not happen in a properly functioning Maglev algorithm, + // but we add this safety check to prevent panic + m.logger.Error("maglev-permutation-table-exhausted", + slog.Int("endpoint-index", i), + slog.Int("next-value", next[i]), + slog.Int("table-size", len(m.permutationTable[i]))) + // Reset to beginning of permutation table as fallback + next[i] = 0 + } + candidate = m.permutationTable[i][next[i]] + } + return candidate +} + +// Getters for unit tests +func (m *Maglev) GetEndpointList() []string { + m.lock.RLock() + defer m.lock.RUnlock() + return append([]string(nil), m.endpointList...) +} + +func (m *Maglev) GetLookupTable() []int { + m.lock.RLock() + defer m.lock.RUnlock() + return append([]int(nil), m.lookupTable...) +} + +func (m *Maglev) GetPermutationTable() [][]uint64 { + m.lock.RLock() + defer m.lock.RUnlock() + copied := make([][]uint64, len(m.permutationTable)) + for i, v := range m.permutationTable { + copied[i] = append([]uint64(nil), v...) + } + return copied +} + +func (m *Maglev) GetLookupTableSize() uint64 { + return lookupTableSize +} + +// TODO: Remove in final version +func (m *Maglev) PrintLookupTable() string { + strArr := make([]string, len(m.lookupTable)) + for i, value := range m.lookupTable { + strArr[i] = strconv.Itoa(value) + } + return fmt.Sprintf("[%s]", strings.Join(strArr, ", ")) +} + +// calculateFNVHash64 computes a hash using the non-cryptographic FNV hash algorithm. +func (m *Maglev) calculateFNVHash64(key string) uint64 { + h := fnv.New64a() + _, _ = h.Write([]byte(key)) + return h.Sum64() +} + +// Compile-time check to ensure Maglev implements MaglevLookup interface +var _ MaglevLookup = (*Maglev)(nil) diff --git a/src/code.cloudfoundry.org/gorouter/route/maglev_test.go b/src/code.cloudfoundry.org/gorouter/route/maglev_test.go new file mode 100644 index 000000000..b72d13b3d --- /dev/null +++ b/src/code.cloudfoundry.org/gorouter/route/maglev_test.go @@ -0,0 +1,345 @@ +package route_test + +import ( + "fmt" + "strconv" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + + "code.cloudfoundry.org/gorouter/route" + "code.cloudfoundry.org/gorouter/test_util" +) + +var _ = Describe("Maglev", func() { + var ( + logger *test_util.TestLogger + maglev *route.Maglev + ) + + BeforeEach(func() { + logger = test_util.NewTestLogger("test") + + maglev = route.NewMaglev(logger.Logger) + }) + + Describe("NewMaglev", func() { + It("should create a new Maglev instance", func() { + Expect(maglev).NotTo(BeNil()) + }) + }) + + Describe("Add", func() { + Context("when adding a new backend", func() { + It("should add the backend successfully", func() { + maglev.Add("backend1") + + Expect(maglev.GetEndpointList()).To(HaveLen(1)) + Expect(maglev.GetLookupTable()).To(HaveLen(int(maglev.GetLookupTableSize()))) + Expect(maglev.GetPermutationTable()).To(HaveLen(1)) + Expect(maglev.GetPermutationTable()[0]).To(HaveLen(int(maglev.GetLookupTableSize()))) + + _, backend, err := maglev.GetInstanceForHashHeader("test-key") + Expect(err).NotTo(HaveOccurred()) + Expect(backend).To(Equal("backend1")) + }) + }) + + Context("when adding a backend twice", func() { + It("should skip adding subsequent adds", func() { + maglev.Add("backend1") + maglev.Add("backend1") + + Expect(maglev.GetEndpointList()).To(HaveLen(1)) + Expect(maglev.GetLookupTable()).To(HaveLen(int(maglev.GetLookupTableSize()))) + Expect(maglev.GetPermutationTable()).To(HaveLen(1)) + Expect(maglev.GetPermutationTable()[0]).To(HaveLen(int(maglev.GetLookupTableSize()))) + + _, backend, err := maglev.GetInstanceForHashHeader("test-key") + Expect(err).NotTo(HaveOccurred()) + Expect(backend).To(Equal("backend1")) + }) + }) + + Context("when adding multiple backends", func() { + It("should make all backends reachable", func() { + maglev.Add("backend1") + maglev.Add("backend2") + maglev.Add("backend3") + + Expect(maglev.GetEndpointList()).To(HaveLen(3)) + Expect(maglev.GetLookupTable()).To(HaveLen(int(maglev.GetLookupTableSize()))) + Expect(maglev.GetPermutationTable()).To(HaveLen(len(maglev.GetEndpointList()))) + for i := range len(maglev.GetEndpointList()) { + Expect(maglev.GetPermutationTable()[i]).To(HaveLen(int(maglev.GetLookupTableSize()))) + } + + backends := make(map[string]bool) + for i := 0; i < 1000; i++ { + _, backend, err := maglev.GetInstanceForHashHeader(string(rune(i))) + Expect(err).NotTo(HaveOccurred()) + backends[backend] = true + } + + Expect(backends["backend1"]).To(BeTrue()) + Expect(backends["backend2"]).To(BeTrue()) + Expect(backends["backend3"]).To(BeTrue()) + }) + }) + }) + + Describe("Remove", func() { + Context("when removing an existing backend", func() { + It("should remove the backend successfully", func() { + maglev.Add("backend1") + maglev.Add("backend2") + + maglev.Remove("backend1") + + Expect(maglev.GetEndpointList()).To(HaveLen(1)) + Expect(maglev.GetLookupTable()).To(HaveLen(int(maglev.GetLookupTableSize()))) + Expect(maglev.GetPermutationTable()).To(HaveLen(1)) + Expect(maglev.GetPermutationTable()[0]).To(HaveLen(int(maglev.GetLookupTableSize()))) + + }) + }) + + Context("when removing a non-existent backend", func() { + It("should handle gracefully without error", func() { + maglev.Add("backend1") + + Expect(func() { maglev.Remove("non-existent") }).NotTo(Panic()) + + Expect(maglev.GetEndpointList()).To(HaveLen(1)) + Expect(maglev.GetLookupTable()).To(HaveLen(int(maglev.GetLookupTableSize()))) + Expect(maglev.GetPermutationTable()).To(HaveLen(1)) + Expect(maglev.GetPermutationTable()[0]).To(HaveLen(int(maglev.GetLookupTableSize()))) + }) + }) + }) + + Describe("Get", func() { + Context("when no backends were added", func() { + It("should return an error", func() { + _, _, err := maglev.GetInstanceForHashHeader("test-key") + Expect(err).To(HaveOccurred()) + }) + }) + + Context("when backends are added", func() { + BeforeEach(func() { + maglev.Add("backend1") + maglev.Add("backend2") + }) + + It("should return consistent results for the same key", func() { + var counter = make(map[string]int) + var result string + var err error + for range 100 { + _, result, err = maglev.GetInstanceForHashHeader("consistent-key") + Expect(err).NotTo(HaveOccurred()) + counter[result]++ + } + + Expect(counter[result]).To(Equal(100)) + }) + + It("should distribute keys across backends", func() { + maglev.Add("backend1") + maglev.Add("backend2") + maglev.Add("backend3") + + distribution := make(map[string]int) + for i := range 1000 { + _, backend, err := maglev.GetInstanceForHashHeader(string(rune(i))) + Expect(err).NotTo(HaveOccurred()) + distribution[backend]++ + } + + Expect(distribution["backend1"]).To(BeNumerically(">", 0)) + Expect(distribution["backend2"]).To(BeNumerically(">", 0)) + Expect(distribution["backend3"]).To(BeNumerically(">", 0)) + }) + }) + + Context("when backends are removed", func() { + BeforeEach(func() { + maglev.Add("backend1") + maglev.Add("backend2") + maglev.Remove("backend1") + }) + + It("should not return the removed backend", func() { + for range 100 { + _, backend, err := maglev.GetInstanceForHashHeader("consistent-key") + Expect(err).NotTo(HaveOccurred()) + Expect(backend).To(Equal("backend2")) + } + }) + }) + }) + + Describe("GetInstanceForHashHeader", func() { + Context("when no backends were added", func() { + It("should return an error", func() { + _, _, err := maglev.GetInstanceForHashHeader("test-key") + Expect(err).To(HaveOccurred()) + }) + }) + + Context("when backends are added", func() { + BeforeEach(func() { + maglev.Add("backend1") + maglev.Add("backend2") + }) + + It("should return consistent results for the same key", func() { + var counter = make(map[uint64]int) + var lookupTableIndex uint64 + var err error + for range 100 { + lookupTableIndex, _, err = maglev.GetInstanceForHashHeader("consistent-key") + Expect(err).NotTo(HaveOccurred()) + counter[lookupTableIndex]++ + } + + Expect(counter[lookupTableIndex]).To(Equal(100)) + }) + }) + }) + + Describe("GetEndpointId", func() { + Context("when backends are added", func() { + BeforeEach(func() { + maglev.Add("app_instance_1") + maglev.Add("app_instance_2") + }) + + It("should return consistent results for the same key", func() { + var counter = make(map[string]int) + var endpointID string + for range 100 { + lookupTableIndex, _, err := maglev.GetInstanceForHashHeader("consistent-key") + Expect(err).NotTo(HaveOccurred()) + endpointID = maglev.GetEndpointId(lookupTableIndex) + Expect(err).NotTo(HaveOccurred()) + counter[endpointID]++ + } + + Expect(counter[endpointID]).To(Equal(100)) + }) + + It("should distribute keys across backends", func() { + maglev.Add("app_instance_1") + maglev.Add("app_instance_2") + maglev.Add("app_instance_3") + + distribution := make(map[string]int) + for i := range 1000 { + lookupTableIndex, _, err := maglev.GetInstanceForHashHeader(string(rune(i))) + Expect(err).NotTo(HaveOccurred()) + endpointID := maglev.GetEndpointId(lookupTableIndex) + Expect(err).NotTo(HaveOccurred()) + distribution[endpointID]++ + } + + Expect(distribution["app_instance_1"]).To(BeNumerically(">", 0)) + Expect(distribution["app_instance_2"]).To(BeNumerically(">", 0)) + Expect(distribution["app_instance_3"]).To(BeNumerically(">", 0)) + }) + }) + + Context("when backends are removed", func() { + BeforeEach(func() { + maglev.Add("app_instance_1") + maglev.Add("app_instance_2") + maglev.Remove("app_instance_1") + }) + + It("should not return the removed backend", func() { + for i := range 1000 { + lookupTableIndex, _, err := maglev.GetInstanceForHashHeader(string(rune(i))) + Expect(err).NotTo(HaveOccurred()) + endpointID := maglev.GetEndpointId(lookupTableIndex) + Expect(endpointID).To(Equal("app_instance_2")) + } + }) + }) + }) + + Describe("Consistency", func() { + // We test that at most half the keys are reassigned to new backends, when one backend is added. + // This ensures a minimal level of consistency. + It("should minimize disruption when adding backends", func() { + for i := range 10 { + maglev.Add(fmt.Sprintf("backend%d", i+1)) + } + keys := make([]string, 1000) + for i := range keys { + keys[i] = fmt.Sprintf("key%d", i+1) + } + + initialMappings := make(map[string]string) + + for _, key := range keys { + _, backend, err := maglev.GetInstanceForHashHeader(key) + Expect(err).NotTo(HaveOccurred()) + initialMappings[key] = backend + } + + maglev.Add("newbackend") + + changedMappings := 0 + for _, key := range keys { + _, backend, err := maglev.GetInstanceForHashHeader(key) + Expect(err).NotTo(HaveOccurred()) + if initialMappings[key] != backend { + changedMappings++ + } + } + + Expect(changedMappings).To(BeNumerically("<=", len(keys)/2)) + }) + }) + + Describe("Concurrency", func() { + It("should handle concurrent reads safely", func() { + maglev.Add("backend1") + + done := make(chan bool, 10) + for i := 0; i < 10; i++ { + go func() { + defer GinkgoRecover() + for j := 0; j < 100; j++ { + _, _, err := maglev.GetInstanceForHashHeader("test-key") + Expect(err).NotTo(HaveOccurred()) + } + done <- true + }() + } + + for i := 0; i < 10; i++ { + Eventually(done).Should(Receive()) + } + }) + It("should handle concurrent endpoint registrations safely", func() { + done := make(chan bool, 10) + for i := 0; i < 10; i++ { + go func() { + defer GinkgoRecover() + for j := 0; j < 100; j++ { + Expect(func() { maglev.Add("endpoint" + strconv.Itoa(j)) }).NotTo(Panic()) + } + done <- true + }() + } + + for i := 0; i < 10; i++ { + Eventually(done).Should(Receive()) + } + Expect(len(maglev.GetEndpointList())).To(Equal(100)) + }) + + }) +}) diff --git a/src/code.cloudfoundry.org/gorouter/route/pool.go b/src/code.cloudfoundry.org/gorouter/route/pool.go index f089fc15b..8217aefca 100644 --- a/src/code.cloudfoundry.org/gorouter/route/pool.go +++ b/src/code.cloudfoundry.org/gorouter/route/pool.go @@ -74,6 +74,21 @@ type ProxyRoundTripper interface { CancelRequest(*http.Request) } +type HashRoutingProperties struct { + Header string + BalanceFactor float64 +} + +func (hrp *HashRoutingProperties) Equal(hrp2 *HashRoutingProperties) bool { + if hrp == nil && hrp2 == nil { + return true + } + if hrp == nil || hrp2 == nil { + return false + } + return hrp.Header == hrp2.Header && hrp.BalanceFactor == hrp2.BalanceFactor +} + type Endpoint struct { ApplicationId string AvailabilityZone string @@ -186,6 +201,8 @@ type EndpointPool struct { logger *slog.Logger updatedAt time.Time LoadBalancingAlgorithm string + HashRoutingProperties *HashRoutingProperties + HashLookupTable MaglevLookup } type EndpointOpts struct { @@ -248,10 +265,12 @@ type PoolOpts struct { MaxConnsPerBackend int64 Logger *slog.Logger LoadBalancingAlgorithm string + HashHeader string + HashBalanceFactor float64 } func NewPool(opts *PoolOpts) *EndpointPool { - return &EndpointPool{ + pool := &EndpointPool{ endpoints: make([]*endpointElem, 0, 1), index: make(map[string]*endpointElem), retryAfterFailure: opts.RetryAfterFailure, @@ -264,6 +283,14 @@ func NewPool(opts *PoolOpts) *EndpointPool { updatedAt: time.Now(), LoadBalancingAlgorithm: opts.LoadBalancingAlgorithm, } + if pool.LoadBalancingAlgorithm == config.LOAD_BALANCE_HB { + pool.HashLookupTable = NewMaglev(opts.Logger) + pool.HashRoutingProperties = &HashRoutingProperties{ + Header: opts.HashHeader, + BalanceFactor: opts.HashBalanceFactor, + } + } + return pool } func PoolsMatch(p1, p2 *EndpointPool) bool { @@ -320,7 +347,6 @@ func (p *EndpointPool) Put(endpoint *Endpoint) PoolPutResult { // new one. e.Lock() defer e.Unlock() - oldEndpoint := e.endpoint e.endpoint = endpoint @@ -336,6 +362,9 @@ func (p *EndpointPool) Put(endpoint *Endpoint) PoolPutResult { p.RouteSvcUrl = e.endpoint.RouteServiceUrl p.setPoolLoadBalancingAlgorithm(e.endpoint) e.updated = time.Now() + if p.LoadBalancingAlgorithm == config.LOAD_BALANCE_HB { + p.HashLookupTable.Add(e.endpoint.PrivateInstanceId) + } p.Update() return EndpointUpdated @@ -348,7 +377,6 @@ func (p *EndpointPool) Put(endpoint *Endpoint) PoolPutResult { updated: time.Now(), maxConnsPerBackend: p.maxConnsPerBackend, } - p.endpoints = append(p.endpoints, e) p.index[endpoint.CanonicalAddr()] = e @@ -356,6 +384,9 @@ func (p *EndpointPool) Put(endpoint *Endpoint) PoolPutResult { p.RouteSvcUrl = e.endpoint.RouteServiceUrl p.setPoolLoadBalancingAlgorithm(e.endpoint) + if p.LoadBalancingAlgorithm == config.LOAD_BALANCE_HB { + p.HashLookupTable.Add(e.endpoint.PrivateInstanceId) + } p.Update() return EndpointAdded @@ -433,6 +464,11 @@ func (p *EndpointPool) removeEndpoint(e *endpointElem) { delete(p.index, e.endpoint.CanonicalAddr()) delete(p.index, e.endpoint.PrivateInstanceId) p.Update() + + if p.LoadBalancingAlgorithm == config.LOAD_BALANCE_HB { + p.HashLookupTable.Remove(e.endpoint.PrivateInstanceId) + } + } func (p *EndpointPool) Endpoints(logger *slog.Logger, initial string, mustBeSticky bool, azPreference string, az string) EndpointIterator { @@ -443,6 +479,9 @@ func (p *EndpointPool) Endpoints(logger *slog.Logger, initial string, mustBeStic case config.LOAD_BALANCE_RR: logger.Debug("endpoint-iterator-with-round-robin-lb-algo") return NewRoundRobin(logger, p, initial, mustBeSticky, azPreference == config.AZ_PREF_LOCAL, az) + case config.LOAD_BALANCE_HB: + logger.Debug("endpoint-iterator-with-hash-based-lb-algo") + return NewHashBased(logger, p, initial, mustBeSticky, azPreference == config.AZ_PREF_LOCAL, az) default: logger.Error("invalid-pool-load-balancing-algorithm", slog.String("poolLBAlgorithm", p.LoadBalancingAlgorithm), @@ -452,6 +491,23 @@ func (p *EndpointPool) Endpoints(logger *slog.Logger, initial string, mustBeStic } } +func (p *EndpointPool) FallBackToDefaultLoadBalancing(defaultLBAlgo string, logger *slog.Logger, initial string, mustBeSticky bool, azPreference string, az string) EndpointIterator { + logger.Info("hash-based-routing-header-not-found", + slog.String("poolLBAlgorithm", p.LoadBalancingAlgorithm), + slog.String("Host", p.host), + slog.String("Path", p.contextPath)) + + switch defaultLBAlgo { + case config.LOAD_BALANCE_LC: + logger.Debug("endpoint-iterator-with-least-connection-lb-algo") + return NewLeastConnection(logger, p, initial, mustBeSticky, azPreference == config.AZ_PREF_LOCAL, az) + case config.LOAD_BALANCE_RR: + logger.Debug("endpoint-iterator-with-round-robin-lb-algo") + return NewRoundRobin(logger, p, initial, mustBeSticky, azPreference == config.AZ_PREF_LOCAL, az) + } + return NewRoundRobin(logger, p, initial, mustBeSticky, azPreference == config.AZ_PREF_LOCAL, az) +} + func (p *EndpointPool) NumEndpoints() int { p.Lock() defer p.Unlock() @@ -561,18 +617,42 @@ func (p *EndpointPool) MarshalJSON() ([]byte, error) { // setPoolLoadBalancingAlgorithm overwrites the load balancing algorithm of a pool by that of a specified endpoint, if that is valid. func (p *EndpointPool) setPoolLoadBalancingAlgorithm(endpoint *Endpoint) { - if len(endpoint.LoadBalancingAlgorithm) > 0 && endpoint.LoadBalancingAlgorithm != p.LoadBalancingAlgorithm { + if endpoint.LoadBalancingAlgorithm == "" { + return + } + + if endpoint.LoadBalancingAlgorithm != p.LoadBalancingAlgorithm { if config.IsLoadBalancingAlgorithmValid(endpoint.LoadBalancingAlgorithm) { p.LoadBalancingAlgorithm = endpoint.LoadBalancingAlgorithm p.logger.Debug("setting-pool-load-balancing-algorithm-to-that-of-an-endpoint", slog.String("endpointLBAlgorithm", endpoint.LoadBalancingAlgorithm), slog.String("poolLBAlgorithm", p.LoadBalancingAlgorithm)) + } else { p.logger.Error("invalid-endpoint-load-balancing-algorithm-provided-keeping-pool-lb-algo", slog.String("endpointLBAlgorithm", endpoint.LoadBalancingAlgorithm), slog.String("poolLBAlgorithm", p.LoadBalancingAlgorithm)) } } + p.prepareHashBasedRouting(endpoint) +} + +func (p *EndpointPool) prepareHashBasedRouting(endpoint *Endpoint) { + if p.LoadBalancingAlgorithm != config.LOAD_BALANCE_HB { + return + } + if p.HashLookupTable == nil { + p.HashLookupTable = NewMaglev(p.logger) + } + + newProps := &HashRoutingProperties{ + Header: endpoint.HashHeaderName, + BalanceFactor: endpoint.HashBalanceFactor, + } + + if p.HashRoutingProperties == nil || !p.HashRoutingProperties.Equal(newProps) { + p.HashRoutingProperties = newProps + } } func (e *endpointElem) failed() { diff --git a/src/code.cloudfoundry.org/gorouter/route/pool_test.go b/src/code.cloudfoundry.org/gorouter/route/pool_test.go index 31da6c8d7..7709a1d8b 100644 --- a/src/code.cloudfoundry.org/gorouter/route/pool_test.go +++ b/src/code.cloudfoundry.org/gorouter/route/pool_test.go @@ -428,6 +428,46 @@ var _ = Describe("EndpointPool", func() { Expect(pool.LoadBalancingAlgorithm).To(Equal(config.LOAD_BALANCE_RR)) }) }) + + Context("When switching to hash-based routing", func() { + It("will create the maglev table and add the endpoint", func() { + pool := route.NewPool(&route.PoolOpts{ + Logger: logger.Logger, + LoadBalancingAlgorithm: config.LOAD_BALANCE_RR, + }) + + endpointOpts := route.EndpointOpts{ + Host: "host-1", + Port: 1234, + RouteServiceUrl: "url", + LoadBalancingAlgorithm: config.LOAD_BALANCE_RR, + } + + initalEndpoint := route.NewEndpoint(&endpointOpts) + + pool.Put(initalEndpoint) + Expect(pool.LoadBalancingAlgorithm).To(Equal(config.LOAD_BALANCE_RR)) + + endpointOptsHash := route.EndpointOpts{ + Host: "host-1", + Port: 1234, + RouteServiceUrl: "url", + LoadBalancingAlgorithm: config.LOAD_BALANCE_HB, + HashBalanceFactor: 1.25, + HashHeaderName: "X-Tenant", + } + + hashEndpoint := route.NewEndpoint(&endpointOptsHash) + + pool.Put(hashEndpoint) + Expect(pool.LoadBalancingAlgorithm).To(Equal(config.LOAD_BALANCE_HB)) + Expect(pool.HashLookupTable).ToNot(BeNil()) + Expect(pool.HashLookupTable.GetEndpointList()).To(HaveLen(1)) + Expect(pool.HashLookupTable.GetEndpointList()[0]).To(Equal(hashEndpoint.PrivateInstanceId)) + }) + + }) + }) Context("RouteServiceUrl", func() {