From d04c5061b76dddf1970b56ad14022abcf4613a7e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fl=C3=A1vio=20Ramalho?= Date: Tue, 7 Sep 2021 13:56:39 +0200 Subject: [PATCH] Add badger store for Extension Implement ExtensionStore interface using Badger as storage. (cherry picked from commit a0b11dba2ac26900749a6666c3fb2da6a0b95c44) --- cmd/fuseml_core/wire.go | 4 +- cmd/fuseml_core/wire_gen.go | 7 +- pkg/core/store/badger/extension.go | 340 +++++++++ pkg/core/store/badger/extension_test.go | 881 ++++++++++++++++++++++++ 4 files changed, 1227 insertions(+), 5 deletions(-) create mode 100644 pkg/core/store/badger/extension.go create mode 100644 pkg/core/store/badger/extension_test.go diff --git a/cmd/fuseml_core/wire.go b/cmd/fuseml_core/wire.go index e0ce2904..2f0b55ef 100644 --- a/cmd/fuseml_core/wire.go +++ b/cmd/fuseml_core/wire.go @@ -39,8 +39,8 @@ var storeSet = wire.NewSet( wire.Bind(new(domain.RunnableStore), new(*core.RunnableStore)), badger.NewWorkflowStore, wire.Bind(new(domain.WorkflowStore), new(*badger.WorkflowStore)), - core.NewExtensionStore, - wire.Bind(new(domain.ExtensionStore), new(*core.ExtensionStore)), + badger.NewExtensionStore, + wire.Bind(new(domain.ExtensionStore), new(*badger.ExtensionStore)), ) var managerSet = wire.NewSet( diff --git a/cmd/fuseml_core/wire_gen.go b/cmd/fuseml_core/wire_gen.go index 842f4ed8..66f56f72 100644 --- a/cmd/fuseml_core/wire_gen.go +++ b/cmd/fuseml_core/wire_gen.go @@ -1,7 +1,8 @@ // Code generated by Wire. DO NOT EDIT. //go:generate go run github.com/google/wire/cmd/wire -//+build !wireinject +//go:build !wireinject +// +build !wireinject package main @@ -55,7 +56,7 @@ func InitializeCore(logger *log.Logger, storeOptions badgerhold.Options, fuseMLN return nil, err } workflowStore := badger.NewWorkflowStore(store) - extensionStore := core.NewExtensionStore() + extensionStore := badger.NewExtensionStore(store) extensionRegistry := manager.NewExtensionRegistry(extensionStore) workflowManager := manager.NewWorkflowManager(workflowBackend, workflowStore, gitCodesetStore, extensionRegistry) workflowService := svc.NewWorkflowService(logger, workflowManager) @@ -80,7 +81,7 @@ func InitializeCore(logger *log.Logger, storeOptions badgerhold.Options, fuseMLN // wire.go: -var storeSet = wire.NewSet(badgerhold.Open, badger.NewApplicationStore, wire.Bind(new(domain.ApplicationStore), new(*badger.ApplicationStore)), gitea.NewAdminClient, wire.Bind(new(domain.GitAdminClient), new(*gitea.AdminClient)), core.NewGitCodesetStore, wire.Bind(new(domain.CodesetStore), new(*core.GitCodesetStore)), core.NewGitProjectStore, wire.Bind(new(domain.ProjectStore), new(*core.GitProjectStore)), core.NewRunnableStore, wire.Bind(new(domain.RunnableStore), new(*core.RunnableStore)), badger.NewWorkflowStore, wire.Bind(new(domain.WorkflowStore), new(*badger.WorkflowStore)), core.NewExtensionStore, wire.Bind(new(domain.ExtensionStore), new(*core.ExtensionStore))) +var storeSet = wire.NewSet(badgerhold.Open, badger.NewApplicationStore, wire.Bind(new(domain.ApplicationStore), new(*badger.ApplicationStore)), gitea.NewAdminClient, wire.Bind(new(domain.GitAdminClient), new(*gitea.AdminClient)), core.NewGitCodesetStore, wire.Bind(new(domain.CodesetStore), new(*core.GitCodesetStore)), core.NewGitProjectStore, wire.Bind(new(domain.ProjectStore), new(*core.GitProjectStore)), core.NewRunnableStore, wire.Bind(new(domain.RunnableStore), new(*core.RunnableStore)), badger.NewWorkflowStore, wire.Bind(new(domain.WorkflowStore), new(*badger.WorkflowStore)), badger.NewExtensionStore, wire.Bind(new(domain.ExtensionStore), new(*badger.ExtensionStore))) var managerSet = wire.NewSet(manager.NewWorkflowManager, wire.Bind(new(domain.WorkflowManager), new(*manager.WorkflowManager)), manager.NewExtensionRegistry, wire.Bind(new(domain.ExtensionRegistry), new(*manager.ExtensionRegistry))) diff --git a/pkg/core/store/badger/extension.go b/pkg/core/store/badger/extension.go new file mode 100644 index 00000000..0c06d806 --- /dev/null +++ b/pkg/core/store/badger/extension.go @@ -0,0 +1,340 @@ +package badger + +import ( + "context" + "time" + + "github.com/fuseml/fuseml-core/pkg/domain" + "github.com/timshannon/badgerhold/v3" +) + +// ExtensionStore is a wrapper around a badgerhold.Store that implements the domain.ExtensionStore interface. +type ExtensionStore struct { + store *badgerhold.Store +} + +// NewExtensionStore creates a new ExtensionStore. +func NewExtensionStore(store *badgerhold.Store) *ExtensionStore { + return &ExtensionStore{store: store} +} + +// AddExtension adds a new extension to the store. +func (es *ExtensionStore) AddExtension(ctx context.Context, extension *domain.Extension) (*domain.Extension, error) { + extension.EnsureID(ctx, es) + extension.SetCreated(ctx) + + err := es.store.Insert(extension.ID, extension) + if err != nil { + return nil, domain.NewErrExtensionExists(extension.ID) + } + return extension, nil +} + +// GetExtension retrieves an extension by its ID. +func (es *ExtensionStore) GetExtension(ctx context.Context, extensionID string) (*domain.Extension, error) { + extension := &domain.Extension{} + err := es.store.Get(extensionID, extension) + if err != nil { + return nil, domain.NewErrExtensionNotFound(extensionID) + } + return extension, nil +} + +// ListExtensions retrieves all stored extensions. +func (es *ExtensionStore) ListExtensions(ctx context.Context, query *domain.ExtensionQuery) (result []*domain.Extension) { + result = []*domain.Extension{} + + // TODO: Replace with a badgerhold query. + if query != nil { + if query.ExtensionID != "" { + fullExtension, err := es.GetExtension(ctx, query.ExtensionID) + if err == nil { + matchingExtension := fullExtension.GetExtensionIfMatch(query) + if matchingExtension != nil { + result = append(result, matchingExtension) + } + } + return + } + + allExtensions := []*domain.Extension{} + es.store.Find(&allExtensions, nil) + + for _, extension := range allExtensions { + matchingExtension := extension.GetExtensionIfMatch(query) + if matchingExtension != nil { + result = append(result, matchingExtension) + } + } + return + } + + es.store.Find(&result, nil) + return +} + +// UpdateExtension updates an existing extension. +func (es *ExtensionStore) UpdateExtension(ctx context.Context, newExtension *domain.Extension) error { + extension, err := es.GetExtension(ctx, newExtension.ID) + if err != nil { + return err + } + newExtension.Created = extension.Created + newExtension.Updated = time.Now() + + for _, newExtService := range newExtension.ListServices() { + _, err := extension.GetService(newExtService.ID) + if err != nil { + // If the service is new, set the creation time + newExtService.SetCreated(newExtension.Updated) + } + } + + err = es.store.Update(newExtension.ID, newExtension) + if err != nil { + return domain.NewErrExtensionNotFound(newExtension.ID) + } + return nil +} + +// DeleteExtension deletes an extension from the store. +func (es *ExtensionStore) DeleteExtension(ctx context.Context, extensionID string) error { + extension, err := es.GetExtension(ctx, extensionID) + if err != nil { + return err + } + return es.store.Delete(extension.ID, extension) +} + +// AddExtensionService adds a new extension service to an extension. +func (es *ExtensionStore) AddExtensionService(ctx context.Context, extensionID string, service *domain.ExtensionService) (*domain.ExtensionService, error) { + extension, err := es.GetExtension(ctx, extensionID) + if err != nil { + return nil, err + } + svc, err := extension.AddService(service) + if err != nil { + return nil, err + } + err = es.UpdateExtension(ctx, extension) + if err != nil { + return nil, err + } + return svc, nil +} + +// GetExtensionService retrieves an extension service by its ID. +func (es *ExtensionStore) GetExtensionService(ctx context.Context, extensionID string, serviceID string) (*domain.ExtensionService, error) { + extension, err := es.GetExtension(ctx, extensionID) + if err != nil { + return nil, err + } + return extension.GetService(serviceID) +} + +// ListExtensionServices retrieves all services belonging to an extension. +func (es *ExtensionStore) ListExtensionServices(ctx context.Context, extensionID string) ([]*domain.ExtensionService, error) { + extension, err := es.GetExtension(ctx, extensionID) + if err != nil { + return nil, err + } + return extension.ListServices(), nil +} + +// UpdateExtensionService updates a service belonging to an extension. +func (es *ExtensionStore) UpdateExtensionService(ctx context.Context, extensionID string, newService *domain.ExtensionService) error { + extension, err := es.GetExtension(ctx, extensionID) + if err != nil { + return err + } + err = extension.UpdateService(newService) + if err != nil { + return err + } + return es.UpdateExtension(ctx, extension) +} + +// DeleteExtensionService deletes an extension service from an extension. +func (es *ExtensionStore) DeleteExtensionService(ctx context.Context, extensionID string, serviceID string) error { + extension, err := es.GetExtension(ctx, extensionID) + if err != nil { + return err + } + err = extension.DeleteService(serviceID) + if err != nil { + return err + } + return es.UpdateExtension(ctx, extension) +} + +// AddExtensionServiceEndpoint adds a new endpoint to an extension service. +func (es *ExtensionStore) AddExtensionServiceEndpoint(ctx context.Context, extensionID string, serviceID string, endpoint *domain.ExtensionServiceEndpoint) (*domain.ExtensionServiceEndpoint, error) { + extension, err := es.GetExtension(ctx, extensionID) + if err != nil { + return nil, err + } + svc, err := extension.GetService(serviceID) + if err != nil { + return nil, err + } + endpoint, err = svc.AddEndpoint(endpoint) + if err != nil { + return nil, err + } + err = es.UpdateExtension(ctx, extension) + if err != nil { + return nil, err + } + return endpoint, nil +} + +// GetExtensionServiceEndpoint retrieves an extension endpoint by its ID. +func (es *ExtensionStore) GetExtensionServiceEndpoint(ctx context.Context, extensionID string, serviceID string, endpointID string) (*domain.ExtensionServiceEndpoint, error) { + extension, err := es.GetExtension(ctx, extensionID) + if err != nil { + return nil, err + } + svc, err := extension.GetService(serviceID) + if err != nil { + return nil, err + } + return svc.GetEndpoint(endpointID) +} + +// ListExtensionServiceEndpoints retrieves all endpoints belonging to an extension service. +func (es *ExtensionStore) ListExtensionServiceEndpoints(ctx context.Context, extensionID string, serviceID string) ([]*domain.ExtensionServiceEndpoint, error) { + extension, err := es.GetExtension(ctx, extensionID) + if err != nil { + return nil, err + } + svc, err := extension.GetService(serviceID) + if err != nil { + return nil, err + } + return svc.ListEndpoints(), nil +} + +// UpdateExtensionServiceEndpoint updates an endpoint belonging to an extension service. +func (es *ExtensionStore) UpdateExtensionServiceEndpoint(ctx context.Context, extensionID string, serviceID string, newEndpoint *domain.ExtensionServiceEndpoint) error { + extension, err := es.GetExtension(ctx, extensionID) + if err != nil { + return err + } + svc, err := extension.GetService(serviceID) + if err != nil { + return err + } + err = svc.UpdateEndpoint(newEndpoint) + if err != nil { + return err + } + return es.UpdateExtension(ctx, extension) +} + +// DeleteExtensionServiceEndpoint deletes an extension endpoint from an extension service. +func (es *ExtensionStore) DeleteExtensionServiceEndpoint(ctx context.Context, extensionID string, serviceID string, endpointID string) error { + extension, err := es.GetExtension(ctx, extensionID) + if err != nil { + return err + } + svc, err := extension.GetService(serviceID) + if err != nil { + return err + } + err = svc.DeleteEndpoint(endpointID) + if err != nil { + return err + } + return es.UpdateExtension(ctx, extension) +} + +// AddExtensionServiceCredentials adds a new credential to an extension service. +func (es *ExtensionStore) AddExtensionServiceCredentials(ctx context.Context, extensionID string, serviceID string, credentials *domain.ExtensionServiceCredentials) (*domain.ExtensionServiceCredentials, error) { + extension, err := es.GetExtension(ctx, extensionID) + if err != nil { + return nil, err + } + svc, err := extension.GetService(serviceID) + if err != nil { + return nil, err + } + credentials, err = svc.AddCredentials(credentials) + if err != nil { + return nil, err + } + err = es.UpdateExtension(ctx, extension) + if err != nil { + return nil, err + } + return credentials, nil +} + +// GetExtensionServiceCredentials retrieves an extension credential by its ID. +func (es *ExtensionStore) GetExtensionServiceCredentials(ctx context.Context, extensionID string, serviceID string, credentialsID string) (*domain.ExtensionServiceCredentials, error) { + extension, err := es.GetExtension(ctx, extensionID) + if err != nil { + return nil, err + } + svc, err := extension.GetService(serviceID) + if err != nil { + return nil, err + } + return svc.GetCredentials(credentialsID) +} + +// ListExtensionServiceCredentials retrieves all credentials belonging to an extension service. +func (es *ExtensionStore) ListExtensionServiceCredentials(ctx context.Context, extensionID string, serviceID string) ([]*domain.ExtensionServiceCredentials, error) { + extension, err := es.GetExtension(ctx, extensionID) + if err != nil { + return nil, err + } + svc, err := extension.GetService(serviceID) + if err != nil { + return nil, err + } + return svc.ListCredentials(), nil +} + +// UpdateExtensionServiceCredentials updates an extension credential. +func (es *ExtensionStore) UpdateExtensionServiceCredentials(ctx context.Context, extensionID string, serviceID string, newCredentials *domain.ExtensionServiceCredentials) error { + extension, err := es.GetExtension(ctx, extensionID) + if err != nil { + return err + } + svc, err := extension.GetService(serviceID) + if err != nil { + return err + } + err = svc.UpdateCredentials(newCredentials) + if err != nil { + return err + } + return es.UpdateExtension(ctx, extension) +} + +// DeleteExtensionServiceCredentials deletes an extension credential from an extension service. +func (es *ExtensionStore) DeleteExtensionServiceCredentials(ctx context.Context, extensionID string, serviceID string, credentialsID string) error { + extension, err := es.GetExtension(ctx, extensionID) + if err != nil { + return err + } + svc, err := extension.GetService(serviceID) + if err != nil { + return err + } + err = svc.DeleteCredentials(credentialsID) + if err != nil { + return err + } + return es.UpdateExtension(ctx, extension) +} + +// GetExtensionAccessDescriptors retrieves access descriptors belonging to an extension that matches the query. +func (es *ExtensionStore) GetExtensionAccessDescriptors(ctx context.Context, query *domain.ExtensionQuery) (result []*domain.ExtensionAccessDescriptor, err error) { + result = make([]*domain.ExtensionAccessDescriptor, 0) + + for _, extension := range es.ListExtensions(ctx, query) { + result = append(result, extension.GetAccessDescriptors()...) + } + return result, nil +} diff --git a/pkg/core/store/badger/extension_test.go b/pkg/core/store/badger/extension_test.go new file mode 100644 index 00000000..c6e0a852 --- /dev/null +++ b/pkg/core/store/badger/extension_test.go @@ -0,0 +1,881 @@ +package badger + +import ( + "context" + "os" + "testing" + + "github.com/fuseml/fuseml-core/pkg/domain" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "github.com/tektoncd/pipeline/test/diff" + "github.com/timshannon/badgerhold/v3" +) + +func TestAddExtension(t *testing.T) { + t.Run("new", func(t *testing.T) { + store, done := newExtensionStore(t) + defer done() + ctx := context.Background() + + ext, err := store.AddExtension(ctx, &domain.Extension{}) + assertNoError(t, err) + + got, err := store.GetExtension(ctx, ext.ID) + assertNoError(t, err) + + if d := cmp.Diff(ext, got); d != "" { + t.Errorf("Unexpected Extension: %s", diff.PrintWantGot(d)) + } + }) + + t.Run("existing", func(t *testing.T) { + store, done := newExtensionStore(t) + defer done() + ctx := context.Background() + + ext, err := store.AddExtension(ctx, &domain.Extension{}) + assertNoError(t, err) + + _, err = store.AddExtension(ctx, ext) + assertErrorMessage(t, domain.NewErrExtensionExists(ext.ID), err) + }) +} + +func TestGetExtension(t *testing.T) { + t.Run("found", func(t *testing.T) { + store, done := newExtensionStore(t) + defer done() + ctx := context.Background() + + ext, err := store.AddExtension(ctx, &domain.Extension{}) + assertNoError(t, err) + + got, err := store.GetExtension(ctx, ext.ID) + assertNoError(t, err) + + if d := cmp.Diff(ext, got); d != "" { + t.Errorf("Unexpected Extension: %s", diff.PrintWantGot(d)) + } + }) + + t.Run("not found", func(t *testing.T) { + store, done := newExtensionStore(t) + defer done() + ctx := context.Background() + + _, err := store.GetExtension(ctx, "not found") + assertErrorMessage(t, domain.NewErrExtensionNotFound("not found"), err) + }) +} + +func TestGetExtensions(t *testing.T) { + t.Run("all", func(t *testing.T) { + store, done := newExtensionStore(t) + defer done() + ctx := context.Background() + + exts := []*domain.Extension{ + {ID: "1"}, + {ID: "2"}, + {ID: "3"}, + } + + for _, ext := range exts { + _, err := store.AddExtension(ctx, ext) + assertNoError(t, err) + } + + got := store.ListExtensions(ctx, nil) + sortExtensionSlices := cmpopts.SortSlices(func(x, y *domain.Extension) bool { return x.ID < y.ID }) + if d := cmp.Diff(exts, got, sortExtensionSlices); d != "" { + t.Errorf("Unexpected Extensions: %s", diff.PrintWantGot(d)) + } + }) + + t.Run("with query", func(t *testing.T) { + store, done := newExtensionStore(t) + defer done() + ctx := context.Background() + + exts := []*domain.Extension{ + {ID: "1", Product: "p1"}, + {ID: "2", Product: "p1"}, + {ID: "3", Product: "p2"}, + } + + for _, ext := range exts { + _, err := store.AddExtension(ctx, ext) + assertNoError(t, err) + } + + // by ID + got := store.ListExtensions(ctx, &domain.ExtensionQuery{ExtensionID: "3"}) + if d := cmp.Diff(exts[2:], got); d != "" { + t.Errorf("Unexpected Extensions: %s", diff.PrintWantGot(d)) + } + + // by product + got = store.ListExtensions(ctx, &domain.ExtensionQuery{Product: "p1"}) + if d := cmp.Diff(exts[:2], got); d != "" { + t.Errorf("Unexpected Extensions: %s", diff.PrintWantGot(d)) + } + }) + + t.Run("empty", func(t *testing.T) { + store, done := newExtensionStore(t) + defer done() + ctx := context.Background() + + got := store.ListExtensions(ctx, nil) + if d := cmp.Diff([]*domain.Extension{}, got); d != "" { + t.Errorf("Unexpected Extensions: %s", diff.PrintWantGot(d)) + } + + got = store.ListExtensions(ctx, &domain.ExtensionQuery{Product: "p1"}) + if d := cmp.Diff([]*domain.Extension{}, got); d != "" { + t.Errorf("Unexpected Extensions: %s", diff.PrintWantGot(d)) + } + }) +} + +func TestUpdateExtension(t *testing.T) { + t.Run("found", func(t *testing.T) { + store, done := newExtensionStore(t) + defer done() + ctx := context.Background() + + ext, err := store.AddExtension(ctx, &domain.Extension{}) + assertNoError(t, err) + + newExt := &domain.Extension{ + ID: ext.ID, + Product: "p1", + Version: "v1", + Services: map[string]*domain.ExtensionService{"test": {ID: "test"}}, + } + err = store.UpdateExtension(ctx, newExt) + assertNoError(t, err) + + got, err := store.GetExtension(ctx, ext.ID) + assertNoError(t, err) + + if d := cmp.Diff(newExt, got); d != "" { + t.Errorf("Unexpected Extension: %s", diff.PrintWantGot(d)) + } + }) + + t.Run("not found", func(t *testing.T) { + store, done := newExtensionStore(t) + defer done() + ctx := context.Background() + + err := store.UpdateExtension(ctx, &domain.Extension{}) + assertErrorMessage(t, domain.NewErrExtensionNotFound(""), err) + + err = store.UpdateExtension(ctx, &domain.Extension{ID: "not found"}) + assertErrorMessage(t, domain.NewErrExtensionNotFound("not found"), err) + }) +} + +func TestDeleteExtension(t *testing.T) { + t.Run("found", func(t *testing.T) { + store, done := newExtensionStore(t) + defer done() + ctx := context.Background() + + ext, err := store.AddExtension(ctx, &domain.Extension{}) + assertNoError(t, err) + + err = store.DeleteExtension(ctx, ext.ID) + assertNoError(t, err) + + _, err = store.GetExtension(ctx, ext.ID) + assertErrorMessage(t, domain.NewErrExtensionNotFound(ext.ID), err) + }) + + t.Run("not found", func(t *testing.T) { + store, done := newExtensionStore(t) + defer done() + ctx := context.Background() + + err := store.DeleteExtension(ctx, "not found") + assertErrorMessage(t, domain.NewErrExtensionNotFound("not found"), err) + }) +} + +func TestAddExtensionService(t *testing.T) { + t.Run("new", func(t *testing.T) { + store, done := newExtensionStore(t) + defer done() + ctx := context.Background() + + ext, err := store.AddExtension(ctx, &domain.Extension{}) + assertNoError(t, err) + + svc, err := store.AddExtensionService(ctx, ext.ID, &domain.ExtensionService{}) + assertNoError(t, err) + + got, err := store.GetExtensionService(ctx, ext.ID, svc.ID) + assertNoError(t, err) + + if d := cmp.Diff(svc, got); d != "" { + t.Errorf("Unexpected ExtensionService: %s", diff.PrintWantGot(d)) + } + + svc, err = store.AddExtensionService(ctx, ext.ID, &domain.ExtensionService{ID: "test"}) + assertNoError(t, err) + + got, err = store.GetExtensionService(ctx, ext.ID, svc.ID) + assertNoError(t, err) + + if d := cmp.Diff(svc, got); d != "" { + t.Errorf("Unexpected ExtensionService: %s", diff.PrintWantGot(d)) + } + }) + + t.Run("existing", func(t *testing.T) { + store, done := newExtensionStore(t) + defer done() + ctx := context.Background() + + ext, err := store.AddExtension(ctx, &domain.Extension{}) + assertNoError(t, err) + + _, err = store.AddExtensionService(ctx, ext.ID, &domain.ExtensionService{ID: "test"}) + assertNoError(t, err) + + _, err = store.AddExtensionService(ctx, ext.ID, &domain.ExtensionService{ID: "test"}) + assertErrorMessage(t, domain.NewErrExtensionServiceExists(ext.ID, "test"), err) + }) +} +func TestGetExtensionService(t *testing.T) { + t.Run("found", func(t *testing.T) { + store, done := newExtensionStore(t) + defer done() + ctx := context.Background() + + ext, err := store.AddExtension(ctx, &domain.Extension{}) + assertNoError(t, err) + + svc, err := store.AddExtensionService(ctx, ext.ID, &domain.ExtensionService{}) + assertNoError(t, err) + + got, err := store.GetExtensionService(ctx, ext.ID, svc.ID) + assertNoError(t, err) + + if d := cmp.Diff(svc, got); d != "" { + t.Errorf("Unexpected ExtensionService: %s", diff.PrintWantGot(d)) + } + }) + + t.Run("not found", func(t *testing.T) { + store, done := newExtensionStore(t) + defer done() + ctx := context.Background() + + ext, err := store.AddExtension(ctx, &domain.Extension{}) + assertNoError(t, err) + + _, err = store.GetExtensionService(ctx, ext.ID, "not found") + assertErrorMessage(t, domain.NewErrExtensionServiceNotFound(ext.ID, "not found"), err) + }) +} + +func TestGetExtensionServices(t *testing.T) { + t.Run("all", func(t *testing.T) { + store, done := newExtensionStore(t) + defer done() + ctx := context.Background() + + ext, err := store.AddExtension(ctx, &domain.Extension{}) + assertNoError(t, err) + + svcs := []*domain.ExtensionService{ + {ID: "test1"}, + {ID: "test2"}, + } + for _, svc := range svcs { + _, err = store.AddExtensionService(ctx, ext.ID, svc) + assertNoError(t, err) + } + + got, err := store.ListExtensionServices(ctx, ext.ID) + assertNoError(t, err) + + sortServiceSlices := cmpopts.SortSlices(func(x, y *domain.ExtensionService) bool { return x.ID < y.ID }) + if d := cmp.Diff(svcs, got, sortServiceSlices); d != "" { + t.Errorf("Unexpected ExtensionServices: %s", diff.PrintWantGot(d)) + } + }) + + t.Run("empty", func(t *testing.T) { + store, done := newExtensionStore(t) + defer done() + ctx := context.Background() + + ext, err := store.AddExtension(ctx, &domain.Extension{}) + assertNoError(t, err) + + got, err := store.ListExtensionServices(ctx, ext.ID) + assertNoError(t, err) + + if d := cmp.Diff([]*domain.ExtensionService{}, got); d != "" { + t.Errorf("Unexpected ExtensionServices: %s", diff.PrintWantGot(d)) + } + }) +} + +func TestUpdateExtensionService(t *testing.T) { + t.Run("found", func(t *testing.T) { + store, done := newExtensionStore(t) + defer done() + ctx := context.Background() + + ext, err := store.AddExtension(ctx, &domain.Extension{}) + assertNoError(t, err) + + svc, err := store.AddExtensionService(ctx, ext.ID, &domain.ExtensionService{Resource: "test"}) + assertNoError(t, err) + + newSvc := &domain.ExtensionService{ + ID: svc.ID, + Resource: "test-updated"} + err = store.UpdateExtensionService(ctx, ext.ID, newSvc) + assertNoError(t, err) + + got, err := store.GetExtensionService(ctx, ext.ID, svc.ID) + assertNoError(t, err) + + if d := cmp.Diff(newSvc, got); d != "" { + t.Errorf("Unexpected ExtensionService: %s", diff.PrintWantGot(d)) + } + }) + + t.Run("not found", func(t *testing.T) { + store, done := newExtensionStore(t) + defer done() + ctx := context.Background() + + ext, err := store.AddExtension(ctx, &domain.Extension{}) + assertNoError(t, err) + + err = store.UpdateExtensionService(ctx, ext.ID, &domain.ExtensionService{ID: "not found"}) + assertErrorMessage(t, domain.NewErrExtensionServiceNotFound(ext.ID, "not found"), err) + }) +} + +func TestDeleteExtensionService(t *testing.T) { + t.Run("found", func(t *testing.T) { + store, done := newExtensionStore(t) + defer done() + ctx := context.Background() + + ext, err := store.AddExtension(ctx, &domain.Extension{}) + assertNoError(t, err) + + svc, err := store.AddExtensionService(ctx, ext.ID, &domain.ExtensionService{}) + assertNoError(t, err) + + err = store.DeleteExtensionService(ctx, ext.ID, svc.ID) + assertNoError(t, err) + + _, err = store.GetExtensionService(ctx, ext.ID, svc.ID) + assertErrorMessage(t, domain.NewErrExtensionServiceNotFound(ext.ID, svc.ID), err) + }) + + t.Run("not found", func(t *testing.T) { + store, done := newExtensionStore(t) + defer done() + ctx := context.Background() + + ext, err := store.AddExtension(ctx, &domain.Extension{}) + assertNoError(t, err) + + err = store.DeleteExtensionService(ctx, ext.ID, "not found") + assertErrorMessage(t, domain.NewErrExtensionServiceNotFound(ext.ID, "not found"), err) + }) +} + +func TestAddExtensionServiceEndpoint(t *testing.T) { + t.Run("new", func(t *testing.T) { + store, done := newExtensionStore(t) + defer done() + ctx := context.Background() + + ext, err := store.AddExtension(ctx, &domain.Extension{ + Services: map[string]*domain.ExtensionService{"test-svc": {ID: "test-svc"}}}) + assertNoError(t, err) + + endpoint, err := store.AddExtensionServiceEndpoint(ctx, ext.ID, "test-svc", &domain.ExtensionServiceEndpoint{URL: "http://test"}) + assertNoError(t, err) + + got, err := store.GetExtensionServiceEndpoint(ctx, ext.ID, "test-svc", endpoint.URL) + assertNoError(t, err) + + if d := cmp.Diff(endpoint, got); d != "" { + t.Errorf("Unexpected ExtensionServiceEndpoint: %s", diff.PrintWantGot(d)) + } + }) + + t.Run("existing", func(t *testing.T) { + store, done := newExtensionStore(t) + defer done() + ctx := context.Background() + + ext, err := store.AddExtension(ctx, &domain.Extension{ + Services: map[string]*domain.ExtensionService{"test-svc": {ID: "test-svc"}}}) + assertNoError(t, err) + + endpoint, err := store.AddExtensionServiceEndpoint(ctx, ext.ID, "test-svc", &domain.ExtensionServiceEndpoint{URL: "http://test"}) + assertNoError(t, err) + + _, err = store.AddExtensionServiceEndpoint(ctx, ext.ID, "test-svc", &domain.ExtensionServiceEndpoint{URL: "http://test"}) + assertErrorMessage(t, domain.NewErrExtensionServiceEndpointExists("", "test-svc", endpoint.URL), err) + }) +} + +func TestGetExtensionServiceEndpoint(t *testing.T) { + t.Run("found", func(t *testing.T) { + store, done := newExtensionStore(t) + defer done() + ctx := context.Background() + + ext, err := store.AddExtension(ctx, &domain.Extension{ + Services: map[string]*domain.ExtensionService{"test-svc": {ID: "test-svc"}}}) + assertNoError(t, err) + + endpoint, err := store.AddExtensionServiceEndpoint(ctx, ext.ID, "test-svc", &domain.ExtensionServiceEndpoint{URL: "http://test"}) + assertNoError(t, err) + + got, err := store.GetExtensionServiceEndpoint(ctx, ext.ID, "test-svc", endpoint.URL) + assertNoError(t, err) + + if d := cmp.Diff(endpoint, got); d != "" { + t.Errorf("Unexpected ExtensionServiceEndpoint: %s", diff.PrintWantGot(d)) + } + }) + + t.Run("not found", func(t *testing.T) { + store, done := newExtensionStore(t) + defer done() + ctx := context.Background() + + ext, err := store.AddExtension(ctx, &domain.Extension{ + Services: map[string]*domain.ExtensionService{"test-svc": {ID: "test-svc"}}}) + assertNoError(t, err) + + _, err = store.GetExtensionServiceEndpoint(ctx, ext.ID, "test-svc", "not found") + assertErrorMessage(t, domain.NewErrExtensionServiceEndpointNotFound("", "test-svc", "not found"), err) + }) +} + +func TestGetExtensionServiceEndpoints(t *testing.T) { + t.Run("found", func(t *testing.T) { + store, done := newExtensionStore(t) + defer done() + ctx := context.Background() + + ext, err := store.AddExtension(ctx, &domain.Extension{ + Services: map[string]*domain.ExtensionService{"test-svc": {ID: "test-svc"}}}) + assertNoError(t, err) + + endpoint, err := store.AddExtensionServiceEndpoint(ctx, ext.ID, "test-svc", &domain.ExtensionServiceEndpoint{URL: "http://test"}) + assertNoError(t, err) + + got, err := store.ListExtensionServiceEndpoints(ctx, ext.ID, "test-svc") + assertNoError(t, err) + + if d := cmp.Diff([]*domain.ExtensionServiceEndpoint{endpoint}, got); d != "" { + t.Errorf("Unexpected ExtensionServiceEndpoint: %s", diff.PrintWantGot(d)) + } + }) + + t.Run("not found", func(t *testing.T) { + store, done := newExtensionStore(t) + defer done() + ctx := context.Background() + + ext, err := store.AddExtension(ctx, &domain.Extension{ + Services: map[string]*domain.ExtensionService{"test-svc": {ID: "test-svc"}}}) + assertNoError(t, err) + + got, err := store.ListExtensionServiceEndpoints(ctx, ext.ID, "test-svc") + assertNoError(t, err) + + if d := cmp.Diff([]*domain.ExtensionServiceEndpoint{}, got); d != "" { + t.Errorf("Unexpected ExtensionServiceEndpoint: %s", diff.PrintWantGot(d)) + } + }) +} + +func TestUpdateExtensionServiceEndpoint(t *testing.T) { + t.Run("found", func(t *testing.T) { + store, done := newExtensionStore(t) + defer done() + ctx := context.Background() + + ext, err := store.AddExtension(ctx, &domain.Extension{ + Services: map[string]*domain.ExtensionService{"test-svc": {ID: "test-svc"}}}) + assertNoError(t, err) + + endpoint, err := store.AddExtensionServiceEndpoint(ctx, ext.ID, "test-svc", &domain.ExtensionServiceEndpoint{URL: "http://test"}) + assertNoError(t, err) + + newEndpoint := &domain.ExtensionServiceEndpoint{URL: "http://test", Type: "test-updated"} + + err = store.UpdateExtensionServiceEndpoint(ctx, ext.ID, "test-svc", newEndpoint) + assertNoError(t, err) + + got, err := store.GetExtensionServiceEndpoint(ctx, ext.ID, "test-svc", endpoint.URL) + assertNoError(t, err) + + if d := cmp.Diff(newEndpoint, got); d != "" { + t.Errorf("Unexpected ExtensionServiceEndpoint: %s", diff.PrintWantGot(d)) + } + }) + + t.Run("not found", func(t *testing.T) { + store, done := newExtensionStore(t) + defer done() + ctx := context.Background() + + ext, err := store.AddExtension(ctx, &domain.Extension{ + Services: map[string]*domain.ExtensionService{"test-svc": {ID: "test-svc"}}}) + assertNoError(t, err) + + err = store.UpdateExtensionServiceEndpoint(ctx, ext.ID, "test-svc", &domain.ExtensionServiceEndpoint{URL: "http://test"}) + assertErrorMessage(t, domain.NewErrExtensionServiceEndpointNotFound("", "test-svc", "http://test"), err) + }) +} + +func TestDeleteExtensionServiceEndpoint(t *testing.T) { + t.Run("found", func(t *testing.T) { + store, done := newExtensionStore(t) + defer done() + ctx := context.Background() + + ext, err := store.AddExtension(ctx, &domain.Extension{ + Services: map[string]*domain.ExtensionService{"test-svc": {ID: "test-svc"}}}) + assertNoError(t, err) + + endpoint, err := store.AddExtensionServiceEndpoint(ctx, ext.ID, "test-svc", &domain.ExtensionServiceEndpoint{URL: "http://test"}) + assertNoError(t, err) + + err = store.DeleteExtensionServiceEndpoint(ctx, ext.ID, "test-svc", endpoint.URL) + assertNoError(t, err) + + _, err = store.GetExtensionServiceEndpoint(ctx, ext.ID, "test-svc", endpoint.URL) + assertErrorMessage(t, domain.NewErrExtensionServiceEndpointNotFound("", "test-svc", endpoint.URL), err) + }) + + t.Run("not found", func(t *testing.T) { + store, done := newExtensionStore(t) + defer done() + ctx := context.Background() + + ext, err := store.AddExtension(ctx, &domain.Extension{ + Services: map[string]*domain.ExtensionService{"test-svc": {ID: "test-svc"}}}) + assertNoError(t, err) + + err = store.DeleteExtensionServiceEndpoint(ctx, ext.ID, "test-svc", "http://test") + assertErrorMessage(t, domain.NewErrExtensionServiceEndpointNotFound("", "test-svc", "http://test"), err) + }) +} + +func TestAddExtensionServiceCredential(t *testing.T) { + t.Run("new", func(t *testing.T) { + store, done := newExtensionStore(t) + defer done() + ctx := context.Background() + + ext, err := store.AddExtension(ctx, &domain.Extension{ + Services: map[string]*domain.ExtensionService{"test-svc": {ID: "test-svc"}}}) + assertNoError(t, err) + + cred, err := store.AddExtensionServiceCredentials(ctx, ext.ID, "test-svc", &domain.ExtensionServiceCredentials{}) + assertNoError(t, err) + + got, err := store.GetExtensionServiceCredentials(ctx, ext.ID, "test-svc", cred.ID) + assertNoError(t, err) + + if d := cmp.Diff(cred, got); d != "" { + t.Errorf("Unexpected ExtensionServiceCredential: %s", diff.PrintWantGot(d)) + } + }) + + t.Run("existing", func(t *testing.T) { + store, done := newExtensionStore(t) + defer done() + ctx := context.Background() + + ext, err := store.AddExtension(ctx, &domain.Extension{ + Services: map[string]*domain.ExtensionService{"test-svc": {ID: "test-svc"}}}) + assertNoError(t, err) + + cred, err := store.AddExtensionServiceCredentials(ctx, ext.ID, "test-svc", &domain.ExtensionServiceCredentials{}) + assertNoError(t, err) + + _, err = store.AddExtensionServiceCredentials(ctx, ext.ID, "test-svc", &domain.ExtensionServiceCredentials{ID: cred.ID}) + assertErrorMessage(t, domain.NewErrExtensionServiceCredentialsExists("", "test-svc", cred.ID), err) + + }) +} + +func TestGetExtensionServiceCredential(t *testing.T) { + t.Run("found", func(t *testing.T) { + store, done := newExtensionStore(t) + defer done() + ctx := context.Background() + + ext, err := store.AddExtension(ctx, &domain.Extension{ + Services: map[string]*domain.ExtensionService{"test-svc": {ID: "test-svc"}}}) + assertNoError(t, err) + + cred, err := store.AddExtensionServiceCredentials(ctx, ext.ID, "test-svc", &domain.ExtensionServiceCredentials{}) + assertNoError(t, err) + + got, err := store.GetExtensionServiceCredentials(ctx, ext.ID, "test-svc", cred.ID) + assertNoError(t, err) + + if d := cmp.Diff(cred, got); d != "" { + t.Errorf("Unexpected ExtensionServiceCredential: %s", diff.PrintWantGot(d)) + } + }) + + t.Run("not found", func(t *testing.T) { + store, done := newExtensionStore(t) + defer done() + ctx := context.Background() + + ext, err := store.AddExtension(ctx, &domain.Extension{ + Services: map[string]*domain.ExtensionService{"test-svc": {ID: "test-svc"}}}) + assertNoError(t, err) + + _, err = store.GetExtensionServiceCredentials(ctx, ext.ID, "test-svc", "test") + assertErrorMessage(t, domain.NewErrExtensionServiceCredentialsNotFound("", "test-svc", "test"), err) + }) +} + +func TestGetExtensionServiceCredentials(t *testing.T) { + t.Run("found", func(t *testing.T) { + store, done := newExtensionStore(t) + defer done() + ctx := context.Background() + + ext, err := store.AddExtension(ctx, &domain.Extension{ + Services: map[string]*domain.ExtensionService{"test-svc": {ID: "test-svc"}}}) + assertNoError(t, err) + + cred, err := store.AddExtensionServiceCredentials(ctx, ext.ID, "test-svc", &domain.ExtensionServiceCredentials{}) + assertNoError(t, err) + + got, err := store.ListExtensionServiceCredentials(ctx, ext.ID, "test-svc") + assertNoError(t, err) + + if d := cmp.Diff([]*domain.ExtensionServiceCredentials{cred}, got); d != "" { + t.Errorf("Unexpected ExtensionServiceCredential: %s", diff.PrintWantGot(d)) + } + }) + + t.Run("not found", func(t *testing.T) { + store, done := newExtensionStore(t) + defer done() + ctx := context.Background() + + ext, err := store.AddExtension(ctx, &domain.Extension{ + Services: map[string]*domain.ExtensionService{"test-svc": {ID: "test-svc"}}}) + assertNoError(t, err) + + got, err := store.ListExtensionServiceCredentials(ctx, ext.ID, "test-svc") + assertNoError(t, err) + + if d := cmp.Diff([]*domain.ExtensionServiceCredentials{}, got); d != "" { + t.Errorf("Unexpected ExtensionServiceCredential: %s", diff.PrintWantGot(d)) + } + }) +} + +func TestUpdateExtensionServiceCredential(t *testing.T) { + t.Run("found", func(t *testing.T) { + store, done := newExtensionStore(t) + defer done() + ctx := context.Background() + + ext, err := store.AddExtension(ctx, &domain.Extension{ + Services: map[string]*domain.ExtensionService{"test-svc": {ID: "test-svc"}}}) + assertNoError(t, err) + + cred, err := store.AddExtensionServiceCredentials(ctx, ext.ID, "test-svc", &domain.ExtensionServiceCredentials{}) + assertNoError(t, err) + + newCred := &domain.ExtensionServiceCredentials{ + ID: cred.ID, + Scope: domain.ECSGlobal, + } + + err = store.UpdateExtensionServiceCredentials(ctx, ext.ID, "test-svc", newCred) + assertNoError(t, err) + + got, err := store.GetExtensionServiceCredentials(ctx, ext.ID, "test-svc", cred.ID) + assertNoError(t, err) + + if d := cmp.Diff(newCred, got); d != "" { + t.Errorf("Unexpected ExtensionServiceCredential: %s", diff.PrintWantGot(d)) + } + }) + + t.Run("not found", func(t *testing.T) { + store, done := newExtensionStore(t) + defer done() + ctx := context.Background() + + ext, err := store.AddExtension(ctx, &domain.Extension{ + Services: map[string]*domain.ExtensionService{"test-svc": {ID: "test-svc"}}}) + assertNoError(t, err) + + err = store.UpdateExtensionServiceCredentials(ctx, ext.ID, "test-svc", &domain.ExtensionServiceCredentials{ID: "not found"}) + assertErrorMessage(t, domain.NewErrExtensionServiceCredentialsNotFound("", "test-svc", "not found"), err) + }) +} + +func TestDeleteExtensionServiceCredential(t *testing.T) { + t.Run("found", func(t *testing.T) { + store, done := newExtensionStore(t) + defer done() + ctx := context.Background() + + ext, err := store.AddExtension(ctx, &domain.Extension{ + Services: map[string]*domain.ExtensionService{"test-svc": {ID: "test-svc"}}}) + assertNoError(t, err) + + cred, err := store.AddExtensionServiceCredentials(ctx, ext.ID, "test-svc", &domain.ExtensionServiceCredentials{}) + assertNoError(t, err) + + err = store.DeleteExtensionServiceCredentials(ctx, ext.ID, "test-svc", cred.ID) + assertNoError(t, err) + + _, err = store.GetExtensionServiceCredentials(ctx, ext.ID, "test-svc", cred.ID) + assertErrorMessage(t, domain.NewErrExtensionServiceCredentialsNotFound("", "test-svc", cred.ID), err) + }) + + t.Run("not found", func(t *testing.T) { + store, done := newExtensionStore(t) + defer done() + ctx := context.Background() + + ext, err := store.AddExtension(ctx, &domain.Extension{ + Services: map[string]*domain.ExtensionService{"test-svc": {ID: "test-svc"}}}) + assertNoError(t, err) + + err = store.DeleteExtensionServiceCredentials(ctx, ext.ID, "test-svc", "not found") + assertErrorMessage(t, domain.NewErrExtensionServiceCredentialsNotFound("", "test-svc", "not found"), err) + }) +} + +func TestGetExtensionAccessDescriptors(t *testing.T) { + t.Run("found", func(t *testing.T) { + store, done := newExtensionStore(t) + defer done() + ctx := context.Background() + + ext, err := store.AddExtension(ctx, &domain.Extension{ + Services: map[string]*domain.ExtensionService{ + "test-svc1": { + ID: "test-svc1", + Endpoints: map[string]*domain.ExtensionServiceEndpoint{"http://test1": {URL: "http://test1"}}, + }, + "test-svc2": { + ID: "test-svc2", + Endpoints: map[string]*domain.ExtensionServiceEndpoint{"http://test2": {URL: "http://test2"}}, + }}}) + assertNoError(t, err) + + got, err := store.GetExtensionAccessDescriptors(ctx, &domain.ExtensionQuery{ServiceID: "test-svc1"}) + assertNoError(t, err) + + endpoint1 := domain.ExtensionServiceEndpoint{ + URL: "http://test1", + Created: ext.Services["test-svc1"].Endpoints["http://test1"].Created, + Updated: ext.Services["test-svc1"].Endpoints["http://test1"].Updated, + } + + svc1 := domain.ExtensionService{ + ID: "test-svc1", + Created: ext.Services["test-svc1"].Created, + Updated: ext.Services["test-svc1"].Updated, + Endpoints: map[string]*domain.ExtensionServiceEndpoint{ + endpoint1.URL: &endpoint1, + }, + Credentials: map[string]*domain.ExtensionServiceCredentials{}, + } + + want := []*domain.ExtensionAccessDescriptor{{ + Extension: domain.Extension{ + ID: ext.ID, + Created: ext.Created, + Updated: ext.Updated, + Services: map[string]*domain.ExtensionService{ + svc1.ID: &svc1, + }, + }, + Service: svc1, + Endpoint: endpoint1, + }} + + if d := cmp.Diff(want, got); d != "" { + t.Errorf("Unexpected ExtensionAccessDescriptor: %s", diff.PrintWantGot(d)) + } + }) + + t.Run("not found", func(t *testing.T) { + store, done := newExtensionStore(t) + defer done() + ctx := context.Background() + + _, err := store.AddExtension(ctx, &domain.Extension{ + Services: map[string]*domain.ExtensionService{"test-svc": {ID: "test-svc"}}}) + assertNoError(t, err) + + got, err := store.GetExtensionAccessDescriptors(ctx, &domain.ExtensionQuery{ServiceID: "not found"}) + assertNoError(t, err) + + if d := cmp.Diff([]*domain.ExtensionAccessDescriptor{}, got); d != "" { + t.Errorf("Unexpected ExtensionAccessDescriptor: %s", diff.PrintWantGot(d)) + } + }) +} + +func newExtensionStore(t *testing.T) (*ExtensionStore, func()) { + t.Helper() + + dir := tmpDir(t) + opt := badgerhold.DefaultOptions + opt.Logger = nil + opt.Dir = dir + opt.ValueDir = dir + + store, err := badgerhold.Open(opt) + if err != nil { + t.Fatalf("failed to open store: %v", err) + } + + workflowStore := NewExtensionStore(store) + + return workflowStore, func() { + store.Close() + os.RemoveAll(dir) + } +} + +func assertErrorMessage(t *testing.T, want error, got error) { + t.Helper() + + if got == nil { + t.Fatalf("expected error, got nil") + } + + if want.Error() != got.Error() { + t.Errorf("expected error, got %q but want %q", got, want) + } +}