Skip to content

Commit

Permalink
feat: add support for ReuseAddr (#1510)
Browse files Browse the repository at this point in the history
* feat: add support for ReuseAddr

* Update listen_reuseport.go

* Update listen_reuseport.go

* fixup! feat: add support for ReuseAddr

---------

Co-authored-by: Miek Gieben <miek@miek.nl>
  • Loading branch information
jimlambrt and miekg committed Nov 15, 2023
1 parent 3d593a6 commit 257e89e
Show file tree
Hide file tree
Showing 4 changed files with 210 additions and 11 deletions.
10 changes: 6 additions & 4 deletions listen_no_reuseport.go
Expand Up @@ -7,16 +7,18 @@ import "net"

const supportsReusePort = false

func listenTCP(network, addr string, reuseport bool) (net.Listener, error) {
if reuseport {
func listenTCP(network, addr string, reuseport, reuseaddr bool) (net.Listener, error) {
if reuseport || reuseaddr {
// TODO(tmthrgd): return an error?
}

return net.Listen(network, addr)
}

func listenUDP(network, addr string, reuseport bool) (net.PacketConn, error) {
if reuseport {
const supportsReuseAddr = false

func listenUDP(network, addr string, reuseport, reuseaddr bool) (net.PacketConn, error) {
if reuseport || reuseaddr {
// TODO(tmthrgd): return an error?
}

Expand Down
30 changes: 26 additions & 4 deletions listen_reuseport.go
Expand Up @@ -25,19 +25,41 @@ func reuseportControl(network, address string, c syscall.RawConn) error {
return opErr
}

func listenTCP(network, addr string, reuseport bool) (net.Listener, error) {
const supportsReuseAddr = true

func reuseaddrControl(network, address string, c syscall.RawConn) error {
var opErr error
err := c.Control(func(fd uintptr) {
opErr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_REUSEADDR, 1)
})
if err != nil {
return err
}

return opErr
}

func listenTCP(network, addr string, reuseport, reuseaddr bool) (net.Listener, error) {
var lc net.ListenConfig
if reuseport {
switch {
case reuseaddr && reuseport:
case reuseport:
lc.Control = reuseportControl
case reuseaddr:
lc.Control = reuseaddrControl
}

return lc.Listen(context.Background(), network, addr)
}

func listenUDP(network, addr string, reuseport bool) (net.PacketConn, error) {
func listenUDP(network, addr string, reuseport, reuseaddr bool) (net.PacketConn, error) {
var lc net.ListenConfig
if reuseport {
switch {
case reuseaddr && reuseport:
case reuseport:
lc.Control = reuseportControl
case reuseaddr:
lc.Control = reuseaddrControl
}

return lc.ListenPacket(context.Background(), network, addr)
Expand Down
10 changes: 7 additions & 3 deletions server.go
Expand Up @@ -226,6 +226,10 @@ type Server struct {
// Whether to set the SO_REUSEPORT socket option, allowing multiple listeners to be bound to a single address.
// It is only supported on certain GOOSes and when using ListenAndServe.
ReusePort bool
// Whether to set the SO_REUSEADDR socket option, allowing multiple listeners to be bound to a single address.
// Crucially this allows binding when an existing server is listening on `0.0.0.0` or `::`.
// It is only supported on certain GOOSes and when using ListenAndServe.
ReuseAddr bool
// AcceptMsgFunc will check the incoming message and will reject it early in the process.
// By default DefaultMsgAcceptFunc will be used.
MsgAcceptFunc MsgAcceptFunc
Expand Down Expand Up @@ -304,7 +308,7 @@ func (srv *Server) ListenAndServe() error {

switch srv.Net {
case "tcp", "tcp4", "tcp6":
l, err := listenTCP(srv.Net, addr, srv.ReusePort)
l, err := listenTCP(srv.Net, addr, srv.ReusePort, srv.ReuseAddr)
if err != nil {
return err
}
Expand All @@ -317,7 +321,7 @@ func (srv *Server) ListenAndServe() error {
return errors.New("dns: neither Certificates nor GetCertificate set in Config")
}
network := strings.TrimSuffix(srv.Net, "-tls")
l, err := listenTCP(network, addr, srv.ReusePort)
l, err := listenTCP(network, addr, srv.ReusePort, srv.ReuseAddr)
if err != nil {
return err
}
Expand All @@ -327,7 +331,7 @@ func (srv *Server) ListenAndServe() error {
unlock()
return srv.serveTCP(l)
case "udp", "udp4", "udp6":
l, err := listenUDP(srv.Net, addr, srv.ReusePort)
l, err := listenUDP(srv.Net, addr, srv.ReusePort, srv.ReuseAddr)
if err != nil {
return err
}
Expand Down
171 changes: 171 additions & 0 deletions server_test.go
Expand Up @@ -3,6 +3,7 @@ package dns
import (
"context"
"crypto/tls"
"errors"
"fmt"
"io"
"net"
Expand Down Expand Up @@ -1041,6 +1042,176 @@ func TestServerReuseport(t *testing.T) {
}
}

func TestServerReuseaddr(t *testing.T) {
startServerFn := func(t *testing.T, network, addr string, expectSuccess bool) (*Server, chan error) {
t.Helper()
wait := make(chan struct{})
srv := &Server{
Net: network,
Addr: addr,
NotifyStartedFunc: func() { close(wait) },
ReuseAddr: true,
}

fin := make(chan error, 1)
go func() {
fin <- srv.ListenAndServe()
}()

select {
case <-wait:
case err := <-fin:
switch {
case expectSuccess:
t.Fatalf("%s: failed to start server: %v", t.Name(), err)
default:
fin <- err
return nil, fin
}
}
return srv, fin
}

externalIPFn := func(t *testing.T) (string, error) {
t.Helper()
ifaces, err := net.Interfaces()
if err != nil {
return "", err
}
for _, iface := range ifaces {
if iface.Flags&net.FlagUp == 0 {
continue // interface down
}
if iface.Flags&net.FlagLoopback != 0 {
continue // loopback interface
}
addrs, err := iface.Addrs()
if err != nil {
return "", err
}
for _, addr := range addrs {
var ip net.IP
switch v := addr.(type) {
case *net.IPNet:
ip = v.IP
case *net.IPAddr:
ip = v.IP
}
if ip == nil || ip.IsLoopback() {
continue
}
ip = ip.To4()
if ip == nil {
continue // not an ipv4 address
}
return ip.String(), nil
}
}
return "", errors.New("are you connected to the network?")
}

freePortFn := func(t *testing.T) int {
t.Helper()
addr, err := net.ResolveTCPAddr("tcp", "localhost:0")
if err != nil {
t.Fatalf("unable resolve tcp addr: %s", err)
}

l, err := net.ListenTCP("tcp", addr)
if err != nil {
t.Fatalf("unable listen tcp: %s", err)
}
defer l.Close()
return l.Addr().(*net.TCPAddr).Port
}

t.Run("should-fail-tcp", func(t *testing.T) {
// ReuseAddr should fail if you try to bind to exactly the same
// combination of source address and port.
// This should fail whether or not ReuseAddr is supported on a
// particular OS
ip, err := externalIPFn(t)
if err != nil {
t.Skip("no external IPs found")
return
}
port := freePortFn(t)
srv1, fin1 := startServerFn(t, "tcp", fmt.Sprintf("%s:%d", ip, port), true)
srv2, fin2 := startServerFn(t, "tcp", fmt.Sprintf("%s:%d", ip, port), false)
switch {
case srv2 != nil && srv2.started:
t.Fatalf("second ListenAndServe should not have started")
default:
if err := <-fin2; err == nil {
t.Fatalf("second ListenAndServe should have returned a startup error: %v", err)
}
}

if err := srv1.Shutdown(); err != nil {
t.Fatalf("failed to shutdown first server: %v", err)
}
if err := <-fin1; err != nil {
t.Fatalf("first ListenAndServe returned error after Shutdown: %v", err)
}
})
t.Run("should-succeed-tcp", func(t *testing.T) {
if !supportsReuseAddr {
t.Skip("reuseaddr is not supported")
}
ip, err := externalIPFn(t)
if err != nil {
t.Skip("no external IPs found")
return
}
port := freePortFn(t)

// ReuseAddr should succeed if you try to bind to the same port but a different source address
srv1, fin1 := startServerFn(t, "tcp", fmt.Sprintf("localhost:%d", port), true)
srv2, fin2 := startServerFn(t, "tcp", fmt.Sprintf("%s:%d", ip, port), true)

if err := srv1.Shutdown(); err != nil {
t.Fatalf("failed to shutdown first server: %v", err)
}
if err := srv2.Shutdown(); err != nil {
t.Fatalf("failed to shutdown second server: %v", err)
}
if err := <-fin1; err != nil {
t.Fatalf("first ListenAndServe returned error after Shutdown: %v", err)
}
if err := <-fin2; err != nil {
t.Fatalf("second ListenAndServe returned error after Shutdown: %v", err)
}
})
t.Run("should-succeed-udp", func(t *testing.T) {
if !supportsReuseAddr {
t.Skip("reuseaddr is not supported")
}
ip, err := externalIPFn(t)
if err != nil {
t.Skip("no external IPs found")
return
}
port := freePortFn(t)

// ReuseAddr should succeed if you try to bind to the same port but a different source address
srv1, fin1 := startServerFn(t, "udp", fmt.Sprintf("localhost:%d", port), true)
srv2, fin2 := startServerFn(t, "udp", fmt.Sprintf("%s:%d", ip, port), true)

if err := srv1.Shutdown(); err != nil {
t.Fatalf("failed to shutdown first server: %v", err)
}
if err := srv2.Shutdown(); err != nil {
t.Fatalf("failed to shutdown second server: %v", err)
}
if err := <-fin1; err != nil {
t.Fatalf("first ListenAndServe returned error after Shutdown: %v", err)
}
if err := <-fin2; err != nil {
t.Fatalf("second ListenAndServe returned error after Shutdown: %v", err)
}
})
}

func TestServerRoundtripTsig(t *testing.T) {
secret := map[string]string{"test.": "so6ZGir4GPAqINNh9U5c3A=="}

Expand Down

0 comments on commit 257e89e

Please sign in to comment.