Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 10 additions & 14 deletions cmd/gateway/endpoint_picker.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,8 @@ import (
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
eppMetadata "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metadata"
)

const (
// defaultPort is the default port for this server to listen on. If collisions become a problem,
// we can make this configurable via the NginxProxy resource.
defaultPort = 54800 // why 54800? Sum "nginx" in ASCII and multiply by 100.
// eppEndpointHostHeader is the HTTP header used to specify the EPP endpoint host, set by the NJS module caller.
eppEndpointHostHeader = "X-EPP-Host"
// eppEndpointPortHeader is the HTTP header used to specify the EPP endpoint port, set by the NJS module caller.
eppEndpointPortHeader = "X-EPP-Port"
"github.com/nginx/nginx-gateway-fabric/v2/internal/framework/types"
)

// extProcClientFactory creates a new ExternalProcessorClient and returns a close function.
Expand All @@ -32,7 +24,7 @@ type extProcClientFactory func(target string) (extprocv3.ExternalProcessorClient
// endpointPickerServer starts an HTTP server on the given port with the provided handler.
func endpointPickerServer(handler http.Handler) error {
server := &http.Server{
Addr: fmt.Sprintf("127.0.0.1:%d", defaultPort),
Addr: fmt.Sprintf("127.0.0.1:%d", types.GoShimPort),
Handler: handler,
ReadHeaderTimeout: 10 * time.Second,
}
Expand All @@ -54,13 +46,13 @@ func realExtProcClientFactory() extProcClientFactory {
// createEndpointPickerHandler returns an http.Handler that forwards requests to the EndpointPicker.
func createEndpointPickerHandler(factory extProcClientFactory, logger logr.Logger) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
host := r.Header.Get(eppEndpointHostHeader)
port := r.Header.Get(eppEndpointPortHeader)
host := r.Header.Get(types.EPPEndpointHostHeader)
port := r.Header.Get(types.EPPEndpointPortHeader)
if host == "" || port == "" {
msg := fmt.Sprintf(
"missing at least one of required headers: %s and %s",
eppEndpointHostHeader,
eppEndpointPortHeader,
types.EPPEndpointHostHeader,
types.EPPEndpointPortHeader,
)
logger.Error(errors.New(msg), "error contacting EndpointPicker")
http.Error(w, msg, http.StatusBadRequest)
Expand Down Expand Up @@ -174,6 +166,10 @@ func buildHeaderRequest(r *http.Request) *extprocv3.ProcessingRequest {
}

func buildBodyRequest(r *http.Request) (*extprocv3.ProcessingRequest, error) {
if r.ContentLength == 0 {
return nil, errors.New("request body is empty")
}

body, err := io.ReadAll(r.Body)
if err != nil {
return nil, fmt.Errorf("error reading request body: %w", err)
Expand Down
42 changes: 35 additions & 7 deletions cmd/gateway/endpoint_picker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ import (
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"
eppMetadata "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metadata"

"github.com/nginx/nginx-gateway-fabric/v2/internal/framework/types"
)

type mockExtProcClient struct {
Expand Down Expand Up @@ -122,8 +124,8 @@ func TestEndpointPickerHandler_Success(t *testing.T) {

h := createEndpointPickerHandler(factory, logr.Discard())
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("test body"))
req.Header.Set(eppEndpointHostHeader, "test-host")
req.Header.Set(eppEndpointPortHeader, "1234")
req.Header.Set(types.EPPEndpointHostHeader, "test-host")
req.Header.Set(types.EPPEndpointPortHeader, "1234")
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()

Expand Down Expand Up @@ -165,8 +167,8 @@ func TestEndpointPickerHandler_ImmediateResponse(t *testing.T) {

h := createEndpointPickerHandler(factory, logr.Discard())
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("test body"))
req.Header.Set(eppEndpointHostHeader, "test-host")
req.Header.Set(eppEndpointPortHeader, "1234")
req.Header.Set(types.EPPEndpointHostHeader, "test-host")
req.Header.Set(types.EPPEndpointPortHeader, "1234")
w := httptest.NewRecorder()

h.ServeHTTP(w, req)
Expand All @@ -190,8 +192,8 @@ func TestEndpointPickerHandler_Errors(t *testing.T) {
h := createEndpointPickerHandler(factory, logr.Discard())
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("test body"))
if setHeaders {
req.Header.Set(eppEndpointHostHeader, "test-host")
req.Header.Set(eppEndpointPortHeader, "1234")
req.Header.Set(types.EPPEndpointHostHeader, "test-host")
req.Header.Set(types.EPPEndpointPortHeader, "1234")
}
w := httptest.NewRecorder()
h.ServeHTTP(w, req)
Expand Down Expand Up @@ -236,7 +238,33 @@ func TestEndpointPickerHandler_Errors(t *testing.T) {
}
runErrorTestCase(factory, true, http.StatusBadGateway, "error sending headers")

// 4. Error sending body
// 4a. Error building body request (content length 0)
client = &mockProcessClient{
SendFunc: func(*extprocv3.ProcessingRequest) error {
return nil
},
RecvFunc: func() (*extprocv3.ProcessingResponse, error) { return nil, io.EOF },
}
extProcClient = &mockExtProcClient{
ProcessFunc: func(context.Context, ...grpc.CallOption) (extprocv3.ExternalProcessor_ProcessClient, error) {
return client, nil
},
}
factory = func(string) (extprocv3.ExternalProcessorClient, func() error, error) {
return extProcClient, func() error { return nil }, nil
}
h := createEndpointPickerHandler(factory, logr.Discard())
req := httptest.NewRequest(http.MethodPost, "/", nil) // nil body, ContentLength = 0
req.Header.Set(types.EPPEndpointHostHeader, "test-host")
req.Header.Set(types.EPPEndpointPortHeader, "1234")
w := httptest.NewRecorder()
h.ServeHTTP(w, req)
resp := w.Result()
g.Expect(resp.StatusCode).To(Equal(http.StatusInternalServerError))
body, _ := io.ReadAll(resp.Body)
g.Expect(string(body)).To(ContainSubstring("request body is empty"))

// 4b. Error sending body
client = &mockProcessClient{
SendFunc: func(req *extprocv3.ProcessingRequest) error {
if req.GetRequestBody() != nil {
Expand Down
1 change: 1 addition & 0 deletions deploy/inference-nginx-plus/deploy.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,7 @@ spec:
- --nginx-docker-secret=nginx-plus-registry-secret
- --nginx-plus
- --usage-report-secret=nplus-license
- --usage-report-enforce-initial-report=true
- --metrics-port=9113
- --health-port=8081
- --leader-election-lock-name=nginx-gateway-leader-election
Expand Down
56 changes: 44 additions & 12 deletions internal/controller/nginx/config/http/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,26 +26,58 @@ type Server struct {
type LocationType string

const (
// InternalLocationType defines an internal location that is only accessible within NGINX.
InternalLocationType LocationType = "internal"
// ExternalLocationType defines a normal external location that is accessible by clients.
ExternalLocationType LocationType = "external"
// RedirectLocationType defines an external location that redirects to an internal location
// based on HTTP matching conditions.
RedirectLocationType LocationType = "redirect"
// InferenceExternalLocationType defines an external location that is used for calling NJS
// to get the inference workload endpoint and redirects to the internal location that will proxy_pass
// to that endpoint.
InferenceExternalLocationType LocationType = "inference-external"
// InferenceInternalLocationType defines an internal location that is used for calling NJS
// to get the inference workload endpoint and redirects to the internal location that will proxy_pass
// to that endpoint. This is used when an HTTP redirect location is also defined that redirects
// to this internal inference location.
InferenceInternalLocationType LocationType = "inference-internal"
)

// Location holds all configuration for an HTTP location.
type Location struct {
Path string
ProxyPass string
HTTPMatchKey string
// Return specifies a return directive (e.g., HTTP status or redirect) for this location block.
Return *Return
// ProxySSLVerify controls SSL verification for upstreams when proxying requests.
ProxySSLVerify *ProxySSLVerify
// ProxyPass is the upstream backend (URL or name) to which requests are proxied.
ProxyPass string
// HTTPMatchKey is the key for associating HTTP match rules, used for routing and NJS module logic.
HTTPMatchKey string
// MirrorSplitClientsVariableName is the variable name for split_clients, used in traffic mirroring scenarios.
MirrorSplitClientsVariableName string
Type LocationType
ProxySetHeaders []Header
ProxySSLVerify *ProxySSLVerify
Return *Return
ResponseHeaders ResponseHeaders
Rewrites []string
MirrorPaths []string
Includes []shared.Include
GRPC bool
// EPPInternalPath is the internal path for the inference NJS module to redirect to.
EPPInternalPath string
// EPPHost is the host for the EndpointPicker, used for inference routing.
EPPHost string
// Type indicates the type of location (external, internal, redirect, etc).
Type LocationType
// Path is the NGINX location path.
Path string
// ResponseHeaders are custom response headers to be sent.
ResponseHeaders ResponseHeaders
// ProxySetHeaders are headers to set when proxying requests upstream.
ProxySetHeaders []Header
// Rewrites are rewrite rules for modifying request paths.
Rewrites []string
// MirrorPaths are paths to which requests are mirrored.
MirrorPaths []string
// Includes are additional NGINX config snippets or policies to include in this location.
Includes []shared.Include
// EPPPort is the port for the EndpointPicker, used for inference routing.
EPPPort int
// GRPC indicates if this location proxies gRPC traffic.
GRPC bool
}

// Header defines an HTTP header to be passed to the proxied server.
Expand Down
44 changes: 44 additions & 0 deletions internal/controller/nginx/config/maps.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
package config

import (
"fmt"
"strings"
gotemplate "text/template"

inference "sigs.k8s.io/gateway-api-inference-extension/api/v1"

"github.com/nginx/nginx-gateway-fabric/v2/internal/controller/nginx/config/shared"
"github.com/nginx/nginx-gateway-fabric/v2/internal/controller/state/dataplane"
"github.com/nginx/nginx-gateway-fabric/v2/internal/framework/helpers"
Expand All @@ -26,6 +29,8 @@ const (

func executeMaps(conf dataplane.Configuration) []executeResult {
maps := buildAddHeaderMaps(append(conf.HTTPServers, conf.SSLServers...))
maps = append(maps, buildInferenceMaps(conf.BackendGroups)...)

result := executeResult{
dest: httpConfigFile,
data: helpers.MustExecuteTemplate(mapsTemplate, maps),
Expand Down Expand Up @@ -177,3 +182,42 @@ func createAddHeadersMap(name string) shared.Map {
Parameters: params,
}
}

// buildInferenceMaps creates maps for InferencePool Backends.
func buildInferenceMaps(groups []dataplane.BackendGroup) []shared.Map {
inferenceMaps := make([]shared.Map, 0, len(groups))
for _, group := range groups {
for _, backend := range group.Backends {
if backend.EndpointPickerConfig != nil {
var defaultResult string
switch backend.EndpointPickerConfig.FailureMode {
// in FailClose mode, if the EPP is unavailable or returns an error,
// we return an invalid backend to ensure the request fails
case inference.EndpointPickerFailClose:
defaultResult = invalidBackendRef
// in FailOpen mode, if the EPP is unavailable or returns an error,
// we fall back to the upstream
case inference.EndpointPickerFailOpen:
defaultResult = backend.UpstreamName
}
params := []shared.MapParameter{
{
Value: "~.+",
Result: "$inference_workload_endpoint",
},
{
Value: "default",
Result: defaultResult,
},
}
backendVarName := strings.ReplaceAll(backend.UpstreamName, "-", "_")
inferenceMaps = append(inferenceMaps, shared.Map{
Source: "$inference_workload_endpoint",
Variable: fmt.Sprintf("$inference_backend_%s", backendVarName),
Parameters: params,
})
}
}
}
return inferenceMaps
}
65 changes: 52 additions & 13 deletions internal/controller/nginx/config/maps_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"testing"

. "github.com/onsi/gomega"
inference "sigs.k8s.io/gateway-api-inference-extension/api/v1"

"github.com/nginx/nginx-gateway-fabric/v2/internal/controller/nginx/config/shared"
"github.com/nginx/nginx-gateway-fabric/v2/internal/controller/state/dataplane"
Expand Down Expand Up @@ -59,22 +60,24 @@ func TestExecuteMaps(t *testing.T) {

conf := dataplane.Configuration{
HTTPServers: []dataplane.VirtualServer{
{
PathRules: pathRules,
},
{
PathRules: pathRules,
},
{
IsDefault: true,
},
{PathRules: pathRules},
{PathRules: pathRules},
{IsDefault: true},
},
SSLServers: []dataplane.VirtualServer{
{PathRules: pathRules},
{IsDefault: true},
},
BackendGroups: []dataplane.BackendGroup{
{
PathRules: pathRules,
},
{
IsDefault: true,
Backends: []dataplane.Backend{
{
UpstreamName: "upstream1",
EndpointPickerConfig: &inference.EndpointPickerRef{
FailureMode: inference.EndpointPickerFailClose,
},
},
},
},
},
}
Expand All @@ -86,6 +89,9 @@ func TestExecuteMaps(t *testing.T) {
"map ${http_my_second_add_header} $my_second_add_header_header_var {": 1,
"~.* ${http_my_second_add_header},;": 1,
"map ${http_my_set_header} $my_set_header_header_var {": 0,
"$inference_workload_endpoint": 2,
"$inference_backend": 1,
"invalid-backend-ref": 1,
}

mapResult := executeMaps(conf)
Expand Down Expand Up @@ -385,3 +391,36 @@ func TestCreateStreamMapsWithEmpty(t *testing.T) {

g.Expect(maps).To(BeNil())
}

func TestBuildInferenceMaps(t *testing.T) {
t.Parallel()
g := NewWithT(t)

group := dataplane.BackendGroup{
Backends: []dataplane.Backend{
{
UpstreamName: "upstream1",
EndpointPickerConfig: &inference.EndpointPickerRef{
FailureMode: inference.EndpointPickerFailClose,
},
},
{
UpstreamName: "upstream2",
EndpointPickerConfig: &inference.EndpointPickerRef{
FailureMode: inference.EndpointPickerFailOpen,
},
},
{
UpstreamName: "upstream3",
EndpointPickerConfig: nil,
},
},
}

maps := buildInferenceMaps([]dataplane.BackendGroup{group})
g.Expect(maps).To(HaveLen(2))
g.Expect(maps[0].Source).To(Equal("$inference_workload_endpoint"))
g.Expect(maps[0].Variable).To(Equal("$inference_backend_upstream1"))
g.Expect(maps[0].Parameters[1].Result).To(Equal("invalid-backend-ref"))
g.Expect(maps[1].Parameters[1].Result).To(Equal("upstream2"))
}
Loading
Loading