From fa1ab6179711942021c0e21b02886620a66da739 Mon Sep 17 00:00:00 2001 From: Saylor Berman Date: Fri, 19 Sep 2025 11:09:30 -0600 Subject: [PATCH 1/5] Query EPP and proxy AI traffic Problem: We need to connect NGINX to the Golang shim that talks to the EndpointPicker, and then pass client traffic to the proper inference workload. Solution: Write an NJS module that will query the local Go server to get the AI endpoint to route traffic to. Then redirect the original client request to an internal location that proxies the traffic to the chosen endpoint. The location building gets a bit complicated especially when using both HTTP matching conditions and inference workloads. It requires 2 layers of internal redirects. I added lots of comments to hopefully clear up how we build these locations to perform all the routing steps. --- cmd/gateway/endpoint_picker.go | 24 +- cmd/gateway/endpoint_picker_test.go | 42 ++- deploy/inference-nginx-plus/deploy.yaml | 1 + .../controller/nginx/config/http/config.go | 56 ++- internal/controller/nginx/config/maps.go | 40 +++ internal/controller/nginx/config/maps_test.go | 65 +++- internal/controller/nginx/config/servers.go | 322 +++++++++++++++--- .../nginx/config/servers_template.go | 12 +- .../controller/nginx/config/servers_test.go | 182 +++++++++- internal/controller/nginx/modules/src/epp.js | 74 ++-- .../controller/nginx/modules/test/epp.test.js | 116 ++++--- .../state/dataplane/configuration.go | 27 +- .../state/dataplane/configuration_test.go | 89 ++++- internal/controller/state/dataplane/types.go | 6 + .../controller/state/graph/backend_refs.go | 90 +++-- .../state/graph/backend_refs_test.go | 8 +- internal/controller/state/graph/graph_test.go | 10 +- internal/controller/state/graph/httproute.go | 21 +- .../controller/state/graph/httproute_test.go | 61 ++++ .../controller/state/graph/route_common.go | 3 + internal/framework/types/types.go | 11 + 21 files changed, 1035 insertions(+), 225 deletions(-) diff --git a/cmd/gateway/endpoint_picker.go b/cmd/gateway/endpoint_picker.go index 7c67a83671..acf9bdfbb6 100644 --- a/cmd/gateway/endpoint_picker.go +++ b/cmd/gateway/endpoint_picker.go @@ -14,16 +14,8 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" eppMetadata "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metadata" -) -const ( - // defaultPort is the default port for this server to listen on. If collisions become a problem, - // we can make this configurable via the NginxProxy resource. - defaultPort = 54800 // why 54800? Sum "nginx" in ASCII and multiply by 100. - // eppEndpointHostHeader is the HTTP header used to specify the EPP endpoint host, set by the NJS module caller. - eppEndpointHostHeader = "X-EPP-Host" - // eppEndpointPortHeader is the HTTP header used to specify the EPP endpoint port, set by the NJS module caller. - eppEndpointPortHeader = "X-EPP-Port" + "github.com/nginx/nginx-gateway-fabric/v2/internal/framework/types" ) // extProcClientFactory creates a new ExternalProcessorClient and returns a close function. @@ -32,7 +24,7 @@ type extProcClientFactory func(target string) (extprocv3.ExternalProcessorClient // endpointPickerServer starts an HTTP server on the given port with the provided handler. func endpointPickerServer(handler http.Handler) error { server := &http.Server{ - Addr: fmt.Sprintf("127.0.0.1:%d", defaultPort), + Addr: fmt.Sprintf("127.0.0.1:%d", types.GoShimPort), Handler: handler, ReadHeaderTimeout: 10 * time.Second, } @@ -54,13 +46,13 @@ func realExtProcClientFactory() extProcClientFactory { // createEndpointPickerHandler returns an http.Handler that forwards requests to the EndpointPicker. func createEndpointPickerHandler(factory extProcClientFactory, logger logr.Logger) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - host := r.Header.Get(eppEndpointHostHeader) - port := r.Header.Get(eppEndpointPortHeader) + host := r.Header.Get(types.EPPEndpointHostHeader) + port := r.Header.Get(types.EPPEndpointPortHeader) if host == "" || port == "" { msg := fmt.Sprintf( "missing at least one of required headers: %s and %s", - eppEndpointHostHeader, - eppEndpointPortHeader, + types.EPPEndpointHostHeader, + types.EPPEndpointPortHeader, ) logger.Error(errors.New(msg), "error contacting EndpointPicker") http.Error(w, msg, http.StatusBadRequest) @@ -174,6 +166,10 @@ func buildHeaderRequest(r *http.Request) *extprocv3.ProcessingRequest { } func buildBodyRequest(r *http.Request) (*extprocv3.ProcessingRequest, error) { + if r.ContentLength == 0 { + return nil, errors.New("request body is empty") + } + body, err := io.ReadAll(r.Body) if err != nil { return nil, fmt.Errorf("error reading request body: %w", err) diff --git a/cmd/gateway/endpoint_picker_test.go b/cmd/gateway/endpoint_picker_test.go index 99808348fc..99fd95aa90 100644 --- a/cmd/gateway/endpoint_picker_test.go +++ b/cmd/gateway/endpoint_picker_test.go @@ -17,6 +17,8 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/metadata" eppMetadata "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metadata" + + "github.com/nginx/nginx-gateway-fabric/v2/internal/framework/types" ) type mockExtProcClient struct { @@ -122,8 +124,8 @@ func TestEndpointPickerHandler_Success(t *testing.T) { h := createEndpointPickerHandler(factory, logr.Discard()) req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("test body")) - req.Header.Set(eppEndpointHostHeader, "test-host") - req.Header.Set(eppEndpointPortHeader, "1234") + req.Header.Set(types.EPPEndpointHostHeader, "test-host") + req.Header.Set(types.EPPEndpointPortHeader, "1234") req.Header.Set("Content-Type", "application/json") w := httptest.NewRecorder() @@ -165,8 +167,8 @@ func TestEndpointPickerHandler_ImmediateResponse(t *testing.T) { h := createEndpointPickerHandler(factory, logr.Discard()) req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("test body")) - req.Header.Set(eppEndpointHostHeader, "test-host") - req.Header.Set(eppEndpointPortHeader, "1234") + req.Header.Set(types.EPPEndpointHostHeader, "test-host") + req.Header.Set(types.EPPEndpointPortHeader, "1234") w := httptest.NewRecorder() h.ServeHTTP(w, req) @@ -190,8 +192,8 @@ func TestEndpointPickerHandler_Errors(t *testing.T) { h := createEndpointPickerHandler(factory, logr.Discard()) req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("test body")) if setHeaders { - req.Header.Set(eppEndpointHostHeader, "test-host") - req.Header.Set(eppEndpointPortHeader, "1234") + req.Header.Set(types.EPPEndpointHostHeader, "test-host") + req.Header.Set(types.EPPEndpointPortHeader, "1234") } w := httptest.NewRecorder() h.ServeHTTP(w, req) @@ -236,7 +238,33 @@ func TestEndpointPickerHandler_Errors(t *testing.T) { } runErrorTestCase(factory, true, http.StatusBadGateway, "error sending headers") - // 4. Error sending body + // 4a. Error building body request (content length 0) + client = &mockProcessClient{ + SendFunc: func(*extprocv3.ProcessingRequest) error { + return nil + }, + RecvFunc: func() (*extprocv3.ProcessingResponse, error) { return nil, io.EOF }, + } + extProcClient = &mockExtProcClient{ + ProcessFunc: func(context.Context, ...grpc.CallOption) (extprocv3.ExternalProcessor_ProcessClient, error) { + return client, nil + }, + } + factory = func(string) (extprocv3.ExternalProcessorClient, func() error, error) { + return extProcClient, func() error { return nil }, nil + } + h := createEndpointPickerHandler(factory, logr.Discard()) + req := httptest.NewRequest(http.MethodPost, "/", nil) // nil body, ContentLength = 0 + req.Header.Set(types.EPPEndpointHostHeader, "test-host") + req.Header.Set(types.EPPEndpointPortHeader, "1234") + w := httptest.NewRecorder() + h.ServeHTTP(w, req) + resp := w.Result() + g.Expect(resp.StatusCode).To(Equal(http.StatusInternalServerError)) + body, _ := io.ReadAll(resp.Body) + g.Expect(string(body)).To(ContainSubstring("request body is empty")) + + // 4b. Error sending body client = &mockProcessClient{ SendFunc: func(req *extprocv3.ProcessingRequest) error { if req.GetRequestBody() != nil { diff --git a/deploy/inference-nginx-plus/deploy.yaml b/deploy/inference-nginx-plus/deploy.yaml index 77ee4da544..025cfeb410 100644 --- a/deploy/inference-nginx-plus/deploy.yaml +++ b/deploy/inference-nginx-plus/deploy.yaml @@ -281,6 +281,7 @@ spec: - --nginx-docker-secret=nginx-plus-registry-secret - --nginx-plus - --usage-report-secret=nplus-license + - --usage-report-enforce-initial-report=true - --metrics-port=9113 - --health-port=8081 - --leader-election-lock-name=nginx-gateway-leader-election diff --git a/internal/controller/nginx/config/http/config.go b/internal/controller/nginx/config/http/config.go index 3a76ab30b4..dedfd04349 100644 --- a/internal/controller/nginx/config/http/config.go +++ b/internal/controller/nginx/config/http/config.go @@ -26,26 +26,58 @@ type Server struct { type LocationType string const ( + // InternalLocationType defines an internal location that is only accessible within NGINX. InternalLocationType LocationType = "internal" + // ExternalLocationType defines a normal external location that is accessible by clients. ExternalLocationType LocationType = "external" + // RedirectLocationType defines an external location that redirects to an internal location + // based on HTTP matching conditions. RedirectLocationType LocationType = "redirect" + // InferenceExternalLocationType defines an external location that is used for calling NJS + // to get the inference workload endpoint and redirects to the internal location that will proxy_pass + // to that endpoint. + InferenceExternalLocationType LocationType = "inference-external" + // InferenceInternalLocationType defines an internal location that is used for calling NJS + // to get the inference workload endpoint and redirects to the internal location that will proxy_pass + // to that endpoint. This is used when an HTTP redirect location is also defined that redirects + // to this internal inference location. + InferenceInternalLocationType LocationType = "inference-internal" ) // Location holds all configuration for an HTTP location. type Location struct { - Path string - ProxyPass string - HTTPMatchKey string + // Return specifies a return directive (e.g., HTTP status or redirect) for this location block. + Return *Return + // ProxySSLVerify controls SSL verification for upstreams when proxying requests. + ProxySSLVerify *ProxySSLVerify + // ProxyPass is the upstream backend (URL or name) to which requests are proxied. + ProxyPass string + // HTTPMatchKey is the key for associating HTTP match rules, used for routing and NJS module logic. + HTTPMatchKey string + // MirrorSplitClientsVariableName is the variable name for split_clients, used in traffic mirroring scenarios. MirrorSplitClientsVariableName string - Type LocationType - ProxySetHeaders []Header - ProxySSLVerify *ProxySSLVerify - Return *Return - ResponseHeaders ResponseHeaders - Rewrites []string - MirrorPaths []string - Includes []shared.Include - GRPC bool + // EPPInternalPath is the internal path for the inference NJS module to redirect to. + EPPInternalPath string + // EPPHost is the host for the EndpointPicker, used for inference routing. + EPPHost string + // Type indicates the type of location (external, internal, redirect, etc). + Type LocationType + // Path is the NGINX location path. + Path string + // ResponseHeaders are custom response headers to be sent. + ResponseHeaders ResponseHeaders + // ProxySetHeaders are headers to set when proxying requests upstream. + ProxySetHeaders []Header + // Rewrites are rewrite rules for modifying request paths. + Rewrites []string + // MirrorPaths are paths to which requests are mirrored. + MirrorPaths []string + // Includes are additional NGINX config snippets or policies to include in this location. + Includes []shared.Include + // EPPPort is the port for the EndpointPicker, used for inference routing. + EPPPort int + // GRPC indicates if this location proxies gRPC traffic. + GRPC bool } // Header defines an HTTP header to be passed to the proxied server. diff --git a/internal/controller/nginx/config/maps.go b/internal/controller/nginx/config/maps.go index 5a5e5ff189..0b7ac00551 100644 --- a/internal/controller/nginx/config/maps.go +++ b/internal/controller/nginx/config/maps.go @@ -1,9 +1,12 @@ package config import ( + "fmt" "strings" gotemplate "text/template" + inference "sigs.k8s.io/gateway-api-inference-extension/api/v1" + "github.com/nginx/nginx-gateway-fabric/v2/internal/controller/nginx/config/shared" "github.com/nginx/nginx-gateway-fabric/v2/internal/controller/state/dataplane" "github.com/nginx/nginx-gateway-fabric/v2/internal/framework/helpers" @@ -26,6 +29,8 @@ const ( func executeMaps(conf dataplane.Configuration) []executeResult { maps := buildAddHeaderMaps(append(conf.HTTPServers, conf.SSLServers...)) + maps = append(maps, buildInferenceMaps(conf.BackendGroups)...) + result := executeResult{ dest: httpConfigFile, data: helpers.MustExecuteTemplate(mapsTemplate, maps), @@ -177,3 +182,38 @@ func createAddHeadersMap(name string) shared.Map { Parameters: params, } } + +// buildInferenceMaps creates maps for InferencePool Backends. +func buildInferenceMaps(groups []dataplane.BackendGroup) []shared.Map { + inferenceMaps := make([]shared.Map, 0) + for _, group := range groups { + for _, backend := range group.Backends { + if backend.EndpointPickerConfig != nil { + var defaultResult string + switch backend.EndpointPickerConfig.FailureMode { + case inference.EndpointPickerFailClose: + defaultResult = invalidBackendRef + case inference.EndpointPickerFailOpen: + defaultResult = backend.UpstreamName + } + params := []shared.MapParameter{ + { + Value: "~.+", + Result: "$inference_workload_endpoint", + }, + { + Value: "default", + Result: defaultResult, + }, + } + backendVarName := strings.ReplaceAll(backend.UpstreamName, "-", "_") + inferenceMaps = append(inferenceMaps, shared.Map{ + Source: "$inference_workload_endpoint", + Variable: fmt.Sprintf("$inference_backend_%s", backendVarName), + Parameters: params, + }) + } + } + } + return inferenceMaps +} diff --git a/internal/controller/nginx/config/maps_test.go b/internal/controller/nginx/config/maps_test.go index d133882d7b..736d7808ec 100644 --- a/internal/controller/nginx/config/maps_test.go +++ b/internal/controller/nginx/config/maps_test.go @@ -5,6 +5,7 @@ import ( "testing" . "github.com/onsi/gomega" + inference "sigs.k8s.io/gateway-api-inference-extension/api/v1" "github.com/nginx/nginx-gateway-fabric/v2/internal/controller/nginx/config/shared" "github.com/nginx/nginx-gateway-fabric/v2/internal/controller/state/dataplane" @@ -59,22 +60,24 @@ func TestExecuteMaps(t *testing.T) { conf := dataplane.Configuration{ HTTPServers: []dataplane.VirtualServer{ - { - PathRules: pathRules, - }, - { - PathRules: pathRules, - }, - { - IsDefault: true, - }, + {PathRules: pathRules}, + {PathRules: pathRules}, + {IsDefault: true}, }, SSLServers: []dataplane.VirtualServer{ + {PathRules: pathRules}, + {IsDefault: true}, + }, + BackendGroups: []dataplane.BackendGroup{ { - PathRules: pathRules, - }, - { - IsDefault: true, + Backends: []dataplane.Backend{ + { + UpstreamName: "upstream1", + EndpointPickerConfig: &inference.EndpointPickerRef{ + FailureMode: inference.EndpointPickerFailClose, + }, + }, + }, }, }, } @@ -86,6 +89,9 @@ func TestExecuteMaps(t *testing.T) { "map ${http_my_second_add_header} $my_second_add_header_header_var {": 1, "~.* ${http_my_second_add_header},;": 1, "map ${http_my_set_header} $my_set_header_header_var {": 0, + "$inference_workload_endpoint": 2, + "$inference_backend": 1, + "invalid-backend-ref": 1, } mapResult := executeMaps(conf) @@ -385,3 +391,36 @@ func TestCreateStreamMapsWithEmpty(t *testing.T) { g.Expect(maps).To(BeNil()) } + +func TestBuildInferenceMaps(t *testing.T) { + t.Parallel() + g := NewWithT(t) + + group := dataplane.BackendGroup{ + Backends: []dataplane.Backend{ + { + UpstreamName: "upstream1", + EndpointPickerConfig: &inference.EndpointPickerRef{ + FailureMode: inference.EndpointPickerFailClose, + }, + }, + { + UpstreamName: "upstream2", + EndpointPickerConfig: &inference.EndpointPickerRef{ + FailureMode: inference.EndpointPickerFailOpen, + }, + }, + { + UpstreamName: "upstream3", + EndpointPickerConfig: nil, + }, + }, + } + + maps := buildInferenceMaps([]dataplane.BackendGroup{group}) + g.Expect(maps).To(HaveLen(2)) + g.Expect(maps[0].Source).To(Equal("$inference_workload_endpoint")) + g.Expect(maps[0].Variable).To(Equal("$inference_backend_upstream1")) + g.Expect(maps[0].Parameters[1].Result).To(Equal("invalid-backend-ref")) + g.Expect(maps[1].Parameters[1].Result).To(Equal("upstream2")) +} diff --git a/internal/controller/nginx/config/servers.go b/internal/controller/nginx/config/servers.go index 88ba4fa8ea..24ce473305 100644 --- a/internal/controller/nginx/config/servers.go +++ b/internal/controller/nginx/config/servers.go @@ -16,7 +16,13 @@ import ( "github.com/nginx/nginx-gateway-fabric/v2/internal/framework/helpers" ) -var serversTemplate = gotemplate.Must(gotemplate.New("servers").Parse(serversTemplateText)) +var serversTemplate = gotemplate.Must( + gotemplate.New("servers").Funcs(gotemplate.FuncMap{ + "contains": func(str http.LocationType, substr string) bool { + return strings.Contains(string(str), substr) + }, + }).Parse(serversTemplateText), +) const ( // HeaderMatchSeparator is the separator for constructing header-based match for NJS. @@ -252,6 +258,73 @@ func extractMirrorTargetsWithPercentages(pathRules []dataplane.PathRule) map[str return mirrorTargets } +/* +There are several different flows of location blocks, depending on the user configuration. +The following describes them, with basic location examples. + +--------------- +Base case, no HTTP matching conditions or inference extension. + +External location proxies straight to backend. + +location /coffee { + proxy_pass http://backend; +} +--------------- +HTTP matching conditions. + +External location calls httpmatch NJS module. The module determines the HTTP request conditions that exist +and which backend to use, then redirects to the appropriate internal location. +The internal location proxies to the backend. + +location /coffee { + js_content httpmatches.match; // chooses backend1 or backend2, and redirects to appropriate internal location +} +location /_ngf-internal-rule0-route0 { + internal; + proxy_pass http://backend1; +} +location /_ngf-internal-rule1-route0 { + internal; + proxy_pass http://backend2; +} +--------------- +Inference extension, no HTTP matching conditions. + +External location calls inference NJS module. The module gets the AI endpoint to proxy to, +then redirects to the internal inference location that proxies to the backend. + +location /coffee { + set $epp_internal_path /_ngf-internal-rule0-route0-inference; + js_content epp.getEndpoint; // gets endpoint and redirects to /_ngf-internal-rule0-route0-inference +} +location /_ngf-internal-rule0-route0-inference { + internal; + proxy_pass http://$inference-backend; +} +--------------- +Inference extension with HTTP matching conditions. + +External location calls httpmatch NJS module. The module determines the HTTP request conditions that exist +and which backend to use, then redirects to the internal inference location. The internal inference +location calls the inference NJS module to get the AI endpoint to proxy to, then redirects to the +internal location that proxies to the backend. + +location /coffee { + js_content httpmatches.match; // chooses backend and redirects to appropriate internal inference location +} +location /_ngf-internal-rule0-route0-inference { + internal; + + set $epp_internal_path /_ngf-internal-rule0-route0; + js_content epp.getEndpoint; // redirects to /_ngf-internal-rule0-route0 +} +location /_ngf-internal-rule0-route0 { + internal; + proxy_pass http://$inference-backend; +} +*/ + type httpMatchPairs map[string][]routeMatch func createLocations( @@ -270,8 +343,6 @@ func createLocations( mirrorPathToPercentage := extractMirrorTargetsWithPercentages(server.PathRules) for pathRuleIdx, rule := range server.PathRules { - matches := make([]routeMatch, 0, len(rule.MatchRules)) - if rule.Path == rootPath { rootPathExists = true } @@ -281,7 +352,6 @@ func createLocations( } mirrorPercentage := mirrorPathToPercentage[rule.Path] - extLocations := initializeExternalLocations(rule, pathsAndTypes) for i := range extLocations { extLocations[i].Includes = createIncludesFromPolicyGenerateResult( @@ -289,54 +359,42 @@ func createLocations( ) } - if !needsInternalLocations(rule) { - for _, r := range rule.MatchRules { - extLocations = updateLocations( - r, - rule, - extLocations, - server.Port, - keepAliveCheck, - mirrorPercentage, - ) - } - - locs = append(locs, extLocations...) - continue - } - - internalLocations := make([]http.Location, 0, len(rule.MatchRules)) - - for matchRuleIdx, r := range rule.MatchRules { - intLocation, match := initializeInternalLocation(pathRuleIdx, matchRuleIdx, r.Match, rule.GRPC) - intLocation.Includes = createIncludesFromPolicyGenerateResult( - generator.GenerateForInternalLocation(rule.Policies), + switch { + case !needsInternalLocationsForMatches(rule) && !rule.HasInferenceBackends: + locs = append(locs, updateExternalLocationsForRule( + rule, + extLocations, + server.Port, + keepAliveCheck, + mirrorPercentage)..., ) - - intLocation = updateLocation( - r, + case needsInternalLocationsForMatches(rule): + internalLocations, matches := createInternalLocationsForRule( + pathRuleIdx, rule, - intLocation, + generator, server.Port, keepAliveCheck, mirrorPercentage, ) - - internalLocations = append(internalLocations, intLocation) - matches = append(matches, match) - } - - httpMatchKey := serverID + "_" + strconv.Itoa(pathRuleIdx) - for i := range extLocations { - // FIXME(sberman): De-dupe matches and associated locations - // so we don't need nginx/njs to perform unnecessary matching. - // https://github.com/nginx/nginx-gateway-fabric/issues/662 - extLocations[i].HTTPMatchKey = httpMatchKey - matchPairs[extLocations[i].HTTPMatchKey] = matches + httpMatchKey := serverID + "_" + strconv.Itoa(pathRuleIdx) + for i := range extLocations { + extLocations[i].HTTPMatchKey = httpMatchKey + matchPairs[extLocations[i].HTTPMatchKey] = matches + } + locs = append(locs, extLocations...) + locs = append(locs, internalLocations...) + case rule.HasInferenceBackends: + locs = append(locs, createInferenceLocationsForRule( + pathRuleIdx, + rule, + extLocations, + generator, + server.Port, + keepAliveCheck, + mirrorPercentage)..., + ) } - - locs = append(locs, extLocations...) - locs = append(locs, internalLocations...) } if !rootPathExists { @@ -346,7 +404,117 @@ func createLocations( return locs, matchPairs, grpcServer } -func needsInternalLocations(rule dataplane.PathRule) bool { +func updateExternalLocationsForRule( + rule dataplane.PathRule, + extLocations []http.Location, + port int32, + keepAliveCheck keepAliveChecker, + mirrorPercentage *float64, +) []http.Location { + for _, r := range rule.MatchRules { + extLocations = updateLocations( + r, + rule, + extLocations, + port, + keepAliveCheck, + mirrorPercentage, + ) + } + return extLocations +} + +func createInternalLocationsForRule( + pathRuleIdx int, + rule dataplane.PathRule, + generator policies.Generator, + port int32, + keepAliveCheck keepAliveChecker, + mirrorPercentage *float64, +) ([]http.Location, []routeMatch) { + internalLocations := make([]http.Location, 0, len(rule.MatchRules)) + matches := make([]routeMatch, 0, len(rule.MatchRules)) + for matchRuleIdx, r := range rule.MatchRules { + var intLocation http.Location + var match routeMatch + if !rule.HasInferenceBackends { + intLocation, match = initializeInternalMatchLocation(pathRuleIdx, matchRuleIdx, r.Match, rule.GRPC) + } else { + intLocation, match = initializeInternalMatchLocationWithInference(pathRuleIdx, matchRuleIdx, r.Match) + intInfLocation := initializeInternalInferenceRedirectLocation(pathRuleIdx, matchRuleIdx) + for _, b := range r.BackendGroup.Backends { + if b.EndpointPickerConfig != nil { + var portNum int + if b.EndpointPickerConfig.Port != nil { + portNum = int(b.EndpointPickerConfig.Port.Number) + } + intInfLocation.EPPInternalPath = intLocation.Path + intInfLocation.EPPHost = string(b.EndpointPickerConfig.Name) + intInfLocation.EPPPort = portNum + } + } + internalLocations = append(internalLocations, intInfLocation) + } + intLocation.Includes = createIncludesFromPolicyGenerateResult( + generator.GenerateForInternalLocation(rule.Policies), + ) + intLocation = updateLocation( + r, + rule, + intLocation, + port, + keepAliveCheck, + mirrorPercentage, + ) + internalLocations = append(internalLocations, intLocation) + matches = append(matches, match) + } + return internalLocations, matches +} + +func createInferenceLocationsForRule( + pathRuleIdx int, + rule dataplane.PathRule, + extLocations []http.Location, + generator policies.Generator, + port int32, + keepAliveCheck keepAliveChecker, + mirrorPercentage *float64, +) []http.Location { + locs := make([]http.Location, 0, len(rule.MatchRules)+len(extLocations)) + for matchRuleIdx, r := range rule.MatchRules { + intLocation := initializeInternalInferenceLocation(pathRuleIdx, matchRuleIdx) + intLocation.Includes = createIncludesFromPolicyGenerateResult( + generator.GenerateForInternalLocation(rule.Policies), + ) + intLocation = updateLocation( + r, + rule, + intLocation, + port, + keepAliveCheck, + mirrorPercentage, + ) + for _, b := range r.BackendGroup.Backends { + if b.EndpointPickerConfig != nil { + for i := range extLocations { + var portNum int + if b.EndpointPickerConfig.Port != nil { + portNum = int(b.EndpointPickerConfig.Port.Number) + } + extLocations[i].EPPInternalPath = intLocation.Path + extLocations[i].EPPHost = string(b.EndpointPickerConfig.Name) + extLocations[i].EPPPort = portNum + } + } + } + locs = append(locs, intLocation) + } + locs = append(locs, extLocations...) + return locs +} + +func needsInternalLocationsForMatches(rule dataplane.PathRule) bool { if len(rule.MatchRules) > 1 { return true } @@ -362,12 +530,13 @@ type pathAndTypeMap map[string]map[dataplane.PathType]struct{} // 2. Each path rule may have an additional location if it contains non-path-only matches. // 3. Each prefix path rule may have an additional location if it doesn't contain trailing slash. // 4. There may be an additional location for the default root path. +// 5. There may be an additional location for the inference extension. // We also return a map of all paths and their types. func getMaxLocationCountAndPathMap(pathRules []dataplane.PathRule) (int, pathAndTypeMap) { maxLocs := 1 pathsAndTypes := make(pathAndTypeMap) for _, rule := range pathRules { - maxLocs += len(rule.MatchRules) + 2 + maxLocs += len(rule.MatchRules) + 3 if pathsAndTypes[rule.Path] == nil { pathsAndTypes[rule.Path] = map[dataplane.PathType]struct{}{ rule.PathType: {}, @@ -431,14 +600,20 @@ func initializeExternalLocations( } func getLocationTypeForPathRule(rule dataplane.PathRule) http.LocationType { - if needsInternalLocations(rule) { + if needsInternalLocationsForMatches(rule) { return http.RedirectLocationType } + if rule.HasInferenceBackends { + return http.InferenceExternalLocationType + } + return http.ExternalLocationType } -func initializeInternalLocation( +// initializes the internal location that is redirected to by an external location HTTP matching decision. +// This location will proxy_pass to the backend. +func initializeInternalMatchLocation( pathruleIdx, matchRuleIdx int, match dataplane.Match, @@ -448,6 +623,43 @@ func initializeInternalLocation( return createMatchLocation(path, grpc), createRouteMatch(match, path) } +// initializes the internal inference location that is redirected to by +// an external HTTP matching location. This location then redirects to the final proxy_pass location. +func initializeInternalInferenceRedirectLocation(pathruleIdx, matchRuleIdx int) http.Location { + return http.Location{ + Path: inferencePath(pathruleIdx, matchRuleIdx), + Type: http.InferenceInternalLocationType, + } +} + +// initializes the internal location that is redirected to by an internal inference location, which was +// redirected to by the external HTTP matching location. This location will proxy_pass to the backend. +// The routeMatch is created with the inference internal location path, so that the HTTP match in the external +// location can redirect to the proper inference location, which then redirects to this location. +func initializeInternalMatchLocationWithInference( + pathruleIdx, + matchRuleIdx int, + match dataplane.Match, +) (http.Location, routeMatch) { + path := fmt.Sprintf("%s-rule%d-route%d", http.InternalRoutePathPrefix, pathruleIdx, matchRuleIdx) + grpc := false + + return createMatchLocation(path, grpc), createRouteMatch(match, inferencePath(pathruleIdx, matchRuleIdx)) +} + +// initializes the internal inference location that does the final proxy_pass to the inference backend. +// This is used when the external location redirects directly here, without any HTTP matching. +func initializeInternalInferenceLocation(pathruleIdx, matchRuleIdx int) http.Location { + return http.Location{ + Path: inferencePath(pathruleIdx, matchRuleIdx), + Type: http.InternalLocationType, + } +} + +func inferencePath(pathruleIdx int, matchRuleIdx int) string { + return fmt.Sprintf("%s-rule%d-route%d-inference", http.InternalRoutePathPrefix, pathruleIdx, matchRuleIdx) +} + // updateLocation updates a location with any relevant configurations, like proxy_pass, filters, tls settings, etc. func updateLocation( matchRule dataplane.MatchRule, @@ -460,6 +672,7 @@ func updateLocation( filters := matchRule.Filters path := pathRule.Path grpc := pathRule.GRPC + inferenceBackend := pathRule.HasInferenceBackends if filters.InvalidFilter != nil { location.Return = &http.Return{Code: http.StatusInternalServerError} @@ -475,7 +688,7 @@ func updateLocation( location = updateLocationRewriteFilter(location, filters.RequestURLRewrite, path) location = updateLocationMirrorFilters(location, filters.RequestMirrors, path, mirrorPercentage) - location = updateLocationProxySettings(location, matchRule, grpc, keepAliveCheck) + location = updateLocationProxySettings(location, matchRule, grpc, inferenceBackend, keepAliveCheck) return location } @@ -555,6 +768,7 @@ func updateLocationProxySettings( location http.Location, matchRule dataplane.MatchRule, grpc bool, + inferenceBackend bool, keepAliveCheck keepAliveChecker, ) http.Location { extraHeaders := make([]http.Header, 0, 3) @@ -575,6 +789,7 @@ func updateLocationProxySettings( matchRule.Filters.RequestURLRewrite, generateProtocolString(location.ProxySSLVerify, grpc), grpc, + inferenceBackend, ) location.ResponseHeaders = responseHeaders @@ -853,6 +1068,7 @@ func createProxyPass( filter *dataplane.HTTPURLRewriteFilter, protocol string, grpc bool, + inferenceBackend bool, ) string { var requestURI string if !grpc { @@ -862,6 +1078,12 @@ func createProxyPass( } backendName := backendGroupName(backendGroup) + + if inferenceBackend { + backendVarName := strings.ReplaceAll(backendName, "-", "_") + return "http://$inference_backend_" + backendVarName + requestURI + } + if backendGroupNeedsSplit(backendGroup) { return protocol + "://$" + convertStringToSafeVariableName(backendName) + requestURI } diff --git a/internal/controller/nginx/config/servers_template.go b/internal/controller/nginx/config/servers_template.go index 224e189a6e..9575b77480 100644 --- a/internal/controller/nginx/config/servers_template.go +++ b/internal/controller/nginx/config/servers_template.go @@ -92,7 +92,7 @@ server { {{ range $l := $s.Locations }} location {{ $l.Path }} { - {{ if eq $l.Type "internal" -}} + {{ if contains $l.Type "internal" -}} internal; {{ end }} @@ -118,11 +118,19 @@ server { return {{ $l.Return.Code }} "{{ $l.Return.Body }}"; {{- end }} - {{- if eq $l.Type "redirect" }} + {{- if eq $l.Type "redirect" -}} set $match_key {{ $l.HTTPMatchKey }}; js_content httpmatches.redirect; {{- end }} + {{- if contains $l.Type "inference" -}} + js_var $inference_workload_endpoint; + set $epp_internal_path {{ $l.EPPInternalPath }}; + set $epp_host {{ $l.EPPHost }}; + set $epp_port {{ $l.EPPPort }}; + js_content epp.getEndpoint; + {{- end }} + {{ $proxyOrGRPC := "proxy" }}{{ if $l.GRPC }}{{ $proxyOrGRPC = "grpc" }}{{ end }} {{- if $l.GRPC }} diff --git a/internal/controller/nginx/config/servers_test.go b/internal/controller/nginx/config/servers_test.go index 6b604d7bec..ab4fad31a5 100644 --- a/internal/controller/nginx/config/servers_test.go +++ b/internal/controller/nginx/config/servers_test.go @@ -9,6 +9,7 @@ import ( . "github.com/onsi/gomega" "github.com/onsi/gomega/format" "k8s.io/apimachinery/pkg/types" + inference "sigs.k8s.io/gateway-api-inference-extension/api/v1" "github.com/nginx/nginx-gateway-fabric/v2/internal/controller/nginx/config/http" "github.com/nginx/nginx-gateway-fabric/v2/internal/controller/nginx/config/policies" @@ -1239,7 +1240,7 @@ func TestCreateServers(t *testing.T) { Filters: dataplane.HTTPFilters{ RequestRedirect: &dataplane.HTTPRequestRedirectFilter{ Hostname: helpers.GetPointer("redirect.example.com"), - StatusCode: helpers.GetPointer[int](301), + StatusCode: helpers.GetPointer(301), Port: helpers.GetPointer[int32](8080), Path: &dataplane.HTTPPathModifier{ Type: dataplane.ReplaceFullPath, @@ -2443,6 +2444,154 @@ func TestCreateLocations_Includes(t *testing.T) { } } +func TestCreateLocations_InferenceBackends(t *testing.T) { + t.Parallel() + + hrNsName := types.NamespacedName{Namespace: "test", Name: "route1"} + + fooGroup := dataplane.BackendGroup{ + Source: hrNsName, + RuleIdx: 0, + Backends: []dataplane.Backend{ + { + UpstreamName: "test_foo_80", + Valid: true, + Weight: 1, + EndpointPickerConfig: &inference.EndpointPickerRef{ + Name: "test-epp", + Port: &inference.Port{ + Number: 80, + }, + }, + }, + }, + } + + pathRuleInferenceOnly := dataplane.PathRule{ + Path: "/inference", + PathType: dataplane.PathTypeExact, + HasInferenceBackends: true, + MatchRules: []dataplane.MatchRule{ + { + Match: dataplane.Match{}, + BackendGroup: fooGroup, + }, + }, + } + + pathRuleInferenceWithMatch := dataplane.PathRule{ + Path: "/inference-match", + PathType: dataplane.PathTypeExact, + HasInferenceBackends: true, + MatchRules: []dataplane.MatchRule{ + { + Match: dataplane.Match{ + Method: helpers.GetPointer("POST"), + }, + BackendGroup: fooGroup, + }, + }, + } + + tests := []struct { + expMatches httpMatchPairs + name string + pathRules []dataplane.PathRule + expLocs []http.Location + }{ + { + name: "inference only, no internal locations for matches", + pathRules: []dataplane.PathRule{pathRuleInferenceOnly}, + expLocs: []http.Location{ + { + Path: "/_ngf-internal-rule0-route0-inference", + Type: http.InternalLocationType, + ProxyPass: "http://$inference_backend_test_foo_80$request_uri", + ProxySetHeaders: []http.Header{ + {Name: "Host", Value: "$gw_api_compliant_host"}, + {Name: "X-Forwarded-For", Value: "$proxy_add_x_forwarded_for"}, + {Name: "X-Real-IP", Value: "$remote_addr"}, + {Name: "X-Forwarded-Proto", Value: "$scheme"}, + {Name: "X-Forwarded-Host", Value: "$host"}, + {Name: "X-Forwarded-Port", Value: "$server_port"}, + {Name: "Upgrade", Value: "$http_upgrade"}, + {Name: "Connection", Value: "$connection_upgrade"}, + }, + }, + { + Path: "= /inference", + Type: http.InferenceExternalLocationType, + EPPInternalPath: "/_ngf-internal-rule0-route0-inference", + EPPHost: "test-epp", + EPPPort: 80, + }, + createDefaultRootLocation(), + }, + expMatches: httpMatchPairs{}, + }, + { + name: "inference with match, needs internal locations for matches", + pathRules: []dataplane.PathRule{pathRuleInferenceWithMatch}, + expLocs: []http.Location{ + { + Path: "= /inference-match", + Type: http.RedirectLocationType, + HTTPMatchKey: "1_0", + }, + { + Path: "/_ngf-internal-rule0-route0-inference", + Type: http.InferenceInternalLocationType, + EPPInternalPath: "/_ngf-internal-rule0-route0", + EPPHost: "test-epp", + EPPPort: 80, + }, + { + Path: "/_ngf-internal-rule0-route0", + Type: http.InternalLocationType, + ProxyPass: "http://$inference_backend_test_foo_80$request_uri", + ProxySetHeaders: []http.Header{ + {Name: "Host", Value: "$gw_api_compliant_host"}, + {Name: "X-Forwarded-For", Value: "$proxy_add_x_forwarded_for"}, + {Name: "X-Real-IP", Value: "$remote_addr"}, + {Name: "X-Forwarded-Proto", Value: "$scheme"}, + {Name: "X-Forwarded-Host", Value: "$host"}, + {Name: "X-Forwarded-Port", Value: "$server_port"}, + {Name: "Upgrade", Value: "$http_upgrade"}, + {Name: "Connection", Value: "$connection_upgrade"}, + }, + }, + createDefaultRootLocation(), + }, + expMatches: httpMatchPairs{ + "1_0": { + {Method: "POST", RedirectPath: "/_ngf-internal-rule0-route0-inference"}, + }, + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + g := NewWithT(t) + + locs, matches, _ := createLocations( + &dataplane.VirtualServer{ + Hostname: "example.com", + PathRules: tc.pathRules, + Port: 80, + }, + "1", + &policiesfakes.FakeGenerator{}, + alwaysFalseKeepAliveChecker, + ) + + g.Expect(helpers.Diff(tc.expLocs, locs)).To(BeEmpty()) + g.Expect(matches).To(Equal(tc.expMatches)) + }) + } +} + func TestCreateLocationsRootPath(t *testing.T) { t.Parallel() hrNsName := types.NamespacedName{Namespace: "test", Name: "route1"} @@ -3332,10 +3481,11 @@ func TestCreateProxyPass(t *testing.T) { t.Parallel() tests := []struct { - rewrite *dataplane.HTTPURLRewriteFilter - expected string - grp dataplane.BackendGroup - GRPC bool + rewrite *dataplane.HTTPURLRewriteFilter + expected string + grp dataplane.BackendGroup + GRPC bool + inferenceBackend bool }{ { expected: "http://10.0.0.1:80$request_uri", @@ -3349,6 +3499,20 @@ func TestCreateProxyPass(t *testing.T) { }, }, }, + // Inference case + { + expected: "http://$inference_backend_upstream_inference$request_uri", + grp: dataplane.BackendGroup{ + Backends: []dataplane.Backend{ + { + UpstreamName: "upstream-inference", + Valid: true, + Weight: 1, + }, + }, + }, + inferenceBackend: true, + }, { expected: "http://$group_ns1__bg_rule0$request_uri", grp: dataplane.BackendGroup{ @@ -3401,7 +3565,13 @@ func TestCreateProxyPass(t *testing.T) { t.Run(tc.expected, func(t *testing.T) { t.Parallel() g := NewWithT(t) - result := createProxyPass(tc.grp, tc.rewrite, generateProtocolString(nil, tc.GRPC), tc.GRPC) + result := createProxyPass( + tc.grp, + tc.rewrite, + generateProtocolString(nil, tc.GRPC), + tc.GRPC, + tc.inferenceBackend, + ) g.Expect(result).To(Equal(tc.expected)) }) } diff --git a/internal/controller/nginx/modules/src/epp.js b/internal/controller/nginx/modules/src/epp.js index d4beeb9e15..88de40062b 100644 --- a/internal/controller/nginx/modules/src/epp.js +++ b/internal/controller/nginx/modules/src/epp.js @@ -1,29 +1,59 @@ -// This file contains the methods to get an AI workload endpoint from the EndpointPicker (EPP). +import qs from 'querystring'; -// TODO(sberman): this module will need to be enhanced to include the following: -// - function that sends the subrequest to the Go middleware application (to get the endpoint from EPP) -// - if a user has specified an Exact matching condition for a model name, extract the model name from -// the request body, and if it matches that condition, set the proper value in the X-Gateway-Model-Name header -// (based on if we do a redirect or traffic split (see design doc)) in the subrequest. If the client request -// already has this header set, then I don't think we need to extract the model from the body, just pass -// through the existing header. -// I believe we have to use js_content to call the NJS functionality. Because this takes over -// the request, we will likely have to finish the NJS functionality with an internalRedirect to an internal -// location that proxy_passes to the chosen endpoint. +const EPP_HOST_HEADER_VAR = 'epp_host'; +const EPP_PORT_HEADER_VAR = 'epp_port'; +const EPP_HOST_HEADER = 'X-EPP-Host'; +const EPP_PORT_HEADER = 'X-EPP-Port'; +const ENDPOINT_HEADER = 'X-Gateway-Destination-Endpoint'; +const EPP_INTERNAL_PATH_VAR = 'epp_internal_path'; +const WORKLOAD_ENDPOINT_VAR = 'inference_workload_endpoint'; +const SHIM_URI = 'http://127.0.0.1:54800'; + +async function getEndpoint(r) { + if (!r.variables[EPP_HOST_HEADER_VAR] || !r.variables[EPP_PORT_HEADER_VAR]) { + throw Error( + `Missing required variables: ${EPP_HOST_HEADER_VAR} and/or ${EPP_PORT_HEADER_VAR}`, + ); + } + if (!r.variables[EPP_INTERNAL_PATH_VAR]) { + throw Error(`Missing required variable: ${EPP_INTERNAL_PATH_VAR}`); + } + + let headers = Object.assign({}, r.headersIn); + headers[EPP_HOST_HEADER] = r.variables[EPP_HOST_HEADER_VAR]; + headers[EPP_PORT_HEADER] = r.variables[EPP_PORT_HEADER_VAR]; -// extractModel extracts the model name from the request body. -function extractModel(r) { try { - var body = JSON.parse(r.requestText); - if (body && body.model !== undefined) { - return String(body.model); + const response = await ngx.fetch(SHIM_URI, { + method: r.method, + headers: headers, + body: r.requestText, + }); + const endpointHeader = response.headers.get(ENDPOINT_HEADER); + if (response.status === 200 && endpointHeader) { + r.variables[WORKLOAD_ENDPOINT_VAR] = endpointHeader; + r.log( + `found inference endpoint from EndpointPicker: ${r.variables[WORKLOAD_ENDPOINT_VAR]}`, + ); + } else { + const body = await response.text(); + r.error( + `could not get specific inference endpoint from EndpointPicker; ` + + `status: ${response.status}; body: ${body}`, + ); } - } catch (e) { - r.error(`error parsing request body for model name: ${e.message}`); - return ''; + } catch (err) { + r.error(`Error in ngx.fetch: ${err}`); } - r.error('request body does not contain model parameter'); - return ''; + + // If performing a rewrite, $request_uri won't be used, + // so we have to preserve args in the internal redirect. + let args = qs.stringify(r.args); + if (args) { + args = '?' + args; + } + + r.internalRedirect(r.variables[EPP_INTERNAL_PATH_VAR] + args); } -export default { extractModel }; +export default { getEndpoint }; diff --git a/internal/controller/nginx/modules/test/epp.test.js b/internal/controller/nginx/modules/test/epp.test.js index 6994423e7a..288104520b 100644 --- a/internal/controller/nginx/modules/test/epp.test.js +++ b/internal/controller/nginx/modules/test/epp.test.js @@ -1,52 +1,82 @@ import { default as epp } from '../src/epp.js'; -import { expect, describe, it } from 'vitest'; +import { expect, describe, it, beforeEach, afterEach, vi } from 'vitest'; -function makeRequest(body) { - let r = { - // Test mocks - error(msg) { - r.variables.error = msg; - }, - requestText: body, - variables: {}, +function makeRequest({ method = 'POST', headersIn = {}, requestText = '', variables = {} } = {}) { + return { + method, + headersIn, + requestText, + variables, + error: vi.fn(), + log: vi.fn(), + internalRedirect: vi.fn(), }; - - return r; } -describe('extractModel', () => { - const tests = [ - { - name: 'returns the model value', - body: '{"model":"gpt-4"}', - model: 'gpt-4', - error: undefined, - }, - { - name: 'returns empty string if model is missing', - body: '{"foo":1}', - model: '', - error: 'request body does not contain model parameter', - }, - { - name: 'returns empty string for invalid JSON', - body: 'not-json', - model: '', - error: `error parsing request body for model name: Unexpected token 'o', "not-json" is not valid JSON`, - }, - { - name: 'empty request body', - body: '', - model: '', - error: 'error parsing request body for model name: Unexpected end of JSON input', - }, - ]; +describe('getEndpoint', () => { + let originalNgx; + beforeEach(() => { + originalNgx = globalThis.ngx; + }); + afterEach(() => { + globalThis.ngx = originalNgx; + }); + + it('throws if host or port is missing', async () => { + const r = makeRequest({ variables: { epp_internal_path: '/foo' } }); + await expect(epp.getEndpoint(r)).rejects.toThrow(/Missing required variables/); + }); + + it('throws if internal path is missing', async () => { + const r = makeRequest({ variables: { epp_host: 'host', epp_port: '1234' } }); + await expect(epp.getEndpoint(r)).rejects.toThrow(/Missing required variable/); + }); + + it('sets endpoint and logs on 200 with endpoint header', async () => { + const endpoint = 'http://endpoint'; + globalThis.ngx = { + fetch: vi.fn().mockResolvedValue({ + status: 200, + headers: { get: () => endpoint }, + text: vi.fn(), + }), + }; + const r = makeRequest({ + variables: { epp_host: 'host', epp_port: '1234', epp_internal_path: '/foo' }, + }); + await epp.getEndpoint(r); + expect(r.variables.inference_workload_endpoint).toBe(endpoint); + expect(r.log).toHaveBeenCalledWith(expect.stringContaining(endpoint)); + expect(r.internalRedirect).toHaveBeenCalledWith('/foo'); + }); + + it('calls error if response is not 200 or endpoint header missing', async () => { + globalThis.ngx = { + fetch: vi.fn().mockResolvedValue({ + status: 404, + headers: { get: () => null }, + text: vi.fn().mockResolvedValue('fail'), + }), + }; + const r = makeRequest({ + variables: { epp_host: 'host', epp_port: '1234', epp_internal_path: '/foo' }, + }); + await epp.getEndpoint(r); + expect(r.error).toHaveBeenCalledWith( + expect.stringContaining('could not get specific inference endpoint'), + ); + expect(r.internalRedirect).toHaveBeenCalledWith('/foo'); + }); - tests.forEach((test) => { - it(test.name, () => { - let r = makeRequest(test.body); - expect(epp.extractModel(r)).to.equal(test.model); - expect(r.variables.error).to.equal(test.error); + it('calls error if fetch throws', async () => { + globalThis.ngx = { + fetch: vi.fn().mockRejectedValue(new Error('network fail')), + }; + const r = makeRequest({ + variables: { epp_host: 'host', epp_port: '1234', epp_internal_path: '/foo' }, }); + await epp.getEndpoint(r); + expect(r.error).toHaveBeenCalledWith(expect.stringContaining('Error in ngx.fetch')); + expect(r.internalRedirect).toHaveBeenCalledWith('/foo'); }); }); diff --git a/internal/controller/state/dataplane/configuration.go b/internal/controller/state/dataplane/configuration.go index 52306f4e0b..59030c0ca7 100644 --- a/internal/controller/state/dataplane/configuration.go +++ b/internal/controller/state/dataplane/configuration.go @@ -374,12 +374,13 @@ func newBackendGroup( gatewayName types.NamespacedName, sourceNsName types.NamespacedName, ruleIdx int, -) BackendGroup { +) (BackendGroup, bool) { var backends []Backend if len(refs) > 0 { backends = make([]Backend, 0, len(refs)) } + var inferencePoolBackendExists bool for _, ref := range refs { if ref.IsMirrorBackend { @@ -391,11 +392,14 @@ func newBackendGroup( valid = false } + inferencePoolBackendExists = inferencePoolBackendExists || ref.IsInferencePool + backends = append(backends, Backend{ - UpstreamName: ref.ServicePortReference(), - Weight: ref.Weight, - Valid: valid, - VerifyTLS: convertBackendTLS(ref.BackendTLSPolicy, gatewayName), + UpstreamName: ref.ServicePortReference(), + Weight: ref.Weight, + Valid: valid, + VerifyTLS: convertBackendTLS(ref.BackendTLSPolicy, gatewayName), + EndpointPickerConfig: ref.EndpointPickerConfig, }) } @@ -403,7 +407,7 @@ func newBackendGroup( Backends: backends, Source: sourceNsName, RuleIdx: ruleIdx, - } + }, inferencePoolBackendExists } func convertBackendTLS(btp *graph.BackendTLSPolicy, gwNsName types.NamespacedName) *VerifyTLS { @@ -595,10 +599,19 @@ func (hpr *hostPathRules) upsertRoute( } hostRule.GRPC = GRPC + backendGroup, inferencePoolBackendExists := newBackendGroup( + rule.BackendRefs, + listener.GatewayName, + routeNsName, + idx, + ) + if inferencePoolBackendExists { + hostRule.HasInferenceBackends = true + } hostRule.MatchRules = append(hostRule.MatchRules, MatchRule{ Source: objectSrc, - BackendGroup: newBackendGroup(rule.BackendRefs, listener.GatewayName, routeNsName, idx), + BackendGroup: backendGroup, Filters: filters, Match: convertMatch(m), }) diff --git a/internal/controller/state/dataplane/configuration_test.go b/internal/controller/state/dataplane/configuration_test.go index b329b9d46a..3e1697590d 100644 --- a/internal/controller/state/dataplane/configuration_test.go +++ b/internal/controller/state/dataplane/configuration_test.go @@ -2777,6 +2777,93 @@ func TestBuildConfiguration_Plus(t *testing.T) { } } +func TestUpsertRoute_PathRuleHasInferenceBackend(t *testing.T) { + t.Parallel() + g := NewWithT(t) + + // Setup minimal route with one BackendRef marked as IsInferencePool + backendRef := graph.BackendRef{ + SvcNsName: types.NamespacedName{Name: "svc", Namespace: "test"}, + ServicePort: apiv1.ServicePort{Port: 80}, + Valid: true, + IsInferencePool: true, + } + + listenerName := "listener-80" + gwName := types.NamespacedName{Namespace: "test", Name: "gw"} + + route := &graph.L7Route{ + RouteType: graph.RouteTypeHTTP, + Source: &v1.HTTPRoute{ + ObjectMeta: metav1.ObjectMeta{ + Name: "hr", + Namespace: "test", + }, + }, + Spec: graph.L7RouteSpec{ + Rules: []graph.RouteRule{ + { + ValidMatches: true, + Filters: graph.RouteRuleFilters{Valid: true}, + BackendRefs: []graph.BackendRef{backendRef}, + Matches: []v1.HTTPRouteMatch{ + { + Path: &v1.HTTPPathMatch{ + Type: helpers.GetPointer(v1.PathMatchPathPrefix), + Value: helpers.GetPointer("/infer"), + }, + }, + }, + }, + }, + }, + ParentRefs: []graph.ParentRef{ + { + Attachment: &graph.ParentRefAttachmentStatus{ + AcceptedHostnames: map[string][]string{ + graph.CreateGatewayListenerKey(gwName, listenerName): {"*"}, + }, + }, + }, + }, + Valid: true, + } + + listener := &graph.Listener{ + Name: listenerName, + GatewayName: gwName, + Valid: true, + Routes: map[graph.RouteKey]*graph.L7Route{ + graph.CreateRouteKey(route.Source): route, + }, + } + + gateway := &graph.Gateway{ + Source: &v1.Gateway{ + ObjectMeta: metav1.ObjectMeta{ + Name: "gw", + Namespace: "test", + }, + }, + Listeners: []*graph.Listener{listener}, + } + + hpr := newHostPathRules() + hpr.upsertRoute(route, listener, gateway) + + // Find the PathRule for "/infer" + found := false + for _, rules := range hpr.rulesPerHost { + for _, pr := range rules { + if pr.Path == "/infer" { + found = true + g.Expect(pr.HasInferenceBackends).To(BeTrue()) + } + } + } + g.Expect(found).To(BeTrue(), "PathRule for '/infer' not found") +} + func TestNewBackendGroup_Mirror(t *testing.T) { t.Parallel() g := NewWithT(t) @@ -2788,7 +2875,7 @@ func TestNewBackendGroup_Mirror(t *testing.T) { IsMirrorBackend: true, } - group := newBackendGroup([]graph.BackendRef{backendRef}, types.NamespacedName{}, types.NamespacedName{}, 0) + group, _ := newBackendGroup([]graph.BackendRef{backendRef}, types.NamespacedName{}, types.NamespacedName{}, 0) g.Expect(group.Backends).To(BeEmpty()) } diff --git a/internal/controller/state/dataplane/types.go b/internal/controller/state/dataplane/types.go index 08e7e0867b..1637c1f408 100644 --- a/internal/controller/state/dataplane/types.go +++ b/internal/controller/state/dataplane/types.go @@ -5,6 +5,7 @@ import ( metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/types" + inference "sigs.k8s.io/gateway-api-inference-extension/api/v1" "github.com/nginx/nginx-gateway-fabric/v2/internal/controller/nginx/config/policies" "github.com/nginx/nginx-gateway-fabric/v2/internal/controller/state/graph" @@ -137,6 +138,8 @@ type PathRule struct { Policies []policies.Policy // GRPC indicates if this is a gRPC rule GRPC bool + // HasInferenceBackends indicates whether the PathRule contains a backend for an inference workload. + HasInferenceBackends bool } // InvalidHTTPFilter is a special filter for handling the case when configured filters are invalid. @@ -323,6 +326,9 @@ func (bg *BackendGroup) Name() string { type Backend struct { // VerifyTLS holds the backend TLS verification configuration. VerifyTLS *VerifyTLS + // EndpointPickerConfig holds the configuration for the EndpointPicker for this backend. + // This is set if this backend is for an inference workload. + EndpointPickerConfig *inference.EndpointPickerRef // UpstreamName is the name of the upstream for this backend. UpstreamName string // Weight is the weight of the BackendRef. diff --git a/internal/controller/state/graph/backend_refs.go b/internal/controller/state/graph/backend_refs.go index e14d0fb0fa..95ce6df0b9 100644 --- a/internal/controller/state/graph/backend_refs.go +++ b/internal/controller/state/graph/backend_refs.go @@ -9,6 +9,7 @@ import ( v1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/types" "k8s.io/apimachinery/pkg/util/validation/field" + inference "sigs.k8s.io/gateway-api-inference-extension/api/v1" gatewayv1 "sigs.k8s.io/gateway-api/apis/v1" "sigs.k8s.io/gateway-api/apis/v1alpha3" @@ -30,6 +31,8 @@ const ( type BackendRef struct { // BackendTLSPolicy is the BackendTLSPolicy of the Service which is referenced by the backendRef. BackendTLSPolicy *BackendTLSPolicy + // EndpointPickerConfig is the configuration for the EndpointPicker, if this backendRef is for an InferencePool. + EndpointPickerConfig *inference.EndpointPickerRef // InvalidForGateways is a map of Gateways for which this BackendRef is invalid for, with the corresponding // condition. Certain NginxProxy configurations may result in a backend not being valid for some Gateways, // but not others. @@ -45,6 +48,8 @@ type BackendRef struct { Valid bool // IsMirrorBackend indicates whether the BackendGroup is for a mirrored backend. IsMirrorBackend bool + // IsInferencePool indicates whether the BackendRef is for an InferencePool. + IsInferencePool bool } // ServicePortReference returns a string representation for the service and port that is referenced by the BackendRef. @@ -118,6 +123,7 @@ func addBackendRefsToRules( if pool, exists := referencedInferencePools[poolName]; exists { port := gatewayv1.PortNumber(pool.Source.Spec.TargetPorts[0].Number) ref.Port = helpers.GetPointer(port) + ref.EndpointPickerConfig = &pool.Source.Spec.EndpointPickerRef } } @@ -181,10 +187,12 @@ func createBackendRef( if !valid { backendRef := BackendRef{ - Weight: weight, - Valid: false, - IsMirrorBackend: ref.MirrorBackendIdx != nil, - InvalidForGateways: make(map[types.NamespacedName]conditions.Condition), + Weight: weight, + Valid: false, + IsMirrorBackend: ref.MirrorBackendIdx != nil, + IsInferencePool: ref.IsInferencePool, + InvalidForGateways: make(map[types.NamespacedName]conditions.Condition), + EndpointPickerConfig: ref.EndpointPickerConfig, } return backendRef, []conditions.Condition{cond} @@ -198,12 +206,14 @@ func createBackendRef( svcIPFamily, svcPort, err := getIPFamilyAndPortFromRef(ref.BackendRef, svcNsName, services, refPath) if err != nil { backendRef := BackendRef{ - Weight: weight, - Valid: false, - SvcNsName: svcNsName, - ServicePort: v1.ServicePort{}, - IsMirrorBackend: ref.MirrorBackendIdx != nil, - InvalidForGateways: make(map[types.NamespacedName]conditions.Condition), + Weight: weight, + Valid: false, + SvcNsName: svcNsName, + ServicePort: v1.ServicePort{}, + IsMirrorBackend: ref.MirrorBackendIdx != nil, + IsInferencePool: ref.IsInferencePool, + InvalidForGateways: make(map[types.NamespacedName]conditions.Condition), + EndpointPickerConfig: ref.EndpointPickerConfig, } return backendRef, []conditions.Condition{conditions.NewRouteBackendRefRefBackendNotFound(err.Error())} @@ -220,12 +230,14 @@ func createBackendRef( // Check if externalName field is empty or whitespace-only if strings.TrimSpace(svc.Spec.ExternalName) == "" { backendRef := BackendRef{ - SvcNsName: svcNsName, - ServicePort: svcPort, - Weight: weight, - Valid: false, - IsMirrorBackend: ref.MirrorBackendIdx != nil, - InvalidForGateways: invalidForGateways, + SvcNsName: svcNsName, + ServicePort: svcPort, + Weight: weight, + Valid: false, + IsMirrorBackend: ref.MirrorBackendIdx != nil, + IsInferencePool: ref.IsInferencePool, + InvalidForGateways: invalidForGateways, + EndpointPickerConfig: ref.EndpointPickerConfig, } return backendRef, append(conds, conditions.NewRouteBackendRefUnsupportedValue( @@ -249,12 +261,14 @@ func createBackendRef( ) if err != nil { backendRef := BackendRef{ - SvcNsName: svcNsName, - ServicePort: svcPort, - Weight: weight, - Valid: false, - IsMirrorBackend: ref.MirrorBackendIdx != nil, - InvalidForGateways: invalidForGateways, + SvcNsName: svcNsName, + ServicePort: svcPort, + Weight: weight, + Valid: false, + IsMirrorBackend: ref.MirrorBackendIdx != nil, + IsInferencePool: ref.IsInferencePool, + InvalidForGateways: invalidForGateways, + EndpointPickerConfig: ref.EndpointPickerConfig, } return backendRef, append(conds, conditions.NewRouteBackendRefUnsupportedValue(err.Error())) @@ -264,13 +278,15 @@ func createBackendRef( err = validateRouteBackendRefAppProtocol(route.RouteType, *svcPort.AppProtocol, backendTLSPolicy) if err != nil { backendRef := BackendRef{ - SvcNsName: svcNsName, - BackendTLSPolicy: backendTLSPolicy, - ServicePort: svcPort, - Weight: weight, - Valid: false, - IsMirrorBackend: ref.MirrorBackendIdx != nil, - InvalidForGateways: invalidForGateways, + SvcNsName: svcNsName, + BackendTLSPolicy: backendTLSPolicy, + ServicePort: svcPort, + Weight: weight, + Valid: false, + IsMirrorBackend: ref.MirrorBackendIdx != nil, + IsInferencePool: ref.IsInferencePool, + InvalidForGateways: invalidForGateways, + EndpointPickerConfig: ref.EndpointPickerConfig, } return backendRef, append(conds, conditions.NewRouteBackendRefUnsupportedProtocol(err.Error())) @@ -278,13 +294,15 @@ func createBackendRef( } backendRef := BackendRef{ - SvcNsName: svcNsName, - BackendTLSPolicy: backendTLSPolicy, - ServicePort: svcPort, - Valid: true, - Weight: weight, - IsMirrorBackend: ref.MirrorBackendIdx != nil, - InvalidForGateways: invalidForGateways, + SvcNsName: svcNsName, + BackendTLSPolicy: backendTLSPolicy, + ServicePort: svcPort, + Valid: true, + Weight: weight, + IsMirrorBackend: ref.MirrorBackendIdx != nil, + IsInferencePool: ref.IsInferencePool, + InvalidForGateways: invalidForGateways, + EndpointPickerConfig: ref.EndpointPickerConfig, } return backendRef, conds diff --git a/internal/controller/state/graph/backend_refs_test.go b/internal/controller/state/graph/backend_refs_test.go index 3f05f793a6..b786daed9b 100644 --- a/internal/controller/state/graph/backend_refs_test.go +++ b/internal/controller/state/graph/backend_refs_test.go @@ -1231,9 +1231,11 @@ func TestAddBackendRefsToRules(t *testing.T) { ServicePort: v1.ServicePort{ Port: 80, }, - Valid: true, - Weight: 1, - InvalidForGateways: map[types.NamespacedName]conditions.Condition{}, + Valid: true, + Weight: 1, + InvalidForGateways: map[types.NamespacedName]conditions.Condition{}, + IsInferencePool: true, + EndpointPickerConfig: &inference.EndpointPickerRef{}, }, }, expectedConditions: nil, diff --git a/internal/controller/state/graph/graph_test.go b/internal/controller/state/graph/graph_test.go index da0ca04d47..1a367e5977 100644 --- a/internal/controller/state/graph/graph_test.go +++ b/internal/controller/state/graph/graph_test.go @@ -223,10 +223,12 @@ func TestBuildGraph(t *testing.T) { Namespace: testNs, Name: controller.CreateInferencePoolServiceName("ipool"), }, - ServicePort: v1.ServicePort{Port: 80}, - Valid: true, - Weight: 1, - InvalidForGateways: map[types.NamespacedName]conditions.Condition{}, + ServicePort: v1.ServicePort{Port: 80}, + Valid: true, + Weight: 1, + InvalidForGateways: map[types.NamespacedName]conditions.Condition{}, + IsInferencePool: true, + EndpointPickerConfig: &inference.EndpointPickerRef{}, }, } rbrs := []RouteBackendRef{ diff --git a/internal/controller/state/graph/httproute.go b/internal/controller/state/graph/httproute.go index de7a85370d..ed8d46a664 100644 --- a/internal/controller/state/graph/httproute.go +++ b/internal/controller/state/graph/httproute.go @@ -210,11 +210,26 @@ func processHTTPRouteRule( } } - var rbr RouteBackendRef + rbr := RouteBackendRef{ + BackendRef: b.BackendRef, + } + // If route specifies an InferencePool backend, we need to convert it to its associated // headless Service backend (that we created), so nginx config can be built properly. // Only do this if the InferencePool actually exists. if inferencePoolBackend(b, routeNamespace, inferencePools) { + // We don't support traffic splitting at the Route level for + // InferencePool backends, so if there's more than one backendRef, and one of them + // is an InferencePool, we mark the rule as invalid. + if len(specRule.BackendRefs) > 1 { + err := field.Forbidden( + rulePath.Child("backendRefs"), + "cannot use InferencePool backend when multiple backendRefs are specified in a single rule", + ) + errors.invalid = append(errors.invalid, err) + break + } + svcName := controller.CreateInferencePoolServiceName(string(b.Name)) rbr = RouteBackendRef{ IsInferencePool: true, @@ -228,10 +243,6 @@ func processHTTPRouteRule( Weight: b.Weight, }, } - } else { - rbr = RouteBackendRef{ - BackendRef: b.BackendRef, - } } rbr.Filters = interfaceFilters diff --git a/internal/controller/state/graph/httproute_test.go b/internal/controller/state/graph/httproute_test.go index 0e06e5bf7e..d6d77c7296 100644 --- a/internal/controller/state/graph/httproute_test.go +++ b/internal/controller/state/graph/httproute_test.go @@ -1213,6 +1213,67 @@ func TestBuildHTTPRouteWithMirrorRoutes(t *testing.T) { g.Expect(helpers.Diff(expectedMirrorRoute, routes[mirrorRouteKey])).To(BeEmpty()) } +func TestProcessHTTPRouteRule_InferencePoolWithMultipleBackendRefs(t *testing.T) { + t.Parallel() + g := NewWithT(t) + + validator := &validationfakes.FakeHTTPFieldsValidator{} + inferencePoolName := "ipool" + routeNamespace := "test" + inferencePools := map[types.NamespacedName]*inference.InferencePool{ + {Namespace: routeNamespace, Name: inferencePoolName}: {}, + } + + // BackendRef 1: InferencePool + backendRef1 := gatewayv1.HTTPBackendRef{ + BackendRef: gatewayv1.BackendRef{ + BackendObjectReference: gatewayv1.BackendObjectReference{ + Group: helpers.GetPointer[gatewayv1.Group](inferenceAPIGroup), + Kind: helpers.GetPointer[gatewayv1.Kind](kinds.InferencePool), + Name: gatewayv1.ObjectName(inferencePoolName), + Namespace: helpers.GetPointer(gatewayv1.Namespace(routeNamespace)), + }, + }, + } + // BackendRef 2: Service + backendRef2 := gatewayv1.HTTPBackendRef{ + BackendRef: gatewayv1.BackendRef{ + BackendObjectReference: gatewayv1.BackendObjectReference{ + Kind: helpers.GetPointer[gatewayv1.Kind](kinds.Service), + Name: "backend", + }, + }, + } + + specRule := gatewayv1.HTTPRouteRule{ + Matches: []gatewayv1.HTTPRouteMatch{ + { + Path: &gatewayv1.HTTPPathMatch{ + Type: helpers.GetPointer(gatewayv1.PathMatchPathPrefix), + Value: helpers.GetPointer("/"), + }, + }, + }, + BackendRefs: []gatewayv1.HTTPBackendRef{backendRef1, backendRef2}, + } + + rulePath := field.NewPath("spec").Child("rules").Index(0) + + routeRule, errs := processHTTPRouteRule( + specRule, + routeNamespace, + rulePath, + validator, + nil, + inferencePools, + ) + + g.Expect(routeRule.RouteBackendRefs).To(BeEmpty()) + g.Expect(errs.invalid).To(HaveLen(1)) + errMsg := "cannot use InferencePool backend when multiple backendRefs are specified in a single rule" + g.Expect(errs.invalid[0].Error()).To(ContainSubstring(errMsg)) +} + func TestValidateMatch(t *testing.T) { t.Parallel() createAllValidValidator := func() *validationfakes.FakeHTTPFieldsValidator { diff --git a/internal/controller/state/graph/route_common.go b/internal/controller/state/graph/route_common.go index f3d3b04e4a..22067c6d44 100644 --- a/internal/controller/state/graph/route_common.go +++ b/internal/controller/state/graph/route_common.go @@ -166,6 +166,9 @@ type RouteBackendRef struct { // If this backend is defined in a RequestMirror filter, this value will indicate the filter's index. MirrorBackendIdx *int + // EndpointPickerConfig is the configuration for the EndpointPicker, if this backendRef is for an InferencePool. + EndpointPickerConfig *inference.EndpointPickerRef + Filters []any // IsInferencePool indicates if this backend is an InferencePool disguised as a Service. diff --git a/internal/framework/types/types.go b/internal/framework/types/types.go index bf61bd23d7..0aeccd008d 100644 --- a/internal/framework/types/types.go +++ b/internal/framework/types/types.go @@ -5,3 +5,14 @@ import "sigs.k8s.io/controller-runtime/pkg/client" // ObjectType is used when we only care about the type of client.Object. // The fields of the client.Object may be empty. type ObjectType client.Object + +// Fields used for communication with the EndpointPicker service when using the Inference Extension. +const ( + // EPPEndpointHostHeader is the HTTP header used to specify the EPP endpoint host. + EPPEndpointHostHeader = "X-EPP-Host" + // EPPEndpointPortHeader is the HTTP header used to specify the EPP endpoint port. + EPPEndpointPortHeader = "X-EPP-Port" + // GoShimPort is the default port for the Go EPP shim server to listen on. If collisions become a problem, + // we can make this configurable via the NginxProxy resource. + GoShimPort = 54800 // why 54800? Sum "nginx" in ASCII and multiply by 100. +) From fb8446976fd3a04976d3d980f63c625aff081201 Mon Sep 17 00:00:00 2001 From: Saylor Berman Date: Mon, 22 Sep 2025 13:10:45 -0600 Subject: [PATCH 2/5] Adjust maxLoc logic, add newlines --- internal/controller/nginx/config/servers.go | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/internal/controller/nginx/config/servers.go b/internal/controller/nginx/config/servers.go index 24ce473305..b8686b775c 100644 --- a/internal/controller/nginx/config/servers.go +++ b/internal/controller/nginx/config/servers.go @@ -421,6 +421,7 @@ func updateExternalLocationsForRule( mirrorPercentage, ) } + return extLocations } @@ -469,6 +470,7 @@ func createInternalLocationsForRule( internalLocations = append(internalLocations, intLocation) matches = append(matches, match) } + return internalLocations, matches } @@ -511,6 +513,7 @@ func createInferenceLocationsForRule( locs = append(locs, intLocation) } locs = append(locs, extLocations...) + return locs } @@ -518,6 +521,7 @@ func needsInternalLocationsForMatches(rule dataplane.PathRule) bool { if len(rule.MatchRules) > 1 { return true } + return len(rule.MatchRules) == 1 && !isPathOnlyMatch(rule.MatchRules[0].Match) } @@ -530,13 +534,13 @@ type pathAndTypeMap map[string]map[dataplane.PathType]struct{} // 2. Each path rule may have an additional location if it contains non-path-only matches. // 3. Each prefix path rule may have an additional location if it doesn't contain trailing slash. // 4. There may be an additional location for the default root path. -// 5. There may be an additional location for the inference extension. +// 5. There may be an additional location per parent location for the inference extension. // We also return a map of all paths and their types. func getMaxLocationCountAndPathMap(pathRules []dataplane.PathRule) (int, pathAndTypeMap) { maxLocs := 1 pathsAndTypes := make(pathAndTypeMap) for _, rule := range pathRules { - maxLocs += len(rule.MatchRules) + 3 + maxLocs += (len(rule.MatchRules) * 2) + 2 if pathsAndTypes[rule.Path] == nil { pathsAndTypes[rule.Path] = map[dataplane.PathType]struct{}{ rule.PathType: {}, From c7603ba9530c8f0b615b8a0a1f27cf03b780c96a Mon Sep 17 00:00:00 2001 From: Saylor Berman Date: Mon, 22 Sep 2025 13:31:01 -0600 Subject: [PATCH 3/5] Add comments --- internal/controller/nginx/config/servers.go | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/internal/controller/nginx/config/servers.go b/internal/controller/nginx/config/servers.go index b8686b775c..d190e05def 100644 --- a/internal/controller/nginx/config/servers.go +++ b/internal/controller/nginx/config/servers.go @@ -379,6 +379,9 @@ func createLocations( ) httpMatchKey := serverID + "_" + strconv.Itoa(pathRuleIdx) for i := range extLocations { + // FIXME(sberman): De-dupe matches and associated locations + // so we don't need nginx/njs to perform unnecessary matching. + // https://github.com/nginx/nginx-gateway-fabric/issues/662 extLocations[i].HTTPMatchKey = httpMatchKey matchPairs[extLocations[i].HTTPMatchKey] = matches } @@ -615,8 +618,8 @@ func getLocationTypeForPathRule(rule dataplane.PathRule) http.LocationType { return http.ExternalLocationType } -// initializes the internal location that is redirected to by an external location HTTP matching decision. -// This location will proxy_pass to the backend. +// initializeInternalMatchLocation initializes the internal location that is redirected to by an +// external location HTTP matching decision. This location will proxy_pass to the backend. func initializeInternalMatchLocation( pathruleIdx, matchRuleIdx int, @@ -627,7 +630,7 @@ func initializeInternalMatchLocation( return createMatchLocation(path, grpc), createRouteMatch(match, path) } -// initializes the internal inference location that is redirected to by +// initializeInternalInferenceRedirectLocation initializes the internal inference location that is redirected to by // an external HTTP matching location. This location then redirects to the final proxy_pass location. func initializeInternalInferenceRedirectLocation(pathruleIdx, matchRuleIdx int) http.Location { return http.Location{ @@ -636,8 +639,9 @@ func initializeInternalInferenceRedirectLocation(pathruleIdx, matchRuleIdx int) } } -// initializes the internal location that is redirected to by an internal inference location, which was -// redirected to by the external HTTP matching location. This location will proxy_pass to the backend. +// initializeInternalMatchLocationWithInference initializes the internal location that is redirected to by +// an internal inference location, which was redirected to by the external HTTP matching location. +// This location will proxy_pass to the backend. // The routeMatch is created with the inference internal location path, so that the HTTP match in the external // location can redirect to the proper inference location, which then redirects to this location. func initializeInternalMatchLocationWithInference( @@ -651,7 +655,8 @@ func initializeInternalMatchLocationWithInference( return createMatchLocation(path, grpc), createRouteMatch(match, inferencePath(pathruleIdx, matchRuleIdx)) } -// initializes the internal inference location that does the final proxy_pass to the inference backend. +// initializeInternalInferenceLocation initializes the internal inference location that does the final +// proxy_pass to the inference backend. // This is used when the external location redirects directly here, without any HTTP matching. func initializeInternalInferenceLocation(pathruleIdx, matchRuleIdx int) http.Location { return http.Location{ From 067f96135375de12c05749ad46997e1e042d7c15 Mon Sep 17 00:00:00 2001 From: Saylor Berman Date: Mon, 22 Sep 2025 14:09:11 -0600 Subject: [PATCH 4/5] Clarify comments --- internal/controller/nginx/config/servers.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/internal/controller/nginx/config/servers.go b/internal/controller/nginx/config/servers.go index d190e05def..9664396c2e 100644 --- a/internal/controller/nginx/config/servers.go +++ b/internal/controller/nginx/config/servers.go @@ -310,6 +310,11 @@ and which backend to use, then redirects to the internal inference location. The location calls the inference NJS module to get the AI endpoint to proxy to, then redirects to the internal location that proxies to the backend. +Note that the location path naming here is a little different than the previous example. +The final location that proxy_passes has the non-inference name to avoid too much refactoring +in the code, and the intermediate location has -inference in the name, whereas in the previous example +it was the final location that had -inference in the name. + location /coffee { js_content httpmatches.match; // chooses backend and redirects to appropriate internal inference location } From 6d7444f259273038b09e0dc1ec8bce457b466ddb Mon Sep 17 00:00:00 2001 From: Saylor Berman Date: Tue, 23 Sep 2025 08:32:26 -0600 Subject: [PATCH 5/5] add some comments and unit test --- internal/controller/nginx/config/maps.go | 6 ++++- .../controller/nginx/modules/test/epp.test.js | 26 ++++++++++++++++++- 2 files changed, 30 insertions(+), 2 deletions(-) diff --git a/internal/controller/nginx/config/maps.go b/internal/controller/nginx/config/maps.go index 0b7ac00551..e0f9ee98d5 100644 --- a/internal/controller/nginx/config/maps.go +++ b/internal/controller/nginx/config/maps.go @@ -185,14 +185,18 @@ func createAddHeadersMap(name string) shared.Map { // buildInferenceMaps creates maps for InferencePool Backends. func buildInferenceMaps(groups []dataplane.BackendGroup) []shared.Map { - inferenceMaps := make([]shared.Map, 0) + inferenceMaps := make([]shared.Map, 0, len(groups)) for _, group := range groups { for _, backend := range group.Backends { if backend.EndpointPickerConfig != nil { var defaultResult string switch backend.EndpointPickerConfig.FailureMode { + // in FailClose mode, if the EPP is unavailable or returns an error, + // we return an invalid backend to ensure the request fails case inference.EndpointPickerFailClose: defaultResult = invalidBackendRef + // in FailOpen mode, if the EPP is unavailable or returns an error, + // we fall back to the upstream case inference.EndpointPickerFailOpen: defaultResult = backend.UpstreamName } diff --git a/internal/controller/nginx/modules/test/epp.test.js b/internal/controller/nginx/modules/test/epp.test.js index 288104520b..c2a4528694 100644 --- a/internal/controller/nginx/modules/test/epp.test.js +++ b/internal/controller/nginx/modules/test/epp.test.js @@ -1,12 +1,19 @@ import { default as epp } from '../src/epp.js'; import { expect, describe, it, beforeEach, afterEach, vi } from 'vitest'; -function makeRequest({ method = 'POST', headersIn = {}, requestText = '', variables = {} } = {}) { +function makeRequest({ + method = 'POST', + headersIn = {}, + args = {}, + requestText = '', + variables = {}, +} = {}) { return { method, headersIn, requestText, variables, + args, error: vi.fn(), log: vi.fn(), internalRedirect: vi.fn(), @@ -79,4 +86,21 @@ describe('getEndpoint', () => { expect(r.error).toHaveBeenCalledWith(expect.stringContaining('Error in ngx.fetch')); expect(r.internalRedirect).toHaveBeenCalledWith('/foo'); }); + + it('preserves args in internal redirect when args are present', async () => { + const endpoint = 'http://endpoint'; + globalThis.ngx = { + fetch: vi.fn().mockResolvedValue({ + status: 200, + headers: { get: () => endpoint }, + text: vi.fn(), + }), + }; + const r = makeRequest({ + variables: { epp_host: 'host', epp_port: '1234', epp_internal_path: '/foo' }, + args: { a: '1', b: '2' }, + }); + await epp.getEndpoint(r); + expect(r.internalRedirect).toHaveBeenCalledWith('/foo?a=1&b=2'); + }); });