Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support node sort by tcp ping latency #920

Merged
merged 3 commits into from
Jun 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 13 additions & 8 deletions cmd/gost/peer.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,16 @@ import (
)

type peerConfig struct {
Strategy string `json:"strategy"`
MaxFails int `json:"max_fails"`
FailTimeout time.Duration
period time.Duration // the period for live reloading
Nodes []string `json:"nodes"`
group *gost.NodeGroup
baseNodes []gost.Node
stopped chan struct{}
Strategy string `json:"strategy"`
MaxFails int `json:"max_fails"`
FastestCount int `json:"fastest_count"` // topN fastest node count
FailTimeout time.Duration
period time.Duration // the period for live reloading

Nodes []string `json:"nodes"`
group *gost.NodeGroup
baseNodes []gost.Node
stopped chan struct{}
}

func newPeerConfig() *peerConfig {
Expand Down Expand Up @@ -52,6 +54,7 @@ func (cfg *peerConfig) Reload(r io.Reader) error {
FailTimeout: cfg.FailTimeout,
},
&gost.InvalidFilter{},
gost.NewFastestFilter(0, cfg.FastestCount),
),
gost.WithStrategy(gost.NewStrategy(cfg.Strategy)),
)
Expand Down Expand Up @@ -126,6 +129,8 @@ func (cfg *peerConfig) parse(r io.Reader) error {
cfg.Strategy = ss[1]
case "max_fails":
cfg.MaxFails, _ = strconv.Atoi(ss[1])
case "fastest_count":
cfg.FastestCount, _ = strconv.Atoi(ss[1])
case "fail_timeout":
cfg.FailTimeout, _ = time.ParseDuration(ss[1])
case "reload":
Expand Down
1 change: 1 addition & 0 deletions cmd/gost/route.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ func (r *route) parseChain() (*gost.Chain, error) {
FailTimeout: nodes[0].GetDuration("fail_timeout"),
},
&gost.InvalidFilter{},
gost.NewFastestFilter(0, nodes[0].GetInt("fastest_count")),
),
gost.WithStrategy(gost.NewStrategy(nodes[0].Get("strategy"))),
)
Expand Down
91 changes: 91 additions & 0 deletions selector.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,13 @@ import (
"errors"
"math/rand"
"net"
"sort"
"strconv"
"sync"
"sync/atomic"
"time"

"github.com/go-log/log"
)

var (
Expand Down Expand Up @@ -205,6 +208,94 @@ func (f *FailFilter) String() string {
return "fail"
}

// FastestFilter filter the fastest node
type FastestFilter struct {
mu sync.Mutex

pinger *net.Dialer
pingResult map[int]int
pingResultTTL map[int]int64

topCount int
}

func NewFastestFilter(pingTimeOut int, topCount int) *FastestFilter {
if pingTimeOut == 0 {
pingTimeOut = 3000 // 3s
}
return &FastestFilter{
mu: sync.Mutex{},
pinger: &net.Dialer{Timeout: time.Millisecond * time.Duration(pingTimeOut)},
pingResult: make(map[int]int, 0),
pingResultTTL: make(map[int]int64, 0),
topCount: topCount,
}
}

func (f *FastestFilter) Filter(nodes []Node) []Node {
// disabled
if f.topCount == 0 {
return nodes
}

// get latency with ttl cache
now := time.Now().Unix()

var getNodeLatency = func(node Node) int {
f.mu.Lock()
defer f.mu.Unlock()

if f.pingResultTTL[node.ID] < now {
f.pingResultTTL[node.ID] = now + 5 // tmp

// get latency
go func(node Node) {
latency := f.doTcpPing(node.Addr)
r := rand.New(rand.NewSource(time.Now().UnixNano()))
ttl := 300 - int64(120*r.Float64())

f.mu.Lock()
defer f.mu.Unlock()

f.pingResult[node.ID] = latency
f.pingResultTTL[node.ID] = now + ttl
}(node)
}
return f.pingResult[node.ID]
}

// sort
sort.Slice(nodes, func(i, j int) bool {
return getNodeLatency(nodes[i]) < getNodeLatency(nodes[j])
})

// split
if len(nodes) <= f.topCount {
return nodes
}

return nodes[0:f.topCount]
}

func (f *FastestFilter) String() string {
return "fastest"
}

// doTcpPing
func (f *FastestFilter) doTcpPing(address string) int {
start := time.Now()
conn, err := f.pinger.Dial("tcp", address)
elapsed := time.Since(start)

if err == nil {
_ = conn.Close()
}

latency := int(elapsed.Milliseconds())
log.Logf("pingDoTCP: %s, latency: %d", address, latency)
return latency
}

// InvalidFilter filters the invalid node.
// A node is invalid if its port is invalid (negative or zero value).
type InvalidFilter struct{}
Expand Down
24 changes: 24 additions & 0 deletions selector_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,30 @@ func TestFailFilter(t *testing.T) {
}
}

func TestFastestFilter(t *testing.T) {
nodes := []Node{
Node{ID: 1, marker: &failMarker{}, Addr: "1.0.0.1:80"},
Node{ID: 2, marker: &failMarker{}, Addr: "1.0.0.2:80"},
Node{ID: 3, marker: &failMarker{}, Addr: "1.0.0.3:80"},
}
filter := NewFastestFilter(0, 2)

var print = func(nodes []Node) []string {
var rows []string
for _, node := range nodes {
rows = append(rows, node.Addr)
}
return rows
}

result1 := filter.Filter(nodes)
t.Logf("result 1: %+v", print(result1))

time.Sleep(time.Second)
result2 := filter.Filter(nodes)
t.Logf("result 2: %+v", print(result2))
}

func TestSelector(t *testing.T) {
nodes := []Node{
Node{ID: 1, marker: &failMarker{}},
Expand Down