-
Notifications
You must be signed in to change notification settings - Fork 58
/
user_validator.go
123 lines (97 loc) · 2.92 KB
/
user_validator.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
package validator
import (
"errors"
"fmt"
"time"
internal_errors "github.com/bricks-cloud/bricksllm/internal/errors"
"github.com/bricks-cloud/bricksllm/internal/key"
"github.com/bricks-cloud/bricksllm/internal/user"
)
type UserValidator struct {
clc costLimitCache
rlc rateLimitCache
cls costLimitStorage
}
func NewUserValidator(
clc costLimitCache,
rlc rateLimitCache,
cls costLimitStorage,
) *UserValidator {
return &UserValidator{
clc: clc,
rlc: rlc,
cls: cls,
}
}
func (v *UserValidator) Validate(u *user.User, promptCost float64) error {
if u == nil {
return internal_errors.NewValidationError("empty user")
}
if u.Revoked {
return internal_errors.NewValidationError("user revoked")
}
parsed, _ := time.ParseDuration(u.Ttl)
if !v.validateTtl(u.CreatedAt, parsed) {
return internal_errors.NewExpirationError("user expired", internal_errors.TtlExpiration)
}
err := v.validateRateLimitOverTime(u.Id, u.RateLimitOverTime, u.RateLimitUnit)
if err != nil {
return err
}
err = v.validateCostLimitOverTime(u.Id, u.CostLimitInUsdOverTime, u.CostLimitInUsdUnit)
if err != nil {
return err
}
err = v.validateCostLimit(u.Id, u.CostLimitInUsd)
if err != nil {
return err
}
return nil
}
func (v *UserValidator) validateTtl(createdAt int64, ttl time.Duration) bool {
ttlInSecs := int64(ttl.Seconds())
if ttlInSecs == 0 {
return true
}
current := time.Now().Unix()
return current < createdAt+ttlInSecs
}
func (v *UserValidator) validateRateLimitOverTime(userId string, rateLimitOverTime int, rateLimitUnit key.TimeUnit) error {
if rateLimitOverTime == 0 {
return nil
}
c, err := v.rlc.GetCounter(userId, rateLimitUnit)
if err != nil {
return errors.New("failed to get rate limit counter")
}
if c >= int64(rateLimitOverTime) {
return internal_errors.NewRateLimitError(fmt.Sprintf("user exceeded rate limit %d requests per %s", rateLimitOverTime, rateLimitUnit))
}
return nil
}
func (v *UserValidator) validateCostLimitOverTime(userId string, costLimitOverTime float64, costLimitUnit key.TimeUnit) error {
if costLimitOverTime == 0 {
return nil
}
cachedCost, err := v.clc.GetCounter(userId, costLimitUnit)
if err != nil {
return errors.New("failed to get cached token cost")
}
if cachedCost >= convertDollarToMicroDollars(costLimitOverTime) {
return internal_errors.NewCostLimitError(fmt.Sprintf("cost limit: %f has been reached for the current time period: %s", costLimitOverTime, costLimitUnit))
}
return nil
}
func (v *UserValidator) validateCostLimit(userId string, costLimit float64) error {
if costLimit == 0 {
return nil
}
existingTotalCost, err := v.cls.GetCounter(userId)
if err != nil {
return errors.New("failed to get total token cost")
}
if existingTotalCost >= convertDollarToMicroDollars(costLimit) {
return internal_errors.NewExpirationError(fmt.Sprintf("total cost limit: %f has been reached", costLimit), internal_errors.CostLimitExpiration)
}
return nil
}