Skip to content

Commit

Permalink
feature(backend) Tenant ID support (#2980)
Browse files Browse the repository at this point in the history
* feature(backend) Tenant ID support

* PR review comments
  • Loading branch information
xoscar committed Jul 24, 2023
1 parent f9e71aa commit 6fc605c
Show file tree
Hide file tree
Showing 16 changed files with 530 additions and 162 deletions.
4 changes: 4 additions & 0 deletions server/app/app.go
Expand Up @@ -23,6 +23,7 @@ import (
"github.com/kubeshop/tracetest/server/executor/trigger"
httpServer "github.com/kubeshop/tracetest/server/http"
"github.com/kubeshop/tracetest/server/http/mappings"
"github.com/kubeshop/tracetest/server/http/middleware"
"github.com/kubeshop/tracetest/server/http/websocket"
"github.com/kubeshop/tracetest/server/linter/analyzer"
"github.com/kubeshop/tracetest/server/model"
Expand Down Expand Up @@ -269,6 +270,9 @@ func (app *App) Start(opts ...appOption) error {
// use the analytics middleware on complete router
router.Use(analyticsMW)

// use the tenant middleware on complete router
router.Use(middleware.TenantMiddleware)

apiRouter := router.
PathPrefix(app.cfg.ServerPathPrefix()).
PathPrefix("/api").
Expand Down
26 changes: 18 additions & 8 deletions server/config/demo/demo_repository.go
Expand Up @@ -30,9 +30,10 @@ const insertQuery = `INSERT INTO demos (
"enabled",
"type",
"pokeshop",
"opentelemetry_store"
"opentelemetry_store",
"tenant_id"
)
VALUES ($1, $2, $3, $4, $5, $6)`
VALUES ($1, $2, $3, $4, $5, $6, $7)`

func (r *Repository) Create(ctx context.Context, demo Demo) (Demo, error) {
tx, err := r.db.BeginTx(ctx, nil)
Expand All @@ -50,13 +51,16 @@ func (r *Repository) Create(ctx context.Context, demo Demo) (Demo, error) {
return Demo{}, fmt.Errorf("could not get JSON data from opentelemetry store example: %w", err)
}

tenantID := sqlutil.TenantID(ctx)

_, err = tx.ExecContext(ctx, insertQuery,
demo.ID,
demo.Name,
demo.Enabled,
demo.Type,
pokeshopJSONData,
openTelemetryStoreJSONData,
tenantID,
)

if err != nil {
Expand Down Expand Up @@ -102,7 +106,7 @@ func (r *Repository) Update(ctx context.Context, demo Demo) (Demo, error) {
return Demo{}, fmt.Errorf("could not get JSON data from opentelemetry store example: %w", err)
}

_, err = tx.ExecContext(ctx, updateQuery,
query, params := sqlutil.Tenant(ctx, updateQuery,
oldDemo.ID,
demo.Name,
demo.Enabled,
Expand All @@ -111,6 +115,8 @@ func (r *Repository) Update(ctx context.Context, demo Demo) (Demo, error) {
openTelemetryStoreJSONData,
)

_, err = tx.ExecContext(ctx, query, params...)

if err != nil {
tx.Rollback()
return Demo{}, fmt.Errorf("sql exec: %w", err)
Expand All @@ -133,17 +139,18 @@ const (
"type",
"pokeshop",
"opentelemetry_store"
FROM demos `
FROM demos`

getQuery = baseSelect + `WHERE "id" = $1`
getDefaultQuery = baseSelect + `WHERE "default" = true`
getQuery = baseSelect + ` WHERE "id" = $1`
getDefaultQuery = baseSelect + ` WHERE "default" = true`
)

func (r *Repository) Get(ctx context.Context, id id.ID) (Demo, error) {
return r.get(ctx, getQuery, id)
}

func (r *Repository) get(ctx context.Context, query string, args ...any) (Demo, error) {
query, args = sqlutil.Tenant(ctx, query, args...)
row := r.db.QueryRowContext(ctx, query, args...)
return readRow(row)
}
Expand All @@ -161,7 +168,8 @@ func (r *Repository) Delete(ctx context.Context, id id.ID) error {
return err
}

_, err = tx.ExecContext(ctx, deleteQuery, demo.ID)
query, params := sqlutil.Tenant(ctx, deleteQuery, demo.ID)
_, err = tx.ExecContext(ctx, query, params...)

if err != nil {
tx.Rollback()
Expand Down Expand Up @@ -191,6 +199,7 @@ func listQuery(baseSQL, query string, params []any) (string, []any) {

func (r *Repository) List(ctx context.Context, take, skip int, query, sortBy, sortDirection string) ([]Demo, error) {
q, params := listQuery(baseSelect, query, []any{take, skip})
q, params = sqlutil.Tenant(ctx, q, params...)

sortingFields := map[string]string{
"id": "id",
Expand Down Expand Up @@ -240,8 +249,9 @@ func (r *Repository) Count(ctx context.Context, query string) (int, error) {

count := 0

countQuery, params := sqlutil.Tenant(ctx, countQuery)
err := r.db.
QueryRowContext(ctx, countQuery).
QueryRowContext(ctx, countQuery, params...).
Scan(&count)

if err != nil {
Expand Down
15 changes: 11 additions & 4 deletions server/datastore/datastore_repository.go
Expand Up @@ -9,6 +9,7 @@ import (
"time"

"github.com/kubeshop/tracetest/server/pkg/id"
"github.com/kubeshop/tracetest/server/pkg/sqlutil"
)

func NewRepository(db *sql.DB) *Repository {
Expand All @@ -33,8 +34,9 @@ INSERT INTO data_stores (
"type",
"is_default",
"values",
"created_at"
) VALUES ($1, $2, $3, $4, $5, $6)`
"created_at",
"tenant_id"
) VALUES ($1, $2, $3, $4, $5, $6, $7)`

const deleteQuery = `DELETE FROM data_stores WHERE "id" = $1`

Expand Down Expand Up @@ -99,13 +101,16 @@ func (r *Repository) Update(ctx context.Context, dataStore DataStore) (DataStore
return DataStore{}, fmt.Errorf("could not marshal values field configuration: %w", err)
}

tenantID := sqlutil.TenantID(ctx)

_, err = tx.ExecContext(ctx, insertQuery,
dataStore.ID,
dataStore.Name,
dataStore.Type,
dataStore.Default,
valuesJSON,
dataStore.CreatedAt,
tenantID,
)
if err != nil {
return DataStore{}, fmt.Errorf("datastore repository sql exec create: %w", err)
Expand All @@ -126,7 +131,8 @@ func (r *Repository) Delete(ctx context.Context, id id.ID) error {
}
defer tx.Rollback()

_, err = tx.ExecContext(ctx, deleteQuery, dataStoreSingleID)
query, params := sqlutil.Tenant(ctx, deleteQuery, id)
_, err = tx.ExecContext(ctx, query, params...)
if err != nil {
return fmt.Errorf("datastore repository sql exec delete: %w", err)
}
Expand Down Expand Up @@ -160,7 +166,8 @@ func (r *Repository) Current(ctx context.Context) (DataStore, error) {
}

func (r *Repository) Get(ctx context.Context, id id.ID) (DataStore, error) {
row := r.db.QueryRowContext(ctx, getQuery, id)
query, params := sqlutil.Tenant(ctx, getQuery, id)
row := r.db.QueryRowContext(ctx, query, params...)

dataStore, err := r.readRow(row)
if err != nil && errors.Is(err, sql.ErrNoRows) {
Expand Down
67 changes: 31 additions & 36 deletions server/environment/environment_repository.go
Expand Up @@ -5,7 +5,6 @@ import (
"database/sql"
"encoding/json"
"fmt"
"strings"
"time"

"github.com/kubeshop/tracetest/server/pkg/id"
Expand All @@ -31,8 +30,9 @@ const (
"name",
"description",
"created_at",
"values"
) VALUES ($1, $2, $3, $4, $5)`
"values",
"tenant_id"
) VALUES ($1, $2, $3, $4, $5, $6)`

updateQuery = `
UPDATE environments SET
Expand Down Expand Up @@ -107,7 +107,8 @@ func (r *Repository) Delete(ctx context.Context, id id.ID) error {
}
defer tx.Rollback()

_, err = tx.ExecContext(ctx, deleteQuery, id)
query, params := sqlutil.Tenant(ctx, deleteQuery, id)
_, err = tx.ExecContext(ctx, query, params...)
if err != nil {
return fmt.Errorf("sql error: %w", err)
}
Expand Down Expand Up @@ -144,6 +145,7 @@ func (r *Repository) List(ctx context.Context, take, skip int, query, sortBy, so
sql = sqlutil.Sort(sql, sortBy, sortDirection, "created", sortingFields)
sql += ` LIMIT $1 OFFSET $2 `

sql, params = sqlutil.Tenant(ctx, sql, params...)
stmt, err := r.db.Prepare(sql)
if err != nil {
return []Environment{}, err
Expand Down Expand Up @@ -173,30 +175,11 @@ func (r *Repository) Count(ctx context.Context, query string) (int, error) {
return r.countEnvironments(ctx, query)
}

func sortQuery(sql, sortBy, sortDirection string, sortingFields map[string]string) string {
sortField, ok := sortingFields[sortBy]

if !ok {
sortField = sortingFields["created"]
}

dir := "DESC"
if strings.ToLower(sortDirection) == "asc" {
dir = "ASC"
}

return fmt.Sprintf("%s ORDER BY %s %s", sql, sortField, dir)
}

func (r *Repository) Get(ctx context.Context, id id.ID) (Environment, error) {
stmt, err := r.db.Prepare(getQuery + " WHERE e.id = $1")
query, params := sqlutil.Tenant(ctx, getQuery+" WHERE e.id = $1", id)
row := r.db.QueryRowContext(ctx, query, params...)

if err != nil {
return Environment{}, fmt.Errorf("prepare: %w", err)
}
defer stmt.Close()

environment, err := r.readEnvironmentRow(ctx, stmt.QueryRowContext(ctx, id))
environment, err := r.readEnvironmentRow(ctx, row)
if err != nil {
return Environment{}, err
}
Expand All @@ -207,7 +190,8 @@ func (r *Repository) Get(ctx context.Context, id id.ID) (Environment, error) {
func (r *Repository) Exists(ctx context.Context, id id.ID) (bool, error) {
exists := false

row := r.db.QueryRowContext(ctx, idExistsQuery, id)
query, params := sqlutil.Tenant(ctx, idExistsQuery, id)
row := r.db.QueryRowContext(ctx, query, params...)

err := row.Scan(&exists)

Expand Down Expand Up @@ -251,6 +235,7 @@ func (r *Repository) countEnvironments(ctx context.Context, query string) (int,

condition := " WHERE (e.name ilike $1 OR e.description ilike $1)"
sql, params := sqlutil.Search(countQuery, condition, query, params)
sql, params = sqlutil.Tenant(ctx, sql, params...)

err := r.db.
QueryRowContext(ctx, sql, params...).
Expand All @@ -274,13 +259,16 @@ func (r *Repository) insertIntoEnvironments(ctx context.Context, environment Env
return Environment{}, fmt.Errorf("encoding error: %w", err)
}

tenantID := sqlutil.TenantID(ctx)

_, err = stmt.ExecContext(
ctx,
environment.ID,
environment.Name,
environment.Description,
environment.CreatedAt,
jsonValues,
tenantID,
)

if err != nil {
Expand All @@ -291,30 +279,37 @@ func (r *Repository) insertIntoEnvironments(ctx context.Context, environment Env
}

func (r *Repository) updateIntoEnvironments(ctx context.Context, environment Environment, oldId id.ID) (Environment, error) {
stmt, err := r.db.Prepare(updateQuery)
if err != nil {
return Environment{}, fmt.Errorf("sql prepare: %w", err)
}
defer stmt.Close()

jsonValues, err := json.Marshal(environment.Values)
if err != nil {
return Environment{}, fmt.Errorf("encoding error: %w", err)
}
tx, err := r.db.BeginTx(ctx, nil)
if err != nil {
return Environment{}, err
}
defer tx.Rollback()

_, err = stmt.ExecContext(
ctx,
oldId,
query, params := sqlutil.Tenant(ctx, updateQuery, oldId,
environment.Name,
environment.Description,
environment.CreatedAt,
jsonValues,
)

_, err = tx.ExecContext(
ctx,
query,
params...,
)
if err != nil {
return Environment{}, fmt.Errorf("sql exec: %w", err)
}

err = tx.Commit()
if err != nil {
return Environment{}, fmt.Errorf("commit: %w", err)
}

return environment, nil
}

Expand Down
16 changes: 11 additions & 5 deletions server/executor/pollingprofile/polling_profile_repository.go
Expand Up @@ -8,6 +8,7 @@ import (
"fmt"

"github.com/kubeshop/tracetest/server/pkg/id"
"github.com/kubeshop/tracetest/server/pkg/sqlutil"
)

func NewRepository(db *sql.DB) *Repository {
Expand Down Expand Up @@ -35,9 +36,10 @@ const (
"name",
"default",
"strategy",
"periodic"
"periodic",
"tenant_id"
)
VALUES ($1, $2, $3, $4, $5)`
VALUES ($1, $2, $3, $4, $5, $6)`
deleteQuery = `DELETE FROM polling_profiles`
)

Expand All @@ -47,12 +49,13 @@ func (r *Repository) Update(ctx context.Context, updated PollingProfile) (Pollin
updated.Default = true

tx, err := r.db.BeginTx(ctx, nil)
defer tx.Rollback()
if err != nil {
return PollingProfile{}, err
}
defer tx.Rollback()

_, err = tx.ExecContext(ctx, deleteQuery)
query, params := sqlutil.Tenant(ctx, deleteQuery)
_, err = tx.ExecContext(ctx, query, params...)
if err != nil {
return PollingProfile{}, fmt.Errorf("sql exec delete: %w", err)
}
Expand All @@ -65,12 +68,14 @@ func (r *Repository) Update(ctx context.Context, updated PollingProfile) (Pollin
}
}

tenantID := sqlutil.TenantID(ctx)
_, err = tx.ExecContext(ctx, insertQuery,
updated.ID,
updated.Name,
updated.Default,
updated.Strategy,
periodicJSON,
tenantID,
)
if err != nil {
return PollingProfile{}, fmt.Errorf("sql exec insert: %w", err)
Expand Down Expand Up @@ -106,8 +111,9 @@ func (r *Repository) Get(ctx context.Context, id id.ID) (PollingProfile, error)
}

var periodicJSON []byte
query, params := sqlutil.Tenant(ctx, getQuery)
err := r.db.
QueryRowContext(ctx, getQuery).
QueryRowContext(ctx, query, params...).
Scan(
&profile.Name,
&profile.Strategy,
Expand Down

0 comments on commit 6fc605c

Please sign in to comment.