From 0ee521dd450ce312d491e4c67caa012e835ef981 Mon Sep 17 00:00:00 2001 From: Alexander Nicke Date: Fri, 29 Aug 2025 12:13:37 +0200 Subject: [PATCH 1/4] Implement hash-based routing (#505) This commit provides the basic implementation for hash-based routing. It does not consider the balance factor yet. Co-authored-by: Clemens Hoffmann Co-authored-by: Tamara Boehm Co-authored-by: Soha Alboghdady --- docs/03-how-to-add-new-route-option.md | 5 + .../round_tripper/proxy_round_tripper.go | 14 + .../round_tripper/proxy_round_tripper_test.go | 162 +++++++++++ .../gorouter/route/hash_based.go | 141 ++++++++++ .../gorouter/route/hash_based_test.go | 155 +++++++++++ .../gorouter/route/maglev.go | 212 +++++++++++++++ .../gorouter/route/maglev_test.go | 257 ++++++++++++++++++ .../gorouter/route/pool.go | 79 +++++- .../gorouter/route/pool_test.go | 40 +++ 9 files changed, 1061 insertions(+), 4 deletions(-) create mode 100644 src/code.cloudfoundry.org/gorouter/route/hash_based.go create mode 100644 src/code.cloudfoundry.org/gorouter/route/hash_based_test.go create mode 100644 src/code.cloudfoundry.org/gorouter/route/maglev.go create mode 100644 src/code.cloudfoundry.org/gorouter/route/maglev_test.go diff --git a/docs/03-how-to-add-new-route-option.md b/docs/03-how-to-add-new-route-option.md index 39ce25447..308213fcc 100644 --- a/docs/03-how-to-add-new-route-option.md +++ b/docs/03-how-to-add-new-route-option.md @@ -22,6 +22,11 @@ applications: - route: example2.com options: loadbalancing: least-connection + - route: example3.com + options: + loadbalancing: hash + hash_header: tenant-id + hash_balance: 1.25 ``` **NOTE**: In the implementation, the `options` property of a route represents per-route features. diff --git a/src/code.cloudfoundry.org/gorouter/proxy/round_tripper/proxy_round_tripper.go b/src/code.cloudfoundry.org/gorouter/proxy/round_tripper/proxy_round_tripper.go index 88cfd20a5..84252262f 100644 --- a/src/code.cloudfoundry.org/gorouter/proxy/round_tripper/proxy_round_tripper.go +++ b/src/code.cloudfoundry.org/gorouter/proxy/round_tripper/proxy_round_tripper.go @@ -127,6 +127,20 @@ func (rt *roundTripper) RoundTrip(originalRequest *http.Request) (*http.Response stickyEndpointID, mustBeSticky := handlers.GetStickySession(request, rt.config.StickySessionCookieNames, rt.config.StickySessionsForAuthNegotiate) numberOfEndpoints := reqInfo.RoutePool.NumEndpoints() iter := reqInfo.RoutePool.Endpoints(rt.logger, stickyEndpointID, mustBeSticky, rt.config.LoadBalanceAZPreference, rt.config.Zone) + if reqInfo.RoutePool.LoadBalancingAlgorithm == config.LOAD_BALANCE_HB { + if reqInfo.RoutePool.HashRoutingProperties == nil { + rt.logger.Error("hash-routing-properties-nil", slog.String("host", reqInfo.RoutePool.Host())) + + } else { + headerName := reqInfo.RoutePool.HashRoutingProperties.Header + headerValue := request.Header.Get(headerName) + if headerValue != "" { + iter.(*route.HashBased).HeaderValue = headerValue + } else { + iter = reqInfo.RoutePool.FallBackToDefaultLoadBalancing(rt.config.LoadBalance, rt.logger, stickyEndpointID, mustBeSticky, rt.config.LoadBalanceAZPreference, rt.config.Zone) + } + } + } // The selectEndpointErr needs to be tracked separately. If we get an error // while selecting an endpoint we might just have run out of routes. In diff --git a/src/code.cloudfoundry.org/gorouter/proxy/round_tripper/proxy_round_tripper_test.go b/src/code.cloudfoundry.org/gorouter/proxy/round_tripper/proxy_round_tripper_test.go index 9d270867c..6abe4d218 100644 --- a/src/code.cloudfoundry.org/gorouter/proxy/round_tripper/proxy_round_tripper_test.go +++ b/src/code.cloudfoundry.org/gorouter/proxy/round_tripper/proxy_round_tripper_test.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "io" + "math/rand" "net" "net/http" "net/http/httptest" @@ -1700,6 +1701,167 @@ var _ = Describe("ProxyRoundTripper", func() { }) }) + Context("when load-balancing strategy is set to hash-based routing", func() { + JustBeforeEach(func() { + for i := 1; i <= 3; i++ { + endpoint = route.NewEndpoint(&route.EndpointOpts{ + AppId: fmt.Sprintf("appID%d", i), + Host: fmt.Sprintf("%d.%d.%d.%d", i, i, i, i), + Port: 9090, + PrivateInstanceId: fmt.Sprintf("instanceID%d", i), + PrivateInstanceIndex: fmt.Sprintf("%d", i), + AvailabilityZone: AZ, + LoadBalancingAlgorithm: config.LOAD_BALANCE_HB, + HashHeaderName: "X-Hash", + }) + + _ = routePool.Put(endpoint) + Expect(routePool.HashLookupTable).ToNot(BeNil()) + + } + }) + + It("routes requests with same hash header value to the same endpoint", func() { + req.Header.Set("X-Hash", "value") + reqInfo, err := handlers.ContextRequestInfo(req) + Expect(err).ToNot(HaveOccurred()) + reqInfo.RoutePool = routePool + + var selectedEndpoints []*route.Endpoint + + // Make multiple requests with the same hash value + for i := 0; i < 5; i++ { + _, err = proxyRoundTripper.RoundTrip(req) + Expect(err).NotTo(HaveOccurred()) + selectedEndpoints = append(selectedEndpoints, reqInfo.RouteEndpoint) + } + + // All requests should go to the same endpoint + firstEndpoint := selectedEndpoints[0] + for _, ep := range selectedEndpoints[1:] { + Expect(ep.PrivateInstanceId).To(Equal(firstEndpoint.PrivateInstanceId)) + } + }) + + It("routes requests with different hash header values to potentially different endpoints", func() { + reqInfo, err := handlers.ContextRequestInfo(req) + Expect(err).ToNot(HaveOccurred()) + reqInfo.RoutePool = routePool + + endpointDistribution := make(map[string]int) + + // Make requests with different hash values + for i := 0; i < 10; i++ { + req.Header.Set("X-Hash", fmt.Sprintf("value-%d", i)) + _, err = proxyRoundTripper.RoundTrip(req) + Expect(err).NotTo(HaveOccurred()) + endpointDistribution[reqInfo.RouteEndpoint.PrivateInstanceId]++ + } + + // Should distribute across multiple endpoints (not all to one) + Expect(len(endpointDistribution)).To(BeNumerically(">", 1)) + }) + + It("falls back to default load balancing algorithm when hash header is missing", func() { + reqInfo, err := handlers.ContextRequestInfo(req) + Expect(err).ToNot(HaveOccurred()) + + reqInfo.RoutePool = routePool + + _, err = proxyRoundTripper.RoundTrip(req) + Expect(err).NotTo(HaveOccurred()) + + infoLogs := logger.Lines(zap.InfoLevel) + count := 0 + for i := 0; i < len(infoLogs); i++ { + if strings.Contains(infoLogs[i], "hash-based-routing-header-not-found") { + count++ + } + } + Expect(count).To(Equal(1)) + // Verify it still selects an endpoint + Expect(reqInfo.RouteEndpoint).ToNot(BeNil()) + }) + + Context("when sticky session cookies (JSESSIONID and VCAP_ID) are on the request", func() { + var ( + sessionCookie *http.Cookie + cookies []*http.Cookie + ) + + JustBeforeEach(func() { + sessionCookie = &http.Cookie{ + Name: StickyCookieKey, //JSESSIONID + } + transport.RoundTripStub = func(req *http.Request) (*http.Response, error) { + resp := &http.Response{StatusCode: http.StatusTeapot, Header: make(map[string][]string)} + //Attach the same JSESSIONID on to the response if it exists on the request + + if len(req.Cookies()) > 0 { + for _, cookie := range req.Cookies() { + if cookie.Name == StickyCookieKey { + resp.Header.Add(round_tripper.CookieHeader, cookie.String()) + return resp, nil + } + } + } + + sessionCookie.Value, _ = uuid.GenerateUUID() + resp.Header.Add(round_tripper.CookieHeader, sessionCookie.String()) + return resp, nil + } + resp, err := proxyRoundTripper.RoundTrip(req) + Expect(err).ToNot(HaveOccurred()) + + cookies = resp.Cookies() + Expect(cookies).To(HaveLen(2)) + + }) + + Context("when there is a JSESSIONID and __VCAP_ID__ set on the request", func() { + It("will always route to the instance specified with the __VCAP_ID__ cookie", func() { + + // Generate 20 random values for the hash header, so chance that all go to instanceID1 + // by accident is 0.33^20 + for i := 0; i < 20; i++ { + randomStr := make([]byte, 8) + for j := range randomStr { + randomStr[j] = byte('a' + rand.Intn(26)) + } + + req.Header.Set("X-Hash", string(randomStr)) + reqInfo, err := handlers.ContextRequestInfo(req) + req.AddCookie(&http.Cookie{Name: round_tripper.VcapCookieId, Value: "instanceID1"}) + req.AddCookie(&http.Cookie{Name: StickyCookieKey, Value: "abc"}) + + Expect(err).ToNot(HaveOccurred()) + reqInfo.RoutePool = routePool + + resp, err := proxyRoundTripper.RoundTrip(req) + Expect(err).ToNot(HaveOccurred()) + + new_cookies := resp.Cookies() + Expect(new_cookies).To(HaveLen(2)) + + for _, cookie := range new_cookies { + Expect(cookie.Name).To(SatisfyAny( + Equal(StickyCookieKey), + Equal(round_tripper.VcapCookieId), + )) + if cookie.Name == StickyCookieKey { + Expect(cookie.Value).To(Equal("abc")) + } else { + Expect(cookie.Value).To(Equal("instanceID1")) + } + } + + } + + }) + }) + }) + }) + Context("when endpoint timeout is not 0", func() { var reqCh chan *http.Request BeforeEach(func() { diff --git a/src/code.cloudfoundry.org/gorouter/route/hash_based.go b/src/code.cloudfoundry.org/gorouter/route/hash_based.go new file mode 100644 index 000000000..39d551eb8 --- /dev/null +++ b/src/code.cloudfoundry.org/gorouter/route/hash_based.go @@ -0,0 +1,141 @@ +package route + +import ( + "context" + "errors" + "log/slog" + "sync" + + log "code.cloudfoundry.org/gorouter/logger" +) + +// HashBased load balancing algorithm distributes requests based on a hash of a specific header value. +// The sticky session cookie has precedence over hash-based routing and the request should be routed to the instance stored in the cookie. +// If requests do not contain the hash-related header set configured for the hash-based route option, use the default load-balancing algorithm. +type HashBased struct { + lock *sync.Mutex + + logger *slog.Logger + pool *EndpointPool + lastEndpoint *Endpoint + + stickyEndpointID string + mustBeSticky bool + + HeaderValue string +} + +// NewHashBased initializes an endpoint iterator that selects endpoints based on a hash of a header value. +// The global properties locallyOptimistic and localAvailabilityZone will be ignored when using Hash-Based Routing. +func NewHashBased(logger *slog.Logger, p *EndpointPool, initial string, mustBeSticky bool, locallyOptimistic bool, localAvailabilityZone string) EndpointIterator { + return &HashBased{ + logger: logger, + pool: p, + lock: &sync.Mutex{}, + stickyEndpointID: initial, + mustBeSticky: mustBeSticky, + } +} + +// Next selects the next endpoint based on the hash of the header value. +// If a sticky session endpoint is available and not overloaded, it will be returned. +// If the request must be sticky and the sticky endpoint is unavailable or overloaded, nil will be returned. +// If no sticky session is present, the endpoint will be selected based on the hash of the header value. +// It returns the same endpoint for the same header value consistently. +// If the hash lookup fails or the endpoint is not found, nil will be returned. +func (h *HashBased) Next(attempt int) *Endpoint { + h.lock.Lock() + defer h.lock.Unlock() + + e := h.findEndpointIfStickySession() + if e == nil && h.mustBeSticky { + return nil + } + + if e != nil { + h.lastEndpoint = e + return e + } + + if h.pool.HashLookupTable == nil { + h.logger.Error("hash-based-routing-failed", slog.String("host", h.pool.host), log.ErrAttr(errors.New("Lookup table is empty"))) + return nil + } + + id, err := h.pool.HashLookupTable.Get(h.HeaderValue) + + if err != nil { + h.logger.Error( + "hash-based-routing-failed", + slog.String("host", h.pool.host), + log.ErrAttr(err), + ) + return nil + } + + h.logger.Debug( + "hash-based-routing", + slog.String("hash header value", h.HeaderValue), + slog.String("endpoint-id", id), + ) + + endpointElem := h.pool.findById(id) + if endpointElem == nil { + h.logger.Error("hash-based-routing-failed", slog.String("host", h.pool.host), log.ErrAttr(errors.New("Endpoint not found in pool")), slog.String("endpoint-id", id)) + return nil + } + + return endpointElem.endpoint +} + +// findEndpointIfStickySession checks if there is a sticky session endpoint and returns it if available. +// If the sticky session endpoint is overloaded, returns nil. +func (h *HashBased) findEndpointIfStickySession() *Endpoint { + var e *endpointElem + if h.stickyEndpointID != "" { + e = h.pool.findById(h.stickyEndpointID) + if e != nil && e.isOverloaded() { + if h.mustBeSticky { + if h.logger.Enabled(context.Background(), slog.LevelDebug) { + h.logger.Debug("endpoint-overloaded-but-request-must-be-sticky", e.endpoint.ToLogData()...) + } + return nil + } + e = nil + } + + if e == nil && h.mustBeSticky { + h.logger.Debug("endpoint-missing-but-request-must-be-sticky", slog.String("requested-endpoint", h.stickyEndpointID)) + return nil + } + + if !h.mustBeSticky { + h.logger.Debug("endpoint-missing-choosing-alternate", slog.String("requested-endpoint", h.stickyEndpointID)) + h.stickyEndpointID = "" + } + } + + if e != nil { + e.RLock() + defer e.RUnlock() + return e.endpoint + } + return nil +} + +// EndpointFailed notifies the endpoint pool that the last selected endpoint has failed. +func (h *HashBased) EndpointFailed(err error) { + if h.lastEndpoint != nil { + h.pool.EndpointFailed(h.lastEndpoint, err) + } +} + +// PreRequest increments the in-flight request count for the selected endpoint from current Gorouter. +func (h *HashBased) PreRequest(e *Endpoint) { + e.Stats.NumberConnections.Increment() +} + +// PostRequest decrements the in-flight request count for the selected endpoint from current Gorouter. +func (h *HashBased) PostRequest(e *Endpoint) { + e.Stats.NumberConnections.Decrement() +} diff --git a/src/code.cloudfoundry.org/gorouter/route/hash_based_test.go b/src/code.cloudfoundry.org/gorouter/route/hash_based_test.go new file mode 100644 index 000000000..1caaed19c --- /dev/null +++ b/src/code.cloudfoundry.org/gorouter/route/hash_based_test.go @@ -0,0 +1,155 @@ +package route_test + +import ( + "code.cloudfoundry.org/gorouter/config" + _ "errors" + "time" + + "code.cloudfoundry.org/gorouter/route" + "code.cloudfoundry.org/gorouter/test_util" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("HashBased", func() { + var ( + pool *route.EndpointPool + logger *test_util.TestLogger + ) + + BeforeEach(func() { + logger = test_util.NewTestLogger("test") + pool = route.NewPool(&route.PoolOpts{ + Logger: logger.Logger, + RetryAfterFailure: 2 * time.Minute, + Host: "", + ContextPath: "", + MaxConnsPerBackend: 0, + LoadBalancingAlgorithm: config.LOAD_BALANCE_HB, + }) + }) + + Describe("Next", func() { + + Context("when pool is empty", func() { + It("does not select an endpoint", func() { + iter := route.NewHashBased(logger.Logger, pool, "", false, false, "") + Expect(iter.Next(0)).To(BeNil()) + }) + }) + + Context("when pool has endpoints", func() { + var ( + endpoints []*route.Endpoint + ) + BeforeEach(func() { + e1 := route.NewEndpoint(&route.EndpointOpts{Host: "1.2.3.4", Port: 5678, LoadBalancingAlgorithm: "hash", HashHeaderName: "tenant-id", PrivateInstanceId: "ID1"}) + e2 := route.NewEndpoint(&route.EndpointOpts{Host: "2.2.3.4", Port: 5678, LoadBalancingAlgorithm: "hash", HashHeaderName: "tenant-id", PrivateInstanceId: "ID2"}) + endpoints = []*route.Endpoint{e1, e2} + for _, e := range endpoints { + pool.Put(e) + } + + }) + It("It returns the same endpoint for the same header value", func() { + iter := route.NewHashBased(logger.Logger, pool, "", false, false, "") + iter.(*route.HashBased).HeaderValue = "tenant-1" + first := iter.Next(0) + second := iter.Next(0) + Expect(first).NotTo(BeNil()) + Expect(second).NotTo(BeNil()) + Expect(first).To(Equal(second)) + }) + + It("It selects another instance for other hash header value", func() { + iter := route.NewHashBased(logger.Logger, pool, "", false, false, "") + iter.(*route.HashBased).HeaderValue = "example.com" + Expect(iter.Next(0)).NotTo(BeNil()) + Expect(iter.Next(0)).To(Equal(endpoints[1])) + Expect(iter.Next(0)).To(Equal(endpoints[1])) + Expect(iter.Next(0)).To(Equal(endpoints[1])) + }) + }) + + Context("when using sticky sessions", func() { + var ( + endpoints []*route.Endpoint + iter route.EndpointIterator + ) + + BeforeEach(func() { + e1 := route.NewEndpoint(&route.EndpointOpts{Host: "1.2.3.4", Port: 5678, LoadBalancingAlgorithm: "hash", PrivateInstanceId: "ID1"}) + e2 := route.NewEndpoint(&route.EndpointOpts{Host: "2.2.3.4", Port: 5678, LoadBalancingAlgorithm: "hash", PrivateInstanceId: "ID2"}) + e3 := route.NewEndpoint(&route.EndpointOpts{Host: "3.2.3.4", Port: 5678, LoadBalancingAlgorithm: "hash", HashHeaderName: "tenant-id", PrivateInstanceId: "ID3"}) + endpoints = []*route.Endpoint{e1, e2, e3} + for _, e := range endpoints { + pool.Put(e) + } + }) + + Context("when mustBeSticky is true", func() { + BeforeEach(func() { + iter = route.NewHashBased(logger.Logger, pool, "ID1", true, false, "") + }) + + It("returns the sticky endpoint when it exists", func() { + endpoint := iter.Next(0) + Expect(endpoint).NotTo(BeNil()) + Expect(endpoint.PrivateInstanceId).To(Equal("ID1")) + }) + + It("returns nil when sticky endpoint doesn't exist", func() { + iter = route.NewHashBased(logger.Logger, pool, "nonexistent-id", true, false, "") + Expect(iter.Next(0)).To(BeNil()) + }) + }) + + Context("when mustBeSticky is false", func() { + BeforeEach(func() { + iter = route.NewHashBased(logger.Logger, pool, "ID1", false, false, "") + }) + + It("returns the sticky endpoint when it exists", func() { + endpoint := iter.Next(0) + Expect(endpoint).NotTo(BeNil()) + Expect(endpoint.PrivateInstanceId).To(Equal("ID1")) + }) + + It("falls back to hash-based routing when sticky endpoint doesn't exist", func() { + iter = route.NewHashBased(logger.Logger, pool, "nonexistent-id", false, false, "") + hashIter := iter.(*route.HashBased) + hashIter.HeaderValue = "some-value" + endpoint := iter.Next(0) + Expect(endpoint).NotTo(BeNil()) + }) + }) + }) + }) + + Context("when testing PreRequest and PostRequest", func() { + var ( + endpoint *route.Endpoint + iter route.EndpointIterator + ) + + BeforeEach(func() { + endpoint = route.NewEndpoint(&route.EndpointOpts{Host: "1.2.3.4", Port: 5678, LoadBalancingAlgorithm: "hash", PrivateInstanceId: "ID1"}) + pool.Put(endpoint) + iter = route.NewHashBased(logger.Logger, pool, "", false, false, "") + }) + + It("increments connection count on PreRequest", func() { + initialCount := endpoint.Stats.NumberConnections.Count() + iter.PreRequest(endpoint) + Expect(endpoint.Stats.NumberConnections.Count()).To(Equal(initialCount + 1)) + }) + + It("decrements connection count on PostRequest", func() { + iter.PreRequest(endpoint) + initialCount := endpoint.Stats.NumberConnections.Count() + iter.PostRequest(endpoint) + Expect(endpoint.Stats.NumberConnections.Count()).To(Equal(initialCount - 1)) + }) + }) + +}) diff --git a/src/code.cloudfoundry.org/gorouter/route/maglev.go b/src/code.cloudfoundry.org/gorouter/route/maglev.go new file mode 100644 index 000000000..9b70aaa09 --- /dev/null +++ b/src/code.cloudfoundry.org/gorouter/route/maglev.go @@ -0,0 +1,212 @@ +package route + +import ( + "errors" + "fmt" + "hash/fnv" + "log/slog" + "sort" + "strconv" + "strings" + "sync" +) + +const ( + // lookupTableSize is prime number for the size of the maglev lookup table, which should be approximately 100x + // the number of expected endpoints + lookupTableSize uint64 = 1801 +) + +// Maglev implementation of consistent hashing algorithm described in "Maglev: A Fast and Reliable Software Network +// Load Balancer" (https://storage.googleapis.com/gweb-research2023-media/pubtools/2904.pdf) +type Maglev struct { + logger *slog.Logger + permutationTable [][]uint64 + lookupTable []int + endpointList []string + lock *sync.RWMutex +} + +// NewMaglev initializes an empty maglev lookupTable table +func NewMaglev(logger *slog.Logger) *Maglev { + return &Maglev{ + lock: &sync.RWMutex{}, + lookupTable: make([]int, lookupTableSize), + endpointList: make([]string, 0, 2), + permutationTable: make([][]uint64, 0, 2), + logger: logger, + } +} + +// Add a new endpoint to lookupTable if it's not already contained. +func (m *Maglev) Add(endpoint string) { + m.lock.Lock() + defer m.lock.Unlock() + + if lookupTableSize == uint64(len(m.endpointList)) { + m.logger.Warn("maglev-add-lookuptable-capacity-exceeded", slog.String("endpoint-id", endpoint)) + return + } + + index := sort.SearchStrings(m.endpointList, endpoint) + if index < len(m.endpointList) && m.endpointList[index] == endpoint { + m.logger.Debug("maglev-add-lookuptable-endpoint-exists", slog.String("endpoint-id", endpoint), slog.Int("current-endpoints", len(m.endpointList))) + return + } + + m.endpointList = append(m.endpointList, "") + copy(m.endpointList[index+1:], m.endpointList[index:]) + m.endpointList[index] = endpoint + + m.generatePermutation(endpoint) + m.fillLookupTable() +} + +// Remove an endpoint from lookupTable if it's contained. +func (m *Maglev) Remove(endpoint string) { + m.lock.Lock() + defer m.lock.Unlock() + + index := sort.SearchStrings(m.endpointList, endpoint) + if index >= len(m.endpointList) || m.endpointList[index] != endpoint { + m.logger.Debug("maglev-remove-endpoint-not-found", slog.String("endpoint-id", endpoint)) + return + } + + m.endpointList = append(m.endpointList[:index], m.endpointList[index+1:]...) + m.permutationTable = append(m.permutationTable[:index], m.permutationTable[index+1:]...) + + m.fillLookupTable() +} + +// Get endpoint by specified request header value +// Todo: Overload scenario: Get should return an index rather than an instance, +// so that we can iterate to the next endpoint in case it is overloaded (e.g. via another +// helper function that resolves the endpoint via the index) +func (m *Maglev) Get(headerValue string) (string, error) { + m.lock.RLock() + defer m.lock.RUnlock() + + if len(m.endpointList) == 0 { + return "", errors.New("maglev-get-endpoint-no-endpoints") + } + key := m.hashKey(headerValue) + return m.endpointList[m.lookupTable[key%lookupTableSize]], nil +} + +func (m *Maglev) hashKey(headerValue string) uint64 { + return m.calculateFNVHash64(headerValue) +} + +// generatePermutation creates a permutationTable of the lookup table for each endpoint +func (m *Maglev) generatePermutation(endpoint string) { + pos := sort.SearchStrings(m.endpointList, endpoint) + if pos == len(m.endpointList) { + m.logger.Debug("maglev-permutation-no-endpoints") + return + } + + endpointHash := m.calculateFNVHash64(endpoint) + offset := endpointHash % lookupTableSize + skip := (endpointHash % (lookupTableSize - 1)) + 1 + + permutationForEndpoint := make([]uint64, lookupTableSize) + for j := uint64(0); j < lookupTableSize; j++ { + permutationForEndpoint[j] = (offset + j*skip) % lookupTableSize + } + + // insert permutationForEndpoint at position pos, shifting the rest to the right + m.permutationTable = append(m.permutationTable, nil) + copy(m.permutationTable[pos+1:], m.permutationTable[pos:]) + m.permutationTable[pos] = permutationForEndpoint + +} + +func (m *Maglev) fillLookupTable() { + if len(m.endpointList) == 0 { + return + } + + numberOfEndpoints := len(m.endpointList) + next := make([]int, numberOfEndpoints) + entry := make([]int, lookupTableSize) + for j := range entry { + entry[j] = -1 + } + + for n := uint64(0); n <= lookupTableSize; { + for i := 0; i < numberOfEndpoints; i++ { + candidate := m.findNextAvailableSlot(i, next, entry) + entry[candidate] = int(i) + next[i] = next[i] + 1 + n++ + + if n == lookupTableSize { + m.lookupTable = entry + return + } + } + } +} + +func (m *Maglev) findNextAvailableSlot(i int, next []int, entry []int) uint64 { + candidate := m.permutationTable[i][next[i]] + for entry[candidate] >= 0 { + next[i]++ + if next[i] >= len(m.permutationTable[i]) { + // This should not happen in a properly functioning Maglev algorithm, + // but we add this safety check to prevent panic + m.logger.Error("maglev-permutation-table-exhausted", + slog.Int("endpoint-index", i), + slog.Int("next-value", next[i]), + slog.Int("table-size", len(m.permutationTable[i]))) + // Reset to beginning of permutation table as fallback + next[i] = 0 + } + candidate = m.permutationTable[i][next[i]] + } + return candidate +} + +// Getters for unit tests +func (m *Maglev) GetEndpointList() []string { + m.lock.RLock() + defer m.lock.RUnlock() + return append([]string(nil), m.endpointList...) +} + +func (m *Maglev) GetLookupTable() []int { + m.lock.RLock() + defer m.lock.RUnlock() + return append([]int(nil), m.lookupTable...) +} + +func (m *Maglev) GetPermutationTable() [][]uint64 { + m.lock.RLock() + defer m.lock.RUnlock() + copied := make([][]uint64, len(m.permutationTable)) + for i, v := range m.permutationTable { + copied[i] = append([]uint64(nil), v...) + } + return copied +} + +func (m *Maglev) GetLookupTableSize() uint64 { + return lookupTableSize +} + +// TODO: Remove in final version +func (m *Maglev) PrintLookupTable() string { + strArr := make([]string, len(m.lookupTable)) + for i, value := range m.lookupTable { + strArr[i] = strconv.Itoa(value) + } + return fmt.Sprintf("[%s]", strings.Join(strArr, ", ")) +} + +// calculateFNVHash64 computes a hash using the non-cryptographic FNV hash algorithm. +func (m *Maglev) calculateFNVHash64(key string) uint64 { + h := fnv.New64a() + _, _ = h.Write([]byte(key)) + return h.Sum64() +} diff --git a/src/code.cloudfoundry.org/gorouter/route/maglev_test.go b/src/code.cloudfoundry.org/gorouter/route/maglev_test.go new file mode 100644 index 000000000..ae8af9d07 --- /dev/null +++ b/src/code.cloudfoundry.org/gorouter/route/maglev_test.go @@ -0,0 +1,257 @@ +package route_test + +import ( + "fmt" + "strconv" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + + "code.cloudfoundry.org/gorouter/route" + "code.cloudfoundry.org/gorouter/test_util" +) + +var _ = Describe("Maglev", func() { + var ( + logger *test_util.TestLogger + maglev *route.Maglev + ) + + BeforeEach(func() { + logger = test_util.NewTestLogger("test") + + maglev = route.NewMaglev(logger.Logger) + }) + + Describe("NewMaglev", func() { + It("should create a new Maglev instance", func() { + Expect(maglev).NotTo(BeNil()) + }) + }) + + Describe("Add", func() { + Context("when adding a new backend", func() { + It("should add the backend successfully", func() { + maglev.Add("backend1") + + Expect(maglev.GetEndpointList()).To(HaveLen(1)) + Expect(maglev.GetLookupTable()).To(HaveLen(int(maglev.GetLookupTableSize()))) + Expect(maglev.GetPermutationTable()).To(HaveLen(1)) + Expect(maglev.GetPermutationTable()[0]).To(HaveLen(int(maglev.GetLookupTableSize()))) + + result, err := maglev.Get("test-key") + Expect(err).NotTo(HaveOccurred()) + Expect(result).To(Equal("backend1")) + }) + }) + + Context("when adding a backend twice", func() { + It("should skip adding subsequent adds", func() { + maglev.Add("backend1") + maglev.Add("backend1") + + Expect(maglev.GetEndpointList()).To(HaveLen(1)) + Expect(maglev.GetLookupTable()).To(HaveLen(int(maglev.GetLookupTableSize()))) + Expect(maglev.GetPermutationTable()).To(HaveLen(1)) + Expect(maglev.GetPermutationTable()[0]).To(HaveLen(int(maglev.GetLookupTableSize()))) + + result, err := maglev.Get("test-key") + Expect(err).NotTo(HaveOccurred()) + Expect(result).To(Equal("backend1")) + }) + }) + + Context("when adding multiple backends", func() { + It("should make all backends reachable", func() { + maglev.Add("backend1") + maglev.Add("backend2") + maglev.Add("backend3") + + Expect(maglev.GetEndpointList()).To(HaveLen(3)) + Expect(maglev.GetLookupTable()).To(HaveLen(int(maglev.GetLookupTableSize()))) + Expect(maglev.GetPermutationTable()).To(HaveLen(len(maglev.GetEndpointList()))) + for i := range len(maglev.GetEndpointList()) { + Expect(maglev.GetPermutationTable()[i]).To(HaveLen(int(maglev.GetLookupTableSize()))) + } + + backends := make(map[string]bool) + for i := 0; i < 1000; i++ { + result, err := maglev.Get(string(rune(i))) + Expect(err).NotTo(HaveOccurred()) + backends[result] = true + } + + Expect(backends["backend1"]).To(BeTrue()) + Expect(backends["backend2"]).To(BeTrue()) + Expect(backends["backend3"]).To(BeTrue()) + }) + }) + }) + + Describe("Remove", func() { + Context("when removing an existing backend", func() { + It("should remove the backend successfully", func() { + maglev.Add("backend1") + maglev.Add("backend2") + + maglev.Remove("backend1") + + Expect(maglev.GetEndpointList()).To(HaveLen(1)) + Expect(maglev.GetLookupTable()).To(HaveLen(int(maglev.GetLookupTableSize()))) + Expect(maglev.GetPermutationTable()).To(HaveLen(1)) + Expect(maglev.GetPermutationTable()[0]).To(HaveLen(int(maglev.GetLookupTableSize()))) + + }) + }) + + Context("when removing a non-existent backend", func() { + It("should handle gracefully without error", func() { + maglev.Add("backend1") + + Expect(func() { maglev.Remove("non-existent") }).NotTo(Panic()) + + Expect(maglev.GetEndpointList()).To(HaveLen(1)) + Expect(maglev.GetLookupTable()).To(HaveLen(int(maglev.GetLookupTableSize()))) + Expect(maglev.GetPermutationTable()).To(HaveLen(1)) + Expect(maglev.GetPermutationTable()[0]).To(HaveLen(int(maglev.GetLookupTableSize()))) + }) + }) + }) + + Describe("Get", func() { + Context("when no backends were added", func() { + It("should return an error", func() { + _, err := maglev.Get("test-key") + Expect(err).To(HaveOccurred()) + }) + }) + + Context("when backends are added", func() { + BeforeEach(func() { + maglev.Add("backend1") + maglev.Add("backend2") + }) + + It("should return consistent results for the same key", func() { + var counter = make(map[string]int) + var result1 string + var err error + for _ = range 100 { + result1, err = maglev.Get("consistent-key") + Expect(err).NotTo(HaveOccurred()) + counter[result1]++ + } + + Expect(counter[result1]).To(Equal(100)) + }) + + It("should distribute keys across backends", func() { + maglev.Add("backend1") + maglev.Add("backend2") + maglev.Add("backend3") + + distribution := make(map[string]int) + for i := range 1000 { + result, err := maglev.Get(string(rune(i))) + Expect(err).NotTo(HaveOccurred()) + distribution[result]++ + } + + Expect(distribution["backend1"]).To(BeNumerically(">", 0)) + Expect(distribution["backend2"]).To(BeNumerically(">", 0)) + Expect(distribution["backend3"]).To(BeNumerically(">", 0)) + }) + }) + + Context("when backends are removed", func() { + BeforeEach(func() { + maglev.Add("backend1") + maglev.Add("backend2") + maglev.Remove("backend1") + }) + + It("should not return the removed backend", func() { + for _ = range 100 { + endpoint, err := maglev.Get("consistent-key") + Expect(err).NotTo(HaveOccurred()) + Expect(endpoint).To(Equal("backend2")) + } + }) + }) + }) + + Describe("Consistency", func() { + // We test that at most half the keys are reassigned to new backends, when one backend is added. + // This ensures a minimal level of consistency. + It("should minimize disruption when adding backends", func() { + for i := range 10 { + maglev.Add(fmt.Sprintf("backend%d", i+1)) + } + keys := make([]string, 1000) + for i := range keys { + keys[i] = fmt.Sprintf("key%d", i+1) + } + + initialMappings := make(map[string]string) + + for _, key := range keys { + backend, err := maglev.Get(key) + Expect(err).NotTo(HaveOccurred()) + initialMappings[key] = backend + } + + maglev.Add("newbackend") + + changedMappings := 0 + for _, key := range keys { + backend, err := maglev.Get(key) + Expect(err).NotTo(HaveOccurred()) + if initialMappings[key] != backend { + changedMappings++ + } + } + + Expect(changedMappings).To(BeNumerically("<=", len(keys)/2)) + }) + }) + + Describe("Concurrency", func() { + It("should handle concurrent reads safely", func() { + maglev.Add("backend1") + + done := make(chan bool, 10) + for i := 0; i < 10; i++ { + go func() { + defer GinkgoRecover() + for j := 0; j < 100; j++ { + _, err := maglev.Get("test-key") + Expect(err).NotTo(HaveOccurred()) + } + done <- true + }() + } + + for i := 0; i < 10; i++ { + Eventually(done).Should(Receive()) + } + }) + It("should handle concurrent endpoint registrations safely", func() { + done := make(chan bool, 10) + for i := 0; i < 10; i++ { + go func() { + defer GinkgoRecover() + for j := 0; j < 100; j++ { + Expect(func() { maglev.Add("endpoint" + strconv.Itoa(j)) }).NotTo(Panic()) + } + done <- true + }() + } + + for i := 0; i < 10; i++ { + Eventually(done).Should(Receive()) + } + Expect(len(maglev.GetEndpointList())).To(Equal(100)) + }) + + }) +}) diff --git a/src/code.cloudfoundry.org/gorouter/route/pool.go b/src/code.cloudfoundry.org/gorouter/route/pool.go index f089fc15b..b9f491798 100644 --- a/src/code.cloudfoundry.org/gorouter/route/pool.go +++ b/src/code.cloudfoundry.org/gorouter/route/pool.go @@ -74,6 +74,21 @@ type ProxyRoundTripper interface { CancelRequest(*http.Request) } +type HashRoutingProperties struct { + Header string + BalanceFactor float64 +} + +func (hrp *HashRoutingProperties) Equal(hrp2 *HashRoutingProperties) bool { + if hrp == nil && hrp2 == nil { + return true + } + if hrp == nil || hrp2 == nil { + return false + } + return hrp.Header == hrp2.Header && hrp.BalanceFactor == hrp2.BalanceFactor +} + type Endpoint struct { ApplicationId string AvailabilityZone string @@ -186,6 +201,8 @@ type EndpointPool struct { logger *slog.Logger updatedAt time.Time LoadBalancingAlgorithm string + HashRoutingProperties *HashRoutingProperties + HashLookupTable *Maglev } type EndpointOpts struct { @@ -248,10 +265,12 @@ type PoolOpts struct { MaxConnsPerBackend int64 Logger *slog.Logger LoadBalancingAlgorithm string + HashHeader string + HashBalanceFactor float64 } func NewPool(opts *PoolOpts) *EndpointPool { - return &EndpointPool{ + pool := &EndpointPool{ endpoints: make([]*endpointElem, 0, 1), index: make(map[string]*endpointElem), retryAfterFailure: opts.RetryAfterFailure, @@ -264,6 +283,14 @@ func NewPool(opts *PoolOpts) *EndpointPool { updatedAt: time.Now(), LoadBalancingAlgorithm: opts.LoadBalancingAlgorithm, } + if pool.LoadBalancingAlgorithm == config.LOAD_BALANCE_HB { + pool.HashLookupTable = NewMaglev(opts.Logger) + pool.HashRoutingProperties = &HashRoutingProperties{ + Header: opts.HashHeader, + BalanceFactor: opts.HashBalanceFactor, + } + } + return pool } func PoolsMatch(p1, p2 *EndpointPool) bool { @@ -320,7 +347,6 @@ func (p *EndpointPool) Put(endpoint *Endpoint) PoolPutResult { // new one. e.Lock() defer e.Unlock() - oldEndpoint := e.endpoint e.endpoint = endpoint @@ -336,6 +362,9 @@ func (p *EndpointPool) Put(endpoint *Endpoint) PoolPutResult { p.RouteSvcUrl = e.endpoint.RouteServiceUrl p.setPoolLoadBalancingAlgorithm(e.endpoint) e.updated = time.Now() + if p.LoadBalancingAlgorithm == config.LOAD_BALANCE_HB { + p.HashLookupTable.Add(e.endpoint.PrivateInstanceId) + } p.Update() return EndpointUpdated @@ -348,7 +377,6 @@ func (p *EndpointPool) Put(endpoint *Endpoint) PoolPutResult { updated: time.Now(), maxConnsPerBackend: p.maxConnsPerBackend, } - p.endpoints = append(p.endpoints, e) p.index[endpoint.CanonicalAddr()] = e @@ -356,6 +384,9 @@ func (p *EndpointPool) Put(endpoint *Endpoint) PoolPutResult { p.RouteSvcUrl = e.endpoint.RouteServiceUrl p.setPoolLoadBalancingAlgorithm(e.endpoint) + if p.LoadBalancingAlgorithm == config.LOAD_BALANCE_HB { + p.HashLookupTable.Add(e.endpoint.PrivateInstanceId) + } p.Update() return EndpointAdded @@ -433,6 +464,11 @@ func (p *EndpointPool) removeEndpoint(e *endpointElem) { delete(p.index, e.endpoint.CanonicalAddr()) delete(p.index, e.endpoint.PrivateInstanceId) p.Update() + + if p.LoadBalancingAlgorithm == config.LOAD_BALANCE_HB { + p.HashLookupTable.Remove(e.endpoint.PrivateInstanceId) + } + } func (p *EndpointPool) Endpoints(logger *slog.Logger, initial string, mustBeSticky bool, azPreference string, az string) EndpointIterator { @@ -443,6 +479,9 @@ func (p *EndpointPool) Endpoints(logger *slog.Logger, initial string, mustBeStic case config.LOAD_BALANCE_RR: logger.Debug("endpoint-iterator-with-round-robin-lb-algo") return NewRoundRobin(logger, p, initial, mustBeSticky, azPreference == config.AZ_PREF_LOCAL, az) + case config.LOAD_BALANCE_HB: + logger.Debug("endpoint-iterator-with-hash-based-lb-algo") + return NewHashBased(logger, p, initial, mustBeSticky, azPreference == config.AZ_PREF_LOCAL, az) default: logger.Error("invalid-pool-load-balancing-algorithm", slog.String("poolLBAlgorithm", p.LoadBalancingAlgorithm), @@ -452,6 +491,23 @@ func (p *EndpointPool) Endpoints(logger *slog.Logger, initial string, mustBeStic } } +func (p *EndpointPool) FallBackToDefaultLoadBalancing(defaultLBAlgo string, logger *slog.Logger, initial string, mustBeSticky bool, azPreference string, az string) EndpointIterator { + logger.Info("hash-based-routing-header-not-found", + slog.String("poolLBAlgorithm", p.LoadBalancingAlgorithm), + slog.String("Host", p.host), + slog.String("Path", p.contextPath)) + + switch defaultLBAlgo { + case config.LOAD_BALANCE_LC: + logger.Debug("endpoint-iterator-with-least-connection-lb-algo") + return NewLeastConnection(logger, p, initial, mustBeSticky, azPreference == config.AZ_PREF_LOCAL, az) + case config.LOAD_BALANCE_RR: + logger.Debug("endpoint-iterator-with-round-robin-lb-algo") + return NewRoundRobin(logger, p, initial, mustBeSticky, azPreference == config.AZ_PREF_LOCAL, az) + } + return NewRoundRobin(logger, p, initial, mustBeSticky, azPreference == config.AZ_PREF_LOCAL, az) +} + func (p *EndpointPool) NumEndpoints() int { p.Lock() defer p.Unlock() @@ -561,12 +617,13 @@ func (p *EndpointPool) MarshalJSON() ([]byte, error) { // setPoolLoadBalancingAlgorithm overwrites the load balancing algorithm of a pool by that of a specified endpoint, if that is valid. func (p *EndpointPool) setPoolLoadBalancingAlgorithm(endpoint *Endpoint) { - if len(endpoint.LoadBalancingAlgorithm) > 0 && endpoint.LoadBalancingAlgorithm != p.LoadBalancingAlgorithm { + if endpoint.LoadBalancingAlgorithm != "" && endpoint.LoadBalancingAlgorithm != p.LoadBalancingAlgorithm { if config.IsLoadBalancingAlgorithmValid(endpoint.LoadBalancingAlgorithm) { p.LoadBalancingAlgorithm = endpoint.LoadBalancingAlgorithm p.logger.Debug("setting-pool-load-balancing-algorithm-to-that-of-an-endpoint", slog.String("endpointLBAlgorithm", endpoint.LoadBalancingAlgorithm), slog.String("poolLBAlgorithm", p.LoadBalancingAlgorithm)) + p.prepareHashBasedRouting(endpoint) } else { p.logger.Error("invalid-endpoint-load-balancing-algorithm-provided-keeping-pool-lb-algo", slog.String("endpointLBAlgorithm", endpoint.LoadBalancingAlgorithm), @@ -575,6 +632,20 @@ func (p *EndpointPool) setPoolLoadBalancingAlgorithm(endpoint *Endpoint) { } } +func (p *EndpointPool) prepareHashBasedRouting(endpoint *Endpoint) { + if p.LoadBalancingAlgorithm != config.LOAD_BALANCE_HB { + return + } + if p.HashLookupTable == nil { + p.HashLookupTable = NewMaglev(p.logger) + } + p.HashRoutingProperties = &HashRoutingProperties{ + Header: endpoint.HashHeaderName, + BalanceFactor: endpoint.HashBalanceFactor, + } + +} + func (e *endpointElem) failed() { t := time.Now() e.failedAt = &t diff --git a/src/code.cloudfoundry.org/gorouter/route/pool_test.go b/src/code.cloudfoundry.org/gorouter/route/pool_test.go index 31da6c8d7..7709a1d8b 100644 --- a/src/code.cloudfoundry.org/gorouter/route/pool_test.go +++ b/src/code.cloudfoundry.org/gorouter/route/pool_test.go @@ -428,6 +428,46 @@ var _ = Describe("EndpointPool", func() { Expect(pool.LoadBalancingAlgorithm).To(Equal(config.LOAD_BALANCE_RR)) }) }) + + Context("When switching to hash-based routing", func() { + It("will create the maglev table and add the endpoint", func() { + pool := route.NewPool(&route.PoolOpts{ + Logger: logger.Logger, + LoadBalancingAlgorithm: config.LOAD_BALANCE_RR, + }) + + endpointOpts := route.EndpointOpts{ + Host: "host-1", + Port: 1234, + RouteServiceUrl: "url", + LoadBalancingAlgorithm: config.LOAD_BALANCE_RR, + } + + initalEndpoint := route.NewEndpoint(&endpointOpts) + + pool.Put(initalEndpoint) + Expect(pool.LoadBalancingAlgorithm).To(Equal(config.LOAD_BALANCE_RR)) + + endpointOptsHash := route.EndpointOpts{ + Host: "host-1", + Port: 1234, + RouteServiceUrl: "url", + LoadBalancingAlgorithm: config.LOAD_BALANCE_HB, + HashBalanceFactor: 1.25, + HashHeaderName: "X-Tenant", + } + + hashEndpoint := route.NewEndpoint(&endpointOptsHash) + + pool.Put(hashEndpoint) + Expect(pool.LoadBalancingAlgorithm).To(Equal(config.LOAD_BALANCE_HB)) + Expect(pool.HashLookupTable).ToNot(BeNil()) + Expect(pool.HashLookupTable.GetEndpointList()).To(HaveLen(1)) + Expect(pool.HashLookupTable.GetEndpointList()[0]).To(Equal(hashEndpoint.PrivateInstanceId)) + }) + + }) + }) Context("RouteServiceUrl", func() { From 010c4e0b762e11b347ae2149f6495527457f669f Mon Sep 17 00:00:00 2001 From: Clemens Hoffmann Date: Wed, 29 Oct 2025 14:32:37 +0100 Subject: [PATCH 2/4] Add LICENSE information for maglev.go --- src/code.cloudfoundry.org/gorouter/route/maglev.go | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/code.cloudfoundry.org/gorouter/route/maglev.go b/src/code.cloudfoundry.org/gorouter/route/maglev.go index 9b70aaa09..6085c7d0f 100644 --- a/src/code.cloudfoundry.org/gorouter/route/maglev.go +++ b/src/code.cloudfoundry.org/gorouter/route/maglev.go @@ -1,5 +1,15 @@ package route +/****************************************************************************** + * Original github.com/kkdai/maglev/maglev.go + * + * Copyright (c) 2019 Evan Lin (github.com/kkdai) + * + * This program and the accompanying materials are made available under + * the terms of the Apache License, Version 2.0 which is available at + * http://www.apache.org/licenses/LICENSE-2.0. + ******************************************************************************/ + import ( "errors" "fmt" From 2d02a51b6e76b929483c5673ca47323c9718852e Mon Sep 17 00:00:00 2001 From: Tamara Boehm Date: Wed, 22 Oct 2025 08:54:13 +0200 Subject: [PATCH 3/4] Implement overflow traffic --- .../gorouter/route/hash_based.go | 128 +++++++-- .../gorouter/route/hash_based_test.go | 261 +++++++++++++++++- .../gorouter/route/maglev.go | 58 +++- .../gorouter/route/maglev_test.go | 128 +++++++-- .../gorouter/route/pool.go | 17 +- 5 files changed, 525 insertions(+), 67 deletions(-) diff --git a/src/code.cloudfoundry.org/gorouter/route/hash_based.go b/src/code.cloudfoundry.org/gorouter/route/hash_based.go index 39d551eb8..ff98f072a 100644 --- a/src/code.cloudfoundry.org/gorouter/route/hash_based.go +++ b/src/code.cloudfoundry.org/gorouter/route/hash_based.go @@ -15,9 +15,10 @@ import ( type HashBased struct { lock *sync.Mutex - logger *slog.Logger - pool *EndpointPool - lastEndpoint *Endpoint + logger *slog.Logger + pool *EndpointPool + lastEndpoint *Endpoint + lastLookupTableIndex uint64 stickyEndpointID string mustBeSticky bool @@ -47,14 +48,14 @@ func (h *HashBased) Next(attempt int) *Endpoint { h.lock.Lock() defer h.lock.Unlock() - e := h.findEndpointIfStickySession() - if e == nil && h.mustBeSticky { + endpoint := h.findEndpointIfStickySession() + if endpoint == nil && h.mustBeSticky { return nil } - if e != nil { - h.lastEndpoint = e - return e + if endpoint != nil { + h.lastEndpoint = endpoint + return endpoint } if h.pool.HashLookupTable == nil { @@ -62,30 +63,92 @@ func (h *HashBased) Next(attempt int) *Endpoint { return nil } - id, err := h.pool.HashLookupTable.Get(h.HeaderValue) + if attempt == 0 || h.lastLookupTableIndex == 0 { + initialLookupTableIndex, _, err := h.pool.HashLookupTable.GetInstanceForHashHeader(h.HeaderValue) - if err != nil { - h.logger.Error( - "hash-based-routing-failed", - slog.String("host", h.pool.host), - log.ErrAttr(err), - ) - return nil + if err != nil { + h.logger.Error( + "hash-based-routing-failed", + slog.String("host", h.pool.host), + log.ErrAttr(err), + ) + return nil + } + + endpoint = h.findEndpoint(initialLookupTableIndex, attempt) + } else { + // On retries, start looking from the next index in the lookup table + nextIndex := (h.lastLookupTableIndex + 1) % h.pool.HashLookupTable.GetLookupTableSize() + endpoint = h.findEndpoint(nextIndex, attempt) } - h.logger.Debug( - "hash-based-routing", - slog.String("hash header value", h.HeaderValue), - slog.String("endpoint-id", id), - ) + if endpoint != nil { + h.lastEndpoint = endpoint + } + return endpoint +} - endpointElem := h.pool.findById(id) - if endpointElem == nil { - h.logger.Error("hash-based-routing-failed", slog.String("host", h.pool.host), log.ErrAttr(errors.New("Endpoint not found in pool")), slog.String("endpoint-id", id)) +func (h *HashBased) findEndpoint(index uint64, attempt int) *Endpoint { + maxIterations := len(h.pool.endpoints) + if maxIterations == 0 { return nil } - return endpointElem.endpoint + // Ensure we don't exceed the lookup table size + lookupTableSize := h.pool.HashLookupTable.GetLookupTableSize() + + // Normalize index + currentIndex := index % lookupTableSize + // Keep track of endpoints already visited, to avoid visiting them twice + visitedEndpoints := make(map[string]bool) + + numberOfEndpoints := len(h.pool.HashLookupTable.GetEndpointList()) + + lastEndpointPrivateId := "" + if attempt > 0 && h.lastEndpoint != nil { + lastEndpointPrivateId = h.lastEndpoint.PrivateInstanceId + } + + // abort when we have visited all available endpoints unsuccessfully + for len(visitedEndpoints) < numberOfEndpoints { + id := h.pool.HashLookupTable.GetEndpointId(currentIndex) + + if visitedEndpoints[id] || id == lastEndpointPrivateId { + currentIndex = (currentIndex + 1) % lookupTableSize + continue + } + visitedEndpoints[id] = true + + endpointElem := h.pool.findById(id) + if endpointElem == nil { + h.logger.Error("hash-based-routing-failed", slog.String("host", h.pool.host), log.ErrAttr(errors.New("Endpoint not found in pool")), slog.String("endpoint-id", id)) + currentIndex = (currentIndex + 1) % lookupTableSize + continue + } + + lastEndpointPrivateId = id + + e := endpointElem.endpoint + if h.pool.HashRoutingProperties.BalanceFactor <= 0 || !h.isOverloaded(e) { + h.lastLookupTableIndex = currentIndex + return e + } + + currentIndex = (currentIndex + 1) % lookupTableSize + } + // All endpoints checked and overloaded or not found + h.logger.Error("hash-based-routing-failed", slog.String("host", h.pool.host), log.ErrAttr(errors.New("All endpoints are overloaded"))) + return nil +} + +func (h *HashBased) isOverloaded(e *Endpoint) bool { + avgLoad := h.CalculateAverageLoad() + balanceFactor := h.pool.HashRoutingProperties.BalanceFactor + if float64(e.Stats.NumberConnections.Count())/avgLoad > balanceFactor { + h.logger.Info("hash-based-routing-endpoint-overloaded", slog.String("host", h.pool.host), slog.String("endpoint-id", e.PrivateInstanceId), slog.Int64("endpoint-connections", e.Stats.NumberConnections.Count()), slog.Float64("average-load", avgLoad)) + return true + } + return false } // findEndpointIfStickySession checks if there is a sticky session endpoint and returns it if available. @@ -139,3 +202,18 @@ func (h *HashBased) PreRequest(e *Endpoint) { func (h *HashBased) PostRequest(e *Endpoint) { e.Stats.NumberConnections.Decrement() } + +func (h *HashBased) CalculateAverageLoad() float64 { + if len(h.pool.endpoints) == 0 { + return 0 + } + + var currentInFlightRequestCount int64 + for _, endpointElem := range h.pool.endpoints { + endpointElem.RLock() + currentInFlightRequestCount += endpointElem.endpoint.Stats.NumberConnections.Count() + endpointElem.RUnlock() + } + + return float64(currentInFlightRequestCount) / float64(len(h.pool.endpoints)) +} diff --git a/src/code.cloudfoundry.org/gorouter/route/hash_based_test.go b/src/code.cloudfoundry.org/gorouter/route/hash_based_test.go index 1caaed19c..26df2c5b4 100644 --- a/src/code.cloudfoundry.org/gorouter/route/hash_based_test.go +++ b/src/code.cloudfoundry.org/gorouter/route/hash_based_test.go @@ -1,10 +1,12 @@ package route_test import ( - "code.cloudfoundry.org/gorouter/config" _ "errors" + "hash/fnv" "time" + "code.cloudfoundry.org/gorouter/config" + "code.cloudfoundry.org/gorouter/route" "code.cloudfoundry.org/gorouter/test_util" . "github.com/onsi/ginkgo/v2" @@ -24,8 +26,9 @@ var _ = Describe("HashBased", func() { RetryAfterFailure: 2 * time.Minute, Host: "", ContextPath: "", - MaxConnsPerBackend: 0, + MaxConnsPerBackend: 500, LoadBalancingAlgorithm: config.LOAD_BALANCE_HB, + HashHeader: "tenant-id", }) }) @@ -60,14 +63,110 @@ var _ = Describe("HashBased", func() { Expect(second).NotTo(BeNil()) Expect(first).To(Equal(second)) }) + }) + + Context("when endpoint overloaded", func() { + var ( + endpoints []*route.Endpoint + e1 *route.Endpoint + e2 *route.Endpoint + e3 *route.Endpoint + ) + It("It returns the next endpoint for the same header value when balancer factor set", func() { + e1 = route.NewEndpoint(&route.EndpointOpts{Host: "1.2.3.4", Port: 5678, LoadBalancingAlgorithm: "hash", HashHeaderName: "tenant-id", HashBalanceFactor: 1.2, PrivateInstanceId: "ID1"}) + e2 = route.NewEndpoint(&route.EndpointOpts{Host: "2.2.3.4", Port: 5678, LoadBalancingAlgorithm: "hash", HashHeaderName: "tenant-id", HashBalanceFactor: 1.2, PrivateInstanceId: "ID2"}) + e3 = route.NewEndpoint(&route.EndpointOpts{Host: "3.2.3.4", Port: 5678, LoadBalancingAlgorithm: "hash", HashHeaderName: "tenant-id", HashBalanceFactor: 1.2, PrivateInstanceId: "ID3"}) + endpoints = []*route.Endpoint{e1, e2, e3} + for _, e := range endpoints { + pool.Put(e) + } + iter := route.NewHashBased(logger.Logger, pool, "", false, false, "") + iter.(*route.HashBased).HeaderValue = "tenant-1" + first := iter.Next(0) + Expect(iter.Next(0)).To(Equal(first)) + for i := 0; i < 6; i++ { + iter.PreRequest(first) + } + second := iter.Next(0) + Expect(second).NotTo(Equal(first)) + }) + It("It returns the same overloaded endpoint for the same header value when balancer factor not set", func() { + e1 = route.NewEndpoint(&route.EndpointOpts{Host: "1.2.3.4", Port: 5678, LoadBalancingAlgorithm: "hash", HashHeaderName: "tenant-id", HashBalanceFactor: 0, PrivateInstanceId: "ID1"}) + e2 = route.NewEndpoint(&route.EndpointOpts{Host: "2.2.3.4", Port: 5678, LoadBalancingAlgorithm: "hash", HashHeaderName: "tenant-id", HashBalanceFactor: 0, PrivateInstanceId: "ID2"}) + e3 = route.NewEndpoint(&route.EndpointOpts{Host: "3.2.3.4", Port: 5678, LoadBalancingAlgorithm: "hash", HashHeaderName: "tenant-id", HashBalanceFactor: 0, PrivateInstanceId: "ID3"}) + endpoints = []*route.Endpoint{e1, e2, e3} + for _, e := range endpoints { + pool.Put(e) + } + iter := route.NewHashBased(logger.Logger, pool, "", false, false, "") + iter.(*route.HashBased).HeaderValue = "tenant-1" + first := iter.Next(0) + Expect(iter.Next(0)).To(Equal(first)) + for i := 0; i < 6; i++ { + iter.PreRequest(first) + } + second := iter.Next(0) + Expect(second).To(Equal(first)) + }) + + }) + + Context("with retries", func() { + var ( + endpoints []*route.Endpoint + e1 *route.Endpoint + e2 *route.Endpoint + e3 *route.Endpoint + e4 *route.Endpoint + MaglevLookupTable = []int{2, 2, 1, 0, 1, 0, 0, 0, 2, 0, 1, 3, 1, 0, 1, 0, 3, 0, 3, 0, 0, 0, 1, 0, 1, 2, 2, 0, 3, 2, 3, 0, 1, 0, 1, 0, 3, 3, 2, 0, 3, 1, 2, 0, 3, 0, 1, 0, 2, 3, 2, 3, 2, 0, 1, 2, 1, 0, 3, 2, 2, 1, 1, 2, 1, 3, 1, 2, 2, 0, 3, 2, 3, 1, 1, 3, 1, 3, 1, 0, 2, 1, 3, 1, 2, 2, 1, 3, 2, 2, 2, 3, 3, 1, 3, 0, 3, 2, 3, 3, 0} + ) + It("It returns next endpoint from maglev lookup table", func() { + e1 = route.NewEndpoint(&route.EndpointOpts{Host: "1.2.3.4", Port: 5678, LoadBalancingAlgorithm: "hash", HashHeaderName: "tenant-id", PrivateInstanceId: "ID1"}) + e2 = route.NewEndpoint(&route.EndpointOpts{Host: "2.2.3.4", Port: 5678, LoadBalancingAlgorithm: "hash", HashHeaderName: "tenant-id", PrivateInstanceId: "ID2"}) + e3 = route.NewEndpoint(&route.EndpointOpts{Host: "3.2.3.4", Port: 5678, LoadBalancingAlgorithm: "hash", HashHeaderName: "tenant-id", PrivateInstanceId: "ID3"}) + e4 = route.NewEndpoint(&route.EndpointOpts{Host: "4.2.3.4", Port: 5678, LoadBalancingAlgorithm: "hash", HashHeaderName: "tenant-id", PrivateInstanceId: "ID4"}) + + endpoints = []*route.Endpoint{e1, e2, e3, e4} + endpointIDList := make([]string, 0, 4) + for _, e := range endpoints { + pool.Put(e) + endpointIDList = append(endpointIDList, e.PrivateInstanceId) + } + maglevMock := NewMockHashLookupTable(MaglevLookupTable, endpointIDList) + pool.HashLookupTable = maglevMock + iter := route.NewHashBased(logger.Logger, pool, "", false, false, "") + iter.(*route.HashBased).HeaderValue = "tenant-1" + // The returned endpoint has always ID3 according to the Maglev lookup table + first := iter.Next(0) + Expect(first).To(Equal(e4)) + second := iter.Next(1) + Expect(second).To(Equal(e1)) + third := iter.Next(2) + Expect(third).To(Equal(e4)) + }) + It("It returns the next not overloaded endpoint for the second attempt", func() { + e1 = route.NewEndpoint(&route.EndpointOpts{Host: "1.2.3.4", Port: 5678, LoadBalancingAlgorithm: "hash", HashHeaderName: "tenant-id", HashBalanceFactor: 1.2, PrivateInstanceId: "ID1"}) + e2 = route.NewEndpoint(&route.EndpointOpts{Host: "2.2.3.4", Port: 5678, LoadBalancingAlgorithm: "hash", HashHeaderName: "tenant-id", HashBalanceFactor: 1.2, PrivateInstanceId: "ID2"}) + e3 = route.NewEndpoint(&route.EndpointOpts{Host: "3.2.3.4", Port: 5678, LoadBalancingAlgorithm: "hash", HashHeaderName: "tenant-id", HashBalanceFactor: 1.2, PrivateInstanceId: "ID3"}) + e4 = route.NewEndpoint(&route.EndpointOpts{Host: "4.2.3.4", Port: 5678, LoadBalancingAlgorithm: "hash", HashHeaderName: "tenant-id", PrivateInstanceId: "ID3"}) - It("It selects another instance for other hash header value", func() { + endpoints = []*route.Endpoint{e1, e2, e3, e4} + for _, e := range endpoints { + pool.Put(e) + } iter := route.NewHashBased(logger.Logger, pool, "", false, false, "") - iter.(*route.HashBased).HeaderValue = "example.com" - Expect(iter.Next(0)).NotTo(BeNil()) - Expect(iter.Next(0)).To(Equal(endpoints[1])) - Expect(iter.Next(0)).To(Equal(endpoints[1])) - Expect(iter.Next(0)).To(Equal(endpoints[1])) + iter.(*route.HashBased).HeaderValue = "tenant-1" + firstAttemptResult := iter.Next(0) + Expect(iter.Next(0)).To(Equal(firstAttemptResult)) + for i := 0; i < 6; i++ { + // Simulate requests to overload the endpoints + iter.PreRequest(e1) + iter.PreRequest(e2) + } + secondAttemptResult := iter.Next(1) + Expect(secondAttemptResult).NotTo(Equal(firstAttemptResult)) + Expect(secondAttemptResult).NotTo(Equal(e1)) + Expect(secondAttemptResult).NotTo(Equal(e2)) }) }) @@ -102,6 +201,13 @@ var _ = Describe("HashBased", func() { iter = route.NewHashBased(logger.Logger, pool, "nonexistent-id", true, false, "") Expect(iter.Next(0)).To(BeNil()) }) + It("returns nil when sticky endpoint is overloaded and mustBeSticky is true", func() { + iter = route.NewHashBased(logger.Logger, pool, "ID1", true, false, "") + for i := 0; i < 1000; i++ { + iter.PreRequest(endpoints[0]) + } + Expect(iter.Next(0)).To(BeNil()) + }) }) Context("when mustBeSticky is false", func() { @@ -151,5 +257,144 @@ var _ = Describe("HashBased", func() { Expect(endpoint.Stats.NumberConnections.Count()).To(Equal(initialCount - 1)) }) }) + Describe("CalculateAverageLoad", func() { + var iter *route.HashBased + var endpoints []*route.Endpoint + + BeforeEach(func() { + iter = route.NewHashBased(logger.Logger, pool, "", false, false, "").(*route.HashBased) + }) + + Context("when there are no endpoints", func() { + It("returns 0", func() { + Expect(iter.CalculateAverageLoad()).To(Equal(float64(0))) + }) + }) + + Context("when all endpoints have zero connections", func() { + BeforeEach(func() { + pool.Put(route.NewEndpoint(&route.EndpointOpts{Host: "1.2.3.4", Port: 5678, LoadBalancingAlgorithm: "hash", PrivateInstanceId: "ID1"})) + pool.Put(route.NewEndpoint(&route.EndpointOpts{Host: "2.2.3.4", Port: 5678, LoadBalancingAlgorithm: "hash", PrivateInstanceId: "ID2"})) + }) + It("returns 0", func() { + Expect(iter.CalculateAverageLoad()).To(Equal(float64(0))) + }) + }) + + Context("when endpoints have varying connection counts", func() { + var e1, e2, e3 *route.Endpoint + BeforeEach(func() { + e1 = route.NewEndpoint(&route.EndpointOpts{Host: "1.2.3.4", Port: 5678, LoadBalancingAlgorithm: "hash", PrivateInstanceId: "ID1"}) + e2 = route.NewEndpoint(&route.EndpointOpts{Host: "2.2.3.4", Port: 5678, LoadBalancingAlgorithm: "hash", PrivateInstanceId: "ID2"}) + e3 = route.NewEndpoint(&route.EndpointOpts{Host: "3.2.3.4", Port: 5678, LoadBalancingAlgorithm: "hash", PrivateInstanceId: "ID3"}) + endpoints = []*route.Endpoint{e1, e2, e3} + for _, e := range endpoints { + pool.Put(e) + } + for i := 0; i < 2; i++ { + iter.PreRequest(e1) + } + for i := 0; i < 4; i++ { + iter.PreRequest(e2) + } + for i := 0; i < 6; i++ { + iter.PreRequest(e3) + } + }) + It("returns the correct average", func() { + // in general 12 in flight requests + Expect(iter.CalculateAverageLoad()).To(Equal(float64(4))) + }) + }) + + Context("when one endpoint has many connections", func() { + var e1, e2 *route.Endpoint + BeforeEach(func() { + e1 = route.NewEndpoint(&route.EndpointOpts{Host: "1.2.3.4", Port: 5678, LoadBalancingAlgorithm: "hash", PrivateInstanceId: "ID1"}) + e2 = route.NewEndpoint(&route.EndpointOpts{Host: "2.2.3.4", Port: 5678, LoadBalancingAlgorithm: "hash", PrivateInstanceId: "ID2"}) + endpoints = []*route.Endpoint{e1, e2} + for _, e := range endpoints { + pool.Put(e) + } + for i := 0; i < 10; i++ { + iter.PreRequest(e1) + } + }) + It("returns the correct average", func() { + Expect(iter.CalculateAverageLoad()).To(Equal(float64(5))) + }) + }) + }) }) + +// MockHashLookupTable provides a simple mock implementation of MaglevLookup interface for testing. +type MockHashLookupTable struct { + lookupTable []int + endpointList []string +} + +// NewMockHashLookupTable creates a new mock lookup table with predefined mappings +func NewMockHashLookupTable(lookupTable []int, endpointList []string) *MockHashLookupTable { + + return &MockHashLookupTable{ + lookupTable: lookupTable, + endpointList: endpointList, + } +} + +func (m *MockHashLookupTable) GetInstanceForHashHeader(hashHeaderValue string) (uint64, string, error) { + if len(m.endpointList) == 0 { + return 0, "", nil + } + h := fnv.New64a() + _, _ = h.Write([]byte(hashHeaderValue)) + key := h.Sum64() + index := key % m.GetLookupTableSize() + return index, m.endpointList[m.lookupTable[index]], nil + +} + +func (m *MockHashLookupTable) GetLookupTableSize() uint64 { + return uint64(len(m.lookupTable)) +} + +func (m *MockHashLookupTable) GetEndpointId(lookupTableIndex uint64) string { + return m.endpointList[m.lookupTable[lookupTableIndex]] +} + +func (m *MockHashLookupTable) Add(endpoint string) { + // Check if endpoint already exists + for _, existing := range m.endpointList { + if existing == endpoint { + return + } + } + m.endpointList = append(m.endpointList, endpoint) +} + +func (m *MockHashLookupTable) Remove(endpoint string) { + for i, existing := range m.endpointList { + if existing == endpoint { + m.endpointList = append(m.endpointList[:i], m.endpointList[i+1:]...) + return + } + } +} + +func (m *MockHashLookupTable) GetEndpointList() []string { + return append([]string(nil), m.endpointList...) // return a copy +} + +// GetLookupTable returns a copy of the current lookup table (for testing) +func (m *MockHashLookupTable) GetLookupTable() []int { + return m.lookupTable // return a copy +} + +// GetPermutationTable returns a copy of the current permutation table (for testing) +func (m *MockHashLookupTable) GetPermutationTable() [][]uint64 { + return nil // not implemented in mock +} + +// Compile-time check to ensure MockHashLookupTable implements MaglevLookup interface +var _ route.MaglevLookup = (*MockHashLookupTable)(nil) diff --git a/src/code.cloudfoundry.org/gorouter/route/maglev.go b/src/code.cloudfoundry.org/gorouter/route/maglev.go index 6085c7d0f..23d46210f 100644 --- a/src/code.cloudfoundry.org/gorouter/route/maglev.go +++ b/src/code.cloudfoundry.org/gorouter/route/maglev.go @@ -27,6 +27,35 @@ const ( lookupTableSize uint64 = 1801 ) +// MaglevLookup defines the interface for consistent hashing lookup table implementations. +// This interface allows for different implementations of the Maglev algorithm and +// enables easy testing with mock implementations. +type MaglevLookup interface { + // Add a new endpoint to the lookup table + Add(endpoint string) + + // Remove an endpoint from the lookup table + Remove(endpoint string) + + // GetInstanceForHashHeader endpoint by specified request header value + GetInstanceForHashHeader(hashHeaderValue string) (uint64, string, error) + + // GetEndpointId returns the endpoint ID by specified lookup table index + GetEndpointId(lookupTableIndex uint64) string + + // GetLookupTableSize returns the size of the lookup table + GetLookupTableSize() uint64 + + // GetEndpointList returns a copy of the current endpoint list (for testing) + GetEndpointList() []string + + // GetLookupTable returns a copy of the current lookup table (for testing) + GetLookupTable() []int + + // GetPermutationTable returns a copy of the current permutation table (for testing) + GetPermutationTable() [][]uint64 +} + // Maglev implementation of consistent hashing algorithm described in "Maglev: A Fast and Reliable Software Network // Load Balancer" (https://storage.googleapis.com/gweb-research2023-media/pubtools/2904.pdf) type Maglev struct { @@ -89,23 +118,29 @@ func (m *Maglev) Remove(endpoint string) { m.fillLookupTable() } -// Get endpoint by specified request header value -// Todo: Overload scenario: Get should return an index rather than an instance, -// so that we can iterate to the next endpoint in case it is overloaded (e.g. via another -// helper function that resolves the endpoint via the index) -func (m *Maglev) Get(headerValue string) (string, error) { +func (m *Maglev) hashKey(headerValue string) uint64 { + return m.calculateFNVHash64(headerValue) +} + +// GetInstanceForHashHeader lookup table index and private instance ID for the specified request header value +func (m *Maglev) GetInstanceForHashHeader(hashHeaderValue string) (uint64, string, error) { m.lock.RLock() defer m.lock.RUnlock() if len(m.endpointList) == 0 { - return "", errors.New("maglev-get-endpoint-no-endpoints") + return 0, "", errors.New("no endpoint available") } - key := m.hashKey(headerValue) - return m.endpointList[m.lookupTable[key%lookupTableSize]], nil + key := m.hashKey(hashHeaderValue) + index := key % lookupTableSize + return index, m.endpointList[m.lookupTable[key%lookupTableSize]], nil } -func (m *Maglev) hashKey(headerValue string) uint64 { - return m.calculateFNVHash64(headerValue) +// GetEndpointId by specified lookup table index +func (m *Maglev) GetEndpointId(lookupTableIndex uint64) string { + m.lock.RLock() + defer m.lock.RUnlock() + + return m.endpointList[m.lookupTable[lookupTableIndex]] } // generatePermutation creates a permutationTable of the lookup table for each endpoint @@ -220,3 +255,6 @@ func (m *Maglev) calculateFNVHash64(key string) uint64 { _, _ = h.Write([]byte(key)) return h.Sum64() } + +// Compile-time check to ensure Maglev implements MaglevLookup interface +var _ MaglevLookup = (*Maglev)(nil) diff --git a/src/code.cloudfoundry.org/gorouter/route/maglev_test.go b/src/code.cloudfoundry.org/gorouter/route/maglev_test.go index ae8af9d07..b72d13b3d 100644 --- a/src/code.cloudfoundry.org/gorouter/route/maglev_test.go +++ b/src/code.cloudfoundry.org/gorouter/route/maglev_test.go @@ -39,9 +39,9 @@ var _ = Describe("Maglev", func() { Expect(maglev.GetPermutationTable()).To(HaveLen(1)) Expect(maglev.GetPermutationTable()[0]).To(HaveLen(int(maglev.GetLookupTableSize()))) - result, err := maglev.Get("test-key") + _, backend, err := maglev.GetInstanceForHashHeader("test-key") Expect(err).NotTo(HaveOccurred()) - Expect(result).To(Equal("backend1")) + Expect(backend).To(Equal("backend1")) }) }) @@ -55,9 +55,9 @@ var _ = Describe("Maglev", func() { Expect(maglev.GetPermutationTable()).To(HaveLen(1)) Expect(maglev.GetPermutationTable()[0]).To(HaveLen(int(maglev.GetLookupTableSize()))) - result, err := maglev.Get("test-key") + _, backend, err := maglev.GetInstanceForHashHeader("test-key") Expect(err).NotTo(HaveOccurred()) - Expect(result).To(Equal("backend1")) + Expect(backend).To(Equal("backend1")) }) }) @@ -76,9 +76,9 @@ var _ = Describe("Maglev", func() { backends := make(map[string]bool) for i := 0; i < 1000; i++ { - result, err := maglev.Get(string(rune(i))) + _, backend, err := maglev.GetInstanceForHashHeader(string(rune(i))) Expect(err).NotTo(HaveOccurred()) - backends[result] = true + backends[backend] = true } Expect(backends["backend1"]).To(BeTrue()) @@ -121,7 +121,7 @@ var _ = Describe("Maglev", func() { Describe("Get", func() { Context("when no backends were added", func() { It("should return an error", func() { - _, err := maglev.Get("test-key") + _, _, err := maglev.GetInstanceForHashHeader("test-key") Expect(err).To(HaveOccurred()) }) }) @@ -134,15 +134,15 @@ var _ = Describe("Maglev", func() { It("should return consistent results for the same key", func() { var counter = make(map[string]int) - var result1 string + var result string var err error - for _ = range 100 { - result1, err = maglev.Get("consistent-key") + for range 100 { + _, result, err = maglev.GetInstanceForHashHeader("consistent-key") Expect(err).NotTo(HaveOccurred()) - counter[result1]++ + counter[result]++ } - Expect(counter[result1]).To(Equal(100)) + Expect(counter[result]).To(Equal(100)) }) It("should distribute keys across backends", func() { @@ -152,9 +152,9 @@ var _ = Describe("Maglev", func() { distribution := make(map[string]int) for i := range 1000 { - result, err := maglev.Get(string(rune(i))) + _, backend, err := maglev.GetInstanceForHashHeader(string(rune(i))) Expect(err).NotTo(HaveOccurred()) - distribution[result]++ + distribution[backend]++ } Expect(distribution["backend1"]).To(BeNumerically(">", 0)) @@ -171,10 +171,98 @@ var _ = Describe("Maglev", func() { }) It("should not return the removed backend", func() { - for _ = range 100 { - endpoint, err := maglev.Get("consistent-key") + for range 100 { + _, backend, err := maglev.GetInstanceForHashHeader("consistent-key") Expect(err).NotTo(HaveOccurred()) - Expect(endpoint).To(Equal("backend2")) + Expect(backend).To(Equal("backend2")) + } + }) + }) + }) + + Describe("GetInstanceForHashHeader", func() { + Context("when no backends were added", func() { + It("should return an error", func() { + _, _, err := maglev.GetInstanceForHashHeader("test-key") + Expect(err).To(HaveOccurred()) + }) + }) + + Context("when backends are added", func() { + BeforeEach(func() { + maglev.Add("backend1") + maglev.Add("backend2") + }) + + It("should return consistent results for the same key", func() { + var counter = make(map[uint64]int) + var lookupTableIndex uint64 + var err error + for range 100 { + lookupTableIndex, _, err = maglev.GetInstanceForHashHeader("consistent-key") + Expect(err).NotTo(HaveOccurred()) + counter[lookupTableIndex]++ + } + + Expect(counter[lookupTableIndex]).To(Equal(100)) + }) + }) + }) + + Describe("GetEndpointId", func() { + Context("when backends are added", func() { + BeforeEach(func() { + maglev.Add("app_instance_1") + maglev.Add("app_instance_2") + }) + + It("should return consistent results for the same key", func() { + var counter = make(map[string]int) + var endpointID string + for range 100 { + lookupTableIndex, _, err := maglev.GetInstanceForHashHeader("consistent-key") + Expect(err).NotTo(HaveOccurred()) + endpointID = maglev.GetEndpointId(lookupTableIndex) + Expect(err).NotTo(HaveOccurred()) + counter[endpointID]++ + } + + Expect(counter[endpointID]).To(Equal(100)) + }) + + It("should distribute keys across backends", func() { + maglev.Add("app_instance_1") + maglev.Add("app_instance_2") + maglev.Add("app_instance_3") + + distribution := make(map[string]int) + for i := range 1000 { + lookupTableIndex, _, err := maglev.GetInstanceForHashHeader(string(rune(i))) + Expect(err).NotTo(HaveOccurred()) + endpointID := maglev.GetEndpointId(lookupTableIndex) + Expect(err).NotTo(HaveOccurred()) + distribution[endpointID]++ + } + + Expect(distribution["app_instance_1"]).To(BeNumerically(">", 0)) + Expect(distribution["app_instance_2"]).To(BeNumerically(">", 0)) + Expect(distribution["app_instance_3"]).To(BeNumerically(">", 0)) + }) + }) + + Context("when backends are removed", func() { + BeforeEach(func() { + maglev.Add("app_instance_1") + maglev.Add("app_instance_2") + maglev.Remove("app_instance_1") + }) + + It("should not return the removed backend", func() { + for i := range 1000 { + lookupTableIndex, _, err := maglev.GetInstanceForHashHeader(string(rune(i))) + Expect(err).NotTo(HaveOccurred()) + endpointID := maglev.GetEndpointId(lookupTableIndex) + Expect(endpointID).To(Equal("app_instance_2")) } }) }) @@ -195,7 +283,7 @@ var _ = Describe("Maglev", func() { initialMappings := make(map[string]string) for _, key := range keys { - backend, err := maglev.Get(key) + _, backend, err := maglev.GetInstanceForHashHeader(key) Expect(err).NotTo(HaveOccurred()) initialMappings[key] = backend } @@ -204,7 +292,7 @@ var _ = Describe("Maglev", func() { changedMappings := 0 for _, key := range keys { - backend, err := maglev.Get(key) + _, backend, err := maglev.GetInstanceForHashHeader(key) Expect(err).NotTo(HaveOccurred()) if initialMappings[key] != backend { changedMappings++ @@ -224,7 +312,7 @@ var _ = Describe("Maglev", func() { go func() { defer GinkgoRecover() for j := 0; j < 100; j++ { - _, err := maglev.Get("test-key") + _, _, err := maglev.GetInstanceForHashHeader("test-key") Expect(err).NotTo(HaveOccurred()) } done <- true diff --git a/src/code.cloudfoundry.org/gorouter/route/pool.go b/src/code.cloudfoundry.org/gorouter/route/pool.go index b9f491798..8217aefca 100644 --- a/src/code.cloudfoundry.org/gorouter/route/pool.go +++ b/src/code.cloudfoundry.org/gorouter/route/pool.go @@ -202,7 +202,7 @@ type EndpointPool struct { updatedAt time.Time LoadBalancingAlgorithm string HashRoutingProperties *HashRoutingProperties - HashLookupTable *Maglev + HashLookupTable MaglevLookup } type EndpointOpts struct { @@ -617,19 +617,24 @@ func (p *EndpointPool) MarshalJSON() ([]byte, error) { // setPoolLoadBalancingAlgorithm overwrites the load balancing algorithm of a pool by that of a specified endpoint, if that is valid. func (p *EndpointPool) setPoolLoadBalancingAlgorithm(endpoint *Endpoint) { - if endpoint.LoadBalancingAlgorithm != "" && endpoint.LoadBalancingAlgorithm != p.LoadBalancingAlgorithm { + if endpoint.LoadBalancingAlgorithm == "" { + return + } + + if endpoint.LoadBalancingAlgorithm != p.LoadBalancingAlgorithm { if config.IsLoadBalancingAlgorithmValid(endpoint.LoadBalancingAlgorithm) { p.LoadBalancingAlgorithm = endpoint.LoadBalancingAlgorithm p.logger.Debug("setting-pool-load-balancing-algorithm-to-that-of-an-endpoint", slog.String("endpointLBAlgorithm", endpoint.LoadBalancingAlgorithm), slog.String("poolLBAlgorithm", p.LoadBalancingAlgorithm)) - p.prepareHashBasedRouting(endpoint) + } else { p.logger.Error("invalid-endpoint-load-balancing-algorithm-provided-keeping-pool-lb-algo", slog.String("endpointLBAlgorithm", endpoint.LoadBalancingAlgorithm), slog.String("poolLBAlgorithm", p.LoadBalancingAlgorithm)) } } + p.prepareHashBasedRouting(endpoint) } func (p *EndpointPool) prepareHashBasedRouting(endpoint *Endpoint) { @@ -639,11 +644,15 @@ func (p *EndpointPool) prepareHashBasedRouting(endpoint *Endpoint) { if p.HashLookupTable == nil { p.HashLookupTable = NewMaglev(p.logger) } - p.HashRoutingProperties = &HashRoutingProperties{ + + newProps := &HashRoutingProperties{ Header: endpoint.HashHeaderName, BalanceFactor: endpoint.HashBalanceFactor, } + if p.HashRoutingProperties == nil || !p.HashRoutingProperties.Equal(newProps) { + p.HashRoutingProperties = newProps + } } func (e *endpointElem) failed() { From 91f5968b0064b05cc55710e58f6d2403e6231e35 Mon Sep 17 00:00:00 2001 From: Clemens Hoffmann Date: Wed, 12 Nov 2025 16:14:35 +0100 Subject: [PATCH 4/4] * Minor improvements and refactoring --- .../gorouter/route/hash_based.go | 55 ++++++++++--- .../gorouter/route/hash_based_test.go | 81 ++++++++++++++++++- 2 files changed, 121 insertions(+), 15 deletions(-) diff --git a/src/code.cloudfoundry.org/gorouter/route/hash_based.go b/src/code.cloudfoundry.org/gorouter/route/hash_based.go index ff98f072a..db7aaac1a 100644 --- a/src/code.cloudfoundry.org/gorouter/route/hash_based.go +++ b/src/code.cloudfoundry.org/gorouter/route/hash_based.go @@ -58,6 +58,17 @@ func (h *HashBased) Next(attempt int) *Endpoint { return endpoint } + if len(h.pool.endpoints) == 0 { + h.logger.Warn("hash-based-routing-pool-empty", slog.String("host", h.pool.host)) + return nil + } + + endpoint = h.getSingleEndpoint() + if endpoint != nil { + h.lastEndpoint = endpoint + return endpoint + } + if h.pool.HashLookupTable == nil { h.logger.Error("hash-based-routing-failed", slog.String("host", h.pool.host), log.ErrAttr(errors.New("Lookup table is empty"))) return nil @@ -89,11 +100,6 @@ func (h *HashBased) Next(attempt int) *Endpoint { } func (h *HashBased) findEndpoint(index uint64, attempt int) *Endpoint { - maxIterations := len(h.pool.endpoints) - if maxIterations == 0 { - return nil - } - // Ensure we don't exceed the lookup table size lookupTableSize := h.pool.HashLookupTable.GetLookupTableSize() @@ -128,10 +134,9 @@ func (h *HashBased) findEndpoint(index uint64, attempt int) *Endpoint { lastEndpointPrivateId = id - e := endpointElem.endpoint - if h.pool.HashRoutingProperties.BalanceFactor <= 0 || !h.isOverloaded(e) { + if h.pool.HashRoutingProperties.BalanceFactor <= 0 || !h.isImbalancedOrOverloaded(endpointElem) { h.lastLookupTableIndex = currentIndex - return e + return endpointElem.endpoint } currentIndex = (currentIndex + 1) % lookupTableSize @@ -141,11 +146,24 @@ func (h *HashBased) findEndpoint(index uint64, attempt int) *Endpoint { return nil } -func (h *HashBased) isOverloaded(e *Endpoint) bool { - avgLoad := h.CalculateAverageLoad() +func (h *HashBased) isImbalancedOrOverloaded(e *endpointElem) bool { + endpoint := e.endpoint + return h.IsImbalancedOrOverloaded(endpoint, e.isOverloaded()) +} + +func (h *HashBased) IsImbalancedOrOverloaded(endpoint *Endpoint, isEndpointOverloaded bool) bool { + avgNumberOfInFlightRequests := h.CalculateAverageLoad() + currentInFlightRequestCount := endpoint.Stats.NumberConnections.Count() balanceFactor := h.pool.HashRoutingProperties.BalanceFactor - if float64(e.Stats.NumberConnections.Count())/avgLoad > balanceFactor { - h.logger.Info("hash-based-routing-endpoint-overloaded", slog.String("host", h.pool.host), slog.String("endpoint-id", e.PrivateInstanceId), slog.Int64("endpoint-connections", e.Stats.NumberConnections.Count()), slog.Float64("average-load", avgLoad)) + + if isEndpointOverloaded { + h.logger.Debug("hash-based-routing-endpoint-overloaded", slog.String("host", h.pool.host), slog.String("endpoint-id", endpoint.PrivateInstanceId), slog.Int64("endpoint-connections", currentInFlightRequestCount)) + return true + } + + // Check if avgNumberOfInFlightRequests is 0 to avoid division by 0 + if avgNumberOfInFlightRequests == 0 || float64(currentInFlightRequestCount)/avgNumberOfInFlightRequests > balanceFactor { + h.logger.Debug("hash-based-routing-endpoint-imbalanced", slog.String("host", h.pool.host), slog.String("endpoint-id", endpoint.PrivateInstanceId), slog.Int64("endpoint-connections", endpoint.Stats.NumberConnections.Count()), slog.Float64("average-load", avgNumberOfInFlightRequests)) return true } return false @@ -203,6 +221,7 @@ func (h *HashBased) PostRequest(e *Endpoint) { e.Stats.NumberConnections.Decrement() } +// CalculateAverageLoad computes the average number of in-flight requests across all endpoints in the pool. func (h *HashBased) CalculateAverageLoad() float64 { if len(h.pool.endpoints) == 0 { return 0 @@ -217,3 +236,15 @@ func (h *HashBased) CalculateAverageLoad() float64 { return float64(currentInFlightRequestCount) / float64(len(h.pool.endpoints)) } + +func (h *HashBased) getSingleEndpoint() *Endpoint { + if len(h.pool.endpoints) == 1 { + e := h.pool.endpoints[0] + if e.isOverloaded() { + return nil + } + + return e.endpoint + } + return nil +} diff --git a/src/code.cloudfoundry.org/gorouter/route/hash_based_test.go b/src/code.cloudfoundry.org/gorouter/route/hash_based_test.go index 26df2c5b4..c7fcd79a3 100644 --- a/src/code.cloudfoundry.org/gorouter/route/hash_based_test.go +++ b/src/code.cloudfoundry.org/gorouter/route/hash_based_test.go @@ -6,11 +6,11 @@ import ( "time" "code.cloudfoundry.org/gorouter/config" - "code.cloudfoundry.org/gorouter/route" "code.cloudfoundry.org/gorouter/test_util" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" + "github.com/onsi/gomega/gbytes" ) var _ = Describe("HashBased", func() { @@ -257,7 +257,83 @@ var _ = Describe("HashBased", func() { Expect(endpoint.Stats.NumberConnections.Count()).To(Equal(initialCount - 1)) }) }) - Describe("CalculateAverageLoad", func() { + Describe("IsImbalancedOrOverloaded", func() { + var iter *route.HashBased + var endpoints []*route.Endpoint + + BeforeEach(func() { + iter = route.NewHashBased(logger.Logger, pool, "", false, false, "").(*route.HashBased) + }) + + Context("when endpoints have a lot of in-flight requests", func() { + var e1, e2, e3 *route.Endpoint + BeforeEach(func() { + e1 = route.NewEndpoint(&route.EndpointOpts{Host: "1.2.3.4", Port: 5678, LoadBalancingAlgorithm: "hash", HashHeaderName: "tenant-id", HashBalanceFactor: 1.2, PrivateInstanceId: "ID1"}) + e2 = route.NewEndpoint(&route.EndpointOpts{Host: "2.2.3.4", Port: 5678, LoadBalancingAlgorithm: "hash", HashHeaderName: "tenant-id", HashBalanceFactor: 1.2, PrivateInstanceId: "ID2"}) + e3 = route.NewEndpoint(&route.EndpointOpts{Host: "3.2.3.4", Port: 5678, LoadBalancingAlgorithm: "hash", HashHeaderName: "tenant-id", HashBalanceFactor: 1.2, PrivateInstanceId: "ID3"}) + endpoints = []*route.Endpoint{e1, e2, e3} + for _, e := range endpoints { + pool.Put(e) + } + + }) + It("mark the endpoint as overloaded", func() { + for i := 0; i < 500; i++ { + iter.PreRequest(e1) + } + // in general 500 in flight requests counted by e1 + Expect(iter.IsImbalancedOrOverloaded(e1, true)).To(BeTrue()) + }) + It("do not mark as imbalanced if every endpoint has 499 in-flight requests", func() { + for i := 0; i < 498; i++ { + iter.PreRequest(e1) + } + for i := 0; i < 498; i++ { + iter.PreRequest(e2) + } + for i := 0; i < 498; i++ { + iter.PreRequest(e3) + } + // in general 500 in flight requests counted by e1 + Expect(iter.IsImbalancedOrOverloaded(e1, false)).To(BeFalse()) + }) + + It("mark endpoint as overloaded if every endpoint has 500 in-flight requests", func() { + for i := 0; i < 499; i++ { + iter.PreRequest(e1) + } + for i := 0; i < 499; i++ { + iter.PreRequest(e2) + } + for i := 0; i < 499; i++ { + iter.PreRequest(e3) + } + // in general 500 in flight requests counted by e1 + Expect(iter.IsImbalancedOrOverloaded(e1, true)).To(BeTrue()) + Eventually(logger).Should(gbytes.Say("hash-based-routing-endpoint-overloaded")) + Expect(iter.IsImbalancedOrOverloaded(e2, true)).To(BeTrue()) + Expect(iter.IsImbalancedOrOverloaded(e3, true)).To(BeTrue()) + + }) + It("mark as imbalanced if it has more in-flight requests", func() { + for i := 0; i < 300; i++ { + iter.PreRequest(e1) + } + for i := 0; i < 200; i++ { + iter.PreRequest(e2) + } + for i := 0; i < 200; i++ { + iter.PreRequest(e3) + } + Expect(iter.IsImbalancedOrOverloaded(e1, false)).To(BeTrue()) + Eventually(logger).Should(gbytes.Say("hash-based-routing-endpoint-imbalanced")) + Expect(iter.IsImbalancedOrOverloaded(e2, false)).To(BeFalse()) + Expect(iter.IsImbalancedOrOverloaded(e3, false)).To(BeFalse()) + }) + }) + }) + + Describe("CalculateAverageNumberOfConnections", func() { var iter *route.HashBased var endpoints []*route.Endpoint @@ -336,7 +412,6 @@ type MockHashLookupTable struct { // NewMockHashLookupTable creates a new mock lookup table with predefined mappings func NewMockHashLookupTable(lookupTable []int, endpointList []string) *MockHashLookupTable { - return &MockHashLookupTable{ lookupTable: lookupTable, endpointList: endpointList,