Skip to content

Commit

Permalink
令牌锁
Browse files Browse the repository at this point in the history
  • Loading branch information
yisar committed Apr 22, 2024
1 parent d95e56b commit 1a71d90
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 29 deletions.
7 changes: 0 additions & 7 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,3 @@ require (
github.com/golang-jwt/jwt/v4 v4.4.1
github.com/lib/pq v1.10.7
)

require (
github.com/didip/tollbooth v4.0.2+incompatible // indirect
github.com/didip/tollbooth_httprouter v0.0.0-20170928042012-f7d42d1bfca5 // indirect
github.com/patrickmn/go-cache v2.1.0+incompatible // indirect
golang.org/x/time v0.5.0 // indirect
)
24 changes: 2 additions & 22 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,7 @@ import (
"io/fs"
"net/http"

"time"

"github.com/cliclitv/go-clicli/svc"
"github.com/didip/tollbooth"
"github.com/didip/tollbooth/limiter"
"github.com/julienschmidt/httprouter"
)

Expand Down Expand Up @@ -41,22 +37,6 @@ func NewMiddleWareHandler(r *httprouter.Router) http.Handler {

var whiteOriginsSet = make(map[string]bool)

func LimitHandler(handler httprouter.Handle) httprouter.Handle {
return func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
lmt := tollbooth.NewLimiter(1, &limiter.ExpirableOptions{DefaultExpirationTTL: time.Minute})

httpError := tollbooth.LimitByRequest(lmt, w, r)
if httpError != nil {
w.Header().Add("Content-Type", lmt.GetMessageContentType())
w.WriteHeader(httpError.StatusCode)
w.Write([]byte(httpError.Message))
return
}

handler(w, r, ps)
}
}

func (m middleWareHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
origin := r.Header.Get("Origin")
if whiteOriginsSet[origin] {
Expand All @@ -72,13 +52,13 @@ func (m middleWareHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
func RegisterHandler() *httprouter.Router {
router := httprouter.New()

router.POST("/user/register", LimitHandler(svc.Register)) // 需要限流
router.POST("/user/register", svc.Register) // 需要限流
router.POST("/user/login", svc.Login)
router.POST("/user/logout", svc.Logout) // 前端清空 localstorage
router.POST("/user/update/:id", svc.UpdateUser)
router.GET("/users", svc.GetUsers)
router.GET("/user", svc.GetUser)
router.POST("/comment/add", LimitHandler(svc.AddComment)) // 需要限流
router.POST("/comment/add", svc.AddComment) // 需要限流
router.POST("/comment/read", svc.ReadComments)
router.GET("/comment/delete/:id", svc.DeleteComment)
router.GET("/comments", svc.GetComments)
Expand Down
7 changes: 7 additions & 0 deletions svc/comment.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,17 @@ import (
"strconv"

"github.com/cliclitv/go-clicli/db"
"github.com/cliclitv/go-clicli/util"
"github.com/julienschmidt/httprouter"
)

var tb = util.NewTokenBucket(10, 20)

func AddComment(w http.ResponseWriter, r *http.Request, p httprouter.Params) {
if !tb.TryConsume() {
sendMsg(w, 429, "请求限速")
return
}
req, _ := ioutil.ReadAll(r.Body)
body := &db.Comment{}

Expand Down
4 changes: 4 additions & 0 deletions svc/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ func IsNumber(str string) bool {
}

func Register(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
if !tb.TryConsume() {
sendMsg(w, 429, "请求限速")
return
}
req, _ := io.ReadAll(r.Body)
ubody := &db.User{}

Expand Down
47 changes: 47 additions & 0 deletions util/tokenlock.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package util

import (
"sync"
"time"
)

type TokenBucket struct {
rate float64
capacity float64
tokens float64
lastFilled time.Time
mutex sync.Mutex
}

func NewTokenBucket(rate float64, capacity float64) *TokenBucket {
return &TokenBucket{
rate: rate,
capacity: capacity,
tokens: capacity,
lastFilled: time.Now(),
}
}

func (tb *TokenBucket) fillTokens() {
now := time.Now()
delta := now.Sub(tb.lastFilled).Seconds()
tb.tokens = tb.tokens + tb.rate*delta
if tb.tokens > tb.capacity {
tb.tokens = tb.capacity
}
tb.lastFilled = now
}

func (tb *TokenBucket) TryConsume() bool {
tb.mutex.Lock()
defer tb.mutex.Unlock()

tb.fillTokens()

if tb.tokens >= 1 {
tb.tokens = tb.tokens - 1
return true
}

return false
}

0 comments on commit 1a71d90

Please sign in to comment.