/
port.go
117 lines (107 loc) · 2.26 KB
/
port.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
package godiscovery
import (
"context"
"errors"
"fmt"
"math"
"math/rand"
"net"
"strconv"
"time"
"github.com/coreos/etcd/clientv3"
)
type Port struct {
port uint16
ctx context.Context
ctxCancel context.CancelFunc
}
func (this *Port) GetPort() uint16 {
return this.port
}
func (this *Port) Init(root context.Context, client *clientv3.Client) error {
key := "__#ETCDPORT#__"
this.ctx, this.ctxCancel = context.WithCancel(root)
rep, err := client.Get(this.ctx, key)
if err != nil {
return err
}
var port uint16 = 1024
var pre uint = 0
var version int64
if rep.Count != 0 {
temp, err := strconv.Atoi(string(rep.Kvs[0].Value))
if err != nil {
return err
}
port = uint16(uint(temp) % math.MaxUint16)
pre = uint(temp) - uint(port)
version = rep.Kvs[0].Version
} else {
txnRep, err := client.Txn(this.ctx).
If(clientv3.Compare(clientv3.CreateRevision(key), "=", 0)).
Then(clientv3.OpPut(key, strconv.FormatInt(int64(port), 10))).
Else(clientv3.OpGet(key)).
Commit()
if err != nil {
return err
}
if txnRep.Succeeded {
version = 1
} else {
version = txnRep.Responses[0].GetResponseRange().Kvs[0].Version
}
}
for {
port = this.getVaildPort(port)
if port == 0 {
return errors.New("invild port!")
}
data := strconv.FormatInt(int64(pre)+int64(port), 10)
txnRep, err := client.Txn(this.ctx).
If(clientv3.Compare(clientv3.Version(key), "=", version)).
Then(clientv3.OpPut(key, data)).
Else(clientv3.OpGet(key)).
Commit()
if err != nil {
return err
}
if txnRep.Succeeded {
break
} else {
port++
version = txnRep.Responses[0].GetResponseRange().Kvs[0].Version
}
time.Sleep(time.Duration(rand.Int31n(5)+1) * time.Second)
}
this.port = port
fmt.Printf("node's port:%d\n", port)
return nil
}
func (this *Port) getVaildPort(port uint16) uint16 {
counter := 0
for {
counter++
if counter > math.MaxUint16+100 {
break
}
port = port + 1
if port == 0 {
continue
}
address := fmt.Sprintf(":%d", port)
tcpAddr, err := net.ResolveTCPAddr("tcp", address)
if err != nil {
continue
}
listener, err := net.ListenTCP("tcp", tcpAddr)
if err != nil {
if listener != nil {
listener.Close()
}
continue
}
listener.Close()
return port
}
return 0
}