diff --git a/modules/caddyhttp/reverseproxy/selectionpolicies.go b/modules/caddyhttp/reverseproxy/selectionpolicies.go index 343140fc1ed..2aef63dc196 100644 --- a/modules/caddyhttp/reverseproxy/selectionpolicies.go +++ b/modules/caddyhttp/reverseproxy/selectionpolicies.go @@ -397,13 +397,18 @@ func leastRequests(upstreams []*Upstream) *Upstream { return nil } var best []*Upstream - var bestReqs int + var bestReqs int = -1 for _, upstream := range upstreams { + if upstream == nil { + continue + } reqs := upstream.NumRequests() if reqs == 0 { return upstream } - if reqs <= bestReqs { + // If bestReqs was just initialized to -1 + // we need to append upstream also + if reqs <= bestReqs || bestReqs == -1 { bestReqs = reqs best = append(best, upstream) } diff --git a/modules/caddyhttp/reverseproxy/selectionpolicies_test.go b/modules/caddyhttp/reverseproxy/selectionpolicies_test.go index e9939d6d14b..49585da44e9 100644 --- a/modules/caddyhttp/reverseproxy/selectionpolicies_test.go +++ b/modules/caddyhttp/reverseproxy/selectionpolicies_test.go @@ -271,3 +271,54 @@ func TestURIHashPolicy(t *testing.T) { t.Error("Expected uri policy policy host to be nil.") } } + +func TestLeastRequests(t *testing.T) { + pool := testPool() + pool[0].Dial = "localhost:8080" + pool[1].Dial = "localhost:8081" + pool[2].Dial = "localhost:8082" + pool[0].SetHealthy(true) + pool[1].SetHealthy(true) + pool[2].SetHealthy(true) + pool[0].CountRequest(10) + pool[1].CountRequest(20) + pool[2].CountRequest(30) + + result := leastRequests(pool) + + if result == nil { + t.Error("Least request should not return nil") + } + + if result != pool[0] { + t.Error("Least request should return pool[0]") + } +} + +func TestRandomChoicePolicy(t *testing.T) { + pool := testPool() + pool[0].Dial = "localhost:8080" + pool[1].Dial = "localhost:8081" + pool[2].Dial = "localhost:8082" + pool[0].SetHealthy(false) + pool[1].SetHealthy(true) + pool[2].SetHealthy(true) + pool[0].CountRequest(10) + pool[1].CountRequest(20) + pool[2].CountRequest(30) + + request := httptest.NewRequest(http.MethodGet, "/test", nil) + randomChoicePolicy := new(RandomChoiceSelection) + randomChoicePolicy.Choose = 2 + + h := randomChoicePolicy.Select(pool, request) + + if h == nil { + t.Error("RandomChoicePolicy should not return nil") + } + + if h == pool[0] { + t.Error("RandomChoicePolicy should not choose pool[0]") + } + +}