Skip to content

Commit

Permalink
fix(server): use tenant_id for sequence creation (#3519)
Browse files Browse the repository at this point in the history
---------

Co-authored-by: Matheus Nogueira <matheus.nogueira2008@gmail.com>
  • Loading branch information
schoren and mathnogueira committed Jan 11, 2024
1 parent dc5f609 commit 2ec83f8
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 16 deletions.
22 changes: 22 additions & 0 deletions server/migrations/38_migrate_run_sequences.up.sql
@@ -0,0 +1,22 @@
DO $$
DECLARE
temprow record;
BEGIN
FOR temprow IN
SELECT max(id)+1 AS next_value, test_id, tenant_id
FROM test_runs
GROUP BY test_id, tenant_id
LOOP
EXECUTE format('CREATE SEQUENCE IF NOT EXISTS runs_test_%s_seq START WITH %s',
MD5(FORMAT('%s%s', temprow.test_id, temprow.tenant_id)), temprow.next_value);
END LOOP;

FOR temprow IN
SELECT max(id::int)+1 AS next_value, test_suite_id, tenant_id
FROM test_suite_runs
GROUP BY test_suite_id, tenant_id
LOOP
EXECUTE format('CREATE SEQUENCE IF NOT EXISTS runs_test_suite_%s_seq START WITH %s',
MD5(FORMAT('%s%s', temprow.test_suite_id, temprow.tenant_id)), temprow.next_value);
END LOOP;
END $$;
9 changes: 9 additions & 0 deletions server/pkg/sqlutil/tenant.go
Expand Up @@ -44,6 +44,15 @@ func TenantInsert(ctx context.Context, params ...any) []any {
return append(params, *tenantID)
}

func TenantIDString(ctx context.Context) string {
tenantID := TenantID(ctx)
if tenantID == nil {
return ""
}

return *tenantID
}

func TenantID(ctx context.Context) *string {
tenantID := ctx.Value(middleware.TenantIDKey)

Expand Down
6 changes: 4 additions & 2 deletions server/test/run_repository.go
Expand Up @@ -169,7 +169,9 @@ func (r *runRepository) CreateRun(ctx context.Context, test Test, run Run) (Run,
return Run{}, fmt.Errorf("sql beginTx: %w", err)
}

_, err = tx.ExecContext(ctx, replaceRunSequenceName(createSequeceQuery, test.ID))
tenantID := sqlutil.TenantIDString(ctx)

_, err = tx.ExecContext(ctx, replaceRunSequenceName(createSequeceQuery, test.ID, tenantID))
if err != nil {
tx.Rollback()
return Run{}, fmt.Errorf("sql exec: %w", err)
Expand All @@ -196,7 +198,7 @@ func (r *runRepository) CreateRun(ctx context.Context, test Test, run Run) (Run,
)

var runID int
err = tx.QueryRowContext(ctx, replaceRunSequenceName(createRunQuery, test.ID), params...).Scan(&runID)
err = tx.QueryRowContext(ctx, replaceRunSequenceName(createRunQuery, test.ID, tenantID), params...).Scan(&runID)
if err != nil {
tx.Rollback()
return Run{}, fmt.Errorf("sql exec: %w", err)
Expand Down
10 changes: 5 additions & 5 deletions server/test/test_repository.go
Expand Up @@ -517,7 +517,7 @@ func (r *repository) Delete(ctx context.Context, id id.ID) error {
}
}

dropSequence(ctx, tx, id)
dropSequence(ctx, tx, id, sqlutil.TenantIDString(ctx))

err = tx.Commit()
if err != nil {
Expand All @@ -533,10 +533,10 @@ const (
runSequenceName = "%sequence_name%"
)

func dropSequence(ctx context.Context, tx *sql.Tx, testID id.ID) error {
func dropSequence(ctx context.Context, tx *sql.Tx, testID id.ID, tenantID string) error {
_, err := tx.ExecContext(
ctx,
replaceRunSequenceName(dropSequenceQuery, testID),
replaceRunSequenceName(dropSequenceQuery, testID, tenantID),
)

return err
Expand All @@ -547,12 +547,12 @@ func md5Hash(text string) string {
return hex.EncodeToString(hash[:])
}

func replaceRunSequenceName(sql string, testID id.ID) string {
func replaceRunSequenceName(sql string, testID id.ID, tenantID string) string {
// postgres doesn't like uppercase chars in sequence names.
// testID might contain uppercase chars, and we cannot lowercase them
// because they might lose their uniqueness.
// md5 creates a unique, lowercase hash.
seqName := "runs_test_" + md5Hash(testID.String()) + "_seq"
seqName := "runs_test_" + md5Hash(testID.String()+tenantID) + "_seq"
return strings.ReplaceAll(sql, runSequenceName, seqName)
}

Expand Down
20 changes: 11 additions & 9 deletions server/testsuite/testsuite_run_repository.go
Expand Up @@ -93,12 +93,12 @@ func md5Hash(text string) string {
return hex.EncodeToString(hash[:])
}

func replaceTestSuiteRunSequenceName(sql string, ID id.ID) string {
func replaceTestSuiteRunSequenceName(sql string, ID id.ID, tenantID string) string {
// postgres doesn't like uppercase chars in sequence names.
// transactionID might contain uppercase chars, and we cannot lowercase them
// because they might lose their uniqueness.
// md5 creates a unique, lowercase hash.
seqName := "runs_test_suite_" + md5Hash(ID.String()) + "_seq"
seqName := "runs_test_suite_" + md5Hash(ID.String()+tenantID) + "_seq"
return strings.ReplaceAll(sql, runSequenceName, seqName)
}

Expand All @@ -118,12 +118,6 @@ func (td *RunRepository) CreateRun(ctx context.Context, tr TestSuiteRun) (TestSu
return TestSuiteRun{}, fmt.Errorf("sql beginTx: %w", err)
}

_, err = tx.ExecContext(ctx, replaceTestSuiteRunSequenceName(createSequenceQuery, tr.TestSuiteID))
if err != nil {
tx.Rollback()
return TestSuiteRun{}, fmt.Errorf("sql exec: %w", err)
}

params := sqlutil.TenantInsert(ctx,
tr.TestSuiteID,
tr.TestSuiteVersion,
Expand All @@ -134,10 +128,18 @@ func (td *RunRepository) CreateRun(ctx context.Context, tr TestSuiteRun) (TestSu
jsonVariableSet,
)

tenantID := sqlutil.TenantIDString(ctx)

_, err = tx.ExecContext(ctx, replaceTestSuiteRunSequenceName(createSequenceQuery, tr.TestSuiteID, tenantID))
if err != nil {
tx.Rollback()
return TestSuiteRun{}, fmt.Errorf("sql exec: %w", err)
}

var runID int
err = tx.QueryRowContext(
ctx,
replaceTestSuiteRunSequenceName(createTestSuiteRunQuery, tr.TestSuiteID),
replaceTestSuiteRunSequenceName(createTestSuiteRunQuery, tr.TestSuiteID, tenantID),
params...,
).Scan(&runID)
if err != nil {
Expand Down

0 comments on commit 2ec83f8

Please sign in to comment.