diff --git a/constraint/pkg/client/drivers/local/args.go b/constraint/pkg/client/drivers/local/args.go index 594aeb2ca..e5cb76d9e 100644 --- a/constraint/pkg/client/drivers/local/args.go +++ b/constraint/pkg/client/drivers/local/args.go @@ -43,6 +43,10 @@ func Defaults() Arg { d.compilers.capabilities.Builtins = append(d.compilers.capabilities.Builtins, newBuiltin) } + if d.sendRequestToProvider == nil { + d.sendRequestToProvider = externaldata.DefaultSendRequestToProvider + } + return nil } } diff --git a/constraint/pkg/client/drivers/local/builtin.go b/constraint/pkg/client/drivers/local/builtin.go index 72c27e5b6..bdc712dc4 100644 --- a/constraint/pkg/client/drivers/local/builtin.go +++ b/constraint/pkg/client/drivers/local/builtin.go @@ -1,12 +1,7 @@ package local import ( - "bytes" - "context" - "encoding/json" - "io/ioutil" "net/http" - "time" "github.com/open-policy-agent/frameworks/constraint/pkg/externaldata" "github.com/open-policy-agent/opa/ast" @@ -25,41 +20,12 @@ func externalDataBuiltin(d *Driver) func(bctx rego.BuiltinContext, regorequest * return externaldata.HandleError(http.StatusBadRequest, err) } - externaldataRequest := externaldata.NewProviderRequest(regoReq.Keys) - reqBody, err := json.Marshal(externaldataRequest) + externaldataResponse, statusCode, err := d.sendRequestToProvider(bctx.Context, &provider, regoReq.Keys) if err != nil { - return externaldata.HandleError(http.StatusInternalServerError, err) + return externaldata.HandleError(statusCode, err) } - req, err := http.NewRequest("POST", provider.Spec.URL, bytes.NewBuffer(reqBody)) - if err != nil { - return externaldata.HandleError(http.StatusInternalServerError, err) - } - req.Header.Set("Content-Type", "application/json") - - ctx, cancel := context.WithDeadline(bctx.Context, time.Now().Add(time.Duration(provider.Spec.Timeout)*time.Second)) - defer cancel() - - resp, err := http.DefaultClient.Do(req.WithContext(ctx)) - if err != nil { - return externaldata.HandleError(http.StatusInternalServerError, err) - } - - defer func() { - _ = resp.Body.Close() - }() - - respBody, err := ioutil.ReadAll(resp.Body) - if err != nil { - return externaldata.HandleError(http.StatusInternalServerError, err) - } - - var externaldataResponse externaldata.ProviderResponse - if err := json.Unmarshal(respBody, &externaldataResponse); err != nil { - return externaldata.HandleError(http.StatusInternalServerError, err) - } - - regoResponse := externaldata.NewRegoResponse(resp.StatusCode, &externaldataResponse) + regoResponse := externaldata.NewRegoResponse(statusCode, externaldataResponse) return externaldata.PrepareRegoResponse(regoResponse) } } diff --git a/constraint/pkg/client/drivers/local/driver.go b/constraint/pkg/client/drivers/local/driver.go index fc769c87d..79e8c75f3 100644 --- a/constraint/pkg/client/drivers/local/driver.go +++ b/constraint/pkg/client/drivers/local/driver.go @@ -58,6 +58,9 @@ type Driver struct { // providerCache allows Rego to read from external_data in Rego queries. providerCache *externaldata.ProviderCache + + // sendRequestToProvider allows Rego to send requests to the provider specified in external_data. + sendRequestToProvider externaldata.SendRequestToProvider } // AddTemplate adds templ to Driver. Normalizes modules into usable forms for diff --git a/constraint/pkg/externaldata/request.go b/constraint/pkg/externaldata/request.go index 3bc32234b..53c507d6b 100644 --- a/constraint/pkg/externaldata/request.go +++ b/constraint/pkg/externaldata/request.go @@ -1,5 +1,17 @@ package externaldata +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io/ioutil" + "net/http" + "time" + + "github.com/open-policy-agent/frameworks/constraint/pkg/apis/externaldata/v1alpha1" +) + // RegoRequest is the request for external_data rego function. type RegoRequest struct { // ProviderName is the name of the external data provider. @@ -35,6 +47,45 @@ func NewProviderRequest(keys []string) *ProviderRequest { } } +// SendRequestToProvider is a function that sends a request to the external data provider. +type SendRequestToProvider func(ctx context.Context, provider *v1alpha1.Provider, keys []string) (*ProviderResponse, int, error) + +// DefaultSendRequestToProvider is the default function to send the request to the external data provider. +func DefaultSendRequestToProvider(ctx context.Context, provider *v1alpha1.Provider, keys []string) (*ProviderResponse, int, error) { + externaldataRequest := NewProviderRequest(keys) + body, err := json.Marshal(externaldataRequest) + if err != nil { + return nil, http.StatusInternalServerError, fmt.Errorf("failed to marshal external data request: %w", err) + } + + req, err := http.NewRequest(http.MethodPost, provider.Spec.URL, bytes.NewBuffer(body)) + if err != nil { + return nil, http.StatusInternalServerError, fmt.Errorf("failed to create external data request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + + ctxWithDeadline, cancel := context.WithDeadline(ctx, time.Now().Add(time.Duration(provider.Spec.Timeout)*time.Second)) + defer cancel() + + resp, err := http.DefaultClient.Do(req.WithContext(ctxWithDeadline)) + if err != nil { + return nil, http.StatusInternalServerError, fmt.Errorf("failed to send external data request: %w", err) + } + defer resp.Body.Close() + + respBody, err := ioutil.ReadAll(resp.Body) + if err != nil { + return nil, http.StatusInternalServerError, fmt.Errorf("failed to read external data response: %w", err) + } + + var externaldataResponse ProviderResponse + if err := json.Unmarshal(respBody, &externaldataResponse); err != nil { + return nil, http.StatusInternalServerError, fmt.Errorf("failed to unmarshal external data response: %w", err) + } + + return &externaldataResponse, resp.StatusCode, nil +} + // ProviderKind strings are special string constants for Providers. // +kubebuilder:validation:Enum=ProviderRequestKind;ProviderResponseKind type ProviderKind string