diff --git a/pkg/async/notifications/factory.go b/pkg/async/notifications/factory.go index 836c1ff2f..a2fd4c357 100644 --- a/pkg/async/notifications/factory.go +++ b/pkg/async/notifications/factory.go @@ -3,6 +3,7 @@ package notifications import ( "context" "fmt" + "sync" "time" "github.com/flyteorg/flyteadmin/pkg/async" @@ -27,6 +28,9 @@ const maxRetries = 3 var enable64decoding = false +var msgChan chan []byte +var once sync.Once + type PublisherConfig struct { TopicName string } @@ -41,6 +45,13 @@ type EmailerConfig struct { BaseURL string } +// For sandbox only +func CreateMsgChan() { + once.Do(func() { + msgChan = make(chan []byte) + }) +} + func GetEmailer(config runtimeInterfaces.NotificationsConfig, scope promutils.Scope) interfaces.Emailer { // If an external email service is specified use that instead. // TODO: Handling of this is messy, see https://github.com/flyteorg/flyte/issues/1063 @@ -120,6 +131,9 @@ func NewNotificationsProcessor(config runtimeInterfaces.NotificationsConfig, sco } emailer = GetEmailer(config, scope) return implementations.NewGcpProcessor(sub, emailer, scope) + case common.Sandbox: + emailer = GetEmailer(config, scope) + return implementations.NewSandboxProcessor(msgChan, emailer) case common.Local: fallthrough default: @@ -171,6 +185,9 @@ func NewNotificationsPublisher(config runtimeInterfaces.NotificationsConfig, sco panic(err) } return implementations.NewPublisher(publisher, scope) + case common.Sandbox: + CreateMsgChan() + return implementations.NewSandboxPublisher(msgChan) case common.Local: fallthrough default: diff --git a/pkg/async/notifications/factory_test.go b/pkg/async/notifications/factory_test.go index 280f383f2..c10bab230 100644 --- a/pkg/async/notifications/factory_test.go +++ b/pkg/async/notifications/factory_test.go @@ -1,13 +1,32 @@ package notifications import ( + "context" "testing" + "github.com/flyteorg/flyteadmin/pkg/async/notifications/implementations" runtimeInterfaces "github.com/flyteorg/flyteadmin/pkg/runtime/interfaces" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" "github.com/flyteorg/flytestdlib/promutils" "github.com/stretchr/testify/assert" ) +var ( + scope = promutils.NewScope("test_sandbox_processor") + notificationsConfig = runtimeInterfaces.NotificationsConfig{ + Type: "sandbox", + } + testEmail = admin.EmailMessage{ + RecipientsEmail: []string{ + "a@example.com", + "b@example.com", + }, + SenderEmail: "no-reply@example.com", + SubjectLine: "Test email", + Body: "This is a sample email.", + } +) + func TestGetEmailer(t *testing.T) { defer func() { r := recover(); assert.NotNil(t, r) }() cfg := runtimeInterfaces.NotificationsConfig{ @@ -23,3 +42,18 @@ func TestGetEmailer(t *testing.T) { // shouldn't reach here t.Errorf("did not panic") } + +func TestNewNotificationPublisherAndProcessor(t *testing.T) { + testSandboxPublisher := NewNotificationsPublisher(notificationsConfig, scope) + assert.IsType(t, testSandboxPublisher, &implementations.SandboxPublisher{}) + testSandboxProcessor := NewNotificationsProcessor(notificationsConfig, scope) + assert.IsType(t, testSandboxProcessor, &implementations.SandboxProcessor{}) + + go func() { + testSandboxProcessor.StartProcessing() + }() + + assert.Nil(t, testSandboxPublisher.Publish(context.Background(), "TEST_NOTIFICATION", &testEmail)) + + assert.Nil(t, testSandboxProcessor.StopProcessing()) +} diff --git a/pkg/async/notifications/implementations/sandbox_processor.go b/pkg/async/notifications/implementations/sandbox_processor.go new file mode 100644 index 000000000..4b89f277a --- /dev/null +++ b/pkg/async/notifications/implementations/sandbox_processor.go @@ -0,0 +1,62 @@ +package implementations + +import ( + "context" + "time" + + "github.com/flyteorg/flyteadmin/pkg/async" + "github.com/flyteorg/flyteadmin/pkg/async/notifications/interfaces" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" + "github.com/flyteorg/flytestdlib/logger" + "github.com/golang/protobuf/proto" +) + +type SandboxProcessor struct { + email interfaces.Emailer + subChan <-chan []byte +} + +func (p *SandboxProcessor) StartProcessing() { + for { + logger.Warningf(context.Background(), "Starting SandBox notifications processor") + err := p.run() + logger.Errorf(context.Background(), "error with running processor err: [%v] ", err) + time.Sleep(async.RetryDelay) + } +} + +func (p *SandboxProcessor) run() error { + var emailMessage admin.EmailMessage + + for { + select { + case msg := <-p.subChan: + err := proto.Unmarshal(msg, &emailMessage) + if err != nil { + logger.Errorf(context.Background(), "error with unmarshalling message [%v]", err) + return err + } + + err = p.email.SendEmail(context.Background(), emailMessage) + if err != nil { + logger.Errorf(context.Background(), "Error sending an email message for message [%s] with emailM with err: %v", emailMessage.String(), err) + return err + } + default: + logger.Debugf(context.Background(), "no message to process") + return nil + } + } +} + +func (p *SandboxProcessor) StopProcessing() error { + logger.Debug(context.Background(), "call to sandbox stop processing.") + return nil +} + +func NewSandboxProcessor(subChan <-chan []byte, emailer interfaces.Emailer) interfaces.Processor { + return &SandboxProcessor{ + subChan: subChan, + email: emailer, + } +} diff --git a/pkg/async/notifications/implementations/sandbox_processor_test.go b/pkg/async/notifications/implementations/sandbox_processor_test.go new file mode 100644 index 000000000..6bbee8cf5 --- /dev/null +++ b/pkg/async/notifications/implementations/sandbox_processor_test.go @@ -0,0 +1,84 @@ +package implementations + +import ( + "context" + "testing" + "time" + + "github.com/flyteorg/flyteadmin/pkg/async/notifications/mocks" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" + "github.com/pkg/errors" + "github.com/stretchr/testify/assert" +) + +var mockSandboxEmailer mocks.MockEmailer + +func TestSandboxProcessor_StartProcessingSuccess(t *testing.T) { + msgChan := make(chan []byte, 1) + msgChan <- msg + testSandboxProcessor := NewSandboxProcessor(msgChan, &mockSandboxEmailer) + + sendEmailValidationFunc := func(ctx context.Context, email admin.EmailMessage) error { + assert.Equal(t, testEmail.Body, email.Body) + assert.Equal(t, testEmail.RecipientsEmail, email.RecipientsEmail) + assert.Equal(t, testEmail.SubjectLine, email.SubjectLine) + assert.Equal(t, testEmail.SenderEmail, email.SenderEmail) + return nil + } + + mockSandboxEmailer.SetSendEmailFunc(sendEmailValidationFunc) + assert.Nil(t, testSandboxProcessor.(*SandboxProcessor).run()) +} + +func TestSandboxProcessor_StartProcessingNoMessage(t *testing.T) { + msgChan := make(chan []byte, 1) + testSandboxProcessor := NewSandboxProcessor(msgChan, &mockSandboxEmailer) + go testSandboxProcessor.StartProcessing() + time.Sleep(1 * time.Second) +} + +func TestSandboxProcessor_StartProcessingError(t *testing.T) { + msgChan := make(chan []byte, 1) + msgChan <- msg + + emailError := errors.New("error running processor") + sendEmailValidationFunc := func(ctx context.Context, email admin.EmailMessage) error { + return emailError + } + mockSandboxEmailer.SetSendEmailFunc(sendEmailValidationFunc) + + testSandboxProcessor := NewSandboxProcessor(msgChan, &mockSandboxEmailer) + go testSandboxProcessor.StartProcessing() + + // give time to receive the err in StartProcessing + time.Sleep(1 * time.Second) + assert.Zero(t, len(msgChan)) +} + +func TestSandboxProcessor_StartProcessingMessageError(t *testing.T) { + msgChan := make(chan []byte, 1) + invalidProtoMessage := []byte("invalid message") + msgChan <- invalidProtoMessage + testSandboxProcessor := NewSandboxProcessor(msgChan, &mockSandboxEmailer) + assert.NotNil(t, testSandboxProcessor.(*SandboxProcessor).run()) +} + +func TestSandboxProcessor_StartProcessingEmailError(t *testing.T) { + msgChan := make(chan []byte, 1) + msgChan <- msg + testSandboxProcessor := NewSandboxProcessor(msgChan, &mockSandboxEmailer) + + emailError := errors.New("error sending email") + sendEmailValidationFunc := func(ctx context.Context, email admin.EmailMessage) error { + return emailError + } + + mockSandboxEmailer.SetSendEmailFunc(sendEmailValidationFunc) + assert.NotNil(t, testSandboxProcessor.(*SandboxProcessor).run()) +} + +func TestSandboxProcessor_StopProcessing(t *testing.T) { + msgChan := make(chan []byte, 1) + testSandboxProcessor := NewSandboxProcessor(msgChan, &mockSandboxEmailer) + assert.Nil(t, testSandboxProcessor.StopProcessing()) +} diff --git a/pkg/async/notifications/implementations/sandbox_publisher.go b/pkg/async/notifications/implementations/sandbox_publisher.go new file mode 100644 index 000000000..ab94b1ee6 --- /dev/null +++ b/pkg/async/notifications/implementations/sandbox_publisher.go @@ -0,0 +1,33 @@ +package implementations + +import ( + "context" + + "github.com/flyteorg/flytestdlib/logger" + "github.com/golang/protobuf/proto" +) + +type SandboxPublisher struct { + pubChan chan<- []byte +} + +func (p *SandboxPublisher) Publish(ctx context.Context, notificationType string, msg proto.Message) error { + logger.Debugf(ctx, "Publishing the following message [%s]", msg.String()) + + data, err := proto.Marshal(msg) + + if err != nil { + logger.Errorf(ctx, "Failed to publish a message with key [%s] and message [%s] and error: %v", notificationType, msg.String(), err) + return err + } + + p.pubChan <- data + + return nil +} + +func NewSandboxPublisher(pubChan chan<- []byte) *SandboxPublisher { + return &SandboxPublisher{ + pubChan: pubChan, + } +} diff --git a/pkg/async/notifications/implementations/sandbox_publisher_test.go b/pkg/async/notifications/implementations/sandbox_publisher_test.go new file mode 100644 index 000000000..5a73186c1 --- /dev/null +++ b/pkg/async/notifications/implementations/sandbox_publisher_test.go @@ -0,0 +1,36 @@ +package implementations + +import ( + "context" + "testing" + + "github.com/pkg/errors" + "github.com/stretchr/testify/assert" +) + +// mockMessage is a dummy proto message that will always fail to marshal +type mockMessage struct{} + +func (m *mockMessage) Reset() {} +func (m *mockMessage) String() string { return "mockMessage" } +func (m *mockMessage) ProtoMessage() {} +func (m *mockMessage) Marshal() ([]byte, error) { return nil, errors.New("forced marshal error") } + +func TestSandboxPublisher_Publish(t *testing.T) { + msgChan := make(chan []byte, 1) + publisher := NewSandboxPublisher(msgChan) + + err := publisher.Publish(context.Background(), "NOTIFICATION_TYPE", &testEmail) + + assert.NotZero(t, len(msgChan)) + assert.Nil(t, err) +} + +func TestSandboxPublisher_PublishMarshalError(t *testing.T) { + msgChan := make(chan []byte, 1) + publisher := NewSandboxPublisher(msgChan) + + err := publisher.Publish(context.Background(), "testMarshallError", &mockMessage{}) + assert.Error(t, err) + assert.Equal(t, "forced marshal error", err.Error()) +} diff --git a/pkg/common/cloud.go b/pkg/common/cloud.go index cba0a6879..93f3669a5 100644 --- a/pkg/common/cloud.go +++ b/pkg/common/cloud.go @@ -5,8 +5,9 @@ package common type CloudProvider = string const ( - AWS CloudProvider = "aws" - GCP CloudProvider = "gcp" - Local CloudProvider = "local" - None CloudProvider = "none" + AWS CloudProvider = "aws" + GCP CloudProvider = "gcp" + Sandbox CloudProvider = "sandbox" + Local CloudProvider = "local" + None CloudProvider = "none" )