-
Notifications
You must be signed in to change notification settings - Fork 66
/
tenant.go
58 lines (44 loc) 路 1.22 KB
/
tenant.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
package sqlutil
import (
"context"
"fmt"
"strings"
"github.com/kubeshop/tracetest/server/http/middleware"
)
func Tenant(ctx context.Context, query string, params ...any) (string, []any) {
tenantID := TenantID(ctx)
if tenantID == nil {
return query, params
}
prefix := getQueryPrefix(query)
paramNumber := len(params) + 1
condition := fmt.Sprintf(" %s tenant_id = $%d", prefix, paramNumber)
return query + condition, append(params, *tenantID)
}
func TenantWithPrefix(ctx context.Context, query string, prefix string, params ...any) (string, []any) {
tenantID := TenantID(ctx)
if tenantID == nil {
return query, params
}
queryPrefix := getQueryPrefix(query)
paramNumber := len(params) + 1
condition := fmt.Sprintf(" %s %stenant_id = $%d", queryPrefix, prefix, paramNumber)
return query + condition, append(params, *tenantID)
}
func TenantID(ctx context.Context) *string {
tenantID := ctx.Value(middleware.TenantIDKey)
if tenantID == "" || tenantID == nil {
return nil
}
tenantIDString := tenantID.(string)
return &tenantIDString
}
func getQueryPrefix(query string) string {
prefix := ""
if strings.Contains(strings.ToLower(query), "where") {
prefix = "AND "
} else {
prefix = "WHERE "
}
return prefix
}