diff --git a/server/app/app.go b/server/app/app.go index 838ccdab87..18b4775935 100644 --- a/server/app/app.go +++ b/server/app/app.go @@ -23,6 +23,7 @@ import ( "github.com/kubeshop/tracetest/server/executor/trigger" httpServer "github.com/kubeshop/tracetest/server/http" "github.com/kubeshop/tracetest/server/http/mappings" + "github.com/kubeshop/tracetest/server/http/middleware" "github.com/kubeshop/tracetest/server/http/websocket" "github.com/kubeshop/tracetest/server/linter/analyzer" "github.com/kubeshop/tracetest/server/model" @@ -269,6 +270,9 @@ func (app *App) Start(opts ...appOption) error { // use the analytics middleware on complete router router.Use(analyticsMW) + // use the tenant middleware on complete router + router.Use(middleware.TenantMiddleware) + apiRouter := router. PathPrefix(app.cfg.ServerPathPrefix()). PathPrefix("/api"). diff --git a/server/config/demo/demo_repository.go b/server/config/demo/demo_repository.go index 0d3acf08d6..fcee382bb4 100644 --- a/server/config/demo/demo_repository.go +++ b/server/config/demo/demo_repository.go @@ -30,9 +30,10 @@ const insertQuery = `INSERT INTO demos ( "enabled", "type", "pokeshop", - "opentelemetry_store" + "opentelemetry_store", + "tenant_id" ) - VALUES ($1, $2, $3, $4, $5, $6)` + VALUES ($1, $2, $3, $4, $5, $6, $7)` func (r *Repository) Create(ctx context.Context, demo Demo) (Demo, error) { tx, err := r.db.BeginTx(ctx, nil) @@ -50,6 +51,8 @@ func (r *Repository) Create(ctx context.Context, demo Demo) (Demo, error) { return Demo{}, fmt.Errorf("could not get JSON data from opentelemetry store example: %w", err) } + tenantID := sqlutil.TenantID(ctx) + _, err = tx.ExecContext(ctx, insertQuery, demo.ID, demo.Name, @@ -57,6 +60,7 @@ func (r *Repository) Create(ctx context.Context, demo Demo) (Demo, error) { demo.Type, pokeshopJSONData, openTelemetryStoreJSONData, + tenantID, ) if err != nil { @@ -102,7 +106,7 @@ func (r *Repository) Update(ctx context.Context, demo Demo) (Demo, error) { return Demo{}, fmt.Errorf("could not get JSON data from opentelemetry store example: %w", err) } - _, err = tx.ExecContext(ctx, updateQuery, + query, params := sqlutil.Tenant(ctx, updateQuery, oldDemo.ID, demo.Name, demo.Enabled, @@ -111,6 +115,8 @@ func (r *Repository) Update(ctx context.Context, demo Demo) (Demo, error) { openTelemetryStoreJSONData, ) + _, err = tx.ExecContext(ctx, query, params...) + if err != nil { tx.Rollback() return Demo{}, fmt.Errorf("sql exec: %w", err) @@ -133,10 +139,10 @@ const ( "type", "pokeshop", "opentelemetry_store" - FROM demos ` + FROM demos` - getQuery = baseSelect + `WHERE "id" = $1` - getDefaultQuery = baseSelect + `WHERE "default" = true` + getQuery = baseSelect + ` WHERE "id" = $1` + getDefaultQuery = baseSelect + ` WHERE "default" = true` ) func (r *Repository) Get(ctx context.Context, id id.ID) (Demo, error) { @@ -144,6 +150,7 @@ func (r *Repository) Get(ctx context.Context, id id.ID) (Demo, error) { } func (r *Repository) get(ctx context.Context, query string, args ...any) (Demo, error) { + query, args = sqlutil.Tenant(ctx, query, args...) row := r.db.QueryRowContext(ctx, query, args...) return readRow(row) } @@ -161,7 +168,8 @@ func (r *Repository) Delete(ctx context.Context, id id.ID) error { return err } - _, err = tx.ExecContext(ctx, deleteQuery, demo.ID) + query, params := sqlutil.Tenant(ctx, deleteQuery, demo.ID) + _, err = tx.ExecContext(ctx, query, params...) if err != nil { tx.Rollback() @@ -191,6 +199,7 @@ func listQuery(baseSQL, query string, params []any) (string, []any) { func (r *Repository) List(ctx context.Context, take, skip int, query, sortBy, sortDirection string) ([]Demo, error) { q, params := listQuery(baseSelect, query, []any{take, skip}) + q, params = sqlutil.Tenant(ctx, q, params...) sortingFields := map[string]string{ "id": "id", @@ -240,8 +249,9 @@ func (r *Repository) Count(ctx context.Context, query string) (int, error) { count := 0 + countQuery, params := sqlutil.Tenant(ctx, countQuery) err := r.db. - QueryRowContext(ctx, countQuery). + QueryRowContext(ctx, countQuery, params...). Scan(&count) if err != nil { diff --git a/server/datastore/datastore_repository.go b/server/datastore/datastore_repository.go index 1b1fdaf9da..08527621df 100644 --- a/server/datastore/datastore_repository.go +++ b/server/datastore/datastore_repository.go @@ -9,6 +9,7 @@ import ( "time" "github.com/kubeshop/tracetest/server/pkg/id" + "github.com/kubeshop/tracetest/server/pkg/sqlutil" ) func NewRepository(db *sql.DB) *Repository { @@ -33,8 +34,9 @@ INSERT INTO data_stores ( "type", "is_default", "values", - "created_at" -) VALUES ($1, $2, $3, $4, $5, $6)` + "created_at", + "tenant_id" +) VALUES ($1, $2, $3, $4, $5, $6, $7)` const deleteQuery = `DELETE FROM data_stores WHERE "id" = $1` @@ -99,6 +101,8 @@ func (r *Repository) Update(ctx context.Context, dataStore DataStore) (DataStore return DataStore{}, fmt.Errorf("could not marshal values field configuration: %w", err) } + tenantID := sqlutil.TenantID(ctx) + _, err = tx.ExecContext(ctx, insertQuery, dataStore.ID, dataStore.Name, @@ -106,6 +110,7 @@ func (r *Repository) Update(ctx context.Context, dataStore DataStore) (DataStore dataStore.Default, valuesJSON, dataStore.CreatedAt, + tenantID, ) if err != nil { return DataStore{}, fmt.Errorf("datastore repository sql exec create: %w", err) @@ -126,7 +131,8 @@ func (r *Repository) Delete(ctx context.Context, id id.ID) error { } defer tx.Rollback() - _, err = tx.ExecContext(ctx, deleteQuery, dataStoreSingleID) + query, params := sqlutil.Tenant(ctx, deleteQuery, id) + _, err = tx.ExecContext(ctx, query, params...) if err != nil { return fmt.Errorf("datastore repository sql exec delete: %w", err) } @@ -160,7 +166,8 @@ func (r *Repository) Current(ctx context.Context) (DataStore, error) { } func (r *Repository) Get(ctx context.Context, id id.ID) (DataStore, error) { - row := r.db.QueryRowContext(ctx, getQuery, id) + query, params := sqlutil.Tenant(ctx, getQuery, id) + row := r.db.QueryRowContext(ctx, query, params...) dataStore, err := r.readRow(row) if err != nil && errors.Is(err, sql.ErrNoRows) { diff --git a/server/environment/environment_repository.go b/server/environment/environment_repository.go index 52b03fcdcc..23a715ae32 100644 --- a/server/environment/environment_repository.go +++ b/server/environment/environment_repository.go @@ -5,7 +5,6 @@ import ( "database/sql" "encoding/json" "fmt" - "strings" "time" "github.com/kubeshop/tracetest/server/pkg/id" @@ -31,8 +30,9 @@ const ( "name", "description", "created_at", - "values" - ) VALUES ($1, $2, $3, $4, $5)` + "values", + "tenant_id" + ) VALUES ($1, $2, $3, $4, $5, $6)` updateQuery = ` UPDATE environments SET @@ -107,7 +107,8 @@ func (r *Repository) Delete(ctx context.Context, id id.ID) error { } defer tx.Rollback() - _, err = tx.ExecContext(ctx, deleteQuery, id) + query, params := sqlutil.Tenant(ctx, deleteQuery, id) + _, err = tx.ExecContext(ctx, query, params...) if err != nil { return fmt.Errorf("sql error: %w", err) } @@ -144,6 +145,7 @@ func (r *Repository) List(ctx context.Context, take, skip int, query, sortBy, so sql = sqlutil.Sort(sql, sortBy, sortDirection, "created", sortingFields) sql += ` LIMIT $1 OFFSET $2 ` + sql, params = sqlutil.Tenant(ctx, sql, params...) stmt, err := r.db.Prepare(sql) if err != nil { return []Environment{}, err @@ -173,30 +175,11 @@ func (r *Repository) Count(ctx context.Context, query string) (int, error) { return r.countEnvironments(ctx, query) } -func sortQuery(sql, sortBy, sortDirection string, sortingFields map[string]string) string { - sortField, ok := sortingFields[sortBy] - - if !ok { - sortField = sortingFields["created"] - } - - dir := "DESC" - if strings.ToLower(sortDirection) == "asc" { - dir = "ASC" - } - - return fmt.Sprintf("%s ORDER BY %s %s", sql, sortField, dir) -} - func (r *Repository) Get(ctx context.Context, id id.ID) (Environment, error) { - stmt, err := r.db.Prepare(getQuery + " WHERE e.id = $1") + query, params := sqlutil.Tenant(ctx, getQuery+" WHERE e.id = $1", id) + row := r.db.QueryRowContext(ctx, query, params...) - if err != nil { - return Environment{}, fmt.Errorf("prepare: %w", err) - } - defer stmt.Close() - - environment, err := r.readEnvironmentRow(ctx, stmt.QueryRowContext(ctx, id)) + environment, err := r.readEnvironmentRow(ctx, row) if err != nil { return Environment{}, err } @@ -207,7 +190,8 @@ func (r *Repository) Get(ctx context.Context, id id.ID) (Environment, error) { func (r *Repository) Exists(ctx context.Context, id id.ID) (bool, error) { exists := false - row := r.db.QueryRowContext(ctx, idExistsQuery, id) + query, params := sqlutil.Tenant(ctx, idExistsQuery, id) + row := r.db.QueryRowContext(ctx, query, params...) err := row.Scan(&exists) @@ -251,6 +235,7 @@ func (r *Repository) countEnvironments(ctx context.Context, query string) (int, condition := " WHERE (e.name ilike $1 OR e.description ilike $1)" sql, params := sqlutil.Search(countQuery, condition, query, params) + sql, params = sqlutil.Tenant(ctx, sql, params...) err := r.db. QueryRowContext(ctx, sql, params...). @@ -274,6 +259,8 @@ func (r *Repository) insertIntoEnvironments(ctx context.Context, environment Env return Environment{}, fmt.Errorf("encoding error: %w", err) } + tenantID := sqlutil.TenantID(ctx) + _, err = stmt.ExecContext( ctx, environment.ID, @@ -281,6 +268,7 @@ func (r *Repository) insertIntoEnvironments(ctx context.Context, environment Env environment.Description, environment.CreatedAt, jsonValues, + tenantID, ) if err != nil { @@ -291,30 +279,37 @@ func (r *Repository) insertIntoEnvironments(ctx context.Context, environment Env } func (r *Repository) updateIntoEnvironments(ctx context.Context, environment Environment, oldId id.ID) (Environment, error) { - stmt, err := r.db.Prepare(updateQuery) - if err != nil { - return Environment{}, fmt.Errorf("sql prepare: %w", err) - } - defer stmt.Close() - jsonValues, err := json.Marshal(environment.Values) if err != nil { return Environment{}, fmt.Errorf("encoding error: %w", err) } + tx, err := r.db.BeginTx(ctx, nil) + if err != nil { + return Environment{}, err + } + defer tx.Rollback() - _, err = stmt.ExecContext( - ctx, - oldId, + query, params := sqlutil.Tenant(ctx, updateQuery, oldId, environment.Name, environment.Description, environment.CreatedAt, jsonValues, ) + _, err = tx.ExecContext( + ctx, + query, + params..., + ) if err != nil { return Environment{}, fmt.Errorf("sql exec: %w", err) } + err = tx.Commit() + if err != nil { + return Environment{}, fmt.Errorf("commit: %w", err) + } + return environment, nil } diff --git a/server/executor/pollingprofile/polling_profile_repository.go b/server/executor/pollingprofile/polling_profile_repository.go index 5f5d0b63a1..0be2462ec0 100644 --- a/server/executor/pollingprofile/polling_profile_repository.go +++ b/server/executor/pollingprofile/polling_profile_repository.go @@ -8,6 +8,7 @@ import ( "fmt" "github.com/kubeshop/tracetest/server/pkg/id" + "github.com/kubeshop/tracetest/server/pkg/sqlutil" ) func NewRepository(db *sql.DB) *Repository { @@ -35,9 +36,10 @@ const ( "name", "default", "strategy", - "periodic" + "periodic", + "tenant_id" ) - VALUES ($1, $2, $3, $4, $5)` + VALUES ($1, $2, $3, $4, $5, $6)` deleteQuery = `DELETE FROM polling_profiles` ) @@ -47,12 +49,13 @@ func (r *Repository) Update(ctx context.Context, updated PollingProfile) (Pollin updated.Default = true tx, err := r.db.BeginTx(ctx, nil) - defer tx.Rollback() if err != nil { return PollingProfile{}, err } + defer tx.Rollback() - _, err = tx.ExecContext(ctx, deleteQuery) + query, params := sqlutil.Tenant(ctx, deleteQuery) + _, err = tx.ExecContext(ctx, query, params...) if err != nil { return PollingProfile{}, fmt.Errorf("sql exec delete: %w", err) } @@ -65,12 +68,14 @@ func (r *Repository) Update(ctx context.Context, updated PollingProfile) (Pollin } } + tenantID := sqlutil.TenantID(ctx) _, err = tx.ExecContext(ctx, insertQuery, updated.ID, updated.Name, updated.Default, updated.Strategy, periodicJSON, + tenantID, ) if err != nil { return PollingProfile{}, fmt.Errorf("sql exec insert: %w", err) @@ -106,8 +111,9 @@ func (r *Repository) Get(ctx context.Context, id id.ID) (PollingProfile, error) } var periodicJSON []byte + query, params := sqlutil.Tenant(ctx, getQuery) err := r.db. - QueryRowContext(ctx, getQuery). + QueryRowContext(ctx, query, params...). Scan( &profile.Name, &profile.Strategy, diff --git a/server/executor/testrunner/testrunner_repository.go b/server/executor/testrunner/testrunner_repository.go index c0d96e83fe..9ba9b5be41 100644 --- a/server/executor/testrunner/testrunner_repository.go +++ b/server/executor/testrunner/testrunner_repository.go @@ -8,6 +8,7 @@ import ( "fmt" "github.com/kubeshop/tracetest/server/pkg/id" + "github.com/kubeshop/tracetest/server/pkg/sqlutil" ) func NewRepository(db *sql.DB) *Repository { @@ -33,9 +34,10 @@ const ( INSERT INTO test_runners( "id", "name", - "required_gates" + "required_gates", + "tenant_id" ) - VALUES ($1, $2, $3)` + VALUES ($1, $2, $3, $4)` deleteQuery = `DELETE FROM test_runners` ) @@ -53,6 +55,7 @@ func (r *Repository) Update(ctx context.Context, updated TestRunner) (TestRunner if err != nil { return TestRunner{}, fmt.Errorf("sql exec delete: %w", err) } + tenantID := sqlutil.TenantID(ctx) var requiredGatesJSON []byte if updated.RequiredGates != nil { @@ -66,6 +69,7 @@ func (r *Repository) Update(ctx context.Context, updated TestRunner) (TestRunner updated.ID, updated.Name, requiredGatesJSON, + tenantID, ) if err != nil { return TestRunner{}, fmt.Errorf("sql exec insert: %w", err) @@ -103,8 +107,9 @@ func (r *Repository) Get(ctx context.Context, id id.ID) (TestRunner, error) { } var requiredGatesJSON []byte + query, params := sqlutil.Tenant(ctx, getQuery) err := r.db. - QueryRowContext(ctx, getQuery). + QueryRowContext(ctx, query, params...). Scan( &testRunner.Name, &requiredGatesJSON, diff --git a/server/http/middleware/tenant.go b/server/http/middleware/tenant.go new file mode 100644 index 0000000000..0f7d7f80c3 --- /dev/null +++ b/server/http/middleware/tenant.go @@ -0,0 +1,44 @@ +package middleware + +import ( + "context" + "fmt" + "net/http" + + "github.com/google/uuid" +) + +type key string + +var ( + TenantIDKey key = "tenantID" +) + +const HeaderTenantID = "X-Tracetest-TenantID" + +func TenantMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + tenantID := getTenantIDFromRequest(r) + + // if tenant id exists and is invalid we return a 400 error + if tenantID != "" && !isValidUUID(tenantID) { + err := fmt.Errorf("invalid tenant id: %s", tenantID) + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte(err.Error())) + return + } + + ctx = context.WithValue(ctx, TenantIDKey, tenantID) + next.ServeHTTP(w, r.WithContext(ctx)) + }) +} + +func getTenantIDFromRequest(r *http.Request) string { + return r.Header.Get(HeaderTenantID) +} + +func isValidUUID(value string) bool { + _, err := uuid.Parse(value) + return err == nil +} diff --git a/server/http/middleware/tenant_test.go b/server/http/middleware/tenant_test.go new file mode 100644 index 0000000000..fd23086f14 --- /dev/null +++ b/server/http/middleware/tenant_test.go @@ -0,0 +1,49 @@ +package middleware_test + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/kubeshop/tracetest/server/http/middleware" + "github.com/stretchr/testify/assert" +) + +type dummyHandler struct { + t *testing.T + OnRequest func(r *http.Request) +} + +func (d *dummyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + d.OnRequest(r) +} + +func TestMiddleware(t *testing.T) { + dummyHandler := &dummyHandler{t: t} + nextHandler := http.HandlerFunc(dummyHandler.ServeHTTP) + handlerToTest := middleware.TenantMiddleware(nextHandler) + + t.Run("should set the tenant id in the context", func(t *testing.T) { + uuid := "16700d36-8e0a-4169-9bb8-30a249281841" + onRequest := func(r *http.Request) { + assert.Equal(t, uuid, r.Context().Value(middleware.TenantIDKey).(string)) + } + + dummyHandler.OnRequest = onRequest + + req := httptest.NewRequest("GET", "http://testing", nil) + req.Header.Set(middleware.HeaderTenantID, uuid) + handlerToTest.ServeHTTP(httptest.NewRecorder(), req) + }) + + t.Run("should set the tenant id as empty string", func(t *testing.T) { + onRequest := func(r *http.Request) { + assert.Equal(t, "", r.Context().Value(middleware.TenantIDKey).(string)) + } + + dummyHandler.OnRequest = onRequest + + req := httptest.NewRequest("GET", "http://testing", nil) + handlerToTest.ServeHTTP(httptest.NewRecorder(), req) + }) +} diff --git a/server/linter/analyzer/analyzer_repository.go b/server/linter/analyzer/analyzer_repository.go index 7d24bead6b..2d4379d85c 100644 --- a/server/linter/analyzer/analyzer_repository.go +++ b/server/linter/analyzer/analyzer_repository.go @@ -8,6 +8,7 @@ import ( "fmt" "github.com/kubeshop/tracetest/server/pkg/id" + "github.com/kubeshop/tracetest/server/pkg/sqlutil" ) type Repository struct { @@ -27,8 +28,9 @@ const ( "name", "enabled", "minimum_score", - "plugins" - ) VALUES ($1, $2, $3, $4, $5)` + "plugins", + "tenant_id" + ) VALUES ($1, $2, $3, $4, $5, $6)` getQuery = ` SELECT @@ -74,7 +76,8 @@ func (r *Repository) Update(ctx context.Context, linter Linter) (Linter, error) return Linter{}, err } - _, err = tx.ExecContext(ctx, deleteQuery, updated.ID) + query, params := sqlutil.Tenant(ctx, deleteQuery, updated.ID) + _, err = tx.ExecContext(ctx, query, params...) if err != nil { return Linter{}, fmt.Errorf("sql exec delete: %w", err) } @@ -87,6 +90,7 @@ func (r *Repository) Update(ctx context.Context, linter Linter) (Linter, error) } } + tenantID := sqlutil.TenantID(ctx) _, err = tx.ExecContext( ctx, insertQuery, @@ -95,6 +99,7 @@ func (r *Repository) Update(ctx context.Context, linter Linter) (Linter, error) updated.Enabled, updated.MinimumScore, pluginsJSON, + tenantID, ) if err != nil { return Linter{}, fmt.Errorf("sql exec insert: %w", err) @@ -154,8 +159,9 @@ func (r *Repository) Get(ctx context.Context, id id.ID) (Linter, error) { linter := defaultLinter var rawPlugins []byte + query, params := sqlutil.Tenant(ctx, getQuery) err := r.db. - QueryRowContext(ctx, getQuery). + QueryRowContext(ctx, query, params...). Scan( &linter.ID, &linter.Name, diff --git a/server/migrations/28_add_tenant_id.down.sql b/server/migrations/28_add_tenant_id.down.sql new file mode 100644 index 0000000000..6205e1df8f --- /dev/null +++ b/server/migrations/28_add_tenant_id.down.sql @@ -0,0 +1,42 @@ +BEGIN; + +ALTER TABLE + config DROP COLUMN tenant_id; + +ALTER TABLE + data_stores DROP COLUMN tenant_id; + +ALTER TABLE + demos DROP COLUMN tenant_id; + +ALTER TABLE + environments DROP COLUMN tenant_id; + +ALTER TABLE + linters DROP COLUMN tenant_id; + +ALTER TABLE + polling_profiles DROP COLUMN tenant_id; + +ALTER TABLE + test_runners DROP COLUMN tenant_id; + +ALTER TABLE + test_runs DROP COLUMN tenant_id; + +ALTER TABLE + tests DROP COLUMN tenant_id; + +ALTER TABLE + transaction_runs DROP COLUMN tenant_id; + +ALTER TABLE + transactions DROP COLUMN tenant_id; + +ALTER TABLE + transaction_run_steps DROP COLUMN tenant_id; + +ALTER TABLE + transaction_steps DROP COLUMN tenant_id; + +COMMIT; \ No newline at end of file diff --git a/server/migrations/28_add_tenant_id.up.sql b/server/migrations/28_add_tenant_id.up.sql new file mode 100644 index 0000000000..5dccf1c887 --- /dev/null +++ b/server/migrations/28_add_tenant_id.up.sql @@ -0,0 +1,94 @@ +BEGIN; + +ALTER TABLE + config +ADD + COLUMN tenant_id uuid; + +CREATE INDEX idx_config_tenant_id ON config(tenant_id); + +ALTER TABLE + data_stores +ADD + COLUMN tenant_id uuid; + +CREATE INDEX idx_data_stores_tenant_id ON data_stores(tenant_id); + +ALTER TABLE + demos +ADD + COLUMN tenant_id uuid; + +CREATE INDEX idx_demos_tenant_id ON demos(tenant_id); + +ALTER TABLE + environments +ADD + COLUMN tenant_id uuid; + +CREATE INDEX idx_environments_tenant_id ON environments(tenant_id); + +ALTER TABLE + linters +ADD + COLUMN tenant_id uuid; + +CREATE INDEX idx_linters_tenant_id ON linters(tenant_id); + +ALTER TABLE + polling_profiles +ADD + COLUMN tenant_id uuid; + +CREATE INDEX idx_polling_profiles_tenant_id ON polling_profiles(tenant_id); + +ALTER TABLE + test_runners +ADD + COLUMN tenant_id uuid; + +CREATE INDEX idx_test_runners_tenant_id ON test_runners(tenant_id); + +ALTER TABLE + test_runs +ADD + COLUMN tenant_id uuid; + +CREATE INDEX idx_test_runs_tenant_id ON test_runs(tenant_id); + +ALTER TABLE + tests +ADD + COLUMN tenant_id uuid; + +CREATE INDEX idx_tests_tenant_id ON tests(tenant_id); + +ALTER TABLE + transaction_run_steps +ADD + COLUMN tenant_id uuid; + +CREATE INDEX idx_transaction_run_steps_tenant_id ON transaction_run_steps(tenant_id); + +ALTER TABLE + transaction_runs +ADD + COLUMN tenant_id uuid; + +CREATE INDEX idx_transaction_runs_tenant_id ON transaction_runs(tenant_id); + +ALTER TABLE + transaction_steps +ADD + COLUMN tenant_id uuid; + +CREATE INDEX idx_transaction_steps_tenant_id ON transaction_steps(tenant_id); + +ALTER TABLE + transactions +ADD + COLUMN tenant_id uuid; + +CREATE INDEX idx_transactions_tenant_id ON transactions(tenant_id); + +COMMIT; \ No newline at end of file diff --git a/server/pkg/sqlutil/tenant.go b/server/pkg/sqlutil/tenant.go new file mode 100644 index 0000000000..f61a53b366 --- /dev/null +++ b/server/pkg/sqlutil/tenant.go @@ -0,0 +1,58 @@ +package sqlutil + +import ( + "context" + "fmt" + "strings" + + "github.com/kubeshop/tracetest/server/http/middleware" +) + +func Tenant(ctx context.Context, query string, params ...any) (string, []any) { + tenantID := TenantID(ctx) + if tenantID == nil { + return query, params + } + + prefix := getQueryPrefix(query) + + paramNumber := len(params) + 1 + condition := fmt.Sprintf(" %s tenant_id = $%d", prefix, paramNumber) + + return query + condition, append(params, tenantID) +} + +func TenantWithPrefix(ctx context.Context, query string, prefix string, params ...any) (string, []any) { + tenantID := TenantID(ctx) + if tenantID == nil { + return query, params + } + + queryPrefix := getQueryPrefix(query) + paramNumber := len(params) + 1 + condition := fmt.Sprintf(" %s %stenant_id = $%d)", queryPrefix, prefix, paramNumber) + + return query + condition, append(params, tenantID) +} + +func TenantID(ctx context.Context) *string { + tenantID := ctx.Value(middleware.TenantIDKey) + + if tenantID == "" || tenantID == nil { + return nil + } + + tenantIDString := tenantID.(string) + return &tenantIDString +} + +func getQueryPrefix(query string) string { + prefix := "" + if strings.Contains(strings.ToLower(query), "where") { + prefix = "AND " + } else { + prefix = "WHERE " + } + + return prefix +} diff --git a/server/test/run_repository.go b/server/test/run_repository.go index 51bd619888..86e4afedb0 100644 --- a/server/test/run_repository.go +++ b/server/test/run_repository.go @@ -10,6 +10,7 @@ import ( "github.com/kubeshop/tracetest/server/environment" "github.com/kubeshop/tracetest/server/executor/testrunner" "github.com/kubeshop/tracetest/server/pkg/id" + "github.com/kubeshop/tracetest/server/pkg/sqlutil" "go.opentelemetry.io/otel/trace" ) @@ -74,7 +75,9 @@ INSERT INTO test_runs ( "linter", -- required gates - "required_gates_result" + "required_gates_result", + + "tenant_id" ) VALUES ( nextval('` + runSequenceName + `'), -- id $1, -- test_id @@ -104,7 +107,8 @@ INSERT INTO test_runs ( $12, -- metadata $13, -- environment $14, -- linter - $15 -- required_gates_result + $15, -- required_gates_result + $16 -- tenant_id ) RETURNING "id"` @@ -157,6 +161,8 @@ func (r *runRepository) CreateRun(ctx context.Context, test Test, run Run) (Run, return Run{}, fmt.Errorf("sql exec: %w", err) } + tenantID := sqlutil.TenantID(ctx) + var runID int err = tx.QueryRowContext( ctx, @@ -176,6 +182,7 @@ func (r *runRepository) CreateRun(ctx context.Context, test Test, run Run) (Run, jsonEnvironment, jsonlinter, jsonGatesResult, + tenantID, ).Scan(&runID) if err != nil { tx.Rollback() @@ -223,12 +230,6 @@ WHERE id = $16 AND test_id = $17 ` func (r *runRepository) UpdateRun(ctx context.Context, run Run) error { - stmt, err := r.db.Prepare(updateRunQuery) - if err != nil { - return fmt.Errorf("prepare: %w", err) - } - defer stmt.Close() - jsonTriggerResults, err := json.Marshal(run.TriggerResult) if err != nil { return fmt.Errorf("trigger results encoding error: %w", err) @@ -277,8 +278,9 @@ func (r *runRepository) UpdateRun(ctx context.Context, run Run) error { pass, fail := run.ResultsCount() - _, err = stmt.ExecContext( + query, params := sqlutil.Tenant( ctx, + updateRunQuery, run.ServiceTriggeredAt, run.ServiceTriggerCompletedAt, run.ObtainedTraceAt, @@ -300,6 +302,12 @@ func (r *runRepository) UpdateRun(ctx context.Context, run Run) error { jsonLinter, jsonGatesResult, ) + + _, err = r.db.ExecContext( + ctx, + query, + params..., + ) if err != nil { return fmt.Errorf("sql exec: %w", err) } @@ -319,7 +327,8 @@ func (r *runRepository) DeleteRun(ctx context.Context, run Run) error { } for _, sql := range queries { - _, err := tx.ExecContext(ctx, sql, run.ID, run.TestID) + query, params := sqlutil.Tenant(ctx, sql, run.ID, run.TestID) + _, err := tx.ExecContext(ctx, query, params...) if err != nil { tx.Rollback() return fmt.Errorf("sql error: %w", err) @@ -374,13 +383,9 @@ FROM ` func (r *runRepository) GetRun(ctx context.Context, testID id.ID, runID int) (Run, error) { - stmt, err := r.db.Prepare(selectRunQuery + " WHERE id = $1 AND test_id = $2") - if err != nil { - return Run{}, err - } - defer stmt.Close() + query, params := sqlutil.Tenant(ctx, selectRunQuery+" WHERE id = $1 AND test_id = $2", runID, testID) - run, err := readRunRow(stmt.QueryRowContext(ctx, runID, testID.String())) + run, err := readRunRow(r.db.QueryRowContext(ctx, query, params...)) if err != nil { return Run{}, fmt.Errorf("cannot read row: %w", err) } @@ -388,14 +393,14 @@ func (r *runRepository) GetRun(ctx context.Context, testID id.ID, runID int) (Ru } func (r *runRepository) GetTestRuns(ctx context.Context, test Test, take, skip int32) ([]Run, error) { - const condition = " WHERE test_id = $1" - stmt, err := r.db.Prepare(selectRunQuery + condition + " ORDER BY created_at DESC LIMIT $2 OFFSET $3") + query, params := sqlutil.Tenant(ctx, selectRunQuery+" WHERE test_id = $1", test.ID, take, skip) + stmt, err := r.db.Prepare(query + " ORDER BY created_at DESC LIMIT $2 OFFSET $3") if err != nil { return []Run{}, err } defer stmt.Close() - rows, err := stmt.QueryContext(ctx, test.ID, take, skip) + rows, err := stmt.QueryContext(ctx, params...) if err != nil { return []Run{}, err } @@ -405,25 +410,18 @@ func (r *runRepository) GetTestRuns(ctx context.Context, test Test, take, skip i return []Run{}, err } - var count int - err = r.db. - QueryRowContext(ctx, "SELECT COUNT(*) FROM test_runs"+condition, test.ID). - Scan(&count) - if err != nil { - return []Run{}, err - } - return runs, nil } func (r *runRepository) GetRunByTraceID(ctx context.Context, traceID trace.TraceID) (Run, error) { - stmt, err := r.db.Prepare(selectRunQuery + " WHERE trace_id = $1") + query, params := sqlutil.Tenant(ctx, selectRunQuery+" WHERE trace_id = $1", traceID.String()) + stmt, err := r.db.Prepare(query) if err != nil { return Run{}, err } defer stmt.Close() - run, err := readRunRow(stmt.QueryRowContext(ctx, traceID.String())) + run, err := readRunRow(stmt.QueryRowContext(ctx, params...)) if err != nil { return Run{}, fmt.Errorf("cannot read row: %w", err) } @@ -431,14 +429,15 @@ func (r *runRepository) GetRunByTraceID(ctx context.Context, traceID trace.Trace } func (r *runRepository) GetLatestRunByTestVersion(ctx context.Context, testID id.ID, version int) (Run, error) { - stmt, err := r.db.Prepare(selectRunQuery + " WHERE test_id = $1 AND test_version = $2 ORDER BY created_at DESC LIMIT 1") + query, params := sqlutil.Tenant(ctx, selectRunQuery+" WHERE test_id = $1 AND test_version = $2 ORDER BY created_at DESC LIMIT 1", testID.String(), version) + stmt, err := r.db.Prepare(query) if err != nil { return Run{}, err } defer stmt.Close() - run, err := readRunRow(stmt.QueryRowContext(ctx, testID.String(), version)) + run, err := readRunRow(stmt.QueryRowContext(ctx, params...)) if err != nil { return Run{}, err } @@ -603,6 +602,7 @@ func (r *runRepository) GetTransactionRunSteps(ctx context.Context, id id.ID, ru WHERE transaction_run_steps.transaction_run_id = $1 AND transaction_run_steps.transaction_run_transaction_id = $2 ORDER BY test_runs.completed_at ASC ` + query, params := sqlutil.Tenant(ctx, query, runID, id) stmt, err := r.db.Prepare(query) if err != nil { @@ -610,7 +610,7 @@ ORDER BY test_runs.completed_at ASC } defer stmt.Close() - rows, err := stmt.QueryContext(ctx, runID, id) + rows, err := stmt.QueryContext(ctx, params...) if err != nil { return []Run{}, fmt.Errorf("query context: %w", err) } diff --git a/server/test/test_repository.go b/server/test/test_repository.go index 6540e769cf..6cfb838439 100644 --- a/server/test/test_repository.go +++ b/server/test/test_repository.go @@ -125,6 +125,7 @@ func (r *repository) list(ctx context.Context, take, skip int, query, sortBy, so condition := fmt.Sprintf(" WHERE (t.name ilike $%d OR t.description ilike $%d)", paramNumber, paramNumber) q, params := sqlutil.Search(sql, condition, query, params) + q, params = sqlutil.TenantWithPrefix(ctx, q, "t.", params...) sortingFields := map[string]string{ "created": "t.created_at", @@ -160,6 +161,7 @@ func (r *repository) Count(ctx context.Context, query string) (int, error) { countQuery := "SELECT COUNT(*) FROM tests t" + testMaxVersionQuery const condition = " WHERE (t.name ilike $1 OR t.description ilike $1)" sql, params := sqlutil.Search(countQuery, condition, query, params) + sql, params = sqlutil.TenantWithPrefix(ctx, sql, "t.", params...) count := 0 @@ -188,14 +190,12 @@ func (r *repository) GetAugmented(ctx context.Context, id id.ID) (Test, error) { return r.get(ctx, id) } +const sortQuery = `ORDER BY t.version DESC LIMIT 1` + func (r *repository) get(ctx context.Context, id id.ID) (Test, error) { - stmt, err := r.db.Prepare(getTestSQL + " WHERE t.id = $1 ORDER BY t.version DESC LIMIT 1") - if err != nil { - return Test{}, fmt.Errorf("prepare: %w", err) - } - defer stmt.Close() + query, params := sqlutil.Tenant(ctx, getTestSQL+" WHERE t.id = $1", id) - test, err := r.readRow(ctx, stmt.QueryRowContext(ctx, id)) + test, err := r.readRow(ctx, r.db.QueryRowContext(ctx, query+sortQuery, params...)) if err != nil { return Test{}, err } @@ -204,14 +204,16 @@ func (r *repository) get(ctx context.Context, id id.ID) (Test, error) { } func (r *repository) GetTransactionSteps(ctx context.Context, id id.ID, version int) ([]Test, error) { - stmt, err := r.db.Prepare(getTestSQL + testMaxVersionQuery + ` INNER JOIN transaction_steps ts ON t.id = ts.test_id - WHERE ts.transaction_id = $1 AND ts.transaction_version = $2 ORDER BY ts.step_number ASC`) + sortQuery := `ORDER BY ts.step_number ASC` + query, params := sqlutil.Tenant(ctx, getTestSQL+testMaxVersionQuery+` INNER JOIN transaction_steps ts ON t.id = ts.test_id + WHERE ts.transaction_id = $1 AND ts.transaction_version = $2`, id, version) + stmt, err := r.db.Prepare(query + sortQuery) if err != nil { return []Test{}, fmt.Errorf("prepare 2: %w", err) } defer stmt.Close() - rows, err := stmt.QueryContext(ctx, id, version) + rows, err := stmt.QueryContext(ctx, params...) if err != nil { return []Test{}, fmt.Errorf("query context: %w", err) } @@ -322,8 +324,9 @@ INSERT INTO tests ( "service_under_test", "specs", "outputs", - "created_at" -) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)` + "created_at", + "tenant_id" +) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)` func (r *repository) Create(ctx context.Context, test Test) (Test, error) { if test.HasID() { @@ -377,6 +380,8 @@ func (r *repository) insertTest(ctx context.Context, test Test) (Test, error) { return Test{}, fmt.Errorf("encoding error: %w", err) } + tenantID := sqlutil.TenantID(ctx) + _, err = stmt.ExecContext( ctx, test.ID, @@ -387,6 +392,7 @@ func (r *repository) insertTest(ctx context.Context, test Test) (Test, error) { specsJson, outputsJson, test.CreatedAt, + tenantID, ) if err != nil { return Test{}, fmt.Errorf("sql exec: %w", err) @@ -500,7 +506,8 @@ func (r *repository) Delete(ctx context.Context, id id.ID) error { defer tx.Rollback() for _, sql := range queries { - _, err := tx.ExecContext(ctx, sql, id) + sql, params := sqlutil.Tenant(ctx, sql, id) + _, err := tx.ExecContext(ctx, sql, params...) if err != nil { return fmt.Errorf("sql error: %w", err) } @@ -559,13 +566,14 @@ func (r *repository) Exists(ctx context.Context, id id.ID) (bool, error) { } func (r *repository) GetVersion(ctx context.Context, id id.ID, version int) (Test, error) { - stmt, err := r.db.Prepare(getTestSQL + " WHERE t.id = $1 AND t.version = $2") + query, params := sqlutil.Tenant(ctx, getTestSQL+" WHERE t.id = $1 AND t.version = $2", id, version) + stmt, err := r.db.Prepare(query) if err != nil { return Test{}, fmt.Errorf("prepare: %w", err) } defer stmt.Close() - test, err := r.readRow(ctx, stmt.QueryRowContext(ctx, id, version)) + test, err := r.readRow(ctx, stmt.QueryRowContext(ctx, params...)) if err != nil { return Test{}, err } diff --git a/server/transaction/transaction_repository.go b/server/transaction/transaction_repository.go index eb8db1d382..ed6b8a18a9 100644 --- a/server/transaction/transaction_repository.go +++ b/server/transaction/transaction_repository.go @@ -117,26 +117,29 @@ func (r *Repository) Delete(ctx context.Context, id id.ID) error { return err } - _, err = tx.ExecContext(ctx, "DELETE FROM transaction_steps WHERE transaction_id = $1", id) + query, params := sqlutil.Tenant(ctx, "DELETE FROM transaction_steps WHERE transaction_id = $1", id) + _, err = tx.ExecContext(ctx, query, params...) if err != nil { tx.Rollback() return err } - q := "DELETE FROM transaction_run_steps WHERE transaction_run_id IN (SELECT id FROM transaction_runs WHERE transaction_id = $1)" - _, err = tx.ExecContext(ctx, q, id) + q, params := sqlutil.Tenant(ctx, "DELETE FROM transaction_run_steps WHERE transaction_run_id IN (SELECT id FROM transaction_runs WHERE transaction_id = $1)", id) + _, err = tx.ExecContext(ctx, q, params...) if err != nil { tx.Rollback() return err } - _, err = tx.ExecContext(ctx, "DELETE FROM transaction_runs WHERE transaction_id = $1", id) + q, params = sqlutil.Tenant(ctx, "DELETE FROM transaction_runs WHERE transaction_id = $1", id) + _, err = tx.ExecContext(ctx, q, params...) if err != nil { tx.Rollback() return err } - _, err = tx.ExecContext(ctx, "DELETE FROM transactions WHERE id = $1", id) + q, params = sqlutil.Tenant(ctx, "DELETE FROM transactions WHERE id = $1", id) + _, err = tx.ExecContext(ctx, q, params...) if err != nil { tx.Rollback() return err @@ -147,11 +150,8 @@ func (r *Repository) Delete(ctx context.Context, id id.ID) error { func (r *Repository) IDExists(ctx context.Context, id id.ID) (bool, error) { exists := false - row := r.db.QueryRowContext( - ctx, - "SELECT COUNT(*) > 0 as exists FROM transactions WHERE id = $1", - id, - ) + query, params := sqlutil.Tenant(ctx, "SELECT COUNT(*) > 0 as exists FROM transactions WHERE id = $1", id) + row := r.db.QueryRowContext(ctx, query, params...) err := row.Scan(&exists) if err != nil { @@ -197,7 +197,7 @@ func (r *Repository) List(ctx context.Context, take, skip int, query, sortBy, so } func (r *Repository) list(ctx context.Context, take, skip int, query, sortBy, sortDirection string, augmented bool) ([]Transaction, error) { - q, params := listQuery(querySelect(), query, []any{take, skip}) + q, params := listQuery(ctx, querySelect(), query, []any{take, skip}) sortingFields := map[string]string{ "created": "t.created_at", @@ -253,7 +253,7 @@ func querySelect() string { } func (r *Repository) Count(ctx context.Context, query string) (int, error) { - sql, params := listQuery(queryCount(), query, []any{}) + sql, params := listQuery(ctx, queryCount(), query, []any{}) count := 0 err := r.db. @@ -276,13 +276,14 @@ func (r *Repository) Get(ctx context.Context, id id.ID) (Transaction, error) { } func (r *Repository) get(ctx context.Context, id id.ID, augmented bool) (Transaction, error) { - stmt, err := r.db.Prepare(querySelect() + " WHERE t.id = $1 ORDER BY t.version DESC LIMIT 1") + query, params := sqlutil.Tenant(ctx, querySelect()+" WHERE t.id = $1", id) + stmt, err := r.db.Prepare(query + "ORDER BY t.version DESC LIMIT 1") if err != nil { return Transaction{}, fmt.Errorf("prepare: %w", err) } defer stmt.Close() - transaction, err := r.readRow(ctx, stmt.QueryRowContext(ctx, id), augmented) + transaction, err := r.readRow(ctx, stmt.QueryRowContext(ctx, params...), augmented) if err != nil { return Transaction{}, err } @@ -291,13 +292,14 @@ func (r *Repository) get(ctx context.Context, id id.ID, augmented bool) (Transac } func (r *Repository) GetVersion(ctx context.Context, id id.ID, version int) (Transaction, error) { - stmt, err := r.db.Prepare(querySelect() + " WHERE t.id = $1 AND t.version = $2") + query, params := sqlutil.Tenant(ctx, querySelect()+" WHERE t.id = $1 AND t.version = $2", id, version) + stmt, err := r.db.Prepare(query) if err != nil { return Transaction{}, fmt.Errorf("prepare 1: %w", err) } defer stmt.Close() - transaction, err := r.readRow(ctx, stmt.QueryRowContext(ctx, id, version), true) + transaction, err := r.readRow(ctx, stmt.QueryRowContext(ctx, params...), true) if err != nil { return Transaction{}, err } @@ -305,24 +307,26 @@ func (r *Repository) GetVersion(ctx context.Context, id id.ID, version int) (Tra return transaction, nil } -func listQuery(baseSQL, query string, params []any) (string, []any) { +func listQuery(ctx context.Context, baseSQL, query string, params []any) (string, []any) { paramNumber := len(params) + 1 condition := fmt.Sprintf(" AND (t.name ilike $%d OR t.description ilike $%d)", paramNumber, paramNumber) sql := baseSQL + transactionMaxVersionQuery sql, params = sqlutil.Search(sql, condition, query, params) + sql, params = sqlutil.Tenant(ctx, sql, params...) return sql, params } func (r *Repository) GetLatestVersion(ctx context.Context, id id.ID) (Transaction, error) { - stmt, err := r.db.Prepare(querySelect() + " WHERE t.id = $1 ORDER BY t.version DESC LIMIT 1") + query, params := sqlutil.Tenant(ctx, querySelect()+" WHERE t.id = $1", id) + stmt, err := r.db.Prepare(query + " ORDER BY t.version DESC LIMIT 1") if err != nil { return Transaction{}, fmt.Errorf("prepare: %w", err) } defer stmt.Close() - transaction, err := r.readRow(ctx, stmt.QueryRowContext(ctx, id), true) + transaction, err := r.readRow(ctx, stmt.QueryRowContext(ctx, params...), true) if err != nil { return Transaction{}, err } @@ -336,8 +340,9 @@ INSERT INTO transactions ( "version", "name", "description", - "created_at" -) VALUES ($1, $2, $3, $4, $5)` + "created_at", + "tenant_id" +) VALUES ($1, $2, $3, $4, $5, $6)` func (r *Repository) insertIntoTransactions(ctx context.Context, transaction Transaction) (Transaction, error) { tx, err := r.db.BeginTx(ctx, &sql.TxOptions{}) @@ -352,6 +357,8 @@ func (r *Repository) insertIntoTransactions(ctx context.Context, transaction Tra } defer stmt.Close() + tenantID := sqlutil.TenantID(ctx) + _, err = stmt.ExecContext( ctx, transaction.ID, @@ -359,6 +366,7 @@ func (r *Repository) insertIntoTransactions(ctx context.Context, transaction Tra transaction.Name, transaction.Description, transaction.GetCreatedAt(), + tenantID, ) if err != nil { return Transaction{}, fmt.Errorf("sql exec: %w", err) @@ -369,12 +377,13 @@ func (r *Repository) insertIntoTransactions(ctx context.Context, transaction Tra func (r *Repository) setTransactionSteps(ctx context.Context, tx *sql.Tx, transaction Transaction) (Transaction, error) { // delete existing steps - stmt, err := tx.Prepare("DELETE FROM transaction_steps WHERE transaction_id = $1 AND transaction_version = $2") + query, params := sqlutil.Tenant(ctx, "DELETE FROM transaction_steps WHERE transaction_id = $1 AND transaction_version = $2", transaction.ID, transaction.GetVersion()) + stmt, err := tx.Prepare(query) if err != nil { return Transaction{}, err } - _, err = stmt.ExecContext(ctx, transaction.ID, transaction.GetVersion()) + _, err = stmt.ExecContext(ctx, params...) if err != nil { return Transaction{}, err } @@ -383,13 +392,23 @@ func (r *Repository) setTransactionSteps(ctx context.Context, tx *sql.Tx, transa return transaction, tx.Commit() } + tenantID := sqlutil.TenantID(ctx) + values := []string{} for i, testID := range transaction.StepIDs { stepNumber := i + 1 - values = append( - values, - fmt.Sprintf("('%s', %d, '%s', %d)", transaction.ID, transaction.GetVersion(), testID, stepNumber), - ) + + if tenantID == nil { + values = append( + values, + fmt.Sprintf("('%s', %d, '%s', %d, NULL)", transaction.ID, transaction.GetVersion(), testID, stepNumber), + ) + } else { + values = append( + values, + fmt.Sprintf("('%s', %d, '%s', %d, '%s')", transaction.ID, transaction.GetVersion(), testID, stepNumber, *tenantID), + ) + } } sql := "INSERT INTO transaction_steps VALUES " + strings.Join(values, ", ") diff --git a/server/transaction/transaction_run_repository.go b/server/transaction/transaction_run_repository.go index 33cf8c7b89..5e3b702c2b 100644 --- a/server/transaction/transaction_run_repository.go +++ b/server/transaction/transaction_run_repository.go @@ -10,6 +10,7 @@ import ( "strings" "github.com/kubeshop/tracetest/server/pkg/id" + "github.com/kubeshop/tracetest/server/pkg/sqlutil" ) func NewRunRepository(db *sql.DB, stepsRepository transactionStepRunRepository) *RunRepository { @@ -47,7 +48,9 @@ INSERT INTO transaction_runs ( "metadata", -- environment - "environment" + "environment", + + "tenant_id" ) VALUES ( nextval('` + runSequenceName + `'), -- id $1, -- transaction_id @@ -68,7 +71,8 @@ INSERT INTO transaction_runs ( TRUE, -- all_steps_required_gates_passed $6, -- metadata - $7 -- environment + $7, -- environment + $8 -- tenant_id ) RETURNING "id"` @@ -114,6 +118,8 @@ func (td *RunRepository) CreateRun(ctx context.Context, tr TransactionRun) (Tran return TransactionRun{}, fmt.Errorf("sql exec: %w", err) } + tenantID := sqlutil.TenantID(ctx) + var runID int err = tx.QueryRowContext( ctx, @@ -125,6 +131,7 @@ func (td *RunRepository) CreateRun(ctx context.Context, tr TransactionRun) (Tran tr.CurrentTest, jsonMetadata, jsonEnvironment, + tenantID, ).Scan(&runID) if err != nil { tx.Rollback() @@ -171,12 +178,6 @@ func (td *RunRepository) UpdateRun(ctx context.Context, tr TransactionRun) error return fmt.Errorf("sql beginTx: %w", err) } - stmt, err := tx.Prepare(updateTransactionRunQuery) - if err != nil { - return fmt.Errorf("prepare: %w", err) - } - defer stmt.Close() - jsonMetadata, err := json.Marshal(tr.Metadata) if err != nil { return fmt.Errorf("failed to marshal transaction run metadata: %w", err) @@ -195,8 +196,9 @@ func (td *RunRepository) UpdateRun(ctx context.Context, tr TransactionRun) error pass, fail := tr.ResultsCount() allStepsRequiredGatesPassed := tr.StepsGatesValidation() - _, err = stmt.ExecContext( + query, params := sqlutil.Tenant( ctx, + updateTransactionRunQuery, tr.CompletedAt, tr.State, tr.CurrentTest, @@ -209,6 +211,13 @@ func (td *RunRepository) UpdateRun(ctx context.Context, tr TransactionRun) error tr.ID, tr.TransactionID, ) + stmt, err := tx.Prepare(query) + if err != nil { + return fmt.Errorf("prepare: %w", err) + } + defer stmt.Close() + + _, err = stmt.ExecContext(ctx, params...) if err != nil { tx.Rollback() @@ -220,12 +229,13 @@ func (td *RunRepository) UpdateRun(ctx context.Context, tr TransactionRun) error func (td *RunRepository) setTransactionRunSteps(ctx context.Context, tx *sql.Tx, tr TransactionRun) error { // delete existing steps - stmt, err := tx.Prepare("DELETE FROM transaction_run_steps WHERE transaction_run_id = $1 AND transaction_run_transaction_id = $2") + query, params := sqlutil.Tenant(ctx, "DELETE FROM transaction_run_steps WHERE transaction_run_id = $1 AND transaction_run_transaction_id = $2", tr.ID, tr.TransactionID) + stmt, err := tx.Prepare(query) if err != nil { return err } - _, err = stmt.ExecContext(ctx, tr.ID, tr.TransactionID) + _, err = stmt.ExecContext(ctx, params...) if err != nil { return err } @@ -234,16 +244,26 @@ func (td *RunRepository) setTransactionRunSteps(ctx context.Context, tx *sql.Tx, return tx.Commit() } + tenantID := sqlutil.TenantID(ctx) + values := []string{} for _, run := range tr.Steps { if run.ID == 0 { // step not set, skip continue } - values = append( - values, - fmt.Sprintf("('%d', '%s', %d, '%s')", tr.ID, tr.TransactionID, run.ID, run.TestID), - ) + + if tenantID == nil { + values = append( + values, + fmt.Sprintf("('%d', '%s', %d, '%s', NULL)", tr.ID, tr.TransactionID, run.ID, run.TestID), + ) + } else { + values = append( + values, + fmt.Sprintf("('%d', '%s', %d, '%s', '%s')", tr.ID, tr.TransactionID, run.ID, run.TestID, *tenantID), + ) + } } sql := "INSERT INTO transaction_run_steps VALUES " + strings.Join(values, ", ") @@ -260,19 +280,15 @@ func (td *RunRepository) DeleteTransactionRun(ctx context.Context, tr Transactio return fmt.Errorf("sql beginTx: %w", err) } - _, err = tx.ExecContext( - ctx, "DELETE FROM transaction_run_steps WHERE transaction_run_id = $1 AND transaction_run_transaction_id = $2", - tr.ID, tr.TransactionID, - ) + query, params := sqlutil.Tenant(ctx, "DELETE FROM transaction_run_steps WHERE transaction_run_id = $1 AND transaction_run_transaction_id = $2", tr.ID, tr.TransactionID) + _, err = tx.ExecContext(ctx, query, params...) if err != nil { tx.Rollback() return fmt.Errorf("delete transaction run steps: %w", err) } - _, err = tx.ExecContext( - ctx, "DELETE FROM transaction_runs WHERE id = $1 AND transaction_id = $2", - tr.ID, tr.TransactionID, - ) + query, params = sqlutil.Tenant(ctx, "DELETE FROM transaction_runs WHERE id = $1 AND transaction_id = $2", tr.ID, tr.TransactionID) + _, err = tx.ExecContext(ctx, query, params...) if err != nil { tx.Rollback() return fmt.Errorf("delete transaction runs: %w", err) @@ -305,12 +321,13 @@ FROM transaction_runs ` func (td *RunRepository) GetTransactionRun(ctx context.Context, transactionID id.ID, runID int) (TransactionRun, error) { - stmt, err := td.db.Prepare(selectTransactionRunQuery + " WHERE id = $1 AND transaction_id = $2") + query, params := sqlutil.Tenant(ctx, selectTransactionRunQuery+" WHERE id = $1 AND transaction_id = $2", runID, transactionID) + stmt, err := td.db.Prepare(query) if err != nil { return TransactionRun{}, fmt.Errorf("prepare: %w", err) } - run, err := td.readRunRow(stmt.QueryRowContext(ctx, runID, transactionID)) + run, err := td.readRunRow(stmt.QueryRowContext(ctx, params...)) if err != nil { return TransactionRun{}, err } @@ -322,12 +339,14 @@ func (td *RunRepository) GetTransactionRun(ctx context.Context, transactionID id } func (td *RunRepository) GetLatestRunByTransactionVersion(ctx context.Context, transactionID id.ID, version int) (TransactionRun, error) { - stmt, err := td.db.Prepare(selectTransactionRunQuery + " WHERE transaction_id = $1 AND transaction_version = $2 ORDER BY created_at DESC LIMIT 1") + sortQuery := "ORDER BY created_at DESC LIMIT 1" + query, params := sqlutil.Tenant(ctx, selectTransactionRunQuery+" WHERE transaction_id = $1 AND transaction_version = $2", transactionID, version) + stmt, err := td.db.Prepare(query + sortQuery) if err != nil { return TransactionRun{}, fmt.Errorf("prepare: %w", err) } - run, err := td.readRunRow(stmt.QueryRowContext(ctx, transactionID, version)) + run, err := td.readRunRow(stmt.QueryRowContext(ctx, params...)) if err != nil { return TransactionRun{}, err } @@ -339,12 +358,14 @@ func (td *RunRepository) GetLatestRunByTransactionVersion(ctx context.Context, t } func (td *RunRepository) GetTransactionsRuns(ctx context.Context, transactionID id.ID, take, skip int32) ([]TransactionRun, error) { - stmt, err := td.db.Prepare(selectTransactionRunQuery + " WHERE transaction_id = $1 ORDER BY created_at DESC LIMIT $2 OFFSET $3") + sortQuery := "ORDER BY created_at DESC LIMIT $2 OFFSET $3" + query, params := sqlutil.Tenant(ctx, selectTransactionRunQuery+" WHERE transaction_id = $1", transactionID.String(), take, skip) + stmt, err := td.db.Prepare(query + sortQuery) if err != nil { return []TransactionRun{}, fmt.Errorf("prepare: %w", err) } - rows, err := stmt.QueryContext(ctx, transactionID.String(), take, skip) + rows, err := stmt.QueryContext(ctx, params...) if err != nil { return []TransactionRun{}, fmt.Errorf("query: %w", err) }