diff --git a/example_user_rate_limit_test.go b/example_user_rate_limit_test.go new file mode 100644 index 0000000..1947a8c --- /dev/null +++ b/example_user_rate_limit_test.go @@ -0,0 +1,91 @@ +// +build go1.7 + +package semaphore_test + +import ( + "context" + "net/http" + "net/http/httptest" + "runtime" + "strconv" + "sync" + + "github.com/kamilsk/semaphore" +) + +// User represents user ID. +type User int + +// Config contains abstract configuration fields. +type Config struct { + DefaultUser User + DefaultCapacity int + Capacity map[User]int +} + +// This variables can be a part of limiter provider service. +var ( + mx sync.RWMutex + limiters = make(map[User]semaphore.Semaphore) +) + +// LimiterForUser returns limiter for user found in request context. +func LimiterForUser(ctx context.Context, cnf Config) semaphore.Semaphore { + user, ok := ctx.Value("user").(User) + if !ok { + user = cnf.DefaultUser + } + + mx.RLock() + sem, ok := limiters[user] + mx.RUnlock() + + if !ok { + c, ok := cnf.Capacity[user] + if !ok { + c = cnf.DefaultCapacity + } + sem = semaphore.New(c) + + mx.Lock() + limiters[user] = sem + mx.Unlock() + } + return sem +} + +// RateLimiter performs rate limitation. +func RateLimiter(cnf Config, handler http.HandlerFunc) http.HandlerFunc { + return func(rw http.ResponseWriter, req *http.Request) { + _ = LimiterForUser(req.Context(), cnf) + handler.ServeHTTP(rw, req) + } +} + +// UserToContext gets user ID from request header and puts it into request context. +func UserToContext(cnf Config, handler http.HandlerFunc) http.HandlerFunc { + return func(rw http.ResponseWriter, req *http.Request) { + var user User = cnf.DefaultUser + + if id := req.Header.Get("user"); id != "" { + i, err := strconv.Atoi(id) + if err == nil { + user = User(i) + } + } + + handler.ServeHTTP(rw, req.WithContext(context.WithValue(req.Context(), "user", user))) + } +} + +// This example shows how to create user specific rate limiter. +func Example_userRateLimitation() { + var cnf Config = Config{ + DefaultUser: 1, + DefaultCapacity: 10, + Capacity: map[User]int{1: runtime.GOMAXPROCS(0)}, + } + + ts := httptest.NewServer(RateLimiter(cnf, UserToContext(cnf, func(rw http.ResponseWriter, req *http.Request) {}))) + defer ts.Close() +}