From d8fb2d9262818c012acc3c3b5d0a2531520d27dc Mon Sep 17 00:00:00 2001 From: Will Beason Date: Tue, 8 Feb 2022 08:23:43 -0800 Subject: [PATCH 1/2] Refactor out backend code Having the backend intermediate type for instantiating and using Client adds complexity, code length, and unnecessary error paths. This commit fully removes backend without breaking any existing funcitonality. Signed-off-by: Will Beason --- constraint/pkg/client/backend.go | 88 --------------- constraint/pkg/client/backend_test.go | 98 ----------------- constraint/pkg/client/client.go | 34 +++--- .../client/client_addtemplate_bench_test.go | 19 +--- constraint/pkg/client/client_opts.go | 16 +++ constraint/pkg/client/client_test.go | 102 +++--------------- constraint/pkg/client/clienttest/client.go | 16 ++- constraint/pkg/client/e2e_test.go | 7 +- constraint/pkg/client/new_client.go | 50 +++++++++ constraint/pkg/client/new_client_test.go | 41 +++++++ 10 files changed, 150 insertions(+), 321 deletions(-) delete mode 100644 constraint/pkg/client/backend.go delete mode 100644 constraint/pkg/client/backend_test.go create mode 100644 constraint/pkg/client/new_client.go create mode 100644 constraint/pkg/client/new_client_test.go diff --git a/constraint/pkg/client/backend.go b/constraint/pkg/client/backend.go deleted file mode 100644 index 62851420a..000000000 --- a/constraint/pkg/client/backend.go +++ /dev/null @@ -1,88 +0,0 @@ -package client - -import ( - "fmt" - - "github.com/open-policy-agent/frameworks/constraint/pkg/client/drivers" - "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" - "k8s.io/apimachinery/pkg/runtime/schema" -) - -type Backend struct { - driver drivers.Driver - hasClient bool -} - -type BackendOpt func(*Backend) - -func Driver(d drivers.Driver) BackendOpt { - return func(b *Backend) { - b.driver = d - } -} - -// NewBackend creates a new backend. A backend could be a connection to a remote -// server or a new local OPA instance. -// -// A BackendOpt setting driver, such as Driver() must be passed. -func NewBackend(opts ...BackendOpt) (*Backend, error) { - b := &Backend{} - for _, opt := range opts { - opt(b) - } - - if b.driver == nil { - return nil, fmt.Errorf("%w: no driver supplied", ErrCreatingBackend) - } - - return b, nil -} - -// NewClient creates a new client for the supplied backend. -func (b *Backend) NewClient(opts ...Opt) (*Client, error) { - if b.hasClient { - return nil, fmt.Errorf("%w: only one client per backend is allowed", - ErrCreatingClient) - } - - var fields []string - for k := range validDataFields { - fields = append(fields, k) - } - - c := &Client{ - backend: b, - constraints: make(map[schema.GroupKind]map[string]*unstructured.Unstructured), - templates: make(map[templateKey]*templateEntry), - AllowedDataFields: fields, - } - - for _, opt := range opts { - if err := opt(c); err != nil { - return nil, err - } - } - - for _, field := range c.AllowedDataFields { - if !validDataFields[field] { - return nil, fmt.Errorf("%w: invalid data field %q; allowed fields are: %v", - ErrCreatingClient, field, validDataFields) - } - } - - if len(c.targets) == 0 { - return nil, fmt.Errorf("%w: must specify at least one target with client.Targets", - ErrCreatingClient) - } - - if err := b.driver.Init(); err != nil { - return nil, err - } - - if err := c.init(); err != nil { - return nil, err - } - - b.hasClient = true - return c, nil -} diff --git a/constraint/pkg/client/backend_test.go b/constraint/pkg/client/backend_test.go deleted file mode 100644 index 6fafbb3b3..000000000 --- a/constraint/pkg/client/backend_test.go +++ /dev/null @@ -1,98 +0,0 @@ -package client_test - -import ( - "errors" - "testing" - - "github.com/open-policy-agent/frameworks/constraint/pkg/client" - "github.com/open-policy-agent/frameworks/constraint/pkg/client/drivers/local" - "github.com/open-policy-agent/frameworks/constraint/pkg/handler/handlertest" -) - -func TestNewBackend(t *testing.T) { - testCases := []struct { - name string - opts []client.BackendOpt - wantError error - }{ - { - name: "no args", - opts: nil, - wantError: client.ErrCreatingBackend, - }, - { - name: "good", - opts: []client.BackendOpt{client.Driver(local.New())}, - wantError: nil, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - _, gotErr := client.NewBackend(tc.opts...) - - if !errors.Is(gotErr, tc.wantError) { - t.Fatalf("got NewBackent() error = %v, want %v", - gotErr, tc.wantError) - } - }) - } -} - -func TestBackend_NewClient(t *testing.T) { - testCases := []struct { - name string - backendOpts []client.BackendOpt - clientOpts []client.Opt - wantError error - }{ - { - name: "no opts", - backendOpts: nil, - clientOpts: nil, - wantError: client.ErrCreatingClient, - }, - { - name: "with handler", - backendOpts: nil, - clientOpts: []client.Opt{client.Targets(&handlertest.Handler{})}, - wantError: nil, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - opts := []client.BackendOpt{client.Driver(local.New())} - opts = append(opts, tc.backendOpts...) - - backend, err := client.NewBackend(opts...) - if err != nil { - t.Fatal(err) - } - - _, err = backend.NewClient(tc.clientOpts...) - if !errors.Is(err, tc.wantError) { - t.Fatalf("got NewClient() eror = %v, want %v", - err, tc.wantError) - } - }) - } -} - -func TestBackend_NewClient_MultipleClients(t *testing.T) { - backend, err := client.NewBackend(client.Driver(local.New())) - if err != nil { - t.Fatal(err) - } - - _, err = backend.NewClient(client.Targets(&handlertest.Handler{})) - if err != nil { - t.Fatal(err) - } - - _, err = backend.NewClient(client.Targets(&handlertest.Handler{})) - if !errors.Is(err, client.ErrCreatingClient) { - t.Fatalf("got NewClient() err = %v, want %v", - err, client.ErrCreatingClient) - } -} diff --git a/constraint/pkg/client/client.go b/constraint/pkg/client/client.go index b954c713e..a7e7f494e 100644 --- a/constraint/pkg/client/client.go +++ b/constraint/pkg/client/client.go @@ -33,7 +33,7 @@ type templateEntry struct { } type Client struct { - backend *Backend + driver drivers.Driver targets map[string]handler.TargetHandler // mtx guards access to both templates and constraints. @@ -69,7 +69,7 @@ func (c *Client) AddData(ctx context.Context, data interface{}) (*types.Response if !handled { continue } - if err := c.backend.driver.PutData(ctx, createDataPath(target, relPath), processedData); err != nil { + if err := c.driver.PutData(ctx, createDataPath(target, relPath), processedData); err != nil { errMap[target] = err continue } @@ -96,7 +96,7 @@ func (c *Client) RemoveData(ctx context.Context, data interface{}) (*types.Respo if !handled { continue } - if _, err := c.backend.driver.DeleteData(ctx, createDataPath(target, relPath)); err != nil { + if _, err := c.driver.DeleteData(ctx, createDataPath(target, relPath)); err != nil { errMap[target] = err continue } @@ -268,11 +268,11 @@ func (c *Client) ValidateConstraintTemplate(templ *templates.ConstraintTemplate) if _, _, err := c.ValidateConstraintTemplateBasic(templ); err != nil { return err } - if dr, ok := c.backend.driver.(*local.Driver); ok { + if dr, ok := c.driver.(*local.Driver); ok { _, _, err := dr.ValidateConstraintTemplate(templ) return err } - return fmt.Errorf("driver %T is not supported", c.backend.driver) + return fmt.Errorf("driver %T is not supported", c.driver) } // AddTemplate adds the template source code to OPA and registers the CRD with the client for @@ -295,7 +295,7 @@ func (c *Client) AddTemplate(templ *templates.ConstraintTemplate) (*types.Respon c.mtx.Lock() defer c.mtx.Unlock() - if err = c.backend.driver.AddTemplate(templ); err != nil { + if err = c.driver.AddTemplate(templ); err != nil { return resp, err } cpy := templ.DeepCopy() @@ -341,7 +341,7 @@ func (c *Client) RemoveTemplate(ctx context.Context, templ *templates.Constraint return resp, err } - if err := c.backend.driver.RemoveTemplate(templ); err != nil { + if err := c.driver.RemoveTemplate(templ); err != nil { return resp, err } @@ -353,7 +353,7 @@ func (c *Client) RemoveTemplate(ctx context.Context, templ *templates.Constraint delete(c.constraints, artifacts.gk) // Also clean up root path to avoid memory leaks constraintRoot := createConstraintGKPath(artifacts.targetHandler.GetName(), artifacts.gk) - if _, err := c.backend.driver.DeleteData(ctx, constraintRoot); err != nil { + if _, err := c.driver.DeleteData(ctx, constraintRoot); err != nil { return resp, err } delete(c.templates, artifacts.Key()) @@ -484,7 +484,7 @@ func (c *Client) AddConstraint(ctx context.Context, constraint *unstructured.Uns if err := c.validateConstraint(constraint, false); err != nil { return resp, err } - if err := c.backend.driver.AddConstraint(ctx, constraint); err != nil { + if err := c.driver.AddConstraint(ctx, constraint); err != nil { return resp, err } for _, target := range entry.Targets { @@ -513,7 +513,7 @@ func (c *Client) removeConstraintNoLock(ctx context.Context, constraint *unstruc if err != nil { return resp, err } - if err := c.backend.driver.RemoveConstraint(ctx, constraint); err != nil { + if err := c.driver.RemoveConstraint(ctx, constraint); err != nil { return resp, err } for _, target := range entry.Targets { @@ -584,7 +584,7 @@ func (c *Client) init() error { } builtinPath := fmt.Sprintf("%s.hooks_builtin", hooks) - err := c.backend.driver.PutModule(builtinPath, libBuiltin.String()) + err := c.driver.PutModule(builtinPath, libBuiltin.String()) if err != nil { return err } @@ -633,20 +633,20 @@ func (c *Client) init() error { ErrCreatingClient, err) } - err = c.backend.driver.PutModule(modulePath, string(src)) + err = c.driver.PutModule(modulePath, string(src)) if err != nil { return fmt.Errorf("%w: error %s from compiled source:\n%s", ErrCreatingClient, err, src) } } - if d, ok := c.backend.driver.(*local.Driver); ok { + if d, ok := c.driver.(*local.Driver); ok { var externs []string for _, field := range c.AllowedDataFields { externs = append(externs, fmt.Sprintf("data.%s", field)) } d.SetExterns(externs) } else { - return fmt.Errorf("%w: driver %T is not supported", ErrCreatingClient, c.backend.driver) + return fmt.Errorf("%w: driver %T is not supported", ErrCreatingClient, c.driver) } return nil } @@ -673,7 +673,7 @@ TargetLoop: continue } input := map[string]interface{}{"review": review} - resp, err := c.backend.driver.Query(ctx, fmt.Sprintf(`hooks["%s"].violation`, name), input, drivers.Tracing(cfg.enableTracing)) + resp, err := c.driver.Query(ctx, fmt.Sprintf(`hooks["%s"].violation`, name), input, drivers.Tracing(cfg.enableTracing)) if err != nil { errMap[name] = err continue @@ -706,7 +706,7 @@ func (c *Client) Audit(ctx context.Context, opts ...QueryOpt) (*types.Responses, TargetLoop: for name, target := range c.targets { // Short-circuiting question applies here as well - resp, err := c.backend.driver.Query(ctx, fmt.Sprintf(`hooks["%s"].audit`, name), nil, drivers.Tracing(cfg.enableTracing)) + resp, err := c.driver.Query(ctx, fmt.Sprintf(`hooks["%s"].audit`, name), nil, drivers.Tracing(cfg.enableTracing)) if err != nil { errMap[name] = err continue @@ -728,7 +728,7 @@ TargetLoop: // Dump dumps the state of OPA to aid in debugging. func (c *Client) Dump(ctx context.Context) (string, error) { - return c.backend.driver.Dump(ctx) + return c.driver.Dump(ctx) } // knownTargets returns a sorted list of currently-known target names. diff --git a/constraint/pkg/client/client_addtemplate_bench_test.go b/constraint/pkg/client/client_addtemplate_bench_test.go index 9403f01dc..e446104c4 100644 --- a/constraint/pkg/client/client_addtemplate_bench_test.go +++ b/constraint/pkg/client/client_addtemplate_bench_test.go @@ -4,10 +4,8 @@ import ( "fmt" "testing" - "github.com/open-policy-agent/frameworks/constraint/pkg/client" - "github.com/open-policy-agent/frameworks/constraint/pkg/client/drivers/local" + "github.com/open-policy-agent/frameworks/constraint/pkg/client/clienttest" "github.com/open-policy-agent/frameworks/constraint/pkg/core/templates" - "github.com/open-policy-agent/frameworks/constraint/pkg/handler/handlertest" ) var modules = []struct { @@ -80,24 +78,13 @@ func BenchmarkClient_AddTemplate(b *testing.B) { for i := 0; i < b.N; i++ { b.StopTimer() - targets := client.Targets(&handlertest.Handler{}) - d := local.New() - - backend, err := client.NewBackend(client.Driver(d)) - if err != nil { - b.Fatal(err) - } - - c, err := backend.NewClient(targets) - if err != nil { - b.Fatal(err) - } + c := clienttest.New(b) b.StartTimer() for _, ct := range cts { - _, err = c.AddTemplate(ct) + _, err := c.AddTemplate(ct) if err != nil { b.Fatal(err) } diff --git a/constraint/pkg/client/client_opts.go b/constraint/pkg/client/client_opts.go index bb592f3ec..b93e807f2 100644 --- a/constraint/pkg/client/client_opts.go +++ b/constraint/pkg/client/client_opts.go @@ -5,6 +5,7 @@ import ( "regexp" "sort" + "github.com/open-policy-agent/frameworks/constraint/pkg/client/drivers" "github.com/open-policy-agent/frameworks/constraint/pkg/handler" ) @@ -55,7 +56,22 @@ func validateTargetNames(ts []handler.TargetHandler) []string { // the system can be enabled. func AllowedDataFields(fields ...string) Opt { return func(c *Client) error { + for _, field := range fields { + if !validDataFields[field] { + return fmt.Errorf("%w: invalid data field %q; allowed fields are: %v", + ErrCreatingClient, field, validDataFields) + } + } + c.AllowedDataFields = fields return nil } } + +// Driver defines the Rego execution environment. +func Driver(d drivers.Driver) Opt { + return func(client *Client) error { + client.driver = d + return nil + } +} diff --git a/constraint/pkg/client/client_test.go b/constraint/pkg/client/client_test.go index aebef9f9c..4693750a3 100644 --- a/constraint/pkg/client/client_test.go +++ b/constraint/pkg/client/client_test.go @@ -55,12 +55,7 @@ func TestBackend_NewClient_InvalidTargetName(t *testing.T) { t.Run(tc.name, func(t *testing.T) { d := local.New() - b, err := client.NewBackend(client.Driver(d)) - if err != nil { - t.Fatalf("Could not create backend: %s", err) - } - - _, err = b.NewClient(client.Targets(tc.handler)) + _, err := client.NewClient(client.Targets(tc.handler), client.Driver(d)) if !errors.Is(err, tc.wantError) { t.Errorf("got NewClient() error = %v, want %v", err, tc.wantError) @@ -139,12 +134,7 @@ func TestClient_AddData(t *testing.T) { t.Run(tc.name, func(t *testing.T) { d := local.New() - b, err := client.NewBackend(client.Driver(d)) - if err != nil { - t.Fatalf("Could not create backend: %s", err) - } - - c, err := b.NewClient(client.Targets(tc.handler1, tc.handler2)) + c, err := client.NewClient(client.Targets(tc.handler1, tc.handler2), client.Driver(d)) if err != nil { t.Fatal(err) } @@ -246,12 +236,7 @@ func TestClient_RemoveData(t *testing.T) { t.Run(tc.name, func(t *testing.T) { d := local.New() - b, err := client.NewBackend(client.Driver(d)) - if err != nil { - t.Fatalf("Could not create backend: %s", err) - } - - c, err := b.NewClient(client.Targets(tc.handler1, tc.handler2)) + c, err := client.NewClient(client.Targets(tc.handler1, tc.handler2), client.Driver(d)) if err != nil { t.Fatal(err) } @@ -352,12 +337,7 @@ r = 5 t.Run(tc.name, func(t *testing.T) { d := local.New() - b, err := client.NewBackend(client.Driver(d)) - if err != nil { - t.Fatalf("Could not create backend: %s", err) - } - - c, err := b.NewClient(client.Targets(tc.handler)) + c, err := client.NewClient(client.Targets(tc.handler), client.Driver(d)) if err != nil { t.Fatal(err) } @@ -443,12 +423,7 @@ func TestClient_RemoveTemplate(t *testing.T) { t.Run(tc.name, func(t *testing.T) { d := local.New() - b, err := client.NewBackend(client.Driver(d)) - if err != nil { - t.Fatalf("Could not create backend: %s", err) - } - - c, err := b.NewClient(client.Targets(tc.handler)) + c, err := client.NewClient(client.Targets(tc.handler), client.Driver(d)) if err != nil { t.Fatal(err) } @@ -506,12 +481,7 @@ func TestClient_RemoveTemplate_ByNameOnly(t *testing.T) { t.Run(tc.name, func(t *testing.T) { d := local.New() - b, err := client.NewBackend(client.Driver(d)) - if err != nil { - t.Fatalf("Could not create backend: %s", err) - } - - c, err := b.NewClient(client.Targets(tc.handler)) + c, err := client.NewClient(client.Targets(tc.handler), client.Driver(d)) if err != nil { t.Fatal(err) } @@ -572,12 +542,7 @@ func TestClient_GetTemplate(t *testing.T) { t.Run(tc.name, func(t *testing.T) { d := local.New() - b, err := client.NewBackend(client.Driver(d)) - if err != nil { - t.Fatalf("Could not create backend: %s", err) - } - - c, err := b.NewClient(client.Targets(tc.handler)) + c, err := client.NewClient(client.Targets(tc.handler), client.Driver(d)) if err != nil { t.Fatal(err) } @@ -638,14 +603,7 @@ func TestClient_GetTemplate_ByNameOnly(t *testing.T) { for _, tc := range tcs { t.Run(tc.name, func(t *testing.T) { - d := local.New() - - b, err := client.NewBackend(client.Driver(d)) - if err != nil { - t.Fatalf("Could not create backend: %s", err) - } - - c, err := b.NewClient(client.Targets(tc.handler)) + c, err := client.NewClient(client.Driver(local.New()), client.Targets(tc.handler)) if err != nil { t.Fatal(err) } @@ -680,12 +638,7 @@ func TestClient_RemoveTemplate_CascadingDelete(t *testing.T) { h := &handlertest.Handler{} d := local.New() - b, err := client.NewBackend(client.Driver(d)) - if err != nil { - t.Fatalf("Could not create backend: %s", err) - } - - c, err := b.NewClient(client.Targets(h)) + c, err := client.NewClient(client.Targets(h), client.Driver(d)) if err != nil { t.Fatal(err) } @@ -843,12 +796,7 @@ func TestClient_AddConstraint(t *testing.T) { t.Run(tc.name, func(t *testing.T) { d := local.New() - b, err := client.NewBackend(client.Driver(d)) - if err != nil { - t.Fatal(err) - } - - c, err := b.NewClient(client.Targets(&handlertest.Handler{})) + c, err := client.NewClient(client.Targets(&handlertest.Handler{}), client.Driver(d)) if err != nil { t.Fatal(err) } @@ -961,13 +909,8 @@ func TestClient_RemoveConstraint(t *testing.T) { ctx := context.Background() d := local.New() - b, err := client.NewBackend(client.Driver(d)) - if err != nil { - t.Fatalf("Could not create backend: %s", err) - } - h := &handlertest.Handler{} - c, err := b.NewClient(client.Targets(h)) + c, err := client.NewClient(client.Targets(h), client.Driver(d)) if err != nil { t.Fatal(err) } @@ -1055,12 +998,7 @@ violation[{"msg": "msg"}] { t.Run(tc.name, func(t *testing.T) { d := local.New() - b, err := client.NewBackend(client.Driver(d)) - if err != nil { - t.Fatal(err) - } - - c, err := b.NewClient(client.Targets(tc.handler), client.AllowedDataFields(tc.allowedFields...)) + c, err := client.NewClient(client.Targets(tc.handler), client.AllowedDataFields(tc.allowedFields...), client.Driver(d)) if err != nil { t.Fatal(err) } @@ -1114,17 +1052,12 @@ func TestClient_AllowedDataFields_Intersection(t *testing.T) { t.Run(tc.name, func(t *testing.T) { d := local.New() - b, err := client.NewBackend(client.Driver(d)) - if err != nil { - t.Fatalf("Could not create backend: %s", err) - } - - opts := []client.Opt{client.Targets(&handlertest.Handler{})} + opts := []client.Opt{client.Targets(&handlertest.Handler{}), client.Driver(d)} if tc.allowed != nil { opts = append(opts, tc.allowed) } - c, err := b.NewClient(opts...) + c, err := client.NewClient(opts...) if !errors.Is(err, tc.wantError) { t.Fatalf("got NewClient() error = %v, want %v", err, tc.wantError) @@ -1338,12 +1271,7 @@ violation[msg] {msg := "always"}`, t.Run(tc.name, func(t *testing.T) { d := local.New() - b, err := client.NewBackend(client.Driver(d)) - if err != nil { - t.Fatal(err) - } - - c, err := b.NewClient(client.Targets(tc.targets...)) + c, err := client.NewClient(client.Targets(tc.targets...), client.Driver(d)) if err != nil { t.Fatal(err) } diff --git a/constraint/pkg/client/clienttest/client.go b/constraint/pkg/client/clienttest/client.go index 4997d87d6..a68714aef 100644 --- a/constraint/pkg/client/clienttest/client.go +++ b/constraint/pkg/client/clienttest/client.go @@ -8,8 +8,11 @@ import ( "github.com/open-policy-agent/frameworks/constraint/pkg/handler/handlertest" ) -var defaults = []client.Opt{ - client.Targets(&handlertest.Handler{}), +func defaults() []client.Opt { + return []client.Opt{ + client.Driver(local.New()), + client.Targets(&handlertest.Handler{}), + } } // New constructs a new Client for testing with a default-constructed local driver @@ -17,14 +20,9 @@ var defaults = []client.Opt{ func New(t testing.TB, opts ...client.Opt) *client.Client { t.Helper() - backend, err := client.NewBackend(client.Driver(local.New())) - if err != nil { - t.Fatal(err) - } - - opts = append(defaults, opts...) + opts = append(defaults(), opts...) - c, err := backend.NewClient(opts...) + c, err := client.NewClient(opts...) if err != nil { t.Fatal(err) } diff --git a/constraint/pkg/client/e2e_test.go b/constraint/pkg/client/e2e_test.go index 60f153630..8d6b70833 100644 --- a/constraint/pkg/client/e2e_test.go +++ b/constraint/pkg/client/e2e_test.go @@ -516,12 +516,7 @@ func TestClient_Review_Print(t *testing.T) { printHook := appendingPrintHook{printed: &printed} d := local.New(local.PrintEnabled(tc.printEnabled), local.PrintHook(printHook)) - b, err := client.NewBackend(client.Driver(d)) - if err != nil { - t.Fatal(err) - } - - c, err := b.NewClient(client.Targets(&handlertest.Handler{})) + c, err := client.NewClient(client.Targets(&handlertest.Handler{}), client.Driver(d)) if err != nil { t.Fatal(err) } diff --git a/constraint/pkg/client/new_client.go b/constraint/pkg/client/new_client.go new file mode 100644 index 000000000..d47eb8fe8 --- /dev/null +++ b/constraint/pkg/client/new_client.go @@ -0,0 +1,50 @@ +package client + +import ( + "fmt" + + "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" + "k8s.io/apimachinery/pkg/runtime/schema" +) + +// NewClient creates a new client. +func NewClient(opts ...Opt) (*Client, error) { + var fields []string + for k := range validDataFields { + fields = append(fields, k) + } + + c := &Client{ + constraints: make(map[schema.GroupKind]map[string]*unstructured.Unstructured), + templates: make(map[templateKey]*templateEntry), + AllowedDataFields: fields, + } + + for _, opt := range opts { + if err := opt(c); err != nil { + return nil, err + } + } + + for _, field := range c.AllowedDataFields { + if !validDataFields[field] { + return nil, fmt.Errorf("%w: invalid data field %q; allowed fields are: %v", + ErrCreatingClient, field, validDataFields) + } + } + + if len(c.targets) == 0 { + return nil, fmt.Errorf("%w: must specify at least one target with client.Targets", + ErrCreatingClient) + } + + if err := c.driver.Init(); err != nil { + return nil, err + } + + if err := c.init(); err != nil { + return nil, err + } + + return c, nil +} diff --git a/constraint/pkg/client/new_client_test.go b/constraint/pkg/client/new_client_test.go new file mode 100644 index 000000000..71b769286 --- /dev/null +++ b/constraint/pkg/client/new_client_test.go @@ -0,0 +1,41 @@ +package client_test + +import ( + "errors" + "testing" + + "github.com/open-policy-agent/frameworks/constraint/pkg/client" + "github.com/open-policy-agent/frameworks/constraint/pkg/client/drivers/local" + "github.com/open-policy-agent/frameworks/constraint/pkg/handler/handlertest" +) + +func TestNewClient(t *testing.T) { + testCases := []struct { + name string + clientOpts []client.Opt + wantError error + }{ + { + name: "no opts", + clientOpts: nil, + wantError: client.ErrCreatingClient, + }, + { + name: "with handler", + clientOpts: []client.Opt{client.Targets(&handlertest.Handler{})}, + wantError: nil, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + opts := append(tc.clientOpts, client.Driver(local.New())) + + _, err := client.NewClient(opts...) + if !errors.Is(err, tc.wantError) { + t.Fatalf("got NewClient() eror = %v, want %v", + err, tc.wantError) + } + }) + } +} From 24e707ef36552cefe083a9577dd068cd1e263bf5 Mon Sep 17 00:00:00 2001 From: Will Beason Date: Tue, 8 Feb 2022 08:27:31 -0800 Subject: [PATCH 2/2] Fix lint errors Signed-off-by: Will Beason --- constraint/pkg/client/new_client_test.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/constraint/pkg/client/new_client_test.go b/constraint/pkg/client/new_client_test.go index 71b769286..2cd45cf4c 100644 --- a/constraint/pkg/client/new_client_test.go +++ b/constraint/pkg/client/new_client_test.go @@ -29,7 +29,8 @@ func TestNewClient(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - opts := append(tc.clientOpts, client.Driver(local.New())) + opts := tc.clientOpts + opts = append(opts, client.Driver(local.New())) _, err := client.NewClient(opts...) if !errors.Is(err, tc.wantError) {