-
Notifications
You must be signed in to change notification settings - Fork 0
/
listener.go
152 lines (130 loc) · 3.51 KB
/
listener.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
package brpc
import (
"context"
"errors"
"github.com/quic-go/quic-go"
"go.uber.org/multierr"
"io"
"log/slog"
"net"
"os"
"sync"
"syscall"
)
var _ net.Listener = &multiListener{}
// multiListener is an implementation of net.Listener that wraps many net.Listeners
// which can be added to the multiListener. Callers can add new net.Listeners and
// other callers can block on Accept() to receive a net.Conn from any of the added
// net.Listeners.
//
// In our case, brpc servers use this because we want to handle incoming connections
// ourselves first (to negotiate a yamux session, pass client ids, etc), and then
// pass the yamux.Session (which implements net.Listener) into the multiListener
// so that our gRPC server can accept all future connections from the yamux.Session.
type multiListener struct {
listenersLock sync.Mutex
listeners []net.Listener
connChan chan net.Conn
errChan chan error // errors on this channel
closeChan chan struct{}
wg sync.WaitGroup
logger *slog.Logger
}
func newMultiListener() *multiListener {
return &multiListener{
connChan: make(chan net.Conn),
errChan: make(chan error, 1), // buffered channel for at least one error
closeChan: make(chan struct{}),
logger: slog.Default(),
}
}
func (ml *multiListener) AddListener(l net.Listener) {
ml.listenersLock.Lock()
ml.listeners = append(ml.listeners, l)
ml.listenersLock.Unlock()
ml.wg.Add(1)
go func() {
defer ml.wg.Done()
for {
conn, err := l.Accept()
if err != nil {
if !isTransientError(err) {
slog.Warn("error accepting connection", "error", err)
}
return
}
select {
case ml.connChan <- conn:
case <-ml.closeChan:
return
}
}
}()
}
func (ml *multiListener) Accept() (net.Conn, error) {
select {
case conn := <-ml.connChan:
return conn, nil
case err := <-ml.errChan:
return nil, err
case <-ml.closeChan:
return nil, net.ErrClosed
}
}
func (ml *multiListener) Close() error {
close(ml.closeChan)
ml.wg.Wait()
var err error
for _, l := range ml.listeners {
err = multierr.Append(err, l.Close())
}
return err
}
func (ml *multiListener) Addr() net.Addr {
return &net.TCPAddr{}
}
func isTransientError(err error) bool {
// Directly check for net.ErrClosed or io.EOF
if errors.Is(err, net.ErrClosed) || errors.Is(err, io.EOF) {
return true
}
// Check for a net.OpError and a nested os.SyscallError indicating ECONNRESET
var opErr *net.OpError
if errors.As(err, &opErr) && opErr.Err != nil {
var syscallErr *os.SyscallError
if errors.As(opErr.Err, &syscallErr) && syscallErr.Err == syscall.ECONNRESET {
return true
}
}
return false
}
var _ net.Conn = &quicConn{}
// quicConn is a net.Conn implementation that wraps a quic.Stream.
type quicConn struct {
quic.Stream
}
func (q *quicConn) LocalAddr() net.Addr {
return (*net.TCPAddr)(nil)
}
func (q *quicConn) RemoteAddr() net.Addr {
return (*net.TCPAddr)(nil)
}
var _ net.Listener = &quicListener{}
// quicListener is a net.Listener implementation that wraps a quic.Connection
// and allows consumers of a net.Listener to accept bi-directional quic streams.
type quicListener struct {
conn quic.Connection
}
func (q *quicListener) Accept() (net.Conn, error) {
stream, err := q.conn.AcceptStream(context.Background())
if err != nil {
return nil, err
}
return &quicConn{Stream: stream}, nil
}
func (q *quicListener) Close() error {
return q.conn.CloseWithError(quic.ApplicationErrorCode(quic.NoError), "")
}
func (q *quicListener) Addr() net.Addr {
return (*net.TCPAddr)(nil)
}