Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Overhaul UDP server #141

Merged
merged 3 commits into from
May 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
166 changes: 145 additions & 21 deletions layer4/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package layer4
import (
"bytes"
"fmt"
"io"
"net"
"sync"
"time"
Expand Down Expand Up @@ -76,23 +77,69 @@ func (s Server) serve(ln net.Listener) error {
}

func (s Server) servePacket(pc net.PacketConn) error {
// Spawn a goroutine whose only job is to consume packets from the socket
// and send to the packets channel.
packets := make(chan packet, 10)
go func(packets chan packet) {
for {
buf := udpBufPool.Get().([]byte)
n, addr, err := pc.ReadFrom(buf)
if err != nil {
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
continue
}
packets <- packet{err: err}
return
}
packets <- packet{
pooledBuf: buf,
n: n,
addr: addr,
}
}
}(packets)

// udpConns tracks active packetConns by downstream address:port. They will
// be removed from this map after being closed.
udpConns := make(map[string]*packetConn)
// closeCh is used to receive notifications of socket closures from
// packetConn, which allows us to to remove stale connections (whose
// proxy handlers have completed) from the udpConns map.
closeCh := make(chan string, 10)
for {
buf := udpBufPool.Get().([]byte)
n, addr, err := pc.ReadFrom(buf)
if err != nil {
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
continue
select {
case addr := <-closeCh:
// UDP connection is closed (either implicitly through timeout or by
// explicit call to Close()).
delete(udpConns, addr)

case pkt := <-packets:
if pkt.err != nil {
return pkt.err
}
return err
conn, ok := udpConns[pkt.addr.String()]
if !ok {
// No existing proxy handler is running for this downstream.
// Create one now.
conn = &packetConn{
PacketConn: pc,
readCh: make(chan *packet, 5),
addr: pkt.addr,
closeCh: closeCh,
}
udpConns[pkt.addr.String()] = conn
go func(conn *packetConn) {
s.handle(conn)
// It might seem cleaner to send to closeCh here rather than
// in packetConn, but doing it earlier in packetConn closes
// the gap between the proxy handler shutting down and new
// packets coming in from the same downstream. Should that
// happen, we'll just spin up a new handler concurrent to
// the old one shutting down.
}(conn)
}
conn.readCh <- &pkt
}
go func(buf []byte, n int, addr net.Addr) {
defer udpBufPool.Put(buf)
s.handle(packetConn{
PacketConn: pc,
buf: bytes.NewBuffer(buf[:n]),
addr: addr,
})
}(buf, n, addr)
}
}

Expand Down Expand Up @@ -120,29 +167,106 @@ func (s Server) handle(conn net.Conn) {
)
}

type packet struct {
// The underlying bytes slice that was gotten from udpBufPool. It's up to
// packetConn to return it to udpBufPool once it's consumed.
pooledBuf []byte
// Number of bytes read from socket
n int
// Error that occurred while reading from socket
err error
// Address of downstream
addr net.Addr
}

type packetConn struct {
net.PacketConn
buf *bytes.Buffer
addr net.Addr
addr net.Addr
readCh chan *packet
closeCh chan string
// If not nil, then the previous Read() call didn't consume all the data
// from the buffer, and this packet will be reused in the next Read()
// without waiting for readCh.
lastPacket *packet
lastBuf *bytes.Buffer
}

func (pc packetConn) Read(b []byte) (n int, err error) {
return pc.buf.Read(b)
func (pc *packetConn) Read(b []byte) (n int, err error) {
if pc.lastPacket != nil {
// There is a partial buffer to continue reading from the previous
// packet.
n, err = pc.lastBuf.Read(b)
if pc.lastBuf.Len() == 0 {
udpBufPool.Put(pc.lastPacket.pooledBuf)
pc.lastPacket = nil
pc.lastBuf = nil
}
return
}
select {
case pkt := <-pc.readCh:
if pkt == nil {
// Channel is closed. Return EOF below.
break
}
buf := bytes.NewBuffer(pkt.pooledBuf[:pkt.n])
n, err = buf.Read(b)
if buf.Len() == 0 {
// Buffer fully consumed, release it.
udpBufPool.Put(pkt.pooledBuf)
} else {
// Buffer only partially consumed. Keep track of it for
// next Read() call.
pc.lastPacket = pkt
pc.lastBuf = buf
}
return
// TODO: idle timeout should be configurable per server
case <-time.After(30 * time.Second):
break
}
// Idle timeout simulates socket closure.
//
// Although Close() also does this, we inform the server loop early about
// the closure to ensure that if any new packets are received from this
// connection in the meantime, a new handler will be started.
pc.closeCh <- pc.addr.String()
// Returning EOF here ensures that io.Copy() waiting on the downstream for
// reads will terminate.
return 0, io.EOF
}

func (pc packetConn) Write(b []byte) (n int, err error) {
return pc.PacketConn.WriteTo(b, pc.addr)
}

func (pc packetConn) Close() error {
// Do nothing, we don't want to close the UDP server
func (pc *packetConn) Close() error {
if pc.lastPacket != nil {
udpBufPool.Put(pc.lastPacket.pooledBuf)
pc.lastPacket = nil
}
// This will abort any active Read() from another goroutine and return EOF
close(pc.readCh)
// Drain pending packets to ensure we release buffers back to the pool
for pkt := range pc.readCh {
udpBufPool.Put(pkt.pooledBuf)
}
// We may have already done this earlier in Read(), but just in case
// Read() wasn't being called, (re-)notify server loop we're closed.
pc.closeCh <- pc.addr.String()
// We don't call net.PacketConn.Close() here as we would stop the UDP
// server.
return nil
}

func (pc packetConn) RemoteAddr() net.Addr { return pc.addr }

var udpBufPool = sync.Pool{
New: func() interface{} {
return make([]byte, 1024)
// Buffers need to be as large as the largest datagram we'll consume, because
// ReadFrom() can't resume partial reads. (This is standard for UDP
// sockets on *nix.) So our buffer sizes are 9000 bytes to accommodate
// networks with jumbo frames. See also https://github.com/golang/go/issues/18056
return make([]byte, 9000)
},
}
26 changes: 21 additions & 5 deletions modules/l4proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"net"
"runtime/debug"
"sync"
"sync/atomic"
"time"

"github.com/caddyserver/caddy/v2"
Expand Down Expand Up @@ -253,6 +254,7 @@ func (h *Handler) proxy(down *layer4.Connection, upConns []net.Conn) {
}

var wg sync.WaitGroup
var downClosed atomic.Bool

for _, up := range upConns {
wg.Add(1)
Expand All @@ -261,11 +263,16 @@ func (h *Handler) proxy(down *layer4.Connection, upConns []net.Conn) {
defer wg.Done()

if _, err := io.Copy(down, up); err != nil {
h.logger.Error("upstream connection",
zap.String("local_address", up.LocalAddr().String()),
zap.String("remote_address", up.RemoteAddr().String()),
zap.Error(err),
)
// If the downstream connection has been closed, we can assume this is
// the reason io.Copy() errored. That's normal operation for UDP
// connections after idle timeout, so don't log an error in that case.
if !downClosed.Load() {
h.logger.Error("upstream connection",
zap.String("local_address", up.LocalAddr().String()),
zap.String("remote_address", up.RemoteAddr().String()),
zap.Error(err),
)
}
}
}(up)
}
Expand All @@ -280,9 +287,18 @@ func (h *Handler) proxy(down *layer4.Connection, upConns []net.Conn) {

// Shut down the writing side of all upstream connections, in case
// that the downstream connection is half closed. (issue #40)
//
// UDP connections meanwhile don't implement CloseWrite(), but in order
// to ensure io.Copy() in the per-upstream goroutines (above) returns,
// we need to close the socket. This will cause io.Copy() return an
// error, which in this particular case is expected, so we signal the
// intentional closure by setting this flag.
downClosed.Store(true)
for _, up := range upConns {
if conn, ok := up.(closeWriter); ok {
_ = conn.CloseWrite()
} else {
up.Close()
}
}
}()
Expand Down