Skip to content

Commit

Permalink
Use single goroutine per UDP connection
Browse files Browse the repository at this point in the history
  • Loading branch information
jtackaberry committed Aug 14, 2023
1 parent ca42e7e commit b91d7e8
Showing 1 changed file with 140 additions and 20 deletions.
160 changes: 140 additions & 20 deletions layer4/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package layer4
import (
"bytes"
"fmt"
"io"
"net"
"sync"
"time"
Expand Down Expand Up @@ -76,23 +77,69 @@ func (s Server) serve(ln net.Listener) error {
}

func (s Server) servePacket(pc net.PacketConn) error {
// Spawn a goroutine whose only job is to consume packets from the socket
// and send to the packets channel.
packets := make(chan packet, 10)
go func(packets chan packet) {
for {
buf := udpBufPool.Get().([]byte)
n, addr, err := pc.ReadFrom(buf)
if err != nil {
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
continue
}
packets <- packet{err: err}
return
}
packets <- packet{
pooledBuf: buf,
n: n,
addr: addr,
}
}
}(packets)

// udpConns tracks active packetConns by downstream address:port. They will
// be removed from this map after being closed.
udpConns := make(map[string]*packetConn)
// closeCh is used to receive notifications of socket closures from
// packetConn, which allows us to to remove stale connections (whose
// proxy handlers have completed) from the udpConns map.
closeCh := make(chan string, 10)
for {
buf := udpBufPool.Get().([]byte)
n, addr, err := pc.ReadFrom(buf)
if err != nil {
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
continue
select {
case addr := <-closeCh:
// UDP connection is closed (either implicitly through timeout or by
// explicit call to Close()).
delete(udpConns, addr)

case pkt := <-packets:
if pkt.err != nil {
return pkt.err
}
return err
conn, ok := udpConns[pkt.addr.String()]
if !ok {
// No existing proxy handler is running for this downstream.
// Create one now.
conn = &packetConn{
PacketConn: pc,
readCh: make(chan *packet, 5),
addr: pkt.addr,
closeCh: closeCh,
}
udpConns[pkt.addr.String()] = conn
go func(conn *packetConn) {
s.handle(conn)
// It might seem cleaner to send to closeCh here rather than
// in packetConn, but doing it earlier in packetConn closes
// the gap between the proxy handler shutting down and new
// packets coming in from the same downstream. Should that
// happen, we'll just spin up a new handler concurrent to
// the old one shutting down.
}(conn)
}
conn.readCh <- &pkt
}
go func(buf []byte, n int, addr net.Addr) {
defer udpBufPool.Put(buf)
s.handle(packetConn{
PacketConn: pc,
buf: bytes.NewBuffer(buf[:n]),
addr: addr,
})
}(buf, n, addr)
}
}

Expand Down Expand Up @@ -120,22 +167,95 @@ func (s Server) handle(conn net.Conn) {
)
}

type packet struct {
// The underlying bytes slice that was gotten from udpBufPool. It's up to
// packetConn to return it to udpBufPool once it's consumed.
pooledBuf []byte
// Number of bytes read from socket
n int
// Error that occurred while reading from socket
err error
// Address of downstream
addr net.Addr
}

type packetConn struct {
net.PacketConn
buf *bytes.Buffer
addr net.Addr
addr net.Addr
readCh chan *packet
closeCh chan string
// If not nil, then the previous Read() call didn't consume all the data
// from the buffer, and this packet will be reused in the next Read()
// without waiting for readCh.
lastPacket *packet
lastBuf *bytes.Buffer
}

func (pc packetConn) Read(b []byte) (n int, err error) {
return pc.buf.Read(b)
func (pc *packetConn) Read(b []byte) (n int, err error) {
if pc.lastPacket != nil {
// There is a partial buffer to continue reading from the previous
// packet.
n, err = pc.lastBuf.Read(b)
if pc.lastBuf.Len() == 0 {
udpBufPool.Put(pc.lastPacket.pooledBuf)
pc.lastPacket = nil
pc.lastBuf = nil
}
return
}
select {
case pkt := <-pc.readCh:
if pkt == nil {
// Channel is closed. Return EOF below.
break
}
buf := bytes.NewBuffer(pkt.pooledBuf[:pkt.n])
n, err = buf.Read(b)
if buf.Len() == 0 {
// Buffer fully consumed, release it.
udpBufPool.Put(pkt.pooledBuf)
} else {
// Buffer only partially consumed. Keep track of it for
// next Read() call.
pc.lastPacket = pkt
pc.lastBuf = buf
}
return
// TODO: idle timeout should be configurable per server
case <-time.After(30 * time.Second):
break
}
// Idle timeout simulates socket closure.
//
// Although Close() also does this, we inform the server loop early about
// the closure to ensure that if any new packets are received from this
// connection in the meantime, a new handler will be started.
pc.closeCh <- pc.addr.String()
// Returning EOF here ensures that io.Copy() waiting on the downstream for
// reads will terminate.
return 0, io.EOF
}

func (pc packetConn) Write(b []byte) (n int, err error) {
return pc.PacketConn.WriteTo(b, pc.addr)
}

func (pc packetConn) Close() error {
// Do nothing, we don't want to close the UDP server
func (pc *packetConn) Close() error {
if pc.lastPacket != nil {
udpBufPool.Put(pc.lastPacket.pooledBuf)
pc.lastPacket = nil
}
// This will abort any active Read() from another goroutine and return EOF
close(pc.readCh)
// Drain pending packets to ensure we release buffers back to the pool
for pkt := range pc.readCh {
udpBufPool.Put(pkt.pooledBuf)
}
// We may have already done this earlier in Read(), but just in case
// Read() wasn't being called, (re-)notify server loop we're closed.
pc.closeCh <- pc.addr.String()
// We don't call net.PacketConn.Close() here as we would stop the UDP
// server.
return nil
}

Expand Down

0 comments on commit b91d7e8

Please sign in to comment.