Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature(backend) Tenant ID support #2980

Merged
merged 2 commits into from
Jul 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 4 additions & 0 deletions server/app/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"github.com/kubeshop/tracetest/server/executor/trigger"
httpServer "github.com/kubeshop/tracetest/server/http"
"github.com/kubeshop/tracetest/server/http/mappings"
"github.com/kubeshop/tracetest/server/http/middleware"
"github.com/kubeshop/tracetest/server/http/websocket"
"github.com/kubeshop/tracetest/server/linter/analyzer"
"github.com/kubeshop/tracetest/server/model"
Expand Down Expand Up @@ -269,6 +270,9 @@ func (app *App) Start(opts ...appOption) error {
// use the analytics middleware on complete router
router.Use(analyticsMW)

// use the tenant middleware on complete router
router.Use(middleware.TenantMiddleware)

apiRouter := router.
PathPrefix(app.cfg.ServerPathPrefix()).
PathPrefix("/api").
Expand Down
26 changes: 18 additions & 8 deletions server/config/demo/demo_repository.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,10 @@ const insertQuery = `INSERT INTO demos (
"enabled",
"type",
"pokeshop",
"opentelemetry_store"
"opentelemetry_store",
"tenant_id"
)
VALUES ($1, $2, $3, $4, $5, $6)`
VALUES ($1, $2, $3, $4, $5, $6, $7)`

func (r *Repository) Create(ctx context.Context, demo Demo) (Demo, error) {
tx, err := r.db.BeginTx(ctx, nil)
Expand All @@ -50,13 +51,16 @@ func (r *Repository) Create(ctx context.Context, demo Demo) (Demo, error) {
return Demo{}, fmt.Errorf("could not get JSON data from opentelemetry store example: %w", err)
}

tenantID := sqlutil.TenantID(ctx)

_, err = tx.ExecContext(ctx, insertQuery,
demo.ID,
demo.Name,
demo.Enabled,
demo.Type,
pokeshopJSONData,
openTelemetryStoreJSONData,
tenantID,
)

if err != nil {
Expand Down Expand Up @@ -102,7 +106,7 @@ func (r *Repository) Update(ctx context.Context, demo Demo) (Demo, error) {
return Demo{}, fmt.Errorf("could not get JSON data from opentelemetry store example: %w", err)
}

_, err = tx.ExecContext(ctx, updateQuery,
query, params := sqlutil.Tenant(ctx, updateQuery,
oldDemo.ID,
demo.Name,
demo.Enabled,
Expand All @@ -111,6 +115,8 @@ func (r *Repository) Update(ctx context.Context, demo Demo) (Demo, error) {
openTelemetryStoreJSONData,
)

_, err = tx.ExecContext(ctx, query, params...)

if err != nil {
tx.Rollback()
return Demo{}, fmt.Errorf("sql exec: %w", err)
Expand All @@ -133,17 +139,18 @@ const (
"type",
"pokeshop",
"opentelemetry_store"
FROM demos `
FROM demos`

getQuery = baseSelect + `WHERE "id" = $1`
getDefaultQuery = baseSelect + `WHERE "default" = true`
getQuery = baseSelect + ` WHERE "id" = $1`
getDefaultQuery = baseSelect + ` WHERE "default" = true`
)

func (r *Repository) Get(ctx context.Context, id id.ID) (Demo, error) {
return r.get(ctx, getQuery, id)
}

func (r *Repository) get(ctx context.Context, query string, args ...any) (Demo, error) {
query, args = sqlutil.Tenant(ctx, query, args...)
row := r.db.QueryRowContext(ctx, query, args...)
return readRow(row)
}
Expand All @@ -161,7 +168,8 @@ func (r *Repository) Delete(ctx context.Context, id id.ID) error {
return err
}

_, err = tx.ExecContext(ctx, deleteQuery, demo.ID)
query, params := sqlutil.Tenant(ctx, deleteQuery, demo.ID)
_, err = tx.ExecContext(ctx, query, params...)

if err != nil {
tx.Rollback()
Expand Down Expand Up @@ -191,6 +199,7 @@ func listQuery(baseSQL, query string, params []any) (string, []any) {

func (r *Repository) List(ctx context.Context, take, skip int, query, sortBy, sortDirection string) ([]Demo, error) {
q, params := listQuery(baseSelect, query, []any{take, skip})
q, params = sqlutil.Tenant(ctx, q, params...)

sortingFields := map[string]string{
"id": "id",
Expand Down Expand Up @@ -240,8 +249,9 @@ func (r *Repository) Count(ctx context.Context, query string) (int, error) {

count := 0

countQuery, params := sqlutil.Tenant(ctx, countQuery)
err := r.db.
QueryRowContext(ctx, countQuery).
QueryRowContext(ctx, countQuery, params...).
Scan(&count)

if err != nil {
Expand Down
15 changes: 11 additions & 4 deletions server/datastore/datastore_repository.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"time"

"github.com/kubeshop/tracetest/server/pkg/id"
"github.com/kubeshop/tracetest/server/pkg/sqlutil"
)

func NewRepository(db *sql.DB) *Repository {
Expand All @@ -33,8 +34,9 @@ INSERT INTO data_stores (
"type",
"is_default",
"values",
"created_at"
) VALUES ($1, $2, $3, $4, $5, $6)`
"created_at",
"tenant_id"
) VALUES ($1, $2, $3, $4, $5, $6, $7)`

const deleteQuery = `DELETE FROM data_stores WHERE "id" = $1`

Expand Down Expand Up @@ -99,13 +101,16 @@ func (r *Repository) Update(ctx context.Context, dataStore DataStore) (DataStore
return DataStore{}, fmt.Errorf("could not marshal values field configuration: %w", err)
}

tenantID := sqlutil.TenantID(ctx)

_, err = tx.ExecContext(ctx, insertQuery,
dataStore.ID,
dataStore.Name,
dataStore.Type,
dataStore.Default,
valuesJSON,
dataStore.CreatedAt,
tenantID,
)
if err != nil {
return DataStore{}, fmt.Errorf("datastore repository sql exec create: %w", err)
Expand All @@ -126,7 +131,8 @@ func (r *Repository) Delete(ctx context.Context, id id.ID) error {
}
defer tx.Rollback()

_, err = tx.ExecContext(ctx, deleteQuery, dataStoreSingleID)
query, params := sqlutil.Tenant(ctx, deleteQuery, id)
_, err = tx.ExecContext(ctx, query, params...)
if err != nil {
return fmt.Errorf("datastore repository sql exec delete: %w", err)
}
Expand Down Expand Up @@ -160,7 +166,8 @@ func (r *Repository) Current(ctx context.Context) (DataStore, error) {
}

func (r *Repository) Get(ctx context.Context, id id.ID) (DataStore, error) {
row := r.db.QueryRowContext(ctx, getQuery, id)
query, params := sqlutil.Tenant(ctx, getQuery, id)
row := r.db.QueryRowContext(ctx, query, params...)

dataStore, err := r.readRow(row)
if err != nil && errors.Is(err, sql.ErrNoRows) {
Expand Down
67 changes: 31 additions & 36 deletions server/environment/environment_repository.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"database/sql"
"encoding/json"
"fmt"
"strings"
"time"

"github.com/kubeshop/tracetest/server/pkg/id"
Expand All @@ -31,8 +30,9 @@ const (
"name",
"description",
"created_at",
"values"
) VALUES ($1, $2, $3, $4, $5)`
"values",
"tenant_id"
) VALUES ($1, $2, $3, $4, $5, $6)`

updateQuery = `
UPDATE environments SET
Expand Down Expand Up @@ -107,7 +107,8 @@ func (r *Repository) Delete(ctx context.Context, id id.ID) error {
}
defer tx.Rollback()

_, err = tx.ExecContext(ctx, deleteQuery, id)
query, params := sqlutil.Tenant(ctx, deleteQuery, id)
_, err = tx.ExecContext(ctx, query, params...)
if err != nil {
return fmt.Errorf("sql error: %w", err)
}
Expand Down Expand Up @@ -144,6 +145,7 @@ func (r *Repository) List(ctx context.Context, take, skip int, query, sortBy, so
sql = sqlutil.Sort(sql, sortBy, sortDirection, "created", sortingFields)
sql += ` LIMIT $1 OFFSET $2 `

sql, params = sqlutil.Tenant(ctx, sql, params...)
stmt, err := r.db.Prepare(sql)
if err != nil {
return []Environment{}, err
Expand Down Expand Up @@ -173,30 +175,11 @@ func (r *Repository) Count(ctx context.Context, query string) (int, error) {
return r.countEnvironments(ctx, query)
}

func sortQuery(sql, sortBy, sortDirection string, sortingFields map[string]string) string {
sortField, ok := sortingFields[sortBy]

if !ok {
sortField = sortingFields["created"]
}

dir := "DESC"
if strings.ToLower(sortDirection) == "asc" {
dir = "ASC"
}

return fmt.Sprintf("%s ORDER BY %s %s", sql, sortField, dir)
}

func (r *Repository) Get(ctx context.Context, id id.ID) (Environment, error) {
stmt, err := r.db.Prepare(getQuery + " WHERE e.id = $1")
query, params := sqlutil.Tenant(ctx, getQuery+" WHERE e.id = $1", id)
row := r.db.QueryRowContext(ctx, query, params...)

if err != nil {
return Environment{}, fmt.Errorf("prepare: %w", err)
}
defer stmt.Close()

environment, err := r.readEnvironmentRow(ctx, stmt.QueryRowContext(ctx, id))
environment, err := r.readEnvironmentRow(ctx, row)
if err != nil {
return Environment{}, err
}
Expand All @@ -207,7 +190,8 @@ func (r *Repository) Get(ctx context.Context, id id.ID) (Environment, error) {
func (r *Repository) Exists(ctx context.Context, id id.ID) (bool, error) {
exists := false

row := r.db.QueryRowContext(ctx, idExistsQuery, id)
query, params := sqlutil.Tenant(ctx, idExistsQuery, id)
row := r.db.QueryRowContext(ctx, query, params...)

err := row.Scan(&exists)

Expand Down Expand Up @@ -251,6 +235,7 @@ func (r *Repository) countEnvironments(ctx context.Context, query string) (int,

condition := " WHERE (e.name ilike $1 OR e.description ilike $1)"
sql, params := sqlutil.Search(countQuery, condition, query, params)
sql, params = sqlutil.Tenant(ctx, sql, params...)

err := r.db.
QueryRowContext(ctx, sql, params...).
Expand All @@ -274,13 +259,16 @@ func (r *Repository) insertIntoEnvironments(ctx context.Context, environment Env
return Environment{}, fmt.Errorf("encoding error: %w", err)
}

tenantID := sqlutil.TenantID(ctx)

_, err = stmt.ExecContext(
ctx,
environment.ID,
environment.Name,
environment.Description,
environment.CreatedAt,
jsonValues,
tenantID,
)

if err != nil {
Expand All @@ -291,30 +279,37 @@ func (r *Repository) insertIntoEnvironments(ctx context.Context, environment Env
}

func (r *Repository) updateIntoEnvironments(ctx context.Context, environment Environment, oldId id.ID) (Environment, error) {
stmt, err := r.db.Prepare(updateQuery)
if err != nil {
return Environment{}, fmt.Errorf("sql prepare: %w", err)
}
defer stmt.Close()

jsonValues, err := json.Marshal(environment.Values)
if err != nil {
return Environment{}, fmt.Errorf("encoding error: %w", err)
}
tx, err := r.db.BeginTx(ctx, nil)
if err != nil {
return Environment{}, err
}
defer tx.Rollback()
Comment on lines +287 to +290
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❤️


_, err = stmt.ExecContext(
ctx,
oldId,
query, params := sqlutil.Tenant(ctx, updateQuery, oldId,
environment.Name,
environment.Description,
environment.CreatedAt,
jsonValues,
)

_, err = tx.ExecContext(
ctx,
query,
params...,
)
if err != nil {
return Environment{}, fmt.Errorf("sql exec: %w", err)
}

err = tx.Commit()
if err != nil {
return Environment{}, fmt.Errorf("commit: %w", err)
}

return environment, nil
}

Expand Down
16 changes: 11 additions & 5 deletions server/executor/pollingprofile/polling_profile_repository.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"fmt"

"github.com/kubeshop/tracetest/server/pkg/id"
"github.com/kubeshop/tracetest/server/pkg/sqlutil"
)

func NewRepository(db *sql.DB) *Repository {
Expand Down Expand Up @@ -35,9 +36,10 @@ const (
"name",
"default",
"strategy",
"periodic"
"periodic",
"tenant_id"
)
VALUES ($1, $2, $3, $4, $5)`
VALUES ($1, $2, $3, $4, $5, $6)`
deleteQuery = `DELETE FROM polling_profiles`
)

Expand All @@ -47,12 +49,13 @@ func (r *Repository) Update(ctx context.Context, updated PollingProfile) (Pollin
updated.Default = true

tx, err := r.db.BeginTx(ctx, nil)
defer tx.Rollback()
if err != nil {
return PollingProfile{}, err
}
defer tx.Rollback()

_, err = tx.ExecContext(ctx, deleteQuery)
query, params := sqlutil.Tenant(ctx, deleteQuery)
_, err = tx.ExecContext(ctx, query, params...)
if err != nil {
return PollingProfile{}, fmt.Errorf("sql exec delete: %w", err)
}
Expand All @@ -65,12 +68,14 @@ func (r *Repository) Update(ctx context.Context, updated PollingProfile) (Pollin
}
}

tenantID := sqlutil.TenantID(ctx)
_, err = tx.ExecContext(ctx, insertQuery,
updated.ID,
updated.Name,
updated.Default,
updated.Strategy,
periodicJSON,
tenantID,
)
if err != nil {
return PollingProfile{}, fmt.Errorf("sql exec insert: %w", err)
Expand Down Expand Up @@ -106,8 +111,9 @@ func (r *Repository) Get(ctx context.Context, id id.ID) (PollingProfile, error)
}

var periodicJSON []byte
query, params := sqlutil.Tenant(ctx, getQuery)
err := r.db.
QueryRowContext(ctx, getQuery).
QueryRowContext(ctx, query, params...).
Scan(
&profile.Name,
&profile.Strategy,
Expand Down