diff --git a/go.mod b/go.mod index b42085d07..eb30f352b 100644 --- a/go.mod +++ b/go.mod @@ -14,9 +14,9 @@ require ( github.com/coreos/go-oidc v2.2.1+incompatible github.com/evanphx/json-patch v4.12.0+incompatible github.com/flyteorg/flyteidl v1.2.5 - github.com/flyteorg/flyteplugins v1.0.18 - github.com/flyteorg/flytepropeller v1.1.47 - github.com/flyteorg/flytestdlib v1.0.12 + github.com/flyteorg/flyteplugins v1.0.20 + github.com/flyteorg/flytepropeller v1.1.51 + github.com/flyteorg/flytestdlib v1.0.14 github.com/flyteorg/stow v0.3.6 github.com/ghodss/yaml v1.0.0 github.com/go-gormigrate/gormigrate/v2 v2.0.0 diff --git a/go.sum b/go.sum index 4bd8bb2f9..5d9e68acf 100644 --- a/go.sum +++ b/go.sum @@ -354,13 +354,13 @@ github.com/felixge/httpsnoop v1.0.1 h1:lvB5Jl89CsZtGIWuTcDM1E/vkVs49/Ml7JJe07l8S github.com/felixge/httpsnoop v1.0.1/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= github.com/flyteorg/flyteidl v1.2.5 h1:oPs0PX9opR9JtWjP5ZH2YMChkbGGL45PIy+90FlaxYc= github.com/flyteorg/flyteidl v1.2.5/go.mod h1:OJAq333OpInPnMhvVz93AlEjmlQ+t0FAD4aakIYE4OU= -github.com/flyteorg/flyteplugins v1.0.18 h1:DOyxAFaS4luv7H9XRKUpHbO09imsG4LP8Du515FGXyM= -github.com/flyteorg/flyteplugins v1.0.18/go.mod h1:ZbZVBxEWh8Icj1AgfNKg0uPzHHGd9twa4eWcY2Yt6xE= -github.com/flyteorg/flytepropeller v1.1.47 h1:k+moR+YGOyKJnYHDZjBBXvwnuZJ7IhK/PRv/9Ak/QIs= -github.com/flyteorg/flytepropeller v1.1.47/go.mod h1:vZlQTBOsddrNGxmA0To+B2ld3VFg6sRWwcC4KU7+g9A= +github.com/flyteorg/flyteplugins v1.0.20 h1:8ZGN2c0iaZa3d/UmN2VYozLBRhthAIO48aD5g8Wly7s= +github.com/flyteorg/flyteplugins v1.0.20/go.mod h1:ZbZVBxEWh8Icj1AgfNKg0uPzHHGd9twa4eWcY2Yt6xE= +github.com/flyteorg/flytepropeller v1.1.51 h1:ITPH2Fqx+/1hKBFnfb6Rawws3VbEJ3tQ/1tQXSIXvcQ= +github.com/flyteorg/flytepropeller v1.1.51/go.mod h1:zstMUz30mIskZB4uMkObzOj3CjsGfXIV/+nVxlOmI7I= github.com/flyteorg/flytestdlib v1.0.0/go.mod h1:QSVN5wIM1lM9d60eAEbX7NwweQXW96t5x4jbyftn89c= -github.com/flyteorg/flytestdlib v1.0.12 h1:A+yN5TX/SezjCjzv/JV29SzlBAyKGeLDOfAiYqzrKcw= -github.com/flyteorg/flytestdlib v1.0.12/go.mod h1:nIBmBHtjTJvhZEn3e/EwVC/iMkR2tUX8hEiXjRBpH/s= +github.com/flyteorg/flytestdlib v1.0.14 h1:P6hy9yVrIEUxp4JaxV7/KwTSTYjHGizQu1fKXYkq9Y8= +github.com/flyteorg/flytestdlib v1.0.14/go.mod h1:nIBmBHtjTJvhZEn3e/EwVC/iMkR2tUX8hEiXjRBpH/s= github.com/flyteorg/stow v0.3.3/go.mod h1:HBld7ud0i4khMHwJjkO8v+NSP7ddKa/ruhf4I8fliaA= github.com/flyteorg/stow v0.3.6 h1:jt50ciM14qhKBaIrB+ppXXY+SXB59FNREFgTJqCyqIk= github.com/flyteorg/stow v0.3.6/go.mod h1:5dfBitPM004dwaZdoVylVjxFT4GWAgI0ghAndhNUzCo= diff --git a/pkg/common/entity.go b/pkg/common/entity.go index 10e64decd..65d3161cf 100644 --- a/pkg/common/entity.go +++ b/pkg/common/entity.go @@ -17,6 +17,7 @@ const ( NamedEntity = "nen" NamedEntityMetadata = "nem" Project = "p" + Signal = "s" ) // ResourceTypeToEntity maps a resource type to an entity suitable for use with Database filters diff --git a/pkg/manager/impl/signal_manager.go b/pkg/manager/impl/signal_manager.go new file mode 100644 index 000000000..df2fbcc7b --- /dev/null +++ b/pkg/manager/impl/signal_manager.go @@ -0,0 +1,160 @@ +package impl + +import ( + "context" + "strconv" + + "github.com/flyteorg/flytestdlib/contextutils" + + "github.com/flyteorg/flyteadmin/pkg/common" + "github.com/flyteorg/flyteadmin/pkg/errors" + "github.com/flyteorg/flyteadmin/pkg/manager/impl/util" + "github.com/flyteorg/flyteadmin/pkg/manager/impl/validation" + "github.com/flyteorg/flyteadmin/pkg/manager/interfaces" + repoInterfaces "github.com/flyteorg/flyteadmin/pkg/repositories/interfaces" + "github.com/flyteorg/flyteadmin/pkg/repositories/transformers" + + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" + + "github.com/flyteorg/flytestdlib/logger" + "github.com/flyteorg/flytestdlib/promutils" + "github.com/flyteorg/flytestdlib/promutils/labeled" + + "google.golang.org/grpc/codes" +) + +type signalMetrics struct { + Scope promutils.Scope + Set labeled.Counter +} + +type SignalManager struct { + db repoInterfaces.Repository + metrics signalMetrics +} + +func getSignalContext(ctx context.Context, identifier *core.SignalIdentifier) context.Context { + ctx = contextutils.WithProjectDomain(ctx, identifier.ExecutionId.Project, identifier.ExecutionId.Domain) + ctx = contextutils.WithWorkflowID(ctx, identifier.ExecutionId.Name) + return contextutils.WithSignalID(ctx, identifier.SignalId) +} + +func (s *SignalManager) GetOrCreateSignal(ctx context.Context, request admin.SignalGetOrCreateRequest) (*admin.Signal, error) { + if err := validation.ValidateSignalGetOrCreateRequest(ctx, request); err != nil { + logger.Debugf(ctx, "invalid request [%+v]: %v", request, err) + return nil, err + } + ctx = getSignalContext(ctx, request.Id) + + signalModel, err := transformers.CreateSignalModel(request.Id, request.Type, nil) + if err != nil { + logger.Errorf(ctx, "Failed to transform signal with id [%+v] and type [+%v] with err: %v", request.Id, request.Type, err) + return nil, err + } + + err = s.db.SignalRepo().GetOrCreate(ctx, &signalModel) + if err != nil { + return nil, err + } + + signal, err := transformers.FromSignalModel(signalModel) + if err != nil { + logger.Errorf(ctx, "Failed to transform signal model [%+v] with err: %v", signalModel, err) + return nil, err + } + + return &signal, nil +} + +func (s *SignalManager) ListSignals(ctx context.Context, request admin.SignalListRequest) (*admin.SignalList, error) { + if err := validation.ValidateSignalListRequest(ctx, request); err != nil { + logger.Debugf(ctx, "ListSignals request [%+v] is invalid: %v", request, err) + return nil, err + } + ctx = getExecutionContext(ctx, request.WorkflowExecutionId) + + identifierFilters, err := util.GetWorkflowExecutionIdentifierFilters(ctx, *request.WorkflowExecutionId) + if err != nil { + return nil, err + } + + filters, err := util.AddRequestFilters(request.Filters, common.Signal, identifierFilters) + if err != nil { + return nil, err + } + var sortParameter common.SortParameter + if request.SortBy != nil { + sortParameter, err = common.NewSortParameter(*request.SortBy) + if err != nil { + return nil, err + } + } + + offset, err := validation.ValidateToken(request.Token) + if err != nil { + return nil, errors.NewFlyteAdminErrorf(codes.InvalidArgument, + "invalid pagination token %s for ListSignals", request.Token) + } + + signalModelList, err := s.db.SignalRepo().List(ctx, repoInterfaces.ListResourceInput{ + InlineFilters: filters, + Offset: offset, + Limit: int(request.Limit), + SortParameter: sortParameter, + }) + if err != nil { + logger.Debugf(ctx, "Failed to list signals with request [%+v] with err %v", + request, err) + return nil, err + } + + signalList, err := transformers.FromSignalModels(signalModelList) + if err != nil { + logger.Debugf(ctx, "failed to transform signal models for request [%+v] with err: %v", request, err) + return nil, err + } + var token string + if len(signalList) == int(request.Limit) { + token = strconv.Itoa(offset + len(signalList)) + } + return &admin.SignalList{ + Signals: signalList, + Token: token, + }, nil +} + +func (s *SignalManager) SetSignal(ctx context.Context, request admin.SignalSetRequest) (*admin.SignalSetResponse, error) { + if err := validation.ValidateSignalSetRequest(ctx, s.db, request); err != nil { + return nil, err + } + ctx = getSignalContext(ctx, request.Id) + + signalModel, err := transformers.CreateSignalModel(request.Id, nil, request.Value) + if err != nil { + logger.Errorf(ctx, "Failed to transform signal with id [%+v] and value [+%v] with err: %v", request.Id, request.Value, err) + return nil, err + } + + err = s.db.SignalRepo().Update(ctx, signalModel.SignalKey, signalModel.Value) + if err != nil { + return nil, err + } + + s.metrics.Set.Inc(ctx) + return &admin.SignalSetResponse{}, nil +} + +func NewSignalManager( + db repoInterfaces.Repository, + scope promutils.Scope) interfaces.SignalInterface { + metrics := signalMetrics{ + Scope: scope, + Set: labeled.NewCounter("num_set", "count of set signals", scope), + } + + return &SignalManager{ + db: db, + metrics: metrics, + } +} diff --git a/pkg/manager/impl/signal_manager_test.go b/pkg/manager/impl/signal_manager_test.go new file mode 100644 index 000000000..cc01f07fe --- /dev/null +++ b/pkg/manager/impl/signal_manager_test.go @@ -0,0 +1,241 @@ +package impl + +import ( + "context" + "errors" + "testing" + + repositoryMocks "github.com/flyteorg/flyteadmin/pkg/repositories/mocks" + "github.com/flyteorg/flyteadmin/pkg/repositories/models" + "github.com/flyteorg/flyteadmin/pkg/repositories/transformers" + + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" + + mockScope "github.com/flyteorg/flytestdlib/promutils" + + "github.com/golang/protobuf/proto" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +var ( + signalID = &core.SignalIdentifier{ + ExecutionId: &core.WorkflowExecutionIdentifier{ + Project: "project", + Domain: "domain", + Name: "name", + }, + SignalId: "signal", + } + + signalType = &core.LiteralType{ + Type: &core.LiteralType_Simple{ + Simple: core.SimpleType_BOOLEAN, + }, + } + + signalValue = &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Primitive{ + Primitive: &core.Primitive{ + Value: &core.Primitive_Boolean{ + Boolean: false, + }, + }, + }, + }, + }, + } +) + +func TestGetOrCreateSignal(t *testing.T) { + t.Run("Happy", func(t *testing.T) { + mockRepository := repositoryMocks.NewMockRepository() + mockRepository.SignalRepo().(*repositoryMocks.SignalRepoInterface).OnGetOrCreateMatch(mock.Anything, mock.Anything).Return(nil) + + signalManager := NewSignalManager(mockRepository, mockScope.NewTestScope()) + request := admin.SignalGetOrCreateRequest{ + Id: signalID, + Type: signalType, + } + + response, err := signalManager.GetOrCreateSignal(context.Background(), request) + assert.NoError(t, err) + + assert.True(t, proto.Equal(&admin.Signal{ + Id: signalID, + Type: signalType, + }, response)) + }) + + t.Run("ValidationError", func(t *testing.T) { + mockRepository := repositoryMocks.NewMockRepository() + signalManager := NewSignalManager(mockRepository, mockScope.NewTestScope()) + request := admin.SignalGetOrCreateRequest{ + Type: signalType, + } + + _, err := signalManager.GetOrCreateSignal(context.Background(), request) + assert.Error(t, err) + }) + + t.Run("DBError", func(t *testing.T) { + mockRepository := repositoryMocks.NewMockRepository() + mockRepository.SignalRepo().(*repositoryMocks.SignalRepoInterface).OnGetOrCreateMatch(mock.Anything, mock.Anything).Return(errors.New("foo")) + + signalManager := NewSignalManager(mockRepository, mockScope.NewTestScope()) + request := admin.SignalGetOrCreateRequest{ + Id: signalID, + Type: signalType, + } + + _, err := signalManager.GetOrCreateSignal(context.Background(), request) + assert.Error(t, err) + }) +} + +func TestListSignals(t *testing.T) { + signalModel, err := transformers.CreateSignalModel(signalID, signalType, nil) + assert.NoError(t, err) + + t.Run("Happy", func(t *testing.T) { + mockRepository := repositoryMocks.NewMockRepository() + mockRepository.SignalRepo().(*repositoryMocks.SignalRepoInterface). + OnListMatch(mock.Anything, mock.Anything).Return( + []models.Signal{signalModel}, + nil, + ) + + signalManager := NewSignalManager(mockRepository, mockScope.NewTestScope()) + request := admin.SignalListRequest{ + WorkflowExecutionId: &core.WorkflowExecutionIdentifier{ + Project: "project", + Domain: "domain", + Name: "name", + }, + Limit: 20, + } + + response, err := signalManager.ListSignals(context.Background(), request) + assert.NoError(t, err) + + assert.True(t, proto.Equal( + &admin.SignalList{ + Signals: []*admin.Signal{ + &admin.Signal{ + Id: signalID, + Type: signalType, + }, + }, + }, + response, + )) + }) + + t.Run("ValidationError", func(t *testing.T) { + mockRepository := repositoryMocks.NewMockRepository() + signalManager := NewSignalManager(mockRepository, mockScope.NewTestScope()) + request := admin.SignalListRequest{ + WorkflowExecutionId: &core.WorkflowExecutionIdentifier{ + Project: "project", + Domain: "domain", + Name: "name", + }, + } + + _, err := signalManager.ListSignals(context.Background(), request) + assert.Error(t, err) + }) + + t.Run("DBError", func(t *testing.T) { + mockRepository := repositoryMocks.NewMockRepository() + mockRepository.SignalRepo().(*repositoryMocks.SignalRepoInterface). + OnListMatch(mock.Anything, mock.Anything).Return(nil, errors.New("foo")) + + signalManager := NewSignalManager(mockRepository, mockScope.NewTestScope()) + request := admin.SignalListRequest{ + WorkflowExecutionId: &core.WorkflowExecutionIdentifier{ + Project: "project", + Domain: "domain", + Name: "name", + }, + Limit: 20, + } + + _, err := signalManager.ListSignals(context.Background(), request) + assert.Error(t, err) + }) +} + +func TestSetSignal(t *testing.T) { + signalModel, err := transformers.CreateSignalModel(signalID, signalType, nil) + assert.NoError(t, err) + + t.Run("Happy", func(t *testing.T) { + mockRepository := repositoryMocks.NewMockRepository() + mockRepository.SignalRepo().(*repositoryMocks.SignalRepoInterface). + OnGetMatch(mock.Anything, mock.Anything, mock.Anything).Return(signalModel, nil) + mockRepository.SignalRepo().(*repositoryMocks.SignalRepoInterface). + OnUpdateMatch(mock.Anything, mock.Anything, mock.Anything).Return(nil) + + signalManager := NewSignalManager(mockRepository, mockScope.NewTestScope()) + request := admin.SignalSetRequest{ + Id: signalID, + Value: signalValue, + } + + response, err := signalManager.SetSignal(context.Background(), request) + assert.NoError(t, err) + + assert.True(t, proto.Equal(&admin.SignalSetResponse{}, response)) + }) + + t.Run("ValidationError", func(t *testing.T) { + mockRepository := repositoryMocks.NewMockRepository() + signalManager := NewSignalManager(mockRepository, mockScope.NewTestScope()) + request := admin.SignalSetRequest{ + Value: signalValue, + } + + _, err := signalManager.SetSignal(context.Background(), request) + assert.Error(t, err) + }) + + t.Run("DBGetError", func(t *testing.T) { + mockRepository := repositoryMocks.NewMockRepository() + mockRepository.SignalRepo().(*repositoryMocks.SignalRepoInterface). + OnGetMatch(mock.Anything, mock.Anything).Return( + models.Signal{}, + errors.New("foo"), + ) + + signalManager := NewSignalManager(mockRepository, mockScope.NewTestScope()) + request := admin.SignalSetRequest{ + Id: signalID, + Value: signalValue, + } + + _, err := signalManager.SetSignal(context.Background(), request) + assert.Error(t, err) + }) + + t.Run("DBUpdateError", func(t *testing.T) { + mockRepository := repositoryMocks.NewMockRepository() + mockRepository.SignalRepo().(*repositoryMocks.SignalRepoInterface). + OnGetMatch(mock.Anything, mock.Anything).Return(signalModel, nil) + mockRepository.SignalRepo().(*repositoryMocks.SignalRepoInterface). + OnUpdateMatch(mock.Anything, mock.Anything, mock.Anything).Return(errors.New("foo")) + + signalManager := NewSignalManager(mockRepository, mockScope.NewTestScope()) + request := admin.SignalSetRequest{ + Id: signalID, + Value: signalValue, + } + + _, err := signalManager.SetSignal(context.Background(), request) + assert.Error(t, err) + }) +} diff --git a/pkg/manager/impl/util/filters.go b/pkg/manager/impl/util/filters.go index a21404960..e52bfb8b1 100644 --- a/pkg/manager/impl/util/filters.go +++ b/pkg/manager/impl/util/filters.go @@ -61,6 +61,7 @@ var filterFieldEntityPrefix = map[string]common.Entity{ "entities": common.NamedEntity, "named_entity_metadata": common.NamedEntityMetadata, "project": common.Project, + "signal": common.Signal, } func parseField(field string, primaryEntity common.Entity) (common.Entity, string) { diff --git a/pkg/manager/impl/validation/signal_validator.go b/pkg/manager/impl/validation/signal_validator.go new file mode 100644 index 000000000..11c5b335d --- /dev/null +++ b/pkg/manager/impl/validation/signal_validator.go @@ -0,0 +1,88 @@ +package validation + +import ( + "context" + + "github.com/flyteorg/flyteadmin/pkg/errors" + "github.com/flyteorg/flyteadmin/pkg/manager/impl/shared" + repositoryInterfaces "github.com/flyteorg/flyteadmin/pkg/repositories/interfaces" + "github.com/flyteorg/flyteadmin/pkg/repositories/transformers" + + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" + + propellervalidators "github.com/flyteorg/flytepropeller/pkg/compiler/validators" + + "google.golang.org/grpc/codes" +) + +func ValidateSignalGetOrCreateRequest(ctx context.Context, request admin.SignalGetOrCreateRequest) error { + if request.Id == nil { + return shared.GetMissingArgumentError("id") + } + if err := ValidateSignalIdentifier(*request.Id); err != nil { + return err + } + if request.Type == nil { + return shared.GetMissingArgumentError("type") + } + + return nil +} + +func ValidateSignalIdentifier(identifier core.SignalIdentifier) error { + if identifier.ExecutionId == nil { + return shared.GetMissingArgumentError(shared.ExecutionID) + } + if identifier.SignalId == "" { + return shared.GetMissingArgumentError("signal_id") + } + + return ValidateWorkflowExecutionIdentifier(identifier.ExecutionId) +} + +func ValidateSignalListRequest(ctx context.Context, request admin.SignalListRequest) error { + if err := ValidateWorkflowExecutionIdentifier(request.WorkflowExecutionId); err != nil { + return shared.GetMissingArgumentError(shared.ExecutionID) + } + if err := ValidateLimit(request.Limit); err != nil { + return err + } + return nil +} + +func ValidateSignalSetRequest(ctx context.Context, db repositoryInterfaces.Repository, request admin.SignalSetRequest) error { + if request.Id == nil { + return shared.GetMissingArgumentError("id") + } + if err := ValidateSignalIdentifier(*request.Id); err != nil { + return err + } + if request.Value == nil { + return shared.GetMissingArgumentError("value") + } + + // validate that signal value matches type of existing signal + signalModel, err := transformers.CreateSignalModel(request.Id, nil, nil) + if err != nil { + return nil + } + lookupSignalModel, err := db.SignalRepo().Get(ctx, signalModel.SignalKey) + if err != nil { + return errors.NewFlyteAdminErrorf(codes.InvalidArgument, + "failed to validate that signal [%v] exists, err: [%+v]", + signalModel.SignalKey, err) + } + valueType := propellervalidators.LiteralTypeForLiteral(request.Value) + lookupSignal, err := transformers.FromSignalModel(lookupSignalModel) + if err != nil { + return err + } + if !propellervalidators.AreTypesCastable(lookupSignal.Type, valueType) { + return errors.NewFlyteAdminErrorf(codes.InvalidArgument, + "requested signal value [%v] is not castable to existing signal type [%v]", + request.Value, lookupSignalModel.Type) + } + + return nil +} diff --git a/pkg/manager/impl/validation/signal_validator_test.go b/pkg/manager/impl/validation/signal_validator_test.go new file mode 100644 index 000000000..331da688c --- /dev/null +++ b/pkg/manager/impl/validation/signal_validator_test.go @@ -0,0 +1,287 @@ +package validation + +import ( + "context" + "errors" + "testing" + + repositoryMocks "github.com/flyteorg/flyteadmin/pkg/repositories/mocks" + "github.com/flyteorg/flyteadmin/pkg/repositories/models" + + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" + + "github.com/golang/protobuf/proto" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +func TestValidateSignalGetOrCreateRequest(t *testing.T) { + ctx := context.TODO() + + t.Run("Happy", func(t *testing.T) { + request := admin.SignalGetOrCreateRequest{ + Id: &core.SignalIdentifier{ + ExecutionId: &core.WorkflowExecutionIdentifier{ + Project: "project", + Domain: "domain", + Name: "name", + }, + SignalId: "signal", + }, + Type: &core.LiteralType{ + Type: &core.LiteralType_Simple{ + Simple: core.SimpleType_BOOLEAN, + }, + }, + } + assert.NoError(t, ValidateSignalGetOrCreateRequest(ctx, request)) + }) + + t.Run("MissingSignalIdentifier", func(t *testing.T) { + request := admin.SignalGetOrCreateRequest{ + Type: &core.LiteralType{ + Type: &core.LiteralType_Simple{ + Simple: core.SimpleType_BOOLEAN, + }, + }, + } + assert.EqualError(t, ValidateSignalGetOrCreateRequest(ctx, request), "missing id") + }) + + t.Run("InvalidSignalIdentifier", func(t *testing.T) { + request := admin.SignalGetOrCreateRequest{ + Id: &core.SignalIdentifier{ + ExecutionId: &core.WorkflowExecutionIdentifier{ + Project: "project", + Domain: "domain", + Name: "name", + }, + }, + Type: &core.LiteralType{ + Type: &core.LiteralType_Simple{ + Simple: core.SimpleType_BOOLEAN, + }, + }, + } + assert.EqualError(t, ValidateSignalGetOrCreateRequest(ctx, request), "missing signal_id") + }) + + t.Run("MissingExecutionIdentifier", func(t *testing.T) { + request := admin.SignalGetOrCreateRequest{ + Id: &core.SignalIdentifier{ + SignalId: "signal", + }, + Type: &core.LiteralType{ + Type: &core.LiteralType_Simple{ + Simple: core.SimpleType_BOOLEAN, + }, + }, + } + assert.EqualError(t, ValidateSignalGetOrCreateRequest(ctx, request), "missing execution_id") + }) + + t.Run("InvalidExecutionIdentifier", func(t *testing.T) { + request := admin.SignalGetOrCreateRequest{ + Id: &core.SignalIdentifier{ + ExecutionId: &core.WorkflowExecutionIdentifier{ + Domain: "domain", + Name: "name", + }, + SignalId: "signal", + }, + Type: &core.LiteralType{ + Type: &core.LiteralType_Simple{ + Simple: core.SimpleType_BOOLEAN, + }, + }, + } + assert.EqualError(t, ValidateSignalGetOrCreateRequest(ctx, request), "missing project") + }) + + t.Run("MissingType", func(t *testing.T) { + request := admin.SignalGetOrCreateRequest{ + Id: &core.SignalIdentifier{ + ExecutionId: &core.WorkflowExecutionIdentifier{ + Project: "project", + Domain: "domain", + Name: "name", + }, + SignalId: "signal", + }, + } + assert.EqualError(t, ValidateSignalGetOrCreateRequest(ctx, request), "missing type") + }) +} + +func TestValidateSignalListrequest(t *testing.T) { + ctx := context.TODO() + + t.Run("Happy", func(t *testing.T) { + request := admin.SignalListRequest{ + WorkflowExecutionId: &core.WorkflowExecutionIdentifier{ + Project: "project", + Domain: "domain", + Name: "name", + }, + Limit: 20, + } + assert.NoError(t, ValidateSignalListRequest(ctx, request)) + }) + + t.Run("MissingWorkflowExecutionIdentifier", func(t *testing.T) { + request := admin.SignalListRequest{ + Limit: 20, + } + assert.EqualError(t, ValidateSignalListRequest(ctx, request), "missing execution_id") + }) + + t.Run("MissingLimit", func(t *testing.T) { + request := admin.SignalListRequest{ + WorkflowExecutionId: &core.WorkflowExecutionIdentifier{ + Project: "project", + Domain: "domain", + Name: "name", + }, + } + assert.EqualError(t, ValidateSignalListRequest(ctx, request), "invalid value for limit") + }) +} + +func TestValidateSignalUpdateRequest(t *testing.T) { + ctx := context.TODO() + + booleanType := &core.LiteralType{ + Type: &core.LiteralType_Simple{ + Simple: core.SimpleType_BOOLEAN, + }, + } + typeBytes, _ := proto.Marshal(booleanType) + + repo := repositoryMocks.NewMockRepository() + repo.SignalRepo().(*repositoryMocks.SignalRepoInterface). + OnGetMatch(mock.Anything, mock.Anything).Return( + models.Signal{ + Type: typeBytes, + }, + nil, + ) + + t.Run("Happy", func(t *testing.T) { + request := admin.SignalSetRequest{ + Id: &core.SignalIdentifier{ + ExecutionId: &core.WorkflowExecutionIdentifier{ + Project: "project", + Domain: "domain", + Name: "name", + }, + SignalId: "signal", + }, + Value: &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Primitive{ + Primitive: &core.Primitive{ + Value: &core.Primitive_Boolean{ + Boolean: false, + }, + }, + }, + }, + }, + }, + } + assert.NoError(t, ValidateSignalSetRequest(ctx, repo, request)) + }) + + t.Run("MissingValue", func(t *testing.T) { + request := admin.SignalSetRequest{ + Id: &core.SignalIdentifier{ + ExecutionId: &core.WorkflowExecutionIdentifier{ + Project: "project", + Domain: "domain", + Name: "name", + }, + SignalId: "signal", + }, + } + assert.EqualError(t, ValidateSignalSetRequest(ctx, repo, request), "missing value") + }) + + t.Run("MissingSignal", func(t *testing.T) { + repo := repositoryMocks.NewMockRepository() + repo.SignalRepo().(*repositoryMocks.SignalRepoInterface). + OnGetMatch(mock.Anything, mock.Anything).Return(models.Signal{}, errors.New("foo")) + + request := admin.SignalSetRequest{ + Id: &core.SignalIdentifier{ + ExecutionId: &core.WorkflowExecutionIdentifier{ + Project: "project", + Domain: "domain", + Name: "name", + }, + SignalId: "signal", + }, + Value: &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Primitive{ + Primitive: &core.Primitive{ + Value: &core.Primitive_Boolean{ + Boolean: false, + }, + }, + }, + }, + }, + }, + } + assert.EqualError(t, ValidateSignalSetRequest(ctx, repo, request), + "failed to validate that signal [{{project domain name} signal}] exists, err: [foo]") + }) + + t.Run("InvalidType", func(t *testing.T) { + integerType := &core.LiteralType{ + Type: &core.LiteralType_Simple{ + Simple: core.SimpleType_INTEGER, + }, + } + typeBytes, _ := proto.Marshal(integerType) + + repo := repositoryMocks.NewMockRepository() + repo.SignalRepo().(*repositoryMocks.SignalRepoInterface). + OnGetMatch(mock.Anything, mock.Anything).Return( + models.Signal{ + Type: typeBytes, + }, + nil, + ) + + request := admin.SignalSetRequest{ + Id: &core.SignalIdentifier{ + ExecutionId: &core.WorkflowExecutionIdentifier{ + Project: "project", + Domain: "domain", + Name: "name", + }, + SignalId: "signal", + }, + Value: &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Primitive{ + Primitive: &core.Primitive{ + Value: &core.Primitive_Boolean{ + Boolean: false, + }, + }, + }, + }, + }, + }, + } + assert.EqualError(t, ValidateSignalSetRequest(ctx, repo, request), + "requested signal value [scalar: > ] is not castable to existing signal type [[8 1]]") + }) +} diff --git a/pkg/manager/interfaces/signal.go b/pkg/manager/interfaces/signal.go new file mode 100644 index 000000000..0547e439d --- /dev/null +++ b/pkg/manager/interfaces/signal.go @@ -0,0 +1,16 @@ +package interfaces + +import ( + "context" + + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" +) + +//go:generate mockery -name=SignalInterface -output=../mocks -case=underscore + +// Interface for managing Flyte Signals +type SignalInterface interface { + GetOrCreateSignal(ctx context.Context, request admin.SignalGetOrCreateRequest) (*admin.Signal, error) + ListSignals(ctx context.Context, request admin.SignalListRequest) (*admin.SignalList, error) + SetSignal(ctx context.Context, request admin.SignalSetRequest) (*admin.SignalSetResponse, error) +} diff --git a/pkg/manager/mocks/signal_interface.go b/pkg/manager/mocks/signal_interface.go new file mode 100644 index 000000000..51e8b6636 --- /dev/null +++ b/pkg/manager/mocks/signal_interface.go @@ -0,0 +1,139 @@ +// Code generated by mockery v1.0.1. DO NOT EDIT. + +package mocks + +import ( + context "context" + + admin "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" + + mock "github.com/stretchr/testify/mock" +) + +// SignalInterface is an autogenerated mock type for the SignalInterface type +type SignalInterface struct { + mock.Mock +} + +type SignalInterface_GetOrCreateSignal struct { + *mock.Call +} + +func (_m SignalInterface_GetOrCreateSignal) Return(_a0 *admin.Signal, _a1 error) *SignalInterface_GetOrCreateSignal { + return &SignalInterface_GetOrCreateSignal{Call: _m.Call.Return(_a0, _a1)} +} + +func (_m *SignalInterface) OnGetOrCreateSignal(ctx context.Context, request admin.SignalGetOrCreateRequest) *SignalInterface_GetOrCreateSignal { + c_call := _m.On("GetOrCreateSignal", ctx, request) + return &SignalInterface_GetOrCreateSignal{Call: c_call} +} + +func (_m *SignalInterface) OnGetOrCreateSignalMatch(matchers ...interface{}) *SignalInterface_GetOrCreateSignal { + c_call := _m.On("GetOrCreateSignal", matchers...) + return &SignalInterface_GetOrCreateSignal{Call: c_call} +} + +// GetOrCreateSignal provides a mock function with given fields: ctx, request +func (_m *SignalInterface) GetOrCreateSignal(ctx context.Context, request admin.SignalGetOrCreateRequest) (*admin.Signal, error) { + ret := _m.Called(ctx, request) + + var r0 *admin.Signal + if rf, ok := ret.Get(0).(func(context.Context, admin.SignalGetOrCreateRequest) *admin.Signal); ok { + r0 = rf(ctx, request) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*admin.Signal) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, admin.SignalGetOrCreateRequest) error); ok { + r1 = rf(ctx, request) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +type SignalInterface_ListSignals struct { + *mock.Call +} + +func (_m SignalInterface_ListSignals) Return(_a0 *admin.SignalList, _a1 error) *SignalInterface_ListSignals { + return &SignalInterface_ListSignals{Call: _m.Call.Return(_a0, _a1)} +} + +func (_m *SignalInterface) OnListSignals(ctx context.Context, request admin.SignalListRequest) *SignalInterface_ListSignals { + c_call := _m.On("ListSignals", ctx, request) + return &SignalInterface_ListSignals{Call: c_call} +} + +func (_m *SignalInterface) OnListSignalsMatch(matchers ...interface{}) *SignalInterface_ListSignals { + c_call := _m.On("ListSignals", matchers...) + return &SignalInterface_ListSignals{Call: c_call} +} + +// ListSignals provides a mock function with given fields: ctx, request +func (_m *SignalInterface) ListSignals(ctx context.Context, request admin.SignalListRequest) (*admin.SignalList, error) { + ret := _m.Called(ctx, request) + + var r0 *admin.SignalList + if rf, ok := ret.Get(0).(func(context.Context, admin.SignalListRequest) *admin.SignalList); ok { + r0 = rf(ctx, request) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*admin.SignalList) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, admin.SignalListRequest) error); ok { + r1 = rf(ctx, request) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +type SignalInterface_SetSignal struct { + *mock.Call +} + +func (_m SignalInterface_SetSignal) Return(_a0 *admin.SignalSetResponse, _a1 error) *SignalInterface_SetSignal { + return &SignalInterface_SetSignal{Call: _m.Call.Return(_a0, _a1)} +} + +func (_m *SignalInterface) OnSetSignal(ctx context.Context, request admin.SignalSetRequest) *SignalInterface_SetSignal { + c_call := _m.On("SetSignal", ctx, request) + return &SignalInterface_SetSignal{Call: c_call} +} + +func (_m *SignalInterface) OnSetSignalMatch(matchers ...interface{}) *SignalInterface_SetSignal { + c_call := _m.On("SetSignal", matchers...) + return &SignalInterface_SetSignal{Call: c_call} +} + +// SetSignal provides a mock function with given fields: ctx, request +func (_m *SignalInterface) SetSignal(ctx context.Context, request admin.SignalSetRequest) (*admin.SignalSetResponse, error) { + ret := _m.Called(ctx, request) + + var r0 *admin.SignalSetResponse + if rf, ok := ret.Get(0).(func(context.Context, admin.SignalSetRequest) *admin.SignalSetResponse); ok { + r0 = rf(ctx, request) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*admin.SignalSetResponse) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, admin.SignalSetRequest) error); ok { + r1 = rf(ctx, request) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} diff --git a/pkg/repositories/config/migrations.go b/pkg/repositories/config/migrations.go index 301fe5a22..e5ad84ba2 100644 --- a/pkg/repositories/config/migrations.go +++ b/pkg/repositories/config/migrations.go @@ -390,6 +390,16 @@ var Migrations = []*gormigrate.Migration{ return tx.Model(&models.Execution{}).Migrator().DropIndex(&models.Execution{}, "idx_executions_created_at") }, }, + // Create signals table. + { + ID: "2022-04-11-signals", + Migrate: func(tx *gorm.DB) error { + return tx.AutoMigrate(&models.Signal{}) + }, + Rollback: func(tx *gorm.DB) error { + return tx.Migrator().DropTable("signals") + }, + }, // Add the launch_type resource to the execution model { ID: "2022-12-09-execution-launch-type", diff --git a/pkg/repositories/gorm_repo.go b/pkg/repositories/gorm_repo.go index a23852446..c5af46980 100644 --- a/pkg/repositories/gorm_repo.go +++ b/pkg/repositories/gorm_repo.go @@ -25,6 +25,7 @@ type GormRepo struct { resourceRepo interfaces.ResourceRepoInterface schedulableEntityRepo schedulerInterfaces.SchedulableEntityRepoInterface scheduleEntitiesSnapshotRepo schedulerInterfaces.ScheduleEntitiesSnapShotRepoInterface + signalRepo interfaces.SignalRepoInterface } func (r *GormRepo) ExecutionRepo() interfaces.ExecutionRepoInterface { @@ -79,6 +80,10 @@ func (r *GormRepo) ScheduleEntitiesSnapshotRepo() schedulerInterfaces.ScheduleEn return r.scheduleEntitiesSnapshotRepo } +func (r *GormRepo) SignalRepo() interfaces.SignalRepoInterface { + return r.signalRepo +} + func (r *GormRepo) GetGormDB() *gorm.DB { return r.db } @@ -99,5 +104,6 @@ func NewGormRepo(db *gorm.DB, errorTransformer errors.ErrorTransformer, scope pr resourceRepo: gormimpl.NewResourceRepo(db, errorTransformer, scope.NewSubScope("resources")), schedulableEntityRepo: schedulerGormImpl.NewSchedulableEntityRepo(db, errorTransformer, scope.NewSubScope("schedulable_entity")), scheduleEntitiesSnapshotRepo: schedulerGormImpl.NewScheduleEntitiesSnapshotRepo(db, errorTransformer, scope.NewSubScope("schedule_entities_snapshot")), + signalRepo: gormimpl.NewSignalRepo(db, errorTransformer, scope.NewSubScope("signals")), } } diff --git a/pkg/repositories/gormimpl/common.go b/pkg/repositories/gormimpl/common.go index 6a9d09703..1e016cf1d 100644 --- a/pkg/repositories/gormimpl/common.go +++ b/pkg/repositories/gormimpl/common.go @@ -41,6 +41,7 @@ var entityToTableName = map[common.Entity]string{ common.Workflow: "workflows", common.NamedEntity: "entities", common.NamedEntityMetadata: "named_entity_metadata", + common.Signal: "signals", } var innerJoinExecToNodeExec = fmt.Sprintf( diff --git a/pkg/repositories/gormimpl/signal_repo.go b/pkg/repositories/gormimpl/signal_repo.go new file mode 100644 index 000000000..b87f70316 --- /dev/null +++ b/pkg/repositories/gormimpl/signal_repo.go @@ -0,0 +1,110 @@ +package gormimpl + +import ( + "context" + "errors" + + adminerrors "github.com/flyteorg/flyteadmin/pkg/errors" + flyteAdminDbErrors "github.com/flyteorg/flyteadmin/pkg/repositories/errors" + "github.com/flyteorg/flyteadmin/pkg/repositories/interfaces" + "github.com/flyteorg/flyteadmin/pkg/repositories/models" + + "github.com/flyteorg/flytestdlib/promutils" + + "google.golang.org/grpc/codes" + + "gorm.io/gorm" +) + +// SignalRepo is an implementation of SignalRepoInterface. +type SignalRepo struct { + db *gorm.DB + errorTransformer flyteAdminDbErrors.ErrorTransformer + metrics gormMetrics +} + +// Get retrieves a signal model from the database store. +func (s *SignalRepo) Get(ctx context.Context, input models.SignalKey) (models.Signal, error) { + var signal models.Signal + timer := s.metrics.GetDuration.Start() + tx := s.db.Where(&models.Signal{ + SignalKey: input, + }).Take(&signal) + timer.Stop() + if errors.Is(tx.Error, gorm.ErrRecordNotFound) { + return models.Signal{}, adminerrors.NewFlyteAdminError(codes.NotFound, "signal does not exist") + } + if tx.Error != nil { + return models.Signal{}, s.errorTransformer.ToFlyteAdminError(tx.Error) + } + return signal, nil +} + +// GetOrCreate returns a signal if it already exists, if not it creates a new one given the input +func (s *SignalRepo) GetOrCreate(ctx context.Context, input *models.Signal) error { + timer := s.metrics.CreateDuration.Start() + tx := s.db.FirstOrCreate(&input, input) + timer.Stop() + if tx.Error != nil { + return s.errorTransformer.ToFlyteAdminError(tx.Error) + } + return nil +} + +// List fetches all signals that match the provided input +func (s *SignalRepo) List(ctx context.Context, input interfaces.ListResourceInput) ([]models.Signal, error) { + // First validate input. + if err := ValidateListInput(input); err != nil { + return nil, err + } + var signals []models.Signal + tx := s.db.Limit(input.Limit).Offset(input.Offset) + + // Apply filters + tx, err := applyFilters(tx, input.InlineFilters, input.MapFilters) + if err != nil { + return nil, err + } + // Apply sort ordering. + if input.SortParameter != nil { + tx = tx.Order(input.SortParameter.GetGormOrderExpr()) + } + timer := s.metrics.ListDuration.Start() + tx.Find(&signals) + timer.Stop() + if tx.Error != nil { + return nil, s.errorTransformer.ToFlyteAdminError(tx.Error) + } + + return signals, nil +} + +// Update sets the value field on the specified signal model +func (s *SignalRepo) Update(ctx context.Context, input models.SignalKey, value []byte) error { + signal := models.Signal{ + SignalKey: input, + Value: value, + } + + timer := s.metrics.GetDuration.Start() + tx := s.db.Model(&signal).Select("value").Updates(signal) + timer.Stop() + if tx.Error != nil { + return s.errorTransformer.ToFlyteAdminError(tx.Error) + } + if tx.RowsAffected == 0 { + return adminerrors.NewFlyteAdminError(codes.NotFound, "signal does not exist") + } + return nil +} + +// Returns an instance of SignalRepoInterface +func NewSignalRepo( + db *gorm.DB, errorTransformer flyteAdminDbErrors.ErrorTransformer, scope promutils.Scope) interfaces.SignalRepoInterface { + metrics := newMetrics(scope) + return &SignalRepo{ + db: db, + errorTransformer: errorTransformer, + metrics: metrics, + } +} diff --git a/pkg/repositories/gormimpl/signal_repo_test.go b/pkg/repositories/gormimpl/signal_repo_test.go new file mode 100644 index 000000000..157af6112 --- /dev/null +++ b/pkg/repositories/gormimpl/signal_repo_test.go @@ -0,0 +1,177 @@ +package gormimpl + +import ( + "context" + "reflect" + "testing" + + "github.com/flyteorg/flyteadmin/pkg/common" + "github.com/flyteorg/flyteadmin/pkg/repositories/errors" + "github.com/flyteorg/flyteadmin/pkg/repositories/interfaces" + "github.com/flyteorg/flyteadmin/pkg/repositories/models" + + mockScope "github.com/flyteorg/flytestdlib/promutils" + + mocket "github.com/Selvatico/go-mocket" + + "github.com/stretchr/testify/assert" +) + +var ( + signalModel = &models.Signal{ + SignalKey: models.SignalKey{ + ExecutionKey: models.ExecutionKey{ + Project: "project", + Domain: "domain", + Name: "name", + }, + SignalID: "signal", + }, + Type: []byte{1, 2}, + Value: []byte{3, 4}, + } +) + +func toSignalMap(signalModel models.Signal) map[string]interface{} { + signal := make(map[string]interface{}) + signal["created_at"] = signalModel.CreatedAt + signal["updated_at"] = signalModel.UpdatedAt + signal["execution_project"] = signalModel.Project + signal["execution_domain"] = signalModel.Domain + signal["execution_name"] = signalModel.Name + signal["signal_id"] = signalModel.SignalID + if signalModel.Type != nil { + signal["type"] = signalModel.Type + } + if signalModel.Value != nil { + signal["value"] = signalModel.Value + } + + return signal +} + +func TestGetSignal(t *testing.T) { + ctx := context.Background() + + signalRepo := NewSignalRepo(GetDbForTest(t), errors.NewTestErrorTransformer(), mockScope.NewTestScope()) + + GlobalMock := mocket.Catcher.Reset() + GlobalMock.Logging = true + + mockSelectQuery := GlobalMock.NewMock() + mockSelectQuery.WithQuery( + `SELECT * FROM "signals" WHERE "signals"."execution_project" = $1 AND "signals"."execution_domain" = $2 AND "signals"."execution_name" = $3 AND "signals"."signal_id" = $4 LIMIT 1`) + + // retrieve non-existent signalModel + lookupSignalModel, err := signalRepo.Get(ctx, signalModel.SignalKey) + assert.Error(t, err) + assert.Empty(t, lookupSignalModel) + + assert.True(t, mockSelectQuery.Triggered) + mockSelectQuery.Triggered = false // reset to false for second call + + // retrieve existent signalModel + signalModels := []map[string]interface{}{toSignalMap(*signalModel)} + mockSelectQuery.WithReply(signalModels) + + lookupSignalModel, err = signalRepo.Get(ctx, signalModel.SignalKey) + assert.NoError(t, err) + assert.True(t, reflect.DeepEqual(*signalModel, lookupSignalModel)) + + assert.True(t, mockSelectQuery.Triggered) +} + +func TestGetOrCreateSignal(t *testing.T) { + ctx := context.Background() + + signalRepo := NewSignalRepo(GetDbForTest(t), errors.NewTestErrorTransformer(), mockScope.NewTestScope()) + + GlobalMock := mocket.Catcher.Reset() + GlobalMock.Logging = true + + // create initial signalModel + mockInsertQuery := GlobalMock.NewMock() + mockInsertQuery.WithQuery( + `INSERT INTO "signals" ("id","created_at","updated_at","deleted_at","execution_project","execution_domain","execution_name","signal_id","type","value") VALUES ($1,$2,$3,$4,$5,$6,$7,$8,$9,$10)`) + + err := signalRepo.GetOrCreate(ctx, signalModel) + assert.NoError(t, err) + + assert.True(t, mockInsertQuery.Triggered) + mockInsertQuery.Triggered = false // reset to false for second call + + // initialize query mocks + signalModels := []map[string]interface{}{toSignalMap(*signalModel)} + mockSelectQuery := GlobalMock.NewMock() + mockSelectQuery.WithQuery( + `SELECT * FROM "signals" WHERE "signals"."created_at" = $1 AND "signals"."updated_at" = $2 AND "signals"."execution_project" = $3 AND "signals"."execution_domain" = $4 AND "signals"."execution_name" = $5 AND "signals"."signal_id" = $6 AND "signals"."execution_project" = $7 AND "signals"."execution_domain" = $8 AND "signals"."execution_name" = $9 AND "signals"."signal_id" = $10 ORDER BY "signals"."id" LIMIT 1`).WithReply(signalModels) + + // retrieve existing signalModel + lookupSignalModel := &models.Signal{} + *lookupSignalModel = *signalModel + lookupSignalModel.Type = nil + lookupSignalModel.Value = nil + + err = signalRepo.GetOrCreate(ctx, lookupSignalModel) + assert.NoError(t, err) + assert.True(t, reflect.DeepEqual(signalModel, lookupSignalModel)) + + assert.True(t, mockSelectQuery.Triggered) + assert.False(t, mockInsertQuery.Triggered) +} + +func TestListSignals(t *testing.T) { + ctx := context.Background() + + signalRepo := NewSignalRepo(GetDbForTest(t), errors.NewTestErrorTransformer(), mockScope.NewTestScope()) + + GlobalMock := mocket.Catcher.Reset() + GlobalMock.Logging = true + + // read all signal models + signalModels := []map[string]interface{}{toSignalMap(*signalModel)} + mockSelectQuery := GlobalMock.NewMock() + mockSelectQuery.WithQuery( + `SELECT * FROM "signals" WHERE project = $1 AND domain = $2 AND name = $3 LIMIT 20`).WithReply(signalModels) + + signals, err := signalRepo.List(ctx, interfaces.ListResourceInput{ + InlineFilters: []common.InlineFilter{ + getEqualityFilter(common.Signal, "project", project), + getEqualityFilter(common.Signal, "domain", domain), + getEqualityFilter(common.Signal, "name", name), + }, + Limit: 20, + }) + assert.NoError(t, err) + + assert.True(t, reflect.DeepEqual([]models.Signal{*signalModel}, signals)) + assert.True(t, mockSelectQuery.Triggered) +} + +func TestUpdateSignal(t *testing.T) { + ctx := context.Background() + + signalRepo := NewSignalRepo(GetDbForTest(t), errors.NewTestErrorTransformer(), mockScope.NewTestScope()) + + GlobalMock := mocket.Catcher.Reset() + GlobalMock.Logging = true + + // update signalModel does not exits + mockUpdateQuery := GlobalMock.NewMock() + mockUpdateQuery.WithQuery( + `UPDATE "signals" SET "updated_at"=$1,"value"=$2 WHERE "execution_project" = $3 AND "execution_domain" = $4 AND "execution_name" = $5 AND "signal_id" = $6`).WithRowsNum(0) + + err := signalRepo.Update(ctx, signalModel.SignalKey, signalModel.Value) + assert.Error(t, err) + + assert.True(t, mockUpdateQuery.Triggered) + mockUpdateQuery.Triggered = false // reset to false for second call + + // update signalModel exists + mockUpdateQuery.WithRowsNum(1) + + err = signalRepo.Update(ctx, signalModel.SignalKey, signalModel.Value) + assert.NoError(t, err) + + assert.True(t, mockUpdateQuery.Triggered) +} diff --git a/pkg/repositories/interfaces/repository.go b/pkg/repositories/interfaces/repository.go index 3dcaffc93..eb81607c5 100644 --- a/pkg/repositories/interfaces/repository.go +++ b/pkg/repositories/interfaces/repository.go @@ -22,6 +22,7 @@ type Repository interface { NamedEntityRepo() NamedEntityRepoInterface SchedulableEntityRepo() schedulerInterfaces.SchedulableEntityRepoInterface ScheduleEntitiesSnapshotRepo() schedulerInterfaces.ScheduleEntitiesSnapShotRepoInterface + SignalRepo() SignalRepoInterface GetGormDB() *gorm.DB } diff --git a/pkg/repositories/interfaces/signal_repo.go b/pkg/repositories/interfaces/signal_repo.go new file mode 100644 index 000000000..f26ca065c --- /dev/null +++ b/pkg/repositories/interfaces/signal_repo.go @@ -0,0 +1,26 @@ +package interfaces + +import ( + "context" + + "github.com/flyteorg/flyteadmin/pkg/repositories/models" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" +) + +//go:generate mockery -name=SignalRepoInterface -output=../mocks -case=underscore + +// Defines the interface for interacting with signal models. +type SignalRepoInterface interface { + // Get retrieves a signal model from the database store. + Get(ctx context.Context, input models.SignalKey) (models.Signal, error) + // GetOrCreate inserts a signal model into the database store or returns one if it already exists. + GetOrCreate(ctx context.Context, input *models.Signal) error + // List all signals that match the input values. + List(ctx context.Context, input ListResourceInput) ([]models.Signal, error) + // Update sets the value on a signal in the database store. + Update(ctx context.Context, input models.SignalKey, value []byte) error +} + +type GetSignalInput struct { + SignalID core.SignalIdentifier +} diff --git a/pkg/repositories/mocks/repository.go b/pkg/repositories/mocks/repository.go index 27fcc17f7..92333ce03 100644 --- a/pkg/repositories/mocks/repository.go +++ b/pkg/repositories/mocks/repository.go @@ -21,6 +21,7 @@ type MockRepository struct { namedEntityRepo interfaces.NamedEntityRepoInterface schedulableEntityRepo sIface.SchedulableEntityRepoInterface schedulableEntitySnapshotRepo sIface.ScheduleEntitiesSnapShotRepoInterface + signalRepo interfaces.SignalRepoInterface } func (r *MockRepository) GetGormDB() *gorm.DB { @@ -79,6 +80,10 @@ func (r *MockRepository) NamedEntityRepo() interfaces.NamedEntityRepoInterface { return r.namedEntityRepo } +func (r *MockRepository) SignalRepo() interfaces.SignalRepoInterface { + return r.signalRepo +} + func NewMockRepository() interfaces.Repository { return &MockRepository{ taskRepo: NewMockTaskRepo(), @@ -94,5 +99,6 @@ func NewMockRepository() interfaces.Repository { NodeExecutionEventRepoIface: &NodeExecutionEventRepoInterface{}, schedulableEntityRepo: &sMocks.SchedulableEntityRepoInterface{}, schedulableEntitySnapshotRepo: &sMocks.ScheduleEntitiesSnapShotRepoInterface{}, + signalRepo: &SignalRepoInterface{}, } } diff --git a/pkg/repositories/mocks/signal_repo_interface.go b/pkg/repositories/mocks/signal_repo_interface.go new file mode 100644 index 000000000..f60307911 --- /dev/null +++ b/pkg/repositories/mocks/signal_repo_interface.go @@ -0,0 +1,161 @@ +// Code generated by mockery v1.0.1. DO NOT EDIT. + +package mocks + +import ( + context "context" + + interfaces "github.com/flyteorg/flyteadmin/pkg/repositories/interfaces" + mock "github.com/stretchr/testify/mock" + + models "github.com/flyteorg/flyteadmin/pkg/repositories/models" +) + +// SignalRepoInterface is an autogenerated mock type for the SignalRepoInterface type +type SignalRepoInterface struct { + mock.Mock +} + +type SignalRepoInterface_Get struct { + *mock.Call +} + +func (_m SignalRepoInterface_Get) Return(_a0 models.Signal, _a1 error) *SignalRepoInterface_Get { + return &SignalRepoInterface_Get{Call: _m.Call.Return(_a0, _a1)} +} + +func (_m *SignalRepoInterface) OnGet(ctx context.Context, input models.SignalKey) *SignalRepoInterface_Get { + c_call := _m.On("Get", ctx, input) + return &SignalRepoInterface_Get{Call: c_call} +} + +func (_m *SignalRepoInterface) OnGetMatch(matchers ...interface{}) *SignalRepoInterface_Get { + c_call := _m.On("Get", matchers...) + return &SignalRepoInterface_Get{Call: c_call} +} + +// Get provides a mock function with given fields: ctx, input +func (_m *SignalRepoInterface) Get(ctx context.Context, input models.SignalKey) (models.Signal, error) { + ret := _m.Called(ctx, input) + + var r0 models.Signal + if rf, ok := ret.Get(0).(func(context.Context, models.SignalKey) models.Signal); ok { + r0 = rf(ctx, input) + } else { + r0 = ret.Get(0).(models.Signal) + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, models.SignalKey) error); ok { + r1 = rf(ctx, input) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +type SignalRepoInterface_GetOrCreate struct { + *mock.Call +} + +func (_m SignalRepoInterface_GetOrCreate) Return(_a0 error) *SignalRepoInterface_GetOrCreate { + return &SignalRepoInterface_GetOrCreate{Call: _m.Call.Return(_a0)} +} + +func (_m *SignalRepoInterface) OnGetOrCreate(ctx context.Context, input *models.Signal) *SignalRepoInterface_GetOrCreate { + c_call := _m.On("GetOrCreate", ctx, input) + return &SignalRepoInterface_GetOrCreate{Call: c_call} +} + +func (_m *SignalRepoInterface) OnGetOrCreateMatch(matchers ...interface{}) *SignalRepoInterface_GetOrCreate { + c_call := _m.On("GetOrCreate", matchers...) + return &SignalRepoInterface_GetOrCreate{Call: c_call} +} + +// GetOrCreate provides a mock function with given fields: ctx, input +func (_m *SignalRepoInterface) GetOrCreate(ctx context.Context, input *models.Signal) error { + ret := _m.Called(ctx, input) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, *models.Signal) error); ok { + r0 = rf(ctx, input) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +type SignalRepoInterface_List struct { + *mock.Call +} + +func (_m SignalRepoInterface_List) Return(_a0 []models.Signal, _a1 error) *SignalRepoInterface_List { + return &SignalRepoInterface_List{Call: _m.Call.Return(_a0, _a1)} +} + +func (_m *SignalRepoInterface) OnList(ctx context.Context, input interfaces.ListResourceInput) *SignalRepoInterface_List { + c_call := _m.On("List", ctx, input) + return &SignalRepoInterface_List{Call: c_call} +} + +func (_m *SignalRepoInterface) OnListMatch(matchers ...interface{}) *SignalRepoInterface_List { + c_call := _m.On("List", matchers...) + return &SignalRepoInterface_List{Call: c_call} +} + +// List provides a mock function with given fields: ctx, input +func (_m *SignalRepoInterface) List(ctx context.Context, input interfaces.ListResourceInput) ([]models.Signal, error) { + ret := _m.Called(ctx, input) + + var r0 []models.Signal + if rf, ok := ret.Get(0).(func(context.Context, interfaces.ListResourceInput) []models.Signal); ok { + r0 = rf(ctx, input) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]models.Signal) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, interfaces.ListResourceInput) error); ok { + r1 = rf(ctx, input) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +type SignalRepoInterface_Update struct { + *mock.Call +} + +func (_m SignalRepoInterface_Update) Return(_a0 error) *SignalRepoInterface_Update { + return &SignalRepoInterface_Update{Call: _m.Call.Return(_a0)} +} + +func (_m *SignalRepoInterface) OnUpdate(ctx context.Context, input models.SignalKey, value []byte) *SignalRepoInterface_Update { + c_call := _m.On("Update", ctx, input, value) + return &SignalRepoInterface_Update{Call: c_call} +} + +func (_m *SignalRepoInterface) OnUpdateMatch(matchers ...interface{}) *SignalRepoInterface_Update { + c_call := _m.On("Update", matchers...) + return &SignalRepoInterface_Update{Call: c_call} +} + +// Update provides a mock function with given fields: ctx, input, value +func (_m *SignalRepoInterface) Update(ctx context.Context, input models.SignalKey, value []byte) error { + ret := _m.Called(ctx, input, value) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, models.SignalKey, []byte) error); ok { + r0 = rf(ctx, input, value) + } else { + r0 = ret.Error(0) + } + + return r0 +} diff --git a/pkg/repositories/models/signal.go b/pkg/repositories/models/signal.go new file mode 100644 index 000000000..8a7fac693 --- /dev/null +++ b/pkg/repositories/models/signal.go @@ -0,0 +1,15 @@ +package models + +// Signal primary key +type SignalKey struct { + ExecutionKey + SignalID string `gorm:"primary_key;index" valid:"length(0|255)"` +} + +// Database model to encapsulate a signal. +type Signal struct { + BaseModel + SignalKey + Type []byte `gorm:"not null"` + Value []byte +} diff --git a/pkg/repositories/transformers/signal.go b/pkg/repositories/transformers/signal.go new file mode 100644 index 000000000..69bb7af01 --- /dev/null +++ b/pkg/repositories/transformers/signal.go @@ -0,0 +1,134 @@ +package transformers + +import ( + "github.com/flyteorg/flyteadmin/pkg/errors" + "github.com/flyteorg/flyteadmin/pkg/repositories/models" + + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" + + "github.com/golang/protobuf/proto" + + "google.golang.org/grpc/codes" +) + +func CreateSignalModel(signalID *core.SignalIdentifier, signalType *core.LiteralType, signalValue *core.Literal) (models.Signal, error) { + signalModel := models.Signal{} + if signalID != nil { + signalKey := &signalModel.SignalKey + if signalID.ExecutionId != nil { + executionKey := &signalKey.ExecutionKey + if len(signalID.ExecutionId.Project) > 0 { + executionKey.Project = signalID.ExecutionId.Project + } + if len(signalID.ExecutionId.Domain) > 0 { + executionKey.Domain = signalID.ExecutionId.Domain + } + if len(signalID.ExecutionId.Name) > 0 { + executionKey.Name = signalID.ExecutionId.Name + } + } + + if len(signalID.SignalId) > 0 { + signalKey.SignalID = signalID.SignalId + } + } + + if signalType != nil { + typeBytes, err := proto.Marshal(signalType) + if err != nil { + return models.Signal{}, errors.NewFlyteAdminError(codes.Internal, "Failed to serialize signal type") + } + + signalModel.Type = typeBytes + } + + if signalValue != nil { + valueBytes, err := proto.Marshal(signalValue) + if err != nil { + return models.Signal{}, errors.NewFlyteAdminError(codes.Internal, "Failed to serialize signal value") + } + + signalModel.Value = valueBytes + } + + return signalModel, nil +} + +func initSignalIdentifier(id *core.SignalIdentifier) *core.SignalIdentifier { + if id == nil { + id = &core.SignalIdentifier{} + } + return id +} + +func initWorkflowExecutionIdentifier(id *core.WorkflowExecutionIdentifier) *core.WorkflowExecutionIdentifier { + if id == nil { + return &core.WorkflowExecutionIdentifier{} + } + return id +} + +func FromSignalModel(signalModel models.Signal) (admin.Signal, error) { + signal := admin.Signal{} + + var executionID *core.WorkflowExecutionIdentifier + if len(signalModel.SignalKey.ExecutionKey.Project) > 0 { + executionID = initWorkflowExecutionIdentifier(executionID) + executionID.Project = signalModel.SignalKey.ExecutionKey.Project + } + if len(signalModel.SignalKey.ExecutionKey.Domain) > 0 { + executionID = initWorkflowExecutionIdentifier(executionID) + executionID.Domain = signalModel.SignalKey.ExecutionKey.Domain + } + if len(signalModel.SignalKey.ExecutionKey.Name) > 0 { + executionID = initWorkflowExecutionIdentifier(executionID) + executionID.Name = signalModel.SignalKey.ExecutionKey.Name + } + + var signalID *core.SignalIdentifier + if executionID != nil { + signalID = initSignalIdentifier(signalID) + signalID.ExecutionId = executionID + } + if len(signalModel.SignalKey.SignalID) > 0 { + signalID = initSignalIdentifier(signalID) + signalID.SignalId = signalModel.SignalKey.SignalID + } + + if signalID != nil { + signal.Id = signalID + } + + if len(signalModel.Type) > 0 { + var typeDeserialized core.LiteralType + err := proto.Unmarshal(signalModel.Type, &typeDeserialized) + if err != nil { + return admin.Signal{}, errors.NewFlyteAdminError(codes.Internal, "failed to unmarshal signal type") + } + signal.Type = &typeDeserialized + } + + if len(signalModel.Value) > 0 { + var valueDeserialized core.Literal + err := proto.Unmarshal(signalModel.Value, &valueDeserialized) + if err != nil { + return admin.Signal{}, errors.NewFlyteAdminError(codes.Internal, "failed to unmarshal signal value") + } + signal.Value = &valueDeserialized + } + + return signal, nil +} + +func FromSignalModels(signalModels []models.Signal) ([]*admin.Signal, error) { + signals := make([]*admin.Signal, len(signalModels)) + for idx, signalModel := range signalModels { + signal, err := FromSignalModel(signalModel) + if err != nil { + return nil, err + } + signals[idx] = &signal + } + return signals, nil +} diff --git a/pkg/repositories/transformers/signal_test.go b/pkg/repositories/transformers/signal_test.go new file mode 100644 index 000000000..c43ed0bb6 --- /dev/null +++ b/pkg/repositories/transformers/signal_test.go @@ -0,0 +1,163 @@ +package transformers + +import ( + "testing" + + "github.com/flyteorg/flyteadmin/pkg/repositories/models" + + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" + + "github.com/stretchr/testify/assert" + + "github.com/golang/protobuf/proto" +) + +var ( + booleanType = core.LiteralType{ + Type: &core.LiteralType_Simple{ + Simple: core.SimpleType_BOOLEAN, + }, + } + + booleanValue = core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Primitive{ + Primitive: &core.Primitive{ + Value: &core.Primitive_Boolean{ + Boolean: false, + }, + }, + }, + }, + }, + } + + signalKey = models.SignalKey{ + ExecutionKey: models.ExecutionKey{ + Project: "project", + Domain: "domain", + Name: "name", + }, + SignalID: "signal", + } + + signalID = core.SignalIdentifier{ + ExecutionId: &core.WorkflowExecutionIdentifier{ + Project: "project", + Domain: "domain", + Name: "name", + }, + SignalId: "signal", + } +) + +func TestCreateSignalModel(t *testing.T) { + booleanTypeBytes, _ := proto.Marshal(&booleanType) + booleanValueBytes, _ := proto.Marshal(&booleanValue) + + tests := []struct { + name string + model models.Signal + proto admin.Signal + }{ + { + name: "Empty", + model: models.Signal{}, + proto: admin.Signal{}, + }, + { + name: "Full", + model: models.Signal{ + SignalKey: signalKey, + Type: booleanTypeBytes, + Value: booleanValueBytes, + }, + proto: admin.Signal{ + Id: &signalID, + Type: &booleanType, + Value: &booleanValue, + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + signalModel, err := CreateSignalModel(test.proto.Id, test.proto.Type, test.proto.Value) + assert.NoError(t, err) + + assert.Equal(t, test.model, signalModel) + }) + } +} + +func TestFromSignalModel(t *testing.T) { + booleanTypeBytes, _ := proto.Marshal(&booleanType) + booleanValueBytes, _ := proto.Marshal(&booleanValue) + + tests := []struct { + name string + model models.Signal + proto admin.Signal + }{ + { + name: "Empty", + model: models.Signal{}, + proto: admin.Signal{}, + }, + { + name: "Full", + model: models.Signal{ + SignalKey: signalKey, + Type: booleanTypeBytes, + Value: booleanValueBytes, + }, + proto: admin.Signal{ + Id: &signalID, + Type: &booleanType, + Value: &booleanValue, + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + signal, err := FromSignalModel(test.model) + assert.NoError(t, err) + + assert.True(t, proto.Equal(&test.proto, &signal)) + }) + } +} + +func TestFromSignalModels(t *testing.T) { + booleanTypeBytes, _ := proto.Marshal(&booleanType) + booleanValueBytes, _ := proto.Marshal(&booleanValue) + + signalModels := []models.Signal{ + models.Signal{}, + models.Signal{ + SignalKey: signalKey, + Type: booleanTypeBytes, + Value: booleanValueBytes, + }, + } + + signals := []*admin.Signal{ + &admin.Signal{}, + &admin.Signal{ + Id: &signalID, + Type: &booleanType, + Value: &booleanValue, + }, + } + + s, err := FromSignalModels(signalModels) + assert.NoError(t, err) + + assert.Len(t, s, len(signals)) + for idx, signal := range signals { + assert.True(t, proto.Equal(signal, s[idx])) + } +} diff --git a/pkg/rpc/signal_service.go b/pkg/rpc/signal_service.go new file mode 100644 index 000000000..2487003d9 --- /dev/null +++ b/pkg/rpc/signal_service.go @@ -0,0 +1,147 @@ +package rpc + +import ( + "context" + "fmt" + "runtime/debug" + + manager "github.com/flyteorg/flyteadmin/pkg/manager/impl" + "github.com/flyteorg/flyteadmin/pkg/manager/interfaces" + "github.com/flyteorg/flyteadmin/pkg/repositories" + "github.com/flyteorg/flyteadmin/pkg/repositories/errors" + "github.com/flyteorg/flyteadmin/pkg/rpc/adminservice/util" + runtimeIfaces "github.com/flyteorg/flyteadmin/pkg/runtime/interfaces" + + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/service" + + "github.com/flyteorg/flytestdlib/logger" + "github.com/flyteorg/flytestdlib/promutils" + + "github.com/golang/protobuf/proto" + + "github.com/prometheus/client_golang/prometheus" + + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +type SignalMetrics struct { + scope promutils.Scope + panicCounter prometheus.Counter + + create util.RequestMetrics + get util.RequestMetrics +} + +func NewSignalMetrics(scope promutils.Scope) SignalMetrics { + return SignalMetrics{ + scope: scope, + panicCounter: scope.MustNewCounter("handler_panic", + "panics encountered while handling requests to the admin service"), + create: util.NewRequestMetrics(scope, "create_signal"), + get: util.NewRequestMetrics(scope, "get_signal"), + } +} + +type SignalService struct { + service.UnimplementedSignalServiceServer + signalManager interfaces.SignalInterface + metrics SignalMetrics +} + +func NewSignalServer(ctx context.Context, configuration runtimeIfaces.Configuration, adminScope promutils.Scope) *SignalService { + panicCounter := adminScope.MustNewCounter("initialization_panic", + "panics encountered initializing the signal service") + + defer func() { + if err := recover(); err != nil { + panicCounter.Inc() + logger.Fatalf(ctx, fmt.Sprintf("caught panic: %v [%+v]", err, string(debug.Stack()))) + } + }() + + databaseConfig := configuration.ApplicationConfiguration().GetDbConfig() + logConfig := logger.GetConfig() + + db, err := repositories.GetDB(ctx, databaseConfig, logConfig) + if err != nil { + logger.Fatal(ctx, err) + } + dbScope := adminScope.NewSubScope("database") + repo := repositories.NewGormRepo( + db, errors.NewPostgresErrorTransformer(adminScope.NewSubScope("errors")), dbScope) + + signalManager := manager.NewSignalManager(repo, adminScope.NewSubScope("signal_manager")) + + logger.Info(ctx, "Initializing a new SignalService") + return &SignalService{ + signalManager: signalManager, + metrics: NewSignalMetrics(adminScope), + } +} + +// Intercepts all admin requests to handle panics during execution. +func (s *SignalService) interceptPanic(ctx context.Context, request proto.Message) { + err := recover() + if err == nil { + return + } + + s.metrics.panicCounter.Inc() + logger.Fatalf(ctx, "panic-ed for request: [%+v] with err: %v with Stack: %v", request, err, string(debug.Stack())) +} + +func (s *SignalService) GetOrCreateSignal( + ctx context.Context, request *admin.SignalGetOrCreateRequest) (*admin.Signal, error) { + defer s.interceptPanic(ctx, request) + if request == nil { + return nil, status.Errorf(codes.InvalidArgument, "Incorrect request, nil requests not allowed") + } + var response *admin.Signal + var err error + s.metrics.create.Time(func() { + response, err = s.signalManager.GetOrCreateSignal(ctx, *request) + }) + if err != nil { + return nil, util.TransformAndRecordError(err, &s.metrics.create) + } + s.metrics.create.Success() + return response, nil +} + +func (s *SignalService) ListSignals( + ctx context.Context, request *admin.SignalListRequest) (*admin.SignalList, error) { + defer s.interceptPanic(ctx, request) + if request == nil { + return nil, status.Errorf(codes.InvalidArgument, "Incorrect request, nil requests not allowed") + } + var response *admin.SignalList + var err error + s.metrics.get.Time(func() { + response, err = s.signalManager.ListSignals(ctx, *request) + }) + if err != nil { + return nil, util.TransformAndRecordError(err, &s.metrics.get) + } + s.metrics.get.Success() + return response, nil +} + +func (s *SignalService) SetSignal( + ctx context.Context, request *admin.SignalSetRequest) (*admin.SignalSetResponse, error) { + defer s.interceptPanic(ctx, request) + if request == nil { + return nil, status.Errorf(codes.InvalidArgument, "Incorrect request, nil requests not allowed") + } + var response *admin.SignalSetResponse + var err error + s.metrics.get.Time(func() { + response, err = s.signalManager.SetSignal(ctx, *request) + }) + if err != nil { + return nil, util.TransformAndRecordError(err, &s.metrics.get) + } + s.metrics.get.Success() + return response, nil +} diff --git a/pkg/rpc/signal_service_test.go b/pkg/rpc/signal_service_test.go new file mode 100644 index 000000000..b987edfdd --- /dev/null +++ b/pkg/rpc/signal_service_test.go @@ -0,0 +1,148 @@ +package rpc + +import ( + "context" + "errors" + "testing" + + "github.com/flyteorg/flyteadmin/pkg/manager/mocks" + + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" + + mockScope "github.com/flyteorg/flytestdlib/promutils" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +func TestGetOrCreateSignal(t *testing.T) { + ctx := context.Background() + + t.Run("Happy", func(t *testing.T) { + signalManager := mocks.SignalInterface{} + signalManager.OnGetOrCreateSignalMatch(mock.Anything, mock.Anything).Return(&admin.Signal{}, nil) + + testScope := mockScope.NewTestScope() + mockServer := &SignalService{ + signalManager: &signalManager, + metrics: NewSignalMetrics(testScope), + } + + _, err := mockServer.GetOrCreateSignal(ctx, &admin.SignalGetOrCreateRequest{}) + assert.NoError(t, err) + }) + + t.Run("NilRequestError", func(t *testing.T) { + signalManager := mocks.SignalInterface{} + testScope := mockScope.NewTestScope() + mockServer := &SignalService{ + signalManager: &signalManager, + metrics: NewSignalMetrics(testScope), + } + + _, err := mockServer.GetOrCreateSignal(ctx, nil) + assert.Error(t, err) + }) + + t.Run("ManagerError", func(t *testing.T) { + signalManager := mocks.SignalInterface{} + signalManager.OnGetOrCreateSignalMatch(mock.Anything, mock.Anything).Return(nil, errors.New("foo")) + + testScope := mockScope.NewTestScope() + mockServer := &SignalService{ + signalManager: &signalManager, + metrics: NewSignalMetrics(testScope), + } + + _, err := mockServer.GetOrCreateSignal(ctx, &admin.SignalGetOrCreateRequest{}) + assert.Error(t, err) + }) +} + +func TestListSignals(t *testing.T) { + ctx := context.Background() + + t.Run("Happy", func(t *testing.T) { + signalManager := mocks.SignalInterface{} + signalManager.OnListSignalsMatch(mock.Anything, mock.Anything).Return(&admin.SignalList{}, nil) + + testScope := mockScope.NewTestScope() + mockServer := &SignalService{ + signalManager: &signalManager, + metrics: NewSignalMetrics(testScope), + } + + _, err := mockServer.ListSignals(ctx, &admin.SignalListRequest{}) + assert.NoError(t, err) + }) + + t.Run("NilRequestError", func(t *testing.T) { + signalManager := mocks.SignalInterface{} + testScope := mockScope.NewTestScope() + mockServer := &SignalService{ + signalManager: &signalManager, + metrics: NewSignalMetrics(testScope), + } + + _, err := mockServer.ListSignals(ctx, nil) + assert.Error(t, err) + }) + + t.Run("ManagerError", func(t *testing.T) { + signalManager := mocks.SignalInterface{} + signalManager.OnListSignalsMatch(mock.Anything, mock.Anything).Return(nil, errors.New("foo")) + + testScope := mockScope.NewTestScope() + mockServer := &SignalService{ + signalManager: &signalManager, + metrics: NewSignalMetrics(testScope), + } + + _, err := mockServer.ListSignals(ctx, &admin.SignalListRequest{}) + assert.Error(t, err) + }) +} + +func TestSetSignal(t *testing.T) { + ctx := context.Background() + + t.Run("Happy", func(t *testing.T) { + signalManager := mocks.SignalInterface{} + signalManager.OnSetSignalMatch(mock.Anything, mock.Anything).Return(&admin.SignalSetResponse{}, nil) + + testScope := mockScope.NewTestScope() + mockServer := &SignalService{ + signalManager: &signalManager, + metrics: NewSignalMetrics(testScope), + } + + _, err := mockServer.SetSignal(ctx, &admin.SignalSetRequest{}) + assert.NoError(t, err) + }) + + t.Run("NilRequestError", func(t *testing.T) { + signalManager := mocks.SignalInterface{} + testScope := mockScope.NewTestScope() + mockServer := &SignalService{ + signalManager: &signalManager, + metrics: NewSignalMetrics(testScope), + } + + _, err := mockServer.SetSignal(ctx, nil) + assert.Error(t, err) + }) + + t.Run("ManagerError", func(t *testing.T) { + signalManager := mocks.SignalInterface{} + signalManager.OnSetSignalMatch(mock.Anything, mock.Anything).Return(nil, errors.New("foo")) + + testScope := mockScope.NewTestScope() + mockServer := &SignalService{ + signalManager: &signalManager, + metrics: NewSignalMetrics(testScope), + } + + _, err := mockServer.SetSignal(ctx, &admin.SignalSetRequest{}) + assert.Error(t, err) + }) +} diff --git a/pkg/server/service.go b/pkg/server/service.go index 7e45fbb03..345072123 100644 --- a/pkg/server/service.go +++ b/pkg/server/service.go @@ -25,6 +25,7 @@ import ( "github.com/flyteorg/flyteadmin/auth/interfaces" "github.com/flyteorg/flyteadmin/pkg/common" "github.com/flyteorg/flyteadmin/pkg/config" + "github.com/flyteorg/flyteadmin/pkg/rpc" "github.com/flyteorg/flyteadmin/pkg/rpc/adminservice" runtimeIfaces "github.com/flyteorg/flyteadmin/pkg/runtime/interfaces" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/service" @@ -126,6 +127,8 @@ func newGRPCServer(ctx context.Context, pluginRegistry *plugins.Registry, cfg *c pluginRegistry.RegisterDefault(plugins.PluginIDDataProxy, dataProxySvc) service.RegisterDataProxyServiceServer(grpcServer, plugins.Get[service.DataProxyServiceServer](pluginRegistry, plugins.PluginIDDataProxy)) + service.RegisterSignalServiceServer(grpcServer, rpc.NewSignalServer(ctx, configuration, scope.NewSubScope("signal"))) + healthServer := health.NewServer() healthServer.SetServingStatus("flyteadmin", grpc_health_v1.HealthCheckResponse_SERVING) grpc_health_v1.RegisterHealthServer(grpcServer, healthServer)