Skip to content

Commit

Permalink
refactor: blacklist and whitelist middleware
Browse files Browse the repository at this point in the history
  • Loading branch information
fufuok committed Mar 7, 2024
1 parent f7504f5 commit 8d95590
Show file tree
Hide file tree
Showing 6 changed files with 184 additions and 67 deletions.
38 changes: 38 additions & 0 deletions common/helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@ package common

import (
"net"
"strconv"

"github.com/fufuok/utils/xhash"
)

// LookupIPNetsString 从 IP 段集合中查询并返回对应数值
Expand All @@ -23,3 +26,38 @@ func LookupIPNets(ip net.IP, ipNets map[*net.IPNet]int64) (int64, bool) {
}
return 0, false
}

// GenSign 使用时间戳和密钥生成简单签名字符串
// 算法: md5(ts+key)
// 结果: ts+sign
func GenSign(ts int64, key string) string {
tss := strconv.FormatInt(ts, 10)
return GenSignString(tss, key)
}

// GenSignString 字符串类型的时间戳生成签名
func GenSignString(ts, key string) string {
if len(ts) != 10 || key == "" {
return ""
}
sign := xhash.MD5Hex(ts + key)
return ts + sign
}

// VerifySign 校验签名
func VerifySign(key, sign string) bool {
if key == "" || len(sign) != 42 {
return false
}
return sign == GenSignString(sign[:10], key)
}

// VerifySignTTL 校验签名及签名有效期(当前时间 **秒 范围内有效)
func VerifySignTTL(key, sign string, second int64) bool {
if ok := VerifySign(key, sign); !ok {
return false
}
ts, _ := strconv.ParseInt(sign[:10], 10, 64)
now := GTimestamp()
return ts >= now-second && ts <= now+second
}
17 changes: 17 additions & 0 deletions common/helper_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package common

import (
"testing"
"time"

"github.com/fufuok/utils"
"github.com/fufuok/utils/assert"
)

func TestGenSign(t *testing.T) {
key := utils.RandString(18)
ts := time.Now().Unix()
sign := GenSign(ts, key)
t.Log("sign:", sign)
assert.True(t, VerifySignTTL(key, sign, 1))
}
30 changes: 12 additions & 18 deletions web/fiber/middleware/blacklist.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,37 +2,31 @@ package middleware

import (
"fmt"
"net/http"

"github.com/gofiber/fiber/v2"

"github.com/fufuok/pkg/common"
"github.com/fufuok/pkg/config"
"github.com/fufuok/pkg/logger/sampler"
"github.com/fufuok/pkg/web/fiber/proxy"
"github.com/fufuok/pkg/web/fiber/response"
)

// CheckBlacklist 接口黑名单检查
func CheckBlacklist(asAPI bool) fiber.Handler {
errMsg := fmt.Sprintf("[ERROR] 非法访问(%s): ", config.AppName)
return func(c *fiber.Ctx) error {
if len(config.Blacklist) > 0 {
clientIP := proxy.GetClientIP(c)
if _, ok := common.LookupIPNetsString(clientIP, config.Blacklist); ok {
msg := errMsg + clientIP
sampler.Info().
Str("cip", c.IP()).Str("x_forwarded_for", c.Get(fiber.HeaderXForwardedFor)).
Str(proxy.HeaderXProxyClientIP, c.Get(proxy.HeaderXProxyClientIP)).
Str("method", c.Method()).Str("uri", c.OriginalURL()).Str("client_ip", clientIP).
Msg(msg)
if asAPI {
return response.APIException(c, http.StatusForbidden, msg, nil)
} else {
return response.TxtException(c, http.StatusForbidden, msg)
}
}
if BlacklistChecker(c) {
return responseForbidden(c, errMsg, asAPI)
}
return c.Next()
}
}

// BlacklistChecker 是否存在于黑名单, true 是黑名单 (黑名单为空时: 放过, false)
func BlacklistChecker(c *fiber.Ctx) bool {
clientIP := proxy.GetClientIP(c)
if len(config.Blacklist) > 0 {
_, ok := common.LookupIPNetsString(clientIP, config.Blacklist)
return ok
}
return false
}
66 changes: 51 additions & 15 deletions web/fiber/middleware/whitelist.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,26 +13,62 @@ import (
"github.com/fufuok/pkg/web/fiber/response"
)

type ForbiddenChecker = func(*fiber.Ctx) bool

// CheckWhitelist 接口白名单检查
func CheckWhitelist(asAPI bool) fiber.Handler {
errMsg := fmt.Sprintf("[ERROR] 非法来访(%s): ", config.AppName)
return func(c *fiber.Ctx) error {
if len(config.Whitelist) > 0 {
clientIP := proxy.GetClientIP(c)
if _, ok := common.LookupIPNetsString(clientIP, config.Whitelist); !ok {
msg := errMsg + clientIP
sampler.Info().
Str("cip", c.IP()).Str("x_forwarded_for", c.Get(fiber.HeaderXForwardedFor)).
Str(proxy.HeaderXProxyClientIP, c.Get(proxy.HeaderXProxyClientIP)).
Str("method", c.Method()).Str("uri", c.OriginalURL()).Str("client_ip", clientIP).
Msg(msg)
if asAPI {
return response.APIException(c, http.StatusForbidden, msg, nil)
} else {
return response.TxtException(c, http.StatusForbidden, msg)
}
}
if !WhitelistChecker(c) {
return responseForbidden(c, errMsg, asAPI)
}
return c.Next()
}
}

// CheckWhitelistOr 校验接口白名单或自定义检查器
func CheckWhitelistOr(checker ForbiddenChecker, asAPI bool) fiber.Handler {
errMsg := fmt.Sprintf("[ERROR] 禁止来访(%s): ", config.AppName)
return func(c *fiber.Ctx) error {
if !WhitelistChecker(c) && !checker(c) {
return responseForbidden(c, errMsg, asAPI)
}
return c.Next()
}
}

// CheckWhitelistAnd 同时校验接口白名单和自定义检查器
func CheckWhitelistAnd(checker ForbiddenChecker, asAPI bool) fiber.Handler {
errMsg := fmt.Sprintf("[ERROR] 禁止访问(%s): ", config.AppName)
return func(c *fiber.Ctx) error {
if !WhitelistChecker(c) || !checker(c) {
return responseForbidden(c, errMsg, asAPI)
}
return c.Next()
}
}

// WhitelistChecker 是否通过了白名单检查, true 是白名单 (白名单为空时: 通过, true)
func WhitelistChecker(c *fiber.Ctx) bool {
clientIP := proxy.GetClientIP(c)
if len(config.Whitelist) > 0 {
_, ok := common.LookupIPNetsString(clientIP, config.Whitelist)
return ok
}
return true
}

func responseForbidden(c *fiber.Ctx, msg string, asAPI bool) error {
clientIP := proxy.GetClientIP(c)
msg += clientIP
sampler.Info().
Str("cip", c.IP()).Str("x_forwarded_for", c.Get(fiber.HeaderXForwardedFor)).
Str(proxy.HeaderXProxyClientIP, c.Get(proxy.HeaderXProxyClientIP)).
Str("method", c.Method()).Str("uri", c.OriginalURL()).Str("client_ip", clientIP).
Msg(msg)

if asAPI {
return response.APIException(c, http.StatusForbidden, msg, nil)
}
return response.TxtException(c, http.StatusForbidden, msg)
}
32 changes: 13 additions & 19 deletions web/gin/middleware/blacklist.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,37 +2,31 @@ package middleware

import (
"fmt"
"net/http"

"github.com/gin-gonic/gin"

"github.com/fufuok/pkg/common"
"github.com/fufuok/pkg/config"
"github.com/fufuok/pkg/logger/sampler"
"github.com/fufuok/pkg/web/gin/response"
)

// CheckBlacklist 接口黑名单检查
func CheckBlacklist(asAPI bool) gin.HandlerFunc {
errMsg := fmt.Sprintf("[ERROR] 非法访问(%s): ", config.AppName)
return func(c *gin.Context) {
if len(config.Blacklist) > 0 {
clientIP := c.ClientIP()
if _, ok := common.LookupIPNetsString(clientIP, config.Blacklist); ok {
msg := errMsg + clientIP
sampler.Info().
Str("cip", clientIP).Str("x_forwarded_for", c.GetHeader("X-Forwarded-For")).
Str("method", c.Request.Method).Str("uri", c.Request.RequestURI).
Msg(msg)
if asAPI {
response.APIException(c, http.StatusForbidden, msg, nil)
} else {
response.TxtException(c, http.StatusForbidden, msg)
}
return
}
if BlacklistChecker(c) {
responseForbidden(c, errMsg, asAPI)
return
}

c.Next()
}
}

// BlacklistChecker 是否存在于黑名单, true 是黑名单 (黑名单为空时: 放过, false)
func BlacklistChecker(c *gin.Context) bool {
clientIP := c.ClientIP()
if len(config.Blacklist) > 0 {
_, ok := common.LookupIPNetsString(clientIP, config.Blacklist)
return ok
}
return false
}
68 changes: 53 additions & 15 deletions web/gin/middleware/whitelist.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,27 +12,65 @@ import (
"github.com/fufuok/pkg/web/gin/response"
)

type ForbiddenChecker = func(*gin.Context) bool

// CheckWhitelist 接口白名单检查
func CheckWhitelist(asAPI bool) gin.HandlerFunc {
errMsg := fmt.Sprintf("[ERROR] 非法来访(%s): ", config.AppName)
return func(c *gin.Context) {
if len(config.Whitelist) > 0 {
clientIP := c.ClientIP()
if _, ok := common.LookupIPNetsString(clientIP, config.Whitelist); !ok {
msg := errMsg + clientIP
sampler.Info().
Str("cip", clientIP).Str("x_forwarded_for", c.GetHeader("X-Forwarded-For")).
Str("method", c.Request.Method).Str("uri", c.Request.RequestURI).
Msg(msg)
if asAPI {
response.APIException(c, http.StatusForbidden, msg, nil)
} else {
response.TxtException(c, http.StatusForbidden, msg)
}
return
}
if !WhitelistChecker(c) {
responseForbidden(c, errMsg, asAPI)
return
}
c.Next()
}
}

// CheckWhitelistOr 校验接口白名单或自定义检查器
func CheckWhitelistOr(checker ForbiddenChecker, asAPI bool) gin.HandlerFunc {
errMsg := fmt.Sprintf("[ERROR] 禁止来访(%s): ", config.AppName)
return func(c *gin.Context) {
if !WhitelistChecker(c) && !checker(c) {
responseForbidden(c, errMsg, asAPI)
return
}
c.Next()
}
}

// CheckWhitelistAnd 同时校验接口白名单和自定义检查器
func CheckWhitelistAnd(checker ForbiddenChecker, asAPI bool) gin.HandlerFunc {
errMsg := fmt.Sprintf("[ERROR] 禁止访问(%s): ", config.AppName)
return func(c *gin.Context) {
if !WhitelistChecker(c) || !checker(c) {
responseForbidden(c, errMsg, asAPI)
return
}
c.Next()
}
}

// WhitelistChecker 是否通过了白名单检查, true 是白名单 (白名单为空时: 通过, true)
func WhitelistChecker(c *gin.Context) bool {
clientIP := c.ClientIP()
if len(config.Whitelist) > 0 {
_, ok := common.LookupIPNetsString(clientIP, config.Whitelist)
return ok
}
return true
}

func responseForbidden(c *gin.Context, msg string, asAPI bool) {
clientIP := c.ClientIP()
msg += clientIP
sampler.Info().
Str("cip", clientIP).Str("x_forwarded_for", c.GetHeader("X-Forwarded-For")).
Str("method", c.Request.Method).Str("uri", c.Request.RequestURI).
Msg(msg)

if asAPI {
response.APIException(c, http.StatusForbidden, msg, nil)
} else {
response.TxtException(c, http.StatusForbidden, msg)
}
}

0 comments on commit 8d95590

Please sign in to comment.