Skip to content

Commit

Permalink
Only treat a *net.UnixConn of unixgram as a packet conn (#1322)
Browse files Browse the repository at this point in the history
* Refactor net.PacketConn checks into helper function

* Only treat a *net.UnixConn of unixgram as a packet conn

* Handle wrapped net.Conn types in isPacketConn

* Use Error instead of Fatal where appropriate in TestIsPacketConn
  • Loading branch information
tmthrgd committed Dec 28, 2021
1 parent af5144a commit 0544c8b
Show file tree
Hide file tree
Showing 3 changed files with 140 additions and 31 deletions.
20 changes: 16 additions & 4 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,18 @@ const (
tcpIdleTimeout time.Duration = 8 * time.Second
)

func isPacketConn(c net.Conn) bool {
if _, ok := c.(net.PacketConn); !ok {
return false
}

if ua, ok := c.LocalAddr().(*net.UnixAddr); ok {
return ua.Net == "unixgram"
}

return true
}

// A Conn represents a connection to a DNS server.
type Conn struct {
net.Conn // a net.Conn holding the connection
Expand Down Expand Up @@ -221,7 +233,7 @@ func (c *Client) exchangeContext(ctx context.Context, m *Msg, co *Conn) (r *Msg,
return nil, 0, err
}

if _, ok := co.Conn.(net.PacketConn); ok {
if isPacketConn(co.Conn) {
for {
r, err = co.ReadMsg()
// Ignore replies with mismatched IDs because they might be
Expand Down Expand Up @@ -282,7 +294,7 @@ func (co *Conn) ReadMsgHeader(hdr *Header) ([]byte, error) {
err error
)

if _, ok := co.Conn.(net.PacketConn); ok {
if isPacketConn(co.Conn) {
if co.UDPSize > MinMsgSize {
p = make([]byte, co.UDPSize)
} else {
Expand Down Expand Up @@ -322,7 +334,7 @@ func (co *Conn) Read(p []byte) (n int, err error) {
return 0, ErrConnEmpty
}

if _, ok := co.Conn.(net.PacketConn); ok {
if isPacketConn(co.Conn) {
// UDP connection
return co.Conn.Read(p)
}
Expand Down Expand Up @@ -371,7 +383,7 @@ func (co *Conn) Write(p []byte) (int, error) {
return 0, &Error{err: "message too large"}
}

if _, ok := co.Conn.(net.PacketConn); ok {
if isPacketConn(co.Conn) {
return co.Conn.Write(p)
}

Expand Down
75 changes: 75 additions & 0 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,87 @@ import (
"errors"
"fmt"
"net"
"path/filepath"
"strconv"
"strings"
"testing"
"time"
)

func TestIsPacketConn(t *testing.T) {
// UDP
s, addrstr, _, err := RunLocalUDPServer(":0")
if err != nil {
t.Fatalf("unable to run test server: %v", err)
}
defer s.Shutdown()
c, err := net.Dial("udp", addrstr)
if err != nil {
t.Fatalf("failed to dial: %v", err)
}
defer c.Close()
if !isPacketConn(c) {
t.Error("UDP connection should be a packet conn")
}
if !isPacketConn(struct{ *net.UDPConn }{c.(*net.UDPConn)}) {
t.Error("UDP connection (wrapped type) should be a packet conn")
}

// TCP
s, addrstr, _, err = RunLocalTCPServer(":0")
if err != nil {
t.Fatalf("unable to run test server: %v", err)
}
defer s.Shutdown()
c, err = net.Dial("tcp", addrstr)
if err != nil {
t.Fatalf("failed to dial: %v", err)
}
defer c.Close()
if isPacketConn(c) {
t.Error("TCP connection should not be a packet conn")
}
if isPacketConn(struct{ *net.TCPConn }{c.(*net.TCPConn)}) {
t.Error("TCP connection (wrapped type) should not be a packet conn")
}

// Unix datagram
s, addrstr, _, err = RunLocalUnixGramServer(filepath.Join(t.TempDir(), "unixgram.sock"))
if err != nil {
t.Fatalf("unable to run test server: %v", err)
}
defer s.Shutdown()
c, err = net.Dial("unixgram", addrstr)
if err != nil {
t.Fatalf("failed to dial: %v", err)
}
defer c.Close()
if !isPacketConn(c) {
t.Error("Unix datagram connection should be a packet conn")
}
if !isPacketConn(struct{ *net.UnixConn }{c.(*net.UnixConn)}) {
t.Error("Unix datagram connection (wrapped type) should be a packet conn")
}

// Unix stream
s, addrstr, _, err = RunLocalUnixServer(filepath.Join(t.TempDir(), "unixstream.sock"))
if err != nil {
t.Fatalf("unable to run test server: %v", err)
}
defer s.Shutdown()
c, err = net.Dial("unix", addrstr)
if err != nil {
t.Fatalf("failed to dial: %v", err)
}
defer c.Close()
if isPacketConn(c) {
t.Error("Unix stream connection should not be a packet conn")
}
if isPacketConn(struct{ *net.UnixConn }{c.(*net.UnixConn)}) {
t.Error("Unix stream connection (wrapped type) should not be a packet conn")
}
}

func TestDialUDP(t *testing.T) {
HandleFunc("miek.nl.", HelloServer)
defer HandleRemove("miek.nl.")
Expand Down
76 changes: 49 additions & 27 deletions server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,14 @@ func AnotherHelloServer(w ResponseWriter, req *Msg) {
w.WriteMsg(m)
}

func RunLocalUDPServer(laddr string, opts ...func(*Server)) (*Server, string, chan error, error) {
pc, err := net.ListenPacket("udp", laddr)
if err != nil {
return nil, "", nil, err
func RunLocalServer(pc net.PacketConn, l net.Listener, opts ...func(*Server)) (*Server, string, chan error, error) {
server := &Server{
PacketConn: pc,
Listener: l,

ReadTimeout: time.Hour,
WriteTimeout: time.Hour,
}
server := &Server{PacketConn: pc, ReadTimeout: time.Hour, WriteTimeout: time.Hour}

waitLock := sync.Mutex{}
waitLock.Lock()
Expand All @@ -82,18 +84,39 @@ func RunLocalUDPServer(laddr string, opts ...func(*Server)) (*Server, string, ch
opt(server)
}

var (
addr string
closer io.Closer
)
if l != nil {
addr = l.Addr().String()
closer = l
} else {
addr = pc.LocalAddr().String()
closer = pc
}

// fin must be buffered so the goroutine below won't block
// forever if fin is never read from. This always happens
// if the channel is discarded and can happen in TestShutdownUDP.
fin := make(chan error, 1)

go func() {
fin <- server.ActivateAndServe()
pc.Close()
closer.Close()
}()

waitLock.Lock()
return server, pc.LocalAddr().String(), fin, nil
return server, addr, fin, nil
}

func RunLocalUDPServer(laddr string, opts ...func(*Server)) (*Server, string, chan error, error) {
pc, err := net.ListenPacket("udp", laddr)
if err != nil {
return nil, "", nil, err
}

return RunLocalServer(pc, nil, opts...)
}

func RunLocalPacketConnServer(laddr string, opts ...func(*Server)) (*Server, string, chan error, error) {
Expand All @@ -109,26 +132,7 @@ func RunLocalTCPServer(laddr string, opts ...func(*Server)) (*Server, string, ch
return nil, "", nil, err
}

server := &Server{Listener: l, ReadTimeout: time.Hour, WriteTimeout: time.Hour}

waitLock := sync.Mutex{}
waitLock.Lock()
server.NotifyStartedFunc = waitLock.Unlock

for _, opt := range opts {
opt(server)
}

// See the comment in RunLocalUDPServer as to why fin must be buffered.
fin := make(chan error, 1)

go func() {
fin <- server.ActivateAndServe()
l.Close()
}()

waitLock.Lock()
return server, l.Addr().String(), fin, nil
return RunLocalServer(nil, l, opts...)
}

func RunLocalTLSServer(laddr string, config *tls.Config) (*Server, string, chan error, error) {
Expand All @@ -137,6 +141,24 @@ func RunLocalTLSServer(laddr string, config *tls.Config) (*Server, string, chan
})
}

func RunLocalUnixServer(laddr string, opts ...func(*Server)) (*Server, string, chan error, error) {
l, err := net.Listen("unix", laddr)
if err != nil {
return nil, "", nil, err
}

return RunLocalServer(nil, l, opts...)
}

func RunLocalUnixGramServer(laddr string, opts ...func(*Server)) (*Server, string, chan error, error) {
pc, err := net.ListenPacket("unixgram", laddr)
if err != nil {
return nil, "", nil, err
}

return RunLocalServer(pc, nil, opts...)
}

func TestServing(t *testing.T) {
for _, tc := range []struct {
name string
Expand Down

0 comments on commit 0544c8b

Please sign in to comment.