diff --git a/AGENTS.md b/AGENTS.md index f28cdc5..4b7949d 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -25,6 +25,19 @@ - Prefer Go naming conventions (CamelCase for exported, lowerCamel for unexported). - Keep package names short and domain-focused (e.g., `repository`, `service`). +## Multi-Tenancy & Data Isolation + +- Treat `tenant_id` as a security boundary, not a convenience filter. Tenant-owned data must never be read, enqueued, dispatched, cached, searched, embedded, or deleted across tenants. +- Exception: GDPR/right-to-erasure flows may intentionally delete all records for a data subject across tenants when that is the documented API contract. Make that all-tenant behavior explicit in the API docs, service/repository names or comments, logs, and tests; do not reuse it for normal tenant-owned workflows. +- When making a model, migration, API request, or repository change involving tenant-owned data, audit every downstream path that carries or derives from that data: handlers, services, repositories, message publishers, River job args, workers, webhook payloads, search, embeddings, bulk operations, logs, and metrics. +- Tenant access rules must be consistent across every path that can observe, mutate, derive from, or act on the same resource. If one API endpoint, repository method, search path, webhook dispatch path, worker, backfill, bulk operation, or export path requires tenant scope, every alternate path for that resource must enforce the same tenant boundary. +- Do not model `tenant_id` as an optional filter for tenant-owned resources. Prefer required tenant parameters in service/repository method signatures (`tenantID string`, not `*string`) unless the domain explicitly supports global resources and documents that behavior. +- Prefer tenant-aware repository/service methods for tenant-owned workflows. Avoid adding broad helpers that return all enabled/all matching resources when the caller is dispatching, processing, deriving, exporting, or exposing tenant data. +- Async jobs must carry the tenant boundary when the source data has one, and workers must re-check tenant scope before doing side effects. Do not rely only on enqueue-time filtering. +- Global resources may intentionally have `tenant_id = NULL` only when the domain explicitly documents them as non-tenant-owned. Webhooks are tenant-owned and must require a non-empty `tenant_id`; a missing tenant on event data must not match any webhook. +- When changing access rules for a tenant-owned model, search for all alternate access paths by resource name and by derived side effects: list, get, update, delete, bulk delete, webhook fan-out, River jobs, workers, embeddings, search, exports, cache invalidation, logs, and metrics. +- For any tenant-scoping change, include verification at the boundary where the leak could happen: database query behavior, service fan-out, worker execution, and API behavior when relevant. Include at least one alternate-path regression test proving that data allowed through the primary path cannot leak through async dispatch, bulk operations, derived indexes, exports, or background workers. Tests are the evidence; the invariant belongs in the architecture. + ## Testing Guidelines - Tests live under `tests/` and are run with `go test ./tests/...`. - Name test files `*_test.go` and test functions `TestXxx`. diff --git a/internal/api/handlers/feedback_records_handler.go b/internal/api/handlers/feedback_records_handler.go index 0524ec7..2979e82 100644 --- a/internal/api/handlers/feedback_records_handler.go +++ b/internal/api/handlers/feedback_records_handler.go @@ -6,6 +6,7 @@ import ( "encoding/json" "errors" "fmt" + "log/slog" "net/http" "github.com/google/uuid" @@ -24,7 +25,7 @@ type FeedbackRecordsService interface { ListFeedbackRecords(ctx context.Context, filters *models.ListFeedbackRecordsFilters) (*models.ListFeedbackRecordsResponse, error) UpdateFeedbackRecord(ctx context.Context, id uuid.UUID, req *models.UpdateFeedbackRecordRequest) (*models.FeedbackRecord, error) DeleteFeedbackRecord(ctx context.Context, id uuid.UUID) error - BulkDeleteFeedbackRecords(ctx context.Context, userID string, tenantID *string) (int, error) + BulkDeleteFeedbackRecords(ctx context.Context, filters *models.BulkDeleteFilters) (int, error) } // FeedbackRecordsHandler handles HTTP requests for feedback records. @@ -59,6 +60,12 @@ func (h *FeedbackRecordsHandler) Create(w http.ResponseWriter, r *http.Request) record, err := h.service.CreateFeedbackRecord(r.Context(), &req) if err != nil { + if errors.Is(err, huberrors.ErrValidation) { + validation.RespondValidationError(w, err) + + return + } + if errors.Is(err, huberrors.ErrNotFound) { response.RespondNotFound(w, "Feedback record not found") @@ -219,7 +226,7 @@ func (h *FeedbackRecordsHandler) Delete(w http.ResponseWriter, r *http.Request) w.WriteHeader(http.StatusNoContent) } -// BulkDelete handles DELETE /v1/feedback-records?user_id=. +// BulkDelete handles DELETE /v1/feedback-records?user_id=[&tenant_id=]. func (h *FeedbackRecordsHandler) BulkDelete(w http.ResponseWriter, r *http.Request) { filters := &models.BulkDeleteFilters{} @@ -230,8 +237,27 @@ func (h *FeedbackRecordsHandler) BulkDelete(w http.ResponseWriter, r *http.Reque return } - deletedCount, err := h.service.BulkDeleteFeedbackRecords(r.Context(), filters.UserID, filters.TenantID) + deletedCount, err := h.service.BulkDeleteFeedbackRecords(r.Context(), filters) if err != nil { + if errors.Is(err, huberrors.ErrValidation) { + validation.RespondValidationError(w, err) + + return + } + + var tenantID string + if filters.TenantID != nil { + tenantID = *filters.TenantID + } + + slog.Error("Failed to bulk delete feedback records", // #nosec G706 -- slog key-values + "method", r.Method, + "path", r.URL.Path, + "user_id", filters.UserID, + "tenant_id", tenantID, + "error", err, + ) + response.RespondInternalServerError(w, "An unexpected error occurred") return diff --git a/internal/api/handlers/feedback_records_handler_test.go b/internal/api/handlers/feedback_records_handler_test.go index 44aa1a7..ab05899 100644 --- a/internal/api/handlers/feedback_records_handler_test.go +++ b/internal/api/handlers/feedback_records_handler_test.go @@ -1,6 +1,7 @@ package handlers import ( + "bytes" "context" "encoding/json" "net/http" @@ -11,17 +12,23 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/formbricks/hub/internal/huberrors" "github.com/formbricks/hub/internal/models" ) // mockFeedbackRecordsService mocks FeedbackRecordsService for handler tests. type mockFeedbackRecordsService struct { - bulkDeleteFunc func(ctx context.Context, userID string, tenantID *string) (int, error) + createFunc func(ctx context.Context, req *models.CreateFeedbackRecordRequest) (*models.FeedbackRecord, error) + bulkDeleteFunc func(ctx context.Context, filters *models.BulkDeleteFilters) (int, error) } func (m *mockFeedbackRecordsService) CreateFeedbackRecord( - context.Context, *models.CreateFeedbackRecordRequest, + ctx context.Context, req *models.CreateFeedbackRecordRequest, ) (*models.FeedbackRecord, error) { + if m.createFunc != nil { + return m.createFunc(ctx, req) + } + return nil, nil } @@ -45,9 +52,11 @@ func (m *mockFeedbackRecordsService) DeleteFeedbackRecord(context.Context, uuid. return nil } -func (m *mockFeedbackRecordsService) BulkDeleteFeedbackRecords(ctx context.Context, userID string, tenantID *string) (int, error) { +func (m *mockFeedbackRecordsService) BulkDeleteFeedbackRecords( + ctx context.Context, filters *models.BulkDeleteFilters, +) (int, error) { if m.bulkDeleteFunc != nil { - return m.bulkDeleteFunc(ctx, userID, tenantID) + return m.bulkDeleteFunc(ctx, filters) } return 0, nil @@ -67,11 +76,101 @@ func TestFeedbackRecordsHandler_List(t *testing.T) { }) } +func TestFeedbackRecordsHandler_Create(t *testing.T) { + t.Run("success returns created record", func(t *testing.T) { + recordID := uuid.Must(uuid.NewV7()) + mock := &mockFeedbackRecordsService{ + createFunc: func(_ context.Context, req *models.CreateFeedbackRecordRequest) (*models.FeedbackRecord, error) { + assert.Equal(t, "org-123", req.TenantID) + + return &models.FeedbackRecord{ + ID: recordID, + SourceType: req.SourceType, + FieldID: req.FieldID, + FieldType: req.FieldType, + TenantID: req.TenantID, + SubmissionID: req.SubmissionID, + }, nil + }, + } + handler := NewFeedbackRecordsHandler(mock) + + req := httptest.NewRequestWithContext( + context.Background(), http.MethodPost, "http://test/v1/feedback-records", feedbackRecordCreateBody(t, "org-123"), + ) + rec := httptest.NewRecorder() + + handler.Create(rec, req) + + assert.Equal(t, http.StatusCreated, rec.Code) + + var got models.FeedbackRecord + + err := json.Unmarshal(rec.Body.Bytes(), &got) + require.NoError(t, err) + assert.Equal(t, recordID, got.ID) + assert.Equal(t, "org-123", got.TenantID) + }) + + t.Run("service validation error returns bad request", func(t *testing.T) { + mock := &mockFeedbackRecordsService{ + createFunc: func(_ context.Context, _ *models.CreateFeedbackRecordRequest) (*models.FeedbackRecord, error) { + return nil, huberrors.NewValidationError("tenant_id", "tenant_id is required and cannot be empty") + }, + } + handler := NewFeedbackRecordsHandler(mock) + + req := httptest.NewRequestWithContext( + context.Background(), http.MethodPost, "http://test/v1/feedback-records", feedbackRecordCreateBody(t, " "), + ) + rec := httptest.NewRecorder() + + handler.Create(rec, req) + + assert.Equal(t, http.StatusBadRequest, rec.Code) + assert.Contains(t, rec.Header().Get("Content-Type"), "application/problem+json") + }) + + t.Run("service conflict returns conflict", func(t *testing.T) { + mock := &mockFeedbackRecordsService{ + createFunc: func(_ context.Context, _ *models.CreateFeedbackRecordRequest) (*models.FeedbackRecord, error) { + return nil, huberrors.NewConflictError("duplicate feedback record") + }, + } + handler := NewFeedbackRecordsHandler(mock) + + req := httptest.NewRequestWithContext( + context.Background(), http.MethodPost, "http://test/v1/feedback-records", feedbackRecordCreateBody(t, "org-123"), + ) + rec := httptest.NewRecorder() + + handler.Create(rec, req) + + assert.Equal(t, http.StatusConflict, rec.Code) + }) +} + +func feedbackRecordCreateBody(t *testing.T, tenantID string) *bytes.Reader { + t.Helper() + + body, err := json.Marshal(map[string]any{ + "source_type": "formbricks", + "submission_id": "submission-1", + "tenant_id": tenantID, + "field_id": "feedback", + "field_type": "text", + }) + require.NoError(t, err) + + return bytes.NewReader(body) +} + func TestFeedbackRecordsHandler_BulkDelete(t *testing.T) { t.Run("success returns 200 with deleted_count and message", func(t *testing.T) { mock := &mockFeedbackRecordsService{ - bulkDeleteFunc: func(_ context.Context, userID string, _ *string) (int, error) { - assert.Equal(t, "user-123", userID) + bulkDeleteFunc: func(_ context.Context, filters *models.BulkDeleteFilters) (int, error) { + assert.Equal(t, "user-123", filters.UserID) + assert.Nil(t, filters.TenantID) return 3, nil }, @@ -94,14 +193,12 @@ func TestFeedbackRecordsHandler_BulkDelete(t *testing.T) { assert.Equal(t, "Successfully deleted 3 feedback records", resp.Message) }) - t.Run("success with tenant_id passes tenant to service", func(t *testing.T) { - var capturedTenantID *string - + t.Run("optional tenant_id query parameter is passed to service", func(t *testing.T) { mock := &mockFeedbackRecordsService{ - bulkDeleteFunc: func(_ context.Context, userID string, tenantID *string) (int, error) { - assert.Equal(t, "user-456", userID) - - capturedTenantID = tenantID + bulkDeleteFunc: func(_ context.Context, filters *models.BulkDeleteFilters) (int, error) { + assert.Equal(t, "user-456", filters.UserID) + require.NotNil(t, filters.TenantID) + assert.Equal(t, "tenant-a", *filters.TenantID) return 1, nil }, @@ -115,8 +212,19 @@ func TestFeedbackRecordsHandler_BulkDelete(t *testing.T) { handler.BulkDelete(rec, req) assert.Equal(t, http.StatusOK, rec.Code) - require.NotNil(t, capturedTenantID) - assert.Equal(t, "tenant-a", *capturedTenantID) + }) + + t.Run("empty tenant_id returns bad request", func(t *testing.T) { + mock := &mockFeedbackRecordsService{} + handler := NewFeedbackRecordsHandler(mock) + + req := httptest.NewRequestWithContext(context.Background(), + http.MethodDelete, "http://test/v1/feedback-records?user_id=user-123&tenant_id=", http.NoBody) + rec := httptest.NewRecorder() + + handler.BulkDelete(rec, req) + + assert.Equal(t, http.StatusBadRequest, rec.Code) }) t.Run("missing user_id returns bad request", func(t *testing.T) { @@ -146,7 +254,7 @@ func TestFeedbackRecordsHandler_BulkDelete(t *testing.T) { t.Run("service error returns 500", func(t *testing.T) { mock := &mockFeedbackRecordsService{ - bulkDeleteFunc: func(_ context.Context, _ string, _ *string) (int, error) { + bulkDeleteFunc: func(_ context.Context, _ *models.BulkDeleteFilters) (int, error) { return 0, assert.AnError }, } @@ -163,7 +271,7 @@ func TestFeedbackRecordsHandler_BulkDelete(t *testing.T) { t.Run("zero deleted returns 200 with deleted_count 0", func(t *testing.T) { mock := &mockFeedbackRecordsService{ - bulkDeleteFunc: func(_ context.Context, _ string, _ *string) (int, error) { + bulkDeleteFunc: func(_ context.Context, _ *models.BulkDeleteFilters) (int, error) { return 0, nil }, } diff --git a/internal/models/events.go b/internal/models/events.go new file mode 100644 index 0000000..4a9b8f2 --- /dev/null +++ b/internal/models/events.go @@ -0,0 +1,9 @@ +package models + +import "github.com/google/uuid" + +// DeletedIDsEventData is the tenant-aware payload for resource deletion events. +type DeletedIDsEventData struct { + TenantID string `json:"tenant_id"` + IDs []uuid.UUID `json:"ids"` +} diff --git a/internal/models/feedback_records.go b/internal/models/feedback_records.go index 39e601e..1274910 100644 --- a/internal/models/feedback_records.go +++ b/internal/models/feedback_records.go @@ -199,7 +199,7 @@ type ListFeedbackRecordsResponse struct { // BulkDeleteFilters represents query parameters for bulk delete operation. type BulkDeleteFilters struct { UserID string `form:"user_id" validate:"required,no_null_bytes,min=1"` - TenantID *string `form:"tenant_id" validate:"omitempty,no_null_bytes"` + TenantID *string `form:"tenant_id" validate:"omitempty,no_null_bytes,min=1"` } // BulkDeleteResponse represents the response for bulk delete operation. @@ -207,3 +207,9 @@ type BulkDeleteResponse struct { DeletedCount int64 `json:"deleted_count"` Message string `json:"message"` } + +// DeletedFeedbackRecordsByTenant groups deleted feedback record IDs by tenant. +type DeletedFeedbackRecordsByTenant struct { + TenantID string + IDs []uuid.UUID +} diff --git a/internal/models/webhooks.go b/internal/models/webhooks.go index d5e84bd..32f6829 100644 --- a/internal/models/webhooks.go +++ b/internal/models/webhooks.go @@ -34,6 +34,12 @@ type Webhook struct { DisabledAt *time.Time `json:"disabled_at,omitempty"` } +// DeletedWebhook is the minimal data returned after deleting a webhook. +type DeletedWebhook struct { + ID uuid.UUID + TenantID *string +} + // MarshalJSON converts []datatypes.EventType to JSON string array. func (w *Webhook) MarshalJSON() ([]byte, error) { type Alias Webhook @@ -183,7 +189,7 @@ type CreateWebhookRequest struct { URL string `json:"url" validate:"required,no_null_bytes,http_url,min=1,max=2048"` SigningKey string `json:"signing_key,omitempty" validate:"omitempty,max=255"` Enabled *bool `json:"enabled,omitempty"` - TenantID *string `json:"tenant_id,omitempty" validate:"omitempty,no_null_bytes,max=255"` + TenantID *string `json:"tenant_id" validate:"required,no_null_bytes,min=1,max=255"` EventTypes []datatypes.EventType `json:"event_types,omitempty"` } @@ -219,7 +225,7 @@ type UpdateWebhookRequest struct { URL *string `json:"url,omitempty" validate:"omitempty,no_null_bytes,http_url,min=1,max=2048"` SigningKey *string `json:"signing_key,omitempty" validate:"omitempty,no_null_bytes,min=1,max=255"` Enabled *bool `json:"enabled,omitempty"` - TenantID *string `json:"tenant_id,omitempty" validate:"omitempty,no_null_bytes,max=255"` + TenantID *string `json:"tenant_id,omitempty" validate:"omitempty,no_null_bytes,min=1,max=255"` EventTypes *[]datatypes.EventType `json:"event_types,omitempty"` DisabledReason *string `json:"-"` // read-only; set by system when disabling DisabledAt *time.Time `json:"-"` // read-only; set by system when disabling diff --git a/internal/observability/names.go b/internal/observability/names.go index 794b26e..cb16de8 100644 --- a/internal/observability/names.go +++ b/internal/observability/names.go @@ -43,8 +43,9 @@ func AllowedEventTypes() []string { // allowedProviderReasons for hub_webhook_provider_errors_total (bounded cardinality). var allowedProviderReasons = map[string]bool{ - "list_failed": true, - "enqueue_failed": true, + "list_failed": true, + "enqueue_failed": true, + "missing_tenant_id": true, } // allowedDeliveryStatuses for hub_webhook_deliveries_total and hub_webhook_delivery_duration_seconds. @@ -63,6 +64,8 @@ var allowedDisabledReasons = map[string]bool{ // allowedDispatchReasons for hub_webhook_dispatch_errors_total. var allowedDispatchReasons = map[string]bool{ "get_webhook_failed": true, + "missing_tenant_id": true, + "tenant_mismatch": true, } // allowedEmbeddingProviderReasons for hub_embedding_provider_errors_total. diff --git a/internal/observability/names_test.go b/internal/observability/names_test.go new file mode 100644 index 0000000..7300939 --- /dev/null +++ b/internal/observability/names_test.go @@ -0,0 +1,21 @@ +package observability + +import "testing" + +func TestAllowedDispatchReasonIncludesTenantBoundaryReasons(t *testing.T) { + reasons := []string{ + "get_webhook_failed", + "missing_tenant_id", + "tenant_mismatch", + } + + for _, reason := range reasons { + if !AllowedDispatchReason(reason) { + t.Errorf("AllowedDispatchReason(%q) = false, want true", reason) + } + + if got := NormalizeReason(reason, AllowedDispatchReason); got != reason { + t.Errorf("NormalizeReason(%q) = %q, want %q", reason, got, reason) + } + } +} diff --git a/internal/repository/feedback_records_repository.go b/internal/repository/feedback_records_repository.go index 678c122..5739813 100644 --- a/internal/repository/feedback_records_repository.go +++ b/internal/repository/feedback_records_repository.go @@ -394,22 +394,25 @@ func (r *FeedbackRecordsRepository) Delete(ctx context.Context, id uuid.UUID) er return nil } -// BulkDelete deletes all feedback records matching user_id and optional tenant_id. -// It returns the deleted IDs (via RETURNING id) so callers can e.g. publish events. -func (r *FeedbackRecordsRepository) BulkDelete(ctx context.Context, userID string, tenantID *string) ([]uuid.UUID, error) { +// BulkDelete deletes all feedback records matching user_id. +// When tenant_id is provided, deletion is restricted to that tenant; otherwise all user records are deleted. +// It returns deleted IDs grouped by tenant so callers can publish tenant-scoped side effects. +func (r *FeedbackRecordsRepository) BulkDelete( + ctx context.Context, filters *models.BulkDeleteFilters, +) ([]models.DeletedFeedbackRecordsByTenant, error) { query := ` DELETE FROM feedback_records WHERE user_id = $1` - args := []any{userID} - argCount := 2 + args := []any{filters.UserID} - if tenantID != nil { - query += fmt.Sprintf(" AND tenant_id = $%d", argCount) + if filters.TenantID != nil { + query += ` AND tenant_id = $2` - args = append(args, *tenantID) + args = append(args, *filters.TenantID) } - query += ` RETURNING id` + query += ` + RETURNING id, tenant_id` rows, err := r.db.Query(ctx, query, args...) if err != nil { @@ -417,22 +420,34 @@ func (r *FeedbackRecordsRepository) BulkDelete(ctx context.Context, userID strin } defer rows.Close() - var ids []uuid.UUID + groups := make([]models.DeletedFeedbackRecordsByTenant, 0) + groupIndexByTenant := make(map[string]int) for rows.Next() { - var id uuid.UUID - if err := rows.Scan(&id); err != nil { + var ( + id uuid.UUID + tenantID string + ) + + if err := rows.Scan(&id, &tenantID); err != nil { return nil, fmt.Errorf("failed to scan deleted feedback record id: %w", err) } - ids = append(ids, id) + groupIndex, ok := groupIndexByTenant[tenantID] + if !ok { + groupIndex = len(groups) + groupIndexByTenant[tenantID] = groupIndex + groups = append(groups, models.DeletedFeedbackRecordsByTenant{TenantID: tenantID}) + } + + groups[groupIndex].IDs = append(groups[groupIndex].IDs, id) } if err := rows.Err(); err != nil { return nil, fmt.Errorf("error iterating bulk delete result: %w", err) } - return ids, nil + return groups, nil } // fetchFeedbackRecords executes the given query and scans rows into FeedbackRecord slices. diff --git a/internal/repository/feedback_records_repository_test.go b/internal/repository/feedback_records_repository_test.go index bec5f99..d54efe0 100644 --- a/internal/repository/feedback_records_repository_test.go +++ b/internal/repository/feedback_records_repository_test.go @@ -6,9 +6,9 @@ import ( // BulkDelete is tested by integration tests in tests/integration_test.go: // - TestFeedbackRecordsRepository_BulkDelete exercises the repository directly and asserts -// the returned slice of deleted records. +// the optional tenant filter and tenant-grouped return values. // - TestBulkDeleteFeedbackRecords exercises the full stack (handler, service, repo) including -// tenant_id filter and response shape. +// tenant-scoped deletion, GDPR user_id erasure across tenants, and response shape. func TestFeedbackRecordsRepository_Package(_ *testing.T) { // No DB in unit tests; BulkDelete coverage is in tests/. } diff --git a/internal/repository/webhooks_repository.go b/internal/repository/webhooks_repository.go index 8509153..2a32a13 100644 --- a/internal/repository/webhooks_repository.go +++ b/internal/repository/webhooks_repository.go @@ -262,16 +262,8 @@ func (r *WebhooksRepository) Update(ctx context.Context, id uuid.UUID, req *mode } if req.TenantID != nil { - // Empty string clears tenant_id (store as NULL) - var val any - if *req.TenantID == "" { - val = nil - } else { - val = *req.TenantID - } - updates = append(updates, fmt.Sprintf("tenant_id = $%d", argCount)) - args = append(args, val) + args = append(args, *req.TenantID) argCount++ } @@ -341,20 +333,26 @@ func (r *WebhooksRepository) Update(ctx context.Context, id uuid.UUID, req *mode return &webhook, nil } -// Delete removes a webhook. -func (r *WebhooksRepository) Delete(ctx context.Context, id uuid.UUID) error { - query := `DELETE FROM webhooks WHERE id = $1` +// Delete removes a webhook and returns the deleted tenant boundary for side effects. +func (r *WebhooksRepository) Delete(ctx context.Context, id uuid.UUID) (*models.DeletedWebhook, error) { + query := ` + DELETE FROM webhooks + WHERE id = $1 + RETURNING id, tenant_id + ` + + var webhook models.DeletedWebhook - result, err := r.db.Exec(ctx, query, id) + err := r.db.QueryRow(ctx, query, id).Scan(&webhook.ID, &webhook.TenantID) if err != nil { - return fmt.Errorf("failed to delete webhook: %w", err) - } + if errors.Is(err, pgx.ErrNoRows) { + return nil, huberrors.NewNotFoundError("webhook", "webhook not found") + } - if result.RowsAffected() == 0 { - return huberrors.NewNotFoundError("webhook", "webhook not found") + return nil, fmt.Errorf("failed to delete webhook: %w", err) } - return nil + return &webhook, nil } // parseDBEventTypes converts a DB string slice to []datatypes.EventType. Returns (nil, nil) for nil input. @@ -376,65 +374,43 @@ func parseDBEventTypes(ss []string) ([]datatypes.EventType, error) { return out, nil } -// ListEnabled retrieves all enabled webhooks (unbounded; used for delivery fan-out). -func (r *WebhooksRepository) ListEnabled(ctx context.Context) ([]models.Webhook, error) { - query := webhooksListSelect + ` WHERE enabled = true ORDER BY created_at DESC, id ASC` - - webhooks, err := r.fetchWebhooks(ctx, query) - if err != nil { - return nil, fmt.Errorf("list enabled webhooks: %w", err) - } - - return webhooks, nil -} - -// ListEnabledForEventType retrieves all enabled webhooks that should receive a specific event type. -// Order is deterministic (ORDER BY id) so delivery behavior is consistent. -func (r *WebhooksRepository) ListEnabledForEventType(ctx context.Context, eventType string) ([]models.Webhook, error) { - query := ` - SELECT id, url, signing_key, enabled, tenant_id, created_at, updated_at, event_types, disabled_reason, disabled_at - FROM webhooks +const listEnabledForEventTypeSelect = ` + SELECT id, url, signing_key, enabled, tenant_id, created_at, updated_at, event_types, disabled_reason, disabled_at + FROM webhooks WHERE enabled = true AND (event_types IS NULL OR event_types = '{}' OR event_types @> ARRAY[$1]::VARCHAR(64)[]) - ORDER BY id ` - rows, err := r.db.Query(ctx, query, eventType) +// ListEnabledForEventTypeAndTenant retrieves enabled webhooks for an event type and tenant boundary. +// Webhooks match only the same tenant. A missing tenantID matches nothing. +func (r *WebhooksRepository) ListEnabledForEventTypeAndTenant( + ctx context.Context, eventType string, tenantID *string, +) ([]models.Webhook, error) { + query, args := listEnabledForEventTypeAndTenantQuery(eventType, tenantID) + + webhooks, err := r.fetchWebhooks(ctx, query, args...) if err != nil { - return nil, fmt.Errorf("failed to list enabled webhooks for event type: %w", err) + return nil, fmt.Errorf("list enabled webhooks for event type and tenant: %w", err) } - defer rows.Close() - - webhooks := []models.Webhook{} - for rows.Next() { - var ( - webhook models.Webhook - dbEventTypes []string - ) + return webhooks, nil +} - err := rows.Scan( - &webhook.ID, &webhook.URL, &webhook.SigningKey, &webhook.Enabled, - &webhook.TenantID, &webhook.CreatedAt, &webhook.UpdatedAt, &dbEventTypes, - &webhook.DisabledReason, &webhook.DisabledAt, - ) - if err != nil { - return nil, fmt.Errorf("failed to scan webhook: %w", err) - } +func listEnabledForEventTypeAndTenantQuery(eventType string, tenantID *string) (string, []any) { + query := listEnabledForEventTypeSelect + args := []any{eventType} - webhook.EventTypes, err = parseDBEventTypes(dbEventTypes) - if err != nil { - return nil, err - } + if tenantID == nil { + query += ` AND FALSE` + } else { + query += ` AND tenant_id = $2` - webhooks = append(webhooks, webhook) + args = append(args, *tenantID) } - if err := rows.Err(); err != nil { - return nil, fmt.Errorf("error iterating webhooks: %w", err) - } + query += ` ORDER BY id` - return webhooks, nil + return query, args } // fetchWebhooks executes the given query and scans rows into Webhook slices. diff --git a/internal/repository/webhooks_repository_test.go b/internal/repository/webhooks_repository_test.go new file mode 100644 index 0000000..5ea7704 --- /dev/null +++ b/internal/repository/webhooks_repository_test.go @@ -0,0 +1,51 @@ +package repository + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/formbricks/hub/internal/datatypes" +) + +func TestListEnabledForEventTypeAndTenantQuery(t *testing.T) { + tenantID := "tenant-a" + + tests := []struct { + name string + tenantID *string + wantArgs []any + wantTenantClause string + rejectTenantClause string + }{ + { + name: "scoped event matches same tenant webhooks only", + tenantID: &tenantID, + wantArgs: []any{datatypes.FeedbackRecordCreated.String(), tenantID}, + wantTenantClause: "AND tenant_id = $2", + rejectTenantClause: "tenant_id IS NULL", + }, + { + name: "tenant-less event matches no webhooks", + tenantID: nil, + wantArgs: []any{datatypes.FeedbackRecordCreated.String()}, + wantTenantClause: "AND FALSE", + rejectTenantClause: "tenant_id = $2", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + query, args := listEnabledForEventTypeAndTenantQuery(datatypes.FeedbackRecordCreated.String(), tt.tenantID) + + require.Equal(t, tt.wantArgs, args) + assert.Contains(t, query, "WHERE enabled = true") + assert.Contains(t, query, "event_types IS NULL OR event_types = '{}' OR event_types @> ARRAY[$1]::VARCHAR(64)[]") + assert.Contains(t, query, tt.wantTenantClause) + assert.NotContains(t, query, tt.rejectTenantClause) + assert.True(t, strings.HasSuffix(strings.TrimSpace(query), "ORDER BY id")) + }) + } +} diff --git a/internal/service/feedback_records_service.go b/internal/service/feedback_records_service.go index 8b279e9..61b792e 100644 --- a/internal/service/feedback_records_service.go +++ b/internal/service/feedback_records_service.go @@ -5,6 +5,7 @@ import ( "context" "errors" "fmt" + "log/slog" "strings" "time" @@ -35,7 +36,7 @@ type FeedbackRecordsRepository interface { ) ([]models.FeedbackRecord, bool, error) Update(ctx context.Context, id uuid.UUID, req *models.UpdateFeedbackRecordRequest) (*models.FeedbackRecord, error) Delete(ctx context.Context, id uuid.UUID) error - BulkDelete(ctx context.Context, userID string, tenantID *string) ([]uuid.UUID, error) + BulkDelete(ctx context.Context, filters *models.BulkDeleteFilters) ([]models.DeletedFeedbackRecordsByTenant, error) } // EmbeddingsRepository defines the interface for embeddings table access. @@ -91,7 +92,15 @@ func (s *FeedbackRecordsService) SetEmbeddingInserter(inserter FeedbackEmbedding func (s *FeedbackRecordsService) CreateFeedbackRecord( ctx context.Context, req *models.CreateFeedbackRecordRequest, ) (*models.FeedbackRecord, error) { - record, err := s.repo.Create(ctx, req) + normalizedTenantID, err := normalizeRequiredTenantIDValue(req.TenantID) + if err != nil { + return nil, err + } + + normalizedReq := *req + normalizedReq.TenantID = normalizedTenantID + + record, err := s.repo.Create(ctx, &normalizedReq) if err != nil { return nil, fmt.Errorf("create feedback record: %w", err) } @@ -182,36 +191,72 @@ func (s *FeedbackRecordsService) UpdateFeedbackRecord( } // DeleteFeedbackRecord deletes a feedback record by ID. -// Publishes FeedbackRecordDeleted with data = [id] (array of deleted IDs) for consistency with bulk delete. +// Publishes FeedbackRecordDeleted with tenant-aware deleted IDs for webhook isolation. func (s *FeedbackRecordsService) DeleteFeedbackRecord(ctx context.Context, id uuid.UUID) error { + record, err := s.repo.GetByID(ctx, id) + if err != nil { + return fmt.Errorf("get feedback record before delete: %w", err) + } + if err := s.repo.Delete(ctx, id); err != nil { return fmt.Errorf("delete feedback record: %w", err) } if s.publisher != nil { - s.publisher.PublishEvent(ctx, datatypes.FeedbackRecordDeleted, []uuid.UUID{id}) + s.publisher.PublishEvent(ctx, datatypes.FeedbackRecordDeleted, models.DeletedIDsEventData{ + TenantID: record.TenantID, + IDs: []uuid.UUID{id}, + }) } return nil } -// BulkDeleteFeedbackRecords deletes all feedback records matching user_id and optional tenant_id. -// Publishes a single FeedbackRecordDeleted event with data = [id1, id2, ...] (array of deleted IDs). -func (s *FeedbackRecordsService) BulkDeleteFeedbackRecords(ctx context.Context, userID string, tenantID *string) (int, error) { - if userID == "" { +// BulkDeleteFeedbackRecords deletes all feedback records matching user_id. +// When tenant_id is provided, deletion is restricted to that tenant; otherwise all user records are deleted. +// It publishes one tenant-aware FeedbackRecordDeleted event per tenant represented in the deleted rows. +func (s *FeedbackRecordsService) BulkDeleteFeedbackRecords(ctx context.Context, filters *models.BulkDeleteFilters) (int, error) { + if filters == nil || filters.UserID == "" { return 0, ErrUserIDRequired } - ids, err := s.repo.BulkDelete(ctx, userID, tenantID) + if filters.TenantID != nil { + normalizedTenantID, err := normalizeRequiredTenantID(filters.TenantID) + if err != nil { + return 0, err + } + + filters = &models.BulkDeleteFilters{ + UserID: filters.UserID, + TenantID: &normalizedTenantID, + } + } + + groups, err := s.repo.BulkDelete(ctx, filters) if err != nil { return 0, fmt.Errorf("bulk delete feedback records: %w", err) } - if len(ids) > 0 && s.publisher != nil { - s.publisher.PublishEvent(ctx, datatypes.FeedbackRecordDeleted, ids) + deletedCount := 0 + for _, group := range groups { + deletedCount += len(group.IDs) + + if len(group.IDs) == 0 || s.publisher == nil { + continue + } + + if group.TenantID == "" { + slog.Error("bulk delete feedback records: deleted rows missing tenant_id; skipping webhook event", + "deleted_count", len(group.IDs), + ) + + continue + } + + s.publisher.PublishEvent(ctx, datatypes.FeedbackRecordDeleted, models.DeletedIDsEventData(group)) } - return len(ids), nil + return deletedCount, nil } // SetEmbedding sets or clears the embedding for a feedback record and model (internal use by embeddings worker). diff --git a/internal/service/feedback_records_service_test.go b/internal/service/feedback_records_service_test.go new file mode 100644 index 0000000..13ad699 --- /dev/null +++ b/internal/service/feedback_records_service_test.go @@ -0,0 +1,292 @@ +package service + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/google/uuid" + + "github.com/formbricks/hub/internal/datatypes" + "github.com/formbricks/hub/internal/models" +) + +type mockFeedbackRecordsRepo struct { + record *models.FeedbackRecord + createReq *models.CreateFeedbackRecordRequest + bulkGroups []models.DeletedFeedbackRecordsByTenant + deletedID uuid.UUID + bulkDeleteFilters *models.BulkDeleteFilters +} + +func (m *mockFeedbackRecordsRepo) Create( + _ context.Context, req *models.CreateFeedbackRecordRequest, +) (*models.FeedbackRecord, error) { + reqCopy := *req + m.createReq = &reqCopy + + if m.record != nil { + return m.record, nil + } + + return &models.FeedbackRecord{TenantID: req.TenantID}, nil +} + +func (m *mockFeedbackRecordsRepo) GetByID(_ context.Context, _ uuid.UUID) (*models.FeedbackRecord, error) { + return m.record, nil +} + +func (m *mockFeedbackRecordsRepo) List( + _ context.Context, _ *models.ListFeedbackRecordsFilters, +) ([]models.FeedbackRecord, bool, error) { + return nil, false, errors.New("not implemented") +} + +func (m *mockFeedbackRecordsRepo) ListAfterCursor( + _ context.Context, _ *models.ListFeedbackRecordsFilters, _ time.Time, _ uuid.UUID, +) ([]models.FeedbackRecord, bool, error) { + return nil, false, errors.New("not implemented") +} + +func (m *mockFeedbackRecordsRepo) Update( + _ context.Context, _ uuid.UUID, _ *models.UpdateFeedbackRecordRequest, +) (*models.FeedbackRecord, error) { + return nil, errors.New("not implemented") +} + +func (m *mockFeedbackRecordsRepo) Delete(_ context.Context, id uuid.UUID) error { + m.deletedID = id + + return nil +} + +func (m *mockFeedbackRecordsRepo) BulkDelete( + _ context.Context, filters *models.BulkDeleteFilters, +) ([]models.DeletedFeedbackRecordsByTenant, error) { + m.bulkDeleteFilters = filters + + return m.bulkGroups, nil +} + +func TestFeedbackRecordsService_DeleteFeedbackRecord_PublishesTenantAwareDeletedEvent(t *testing.T) { + ctx := context.Background() + recordID := uuid.Must(uuid.NewV7()) + tenantID := "org-123" + repo := &mockFeedbackRecordsRepo{record: &models.FeedbackRecord{ID: recordID, TenantID: tenantID}} + publisher := &capturePublisher{} + svc := NewFeedbackRecordsService(repo, nil, "", publisher, nil, "", 0) + + err := svc.DeleteFeedbackRecord(ctx, recordID) + if err != nil { + t.Fatalf("DeleteFeedbackRecord() error = %v", err) + } + + if repo.deletedID != recordID { + t.Fatalf("deletedID = %v, want %v", repo.deletedID, recordID) + } + + assertDeletedEventData(t, publisher, datatypes.FeedbackRecordDeleted, tenantID, []uuid.UUID{recordID}) +} + +func TestFeedbackRecordsService_CreateFeedbackRecord_NormalizesTenantID(t *testing.T) { + ctx := context.Background() + inputTenantID := " org-123 " + repo := &mockFeedbackRecordsRepo{} + publisher := &capturePublisher{} + svc := NewFeedbackRecordsService(repo, nil, "", publisher, nil, "", 0) + + record, err := svc.CreateFeedbackRecord(ctx, &models.CreateFeedbackRecordRequest{ + SourceType: "formbricks", + FieldID: "field-1", + FieldType: models.FieldTypeText, + TenantID: inputTenantID, + SubmissionID: "submission-1", + }) + if err != nil { + t.Fatalf("CreateFeedbackRecord() error = %v", err) + } + + if repo.createReq == nil { + t.Fatal("repo Create request = nil") + } + + if repo.createReq.TenantID != "org-123" { + t.Fatalf("repo TenantID = %q, want org-123", repo.createReq.TenantID) + } + + if record.TenantID != "org-123" { + t.Fatalf("record TenantID = %q, want org-123", record.TenantID) + } + + if publisher.callCount != 1 || publisher.eventType != datatypes.FeedbackRecordCreated { + t.Fatalf("published event = (%d, %s), want one feedback_record.created", publisher.callCount, publisher.eventType) + } +} + +func TestFeedbackRecordsService_BulkDeleteFeedbackRecords_PublishesTenantAwareDeletedEventsByTenant(t *testing.T) { + ctx := context.Background() + tenantA := "org-123" + tenantB := "org-456" + tenantAIDs := []uuid.UUID{uuid.Must(uuid.NewV7()), uuid.Must(uuid.NewV7())} + tenantBIDs := []uuid.UUID{uuid.Must(uuid.NewV7())} + repo := &mockFeedbackRecordsRepo{ + bulkGroups: []models.DeletedFeedbackRecordsByTenant{ + {TenantID: tenantA, IDs: tenantAIDs}, + {TenantID: tenantB, IDs: tenantBIDs}, + }, + } + publisher := &capturePublisher{} + svc := NewFeedbackRecordsService(repo, nil, "", publisher, nil, "", 0) + + count, err := svc.BulkDeleteFeedbackRecords(ctx, &models.BulkDeleteFilters{UserID: "user-123"}) + if err != nil { + t.Fatalf("BulkDeleteFeedbackRecords() error = %v", err) + } + + if repo.bulkDeleteFilters == nil { + t.Fatal("repo BulkDelete filters = nil") + } + + if repo.bulkDeleteFilters.UserID != "user-123" { + t.Fatalf("repo UserID = %q, want user-123", repo.bulkDeleteFilters.UserID) + } + + if repo.bulkDeleteFilters.TenantID != nil { + t.Fatalf("repo TenantID = %q, want nil for all-tenant delete", *repo.bulkDeleteFilters.TenantID) + } + + if count != len(tenantAIDs)+len(tenantBIDs) { + t.Fatalf("count = %d, want %d", count, len(tenantAIDs)+len(tenantBIDs)) + } + + assertDeletedEventDataAt(t, publisher, 0, datatypes.FeedbackRecordDeleted, tenantA, tenantAIDs) + assertDeletedEventDataAt(t, publisher, 1, datatypes.FeedbackRecordDeleted, tenantB, tenantBIDs) +} + +func TestFeedbackRecordsService_BulkDeleteFeedbackRecords_NormalizesTenantFilter(t *testing.T) { + ctx := context.Background() + tenantID := " org-123 " + deletedID := uuid.Must(uuid.NewV7()) + repo := &mockFeedbackRecordsRepo{ + bulkGroups: []models.DeletedFeedbackRecordsByTenant{ + {TenantID: "org-123", IDs: []uuid.UUID{deletedID}}, + }, + } + publisher := &capturePublisher{} + svc := NewFeedbackRecordsService(repo, nil, "", publisher, nil, "", 0) + + count, err := svc.BulkDeleteFeedbackRecords(ctx, &models.BulkDeleteFilters{ + UserID: "user-123", + TenantID: &tenantID, + }) + if err != nil { + t.Fatalf("BulkDeleteFeedbackRecords() error = %v", err) + } + + if count != 1 { + t.Fatalf("count = %d, want 1", count) + } + + if repo.bulkDeleteFilters == nil || repo.bulkDeleteFilters.TenantID == nil { + t.Fatal("repo TenantID = nil, want normalized tenant") + } + + if *repo.bulkDeleteFilters.TenantID != "org-123" { + t.Fatalf("repo TenantID = %q, want org-123", *repo.bulkDeleteFilters.TenantID) + } + + assertDeletedEventData(t, publisher, datatypes.FeedbackRecordDeleted, "org-123", []uuid.UUID{deletedID}) +} + +func TestFeedbackRecordsService_BulkDeleteFeedbackRecords_RequiresUserID(t *testing.T) { + ctx := context.Background() + repo := &mockFeedbackRecordsRepo{ + bulkGroups: []models.DeletedFeedbackRecordsByTenant{ + {TenantID: "org-123", IDs: []uuid.UUID{uuid.Must(uuid.NewV7())}}, + }, + } + publisher := &capturePublisher{} + svc := NewFeedbackRecordsService(repo, nil, "", publisher, nil, "", 0) + + count, err := svc.BulkDeleteFeedbackRecords(ctx, &models.BulkDeleteFilters{}) + if !errors.Is(err, ErrUserIDRequired) { + t.Fatalf("BulkDeleteFeedbackRecords() error = %v, want ErrUserIDRequired", err) + } + + if count != 0 { + t.Fatalf("count = %d, want 0", count) + } + + if publisher.callCount != 0 { + t.Fatalf("published %d events, want 0", publisher.callCount) + } +} + +func assertDeletedEventDataAt( + t *testing.T, + publisher *capturePublisher, + index int, + eventType datatypes.EventType, + tenantID string, + ids []uuid.UUID, +) { + t.Helper() + + if publisher.callCount <= index { + t.Fatalf("published %d events, want event at index %d", publisher.callCount, index) + } + + event := publisher.events[index] + if event.eventType != eventType { + t.Fatalf("published event type = %s, want %s", event.eventType, eventType) + } + + data, ok := event.data.(models.DeletedIDsEventData) + if !ok { + t.Fatalf("published data type = %T, want DeletedIDsEventData", event.data) + } + + if data.TenantID != tenantID { + t.Errorf("TenantID = %q, want %q", data.TenantID, tenantID) + } + + if len(data.IDs) != len(ids) { + t.Fatalf("IDs length = %d, want %d", len(data.IDs), len(ids)) + } + + for i := range ids { + if data.IDs[i] != ids[i] { + t.Errorf("IDs[%d] = %v, want %v", i, data.IDs[i], ids[i]) + } + } +} + +func assertDeletedEventData( + t *testing.T, publisher *capturePublisher, eventType datatypes.EventType, tenantID string, ids []uuid.UUID, +) { + t.Helper() + + if publisher.callCount != 1 || publisher.eventType != eventType { + t.Fatalf("published event = (%d, %s), want one %s", publisher.callCount, publisher.eventType, eventType) + } + + data, ok := publisher.data.(models.DeletedIDsEventData) + if !ok { + t.Fatalf("published data type = %T, want DeletedIDsEventData", publisher.data) + } + + if data.TenantID != tenantID { + t.Errorf("TenantID = %q, want %q", data.TenantID, tenantID) + } + + if len(data.IDs) != len(ids) { + t.Fatalf("IDs length = %d, want %d", len(data.IDs), len(ids)) + } + + for i := range ids { + if data.IDs[i] != ids[i] { + t.Errorf("IDs[%d] = %v, want %v", i, data.IDs[i], ids[i]) + } + } +} diff --git a/internal/service/tenant_validation.go b/internal/service/tenant_validation.go new file mode 100644 index 0000000..0ecf632 --- /dev/null +++ b/internal/service/tenant_validation.go @@ -0,0 +1,24 @@ +package service + +import ( + "strings" + + "github.com/formbricks/hub/internal/huberrors" +) + +func normalizeRequiredTenantID(tenantID *string) (string, error) { + if tenantID == nil { + return "", huberrors.NewValidationError("tenant_id", "tenant_id is required") + } + + return normalizeRequiredTenantIDValue(*tenantID) +} + +func normalizeRequiredTenantIDValue(tenantID string) (string, error) { + normalized := strings.TrimSpace(tenantID) + if normalized == "" { + return "", huberrors.NewValidationError("tenant_id", "tenant_id is required and cannot be empty") + } + + return normalized, nil +} diff --git a/internal/service/webhook_dispatch_args.go b/internal/service/webhook_dispatch_args.go index 7c530d5..8d7715c 100644 --- a/internal/service/webhook_dispatch_args.go +++ b/internal/service/webhook_dispatch_args.go @@ -19,6 +19,7 @@ type WebhookDispatchArgs struct { Timestamp time.Time `json:"timestamp"` Data any `json:"data"` ChangedFields []string `json:"changed_fields,omitempty"` + TenantID *string `json:"tenant_id,omitempty"` WebhookID uuid.UUID `json:"webhook_id" river:"unique"` } diff --git a/internal/service/webhook_payload.go b/internal/service/webhook_payload.go index 4526d0c..d070ae2 100644 --- a/internal/service/webhook_payload.go +++ b/internal/service/webhook_payload.go @@ -1,9 +1,13 @@ package service import ( + "encoding/json" "time" "github.com/google/uuid" + + "github.com/formbricks/hub/internal/datatypes" + "github.com/formbricks/hub/internal/models" ) // WebhookPayload represents a generic webhook payload structure for all event types. @@ -12,6 +16,161 @@ type WebhookPayload struct { ID uuid.UUID `json:"id"` // Unique event id (UUID v7) Type string `json:"type"` // Event type as string (e.g., "feedback_record.created", "webhook.created") Timestamp time.Time `json:"timestamp"` // Event creation timestamp + TenantID *string `json:"tenant_id,omitempty"` // Tenant boundary for the event Data any `json:"data"` // Event data (FeedbackRecord, Webhook, etc.) ChangedFields []string `json:"changed_fields,omitempty"` // Only for update events (optional) } + +// NewWebhookPayload builds the public webhook payload from internal dispatch args. +func NewWebhookPayload(args WebhookDispatchArgs) *WebhookPayload { + tenantID := clonePayloadTenantID(args.TenantID) + if tenantID == nil { + tenantID = TenantIDPointerFromEventData(args.Data) + } + + return &WebhookPayload{ + ID: args.EventID, + Type: args.EventType, + Timestamp: args.Timestamp, + TenantID: tenantID, + Data: publicWebhookData(args.EventType, args.Data), + ChangedFields: args.ChangedFields, + } +} + +func publicWebhookData(eventType string, data any) any { + if !isDeletedIDsEvent(eventType) { + return data + } + + ids, ok := deletedIDsFromEventData(data) + if !ok { + return data + } + + return ids +} + +func isDeletedIDsEvent(eventType string) bool { + return eventType == datatypes.FeedbackRecordDeleted.String() || + eventType == datatypes.WebhookDeleted.String() +} + +func deletedIDsFromEventData(data any) ([]uuid.UUID, bool) { + switch payload := data.(type) { + case models.DeletedIDsEventData: + return cloneUUIDs(payload.IDs), true + case *models.DeletedIDsEventData: + if payload == nil { + return nil, true + } + + return cloneUUIDs(payload.IDs), true + case map[string]any: + return deletedIDsFromValue(payload["ids"]) + case map[string][]uuid.UUID: + return cloneUUIDs(payload["ids"]), true + case map[string][]string: + return deletedIDsFromStrings(payload["ids"]) + case []uuid.UUID: + return cloneUUIDs(payload), true + case []string: + return deletedIDsFromStrings(payload) + case []any: + return deletedIDsFromValues(payload) + case json.RawMessage: + return deletedIDsFromRawJSON(payload) + default: + return deletedIDsFromJSON(data) + } +} + +func clonePayloadTenantID(tenantID *string) *string { + if tenantID == nil { + return nil + } + + return stringPointer(*tenantID) +} + +func cloneUUIDs(ids []uuid.UUID) []uuid.UUID { + if ids == nil { + return nil + } + + return append([]uuid.UUID(nil), ids...) +} + +func deletedIDsFromValue(value any) ([]uuid.UUID, bool) { + switch ids := value.(type) { + case []uuid.UUID: + return cloneUUIDs(ids), true + case []string: + return deletedIDsFromStrings(ids) + case []any: + return deletedIDsFromValues(ids) + default: + return nil, false + } +} + +func deletedIDsFromStrings(values []string) ([]uuid.UUID, bool) { + ids := make([]uuid.UUID, 0, len(values)) + for _, value := range values { + id, err := uuid.Parse(value) + if err != nil { + return nil, false + } + + ids = append(ids, id) + } + + return ids, true +} + +func deletedIDsFromValues(values []any) ([]uuid.UUID, bool) { + ids := make([]uuid.UUID, 0, len(values)) + for _, value := range values { + switch id := value.(type) { + case uuid.UUID: + ids = append(ids, id) + case string: + parsed, err := uuid.Parse(id) + if err != nil { + return nil, false + } + + ids = append(ids, parsed) + default: + return nil, false + } + } + + return ids, true +} + +func deletedIDsFromJSON(data any) ([]uuid.UUID, bool) { + payload, err := json.Marshal(data) + if err != nil { + return nil, false + } + + return deletedIDsFromRawJSON(payload) +} + +func deletedIDsFromRawJSON(payload []byte) ([]uuid.UUID, bool) { + var envelope struct { + IDs []uuid.UUID `json:"ids"` + } + + if err := json.Unmarshal(payload, &envelope); err == nil && envelope.IDs != nil { + return cloneUUIDs(envelope.IDs), true + } + + var ids []uuid.UUID + if err := json.Unmarshal(payload, &ids); err != nil { + return nil, false + } + + return cloneUUIDs(ids), true +} diff --git a/internal/service/webhook_payload_test.go b/internal/service/webhook_payload_test.go new file mode 100644 index 0000000..3c5db91 --- /dev/null +++ b/internal/service/webhook_payload_test.go @@ -0,0 +1,132 @@ +package service + +import ( + "encoding/json" + "testing" + "time" + + "github.com/google/uuid" + + "github.com/formbricks/hub/internal/models" +) + +func TestNewWebhookPayload_MapsDeletedIDsEventDataToPublicPayload(t *testing.T) { + tenantID := "org-123" + ids := []uuid.UUID{uuid.Must(uuid.NewV7()), uuid.Must(uuid.NewV7())} + args := WebhookDispatchArgs{ + EventID: uuid.Must(uuid.NewV7()), + EventType: "feedback_record.deleted", + Timestamp: time.Now(), + TenantID: &tenantID, + Data: models.DeletedIDsEventData{TenantID: tenantID, IDs: ids}, + WebhookID: uuid.Must(uuid.NewV7()), + } + + payload := NewWebhookPayload(args) + + if payload.TenantID == nil || *payload.TenantID != tenantID { + t.Fatalf("TenantID = %v, want %q", payload.TenantID, tenantID) + } + + gotIDs, ok := payload.Data.([]uuid.UUID) + if !ok { + t.Fatalf("Data type = %T, want []uuid.UUID", payload.Data) + } + + if len(gotIDs) != len(ids) { + t.Fatalf("Data length = %d, want %d", len(gotIDs), len(ids)) + } + + for i := range ids { + if gotIDs[i] != ids[i] { + t.Errorf("Data[%d] = %v, want %v", i, gotIDs[i], ids[i]) + } + } + + ids[0] = uuid.Must(uuid.NewV7()) + if gotIDs[0] == ids[0] { + t.Error("Data aliases internal deleted ID slice") + } +} + +func TestNewWebhookPayload_MapsJSONRoundTrippedDeletedIDsEventDataToPublicPayload(t *testing.T) { + tenantID := "org-123" + ids := []uuid.UUID{uuid.Must(uuid.NewV7()), uuid.Must(uuid.NewV7())} + args := WebhookDispatchArgs{ + EventID: uuid.Must(uuid.NewV7()), + EventType: "feedback_record.deleted", + Timestamp: time.Now(), + Data: map[string]any{ + "tenant_id": tenantID, + "ids": []any{ids[0].String(), ids[1].String()}, + }, + WebhookID: uuid.Must(uuid.NewV7()), + } + + payload := NewWebhookPayload(args) + + if payload.TenantID == nil || *payload.TenantID != tenantID { + t.Fatalf("TenantID = %v, want %q", payload.TenantID, tenantID) + } + + assertWebhookPayloadIDs(t, payload.Data, ids) +} + +func TestNewWebhookPayload_MapsRawJSONDeletedIDsEventDataToPublicPayload(t *testing.T) { + tenantID := "org-123" + ids := []uuid.UUID{uuid.Must(uuid.NewV7()), uuid.Must(uuid.NewV7())} + args := WebhookDispatchArgs{ + EventID: uuid.Must(uuid.NewV7()), + EventType: "webhook.deleted", + Timestamp: time.Now(), + Data: json.RawMessage(`{ + "tenant_id": "` + tenantID + `", + "ids": ["` + ids[0].String() + `", "` + ids[1].String() + `"] + }`), + WebhookID: uuid.Must(uuid.NewV7()), + } + + payload := NewWebhookPayload(args) + + if payload.TenantID == nil || *payload.TenantID != tenantID { + t.Fatalf("TenantID = %v, want %q", payload.TenantID, tenantID) + } + + assertWebhookPayloadIDs(t, payload.Data, ids) +} + +func TestNewWebhookPayload_DerivesTenantFromLegacyArgsData(t *testing.T) { + tenantID := "org-123" + args := WebhookDispatchArgs{ + EventID: uuid.Must(uuid.NewV7()), + EventType: "feedback_record.created", + Timestamp: time.Now(), + Data: map[string]any{"tenant_id": tenantID}, + WebhookID: uuid.Must(uuid.NewV7()), + } + + payload := NewWebhookPayload(args) + + if payload.TenantID == nil || *payload.TenantID != tenantID { + t.Fatalf("TenantID = %v, want %q", payload.TenantID, tenantID) + } +} + +func assertWebhookPayloadIDs(t *testing.T, data any, want []uuid.UUID) { + t.Helper() + + got, ok := data.([]uuid.UUID) + if !ok { + t.Fatalf("Data type = %T, want []uuid.UUID", data) + } + + if len(got) != len(want) { + t.Fatalf("Data length = %d, want %d", len(got), len(want)) + } + + for i := range want { + if got[i] != want[i] { + t.Errorf("Data[%d] = %v, want %v", i, got[i], want[i]) + } + } +} diff --git a/internal/service/webhook_provider.go b/internal/service/webhook_provider.go index 3032771..8e40167 100644 --- a/internal/service/webhook_provider.go +++ b/internal/service/webhook_provider.go @@ -10,6 +10,7 @@ import ( "github.com/riverqueue/river" "github.com/riverqueue/river/rivertype" + "github.com/formbricks/hub/internal/models" "github.com/formbricks/hub/internal/observability" ) @@ -18,9 +19,14 @@ type WebhookDispatchInserter interface { InsertMany(ctx context.Context, params []river.InsertManyParams) ([]*rivertype.JobInsertResult, error) } +// WebhookProviderRepository lists tenant-scoped webhooks eligible for event fan-out. +type WebhookProviderRepository interface { + ListEnabledForEventTypeAndTenant(ctx context.Context, eventType string, tenantID *string) ([]models.Webhook, error) +} + // WebhookProvider implements eventPublisher by enqueueing one River job per (event, webhook). type WebhookProvider struct { - repo WebhooksRepository + repo WebhookProviderRepository inserter WebhookDispatchInserter maxAttempts int maxFanOut int @@ -35,7 +41,7 @@ type WebhookProvider struct { // enqueueMaxRetries, enqueueInitialBackoff, enqueueMaxBackoff configure retries when InsertMany fails (transient River/DB errors). // metrics may be nil when metrics are disabled. func NewWebhookProvider( - inserter WebhookDispatchInserter, repo WebhooksRepository, + inserter WebhookDispatchInserter, repo WebhookProviderRepository, maxAttempts, maxFanOut int, enqueueMaxRetries int, enqueueInitialBackoff, enqueueMaxBackoff time.Duration, metrics observability.WebhookMetrics, @@ -52,10 +58,26 @@ func NewWebhookProvider( } } -// PublishEvent lists all enabled webhooks for the event type and enqueues one job per webhook, -// in batches of maxFanOut to avoid oversized InsertMany calls. +// PublishEvent lists enabled webhooks for the event type and tenant, then enqueues one job per webhook. +// Webhooks are only eligible when the event payload has the same tenant_id. func (p *WebhookProvider) PublishEvent(ctx context.Context, event Event) { - webhooks, err := p.repo.ListEnabledForEventType(ctx, event.Type.String()) + tenantID := TenantIDPointerFromEventData(event.Data) + if tenantID == nil { + if p.metrics != nil { + p.metrics.RecordProviderError(ctx, "missing_tenant_id") + } + + slog.Warn("webhook provider: event has no tenant_id; skipping webhook fan-out", + "event_id", event.ID, + "event_type", event.Type, + ) + + return + } + + tenantIDValue := *tenantID + + webhooks, err := p.repo.ListEnabledForEventTypeAndTenant(ctx, event.Type.String(), tenantID) if err != nil { if p.metrics != nil { p.metrics.RecordProviderError(ctx, "list_failed") @@ -64,12 +86,23 @@ func (p *WebhookProvider) PublishEvent(ctx context.Context, event Event) { slog.Error("failed to list enabled webhooks for event type", "event_id", event.ID, "event_type", event.Type, + "tenant_id", tenantIDValue, "error", err, ) return } + webhooks, skipped := filterWebhooksByTenant(webhooks, tenantID) + if skipped > 0 { + slog.Warn("webhook provider: skipped tenant-mismatched webhooks returned by repository", + "event_id", event.ID, + "event_type", event.Type, + "tenant_id", tenantIDValue, + "skipped", skipped, + ) + } + if len(webhooks) == 0 { return } @@ -83,7 +116,7 @@ func (p *WebhookProvider) PublishEvent(ctx context.Context, event Event) { ByPeriod: uniqueByPeriodHours * time.Hour, }, } - baseArgs := p.eventToArgs(event) + baseArgs := p.eventToArgs(event, tenantID) var enqueued int64 @@ -129,6 +162,7 @@ func (p *WebhookProvider) PublishEvent(ctx context.Context, event Event) { slog.Error("failed to enqueue webhook jobs after retries", "event_id", event.ID, "event_type", event.Type, + "tenant_id", tenantIDValue, "error", insertErr, ) @@ -160,14 +194,27 @@ func (p *WebhookProvider) enqueueBackoffWithJitter(attempt int) time.Duration { return exp + jitter } +func filterWebhooksByTenant(webhooks []models.Webhook, tenantID *string) ([]models.Webhook, int) { + filtered := make([]models.Webhook, 0, len(webhooks)) + + for i := range webhooks { + if WebhookMatchesTenant(&webhooks[i], tenantID) { + filtered = append(filtered, webhooks[i]) + } + } + + return filtered, len(webhooks) - len(filtered) +} + // eventToArgs converts an Event to WebhookDispatchArgs (WebhookID must be set per webhook). -func (p *WebhookProvider) eventToArgs(event Event) WebhookDispatchArgs { +func (p *WebhookProvider) eventToArgs(event Event, tenantID *string) WebhookDispatchArgs { return WebhookDispatchArgs{ EventID: event.ID, EventType: event.Type.String(), Timestamp: event.Timestamp, Data: event.Data, ChangedFields: event.ChangedFields, + TenantID: tenantID, WebhookID: uuid.Nil, // set per webhook in PublishEvent } } diff --git a/internal/service/webhook_provider_test.go b/internal/service/webhook_provider_test.go index e536e2f..73a2786 100644 --- a/internal/service/webhook_provider_test.go +++ b/internal/service/webhook_provider_test.go @@ -37,13 +37,22 @@ func (m *mockWebhookInserter) InsertMany(_ context.Context, params []river.Inser return results, nil } -// mockProviderRepo implements only ListEnabledForEventType for provider tests. +// mockProviderRepo implements webhook listing for provider tests. type mockProviderRepo struct { - webhooks []models.Webhook - err error + webhooks []models.Webhook + err error + eventType string + tenantID *string + listCallCount int } -func (m *mockProviderRepo) ListEnabledForEventType(_ context.Context, _ string) ([]models.Webhook, error) { +func (m *mockProviderRepo) ListEnabledForEventTypeAndTenant( + _ context.Context, eventType string, tenantID *string, +) ([]models.Webhook, error) { + m.eventType = eventType + m.tenantID = cloneStringPointer(tenantID) + m.listCallCount++ + if m.err != nil { return nil, m.err } @@ -51,52 +60,18 @@ func (m *mockProviderRepo) ListEnabledForEventType(_ context.Context, _ string) return m.webhooks, nil } -// Stub other WebhooksRepository methods so mockProviderRepo can be used as WebhooksRepository. -func (m *mockProviderRepo) Create(_ context.Context, _ *models.CreateWebhookRequest) (*models.Webhook, error) { - return nil, errors.New("not implemented") -} - -func (m *mockProviderRepo) GetByID(_ context.Context, _ uuid.UUID) (*models.Webhook, error) { - return nil, errors.New("not implemented") -} - -func (m *mockProviderRepo) List(_ context.Context, _ *models.ListWebhooksFilters) ([]models.Webhook, bool, error) { - return nil, false, errors.New("not implemented") -} - -func (m *mockProviderRepo) ListAfterCursor( - _ context.Context, _ *models.ListWebhooksFilters, _ time.Time, _ uuid.UUID, -) ([]models.Webhook, bool, error) { - return nil, false, errors.New("not implemented") -} - -func (m *mockProviderRepo) Count(_ context.Context, _ *models.ListWebhooksFilters) (int64, error) { - return 0, errors.New("not implemented") -} - -func (m *mockProviderRepo) Update(_ context.Context, _ uuid.UUID, _ *models.UpdateWebhookRequest) (*models.Webhook, error) { - return nil, errors.New("not implemented") -} - -func (m *mockProviderRepo) Delete(_ context.Context, _ uuid.UUID) error { - return errors.New("not implemented") -} - -func (m *mockProviderRepo) ListEnabled(_ context.Context) ([]models.Webhook, error) { - return nil, errors.New("not implemented") -} - func TestWebhookProvider_PublishEvent(t *testing.T) { ctx := context.Background() eventID := uuid.Must(uuid.NewV7()) eventType := datatypes.FeedbackRecordCreated wh1 := uuid.Must(uuid.NewV7()) wh2 := uuid.Must(uuid.NewV7()) + tenantID := "org-123" t.Run("inserts one job per webhook via InsertMany with correct opts", func(t *testing.T) { inserter := &mockWebhookInserter{} repo := &mockProviderRepo{ - webhooks: []models.Webhook{{ID: wh1}, {ID: wh2}}, + webhooks: []models.Webhook{{ID: wh1, TenantID: &tenantID}, {ID: wh2, TenantID: &tenantID}}, } provider := NewWebhookProvider(inserter, repo, 3, 500, 0, 0, 0, nil) @@ -104,11 +79,19 @@ func TestWebhookProvider_PublishEvent(t *testing.T) { ID: eventID, Type: eventType, Timestamp: time.Now(), - Data: map[string]string{"id": "123"}, + Data: map[string]string{"id": "123", "tenant_id": tenantID}, } provider.PublishEvent(ctx, event) + if repo.eventType != eventType.String() { + t.Errorf("eventType = %q, want %q", repo.eventType, eventType.String()) + } + + if repo.tenantID == nil || *repo.tenantID != tenantID { + t.Errorf("tenantID = %v, want %q", repo.tenantID, tenantID) + } + if n := len(inserter.insertManyCalls); n != 1 { t.Fatalf("InsertMany called %d times, want 1", n) } @@ -137,6 +120,10 @@ func TestWebhookProvider_PublishEvent(t *testing.T) { t.Errorf("param %d WebhookID = %v, want %v", i, args.WebhookID, wantID) } + if args.TenantID == nil || *args.TenantID != tenantID { + t.Errorf("param %d TenantID = %v, want %q", i, args.TenantID, tenantID) + } + if p.InsertOpts == nil || p.InsertOpts.MaxAttempts != 3 { t.Errorf("param %d MaxAttempts = %v, want 3", i, p.InsertOpts) } @@ -151,7 +138,7 @@ func TestWebhookProvider_PublishEvent(t *testing.T) { inserter := &mockWebhookInserter{} repo := &mockProviderRepo{webhooks: nil} provider := NewWebhookProvider(inserter, repo, 3, 500, 0, 0, 0, nil) - event := Event{ID: eventID, Type: eventType, Timestamp: time.Now(), Data: nil} + event := Event{ID: eventID, Type: eventType, Timestamp: time.Now(), Data: map[string]string{"tenant_id": tenantID}} provider.PublishEvent(ctx, event) if len(inserter.insertManyCalls) != 0 { @@ -163,7 +150,7 @@ func TestWebhookProvider_PublishEvent(t *testing.T) { inserter := &mockWebhookInserter{} repo := &mockProviderRepo{err: errors.New("db error")} provider := NewWebhookProvider(inserter, repo, 3, 500, 0, 0, 0, nil) - event := Event{ID: eventID, Type: eventType, Timestamp: time.Now(), Data: nil} + event := Event{ID: eventID, Type: eventType, Timestamp: time.Now(), Data: map[string]string{"tenant_id": tenantID}} provider.PublishEvent(ctx, event) if len(inserter.insertManyCalls) != 0 { @@ -173,9 +160,11 @@ func TestWebhookProvider_PublishEvent(t *testing.T) { t.Run("when InsertMany returns error, provider logs and returns", func(t *testing.T) { inserter := &mockWebhookInserter{insertManyErr: errors.New("river error")} - repo := &mockProviderRepo{webhooks: []models.Webhook{{ID: wh1}, {ID: wh2}}} + repo := &mockProviderRepo{ + webhooks: []models.Webhook{{ID: wh1, TenantID: &tenantID}, {ID: wh2, TenantID: &tenantID}}, + } provider := NewWebhookProvider(inserter, repo, 5, 500, 0, 0, 0, nil) - event := Event{ID: eventID, Type: eventType, Timestamp: time.Now(), Data: nil} + event := Event{ID: eventID, Type: eventType, Timestamp: time.Now(), Data: map[string]string{"tenant_id": tenantID}} provider.PublishEvent(ctx, event) // InsertMany was still called once (batch fails as a whole). if len(inserter.insertManyCalls) != 1 { @@ -196,12 +185,12 @@ func TestWebhookProvider_PublishEvent(t *testing.T) { webhooks := make([]models.Webhook, 501) for i := range webhooks { - webhooks[i] = models.Webhook{ID: uuid.Must(uuid.NewV7())} + webhooks[i] = models.Webhook{ID: uuid.Must(uuid.NewV7()), TenantID: &tenantID} } repo := &mockProviderRepo{webhooks: webhooks} provider := NewWebhookProvider(inserter, repo, 3, 500, 0, 0, 0, nil) - event := Event{ID: eventID, Type: eventType, Timestamp: time.Now(), Data: nil} + event := Event{ID: eventID, Type: eventType, Timestamp: time.Now(), Data: map[string]string{"tenant_id": tenantID}} provider.PublishEvent(ctx, event) if len(inserter.insertManyCalls) != 2 { @@ -216,4 +205,90 @@ func TestWebhookProvider_PublishEvent(t *testing.T) { t.Errorf("second batch params length = %d, want 1", len(inserter.insertManyCalls[1])) } }) + + t.Run("filters tenant-scoped webhooks before enqueue", func(t *testing.T) { + tenantA := "org-123" + tenantB := "org-other" + inserter := &mockWebhookInserter{} + repo := &mockProviderRepo{ + webhooks: []models.Webhook{ + {ID: wh1}, + {ID: wh2, TenantID: &tenantA}, + {ID: uuid.Must(uuid.NewV7()), TenantID: &tenantB}, + }, + } + provider := NewWebhookProvider(inserter, repo, 3, 500, 0, 0, 0, nil) + event := Event{ + ID: eventID, + Type: eventType, + Timestamp: time.Now(), + Data: &models.FeedbackRecord{TenantID: tenantA}, + } + + provider.PublishEvent(ctx, event) + + if repo.tenantID == nil || *repo.tenantID != tenantA { + t.Fatalf("tenantID = %v, want %q", repo.tenantID, tenantA) + } + + if len(inserter.insertManyCalls) != 1 { + t.Fatalf("InsertMany called %d times, want 1", len(inserter.insertManyCalls)) + } + + params := inserter.insertManyCalls[0] + if len(params) != 1 { + t.Fatalf("InsertMany params length = %d, want 1", len(params)) + } + + gotWebhookIDs := map[uuid.UUID]bool{} + + for _, p := range params { + args, ok := p.Args.(WebhookDispatchArgs) + if !ok { + t.Fatalf("Args type = %T, want WebhookDispatchArgs", p.Args) + } + + if args.TenantID == nil || *args.TenantID != tenantA { + t.Errorf("TenantID = %v, want %q", args.TenantID, tenantA) + } + + gotWebhookIDs[args.WebhookID] = true + } + + if !gotWebhookIDs[wh2] || gotWebhookIDs[wh1] { + t.Errorf("enqueued webhook IDs = %v, want only matching tenant webhook", gotWebhookIDs) + } + }) + + t.Run("tenant-less events do not query or enqueue webhooks", func(t *testing.T) { + inserter := &mockWebhookInserter{} + repo := &mockProviderRepo{ + webhooks: []models.Webhook{ + {ID: wh1}, + {ID: wh2, TenantID: &tenantID}, + }, + } + provider := NewWebhookProvider(inserter, repo, 3, 500, 0, 0, 0, nil) + event := Event{ID: eventID, Type: eventType, Timestamp: time.Now(), Data: []uuid.UUID{uuid.Must(uuid.NewV7())}} + + provider.PublishEvent(ctx, event) + + if repo.listCallCount != 0 { + t.Fatalf("ListEnabledForEventTypeAndTenant called %d times, want 0", repo.listCallCount) + } + + if len(inserter.insertManyCalls) != 0 { + t.Fatalf("InsertMany called %d times, want 0", len(inserter.insertManyCalls)) + } + }) +} + +func cloneStringPointer(value *string) *string { + if value == nil { + return nil + } + + v := *value + + return &v } diff --git a/internal/service/webhook_sender.go b/internal/service/webhook_sender.go index ef5ef6d..66224f3 100644 --- a/internal/service/webhook_sender.go +++ b/internal/service/webhook_sender.go @@ -12,6 +12,7 @@ import ( "strconv" "time" + "github.com/google/uuid" standardwebhooks "github.com/standard-webhooks/standard-webhooks/libraries/go" "github.com/formbricks/hub/internal/models" @@ -29,9 +30,14 @@ type WebhookSender interface { Send(ctx context.Context, webhook *models.Webhook, payload *WebhookPayload) error } +// WebhookSenderRepository persists webhook state changes caused by delivery. +type WebhookSenderRepository interface { + Update(ctx context.Context, id uuid.UUID, req *models.UpdateWebhookRequest) (*models.Webhook, error) +} + // WebhookSenderImpl implements WebhookSender with Standard Webhooks conformance. type WebhookSenderImpl struct { - repo WebhooksRepository + repo WebhookSenderRepository httpClient *http.Client metrics observability.WebhookMetrics urlHostBlacklist map[string]struct{} @@ -44,7 +50,7 @@ type WebhookSenderImpl struct { // metrics may be nil when metrics are disabled. // If httpClient is non-nil, it is used as-is (e.g. for tests that hit loopback); otherwise a secured client is built. func NewWebhookSenderImpl( - repo WebhooksRepository, metrics observability.WebhookMetrics, urlHostBlacklist map[string]struct{}, + repo WebhookSenderRepository, metrics observability.WebhookMetrics, urlHostBlacklist map[string]struct{}, httpTimeout time.Duration, httpClient *http.Client, ) *WebhookSenderImpl { if httpClient == nil { diff --git a/internal/service/webhook_sender_test.go b/internal/service/webhook_sender_test.go index e1f6ea1..1e57613 100644 --- a/internal/service/webhook_sender_test.go +++ b/internal/service/webhook_sender_test.go @@ -24,40 +24,6 @@ func (m *mockSenderRepo) Update(_ context.Context, _ uuid.UUID, _ *models.Update return nil, m.updateErr } -func (m *mockSenderRepo) Create(_ context.Context, _ *models.CreateWebhookRequest) (*models.Webhook, error) { - return nil, nil -} - -func (m *mockSenderRepo) GetByID(_ context.Context, _ uuid.UUID) (*models.Webhook, error) { - return nil, nil -} - -func (m *mockSenderRepo) List(_ context.Context, _ *models.ListWebhooksFilters) ([]models.Webhook, bool, error) { - return nil, false, nil -} - -func (m *mockSenderRepo) ListAfterCursor( - _ context.Context, _ *models.ListWebhooksFilters, _ time.Time, _ uuid.UUID, -) ([]models.Webhook, bool, error) { - return nil, false, nil -} - -func (m *mockSenderRepo) Count(_ context.Context, _ *models.ListWebhooksFilters) (int64, error) { - return 0, nil -} - -func (m *mockSenderRepo) Delete(_ context.Context, _ uuid.UUID) error { - return nil -} - -func (m *mockSenderRepo) ListEnabled(_ context.Context) ([]models.Webhook, error) { - return nil, nil -} - -func (m *mockSenderRepo) ListEnabledForEventType(_ context.Context, _ string) ([]models.Webhook, error) { - return nil, nil -} - func TestWebhookSenderImpl_Send(t *testing.T) { ctx := context.Background() webhookID := uuid.Must(uuid.NewV7()) diff --git a/internal/service/webhook_tenant.go b/internal/service/webhook_tenant.go new file mode 100644 index 0000000..00b88b9 --- /dev/null +++ b/internal/service/webhook_tenant.go @@ -0,0 +1,128 @@ +package service + +import ( + "encoding/json" + + "github.com/formbricks/hub/internal/models" +) + +// TenantIDFromEventData extracts tenant_id from known event payload shapes. +func TenantIDFromEventData(data any) (string, bool) { + switch payload := data.(type) { + case *models.FeedbackRecord: + if payload == nil { + return "", false + } + + return tenantIDFromString(payload.TenantID) + case models.FeedbackRecord: + return tenantIDFromString(payload.TenantID) + case *models.Webhook: + if payload == nil { + return "", false + } + + return tenantIDFromPointer(payload.TenantID) + case models.Webhook: + return tenantIDFromPointer(payload.TenantID) + case *models.WebhookPublic: + if payload == nil { + return "", false + } + + return tenantIDFromPointer(payload.TenantID) + case models.WebhookPublic: + return tenantIDFromPointer(payload.TenantID) + case *models.DeletedIDsEventData: + if payload == nil { + return "", false + } + + return tenantIDFromString(payload.TenantID) + case models.DeletedIDsEventData: + return tenantIDFromString(payload.TenantID) + case map[string]any: + return tenantIDFromMapValue(payload["tenant_id"]) + case map[string]string: + return tenantIDFromString(payload["tenant_id"]) + case json.RawMessage: + return tenantIDFromRawJSON(payload) + } + + return tenantIDFromJSON(data) +} + +// TenantIDPointerFromEventData returns a detached pointer so it can be safely stored in job args. +func TenantIDPointerFromEventData(data any) *string { + tenantID, ok := TenantIDFromEventData(data) + if !ok { + return nil + } + + return stringPointer(tenantID) +} + +// WebhookMatchesTenant reports whether a webhook may receive an event with tenantID. +func WebhookMatchesTenant(webhook *models.Webhook, tenantID *string) bool { + if webhook == nil { + return false + } + + if webhook.TenantID == nil || tenantID == nil { + return false + } + + return *webhook.TenantID == *tenantID +} + +func tenantIDFromMapValue(value any) (string, bool) { + tenantID, ok := value.(string) + if !ok { + return "", false + } + + return tenantIDFromString(tenantID) +} + +func tenantIDFromPointer(tenantID *string) (string, bool) { + if tenantID == nil { + return "", false + } + + return tenantIDFromString(*tenantID) +} + +func tenantIDFromString(tenantID string) (string, bool) { + if tenantID == "" { + return "", false + } + + return tenantID, true +} + +func tenantIDFromJSON(data any) (string, bool) { + payload, err := json.Marshal(data) + if err != nil { + return "", false + } + + return tenantIDFromRawJSON(payload) +} + +func tenantIDFromRawJSON(payload []byte) (string, bool) { + var envelope struct { + TenantID *string `json:"tenant_id"` + } + + if err := json.Unmarshal(payload, &envelope); err != nil { + return "", false + } + + return tenantIDFromPointer(envelope.TenantID) +} + +func stringPointer(value string) *string { + v := value + + return &v +} diff --git a/internal/service/webhook_tenant_test.go b/internal/service/webhook_tenant_test.go new file mode 100644 index 0000000..b5ef8d0 --- /dev/null +++ b/internal/service/webhook_tenant_test.go @@ -0,0 +1,121 @@ +package service + +import ( + "encoding/json" + "testing" + + "github.com/formbricks/hub/internal/models" +) + +func TestTenantIDFromEventData(t *testing.T) { + tenantID := "org-123" + + tests := []struct { + name string + data any + want string + ok bool + }{ + { + name: "feedback record pointer", + data: &models.FeedbackRecord{TenantID: tenantID}, + want: tenantID, + ok: true, + }, + { + name: "feedback record value", + data: models.FeedbackRecord{TenantID: tenantID}, + want: tenantID, + ok: true, + }, + { + name: "webhook pointer", + data: &models.Webhook{TenantID: &tenantID}, + want: tenantID, + ok: true, + }, + { + name: "deleted IDs event data", + data: models.DeletedIDsEventData{TenantID: tenantID}, + want: tenantID, + ok: true, + }, + { + name: "map any", + data: map[string]any{"tenant_id": tenantID}, + want: tenantID, + ok: true, + }, + { + name: "map string", + data: map[string]string{"tenant_id": tenantID}, + want: tenantID, + ok: true, + }, + { + name: "raw json", + data: json.RawMessage(`{"tenant_id":"org-123"}`), + want: tenantID, + ok: true, + }, + { + name: "struct with json tag fallback", + data: struct { + TenantID string `json:"tenant_id"` + }{TenantID: tenantID}, + want: tenantID, + ok: true, + }, + { + name: "tenant-less data", + data: []string{"record-id"}, + ok: false, + }, + { + name: "empty tenant", + data: map[string]any{"tenant_id": ""}, + ok: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, ok := TenantIDFromEventData(tt.data) + if ok != tt.ok { + t.Fatalf("ok = %v, want %v", ok, tt.ok) + } + + if got != tt.want { + t.Errorf("tenantID = %q, want %q", got, tt.want) + } + }) + } +} + +func TestWebhookMatchesTenant(t *testing.T) { + tenantID := "org-123" + otherTenantID := "org-other" + + tests := []struct { + name string + webhook *models.Webhook + tenantID *string + want bool + }{ + {name: "webhook without tenant rejects tenant event", webhook: &models.Webhook{}, tenantID: &tenantID, want: false}, + {name: "webhook without tenant rejects tenant-less event", webhook: &models.Webhook{}, tenantID: nil, want: false}, + {name: "scoped webhook matches same tenant", webhook: &models.Webhook{TenantID: &tenantID}, tenantID: &tenantID, want: true}, + {name: "scoped webhook rejects different tenant", webhook: &models.Webhook{TenantID: &tenantID}, tenantID: &otherTenantID, want: false}, + {name: "scoped webhook rejects tenant-less event", webhook: &models.Webhook{TenantID: &tenantID}, tenantID: nil, want: false}, + {name: "nil webhook rejects", webhook: nil, tenantID: &tenantID, want: false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := WebhookMatchesTenant(tt.webhook, tt.tenantID) + if got != tt.want { + t.Errorf("WebhookMatchesTenant() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/internal/service/webhooks_service.go b/internal/service/webhooks_service.go index a6f4b89..a249341 100644 --- a/internal/service/webhooks_service.go +++ b/internal/service/webhooks_service.go @@ -5,6 +5,7 @@ import ( "crypto/rand" "encoding/base64" "fmt" + "log/slog" "net" "net/netip" "net/url" @@ -31,9 +32,7 @@ type WebhooksRepository interface { ) ([]models.Webhook, bool, error) Count(ctx context.Context, filters *models.ListWebhooksFilters) (int64, error) Update(ctx context.Context, id uuid.UUID, req *models.UpdateWebhookRequest) (*models.Webhook, error) - Delete(ctx context.Context, id uuid.UUID) error - ListEnabled(ctx context.Context) ([]models.Webhook, error) - ListEnabledForEventType(ctx context.Context, eventType string) ([]models.Webhook, error) + Delete(ctx context.Context, id uuid.UUID) (*models.DeletedWebhook, error) } // WebhooksService handles business logic for webhooks. @@ -59,6 +58,10 @@ func NewWebhooksService( // CreateWebhook creates a new webhook. func (s *WebhooksService) CreateWebhook(ctx context.Context, req *models.CreateWebhookRequest) (*models.Webhook, error) { + if err := normalizeRequiredWebhookTenantID(req.TenantID); err != nil { + return nil, err + } + count, err := s.repo.Count(ctx, &models.ListWebhooksFilters{}) if err != nil { return nil, fmt.Errorf("count webhooks: %w", err) @@ -347,6 +350,10 @@ func (s *WebhooksService) ListWebhooks(ctx context.Context, filters *models.List // UpdateWebhook updates an existing webhook. func (s *WebhooksService) UpdateWebhook(ctx context.Context, id uuid.UUID, req *models.UpdateWebhookRequest) (*models.Webhook, error) { + if err := normalizeOptionalWebhookTenantID(req.TenantID); err != nil { + return nil, err + } + if req.URL != nil { if err := validateWebhookURLHost(ctx, *req.URL, s.urlHostBlacklist); err != nil { return nil, err @@ -369,14 +376,54 @@ func (s *WebhooksService) UpdateWebhook(ctx context.Context, id uuid.UUID, req * return webhook, nil } +func normalizeRequiredWebhookTenantID(tenantID *string) error { + normalized, err := normalizeRequiredTenantID(tenantID) + if err != nil { + return err + } + + *tenantID = normalized + + return nil +} + +func normalizeOptionalWebhookTenantID(tenantID *string) error { + if tenantID == nil { + return nil + } + + return normalizeWebhookTenantID(tenantID) +} + +func normalizeWebhookTenantID(tenantID *string) error { + normalized, err := normalizeRequiredTenantID(tenantID) + if err != nil { + return err + } + + *tenantID = normalized + + return nil +} + // DeleteWebhook deletes a webhook by ID. -// Publishes WebhookDeleted with data = [id] (array of deleted IDs) for consistency with feedback record deletes. +// Publishes WebhookDeleted with tenant-aware deleted IDs. func (s *WebhooksService) DeleteWebhook(ctx context.Context, id uuid.UUID) error { - if err := s.repo.Delete(ctx, id); err != nil { + webhook, err := s.repo.Delete(ctx, id) + if err != nil { return fmt.Errorf("delete webhook: %w", err) } - s.publisher.PublishEvent(ctx, datatypes.WebhookDeleted, []uuid.UUID{id}) + if webhook.TenantID == nil { + slog.Warn("webhook delete: tenant_id missing, skipping webhook event", "webhook_id", id) + + return nil + } + + s.publisher.PublishEvent(ctx, datatypes.WebhookDeleted, models.DeletedIDsEventData{ + TenantID: *webhook.TenantID, + IDs: []uuid.UUID{id}, + }) return nil } diff --git a/internal/service/webhooks_service_test.go b/internal/service/webhooks_service_test.go index 4ad8efc..f29f58b 100644 --- a/internal/service/webhooks_service_test.go +++ b/internal/service/webhooks_service_test.go @@ -15,7 +15,11 @@ import ( ) type mockWebhooksRepo struct { - count int64 + count int64 + webhook *models.Webhook + deleted *models.DeletedWebhook + deletedID uuid.UUID + getByIDCalls int } func (m *mockWebhooksRepo) Create(_ context.Context, _ *models.CreateWebhookRequest) (*models.Webhook, error) { @@ -23,6 +27,12 @@ func (m *mockWebhooksRepo) Create(_ context.Context, _ *models.CreateWebhookRequ } func (m *mockWebhooksRepo) GetByID(_ context.Context, _ uuid.UUID) (*models.Webhook, error) { + m.getByIDCalls++ + + if m.webhook != nil { + return m.webhook, nil + } + return nil, nil } @@ -44,16 +54,10 @@ func (m *mockWebhooksRepo) Update(_ context.Context, _ uuid.UUID, _ *models.Upda return nil, nil } -func (m *mockWebhooksRepo) Delete(_ context.Context, _ uuid.UUID) error { - return nil -} - -func (m *mockWebhooksRepo) ListEnabled(_ context.Context) ([]models.Webhook, error) { - return nil, nil -} +func (m *mockWebhooksRepo) Delete(_ context.Context, id uuid.UUID) (*models.DeletedWebhook, error) { + m.deletedID = id -func (m *mockWebhooksRepo) ListEnabledForEventType(_ context.Context, _ string) ([]models.Webhook, error) { - return nil, nil + return m.deleted, nil } type noopPublisher struct{} @@ -63,13 +67,46 @@ func (noopPublisher) PublishEvent(_ context.Context, _ datatypes.EventType, _ an func (noopPublisher) PublishEventWithChangedFields(_ context.Context, _ datatypes.EventType, _ any, _ []string) { } +type capturePublisher struct { + eventType datatypes.EventType + data any + changedFields []string + callCount int + events []capturedEvent +} + +type capturedEvent struct { + eventType datatypes.EventType + data any + changedFields []string +} + +func (p *capturePublisher) PublishEvent(_ context.Context, eventType datatypes.EventType, data any) { + p.eventType = eventType + p.data = data + p.callCount++ + p.events = append(p.events, capturedEvent{eventType: eventType, data: data}) +} + +func (p *capturePublisher) PublishEventWithChangedFields( + _ context.Context, eventType datatypes.EventType, data any, changedFields []string, +) { + p.eventType = eventType + p.data = data + p.changedFields = changedFields + p.callCount++ + p.events = append(p.events, capturedEvent{eventType: eventType, data: data, changedFields: changedFields}) +} + func TestWebhooksService_CreateWebhook_InvalidSigningKey(t *testing.T) { ctx := context.Background() svc := NewWebhooksService(&mockWebhooksRepo{count: 0}, noopPublisher{}, 10, nil) + tenantID := "org-123" req := &models.CreateWebhookRequest{ URL: "https://example.com/webhook", SigningKey: "not-valid", + TenantID: &tenantID, EventTypes: []datatypes.EventType{datatypes.FeedbackRecordCreated}, } @@ -107,6 +144,7 @@ func TestWebhooksService_CreateWebhook_RejectsSSRFHosts(t *testing.T) { ctx := context.Background() svc := NewWebhooksService(&mockWebhooksRepo{count: 0}, noopPublisher{}, 10, ssrfBlacklist) validKey := "whsec_" + "abcdefghijklmnopqrstuvwxyz123456" + tenantID := "org-123" tests := []struct { name string @@ -125,6 +163,7 @@ func TestWebhooksService_CreateWebhook_RejectsSSRFHosts(t *testing.T) { req := &models.CreateWebhookRequest{ URL: tt.url, SigningKey: validKey, + TenantID: &tenantID, EventTypes: []datatypes.EventType{datatypes.FeedbackRecordCreated}, } @@ -141,6 +180,75 @@ func TestWebhooksService_CreateWebhook_RejectsSSRFHosts(t *testing.T) { } } +func TestWebhooksService_CreateWebhook_RequiresTenantID(t *testing.T) { + ctx := context.Background() + svc := NewWebhooksService(&mockWebhooksRepo{count: 0}, noopPublisher{}, 10, nil) + + req := &models.CreateWebhookRequest{ + URL: "https://example.com/webhook", + SigningKey: "whsec_" + "abcdefghijklmnopqrstuvwxyz123456", + EventTypes: []datatypes.EventType{datatypes.FeedbackRecordCreated}, + } + + _, err := svc.CreateWebhook(ctx, req) + if !errors.Is(err, huberrors.ErrValidation) { + t.Fatalf("expected ErrValidation, got %v", err) + } +} + +func TestWebhooksService_UpdateWebhook_RejectsEmptyTenantID(t *testing.T) { + ctx := context.Background() + svc := NewWebhooksService(&mockWebhooksRepo{count: 0}, noopPublisher{}, 10, nil) + id := uuid.Must(uuid.NewV7()) + tenantID := " " + + req := &models.UpdateWebhookRequest{TenantID: &tenantID} + + _, err := svc.UpdateWebhook(ctx, id, req) + if !errors.Is(err, huberrors.ErrValidation) { + t.Fatalf("expected ErrValidation, got %v", err) + } +} + +func TestWebhooksService_DeleteWebhook_PublishesTenantAwareDeletedEvent(t *testing.T) { + ctx := context.Background() + webhookID := uuid.Must(uuid.NewV7()) + tenantID := "org-123" + repo := &mockWebhooksRepo{deleted: &models.DeletedWebhook{ID: webhookID, TenantID: &tenantID}} + publisher := &capturePublisher{} + svc := NewWebhooksService(repo, publisher, 10, nil) + + err := svc.DeleteWebhook(ctx, webhookID) + if err != nil { + t.Fatalf("DeleteWebhook() error = %v", err) + } + + if repo.deletedID != webhookID { + t.Fatalf("deletedID = %v, want %v", repo.deletedID, webhookID) + } + + if repo.getByIDCalls != 0 { + t.Fatalf("GetByID called %d times, want 0; delete should return the deleted row atomically", repo.getByIDCalls) + } + + if publisher.callCount != 1 || publisher.eventType != datatypes.WebhookDeleted { + t.Fatalf("published event = (%d, %s), want one webhook.deleted", publisher.callCount, publisher.eventType) + } + + data, ok := publisher.data.(models.DeletedIDsEventData) + if !ok { + t.Fatalf("published data type = %T, want DeletedIDsEventData", publisher.data) + } + + if data.TenantID != tenantID { + t.Errorf("TenantID = %q, want %q", data.TenantID, tenantID) + } + + if len(data.IDs) != 1 || data.IDs[0] != webhookID { + t.Errorf("IDs = %v, want [%v]", data.IDs, webhookID) + } +} + func TestWebhooksService_UpdateWebhook_RejectsSSRFHosts(t *testing.T) { ctx := context.Background() svc := NewWebhooksService(&mockWebhooksRepo{count: 0}, noopPublisher{}, 10, ssrfBlacklist) diff --git a/internal/workers/webhook_dispatch.go b/internal/workers/webhook_dispatch.go index 5c68a38..d45ac63 100644 --- a/internal/workers/webhook_dispatch.go +++ b/internal/workers/webhook_dispatch.go @@ -84,7 +84,68 @@ func (w *WebhookDispatchWorker) Work(ctx context.Context, job *river.Job[service return nil } - payload := argsToPayload(args) + jobTenantID := args.TenantID + dataTenantID := service.TenantIDPointerFromEventData(args.Data) + + if jobTenantID != nil && dataTenantID != nil && *jobTenantID != *dataTenantID { + if w.metrics != nil { + w.metrics.RecordDispatchError(ctx, "tenant_mismatch") + } + + slog.Error("webhook dispatch: job tenant_id conflicts with payload tenant_id, skipping delivery", + "event_id", args.EventID, + "webhook_id", args.WebhookID, + "job_tenant_id", *jobTenantID, + "payload_tenant_id", *dataTenantID, + ) + + return nil + } + + tenantID := jobTenantID + if tenantID == nil { + tenantID = dataTenantID + } + + if tenantID == nil { + if w.metrics != nil { + w.metrics.RecordDispatchError(ctx, "missing_tenant_id") + } + + slog.Error("webhook dispatch: event tenant_id missing, skipping delivery", + "event_id", args.EventID, + "webhook_id", args.WebhookID, + ) + + return nil + } + + if !service.WebhookMatchesTenant(webhook, tenantID) { + if w.metrics != nil { + w.metrics.RecordDispatchError(ctx, "tenant_mismatch") + } + + var webhookTenantID any + if webhook.TenantID != nil { + webhookTenantID = *webhook.TenantID + } + + var eventTenantID any + if tenantID != nil { + eventTenantID = *tenantID + } + + slog.Error("webhook dispatch: tenant scope mismatch, skipping delivery", + "event_id", args.EventID, + "webhook_id", args.WebhookID, + "webhook_tenant_id", webhookTenantID, + "event_tenant_id", eventTenantID, + ) + + return nil + } + + payload := service.NewWebhookPayload(args) err = w.sender.Send(ctx, webhook, payload) if err == nil { @@ -146,14 +207,3 @@ func (w *WebhookDispatchWorker) Work(ctx context.Context, job *river.Job[service return fmt.Errorf("webhook send: %w", err) } - -// argsToPayload builds a WebhookPayload from job args. -func argsToPayload(args service.WebhookDispatchArgs) *service.WebhookPayload { - return &service.WebhookPayload{ - ID: args.EventID, - Type: args.EventType, - Timestamp: args.Timestamp, - Data: args.Data, - ChangedFields: args.ChangedFields, - } -} diff --git a/internal/workers/webhook_dispatch_test.go b/internal/workers/webhook_dispatch_test.go index 750a842..72fc72f 100644 --- a/internal/workers/webhook_dispatch_test.go +++ b/internal/workers/webhook_dispatch_test.go @@ -31,10 +31,15 @@ func (m *mockDispatchRepo) Update(_ context.Context, _ uuid.UUID, req *models.Up } type mockSender struct { - err error + err error + calls int + payloads []*service.WebhookPayload } -func (m *mockSender) Send(_ context.Context, _ *models.Webhook, _ *service.WebhookPayload) error { +func (m *mockSender) Send(_ context.Context, _ *models.Webhook, payload *service.WebhookPayload) error { + m.calls++ + m.payloads = append(m.payloads, payload) + return m.err } @@ -42,11 +47,13 @@ func TestWebhookDispatchWorker_Work(t *testing.T) { ctx := context.Background() eventID := uuid.Must(uuid.NewV7()) webhookID := uuid.Must(uuid.NewV7()) + tenantID := "org-123" args := service.WebhookDispatchArgs{ EventID: eventID, EventType: "feedback_record.created", Timestamp: time.Now(), Data: nil, + TenantID: &tenantID, WebhookID: webhookID, } @@ -75,7 +82,9 @@ func TestWebhookDispatchWorker_Work(t *testing.T) { }) t.Run("returns nil on send success", func(t *testing.T) { - repo := &mockDispatchRepo{webhook: &models.Webhook{ID: webhookID, Enabled: true, URL: "http://x", SigningKey: "sk"}} + repo := &mockDispatchRepo{ + webhook: &models.Webhook{ID: webhookID, Enabled: true, URL: "http://x", SigningKey: "sk", TenantID: &tenantID}, + } sender := &mockSender{} worker := NewWebhookDispatchWorker(repo, sender, 15*time.Second, nil) job := &river.Job[service.WebhookDispatchArgs]{JobRow: &rivertype.JobRow{}, Args: args} @@ -88,10 +97,189 @@ func TestWebhookDispatchWorker_Work(t *testing.T) { if repo.update != nil { t.Error("Update should not be called on success") } + + if sender.calls != 1 { + t.Errorf("Send called %d times, want 1", sender.calls) + } + }) + + t.Run("returns nil without send when scoped webhook tenant mismatches job tenant", func(t *testing.T) { + eventTenant := "org-other" + repo := &mockDispatchRepo{ + webhook: &models.Webhook{ + ID: webhookID, + Enabled: true, + URL: "http://x", + SigningKey: "sk", + TenantID: &tenantID, + }, + } + sender := &mockSender{} + worker := NewWebhookDispatchWorker(repo, sender, 15*time.Second, nil) + scopedArgs := args + scopedArgs.TenantID = &eventTenant + job := &river.Job[service.WebhookDispatchArgs]{JobRow: &rivertype.JobRow{}, Args: scopedArgs} + + err := worker.Work(ctx, job) + if err != nil { + t.Errorf("Work() error = %v, want nil", err) + } + + if sender.calls != 0 { + t.Errorf("Send called %d times, want 0", sender.calls) + } + + if repo.update != nil { + t.Error("Update should not be called for tenant mismatch") + } + }) + + t.Run("returns nil without send when legacy job data tenant mismatches scoped webhook", func(t *testing.T) { + webhookTenant := "org-123" + repo := &mockDispatchRepo{ + webhook: &models.Webhook{ + ID: webhookID, + Enabled: true, + URL: "http://x", + SigningKey: "sk", + TenantID: &webhookTenant, + }, + } + sender := &mockSender{} + worker := NewWebhookDispatchWorker(repo, sender, 15*time.Second, nil) + scopedArgs := args + scopedArgs.Data = map[string]any{"tenant_id": "org-other"} + scopedArgs.TenantID = nil + job := &river.Job[service.WebhookDispatchArgs]{JobRow: &rivertype.JobRow{}, Args: scopedArgs} + + err := worker.Work(ctx, job) + if err != nil { + t.Errorf("Work() error = %v, want nil", err) + } + + if sender.calls != 0 { + t.Errorf("Send called %d times, want 0", sender.calls) + } + }) + + t.Run("returns nil without send when job has no tenant boundary", func(t *testing.T) { + webhookTenant := "org-123" + repo := &mockDispatchRepo{ + webhook: &models.Webhook{ + ID: webhookID, + Enabled: true, + URL: "http://x", + SigningKey: "sk", + TenantID: &webhookTenant, + }, + } + sender := &mockSender{} + worker := NewWebhookDispatchWorker(repo, sender, 15*time.Second, nil) + scopedArgs := args + scopedArgs.Data = nil + scopedArgs.TenantID = nil + job := &river.Job[service.WebhookDispatchArgs]{JobRow: &rivertype.JobRow{}, Args: scopedArgs} + + err := worker.Work(ctx, job) + if err != nil { + t.Errorf("Work() error = %v, want nil", err) + } + + if sender.calls != 0 { + t.Errorf("Send called %d times, want 0", sender.calls) + } + }) + + t.Run("returns nil without send when job tenant conflicts with payload tenant", func(t *testing.T) { + payloadTenant := "org-other" + repo := &mockDispatchRepo{ + webhook: &models.Webhook{ + ID: webhookID, + Enabled: true, + URL: "http://x", + SigningKey: "sk", + TenantID: &tenantID, + }, + } + sender := &mockSender{} + worker := NewWebhookDispatchWorker(repo, sender, 15*time.Second, nil) + scopedArgs := args + scopedArgs.TenantID = &tenantID + scopedArgs.Data = map[string]any{"tenant_id": payloadTenant} + job := &river.Job[service.WebhookDispatchArgs]{JobRow: &rivertype.JobRow{}, Args: scopedArgs} + + err := worker.Work(ctx, job) + if err != nil { + t.Errorf("Work() error = %v, want nil", err) + } + + if sender.calls != 0 { + t.Errorf("Send called %d times, want 0", sender.calls) + } + }) + + t.Run("sends when scoped webhook tenant matches job tenant", func(t *testing.T) { + repo := &mockDispatchRepo{ + webhook: &models.Webhook{ + ID: webhookID, + Enabled: true, + URL: "http://x", + SigningKey: "sk", + TenantID: &tenantID, + }, + } + sender := &mockSender{} + worker := NewWebhookDispatchWorker(repo, sender, 15*time.Second, nil) + scopedArgs := args + scopedArgs.TenantID = &tenantID + job := &river.Job[service.WebhookDispatchArgs]{JobRow: &rivertype.JobRow{}, Args: scopedArgs} + + err := worker.Work(ctx, job) + if err != nil { + t.Errorf("Work() error = %v, want nil", err) + } + + if sender.calls != 1 { + t.Errorf("Send called %d times, want 1", sender.calls) + } + }) + + t.Run("sends legacy job and includes derived tenant in payload", func(t *testing.T) { + webhookTenant := "org-123" + repo := &mockDispatchRepo{ + webhook: &models.Webhook{ + ID: webhookID, + Enabled: true, + URL: "http://x", + SigningKey: "sk", + TenantID: &webhookTenant, + }, + } + sender := &mockSender{} + worker := NewWebhookDispatchWorker(repo, sender, 15*time.Second, nil) + scopedArgs := args + scopedArgs.Data = map[string]any{"tenant_id": webhookTenant} + scopedArgs.TenantID = nil + job := &river.Job[service.WebhookDispatchArgs]{JobRow: &rivertype.JobRow{}, Args: scopedArgs} + + err := worker.Work(ctx, job) + if err != nil { + t.Errorf("Work() error = %v, want nil", err) + } + + if sender.calls != 1 { + t.Fatalf("Send called %d times, want 1", sender.calls) + } + + if sender.payloads[0].TenantID == nil || *sender.payloads[0].TenantID != webhookTenant { + t.Errorf("payload tenant_id = %v, want %q", sender.payloads[0].TenantID, webhookTenant) + } }) t.Run("returns error and does not update when send fails and attempt < max", func(t *testing.T) { - repo := &mockDispatchRepo{webhook: &models.Webhook{ID: webhookID, Enabled: true, URL: "http://x", SigningKey: "sk"}} + repo := &mockDispatchRepo{ + webhook: &models.Webhook{ID: webhookID, Enabled: true, URL: "http://x", SigningKey: "sk", TenantID: &tenantID}, + } sender := &mockSender{err: errors.New("network error")} worker := NewWebhookDispatchWorker(repo, sender, 15*time.Second, nil) job := &river.Job[service.WebhookDispatchArgs]{ @@ -110,7 +298,9 @@ func TestWebhookDispatchWorker_Work(t *testing.T) { }) t.Run("updates webhook and returns error when send fails on last attempt", func(t *testing.T) { - repo := &mockDispatchRepo{webhook: &models.Webhook{ID: webhookID, Enabled: true, URL: "http://x", SigningKey: "sk"}} + repo := &mockDispatchRepo{ + webhook: &models.Webhook{ID: webhookID, Enabled: true, URL: "http://x", SigningKey: "sk", TenantID: &tenantID}, + } sender := &mockSender{err: errors.New("final failure")} worker := NewWebhookDispatchWorker(repo, sender, 15*time.Second, nil) job := &river.Job[service.WebhookDispatchArgs]{ diff --git a/migrations/008_require_webhook_tenant_id.sql b/migrations/008_require_webhook_tenant_id.sql new file mode 100644 index 0000000..576fea5 --- /dev/null +++ b/migrations/008_require_webhook_tenant_id.sql @@ -0,0 +1,18 @@ +-- +goose Up +-- Webhooks are tenant-owned dispatch configuration. Disable legacy global rows and +-- prevent new NULL or empty tenant IDs without blocking this migration on old data. +UPDATE webhooks +SET + enabled = false, + disabled_reason = COALESCE(disabled_reason, 'Disabled by migration: tenant_id is required for webhook isolation'), + disabled_at = COALESCE(disabled_at, NOW()), + updated_at = NOW() +WHERE tenant_id IS NULL OR btrim(tenant_id) = ''; + +ALTER TABLE webhooks + ADD CONSTRAINT webhooks_tenant_id_required + CHECK (tenant_id IS NOT NULL AND btrim(tenant_id) <> '') NOT VALID; + +-- +goose Down +ALTER TABLE webhooks + DROP CONSTRAINT IF EXISTS webhooks_tenant_id_required; diff --git a/openapi.yaml b/openapi.yaml index 19051e3..1dc868d 100644 --- a/openapi.yaml +++ b/openapi.yaml @@ -381,25 +381,26 @@ paths: tags: - Feedback Records summary: Bulk delete feedback records by user ID - description: Permanently deletes all feedback record data points matching the specified user_id. This endpoint supports GDPR Article 17 (Right to Erasure) requests. + description: Permanently deletes feedback record data points matching the specified user_id. Omit tenant_id to delete that user_id across all tenants for GDPR Article 17 (Right to Erasure) requests. Provide tenant_id to restrict deletion to that tenant only. operationId: bulk-delete-feedback-records parameters: - name: user_id in: query - description: Delete all records matching this user ID (required). NULL bytes not allowed. + description: Delete records matching this user ID (required). NULL bytes not allowed. required: true schema: type: string - description: Delete all records matching this user ID (required). NULL bytes not allowed. + description: Delete records matching this user ID (required). NULL bytes not allowed. minLength: 1 pattern: '^[^\x00]*$' example: "user-abc-123" - name: tenant_id in: query - description: Filter by tenant ID (optional, for multi-tenant deployments). NULL bytes not allowed. + description: Optional tenant scope. Omit this parameter to delete all records matching user_id across tenants; provide it to delete only records for this tenant. Empty strings and NULL bytes are not allowed. schema: type: string - description: Filter by tenant ID (optional, for multi-tenant deployments). NULL bytes not allowed. + description: Optional tenant scope. Omit this parameter to delete all records matching user_id across tenants; provide it to delete only records for this tenant. Empty strings and NULL bytes are not allowed. + minLength: 1 pattern: '^[^\x00]*$' example: "org-123" responses: @@ -837,6 +838,7 @@ paths: Creates a new webhook endpoint. When events occur (e.g. feedback_record.created), the Hub POSTs a signed payload to the webhook URL. If signing_key is omitted, a key is auto-generated (Standard Webhooks format, whsec_...). + tenant_id is required; webhooks only receive events from that exact tenant. See WebhookDeliveryPayload for the payload structure sent to your URL. operationId: create-webhook requestBody: @@ -849,6 +851,7 @@ paths: summary: Subscribe to feedback events value: url: "https://example.com/hub-events" + tenant_id: "org-123" event_types: - "feedback_record.created" - "feedback_record.updated" @@ -1430,7 +1433,8 @@ components: description: Whether the webhook is active (default true) tenant_id: type: string - description: Tenant/organization identifier. NULL bytes not allowed. + description: Tenant/organization identifier. Required for webhook isolation; NULL bytes not allowed. + minLength: 1 maxLength: 255 pattern: '^[^\x00]*$' example: "org-123" @@ -1443,6 +1447,7 @@ components: $ref: '#/components/schemas/WebhookEventType' required: - url + - tenant_id ListWebhooksOutputBody: type: object additionalProperties: false @@ -1492,8 +1497,9 @@ components: type: boolean description: Enable or disable the webhook tenant_id: - type: [string, "null"] - description: Omit or send null to leave unchanged. Send empty string to clear (store as null). + type: string + description: Omit to leave unchanged. Empty strings are rejected; webhooks cannot be global. + minLength: 1 maxLength: 255 pattern: '^[^\x00]*$' example: "org-123" @@ -1622,6 +1628,10 @@ components: The webhook-id header is a stable identifier per event: the same value is sent for every delivery attempt (all endpoints and all retries) for that event. Use it as an idempotency key (e.g. store seen IDs for a short window and skip duplicate processing). additionalProperties: false properties: + id: + type: string + format: uuid + description: Stable event ID. Matches the webhook-id header and can be used as an idempotency key. type: type: string description: Event type that occurred @@ -1636,13 +1646,16 @@ components: type: string format: date-time description: When the event was published (ISO 8601) + tenant_id: + type: string + description: Tenant/organization identifier for the event. data: description: | Event payload. Shape depends on event type: - feedback_record.created / feedback_record.updated: FeedbackRecordData (object). - feedback_record.deleted: DeletedIdsPayload (array of deleted feedback record IDs). - webhook.created / webhook.updated: WebhookData (object). - - webhook.deleted: DeletedIdsPayload (array of deleted webhook IDs; one element for single delete). + - webhook.deleted: DeletedIdsPayload (array of deleted webhook IDs; one ID for single delete). oneOf: - $ref: '#/components/schemas/FeedbackRecordData' - $ref: '#/components/schemas/WebhookData' @@ -1653,8 +1666,10 @@ components: items: type: string required: + - id - type - timestamp + - tenant_id - data UpdateFeedbackRecordInputBody: type: object diff --git a/tests/integration_test.go b/tests/integration_test.go index 6db9ef3..9a6e4b6 100644 --- a/tests/integration_test.go +++ b/tests/integration_test.go @@ -885,13 +885,16 @@ func TestBulkDeleteFeedbackRecords(t *testing.T) { defer cleanup() client := &http.Client{} - userID := "bulk-delete-test-user-123" + userID := "bulk-delete-test-user-" + uuid.New().String() subID := uuid.New().String() // unique per run to avoid 409 from leftover data - // Create several feedback records with the same user_id - tenantID := "test-tenant" - createPayload := func(fieldID string, valueNum float64) map[string]any { - return map[string]any{ + // Create several feedback records with the same user_id across tenants. + tenantA := "test-tenant-a" + tenantB := "test-tenant-b" + createRecord := func(fieldID, tenantID string, valueNum float64) string { + t.Helper() + + body, err := json.Marshal(map[string]any{ "source_type": "formbricks", "submission_id": subID, "tenant_id": tenantID, @@ -899,38 +902,49 @@ func TestBulkDeleteFeedbackRecords(t *testing.T) { "field_id": fieldID, "field_type": "number", "value_number": valueNum, - } - } - createdIDs := make([]string, 0, 3) - - for i, p := range []map[string]any{ - createPayload("nps_1", 8), - createPayload("nps_2", 9), - createPayload("nps_3", 10), - } { - body, err := json.Marshal(p) + }) require.NoError(t, err) + req, err := http.NewRequestWithContext(context.Background(), http.MethodPost, server.URL+"/v1/feedback-records", bytes.NewBuffer(body)) require.NoError(t, err) req.Header.Set("Authorization", "Bearer "+testAPIKey) req.Header.Set("Content-Type", "application/json") + resp, err := client.Do(req) require.NoError(t, err) - require.Equal(t, http.StatusCreated, resp.StatusCode, "create record %d", i+1) + require.Equal(t, http.StatusCreated, resp.StatusCode, "create record %s/%s", tenantID, fieldID) var rec models.FeedbackRecord err = decodeData(resp, &rec) require.NoError(t, err) - createdIDs = append(createdIDs, rec.ID.String()) - require.NoError(t, resp.Body.Close()) + + return rec.ID.String() } - // Bulk delete by user_id - bulkDelURL := server.URL + "/v1/feedback-records?user_id=" + userID - req, err := http.NewRequestWithContext(context.Background(), http.MethodDelete, bulkDelURL, http.NoBody) + requireStatus := func(id string, status int) { + t.Helper() + + getReq, err := http.NewRequestWithContext(context.Background(), http.MethodGet, server.URL+"/v1/feedback-records/"+id, http.NoBody) + require.NoError(t, err) + getReq.Header.Set("Authorization", "Bearer "+testAPIKey) + + getResp, err := client.Do(getReq) + require.NoError(t, err) + assert.Equal(t, status, getResp.StatusCode) + require.NoError(t, getResp.Body.Close()) + } + + tenantAID := createRecord("nps_1", tenantA, 8) + tenantBID1 := createRecord("nps_2", tenantB, 9) + tenantBID2 := createRecord("nps_3", tenantB, 10) + + // Providing tenant_id scopes deletion to only that tenant. + scopedDelURL := fmt.Sprintf("%s/v1/feedback-records?user_id=%s&tenant_id=%s", + server.URL, url.QueryEscape(userID), url.QueryEscape(tenantA)) + req, err := http.NewRequestWithContext(context.Background(), http.MethodDelete, scopedDelURL, http.NoBody) require.NoError(t, err) req.Header.Set("Authorization", "Bearer "+testAPIKey) @@ -938,6 +952,29 @@ func TestBulkDeleteFeedbackRecords(t *testing.T) { require.NoError(t, err) assert.Equal(t, http.StatusOK, resp.StatusCode) + var scopedResp models.BulkDeleteResponse + + err = decodeData(resp, &scopedResp) + require.NoError(t, err) + require.NoError(t, resp.Body.Close()) + assert.Equal(t, int64(1), scopedResp.DeletedCount) + + requireStatus(tenantAID, http.StatusNotFound) + requireStatus(tenantBID1, http.StatusOK) + requireStatus(tenantBID2, http.StatusOK) + + tenantAID2 := createRecord("nps_4", tenantA, 7) + + // Omitting tenant_id deletes every remaining matching record for GDPR erasure, regardless of tenant. + bulkDelURL := server.URL + "/v1/feedback-records?user_id=" + url.QueryEscape(userID) + req, err = http.NewRequestWithContext(context.Background(), http.MethodDelete, bulkDelURL, http.NoBody) + require.NoError(t, err) + req.Header.Set("Authorization", "Bearer "+testAPIKey) + + resp, err = client.Do(req) + require.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + var bulkResp models.BulkDeleteResponse err = decodeData(resp, &bulkResp) @@ -947,18 +984,12 @@ func TestBulkDeleteFeedbackRecords(t *testing.T) { assert.Equal(t, "Successfully deleted 3 feedback records", bulkResp.Message) // Verify records are gone - for _, id := range createdIDs { - getReq, err := http.NewRequestWithContext(context.Background(), http.MethodGet, server.URL+"/v1/feedback-records/"+id, http.NoBody) - require.NoError(t, err) - getReq.Header.Set("Authorization", "Bearer "+testAPIKey) - getResp, err := client.Do(getReq) - require.NoError(t, err) - assert.Equal(t, http.StatusNotFound, getResp.StatusCode) - require.NoError(t, getResp.Body.Close()) + for _, id := range []string{tenantAID2, tenantBID1, tenantBID2} { + requireStatus(id, http.StatusNotFound) } // Bulk delete again (no matching records) returns 0 - bulkDelURL2 := server.URL + "/v1/feedback-records?user_id=" + userID + bulkDelURL2 := server.URL + "/v1/feedback-records?user_id=" + url.QueryEscape(userID) req2, err := http.NewRequestWithContext(context.Background(), http.MethodDelete, bulkDelURL2, http.NoBody) require.NoError(t, err) req2.Header.Set("Authorization", "Bearer "+testAPIKey) @@ -972,70 +1003,6 @@ func TestBulkDeleteFeedbackRecords(t *testing.T) { require.NoError(t, err) require.NoError(t, resp2.Body.Close()) assert.Equal(t, int64(0), bulkResp2.DeletedCount) - - // Bulk delete with tenant_id: only records for that tenant are deleted - t.Run("Bulk delete with tenant_id filter", func(t *testing.T) { - tenantA, tenantB := "tenant-bulk-a", "tenant-bulk-b" - userIDTenant := "bulk-delete-tenant-user" - - // Create one record with tenant_a, two with tenant_b - for _, item := range []struct { - tenantID string - fieldID string - }{ - {tenantA, "fa"}, - {tenantB, "fb1"}, - {tenantB, "fb2"}, - } { - body, err := json.Marshal(map[string]any{ - "source_type": "formbricks", - "submission_id": userIDTenant + "-" + item.fieldID, - "user_id": userIDTenant, - "tenant_id": item.tenantID, - "field_id": item.fieldID, - "field_type": "text", - "value_text": "x", - }) - require.NoError(t, err) - req, err := http.NewRequestWithContext(context.Background(), http.MethodPost, server.URL+"/v1/feedback-records", bytes.NewBuffer(body)) - require.NoError(t, err) - req.Header.Set("Authorization", "Bearer "+testAPIKey) - req.Header.Set("Content-Type", "application/json") - resp, err := client.Do(req) - require.NoError(t, err) - require.Equal(t, http.StatusCreated, resp.StatusCode) - require.NoError(t, resp.Body.Close()) - } - - // Delete only tenant_a - delURL := server.URL + "/v1/feedback-records?user_id=" + userIDTenant + "&tenant_id=" + tenantA - delReq, err := http.NewRequestWithContext(context.Background(), http.MethodDelete, delURL, http.NoBody) - require.NoError(t, err) - delReq.Header.Set("Authorization", "Bearer "+testAPIKey) - delResp, err := client.Do(delReq) - require.NoError(t, err) - require.Equal(t, http.StatusOK, delResp.StatusCode) - - var delResult models.BulkDeleteResponse - - err = decodeData(delResp, &delResult) - require.NoError(t, err) - require.NoError(t, delResp.Body.Close()) - assert.Equal(t, int64(1), delResult.DeletedCount) - - // Delete remaining (tenant_b) — should delete 2 - delURL2 := server.URL + "/v1/feedback-records?user_id=" + userIDTenant + "&tenant_id=" + tenantB - delReq2, err := http.NewRequestWithContext(context.Background(), http.MethodDelete, delURL2, http.NoBody) - require.NoError(t, err) - delReq2.Header.Set("Authorization", "Bearer "+testAPIKey) - delResp2, err := client.Do(delReq2) - require.NoError(t, err) - require.Equal(t, http.StatusOK, delResp2.StatusCode) - err = decodeData(delResp2, &delResult) - require.NoError(t, err) - require.NoError(t, delResp2.Body.Close()) - assert.Equal(t, int64(2), delResult.DeletedCount) - }) } // TestFeedbackRecordsRepository_BulkDelete tests the repository BulkDelete return value (deleted IDs). @@ -1060,12 +1027,14 @@ func TestFeedbackRecordsRepository_BulkDelete(t *testing.T) { defer db.Close() repo := repository.NewFeedbackRecordsRepository(db) - userID := "repo-bulk-delete-user" + userID := "repo-bulk-delete-user-" + uuid.New().String() sourceType := "formbricks" - // Create two records with same user_id + // Create records with same user_id across tenants. const bulkDeleteTenant = "bulk-delete-tenant" + const otherBulkDeleteTenant = "bulk-delete-tenant-other" + req1 := &models.CreateFeedbackRecordRequest{ SourceType: sourceType, SubmissionID: userID, @@ -1073,7 +1042,7 @@ func TestFeedbackRecordsRepository_BulkDelete(t *testing.T) { FieldID: "f1", FieldType: models.FieldTypeNumber, ValueNumber: new(1.0), - UserID: new(userID), + UserID: &userID, } rec1, err := repo.Create(ctx, req1) require.NoError(t, err) @@ -1086,19 +1055,54 @@ func TestFeedbackRecordsRepository_BulkDelete(t *testing.T) { FieldID: "f2", FieldType: models.FieldTypeNumber, ValueNumber: new(2.0), - UserID: new(userID), + UserID: &userID, } rec2, err := repo.Create(ctx, req2) require.NoError(t, err) require.NotEmpty(t, rec2.ID) - // BulkDelete returns the deleted IDs - deletedIDs, err := repo.BulkDelete(ctx, userID, nil) + valueText := "delete me too" + req3 := &models.CreateFeedbackRecordRequest{ + SourceType: sourceType, + SubmissionID: userID, + TenantID: otherBulkDeleteTenant, + FieldID: "f3", + FieldType: models.FieldTypeText, + ValueText: &valueText, + UserID: &userID, + } + rec3, err := repo.Create(ctx, req3) + require.NoError(t, err) + require.NotEmpty(t, rec3.ID) + + // BulkDelete with tenant_id restricts deletion to that tenant and returns tenant-safe groups. + tenantFilter := bulkDeleteTenant + deletedGroups, err := repo.BulkDelete(ctx, &models.BulkDeleteFilters{ + UserID: userID, + TenantID: &tenantFilter, + }) + require.NoError(t, err) + require.Len(t, deletedGroups, 1) + require.Equal(t, bulkDeleteTenant, deletedGroups[0].TenantID) + assert.ElementsMatch(t, []uuid.UUID{rec1.ID, rec2.ID}, deletedGroups[0].IDs) + + _, err = repo.GetByID(ctx, rec1.ID) + require.Error(t, err) + _, err = repo.GetByID(ctx, rec2.ID) + require.Error(t, err) + remaining, err := repo.GetByID(ctx, rec3.ID) + require.NoError(t, err) + require.Equal(t, otherBulkDeleteTenant, remaining.TenantID) + + // Omitting tenant_id deletes the rest of the user records across tenants. + deletedGroups, err = repo.BulkDelete(ctx, &models.BulkDeleteFilters{UserID: userID}) require.NoError(t, err) - require.Len(t, deletedIDs, 2) - ids := map[string]bool{deletedIDs[0].String(): true, deletedIDs[1].String(): true} - assert.True(t, ids[rec1.ID.String()]) - assert.True(t, ids[rec2.ID.String()]) + require.Len(t, deletedGroups, 1) + require.Equal(t, otherBulkDeleteTenant, deletedGroups[0].TenantID) + assert.ElementsMatch(t, []uuid.UUID{rec3.ID}, deletedGroups[0].IDs) + + _, err = repo.GetByID(ctx, rec3.ID) + require.Error(t, err) } // TestWebhooksCRUD tests webhook create, get, list, update, delete. @@ -1109,8 +1113,10 @@ func TestWebhooksCRUD(t *testing.T) { client := &http.Client{} // Create webhook (no signing key = auto-generated) + webhookTenantID := "org-123" createBody := map[string]any{ "url": testWebhookURL, + "tenant_id": webhookTenantID, "event_types": []string{"feedback_record.created", "feedback_record.updated"}, } body, err := json.Marshal(createBody) @@ -1133,6 +1139,8 @@ func TestWebhooksCRUD(t *testing.T) { assert.Equal(t, testWebhookURL, created.URL) assert.NotEmpty(t, created.SigningKey) assert.True(t, created.Enabled) + require.NotNil(t, created.TenantID) + assert.Equal(t, webhookTenantID, *created.TenantID) assert.Len(t, created.EventTypes, 2) // Get webhook @@ -1202,7 +1210,7 @@ func TestWebhooksCRUD(t *testing.T) { updateBody := map[string]any{ "url": testWebhookURLV2, "enabled": false, - "tenant_id": "org-123", + "tenant_id": "org-456", } updateJSON, err := json.Marshal(updateBody) require.NoError(t, err) @@ -1224,9 +1232,9 @@ func TestWebhooksCRUD(t *testing.T) { assert.Equal(t, testWebhookURLV2, updated.URL) assert.False(t, updated.Enabled) require.NotNil(t, updated.TenantID) - assert.Equal(t, "org-123", *updated.TenantID) + assert.Equal(t, "org-456", *updated.TenantID) - // PATCH tenant_id to empty string to clear it + // PATCH tenant_id to empty string is rejected; webhooks cannot be global. clearTenantBody := map[string]any{"tenant_id": ""} clearTenantJSON, err := json.Marshal(clearTenantBody) require.NoError(t, err) @@ -1238,14 +1246,8 @@ func TestWebhooksCRUD(t *testing.T) { clearTenantReq.Header.Set("Content-Type", "application/json") clearTenantResp, err := client.Do(clearTenantReq) require.NoError(t, err) - assert.Equal(t, http.StatusOK, clearTenantResp.StatusCode) - - var afterClear models.Webhook - - err = decodeData(clearTenantResp, &afterClear) - require.NoError(t, err) + assert.Equal(t, http.StatusBadRequest, clearTenantResp.StatusCode) require.NoError(t, clearTenantResp.Body.Close()) - assert.Nil(t, afterClear.TenantID) // Delete webhook deleteWebhookURL := fmt.Sprintf("%s/v1/webhooks/%s", server.URL, created.ID) @@ -1268,6 +1270,34 @@ func TestWebhooksCRUD(t *testing.T) { require.NoError(t, getAfterResp.Body.Close()) } +func TestWebhooksCreateRequiresTenantID(t *testing.T) { + server, cleanup := setupTestServer(t) + defer cleanup() + + body, err := json.Marshal(map[string]any{ + "url": testWebhookURL, + "event_types": []string{"feedback_record.created"}, + }) + require.NoError(t, err) + + req, err := http.NewRequestWithContext(context.Background(), http.MethodPost, server.URL+"/v1/webhooks", bytes.NewBuffer(body)) + require.NoError(t, err) + req.Header.Set("Authorization", "Bearer "+testAPIKey) + req.Header.Set("Content-Type", "application/json") + + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + assert.Equal(t, http.StatusBadRequest, resp.StatusCode) + + var problem response.ProblemDetails + + err = json.NewDecoder(resp.Body).Decode(&problem) + require.NoError(t, err) + require.NoError(t, resp.Body.Close()) + assert.Equal(t, "Validation Error", problem.Title) + assert.Contains(t, problem.Detail, "TenantID is required") +} + // TestWebhooksInvalidSigningKey asserts that create and update reject invalid signing_key with 400. func TestWebhooksInvalidSigningKey(t *testing.T) { server, cleanup := setupTestServer(t) @@ -1279,6 +1309,7 @@ func TestWebhooksInvalidSigningKey(t *testing.T) { createBody := map[string]any{ "url": testWebhookURL, "signing_key": "not-valid", + "tenant_id": "org-123", "event_types": []string{"feedback_record.created"}, } body, err := json.Marshal(createBody) @@ -1306,6 +1337,7 @@ func TestWebhooksInvalidSigningKey(t *testing.T) { // Create a valid webhook first for update test validBody := map[string]any{ "url": testWebhookURL, + "tenant_id": "org-123", "event_types": []string{"feedback_record.created"}, } validJSON, err := json.Marshal(validBody) diff --git a/tests/webhooks_repository_test.go b/tests/webhooks_repository_test.go new file mode 100644 index 0000000..45fead7 --- /dev/null +++ b/tests/webhooks_repository_test.go @@ -0,0 +1,173 @@ +package tests + +import ( + "context" + "os" + "strings" + "testing" + + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/formbricks/hub/internal/config" + "github.com/formbricks/hub/internal/datatypes" + "github.com/formbricks/hub/internal/huberrors" + "github.com/formbricks/hub/internal/models" + "github.com/formbricks/hub/internal/repository" + "github.com/formbricks/hub/pkg/database" +) + +func TestWebhooksRepository_ListEnabledForEventTypeAndTenant(t *testing.T) { + ctx := context.Background() + urlPrefix := "https://tenant-scope.test/" + uuid.NewString() + "/" + + databaseURL := os.Getenv("DATABASE_URL") + if databaseURL == "" { + databaseURL = defaultTestDatabaseURL + } + + t.Setenv("API_KEY", testAPIKey) + t.Setenv("DATABASE_URL", databaseURL) + + cfg, err := config.Load() + require.NoError(t, err) + + db, err := database.NewPostgresPool(ctx, cfg.Database.URL, + database.WithPoolConfig(cfg.Database.PoolConfig()), + ) + require.NoError(t, err) + + defer db.Close() + + cleanupRepositoryWebhookScopeTestRows := func() { + _, cleanupErr := db.Exec(ctx, "DELETE FROM webhooks WHERE url LIKE $1", urlPrefix+"%") + require.NoError(t, cleanupErr) + } + + cleanupRepositoryWebhookScopeTestRows() + defer cleanupRepositoryWebhookScopeTestRows() + + repo := repository.NewWebhooksRepository(db) + tenantA := "repo-scope-tenant-a" + tenantB := "repo-scope-tenant-b" + disabled := false + feedbackCreated := []datatypes.EventType{datatypes.FeedbackRecordCreated} + feedbackUpdated := []datatypes.EventType{datatypes.FeedbackRecordUpdated} + + tenantAWebhook := createWebhookForRepositoryScopeTest(ctx, t, repo, urlPrefix, "tenant-a", &tenantA, feedbackCreated) + tenantBWebhook := createWebhookForRepositoryScopeTest(ctx, t, repo, urlPrefix, "tenant-b", &tenantB, feedbackCreated) + disabledTenantAWebhook := createWebhookForRepositoryScopeTest(ctx, t, repo, urlPrefix, "disabled-a", &tenantA, feedbackCreated) + _, err = repo.Update(ctx, disabledTenantAWebhook.ID, &models.UpdateWebhookRequest{Enabled: &disabled}) + require.NoError(t, err) + + createWebhookForRepositoryScopeTest(ctx, t, repo, urlPrefix, "updated-only-a", &tenantA, feedbackUpdated) + + tenantAWebhooks, err := repo.ListEnabledForEventTypeAndTenant(ctx, datatypes.FeedbackRecordCreated.String(), &tenantA) + require.NoError(t, err) + assertRepositoryScopeWebhookIDs(t, tenantAWebhooks, urlPrefix, map[uuid.UUID]bool{ + tenantAWebhook.ID: true, + }) + + tenantBWebhooks, err := repo.ListEnabledForEventTypeAndTenant(ctx, datatypes.FeedbackRecordCreated.String(), &tenantB) + require.NoError(t, err) + assertRepositoryScopeWebhookIDs(t, tenantBWebhooks, urlPrefix, map[uuid.UUID]bool{ + tenantBWebhook.ID: true, + }) + + tenantlessWebhooks, err := repo.ListEnabledForEventTypeAndTenant(ctx, datatypes.FeedbackRecordCreated.String(), nil) + require.NoError(t, err) + assertRepositoryScopeWebhookIDs(t, tenantlessWebhooks, urlPrefix, map[uuid.UUID]bool{}) +} + +func TestWebhooksRepository_DeleteReturnsDeletedWebhook(t *testing.T) { + ctx := context.Background() + urlPrefix := "https://tenant-delete.test/" + uuid.NewString() + "/" + + databaseURL := os.Getenv("DATABASE_URL") + if databaseURL == "" { + databaseURL = defaultTestDatabaseURL + } + + t.Setenv("API_KEY", testAPIKey) + t.Setenv("DATABASE_URL", databaseURL) + + cfg, err := config.Load() + require.NoError(t, err) + + db, err := database.NewPostgresPool(ctx, cfg.Database.URL, + database.WithPoolConfig(cfg.Database.PoolConfig()), + ) + require.NoError(t, err) + + defer db.Close() + + cleanupRepositoryWebhookDeleteTestRows := func() { + _, cleanupErr := db.Exec(ctx, "DELETE FROM webhooks WHERE url LIKE $1", urlPrefix+"%") + require.NoError(t, cleanupErr) + } + + cleanupRepositoryWebhookDeleteTestRows() + defer cleanupRepositoryWebhookDeleteTestRows() + + repo := repository.NewWebhooksRepository(db) + tenantID := "repo-delete-tenant" + webhook := createWebhookForRepositoryScopeTest( + ctx, t, repo, urlPrefix, "delete-returning", &tenantID, []datatypes.EventType{datatypes.WebhookDeleted}, + ) + + deleted, err := repo.Delete(ctx, webhook.ID) + require.NoError(t, err) + require.NotNil(t, deleted) + require.NotNil(t, deleted.TenantID) + assert.Equal(t, webhook.ID, deleted.ID) + assert.Equal(t, tenantID, *deleted.TenantID) + + _, err = repo.GetByID(ctx, webhook.ID) + assert.ErrorIs(t, err, huberrors.ErrNotFound) +} + +func createWebhookForRepositoryScopeTest( + ctx context.Context, + t *testing.T, + repo *repository.WebhooksRepository, + urlPrefix string, + path string, + tenantID *string, + eventTypes []datatypes.EventType, +) *models.Webhook { + t.Helper() + + webhook, err := repo.Create(ctx, &models.CreateWebhookRequest{ + URL: urlPrefix + path, + SigningKey: "whsec_abcdefghijklmnopqrstuvwxyz123456", + TenantID: tenantID, + EventTypes: eventTypes, + }) + require.NoError(t, err) + + return webhook +} + +func assertRepositoryScopeWebhookIDs(t *testing.T, webhooks []models.Webhook, urlPrefix string, wantIDs map[uuid.UUID]bool) { + t.Helper() + + gotIDs := make(map[uuid.UUID]bool, len(webhooks)) + for _, webhook := range webhooks { + if !strings.HasPrefix(webhook.URL, urlPrefix) { + continue + } + + if !wantIDs[webhook.ID] { + t.Fatalf("unexpected scoped test webhook returned: %+v", webhook) + } + + gotIDs[webhook.ID] = true + } + + assert.Len(t, gotIDs, len(wantIDs), "webhooks = %+v", webhooks) + + for id := range wantIDs { + assert.True(t, gotIDs[id], "missing webhook %s in %+v", id, webhooks) + } +}