-
Notifications
You must be signed in to change notification settings - Fork 7
/
host.go
90 lines (78 loc) · 2.2 KB
/
host.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
// SPDX-FileCopyrightText: 2014-2024 caixw
//
// SPDX-License-Identifier: MIT
package group
import (
"net/http"
"strings"
"github.com/issue9/mux/v8"
"github.com/issue9/mux/v8/internal/syntax"
"github.com/issue9/mux/v8/internal/tree"
"github.com/issue9/mux/v8/types"
)
// Hosts 限定域名的匹配工具
type Hosts struct {
i *syntax.Interceptors
tree *tree.Tree[any]
}
// NewHosts 声明新的 [Hosts] 实例
func NewHosts(lock bool, domain ...string) *Hosts {
i := syntax.NewInterceptors()
f := func(types.Node) any { return nil }
t := tree.New(lock, i, nil, false, f, f)
h := &Hosts{tree: t, i: i}
h.Add(domain...)
return h
}
// RegisterInterceptor 注册拦截器
//
// NOTE: 拦截器只有在注册之后添加的域名才有效果。
func (hs *Hosts) RegisterInterceptor(f mux.InterceptorFunc, name ...string) {
hs.i.Add(f, name...)
}
func (hs *Hosts) Match(r *http.Request, ctx *types.Context) bool {
h := r.Host // r.URL.Hostname() 可能为空,r.Host 一直有值!
if i := strings.LastIndexByte(h, ':'); i != -1 && validOptionalPort(h[i:]) {
h = h[:i]
}
if strings.HasPrefix(h, "[") && strings.HasSuffix(h, "]") { // ipv6
h = h[1 : len(h)-1]
}
ctx.Path = strings.ToLower(h)
_, _, exists := hs.tree.Handler(ctx, http.MethodGet)
return exists
}
// 源自 https://github.com/golang/go/blob/d8762b2f4532cc2e5ec539670b88bbc469a13938/src/net/url/url.go#L769
func validOptionalPort(port string) bool {
if port == "" {
return true
}
if port[0] != ':' {
return false
}
for _, b := range port[1:] {
if b < '0' || b > '9' {
return false
}
}
return true
}
// Add 添加新的域名
//
// 域名的格式和路由的语法格式是一样的,比如:
//
// api.example.com
// {sub:[a-z]+}.example.com
//
// 如果存在命名参数,也可以通过也可通过 [types.Params] 接口获取。
// 当语法错误时,会触发 panic,可通过 [mux.CheckSyntax] 检测语法的正确性。
func (hs *Hosts) Add(domain ...string) {
for _, d := range domain {
err := hs.tree.Add(strings.ToLower(d), hs.emptyHandlerFunc, nil, http.MethodGet)
if err != nil {
panic(err)
}
}
}
func (hs *Hosts) Delete(domain string) { hs.tree.Remove(domain) }
func (hs *Hosts) emptyHandlerFunc() {}