diff --git a/cmd/gateway/endpoint_picker.go b/cmd/gateway/endpoint_picker.go index 7c67a83671..acf9bdfbb6 100644 --- a/cmd/gateway/endpoint_picker.go +++ b/cmd/gateway/endpoint_picker.go @@ -14,16 +14,8 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" eppMetadata "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metadata" -) -const ( - // defaultPort is the default port for this server to listen on. If collisions become a problem, - // we can make this configurable via the NginxProxy resource. - defaultPort = 54800 // why 54800? Sum "nginx" in ASCII and multiply by 100. - // eppEndpointHostHeader is the HTTP header used to specify the EPP endpoint host, set by the NJS module caller. - eppEndpointHostHeader = "X-EPP-Host" - // eppEndpointPortHeader is the HTTP header used to specify the EPP endpoint port, set by the NJS module caller. - eppEndpointPortHeader = "X-EPP-Port" + "github.com/nginx/nginx-gateway-fabric/v2/internal/framework/types" ) // extProcClientFactory creates a new ExternalProcessorClient and returns a close function. @@ -32,7 +24,7 @@ type extProcClientFactory func(target string) (extprocv3.ExternalProcessorClient // endpointPickerServer starts an HTTP server on the given port with the provided handler. func endpointPickerServer(handler http.Handler) error { server := &http.Server{ - Addr: fmt.Sprintf("127.0.0.1:%d", defaultPort), + Addr: fmt.Sprintf("127.0.0.1:%d", types.GoShimPort), Handler: handler, ReadHeaderTimeout: 10 * time.Second, } @@ -54,13 +46,13 @@ func realExtProcClientFactory() extProcClientFactory { // createEndpointPickerHandler returns an http.Handler that forwards requests to the EndpointPicker. func createEndpointPickerHandler(factory extProcClientFactory, logger logr.Logger) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - host := r.Header.Get(eppEndpointHostHeader) - port := r.Header.Get(eppEndpointPortHeader) + host := r.Header.Get(types.EPPEndpointHostHeader) + port := r.Header.Get(types.EPPEndpointPortHeader) if host == "" || port == "" { msg := fmt.Sprintf( "missing at least one of required headers: %s and %s", - eppEndpointHostHeader, - eppEndpointPortHeader, + types.EPPEndpointHostHeader, + types.EPPEndpointPortHeader, ) logger.Error(errors.New(msg), "error contacting EndpointPicker") http.Error(w, msg, http.StatusBadRequest) @@ -174,6 +166,10 @@ func buildHeaderRequest(r *http.Request) *extprocv3.ProcessingRequest { } func buildBodyRequest(r *http.Request) (*extprocv3.ProcessingRequest, error) { + if r.ContentLength == 0 { + return nil, errors.New("request body is empty") + } + body, err := io.ReadAll(r.Body) if err != nil { return nil, fmt.Errorf("error reading request body: %w", err) diff --git a/cmd/gateway/endpoint_picker_test.go b/cmd/gateway/endpoint_picker_test.go index 99808348fc..99fd95aa90 100644 --- a/cmd/gateway/endpoint_picker_test.go +++ b/cmd/gateway/endpoint_picker_test.go @@ -17,6 +17,8 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/metadata" eppMetadata "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metadata" + + "github.com/nginx/nginx-gateway-fabric/v2/internal/framework/types" ) type mockExtProcClient struct { @@ -122,8 +124,8 @@ func TestEndpointPickerHandler_Success(t *testing.T) { h := createEndpointPickerHandler(factory, logr.Discard()) req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("test body")) - req.Header.Set(eppEndpointHostHeader, "test-host") - req.Header.Set(eppEndpointPortHeader, "1234") + req.Header.Set(types.EPPEndpointHostHeader, "test-host") + req.Header.Set(types.EPPEndpointPortHeader, "1234") req.Header.Set("Content-Type", "application/json") w := httptest.NewRecorder() @@ -165,8 +167,8 @@ func TestEndpointPickerHandler_ImmediateResponse(t *testing.T) { h := createEndpointPickerHandler(factory, logr.Discard()) req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("test body")) - req.Header.Set(eppEndpointHostHeader, "test-host") - req.Header.Set(eppEndpointPortHeader, "1234") + req.Header.Set(types.EPPEndpointHostHeader, "test-host") + req.Header.Set(types.EPPEndpointPortHeader, "1234") w := httptest.NewRecorder() h.ServeHTTP(w, req) @@ -190,8 +192,8 @@ func TestEndpointPickerHandler_Errors(t *testing.T) { h := createEndpointPickerHandler(factory, logr.Discard()) req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("test body")) if setHeaders { - req.Header.Set(eppEndpointHostHeader, "test-host") - req.Header.Set(eppEndpointPortHeader, "1234") + req.Header.Set(types.EPPEndpointHostHeader, "test-host") + req.Header.Set(types.EPPEndpointPortHeader, "1234") } w := httptest.NewRecorder() h.ServeHTTP(w, req) @@ -236,7 +238,33 @@ func TestEndpointPickerHandler_Errors(t *testing.T) { } runErrorTestCase(factory, true, http.StatusBadGateway, "error sending headers") - // 4. Error sending body + // 4a. Error building body request (content length 0) + client = &mockProcessClient{ + SendFunc: func(*extprocv3.ProcessingRequest) error { + return nil + }, + RecvFunc: func() (*extprocv3.ProcessingResponse, error) { return nil, io.EOF }, + } + extProcClient = &mockExtProcClient{ + ProcessFunc: func(context.Context, ...grpc.CallOption) (extprocv3.ExternalProcessor_ProcessClient, error) { + return client, nil + }, + } + factory = func(string) (extprocv3.ExternalProcessorClient, func() error, error) { + return extProcClient, func() error { return nil }, nil + } + h := createEndpointPickerHandler(factory, logr.Discard()) + req := httptest.NewRequest(http.MethodPost, "/", nil) // nil body, ContentLength = 0 + req.Header.Set(types.EPPEndpointHostHeader, "test-host") + req.Header.Set(types.EPPEndpointPortHeader, "1234") + w := httptest.NewRecorder() + h.ServeHTTP(w, req) + resp := w.Result() + g.Expect(resp.StatusCode).To(Equal(http.StatusInternalServerError)) + body, _ := io.ReadAll(resp.Body) + g.Expect(string(body)).To(ContainSubstring("request body is empty")) + + // 4b. Error sending body client = &mockProcessClient{ SendFunc: func(req *extprocv3.ProcessingRequest) error { if req.GetRequestBody() != nil { diff --git a/deploy/inference-nginx-plus/deploy.yaml b/deploy/inference-nginx-plus/deploy.yaml index 77ee4da544..025cfeb410 100644 --- a/deploy/inference-nginx-plus/deploy.yaml +++ b/deploy/inference-nginx-plus/deploy.yaml @@ -281,6 +281,7 @@ spec: - --nginx-docker-secret=nginx-plus-registry-secret - --nginx-plus - --usage-report-secret=nplus-license + - --usage-report-enforce-initial-report=true - --metrics-port=9113 - --health-port=8081 - --leader-election-lock-name=nginx-gateway-leader-election diff --git a/internal/controller/nginx/config/http/config.go b/internal/controller/nginx/config/http/config.go index 3a76ab30b4..dedfd04349 100644 --- a/internal/controller/nginx/config/http/config.go +++ b/internal/controller/nginx/config/http/config.go @@ -26,26 +26,58 @@ type Server struct { type LocationType string const ( + // InternalLocationType defines an internal location that is only accessible within NGINX. InternalLocationType LocationType = "internal" + // ExternalLocationType defines a normal external location that is accessible by clients. ExternalLocationType LocationType = "external" + // RedirectLocationType defines an external location that redirects to an internal location + // based on HTTP matching conditions. RedirectLocationType LocationType = "redirect" + // InferenceExternalLocationType defines an external location that is used for calling NJS + // to get the inference workload endpoint and redirects to the internal location that will proxy_pass + // to that endpoint. + InferenceExternalLocationType LocationType = "inference-external" + // InferenceInternalLocationType defines an internal location that is used for calling NJS + // to get the inference workload endpoint and redirects to the internal location that will proxy_pass + // to that endpoint. This is used when an HTTP redirect location is also defined that redirects + // to this internal inference location. + InferenceInternalLocationType LocationType = "inference-internal" ) // Location holds all configuration for an HTTP location. type Location struct { - Path string - ProxyPass string - HTTPMatchKey string + // Return specifies a return directive (e.g., HTTP status or redirect) for this location block. + Return *Return + // ProxySSLVerify controls SSL verification for upstreams when proxying requests. + ProxySSLVerify *ProxySSLVerify + // ProxyPass is the upstream backend (URL or name) to which requests are proxied. + ProxyPass string + // HTTPMatchKey is the key for associating HTTP match rules, used for routing and NJS module logic. + HTTPMatchKey string + // MirrorSplitClientsVariableName is the variable name for split_clients, used in traffic mirroring scenarios. MirrorSplitClientsVariableName string - Type LocationType - ProxySetHeaders []Header - ProxySSLVerify *ProxySSLVerify - Return *Return - ResponseHeaders ResponseHeaders - Rewrites []string - MirrorPaths []string - Includes []shared.Include - GRPC bool + // EPPInternalPath is the internal path for the inference NJS module to redirect to. + EPPInternalPath string + // EPPHost is the host for the EndpointPicker, used for inference routing. + EPPHost string + // Type indicates the type of location (external, internal, redirect, etc). + Type LocationType + // Path is the NGINX location path. + Path string + // ResponseHeaders are custom response headers to be sent. + ResponseHeaders ResponseHeaders + // ProxySetHeaders are headers to set when proxying requests upstream. + ProxySetHeaders []Header + // Rewrites are rewrite rules for modifying request paths. + Rewrites []string + // MirrorPaths are paths to which requests are mirrored. + MirrorPaths []string + // Includes are additional NGINX config snippets or policies to include in this location. + Includes []shared.Include + // EPPPort is the port for the EndpointPicker, used for inference routing. + EPPPort int + // GRPC indicates if this location proxies gRPC traffic. + GRPC bool } // Header defines an HTTP header to be passed to the proxied server. diff --git a/internal/controller/nginx/config/maps.go b/internal/controller/nginx/config/maps.go index 5a5e5ff189..e0f9ee98d5 100644 --- a/internal/controller/nginx/config/maps.go +++ b/internal/controller/nginx/config/maps.go @@ -1,9 +1,12 @@ package config import ( + "fmt" "strings" gotemplate "text/template" + inference "sigs.k8s.io/gateway-api-inference-extension/api/v1" + "github.com/nginx/nginx-gateway-fabric/v2/internal/controller/nginx/config/shared" "github.com/nginx/nginx-gateway-fabric/v2/internal/controller/state/dataplane" "github.com/nginx/nginx-gateway-fabric/v2/internal/framework/helpers" @@ -26,6 +29,8 @@ const ( func executeMaps(conf dataplane.Configuration) []executeResult { maps := buildAddHeaderMaps(append(conf.HTTPServers, conf.SSLServers...)) + maps = append(maps, buildInferenceMaps(conf.BackendGroups)...) + result := executeResult{ dest: httpConfigFile, data: helpers.MustExecuteTemplate(mapsTemplate, maps), @@ -177,3 +182,42 @@ func createAddHeadersMap(name string) shared.Map { Parameters: params, } } + +// buildInferenceMaps creates maps for InferencePool Backends. +func buildInferenceMaps(groups []dataplane.BackendGroup) []shared.Map { + inferenceMaps := make([]shared.Map, 0, len(groups)) + for _, group := range groups { + for _, backend := range group.Backends { + if backend.EndpointPickerConfig != nil { + var defaultResult string + switch backend.EndpointPickerConfig.FailureMode { + // in FailClose mode, if the EPP is unavailable or returns an error, + // we return an invalid backend to ensure the request fails + case inference.EndpointPickerFailClose: + defaultResult = invalidBackendRef + // in FailOpen mode, if the EPP is unavailable or returns an error, + // we fall back to the upstream + case inference.EndpointPickerFailOpen: + defaultResult = backend.UpstreamName + } + params := []shared.MapParameter{ + { + Value: "~.+", + Result: "$inference_workload_endpoint", + }, + { + Value: "default", + Result: defaultResult, + }, + } + backendVarName := strings.ReplaceAll(backend.UpstreamName, "-", "_") + inferenceMaps = append(inferenceMaps, shared.Map{ + Source: "$inference_workload_endpoint", + Variable: fmt.Sprintf("$inference_backend_%s", backendVarName), + Parameters: params, + }) + } + } + } + return inferenceMaps +} diff --git a/internal/controller/nginx/config/maps_test.go b/internal/controller/nginx/config/maps_test.go index d133882d7b..736d7808ec 100644 --- a/internal/controller/nginx/config/maps_test.go +++ b/internal/controller/nginx/config/maps_test.go @@ -5,6 +5,7 @@ import ( "testing" . "github.com/onsi/gomega" + inference "sigs.k8s.io/gateway-api-inference-extension/api/v1" "github.com/nginx/nginx-gateway-fabric/v2/internal/controller/nginx/config/shared" "github.com/nginx/nginx-gateway-fabric/v2/internal/controller/state/dataplane" @@ -59,22 +60,24 @@ func TestExecuteMaps(t *testing.T) { conf := dataplane.Configuration{ HTTPServers: []dataplane.VirtualServer{ - { - PathRules: pathRules, - }, - { - PathRules: pathRules, - }, - { - IsDefault: true, - }, + {PathRules: pathRules}, + {PathRules: pathRules}, + {IsDefault: true}, }, SSLServers: []dataplane.VirtualServer{ + {PathRules: pathRules}, + {IsDefault: true}, + }, + BackendGroups: []dataplane.BackendGroup{ { - PathRules: pathRules, - }, - { - IsDefault: true, + Backends: []dataplane.Backend{ + { + UpstreamName: "upstream1", + EndpointPickerConfig: &inference.EndpointPickerRef{ + FailureMode: inference.EndpointPickerFailClose, + }, + }, + }, }, }, } @@ -86,6 +89,9 @@ func TestExecuteMaps(t *testing.T) { "map ${http_my_second_add_header} $my_second_add_header_header_var {": 1, "~.* ${http_my_second_add_header},;": 1, "map ${http_my_set_header} $my_set_header_header_var {": 0, + "$inference_workload_endpoint": 2, + "$inference_backend": 1, + "invalid-backend-ref": 1, } mapResult := executeMaps(conf) @@ -385,3 +391,36 @@ func TestCreateStreamMapsWithEmpty(t *testing.T) { g.Expect(maps).To(BeNil()) } + +func TestBuildInferenceMaps(t *testing.T) { + t.Parallel() + g := NewWithT(t) + + group := dataplane.BackendGroup{ + Backends: []dataplane.Backend{ + { + UpstreamName: "upstream1", + EndpointPickerConfig: &inference.EndpointPickerRef{ + FailureMode: inference.EndpointPickerFailClose, + }, + }, + { + UpstreamName: "upstream2", + EndpointPickerConfig: &inference.EndpointPickerRef{ + FailureMode: inference.EndpointPickerFailOpen, + }, + }, + { + UpstreamName: "upstream3", + EndpointPickerConfig: nil, + }, + }, + } + + maps := buildInferenceMaps([]dataplane.BackendGroup{group}) + g.Expect(maps).To(HaveLen(2)) + g.Expect(maps[0].Source).To(Equal("$inference_workload_endpoint")) + g.Expect(maps[0].Variable).To(Equal("$inference_backend_upstream1")) + g.Expect(maps[0].Parameters[1].Result).To(Equal("invalid-backend-ref")) + g.Expect(maps[1].Parameters[1].Result).To(Equal("upstream2")) +} diff --git a/internal/controller/nginx/config/servers.go b/internal/controller/nginx/config/servers.go index 88ba4fa8ea..9664396c2e 100644 --- a/internal/controller/nginx/config/servers.go +++ b/internal/controller/nginx/config/servers.go @@ -16,7 +16,13 @@ import ( "github.com/nginx/nginx-gateway-fabric/v2/internal/framework/helpers" ) -var serversTemplate = gotemplate.Must(gotemplate.New("servers").Parse(serversTemplateText)) +var serversTemplate = gotemplate.Must( + gotemplate.New("servers").Funcs(gotemplate.FuncMap{ + "contains": func(str http.LocationType, substr string) bool { + return strings.Contains(string(str), substr) + }, + }).Parse(serversTemplateText), +) const ( // HeaderMatchSeparator is the separator for constructing header-based match for NJS. @@ -252,6 +258,78 @@ func extractMirrorTargetsWithPercentages(pathRules []dataplane.PathRule) map[str return mirrorTargets } +/* +There are several different flows of location blocks, depending on the user configuration. +The following describes them, with basic location examples. + +--------------- +Base case, no HTTP matching conditions or inference extension. + +External location proxies straight to backend. + +location /coffee { + proxy_pass http://backend; +} +--------------- +HTTP matching conditions. + +External location calls httpmatch NJS module. The module determines the HTTP request conditions that exist +and which backend to use, then redirects to the appropriate internal location. +The internal location proxies to the backend. + +location /coffee { + js_content httpmatches.match; // chooses backend1 or backend2, and redirects to appropriate internal location +} +location /_ngf-internal-rule0-route0 { + internal; + proxy_pass http://backend1; +} +location /_ngf-internal-rule1-route0 { + internal; + proxy_pass http://backend2; +} +--------------- +Inference extension, no HTTP matching conditions. + +External location calls inference NJS module. The module gets the AI endpoint to proxy to, +then redirects to the internal inference location that proxies to the backend. + +location /coffee { + set $epp_internal_path /_ngf-internal-rule0-route0-inference; + js_content epp.getEndpoint; // gets endpoint and redirects to /_ngf-internal-rule0-route0-inference +} +location /_ngf-internal-rule0-route0-inference { + internal; + proxy_pass http://$inference-backend; +} +--------------- +Inference extension with HTTP matching conditions. + +External location calls httpmatch NJS module. The module determines the HTTP request conditions that exist +and which backend to use, then redirects to the internal inference location. The internal inference +location calls the inference NJS module to get the AI endpoint to proxy to, then redirects to the +internal location that proxies to the backend. + +Note that the location path naming here is a little different than the previous example. +The final location that proxy_passes has the non-inference name to avoid too much refactoring +in the code, and the intermediate location has -inference in the name, whereas in the previous example +it was the final location that had -inference in the name. + +location /coffee { + js_content httpmatches.match; // chooses backend and redirects to appropriate internal inference location +} +location /_ngf-internal-rule0-route0-inference { + internal; + + set $epp_internal_path /_ngf-internal-rule0-route0; + js_content epp.getEndpoint; // redirects to /_ngf-internal-rule0-route0 +} +location /_ngf-internal-rule0-route0 { + internal; + proxy_pass http://$inference-backend; +} +*/ + type httpMatchPairs map[string][]routeMatch func createLocations( @@ -270,8 +348,6 @@ func createLocations( mirrorPathToPercentage := extractMirrorTargetsWithPercentages(server.PathRules) for pathRuleIdx, rule := range server.PathRules { - matches := make([]routeMatch, 0, len(rule.MatchRules)) - if rule.Path == rootPath { rootPathExists = true } @@ -281,7 +357,6 @@ func createLocations( } mirrorPercentage := mirrorPathToPercentage[rule.Path] - extLocations := initializeExternalLocations(rule, pathsAndTypes) for i := range extLocations { extLocations[i].Includes = createIncludesFromPolicyGenerateResult( @@ -289,54 +364,45 @@ func createLocations( ) } - if !needsInternalLocations(rule) { - for _, r := range rule.MatchRules { - extLocations = updateLocations( - r, - rule, - extLocations, - server.Port, - keepAliveCheck, - mirrorPercentage, - ) - } - - locs = append(locs, extLocations...) - continue - } - - internalLocations := make([]http.Location, 0, len(rule.MatchRules)) - - for matchRuleIdx, r := range rule.MatchRules { - intLocation, match := initializeInternalLocation(pathRuleIdx, matchRuleIdx, r.Match, rule.GRPC) - intLocation.Includes = createIncludesFromPolicyGenerateResult( - generator.GenerateForInternalLocation(rule.Policies), + switch { + case !needsInternalLocationsForMatches(rule) && !rule.HasInferenceBackends: + locs = append(locs, updateExternalLocationsForRule( + rule, + extLocations, + server.Port, + keepAliveCheck, + mirrorPercentage)..., ) - - intLocation = updateLocation( - r, + case needsInternalLocationsForMatches(rule): + internalLocations, matches := createInternalLocationsForRule( + pathRuleIdx, rule, - intLocation, + generator, server.Port, keepAliveCheck, mirrorPercentage, ) - - internalLocations = append(internalLocations, intLocation) - matches = append(matches, match) - } - - httpMatchKey := serverID + "_" + strconv.Itoa(pathRuleIdx) - for i := range extLocations { - // FIXME(sberman): De-dupe matches and associated locations - // so we don't need nginx/njs to perform unnecessary matching. - // https://github.com/nginx/nginx-gateway-fabric/issues/662 - extLocations[i].HTTPMatchKey = httpMatchKey - matchPairs[extLocations[i].HTTPMatchKey] = matches + httpMatchKey := serverID + "_" + strconv.Itoa(pathRuleIdx) + for i := range extLocations { + // FIXME(sberman): De-dupe matches and associated locations + // so we don't need nginx/njs to perform unnecessary matching. + // https://github.com/nginx/nginx-gateway-fabric/issues/662 + extLocations[i].HTTPMatchKey = httpMatchKey + matchPairs[extLocations[i].HTTPMatchKey] = matches + } + locs = append(locs, extLocations...) + locs = append(locs, internalLocations...) + case rule.HasInferenceBackends: + locs = append(locs, createInferenceLocationsForRule( + pathRuleIdx, + rule, + extLocations, + generator, + server.Port, + keepAliveCheck, + mirrorPercentage)..., + ) } - - locs = append(locs, extLocations...) - locs = append(locs, internalLocations...) } if !rootPathExists { @@ -346,10 +412,124 @@ func createLocations( return locs, matchPairs, grpcServer } -func needsInternalLocations(rule dataplane.PathRule) bool { +func updateExternalLocationsForRule( + rule dataplane.PathRule, + extLocations []http.Location, + port int32, + keepAliveCheck keepAliveChecker, + mirrorPercentage *float64, +) []http.Location { + for _, r := range rule.MatchRules { + extLocations = updateLocations( + r, + rule, + extLocations, + port, + keepAliveCheck, + mirrorPercentage, + ) + } + + return extLocations +} + +func createInternalLocationsForRule( + pathRuleIdx int, + rule dataplane.PathRule, + generator policies.Generator, + port int32, + keepAliveCheck keepAliveChecker, + mirrorPercentage *float64, +) ([]http.Location, []routeMatch) { + internalLocations := make([]http.Location, 0, len(rule.MatchRules)) + matches := make([]routeMatch, 0, len(rule.MatchRules)) + for matchRuleIdx, r := range rule.MatchRules { + var intLocation http.Location + var match routeMatch + if !rule.HasInferenceBackends { + intLocation, match = initializeInternalMatchLocation(pathRuleIdx, matchRuleIdx, r.Match, rule.GRPC) + } else { + intLocation, match = initializeInternalMatchLocationWithInference(pathRuleIdx, matchRuleIdx, r.Match) + intInfLocation := initializeInternalInferenceRedirectLocation(pathRuleIdx, matchRuleIdx) + for _, b := range r.BackendGroup.Backends { + if b.EndpointPickerConfig != nil { + var portNum int + if b.EndpointPickerConfig.Port != nil { + portNum = int(b.EndpointPickerConfig.Port.Number) + } + intInfLocation.EPPInternalPath = intLocation.Path + intInfLocation.EPPHost = string(b.EndpointPickerConfig.Name) + intInfLocation.EPPPort = portNum + } + } + internalLocations = append(internalLocations, intInfLocation) + } + intLocation.Includes = createIncludesFromPolicyGenerateResult( + generator.GenerateForInternalLocation(rule.Policies), + ) + intLocation = updateLocation( + r, + rule, + intLocation, + port, + keepAliveCheck, + mirrorPercentage, + ) + internalLocations = append(internalLocations, intLocation) + matches = append(matches, match) + } + + return internalLocations, matches +} + +func createInferenceLocationsForRule( + pathRuleIdx int, + rule dataplane.PathRule, + extLocations []http.Location, + generator policies.Generator, + port int32, + keepAliveCheck keepAliveChecker, + mirrorPercentage *float64, +) []http.Location { + locs := make([]http.Location, 0, len(rule.MatchRules)+len(extLocations)) + for matchRuleIdx, r := range rule.MatchRules { + intLocation := initializeInternalInferenceLocation(pathRuleIdx, matchRuleIdx) + intLocation.Includes = createIncludesFromPolicyGenerateResult( + generator.GenerateForInternalLocation(rule.Policies), + ) + intLocation = updateLocation( + r, + rule, + intLocation, + port, + keepAliveCheck, + mirrorPercentage, + ) + for _, b := range r.BackendGroup.Backends { + if b.EndpointPickerConfig != nil { + for i := range extLocations { + var portNum int + if b.EndpointPickerConfig.Port != nil { + portNum = int(b.EndpointPickerConfig.Port.Number) + } + extLocations[i].EPPInternalPath = intLocation.Path + extLocations[i].EPPHost = string(b.EndpointPickerConfig.Name) + extLocations[i].EPPPort = portNum + } + } + } + locs = append(locs, intLocation) + } + locs = append(locs, extLocations...) + + return locs +} + +func needsInternalLocationsForMatches(rule dataplane.PathRule) bool { if len(rule.MatchRules) > 1 { return true } + return len(rule.MatchRules) == 1 && !isPathOnlyMatch(rule.MatchRules[0].Match) } @@ -362,12 +542,13 @@ type pathAndTypeMap map[string]map[dataplane.PathType]struct{} // 2. Each path rule may have an additional location if it contains non-path-only matches. // 3. Each prefix path rule may have an additional location if it doesn't contain trailing slash. // 4. There may be an additional location for the default root path. +// 5. There may be an additional location per parent location for the inference extension. // We also return a map of all paths and their types. func getMaxLocationCountAndPathMap(pathRules []dataplane.PathRule) (int, pathAndTypeMap) { maxLocs := 1 pathsAndTypes := make(pathAndTypeMap) for _, rule := range pathRules { - maxLocs += len(rule.MatchRules) + 2 + maxLocs += (len(rule.MatchRules) * 2) + 2 if pathsAndTypes[rule.Path] == nil { pathsAndTypes[rule.Path] = map[dataplane.PathType]struct{}{ rule.PathType: {}, @@ -431,14 +612,20 @@ func initializeExternalLocations( } func getLocationTypeForPathRule(rule dataplane.PathRule) http.LocationType { - if needsInternalLocations(rule) { + if needsInternalLocationsForMatches(rule) { return http.RedirectLocationType } + if rule.HasInferenceBackends { + return http.InferenceExternalLocationType + } + return http.ExternalLocationType } -func initializeInternalLocation( +// initializeInternalMatchLocation initializes the internal location that is redirected to by an +// external location HTTP matching decision. This location will proxy_pass to the backend. +func initializeInternalMatchLocation( pathruleIdx, matchRuleIdx int, match dataplane.Match, @@ -448,6 +635,45 @@ func initializeInternalLocation( return createMatchLocation(path, grpc), createRouteMatch(match, path) } +// initializeInternalInferenceRedirectLocation initializes the internal inference location that is redirected to by +// an external HTTP matching location. This location then redirects to the final proxy_pass location. +func initializeInternalInferenceRedirectLocation(pathruleIdx, matchRuleIdx int) http.Location { + return http.Location{ + Path: inferencePath(pathruleIdx, matchRuleIdx), + Type: http.InferenceInternalLocationType, + } +} + +// initializeInternalMatchLocationWithInference initializes the internal location that is redirected to by +// an internal inference location, which was redirected to by the external HTTP matching location. +// This location will proxy_pass to the backend. +// The routeMatch is created with the inference internal location path, so that the HTTP match in the external +// location can redirect to the proper inference location, which then redirects to this location. +func initializeInternalMatchLocationWithInference( + pathruleIdx, + matchRuleIdx int, + match dataplane.Match, +) (http.Location, routeMatch) { + path := fmt.Sprintf("%s-rule%d-route%d", http.InternalRoutePathPrefix, pathruleIdx, matchRuleIdx) + grpc := false + + return createMatchLocation(path, grpc), createRouteMatch(match, inferencePath(pathruleIdx, matchRuleIdx)) +} + +// initializeInternalInferenceLocation initializes the internal inference location that does the final +// proxy_pass to the inference backend. +// This is used when the external location redirects directly here, without any HTTP matching. +func initializeInternalInferenceLocation(pathruleIdx, matchRuleIdx int) http.Location { + return http.Location{ + Path: inferencePath(pathruleIdx, matchRuleIdx), + Type: http.InternalLocationType, + } +} + +func inferencePath(pathruleIdx int, matchRuleIdx int) string { + return fmt.Sprintf("%s-rule%d-route%d-inference", http.InternalRoutePathPrefix, pathruleIdx, matchRuleIdx) +} + // updateLocation updates a location with any relevant configurations, like proxy_pass, filters, tls settings, etc. func updateLocation( matchRule dataplane.MatchRule, @@ -460,6 +686,7 @@ func updateLocation( filters := matchRule.Filters path := pathRule.Path grpc := pathRule.GRPC + inferenceBackend := pathRule.HasInferenceBackends if filters.InvalidFilter != nil { location.Return = &http.Return{Code: http.StatusInternalServerError} @@ -475,7 +702,7 @@ func updateLocation( location = updateLocationRewriteFilter(location, filters.RequestURLRewrite, path) location = updateLocationMirrorFilters(location, filters.RequestMirrors, path, mirrorPercentage) - location = updateLocationProxySettings(location, matchRule, grpc, keepAliveCheck) + location = updateLocationProxySettings(location, matchRule, grpc, inferenceBackend, keepAliveCheck) return location } @@ -555,6 +782,7 @@ func updateLocationProxySettings( location http.Location, matchRule dataplane.MatchRule, grpc bool, + inferenceBackend bool, keepAliveCheck keepAliveChecker, ) http.Location { extraHeaders := make([]http.Header, 0, 3) @@ -575,6 +803,7 @@ func updateLocationProxySettings( matchRule.Filters.RequestURLRewrite, generateProtocolString(location.ProxySSLVerify, grpc), grpc, + inferenceBackend, ) location.ResponseHeaders = responseHeaders @@ -853,6 +1082,7 @@ func createProxyPass( filter *dataplane.HTTPURLRewriteFilter, protocol string, grpc bool, + inferenceBackend bool, ) string { var requestURI string if !grpc { @@ -862,6 +1092,12 @@ func createProxyPass( } backendName := backendGroupName(backendGroup) + + if inferenceBackend { + backendVarName := strings.ReplaceAll(backendName, "-", "_") + return "http://$inference_backend_" + backendVarName + requestURI + } + if backendGroupNeedsSplit(backendGroup) { return protocol + "://$" + convertStringToSafeVariableName(backendName) + requestURI } diff --git a/internal/controller/nginx/config/servers_template.go b/internal/controller/nginx/config/servers_template.go index 224e189a6e..9575b77480 100644 --- a/internal/controller/nginx/config/servers_template.go +++ b/internal/controller/nginx/config/servers_template.go @@ -92,7 +92,7 @@ server { {{ range $l := $s.Locations }} location {{ $l.Path }} { - {{ if eq $l.Type "internal" -}} + {{ if contains $l.Type "internal" -}} internal; {{ end }} @@ -118,11 +118,19 @@ server { return {{ $l.Return.Code }} "{{ $l.Return.Body }}"; {{- end }} - {{- if eq $l.Type "redirect" }} + {{- if eq $l.Type "redirect" -}} set $match_key {{ $l.HTTPMatchKey }}; js_content httpmatches.redirect; {{- end }} + {{- if contains $l.Type "inference" -}} + js_var $inference_workload_endpoint; + set $epp_internal_path {{ $l.EPPInternalPath }}; + set $epp_host {{ $l.EPPHost }}; + set $epp_port {{ $l.EPPPort }}; + js_content epp.getEndpoint; + {{- end }} + {{ $proxyOrGRPC := "proxy" }}{{ if $l.GRPC }}{{ $proxyOrGRPC = "grpc" }}{{ end }} {{- if $l.GRPC }} diff --git a/internal/controller/nginx/config/servers_test.go b/internal/controller/nginx/config/servers_test.go index 6b604d7bec..ab4fad31a5 100644 --- a/internal/controller/nginx/config/servers_test.go +++ b/internal/controller/nginx/config/servers_test.go @@ -9,6 +9,7 @@ import ( . "github.com/onsi/gomega" "github.com/onsi/gomega/format" "k8s.io/apimachinery/pkg/types" + inference "sigs.k8s.io/gateway-api-inference-extension/api/v1" "github.com/nginx/nginx-gateway-fabric/v2/internal/controller/nginx/config/http" "github.com/nginx/nginx-gateway-fabric/v2/internal/controller/nginx/config/policies" @@ -1239,7 +1240,7 @@ func TestCreateServers(t *testing.T) { Filters: dataplane.HTTPFilters{ RequestRedirect: &dataplane.HTTPRequestRedirectFilter{ Hostname: helpers.GetPointer("redirect.example.com"), - StatusCode: helpers.GetPointer[int](301), + StatusCode: helpers.GetPointer(301), Port: helpers.GetPointer[int32](8080), Path: &dataplane.HTTPPathModifier{ Type: dataplane.ReplaceFullPath, @@ -2443,6 +2444,154 @@ func TestCreateLocations_Includes(t *testing.T) { } } +func TestCreateLocations_InferenceBackends(t *testing.T) { + t.Parallel() + + hrNsName := types.NamespacedName{Namespace: "test", Name: "route1"} + + fooGroup := dataplane.BackendGroup{ + Source: hrNsName, + RuleIdx: 0, + Backends: []dataplane.Backend{ + { + UpstreamName: "test_foo_80", + Valid: true, + Weight: 1, + EndpointPickerConfig: &inference.EndpointPickerRef{ + Name: "test-epp", + Port: &inference.Port{ + Number: 80, + }, + }, + }, + }, + } + + pathRuleInferenceOnly := dataplane.PathRule{ + Path: "/inference", + PathType: dataplane.PathTypeExact, + HasInferenceBackends: true, + MatchRules: []dataplane.MatchRule{ + { + Match: dataplane.Match{}, + BackendGroup: fooGroup, + }, + }, + } + + pathRuleInferenceWithMatch := dataplane.PathRule{ + Path: "/inference-match", + PathType: dataplane.PathTypeExact, + HasInferenceBackends: true, + MatchRules: []dataplane.MatchRule{ + { + Match: dataplane.Match{ + Method: helpers.GetPointer("POST"), + }, + BackendGroup: fooGroup, + }, + }, + } + + tests := []struct { + expMatches httpMatchPairs + name string + pathRules []dataplane.PathRule + expLocs []http.Location + }{ + { + name: "inference only, no internal locations for matches", + pathRules: []dataplane.PathRule{pathRuleInferenceOnly}, + expLocs: []http.Location{ + { + Path: "/_ngf-internal-rule0-route0-inference", + Type: http.InternalLocationType, + ProxyPass: "http://$inference_backend_test_foo_80$request_uri", + ProxySetHeaders: []http.Header{ + {Name: "Host", Value: "$gw_api_compliant_host"}, + {Name: "X-Forwarded-For", Value: "$proxy_add_x_forwarded_for"}, + {Name: "X-Real-IP", Value: "$remote_addr"}, + {Name: "X-Forwarded-Proto", Value: "$scheme"}, + {Name: "X-Forwarded-Host", Value: "$host"}, + {Name: "X-Forwarded-Port", Value: "$server_port"}, + {Name: "Upgrade", Value: "$http_upgrade"}, + {Name: "Connection", Value: "$connection_upgrade"}, + }, + }, + { + Path: "= /inference", + Type: http.InferenceExternalLocationType, + EPPInternalPath: "/_ngf-internal-rule0-route0-inference", + EPPHost: "test-epp", + EPPPort: 80, + }, + createDefaultRootLocation(), + }, + expMatches: httpMatchPairs{}, + }, + { + name: "inference with match, needs internal locations for matches", + pathRules: []dataplane.PathRule{pathRuleInferenceWithMatch}, + expLocs: []http.Location{ + { + Path: "= /inference-match", + Type: http.RedirectLocationType, + HTTPMatchKey: "1_0", + }, + { + Path: "/_ngf-internal-rule0-route0-inference", + Type: http.InferenceInternalLocationType, + EPPInternalPath: "/_ngf-internal-rule0-route0", + EPPHost: "test-epp", + EPPPort: 80, + }, + { + Path: "/_ngf-internal-rule0-route0", + Type: http.InternalLocationType, + ProxyPass: "http://$inference_backend_test_foo_80$request_uri", + ProxySetHeaders: []http.Header{ + {Name: "Host", Value: "$gw_api_compliant_host"}, + {Name: "X-Forwarded-For", Value: "$proxy_add_x_forwarded_for"}, + {Name: "X-Real-IP", Value: "$remote_addr"}, + {Name: "X-Forwarded-Proto", Value: "$scheme"}, + {Name: "X-Forwarded-Host", Value: "$host"}, + {Name: "X-Forwarded-Port", Value: "$server_port"}, + {Name: "Upgrade", Value: "$http_upgrade"}, + {Name: "Connection", Value: "$connection_upgrade"}, + }, + }, + createDefaultRootLocation(), + }, + expMatches: httpMatchPairs{ + "1_0": { + {Method: "POST", RedirectPath: "/_ngf-internal-rule0-route0-inference"}, + }, + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + g := NewWithT(t) + + locs, matches, _ := createLocations( + &dataplane.VirtualServer{ + Hostname: "example.com", + PathRules: tc.pathRules, + Port: 80, + }, + "1", + &policiesfakes.FakeGenerator{}, + alwaysFalseKeepAliveChecker, + ) + + g.Expect(helpers.Diff(tc.expLocs, locs)).To(BeEmpty()) + g.Expect(matches).To(Equal(tc.expMatches)) + }) + } +} + func TestCreateLocationsRootPath(t *testing.T) { t.Parallel() hrNsName := types.NamespacedName{Namespace: "test", Name: "route1"} @@ -3332,10 +3481,11 @@ func TestCreateProxyPass(t *testing.T) { t.Parallel() tests := []struct { - rewrite *dataplane.HTTPURLRewriteFilter - expected string - grp dataplane.BackendGroup - GRPC bool + rewrite *dataplane.HTTPURLRewriteFilter + expected string + grp dataplane.BackendGroup + GRPC bool + inferenceBackend bool }{ { expected: "http://10.0.0.1:80$request_uri", @@ -3349,6 +3499,20 @@ func TestCreateProxyPass(t *testing.T) { }, }, }, + // Inference case + { + expected: "http://$inference_backend_upstream_inference$request_uri", + grp: dataplane.BackendGroup{ + Backends: []dataplane.Backend{ + { + UpstreamName: "upstream-inference", + Valid: true, + Weight: 1, + }, + }, + }, + inferenceBackend: true, + }, { expected: "http://$group_ns1__bg_rule0$request_uri", grp: dataplane.BackendGroup{ @@ -3401,7 +3565,13 @@ func TestCreateProxyPass(t *testing.T) { t.Run(tc.expected, func(t *testing.T) { t.Parallel() g := NewWithT(t) - result := createProxyPass(tc.grp, tc.rewrite, generateProtocolString(nil, tc.GRPC), tc.GRPC) + result := createProxyPass( + tc.grp, + tc.rewrite, + generateProtocolString(nil, tc.GRPC), + tc.GRPC, + tc.inferenceBackend, + ) g.Expect(result).To(Equal(tc.expected)) }) } diff --git a/internal/controller/nginx/modules/src/epp.js b/internal/controller/nginx/modules/src/epp.js index d4beeb9e15..88de40062b 100644 --- a/internal/controller/nginx/modules/src/epp.js +++ b/internal/controller/nginx/modules/src/epp.js @@ -1,29 +1,59 @@ -// This file contains the methods to get an AI workload endpoint from the EndpointPicker (EPP). +import qs from 'querystring'; -// TODO(sberman): this module will need to be enhanced to include the following: -// - function that sends the subrequest to the Go middleware application (to get the endpoint from EPP) -// - if a user has specified an Exact matching condition for a model name, extract the model name from -// the request body, and if it matches that condition, set the proper value in the X-Gateway-Model-Name header -// (based on if we do a redirect or traffic split (see design doc)) in the subrequest. If the client request -// already has this header set, then I don't think we need to extract the model from the body, just pass -// through the existing header. -// I believe we have to use js_content to call the NJS functionality. Because this takes over -// the request, we will likely have to finish the NJS functionality with an internalRedirect to an internal -// location that proxy_passes to the chosen endpoint. +const EPP_HOST_HEADER_VAR = 'epp_host'; +const EPP_PORT_HEADER_VAR = 'epp_port'; +const EPP_HOST_HEADER = 'X-EPP-Host'; +const EPP_PORT_HEADER = 'X-EPP-Port'; +const ENDPOINT_HEADER = 'X-Gateway-Destination-Endpoint'; +const EPP_INTERNAL_PATH_VAR = 'epp_internal_path'; +const WORKLOAD_ENDPOINT_VAR = 'inference_workload_endpoint'; +const SHIM_URI = 'http://127.0.0.1:54800'; + +async function getEndpoint(r) { + if (!r.variables[EPP_HOST_HEADER_VAR] || !r.variables[EPP_PORT_HEADER_VAR]) { + throw Error( + `Missing required variables: ${EPP_HOST_HEADER_VAR} and/or ${EPP_PORT_HEADER_VAR}`, + ); + } + if (!r.variables[EPP_INTERNAL_PATH_VAR]) { + throw Error(`Missing required variable: ${EPP_INTERNAL_PATH_VAR}`); + } + + let headers = Object.assign({}, r.headersIn); + headers[EPP_HOST_HEADER] = r.variables[EPP_HOST_HEADER_VAR]; + headers[EPP_PORT_HEADER] = r.variables[EPP_PORT_HEADER_VAR]; -// extractModel extracts the model name from the request body. -function extractModel(r) { try { - var body = JSON.parse(r.requestText); - if (body && body.model !== undefined) { - return String(body.model); + const response = await ngx.fetch(SHIM_URI, { + method: r.method, + headers: headers, + body: r.requestText, + }); + const endpointHeader = response.headers.get(ENDPOINT_HEADER); + if (response.status === 200 && endpointHeader) { + r.variables[WORKLOAD_ENDPOINT_VAR] = endpointHeader; + r.log( + `found inference endpoint from EndpointPicker: ${r.variables[WORKLOAD_ENDPOINT_VAR]}`, + ); + } else { + const body = await response.text(); + r.error( + `could not get specific inference endpoint from EndpointPicker; ` + + `status: ${response.status}; body: ${body}`, + ); } - } catch (e) { - r.error(`error parsing request body for model name: ${e.message}`); - return ''; + } catch (err) { + r.error(`Error in ngx.fetch: ${err}`); } - r.error('request body does not contain model parameter'); - return ''; + + // If performing a rewrite, $request_uri won't be used, + // so we have to preserve args in the internal redirect. + let args = qs.stringify(r.args); + if (args) { + args = '?' + args; + } + + r.internalRedirect(r.variables[EPP_INTERNAL_PATH_VAR] + args); } -export default { extractModel }; +export default { getEndpoint }; diff --git a/internal/controller/nginx/modules/test/epp.test.js b/internal/controller/nginx/modules/test/epp.test.js index 6994423e7a..c2a4528694 100644 --- a/internal/controller/nginx/modules/test/epp.test.js +++ b/internal/controller/nginx/modules/test/epp.test.js @@ -1,52 +1,106 @@ import { default as epp } from '../src/epp.js'; -import { expect, describe, it } from 'vitest'; - -function makeRequest(body) { - let r = { - // Test mocks - error(msg) { - r.variables.error = msg; - }, - requestText: body, - variables: {}, - }; +import { expect, describe, it, beforeEach, afterEach, vi } from 'vitest'; - return r; +function makeRequest({ + method = 'POST', + headersIn = {}, + args = {}, + requestText = '', + variables = {}, +} = {}) { + return { + method, + headersIn, + requestText, + variables, + args, + error: vi.fn(), + log: vi.fn(), + internalRedirect: vi.fn(), + }; } -describe('extractModel', () => { - const tests = [ - { - name: 'returns the model value', - body: '{"model":"gpt-4"}', - model: 'gpt-4', - error: undefined, - }, - { - name: 'returns empty string if model is missing', - body: '{"foo":1}', - model: '', - error: 'request body does not contain model parameter', - }, - { - name: 'returns empty string for invalid JSON', - body: 'not-json', - model: '', - error: `error parsing request body for model name: Unexpected token 'o', "not-json" is not valid JSON`, - }, - { - name: 'empty request body', - body: '', - model: '', - error: 'error parsing request body for model name: Unexpected end of JSON input', - }, - ]; - - tests.forEach((test) => { - it(test.name, () => { - let r = makeRequest(test.body); - expect(epp.extractModel(r)).to.equal(test.model); - expect(r.variables.error).to.equal(test.error); +describe('getEndpoint', () => { + let originalNgx; + beforeEach(() => { + originalNgx = globalThis.ngx; + }); + afterEach(() => { + globalThis.ngx = originalNgx; + }); + + it('throws if host or port is missing', async () => { + const r = makeRequest({ variables: { epp_internal_path: '/foo' } }); + await expect(epp.getEndpoint(r)).rejects.toThrow(/Missing required variables/); + }); + + it('throws if internal path is missing', async () => { + const r = makeRequest({ variables: { epp_host: 'host', epp_port: '1234' } }); + await expect(epp.getEndpoint(r)).rejects.toThrow(/Missing required variable/); + }); + + it('sets endpoint and logs on 200 with endpoint header', async () => { + const endpoint = 'http://endpoint'; + globalThis.ngx = { + fetch: vi.fn().mockResolvedValue({ + status: 200, + headers: { get: () => endpoint }, + text: vi.fn(), + }), + }; + const r = makeRequest({ + variables: { epp_host: 'host', epp_port: '1234', epp_internal_path: '/foo' }, + }); + await epp.getEndpoint(r); + expect(r.variables.inference_workload_endpoint).toBe(endpoint); + expect(r.log).toHaveBeenCalledWith(expect.stringContaining(endpoint)); + expect(r.internalRedirect).toHaveBeenCalledWith('/foo'); + }); + + it('calls error if response is not 200 or endpoint header missing', async () => { + globalThis.ngx = { + fetch: vi.fn().mockResolvedValue({ + status: 404, + headers: { get: () => null }, + text: vi.fn().mockResolvedValue('fail'), + }), + }; + const r = makeRequest({ + variables: { epp_host: 'host', epp_port: '1234', epp_internal_path: '/foo' }, + }); + await epp.getEndpoint(r); + expect(r.error).toHaveBeenCalledWith( + expect.stringContaining('could not get specific inference endpoint'), + ); + expect(r.internalRedirect).toHaveBeenCalledWith('/foo'); + }); + + it('calls error if fetch throws', async () => { + globalThis.ngx = { + fetch: vi.fn().mockRejectedValue(new Error('network fail')), + }; + const r = makeRequest({ + variables: { epp_host: 'host', epp_port: '1234', epp_internal_path: '/foo' }, + }); + await epp.getEndpoint(r); + expect(r.error).toHaveBeenCalledWith(expect.stringContaining('Error in ngx.fetch')); + expect(r.internalRedirect).toHaveBeenCalledWith('/foo'); + }); + + it('preserves args in internal redirect when args are present', async () => { + const endpoint = 'http://endpoint'; + globalThis.ngx = { + fetch: vi.fn().mockResolvedValue({ + status: 200, + headers: { get: () => endpoint }, + text: vi.fn(), + }), + }; + const r = makeRequest({ + variables: { epp_host: 'host', epp_port: '1234', epp_internal_path: '/foo' }, + args: { a: '1', b: '2' }, }); + await epp.getEndpoint(r); + expect(r.internalRedirect).toHaveBeenCalledWith('/foo?a=1&b=2'); }); }); diff --git a/internal/controller/state/dataplane/configuration.go b/internal/controller/state/dataplane/configuration.go index 52306f4e0b..59030c0ca7 100644 --- a/internal/controller/state/dataplane/configuration.go +++ b/internal/controller/state/dataplane/configuration.go @@ -374,12 +374,13 @@ func newBackendGroup( gatewayName types.NamespacedName, sourceNsName types.NamespacedName, ruleIdx int, -) BackendGroup { +) (BackendGroup, bool) { var backends []Backend if len(refs) > 0 { backends = make([]Backend, 0, len(refs)) } + var inferencePoolBackendExists bool for _, ref := range refs { if ref.IsMirrorBackend { @@ -391,11 +392,14 @@ func newBackendGroup( valid = false } + inferencePoolBackendExists = inferencePoolBackendExists || ref.IsInferencePool + backends = append(backends, Backend{ - UpstreamName: ref.ServicePortReference(), - Weight: ref.Weight, - Valid: valid, - VerifyTLS: convertBackendTLS(ref.BackendTLSPolicy, gatewayName), + UpstreamName: ref.ServicePortReference(), + Weight: ref.Weight, + Valid: valid, + VerifyTLS: convertBackendTLS(ref.BackendTLSPolicy, gatewayName), + EndpointPickerConfig: ref.EndpointPickerConfig, }) } @@ -403,7 +407,7 @@ func newBackendGroup( Backends: backends, Source: sourceNsName, RuleIdx: ruleIdx, - } + }, inferencePoolBackendExists } func convertBackendTLS(btp *graph.BackendTLSPolicy, gwNsName types.NamespacedName) *VerifyTLS { @@ -595,10 +599,19 @@ func (hpr *hostPathRules) upsertRoute( } hostRule.GRPC = GRPC + backendGroup, inferencePoolBackendExists := newBackendGroup( + rule.BackendRefs, + listener.GatewayName, + routeNsName, + idx, + ) + if inferencePoolBackendExists { + hostRule.HasInferenceBackends = true + } hostRule.MatchRules = append(hostRule.MatchRules, MatchRule{ Source: objectSrc, - BackendGroup: newBackendGroup(rule.BackendRefs, listener.GatewayName, routeNsName, idx), + BackendGroup: backendGroup, Filters: filters, Match: convertMatch(m), }) diff --git a/internal/controller/state/dataplane/configuration_test.go b/internal/controller/state/dataplane/configuration_test.go index b329b9d46a..3e1697590d 100644 --- a/internal/controller/state/dataplane/configuration_test.go +++ b/internal/controller/state/dataplane/configuration_test.go @@ -2777,6 +2777,93 @@ func TestBuildConfiguration_Plus(t *testing.T) { } } +func TestUpsertRoute_PathRuleHasInferenceBackend(t *testing.T) { + t.Parallel() + g := NewWithT(t) + + // Setup minimal route with one BackendRef marked as IsInferencePool + backendRef := graph.BackendRef{ + SvcNsName: types.NamespacedName{Name: "svc", Namespace: "test"}, + ServicePort: apiv1.ServicePort{Port: 80}, + Valid: true, + IsInferencePool: true, + } + + listenerName := "listener-80" + gwName := types.NamespacedName{Namespace: "test", Name: "gw"} + + route := &graph.L7Route{ + RouteType: graph.RouteTypeHTTP, + Source: &v1.HTTPRoute{ + ObjectMeta: metav1.ObjectMeta{ + Name: "hr", + Namespace: "test", + }, + }, + Spec: graph.L7RouteSpec{ + Rules: []graph.RouteRule{ + { + ValidMatches: true, + Filters: graph.RouteRuleFilters{Valid: true}, + BackendRefs: []graph.BackendRef{backendRef}, + Matches: []v1.HTTPRouteMatch{ + { + Path: &v1.HTTPPathMatch{ + Type: helpers.GetPointer(v1.PathMatchPathPrefix), + Value: helpers.GetPointer("/infer"), + }, + }, + }, + }, + }, + }, + ParentRefs: []graph.ParentRef{ + { + Attachment: &graph.ParentRefAttachmentStatus{ + AcceptedHostnames: map[string][]string{ + graph.CreateGatewayListenerKey(gwName, listenerName): {"*"}, + }, + }, + }, + }, + Valid: true, + } + + listener := &graph.Listener{ + Name: listenerName, + GatewayName: gwName, + Valid: true, + Routes: map[graph.RouteKey]*graph.L7Route{ + graph.CreateRouteKey(route.Source): route, + }, + } + + gateway := &graph.Gateway{ + Source: &v1.Gateway{ + ObjectMeta: metav1.ObjectMeta{ + Name: "gw", + Namespace: "test", + }, + }, + Listeners: []*graph.Listener{listener}, + } + + hpr := newHostPathRules() + hpr.upsertRoute(route, listener, gateway) + + // Find the PathRule for "/infer" + found := false + for _, rules := range hpr.rulesPerHost { + for _, pr := range rules { + if pr.Path == "/infer" { + found = true + g.Expect(pr.HasInferenceBackends).To(BeTrue()) + } + } + } + g.Expect(found).To(BeTrue(), "PathRule for '/infer' not found") +} + func TestNewBackendGroup_Mirror(t *testing.T) { t.Parallel() g := NewWithT(t) @@ -2788,7 +2875,7 @@ func TestNewBackendGroup_Mirror(t *testing.T) { IsMirrorBackend: true, } - group := newBackendGroup([]graph.BackendRef{backendRef}, types.NamespacedName{}, types.NamespacedName{}, 0) + group, _ := newBackendGroup([]graph.BackendRef{backendRef}, types.NamespacedName{}, types.NamespacedName{}, 0) g.Expect(group.Backends).To(BeEmpty()) } diff --git a/internal/controller/state/dataplane/types.go b/internal/controller/state/dataplane/types.go index 08e7e0867b..1637c1f408 100644 --- a/internal/controller/state/dataplane/types.go +++ b/internal/controller/state/dataplane/types.go @@ -5,6 +5,7 @@ import ( metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/types" + inference "sigs.k8s.io/gateway-api-inference-extension/api/v1" "github.com/nginx/nginx-gateway-fabric/v2/internal/controller/nginx/config/policies" "github.com/nginx/nginx-gateway-fabric/v2/internal/controller/state/graph" @@ -137,6 +138,8 @@ type PathRule struct { Policies []policies.Policy // GRPC indicates if this is a gRPC rule GRPC bool + // HasInferenceBackends indicates whether the PathRule contains a backend for an inference workload. + HasInferenceBackends bool } // InvalidHTTPFilter is a special filter for handling the case when configured filters are invalid. @@ -323,6 +326,9 @@ func (bg *BackendGroup) Name() string { type Backend struct { // VerifyTLS holds the backend TLS verification configuration. VerifyTLS *VerifyTLS + // EndpointPickerConfig holds the configuration for the EndpointPicker for this backend. + // This is set if this backend is for an inference workload. + EndpointPickerConfig *inference.EndpointPickerRef // UpstreamName is the name of the upstream for this backend. UpstreamName string // Weight is the weight of the BackendRef. diff --git a/internal/controller/state/graph/backend_refs.go b/internal/controller/state/graph/backend_refs.go index e14d0fb0fa..95ce6df0b9 100644 --- a/internal/controller/state/graph/backend_refs.go +++ b/internal/controller/state/graph/backend_refs.go @@ -9,6 +9,7 @@ import ( v1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/types" "k8s.io/apimachinery/pkg/util/validation/field" + inference "sigs.k8s.io/gateway-api-inference-extension/api/v1" gatewayv1 "sigs.k8s.io/gateway-api/apis/v1" "sigs.k8s.io/gateway-api/apis/v1alpha3" @@ -30,6 +31,8 @@ const ( type BackendRef struct { // BackendTLSPolicy is the BackendTLSPolicy of the Service which is referenced by the backendRef. BackendTLSPolicy *BackendTLSPolicy + // EndpointPickerConfig is the configuration for the EndpointPicker, if this backendRef is for an InferencePool. + EndpointPickerConfig *inference.EndpointPickerRef // InvalidForGateways is a map of Gateways for which this BackendRef is invalid for, with the corresponding // condition. Certain NginxProxy configurations may result in a backend not being valid for some Gateways, // but not others. @@ -45,6 +48,8 @@ type BackendRef struct { Valid bool // IsMirrorBackend indicates whether the BackendGroup is for a mirrored backend. IsMirrorBackend bool + // IsInferencePool indicates whether the BackendRef is for an InferencePool. + IsInferencePool bool } // ServicePortReference returns a string representation for the service and port that is referenced by the BackendRef. @@ -118,6 +123,7 @@ func addBackendRefsToRules( if pool, exists := referencedInferencePools[poolName]; exists { port := gatewayv1.PortNumber(pool.Source.Spec.TargetPorts[0].Number) ref.Port = helpers.GetPointer(port) + ref.EndpointPickerConfig = &pool.Source.Spec.EndpointPickerRef } } @@ -181,10 +187,12 @@ func createBackendRef( if !valid { backendRef := BackendRef{ - Weight: weight, - Valid: false, - IsMirrorBackend: ref.MirrorBackendIdx != nil, - InvalidForGateways: make(map[types.NamespacedName]conditions.Condition), + Weight: weight, + Valid: false, + IsMirrorBackend: ref.MirrorBackendIdx != nil, + IsInferencePool: ref.IsInferencePool, + InvalidForGateways: make(map[types.NamespacedName]conditions.Condition), + EndpointPickerConfig: ref.EndpointPickerConfig, } return backendRef, []conditions.Condition{cond} @@ -198,12 +206,14 @@ func createBackendRef( svcIPFamily, svcPort, err := getIPFamilyAndPortFromRef(ref.BackendRef, svcNsName, services, refPath) if err != nil { backendRef := BackendRef{ - Weight: weight, - Valid: false, - SvcNsName: svcNsName, - ServicePort: v1.ServicePort{}, - IsMirrorBackend: ref.MirrorBackendIdx != nil, - InvalidForGateways: make(map[types.NamespacedName]conditions.Condition), + Weight: weight, + Valid: false, + SvcNsName: svcNsName, + ServicePort: v1.ServicePort{}, + IsMirrorBackend: ref.MirrorBackendIdx != nil, + IsInferencePool: ref.IsInferencePool, + InvalidForGateways: make(map[types.NamespacedName]conditions.Condition), + EndpointPickerConfig: ref.EndpointPickerConfig, } return backendRef, []conditions.Condition{conditions.NewRouteBackendRefRefBackendNotFound(err.Error())} @@ -220,12 +230,14 @@ func createBackendRef( // Check if externalName field is empty or whitespace-only if strings.TrimSpace(svc.Spec.ExternalName) == "" { backendRef := BackendRef{ - SvcNsName: svcNsName, - ServicePort: svcPort, - Weight: weight, - Valid: false, - IsMirrorBackend: ref.MirrorBackendIdx != nil, - InvalidForGateways: invalidForGateways, + SvcNsName: svcNsName, + ServicePort: svcPort, + Weight: weight, + Valid: false, + IsMirrorBackend: ref.MirrorBackendIdx != nil, + IsInferencePool: ref.IsInferencePool, + InvalidForGateways: invalidForGateways, + EndpointPickerConfig: ref.EndpointPickerConfig, } return backendRef, append(conds, conditions.NewRouteBackendRefUnsupportedValue( @@ -249,12 +261,14 @@ func createBackendRef( ) if err != nil { backendRef := BackendRef{ - SvcNsName: svcNsName, - ServicePort: svcPort, - Weight: weight, - Valid: false, - IsMirrorBackend: ref.MirrorBackendIdx != nil, - InvalidForGateways: invalidForGateways, + SvcNsName: svcNsName, + ServicePort: svcPort, + Weight: weight, + Valid: false, + IsMirrorBackend: ref.MirrorBackendIdx != nil, + IsInferencePool: ref.IsInferencePool, + InvalidForGateways: invalidForGateways, + EndpointPickerConfig: ref.EndpointPickerConfig, } return backendRef, append(conds, conditions.NewRouteBackendRefUnsupportedValue(err.Error())) @@ -264,13 +278,15 @@ func createBackendRef( err = validateRouteBackendRefAppProtocol(route.RouteType, *svcPort.AppProtocol, backendTLSPolicy) if err != nil { backendRef := BackendRef{ - SvcNsName: svcNsName, - BackendTLSPolicy: backendTLSPolicy, - ServicePort: svcPort, - Weight: weight, - Valid: false, - IsMirrorBackend: ref.MirrorBackendIdx != nil, - InvalidForGateways: invalidForGateways, + SvcNsName: svcNsName, + BackendTLSPolicy: backendTLSPolicy, + ServicePort: svcPort, + Weight: weight, + Valid: false, + IsMirrorBackend: ref.MirrorBackendIdx != nil, + IsInferencePool: ref.IsInferencePool, + InvalidForGateways: invalidForGateways, + EndpointPickerConfig: ref.EndpointPickerConfig, } return backendRef, append(conds, conditions.NewRouteBackendRefUnsupportedProtocol(err.Error())) @@ -278,13 +294,15 @@ func createBackendRef( } backendRef := BackendRef{ - SvcNsName: svcNsName, - BackendTLSPolicy: backendTLSPolicy, - ServicePort: svcPort, - Valid: true, - Weight: weight, - IsMirrorBackend: ref.MirrorBackendIdx != nil, - InvalidForGateways: invalidForGateways, + SvcNsName: svcNsName, + BackendTLSPolicy: backendTLSPolicy, + ServicePort: svcPort, + Valid: true, + Weight: weight, + IsMirrorBackend: ref.MirrorBackendIdx != nil, + IsInferencePool: ref.IsInferencePool, + InvalidForGateways: invalidForGateways, + EndpointPickerConfig: ref.EndpointPickerConfig, } return backendRef, conds diff --git a/internal/controller/state/graph/backend_refs_test.go b/internal/controller/state/graph/backend_refs_test.go index 3f05f793a6..b786daed9b 100644 --- a/internal/controller/state/graph/backend_refs_test.go +++ b/internal/controller/state/graph/backend_refs_test.go @@ -1231,9 +1231,11 @@ func TestAddBackendRefsToRules(t *testing.T) { ServicePort: v1.ServicePort{ Port: 80, }, - Valid: true, - Weight: 1, - InvalidForGateways: map[types.NamespacedName]conditions.Condition{}, + Valid: true, + Weight: 1, + InvalidForGateways: map[types.NamespacedName]conditions.Condition{}, + IsInferencePool: true, + EndpointPickerConfig: &inference.EndpointPickerRef{}, }, }, expectedConditions: nil, diff --git a/internal/controller/state/graph/graph_test.go b/internal/controller/state/graph/graph_test.go index da0ca04d47..1a367e5977 100644 --- a/internal/controller/state/graph/graph_test.go +++ b/internal/controller/state/graph/graph_test.go @@ -223,10 +223,12 @@ func TestBuildGraph(t *testing.T) { Namespace: testNs, Name: controller.CreateInferencePoolServiceName("ipool"), }, - ServicePort: v1.ServicePort{Port: 80}, - Valid: true, - Weight: 1, - InvalidForGateways: map[types.NamespacedName]conditions.Condition{}, + ServicePort: v1.ServicePort{Port: 80}, + Valid: true, + Weight: 1, + InvalidForGateways: map[types.NamespacedName]conditions.Condition{}, + IsInferencePool: true, + EndpointPickerConfig: &inference.EndpointPickerRef{}, }, } rbrs := []RouteBackendRef{ diff --git a/internal/controller/state/graph/httproute.go b/internal/controller/state/graph/httproute.go index de7a85370d..ed8d46a664 100644 --- a/internal/controller/state/graph/httproute.go +++ b/internal/controller/state/graph/httproute.go @@ -210,11 +210,26 @@ func processHTTPRouteRule( } } - var rbr RouteBackendRef + rbr := RouteBackendRef{ + BackendRef: b.BackendRef, + } + // If route specifies an InferencePool backend, we need to convert it to its associated // headless Service backend (that we created), so nginx config can be built properly. // Only do this if the InferencePool actually exists. if inferencePoolBackend(b, routeNamespace, inferencePools) { + // We don't support traffic splitting at the Route level for + // InferencePool backends, so if there's more than one backendRef, and one of them + // is an InferencePool, we mark the rule as invalid. + if len(specRule.BackendRefs) > 1 { + err := field.Forbidden( + rulePath.Child("backendRefs"), + "cannot use InferencePool backend when multiple backendRefs are specified in a single rule", + ) + errors.invalid = append(errors.invalid, err) + break + } + svcName := controller.CreateInferencePoolServiceName(string(b.Name)) rbr = RouteBackendRef{ IsInferencePool: true, @@ -228,10 +243,6 @@ func processHTTPRouteRule( Weight: b.Weight, }, } - } else { - rbr = RouteBackendRef{ - BackendRef: b.BackendRef, - } } rbr.Filters = interfaceFilters diff --git a/internal/controller/state/graph/httproute_test.go b/internal/controller/state/graph/httproute_test.go index 0e06e5bf7e..d6d77c7296 100644 --- a/internal/controller/state/graph/httproute_test.go +++ b/internal/controller/state/graph/httproute_test.go @@ -1213,6 +1213,67 @@ func TestBuildHTTPRouteWithMirrorRoutes(t *testing.T) { g.Expect(helpers.Diff(expectedMirrorRoute, routes[mirrorRouteKey])).To(BeEmpty()) } +func TestProcessHTTPRouteRule_InferencePoolWithMultipleBackendRefs(t *testing.T) { + t.Parallel() + g := NewWithT(t) + + validator := &validationfakes.FakeHTTPFieldsValidator{} + inferencePoolName := "ipool" + routeNamespace := "test" + inferencePools := map[types.NamespacedName]*inference.InferencePool{ + {Namespace: routeNamespace, Name: inferencePoolName}: {}, + } + + // BackendRef 1: InferencePool + backendRef1 := gatewayv1.HTTPBackendRef{ + BackendRef: gatewayv1.BackendRef{ + BackendObjectReference: gatewayv1.BackendObjectReference{ + Group: helpers.GetPointer[gatewayv1.Group](inferenceAPIGroup), + Kind: helpers.GetPointer[gatewayv1.Kind](kinds.InferencePool), + Name: gatewayv1.ObjectName(inferencePoolName), + Namespace: helpers.GetPointer(gatewayv1.Namespace(routeNamespace)), + }, + }, + } + // BackendRef 2: Service + backendRef2 := gatewayv1.HTTPBackendRef{ + BackendRef: gatewayv1.BackendRef{ + BackendObjectReference: gatewayv1.BackendObjectReference{ + Kind: helpers.GetPointer[gatewayv1.Kind](kinds.Service), + Name: "backend", + }, + }, + } + + specRule := gatewayv1.HTTPRouteRule{ + Matches: []gatewayv1.HTTPRouteMatch{ + { + Path: &gatewayv1.HTTPPathMatch{ + Type: helpers.GetPointer(gatewayv1.PathMatchPathPrefix), + Value: helpers.GetPointer("/"), + }, + }, + }, + BackendRefs: []gatewayv1.HTTPBackendRef{backendRef1, backendRef2}, + } + + rulePath := field.NewPath("spec").Child("rules").Index(0) + + routeRule, errs := processHTTPRouteRule( + specRule, + routeNamespace, + rulePath, + validator, + nil, + inferencePools, + ) + + g.Expect(routeRule.RouteBackendRefs).To(BeEmpty()) + g.Expect(errs.invalid).To(HaveLen(1)) + errMsg := "cannot use InferencePool backend when multiple backendRefs are specified in a single rule" + g.Expect(errs.invalid[0].Error()).To(ContainSubstring(errMsg)) +} + func TestValidateMatch(t *testing.T) { t.Parallel() createAllValidValidator := func() *validationfakes.FakeHTTPFieldsValidator { diff --git a/internal/controller/state/graph/route_common.go b/internal/controller/state/graph/route_common.go index f3d3b04e4a..22067c6d44 100644 --- a/internal/controller/state/graph/route_common.go +++ b/internal/controller/state/graph/route_common.go @@ -166,6 +166,9 @@ type RouteBackendRef struct { // If this backend is defined in a RequestMirror filter, this value will indicate the filter's index. MirrorBackendIdx *int + // EndpointPickerConfig is the configuration for the EndpointPicker, if this backendRef is for an InferencePool. + EndpointPickerConfig *inference.EndpointPickerRef + Filters []any // IsInferencePool indicates if this backend is an InferencePool disguised as a Service. diff --git a/internal/framework/types/types.go b/internal/framework/types/types.go index bf61bd23d7..0aeccd008d 100644 --- a/internal/framework/types/types.go +++ b/internal/framework/types/types.go @@ -5,3 +5,14 @@ import "sigs.k8s.io/controller-runtime/pkg/client" // ObjectType is used when we only care about the type of client.Object. // The fields of the client.Object may be empty. type ObjectType client.Object + +// Fields used for communication with the EndpointPicker service when using the Inference Extension. +const ( + // EPPEndpointHostHeader is the HTTP header used to specify the EPP endpoint host. + EPPEndpointHostHeader = "X-EPP-Host" + // EPPEndpointPortHeader is the HTTP header used to specify the EPP endpoint port. + EPPEndpointPortHeader = "X-EPP-Port" + // GoShimPort is the default port for the Go EPP shim server to listen on. If collisions become a problem, + // we can make this configurable via the NginxProxy resource. + GoShimPort = 54800 // why 54800? Sum "nginx" in ASCII and multiply by 100. +)