diff --git a/listen_no_reuseport.go b/listen_no_reuseport.go index 6ed50f86b..8cebb2f17 100644 --- a/listen_no_reuseport.go +++ b/listen_no_reuseport.go @@ -7,16 +7,18 @@ import "net" const supportsReusePort = false -func listenTCP(network, addr string, reuseport bool) (net.Listener, error) { - if reuseport { +func listenTCP(network, addr string, reuseport, reuseaddr bool) (net.Listener, error) { + if reuseport || reuseaddr { // TODO(tmthrgd): return an error? } return net.Listen(network, addr) } -func listenUDP(network, addr string, reuseport bool) (net.PacketConn, error) { - if reuseport { +const supportsReuseAddr = false + +func listenUDP(network, addr string, reuseport, reuseaddr bool) (net.PacketConn, error) { + if reuseport || reuseaddr { // TODO(tmthrgd): return an error? } diff --git a/listen_reuseport.go b/listen_reuseport.go index 89bac9034..41326f20b 100644 --- a/listen_reuseport.go +++ b/listen_reuseport.go @@ -25,19 +25,41 @@ func reuseportControl(network, address string, c syscall.RawConn) error { return opErr } -func listenTCP(network, addr string, reuseport bool) (net.Listener, error) { +const supportsReuseAddr = true + +func reuseaddrControl(network, address string, c syscall.RawConn) error { + var opErr error + err := c.Control(func(fd uintptr) { + opErr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_REUSEADDR, 1) + }) + if err != nil { + return err + } + + return opErr +} + +func listenTCP(network, addr string, reuseport, reuseaddr bool) (net.Listener, error) { var lc net.ListenConfig - if reuseport { + switch { + case reuseaddr && reuseport: + case reuseport: lc.Control = reuseportControl + case reuseaddr: + lc.Control = reuseaddrControl } return lc.Listen(context.Background(), network, addr) } -func listenUDP(network, addr string, reuseport bool) (net.PacketConn, error) { +func listenUDP(network, addr string, reuseport, reuseaddr bool) (net.PacketConn, error) { var lc net.ListenConfig - if reuseport { + switch { + case reuseaddr && reuseport: + case reuseport: lc.Control = reuseportControl + case reuseaddr: + lc.Control = reuseaddrControl } return lc.ListenPacket(context.Background(), network, addr) diff --git a/server.go b/server.go index 64e388546..0207d6da2 100644 --- a/server.go +++ b/server.go @@ -226,6 +226,10 @@ type Server struct { // Whether to set the SO_REUSEPORT socket option, allowing multiple listeners to be bound to a single address. // It is only supported on certain GOOSes and when using ListenAndServe. ReusePort bool + // Whether to set the SO_REUSEADDR socket option, allowing multiple listeners to be bound to a single address. + // Crucially this allows binding when an existing server is listening on `0.0.0.0` or `::`. + // It is only supported on certain GOOSes and when using ListenAndServe. + ReuseAddr bool // AcceptMsgFunc will check the incoming message and will reject it early in the process. // By default DefaultMsgAcceptFunc will be used. MsgAcceptFunc MsgAcceptFunc @@ -304,7 +308,7 @@ func (srv *Server) ListenAndServe() error { switch srv.Net { case "tcp", "tcp4", "tcp6": - l, err := listenTCP(srv.Net, addr, srv.ReusePort) + l, err := listenTCP(srv.Net, addr, srv.ReusePort, srv.ReuseAddr) if err != nil { return err } @@ -317,7 +321,7 @@ func (srv *Server) ListenAndServe() error { return errors.New("dns: neither Certificates nor GetCertificate set in Config") } network := strings.TrimSuffix(srv.Net, "-tls") - l, err := listenTCP(network, addr, srv.ReusePort) + l, err := listenTCP(network, addr, srv.ReusePort, srv.ReuseAddr) if err != nil { return err } @@ -327,7 +331,7 @@ func (srv *Server) ListenAndServe() error { unlock() return srv.serveTCP(l) case "udp", "udp4", "udp6": - l, err := listenUDP(srv.Net, addr, srv.ReusePort) + l, err := listenUDP(srv.Net, addr, srv.ReusePort, srv.ReuseAddr) if err != nil { return err } diff --git a/server_test.go b/server_test.go index aaaca7048..4fc2af329 100644 --- a/server_test.go +++ b/server_test.go @@ -3,6 +3,7 @@ package dns import ( "context" "crypto/tls" + "errors" "fmt" "io" "net" @@ -1041,6 +1042,176 @@ func TestServerReuseport(t *testing.T) { } } +func TestServerReuseaddr(t *testing.T) { + startServerFn := func(t *testing.T, network, addr string, expectSuccess bool) (*Server, chan error) { + t.Helper() + wait := make(chan struct{}) + srv := &Server{ + Net: network, + Addr: addr, + NotifyStartedFunc: func() { close(wait) }, + ReuseAddr: true, + } + + fin := make(chan error, 1) + go func() { + fin <- srv.ListenAndServe() + }() + + select { + case <-wait: + case err := <-fin: + switch { + case expectSuccess: + t.Fatalf("%s: failed to start server: %v", t.Name(), err) + default: + fin <- err + return nil, fin + } + } + return srv, fin + } + + externalIPFn := func(t *testing.T) (string, error) { + t.Helper() + ifaces, err := net.Interfaces() + if err != nil { + return "", err + } + for _, iface := range ifaces { + if iface.Flags&net.FlagUp == 0 { + continue // interface down + } + if iface.Flags&net.FlagLoopback != 0 { + continue // loopback interface + } + addrs, err := iface.Addrs() + if err != nil { + return "", err + } + for _, addr := range addrs { + var ip net.IP + switch v := addr.(type) { + case *net.IPNet: + ip = v.IP + case *net.IPAddr: + ip = v.IP + } + if ip == nil || ip.IsLoopback() { + continue + } + ip = ip.To4() + if ip == nil { + continue // not an ipv4 address + } + return ip.String(), nil + } + } + return "", errors.New("are you connected to the network?") + } + + freePortFn := func(t *testing.T) int { + t.Helper() + addr, err := net.ResolveTCPAddr("tcp", "localhost:0") + if err != nil { + t.Fatalf("unable resolve tcp addr: %s", err) + } + + l, err := net.ListenTCP("tcp", addr) + if err != nil { + t.Fatalf("unable listen tcp: %s", err) + } + defer l.Close() + return l.Addr().(*net.TCPAddr).Port + } + + t.Run("should-fail-tcp", func(t *testing.T) { + // ReuseAddr should fail if you try to bind to exactly the same + // combination of source address and port. + // This should fail whether or not ReuseAddr is supported on a + // particular OS + ip, err := externalIPFn(t) + if err != nil { + t.Skip("no external IPs found") + return + } + port := freePortFn(t) + srv1, fin1 := startServerFn(t, "tcp", fmt.Sprintf("%s:%d", ip, port), true) + srv2, fin2 := startServerFn(t, "tcp", fmt.Sprintf("%s:%d", ip, port), false) + switch { + case srv2 != nil && srv2.started: + t.Fatalf("second ListenAndServe should not have started") + default: + if err := <-fin2; err == nil { + t.Fatalf("second ListenAndServe should have returned a startup error: %v", err) + } + } + + if err := srv1.Shutdown(); err != nil { + t.Fatalf("failed to shutdown first server: %v", err) + } + if err := <-fin1; err != nil { + t.Fatalf("first ListenAndServe returned error after Shutdown: %v", err) + } + }) + t.Run("should-succeed-tcp", func(t *testing.T) { + if !supportsReuseAddr { + t.Skip("reuseaddr is not supported") + } + ip, err := externalIPFn(t) + if err != nil { + t.Skip("no external IPs found") + return + } + port := freePortFn(t) + + // ReuseAddr should succeed if you try to bind to the same port but a different source address + srv1, fin1 := startServerFn(t, "tcp", fmt.Sprintf("localhost:%d", port), true) + srv2, fin2 := startServerFn(t, "tcp", fmt.Sprintf("%s:%d", ip, port), true) + + if err := srv1.Shutdown(); err != nil { + t.Fatalf("failed to shutdown first server: %v", err) + } + if err := srv2.Shutdown(); err != nil { + t.Fatalf("failed to shutdown second server: %v", err) + } + if err := <-fin1; err != nil { + t.Fatalf("first ListenAndServe returned error after Shutdown: %v", err) + } + if err := <-fin2; err != nil { + t.Fatalf("second ListenAndServe returned error after Shutdown: %v", err) + } + }) + t.Run("should-succeed-udp", func(t *testing.T) { + if !supportsReuseAddr { + t.Skip("reuseaddr is not supported") + } + ip, err := externalIPFn(t) + if err != nil { + t.Skip("no external IPs found") + return + } + port := freePortFn(t) + + // ReuseAddr should succeed if you try to bind to the same port but a different source address + srv1, fin1 := startServerFn(t, "udp", fmt.Sprintf("localhost:%d", port), true) + srv2, fin2 := startServerFn(t, "udp", fmt.Sprintf("%s:%d", ip, port), true) + + if err := srv1.Shutdown(); err != nil { + t.Fatalf("failed to shutdown first server: %v", err) + } + if err := srv2.Shutdown(); err != nil { + t.Fatalf("failed to shutdown second server: %v", err) + } + if err := <-fin1; err != nil { + t.Fatalf("first ListenAndServe returned error after Shutdown: %v", err) + } + if err := <-fin2; err != nil { + t.Fatalf("second ListenAndServe returned error after Shutdown: %v", err) + } + }) +} + func TestServerRoundtripTsig(t *testing.T) { secret := map[string]string{"test.": "so6ZGir4GPAqINNh9U5c3A=="}