Skip to content

Commit

Permalink
Add badger store for Extension
Browse files Browse the repository at this point in the history
Implement ExtensionStore interface using Badger as storage.

(cherry picked from commit a0b11db)
  • Loading branch information
flaviodsr committed Sep 14, 2021
1 parent 4fb4a0d commit d04c506
Show file tree
Hide file tree
Showing 4 changed files with 1,227 additions and 5 deletions.
4 changes: 2 additions & 2 deletions cmd/fuseml_core/wire.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ var storeSet = wire.NewSet(
wire.Bind(new(domain.RunnableStore), new(*core.RunnableStore)),
badger.NewWorkflowStore,
wire.Bind(new(domain.WorkflowStore), new(*badger.WorkflowStore)),
core.NewExtensionStore,
wire.Bind(new(domain.ExtensionStore), new(*core.ExtensionStore)),
badger.NewExtensionStore,
wire.Bind(new(domain.ExtensionStore), new(*badger.ExtensionStore)),
)

var managerSet = wire.NewSet(
Expand Down
7 changes: 4 additions & 3 deletions cmd/fuseml_core/wire_gen.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

340 changes: 340 additions & 0 deletions pkg/core/store/badger/extension.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,340 @@
package badger

import (
"context"
"time"

"github.com/fuseml/fuseml-core/pkg/domain"
"github.com/timshannon/badgerhold/v3"
)

// ExtensionStore is a wrapper around a badgerhold.Store that implements the domain.ExtensionStore interface.
type ExtensionStore struct {
store *badgerhold.Store
}

// NewExtensionStore creates a new ExtensionStore.
func NewExtensionStore(store *badgerhold.Store) *ExtensionStore {
return &ExtensionStore{store: store}
}

// AddExtension adds a new extension to the store.
func (es *ExtensionStore) AddExtension(ctx context.Context, extension *domain.Extension) (*domain.Extension, error) {
extension.EnsureID(ctx, es)
extension.SetCreated(ctx)

err := es.store.Insert(extension.ID, extension)
if err != nil {
return nil, domain.NewErrExtensionExists(extension.ID)
}
return extension, nil
}

// GetExtension retrieves an extension by its ID.
func (es *ExtensionStore) GetExtension(ctx context.Context, extensionID string) (*domain.Extension, error) {
extension := &domain.Extension{}
err := es.store.Get(extensionID, extension)
if err != nil {
return nil, domain.NewErrExtensionNotFound(extensionID)
}
return extension, nil
}

// ListExtensions retrieves all stored extensions.
func (es *ExtensionStore) ListExtensions(ctx context.Context, query *domain.ExtensionQuery) (result []*domain.Extension) {
result = []*domain.Extension{}

// TODO: Replace with a badgerhold query.
if query != nil {
if query.ExtensionID != "" {
fullExtension, err := es.GetExtension(ctx, query.ExtensionID)
if err == nil {
matchingExtension := fullExtension.GetExtensionIfMatch(query)
if matchingExtension != nil {
result = append(result, matchingExtension)
}
}
return
}

allExtensions := []*domain.Extension{}
es.store.Find(&allExtensions, nil)

for _, extension := range allExtensions {
matchingExtension := extension.GetExtensionIfMatch(query)
if matchingExtension != nil {
result = append(result, matchingExtension)
}
}
return
}

es.store.Find(&result, nil)
return
}

// UpdateExtension updates an existing extension.
func (es *ExtensionStore) UpdateExtension(ctx context.Context, newExtension *domain.Extension) error {
extension, err := es.GetExtension(ctx, newExtension.ID)
if err != nil {
return err
}
newExtension.Created = extension.Created
newExtension.Updated = time.Now()

for _, newExtService := range newExtension.ListServices() {
_, err := extension.GetService(newExtService.ID)
if err != nil {
// If the service is new, set the creation time
newExtService.SetCreated(newExtension.Updated)
}
}

err = es.store.Update(newExtension.ID, newExtension)
if err != nil {
return domain.NewErrExtensionNotFound(newExtension.ID)
}
return nil
}

// DeleteExtension deletes an extension from the store.
func (es *ExtensionStore) DeleteExtension(ctx context.Context, extensionID string) error {
extension, err := es.GetExtension(ctx, extensionID)
if err != nil {
return err
}
return es.store.Delete(extension.ID, extension)
}

// AddExtensionService adds a new extension service to an extension.
func (es *ExtensionStore) AddExtensionService(ctx context.Context, extensionID string, service *domain.ExtensionService) (*domain.ExtensionService, error) {
extension, err := es.GetExtension(ctx, extensionID)
if err != nil {
return nil, err
}
svc, err := extension.AddService(service)
if err != nil {
return nil, err
}
err = es.UpdateExtension(ctx, extension)
if err != nil {
return nil, err
}
return svc, nil
}

// GetExtensionService retrieves an extension service by its ID.
func (es *ExtensionStore) GetExtensionService(ctx context.Context, extensionID string, serviceID string) (*domain.ExtensionService, error) {
extension, err := es.GetExtension(ctx, extensionID)
if err != nil {
return nil, err
}
return extension.GetService(serviceID)
}

// ListExtensionServices retrieves all services belonging to an extension.
func (es *ExtensionStore) ListExtensionServices(ctx context.Context, extensionID string) ([]*domain.ExtensionService, error) {
extension, err := es.GetExtension(ctx, extensionID)
if err != nil {
return nil, err
}
return extension.ListServices(), nil
}

// UpdateExtensionService updates a service belonging to an extension.
func (es *ExtensionStore) UpdateExtensionService(ctx context.Context, extensionID string, newService *domain.ExtensionService) error {
extension, err := es.GetExtension(ctx, extensionID)
if err != nil {
return err
}
err = extension.UpdateService(newService)
if err != nil {
return err
}
return es.UpdateExtension(ctx, extension)
}

// DeleteExtensionService deletes an extension service from an extension.
func (es *ExtensionStore) DeleteExtensionService(ctx context.Context, extensionID string, serviceID string) error {
extension, err := es.GetExtension(ctx, extensionID)
if err != nil {
return err
}
err = extension.DeleteService(serviceID)
if err != nil {
return err
}
return es.UpdateExtension(ctx, extension)
}

// AddExtensionServiceEndpoint adds a new endpoint to an extension service.
func (es *ExtensionStore) AddExtensionServiceEndpoint(ctx context.Context, extensionID string, serviceID string, endpoint *domain.ExtensionServiceEndpoint) (*domain.ExtensionServiceEndpoint, error) {
extension, err := es.GetExtension(ctx, extensionID)
if err != nil {
return nil, err
}
svc, err := extension.GetService(serviceID)
if err != nil {
return nil, err
}
endpoint, err = svc.AddEndpoint(endpoint)
if err != nil {
return nil, err
}
err = es.UpdateExtension(ctx, extension)
if err != nil {
return nil, err
}
return endpoint, nil
}

// GetExtensionServiceEndpoint retrieves an extension endpoint by its ID.
func (es *ExtensionStore) GetExtensionServiceEndpoint(ctx context.Context, extensionID string, serviceID string, endpointID string) (*domain.ExtensionServiceEndpoint, error) {
extension, err := es.GetExtension(ctx, extensionID)
if err != nil {
return nil, err
}
svc, err := extension.GetService(serviceID)
if err != nil {
return nil, err
}
return svc.GetEndpoint(endpointID)
}

// ListExtensionServiceEndpoints retrieves all endpoints belonging to an extension service.
func (es *ExtensionStore) ListExtensionServiceEndpoints(ctx context.Context, extensionID string, serviceID string) ([]*domain.ExtensionServiceEndpoint, error) {
extension, err := es.GetExtension(ctx, extensionID)
if err != nil {
return nil, err
}
svc, err := extension.GetService(serviceID)
if err != nil {
return nil, err
}
return svc.ListEndpoints(), nil
}

// UpdateExtensionServiceEndpoint updates an endpoint belonging to an extension service.
func (es *ExtensionStore) UpdateExtensionServiceEndpoint(ctx context.Context, extensionID string, serviceID string, newEndpoint *domain.ExtensionServiceEndpoint) error {
extension, err := es.GetExtension(ctx, extensionID)
if err != nil {
return err
}
svc, err := extension.GetService(serviceID)
if err != nil {
return err
}
err = svc.UpdateEndpoint(newEndpoint)
if err != nil {
return err
}
return es.UpdateExtension(ctx, extension)
}

// DeleteExtensionServiceEndpoint deletes an extension endpoint from an extension service.
func (es *ExtensionStore) DeleteExtensionServiceEndpoint(ctx context.Context, extensionID string, serviceID string, endpointID string) error {
extension, err := es.GetExtension(ctx, extensionID)
if err != nil {
return err
}
svc, err := extension.GetService(serviceID)
if err != nil {
return err
}
err = svc.DeleteEndpoint(endpointID)
if err != nil {
return err
}
return es.UpdateExtension(ctx, extension)
}

// AddExtensionServiceCredentials adds a new credential to an extension service.
func (es *ExtensionStore) AddExtensionServiceCredentials(ctx context.Context, extensionID string, serviceID string, credentials *domain.ExtensionServiceCredentials) (*domain.ExtensionServiceCredentials, error) {
extension, err := es.GetExtension(ctx, extensionID)
if err != nil {
return nil, err
}
svc, err := extension.GetService(serviceID)
if err != nil {
return nil, err
}
credentials, err = svc.AddCredentials(credentials)
if err != nil {
return nil, err
}
err = es.UpdateExtension(ctx, extension)
if err != nil {
return nil, err
}
return credentials, nil
}

// GetExtensionServiceCredentials retrieves an extension credential by its ID.
func (es *ExtensionStore) GetExtensionServiceCredentials(ctx context.Context, extensionID string, serviceID string, credentialsID string) (*domain.ExtensionServiceCredentials, error) {
extension, err := es.GetExtension(ctx, extensionID)
if err != nil {
return nil, err
}
svc, err := extension.GetService(serviceID)
if err != nil {
return nil, err
}
return svc.GetCredentials(credentialsID)
}

// ListExtensionServiceCredentials retrieves all credentials belonging to an extension service.
func (es *ExtensionStore) ListExtensionServiceCredentials(ctx context.Context, extensionID string, serviceID string) ([]*domain.ExtensionServiceCredentials, error) {
extension, err := es.GetExtension(ctx, extensionID)
if err != nil {
return nil, err
}
svc, err := extension.GetService(serviceID)
if err != nil {
return nil, err
}
return svc.ListCredentials(), nil
}

// UpdateExtensionServiceCredentials updates an extension credential.
func (es *ExtensionStore) UpdateExtensionServiceCredentials(ctx context.Context, extensionID string, serviceID string, newCredentials *domain.ExtensionServiceCredentials) error {
extension, err := es.GetExtension(ctx, extensionID)
if err != nil {
return err
}
svc, err := extension.GetService(serviceID)
if err != nil {
return err
}
err = svc.UpdateCredentials(newCredentials)
if err != nil {
return err
}
return es.UpdateExtension(ctx, extension)
}

// DeleteExtensionServiceCredentials deletes an extension credential from an extension service.
func (es *ExtensionStore) DeleteExtensionServiceCredentials(ctx context.Context, extensionID string, serviceID string, credentialsID string) error {
extension, err := es.GetExtension(ctx, extensionID)
if err != nil {
return err
}
svc, err := extension.GetService(serviceID)
if err != nil {
return err
}
err = svc.DeleteCredentials(credentialsID)
if err != nil {
return err
}
return es.UpdateExtension(ctx, extension)
}

// GetExtensionAccessDescriptors retrieves access descriptors belonging to an extension that matches the query.
func (es *ExtensionStore) GetExtensionAccessDescriptors(ctx context.Context, query *domain.ExtensionQuery) (result []*domain.ExtensionAccessDescriptor, err error) {
result = make([]*domain.ExtensionAccessDescriptor, 0)

for _, extension := range es.ListExtensions(ctx, query) {
result = append(result, extension.GetAccessDescriptors()...)
}
return result, nil
}

0 comments on commit d04c506

Please sign in to comment.