-
-
Notifications
You must be signed in to change notification settings - Fork 80
/
server.go
202 lines (186 loc) · 4.7 KB
/
server.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
package server
import (
"context"
"errors"
"fmt"
"github.com/anacrolix/torrent/bencode"
"github.com/bitmagnet-io/bitmagnet/internal/protocol/dht"
"github.com/bitmagnet-io/bitmagnet/internal/protocol/dht/responder"
"go.uber.org/zap"
"net/netip"
"sync"
"time"
)
type Server interface {
start() error
stop()
Query(ctx context.Context, addr netip.AddrPort, q string, args dht.MsgArgs) (dht.RecvMsg, error)
}
type server struct {
stopped chan struct{}
mutex sync.Mutex
localAddr netip.AddrPort
socket Socket
queryTimeout time.Duration
queries map[string]chan dht.RecvMsg
responder responder.Responder
responderTimeout time.Duration
idIssuer IdIssuer
logger *zap.SugaredLogger
}
func (s *server) start() error {
if err := s.socket.Open(s.localAddr); err != nil {
return fmt.Errorf("could not open socket: %w", err)
}
go func() {
ctx, cancel := context.WithCancel(context.Background())
go s.read(ctx)
<-s.stopped
cancel()
_ = s.socket.Close()
}()
return nil
}
func (s *server) stop() {
close(s.stopped)
}
func (s *server) read(ctx context.Context) {
/* The field size sets a theoretical limit of 65,535 bytes (8 byte header + 65,527 bytes of
* data) for a UDP datagram. However the actual limit for the data length, which is imposed by
* the underlying IPv4 protocol, is 65,507 bytes (65,535 − 8 byte UDP header − 20 byte IP
* header).
*
* In IPv6 jumbograms it is possible to have UDP packets of size greater than 65,535 bytes.
* RFC 2675 specifies that the length field is set to zero if the length of the UDP header plus
* UDP data is greater than 65,535.
*
* https://en.wikipedia.org/wiki/User_Datagram_Protocol
*/
buffer := make([]byte, 65507)
for {
if ctx.Err() != nil {
return
}
n, from, err := s.socket.Receive(buffer)
if err != nil {
// Socket is probably closed; if we're not shutting down then panic
if ctx.Err() == nil {
panic(fmt.Errorf("socket read error: %w", err))
}
return
}
if n == 0 {
/* Datagram sockets in various domains (e.g., the UNIX and Internet domains) permit
* zero-length datagrams. When such a datagram is received, the return value (n) is 0.
*/
continue
}
var msg dht.Msg
err = bencode.Unmarshal(buffer[:n], &msg)
if err != nil {
s.logger.Debugw("could not unmarshal packet data", "error", err)
continue
}
recvMsg := dht.RecvMsg{
Msg: msg,
From: from,
}
switch msg.Y {
case dht.YQuery:
go s.handleQuery(recvMsg)
case dht.YResponse, dht.YError:
go s.handleResponse(recvMsg)
}
}
}
func (s *server) handleQuery(msg dht.RecvMsg) {
ctx, cancel := context.WithTimeout(context.Background(), s.responderTimeout)
defer cancel()
res := dht.Msg{
T: msg.Msg.T,
Y: dht.YResponse,
}
ret, retErr := s.responder.Respond(ctx, msg)
if retErr != nil {
dhtErr := &dht.Error{}
if ok := errors.As(retErr, dhtErr); ok {
res.E = dhtErr
} else {
res.E = &dht.Error{
Code: dht.ErrorCodeServerError,
Msg: "server error",
}
s.logger.Errorw("server error", "msg", msg, "retErr", retErr)
}
} else {
res.R = &ret
}
if sendErr := s.send(msg.From, res); sendErr != nil {
s.logger.Debugw("could not send response", "msg", msg, "retErr", sendErr)
}
}
func (s *server) handleResponse(msg dht.RecvMsg) {
transactionId := msg.Msg.T
s.mutex.Lock()
ch, ok := s.queries[transactionId]
s.mutex.Unlock()
if ok {
ch <- msg
}
}
func (s *server) Query(ctx context.Context, addr netip.AddrPort, q string, args dht.MsgArgs) (r dht.RecvMsg, err error) {
transactionId := s.idIssuer.Issue()
ch := make(chan dht.RecvMsg, 1)
s.mutex.Lock()
s.queries[transactionId] = ch
s.mutex.Unlock()
defer (func() {
s.mutex.Lock()
delete(s.queries, transactionId)
s.mutex.Unlock()
})()
msg := dht.Msg{
Q: q,
T: transactionId,
A: &args,
Y: dht.YQuery,
}
if sendErr := s.send(addr, msg); sendErr != nil {
s.logger.Debugw("could not send query", "msg", msg, "sendErr", sendErr)
err = sendErr
return
}
queryCtx, cancel := context.WithTimeout(ctx, s.queryTimeout)
defer cancel()
select {
case <-queryCtx.Done():
err = queryCtx.Err()
return
case res, ok := <-ch:
if !ok {
err = errors.New("channel closed")
return
}
r = res
if res.Msg.Y == dht.YError {
err = res.Msg.E
if err == nil {
err = errors.New("error missing from response")
}
} else if r.Msg.R == nil {
err = errors.New("return data missing from response")
}
return
}
}
func (s *server) send(addr netip.AddrPort, msg dht.Msg) error {
data, encodeErr := bencode.Marshal(msg)
if encodeErr != nil {
return encodeErr
}
sendErr := s.socket.Send(addr, data)
if sendErr != nil {
return sendErr
}
return nil
}