Skip to content

Commit

Permalink
net: avoid nil pointer dereference when RemoteAddr.String method chai…
Browse files Browse the repository at this point in the history
…n is called

Fixes #3721.

R=dave, rsc
CC=golang-dev
https://golang.org/cl/6395055
  • Loading branch information
cixtor committed Aug 23, 2012
1 parent e80f6a4 commit 6cf77f2
Show file tree
Hide file tree
Showing 5 changed files with 133 additions and 31 deletions.
5 changes: 2 additions & 3 deletions src/pkg/net/fd.go
Expand Up @@ -612,11 +612,10 @@ func (fd *netFD) accept(toAddr func(syscall.Sockaddr) Addr) (netfd *netFD, err e
syscall.ForkLock.RUnlock()

if netfd, err = newFD(s, fd.family, fd.sotype, fd.net); err != nil {
syscall.Close(s)
closesocket(s)
return nil, err
}
lsa, _ := syscall.Getsockname(netfd.sysfd)
netfd.setAddr(toAddr(lsa), toAddr(rsa))
netfd.setAddr(localSockname(fd, toAddr), toAddr(rsa))
return netfd, nil
}

Expand Down
14 changes: 6 additions & 8 deletions src/pkg/net/file.go
Expand Up @@ -25,8 +25,8 @@ func newFileFD(f *os.File) (*netFD, error) {

family := syscall.AF_UNSPEC
toAddr := sockaddrToTCP
sa, _ := syscall.Getsockname(fd)
switch sa.(type) {
lsa, _ := syscall.Getsockname(fd)
switch lsa.(type) {
default:
closesocket(fd)
return nil, syscall.EINVAL
Expand All @@ -53,16 +53,14 @@ func newFileFD(f *os.File) (*netFD, error) {
toAddr = sockaddrToUnixpacket
}
}
laddr := toAddr(sa)
sa, _ = syscall.Getpeername(fd)
raddr := toAddr(sa)
laddr := toAddr(lsa)

netfd, err := newFD(fd, family, sotype, laddr.Network())
if err != nil {
closesocket(fd)
return nil, err
}
netfd.setAddr(laddr, raddr)
netfd.setAddr(laddr, remoteSockname(netfd, toAddr))
return netfd, nil
}

Expand All @@ -80,10 +78,10 @@ func FileConn(f *os.File) (c Conn, err error) {
return newTCPConn(fd), nil
case *UDPAddr:
return newUDPConn(fd), nil
case *UnixAddr:
return newUnixConn(fd), nil
case *IPAddr:
return newIPConn(fd), nil
case *UnixAddr:
return newUnixConn(fd), nil
}
fd.Close()
return nil, syscall.EINVAL
Expand Down
51 changes: 50 additions & 1 deletion src/pkg/net/ipraw_test.go
Expand Up @@ -14,6 +14,55 @@ import (
"time"
)

var ipConnAddrStringTests = []struct {
net string
laddr string
raddr string
ipv6 bool
}{
{"ip:icmp", "127.0.0.1", "", false},
{"ip:icmp", "::1", "", true},
{"ip:icmp", "", "127.0.0.1", false},
{"ip:icmp", "", "::1", true},
}

func TestIPConnAddrString(t *testing.T) {
if os.Getuid() != 0 {
t.Logf("skipping test; must be root")
return
}

for i, tt := range ipConnAddrStringTests {
if tt.ipv6 && !supportsIPv6 {
continue
}
var (
err error
c *IPConn
mode string
)
if tt.raddr == "" {
mode = "listen"
la, _ := ResolveIPAddr(tt.net, tt.laddr)
c, err = ListenIP(tt.net, la)
if err != nil {
t.Fatalf("ListenIP(%q, %q) failed: %v", tt.net, la.String(), err)
}
} else {
mode = "dial"
la, _ := ResolveIPAddr(tt.net, tt.laddr)
ra, _ := ResolveIPAddr(tt.net, tt.raddr)
c, err = DialIP(tt.net, la, ra)
if err != nil {
t.Fatalf("DialIP(%q, %q) failed: %v", tt.net, ra.String(), err)
}
}
t.Logf("%s-%v: LocalAddr: %q, %q", mode, i, c.LocalAddr(), c.LocalAddr().String())
t.Logf("%s-%v: RemoteAddr: %q, %q", mode, i, c.RemoteAddr(), c.RemoteAddr().String())
c.Close()
}
}

var icmpTests = []struct {
net string
laddr string
Expand All @@ -26,7 +75,7 @@ var icmpTests = []struct {

func TestICMP(t *testing.T) {
if os.Getuid() != 0 {
t.Logf("test disabled; must be root")
t.Logf("skipping test; must be root")
return
}

Expand Down
67 changes: 48 additions & 19 deletions src/pkg/net/sock.go
Expand Up @@ -16,7 +16,7 @@ import (
var listenerBacklog = maxListenerBacklog()

// Generic socket creation.
func socket(net string, f, t, p int, ipv6only bool, la, ra syscall.Sockaddr, toAddr func(syscall.Sockaddr) Addr) (fd *netFD, err error) {
func socket(net string, f, t, p int, ipv6only bool, ulsa, ursa syscall.Sockaddr, toAddr func(syscall.Sockaddr) Addr) (fd *netFD, err error) {
// See ../syscall/exec.go for description of ForkLock.
syscall.ForkLock.RLock()
s, err := syscall.Socket(f, t, p)
Expand All @@ -27,21 +27,18 @@ func socket(net string, f, t, p int, ipv6only bool, la, ra syscall.Sockaddr, toA
syscall.CloseOnExec(s)
syscall.ForkLock.RUnlock()

err = setDefaultSockopts(s, f, t, ipv6only)
if err != nil {
if err = setDefaultSockopts(s, f, t, ipv6only); err != nil {
closesocket(s)
return nil, err
}

var bla syscall.Sockaddr
if la != nil {
bla, err = listenerSockaddr(s, f, la, toAddr)
if err != nil {
var blsa syscall.Sockaddr
if ulsa != nil {
if blsa, err = listenerSockaddr(s, f, ulsa, toAddr); err != nil {
closesocket(s)
return nil, err
}
err = syscall.Bind(s, bla)
if err != nil {
if err = syscall.Bind(s, blsa); err != nil {
closesocket(s)
return nil, err
}
Expand All @@ -52,26 +49,22 @@ func socket(net string, f, t, p int, ipv6only bool, la, ra syscall.Sockaddr, toA
return nil, err
}

if ra != nil {
if err = fd.connect(ra); err != nil {
if ursa != nil {
if err = fd.connect(ursa); err != nil {
closesocket(s)
fd.Close()
return nil, err
}
fd.isConnected = true
}

sa, _ := syscall.Getsockname(s)
var laddr Addr
if la != nil && bla != la {
laddr = toAddr(la)
if ulsa != nil && blsa != ulsa {
laddr = toAddr(ulsa)
} else {
laddr = toAddr(sa)
laddr = localSockname(fd, toAddr)
}
sa, _ = syscall.Getpeername(s)
raddr := toAddr(sa)

fd.setAddr(laddr, raddr)
fd.setAddr(laddr, remoteSockname(fd, toAddr))
return fd, nil
}

Expand All @@ -85,3 +78,39 @@ func genericReadFrom(w io.Writer, r io.Reader) (n int64, err error) {
// Use wrapper to hide existing r.ReadFrom from io.Copy.
return io.Copy(writerOnly{w}, r)
}

func localSockname(fd *netFD, toAddr func(syscall.Sockaddr) Addr) Addr {
sa, _ := syscall.Getsockname(fd.sysfd)
if sa == nil {
return nullProtocolAddr(fd.family, fd.sotype)
}
return toAddr(sa)
}

func remoteSockname(fd *netFD, toAddr func(syscall.Sockaddr) Addr) Addr {
sa, _ := syscall.Getpeername(fd.sysfd)
if sa == nil {
return nullProtocolAddr(fd.family, fd.sotype)
}
return toAddr(sa)
}

func nullProtocolAddr(f, t int) Addr {
switch f {
case syscall.AF_INET, syscall.AF_INET6:
switch t {
case syscall.SOCK_STREAM:
return (*TCPAddr)(nil)
case syscall.SOCK_DGRAM:
return (*UDPAddr)(nil)
case syscall.SOCK_RAW:
return (*IPAddr)(nil)
}
case syscall.AF_UNIX:
switch t {
case syscall.SOCK_STREAM, syscall.SOCK_DGRAM, syscall.SOCK_SEQPACKET:
return (*UnixAddr)(nil)
}
}
panic("unreachable")
}
27 changes: 27 additions & 0 deletions src/pkg/net/udp_test.go
Expand Up @@ -9,6 +9,33 @@ import (
"testing"
)

var udpConnAddrStringTests = []struct {
net string
laddr string
raddr string
ipv6 bool
}{
{"udp", "127.0.0.1:0", "", false},
{"udp", "[::1]:0", "", true},
}

func TestUDPConnAddrString(t *testing.T) {
for i, tt := range udpConnAddrStringTests {
if tt.ipv6 && !supportsIPv6 {
continue
}
mode := "listen"
la, _ := ResolveUDPAddr(tt.net, tt.laddr)
c, err := ListenUDP(tt.net, la)
if err != nil {
t.Fatalf("ListenUDP(%q, %q) failed: %v", tt.net, la.String(), err)
}
t.Logf("%s-%v: LocalAddr: %q, %q", mode, i, c.LocalAddr(), c.LocalAddr().String())
t.Logf("%s-%v: RemoteAddr: %q, %q", mode, i, c.RemoteAddr(), c.RemoteAddr().String())
c.Close()
}
}

func TestWriteToUDP(t *testing.T) {
switch runtime.GOOS {
case "plan9":
Expand Down

0 comments on commit 6cf77f2

Please sign in to comment.