diff --git a/pkg/server/service.go b/pkg/server/service.go index 1fe2f57c1..795ffbdfd 100644 --- a/pkg/server/service.go +++ b/pkg/server/service.go @@ -90,6 +90,11 @@ func newGRPCServer(ctx context.Context, pluginRegistry *plugins.Registry, cfg *c auth.AuthenticationLoggingInterceptor, middlewareInterceptors, ) + if cfg.Security.RateLimit.Enabled { + rateLimiter := plugins.NewRateLimiter(cfg.Security.RateLimit.RequestsPerSecond, cfg.Security.RateLimit.BurstSize, cfg.Security.RateLimit.CleanupInterval.Duration) + rateLimitInterceptors := plugins.RateLimiteInterceptor(*rateLimiter) + chainedUnaryInterceptors = grpcmiddleware.ChainUnaryServer(chainedUnaryInterceptors, rateLimitInterceptors) + } } else { logger.Infof(ctx, "Creating gRPC server without authentication") chainedUnaryInterceptors = grpcmiddleware.ChainUnaryServer(grpcprometheus.UnaryServerInterceptor) @@ -257,6 +262,7 @@ func serveGatewayInsecure(ctx context.Context, pluginRegistry *plugins.Registry, } oauth2ResourceServer = oauth2Provider + } else { oauth2ResourceServer, err = authzserver.NewOAuth2ResourceServer(ctx, authCfg.AppAuth.ExternalAuthServer, authCfg.UserAuth.OpenID.BaseURL) if err != nil { diff --git a/plugins/rate_limit.go b/plugins/rate_limit.go index 0197c0cf3..d32cdf985 100644 --- a/plugins/rate_limit.go +++ b/plugins/rate_limit.go @@ -1,14 +1,20 @@ package plugins import ( + "context" + "errors" "fmt" "sync" "time" + auth "github.com/flyteorg/flyteadmin/auth" "golang.org/x/time/rate" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" ) -type RateLimitError error +type RateLimitExceeded error // define a struct that contains a map of rate limiters, and a time stamp of last access and a mutex to protect the map type accessRecords struct { @@ -16,7 +22,7 @@ type accessRecords struct { lastAccess time.Time } -type Limiter struct { +type LimiterStore struct { accessPerUser map[string]*accessRecords mutex *sync.Mutex requestPerSec int @@ -27,7 +33,7 @@ type Limiter struct { // define a function named Allow that takes userID and returns RateLimitError // the function check if the user is in the map, if not, create a new accessRecords for the user // then it check if the user can access the resource, if not, return RateLimitError -func (l *Limiter) Allow(userID string) error { +func (l *LimiterStore) Allow(userID string) error { l.mutex.Lock() defer l.mutex.Unlock() if _, ok := l.accessPerUser[userID]; !ok { @@ -38,13 +44,13 @@ func (l *Limiter) Allow(userID string) error { } if !l.accessPerUser[userID].limiter.Allow() { - return RateLimitError(fmt.Errorf("rate limit exceeded")) + return RateLimitExceeded(fmt.Errorf("rate limit exceeded")) } return nil } -func (l *Limiter) clean() { +func (l *LimiterStore) clean() { l.mutex.Lock() defer l.mutex.Unlock() for userID, accessRecord := range l.accessPerUser { @@ -54,8 +60,8 @@ func (l *Limiter) clean() { } } -func NewRateLimiter(requestPerSec int, burstSize int, cleanupInterval time.Duration) *Limiter { - l := &Limiter{ +func newRateLimitStore(requestPerSec int, burstSize int, cleanupInterval time.Duration) *LimiterStore { + l := &LimiterStore{ accessPerUser: make(map[string]*accessRecords), mutex: &sync.Mutex{}, requestPerSec: requestPerSec, @@ -72,3 +78,35 @@ func NewRateLimiter(requestPerSec int, burstSize int, cleanupInterval time.Durat return l } + +type RateLimiter struct { + limiter *LimiterStore +} + +func (r *RateLimiter) Limit(ctx context.Context) error { + IdenCtx := auth.IdentityContextFromContext(ctx) + if IdenCtx.IsEmpty() { + return errors.New("no identity context found") + } + userID := IdenCtx.UserID() + if err := r.limiter.Allow(userID); err != nil { + return err + } + return nil +} + +func NewRateLimiter(requestPerSec int, burstSize int, cleanupInterval time.Duration) *RateLimiter { + limiter := newRateLimitStore(requestPerSec, burstSize, cleanupInterval) + return &RateLimiter{limiter: limiter} +} + +func RateLimiteInterceptor(limiter RateLimiter) grpc.UnaryServerInterceptor { + return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) ( + resp interface{}, err error) { + if err := limiter.Limit(ctx); err != nil { + return nil, status.Errorf(codes.ResourceExhausted, "rate limit exceeded") + } + + return handler(ctx, req) + } +} diff --git a/plugins/rate_limit_test.go b/plugins/rate_limit_test.go index 0b045ed44..87907b2f1 100644 --- a/plugins/rate_limit_test.go +++ b/plugins/rate_limit_test.go @@ -1,56 +1,95 @@ package plugins import ( + "context" "testing" "time" + auth "github.com/flyteorg/flyteadmin/auth" "github.com/stretchr/testify/assert" ) func TestNewRateLimiter(t *testing.T) { - rl := NewRateLimiter(1, 1, time.Second) - assert.NotNil(t, rl) + rlStore := newRateLimitStore(1, 1, time.Second) + assert.NotNil(t, rlStore) } -func TestLimiter_Allow(t *testing.T) { - rl := NewRateLimiter(1, 1, time.Second) - assert.NoError(t, rl.Allow("hello")) - // assert error type is RateLimitError - assert.Error(t, rl.Allow("hello")) +func TestLimiterAllow(t *testing.T) { + rlStore := newRateLimitStore(1, 1, time.Second) + assert.NoError(t, rlStore.Allow("hello")) + assert.Error(t, rlStore.Allow("hello")) time.Sleep(time.Second) - assert.NoError(t, rl.Allow("hello")) + assert.NoError(t, rlStore.Allow("hello")) } -func TestLimiter_AllowBurst(t *testing.T) { - rl := NewRateLimiter(1, 2, time.Second) - assert.NoError(t, rl.Allow("hello")) - assert.NoError(t, rl.Allow("hello")) - assert.Error(t, rl.Allow("hello")) - assert.NoError(t, rl.Allow("world")) +func TestLimiterAllowBurst(t *testing.T) { + rlStore := newRateLimitStore(1, 2, time.Second) + assert.NoError(t, rlStore.Allow("hello")) + assert.NoError(t, rlStore.Allow("hello")) + assert.Error(t, rlStore.Allow("hello")) + assert.NoError(t, rlStore.Allow("world")) } -func TestLimiter_Clean(t *testing.T) { - rl := NewRateLimiter(1, 1, time.Second) - assert.NoError(t, rl.Allow("hello")) - assert.Error(t, rl.Allow("hello")) +func TestLimiterClean(t *testing.T) { + rlStore := newRateLimitStore(1, 1, time.Second) + assert.NoError(t, rlStore.Allow("hello")) + assert.Error(t, rlStore.Allow("hello")) time.Sleep(time.Second) - rl.clean() - assert.NoError(t, rl.Allow("hello")) + rlStore.clean() + assert.NoError(t, rlStore.Allow("hello")) } -func TestLimiter_AllowOnMultipleRequests(t *testing.T) { - rl := NewRateLimiter(1, 1, time.Second) - assert.NoError(t, rl.Allow("a")) - assert.NoError(t, rl.Allow("b")) - assert.NoError(t, rl.Allow("c")) - assert.Error(t, rl.Allow("a")) - assert.Error(t, rl.Allow("b")) +func TestLimiterAllowOnMultipleRequests(t *testing.T) { + rlStore := newRateLimitStore(1, 1, time.Second) + assert.NoError(t, rlStore.Allow("a")) + assert.NoError(t, rlStore.Allow("b")) + assert.NoError(t, rlStore.Allow("c")) + assert.Error(t, rlStore.Allow("a")) + assert.Error(t, rlStore.Allow("b")) time.Sleep(time.Second) - assert.NoError(t, rl.Allow("a")) - assert.Error(t, rl.Allow("a")) - assert.NoError(t, rl.Allow("b")) - assert.Error(t, rl.Allow("b")) - assert.NoError(t, rl.Allow("c")) + assert.NoError(t, rlStore.Allow("a")) + assert.Error(t, rlStore.Allow("a")) + assert.NoError(t, rlStore.Allow("b")) + assert.Error(t, rlStore.Allow("b")) + assert.NoError(t, rlStore.Allow("c")) +} + +func TestRateLimiterLimitPass(t *testing.T) { + rateLimit := NewRateLimiter(1, 1, time.Second) + assert.NotNil(t, rateLimit) + + identityCtx, err := auth.NewIdentityContext("audience", "user1", "app1", time.Now(), nil, nil, nil) + assert.NoError(t, err) + + ctx := context.WithValue(context.TODO(), auth.ContextKeyIdentityContext, identityCtx) + err = rateLimit.Limit(ctx) + assert.NoError(t, err) + +} + +func TestRateLimiterLimitStop(t *testing.T) { + rateLimit := NewRateLimiter(1, 1, time.Second) + assert.NotNil(t, rateLimit) + + identityCtx, err := auth.NewIdentityContext("audience", "user1", "app1", time.Now(), nil, nil, nil) + assert.NoError(t, err) + ctx := context.WithValue(context.TODO(), auth.ContextKeyIdentityContext, identityCtx) + err = rateLimit.Limit(ctx) + assert.NoError(t, err) + + err = rateLimit.Limit(ctx) + assert.Error(t, err) + +} + +func TestRateLimiterLimitWithoutUserIdentity(t *testing.T) { + rateLimit := NewRateLimiter(1, 1, time.Second) + assert.NotNil(t, rateLimit) + + ctx := context.TODO() + + err := rateLimit.Limit(ctx) + assert.Error(t, err) }