Skip to content

Commit

Permalink
net: implement ip protocol name to number resolver for windows
Browse files Browse the repository at this point in the history
Fixes #2215.
Fixes #2216.

R=golang-dev, dave, rsc
CC=golang-dev
https://golang.org/cl/5248055
  • Loading branch information
alexbrainman committed Oct 11, 2011
1 parent d69b820 commit 059c68b
Show file tree
Hide file tree
Showing 8 changed files with 103 additions and 37 deletions.
3 changes: 2 additions & 1 deletion src/pkg/net/ipraw_test.go
Expand Up @@ -11,6 +11,7 @@ import (
"bytes"
"flag"
"os"
"runtime"
"testing"
)

Expand Down Expand Up @@ -64,7 +65,7 @@ var dsthost = flag.String("dsthost", "127.0.0.1", "Destination for the ICMP ECHO

// test (raw) IP socket using ICMP
func TestICMP(t *testing.T) {
if os.Getuid() != 0 {
if runtime.GOOS != "windows" && os.Getuid() != 0 {
t.Logf("test disabled; must be root")
return
}
Expand Down
38 changes: 4 additions & 34 deletions src/pkg/net/iprawsock_posix.go
Expand Up @@ -10,12 +10,9 @@ package net

import (
"os"
"sync"
"syscall"
)

var onceReadProtocols sync.Once

func sockaddrToIP(sa syscall.Sockaddr) Addr {
switch sa := sa.(type) {
case *syscall.SockaddrInet4:
Expand Down Expand Up @@ -209,33 +206,7 @@ func (c *IPConn) WriteTo(b []byte, addr Addr) (n int, err os.Error) {
return c.WriteToIP(b, a)
}

var protocols map[string]int

func readProtocols() {
protocols = make(map[string]int)
if file, err := open("/etc/protocols"); err == nil {
for line, ok := file.readLine(); ok; line, ok = file.readLine() {
// tcp 6 TCP # transmission control protocol
if i := byteIndex(line, '#'); i >= 0 {
line = line[0:i]
}
f := getFields(line)
if len(f) < 2 {
continue
}
if proto, _, ok := dtoi(f[1], 0); ok {
protocols[f[0]] = proto
for _, alias := range f[2:] {
protocols[alias] = proto
}
}
}
file.close()
}
}

func splitNetProto(netProto string) (net string, proto int, err os.Error) {
onceReadProtocols.Do(readProtocols)
i := last(netProto, ':')
if i < 0 { // no colon
return "", 0, os.NewError("no IP protocol specified")
Expand All @@ -244,13 +215,12 @@ func splitNetProto(netProto string) (net string, proto int, err os.Error) {
protostr := netProto[i+1:]
proto, i, ok := dtoi(protostr, 0)
if !ok || i != len(protostr) {
// lookup by name
proto, ok = protocols[protostr]
if ok {
return
proto, err = lookupProtocol(protostr)
if err != nil {
return "", 0, err
}
}
return
return net, proto, nil
}

// DialIP connects to the remote address raddr on the network net,
Expand Down
42 changes: 42 additions & 0 deletions src/pkg/net/lookup_unix.go
Expand Up @@ -8,8 +8,50 @@ package net

import (
"os"
"sync"
)

var (
protocols map[string]int
onceReadProtocols sync.Once
)

// readProtocols loads contents of /etc/protocols into protocols map
// for quick access.
func readProtocols() {
protocols = make(map[string]int)
if file, err := open("/etc/protocols"); err == nil {
for line, ok := file.readLine(); ok; line, ok = file.readLine() {
// tcp 6 TCP # transmission control protocol
if i := byteIndex(line, '#'); i >= 0 {
line = line[0:i]
}
f := getFields(line)
if len(f) < 2 {
continue
}
if proto, _, ok := dtoi(f[1], 0); ok {
protocols[f[0]] = proto
for _, alias := range f[2:] {
protocols[alias] = proto
}
}
}
file.close()
}
}

// lookupProtocol looks up IP protocol name in /etc/protocols and
// returns correspondent protocol number.
func lookupProtocol(name string) (proto int, err os.Error) {
onceReadProtocols.Do(readProtocols)
proto, found := protocols[name]
if !found {
return 0, os.NewError("unknown IP protocol specified: " + name)
}
return
}

// LookupHost looks up the given host using the local resolver.
// It returns an array of that host's addresses.
func LookupHost(host string) (addrs []string, err os.Error) {
Expand Down
18 changes: 16 additions & 2 deletions src/pkg/net/lookup_windows.go
Expand Up @@ -11,8 +11,22 @@ import (
"sync"
)

var hostentLock sync.Mutex
var serventLock sync.Mutex
var (
protoentLock sync.Mutex
hostentLock sync.Mutex
serventLock sync.Mutex
)

// lookupProtocol looks up IP protocol name and returns correspondent protocol number.
func lookupProtocol(name string) (proto int, err os.Error) {
protoentLock.Lock()
defer protoentLock.Unlock()
p, e := syscall.GetProtoByName(name)
if e != 0 {
return 0, os.NewSyscallError("GetProtoByName", e)
}
return int(p.Proto), nil
}

func LookupHost(name string) (addrs []string, err os.Error) {
ips, err := LookupIP(name)
Expand Down
1 change: 1 addition & 0 deletions src/pkg/syscall/syscall_windows.go
Expand Up @@ -502,6 +502,7 @@ func Chmod(path string, mode uint32) (errno int) {
//sys GetHostByName(name string) (h *Hostent, errno int) [failretval==nil] = ws2_32.gethostbyname
//sys GetServByName(name string, proto string) (s *Servent, errno int) [failretval==nil] = ws2_32.getservbyname
//sys Ntohs(netshort uint16) (u uint16) = ws2_32.ntohs
//sys GetProtoByName(name string) (p *Protoent, errno int) [failretval==nil] = ws2_32.getprotobyname
//sys DnsQuery(name string, qtype uint16, options uint32, extra *byte, qrs **DNSRecord, pr *byte) (status uint32) = dnsapi.DnsQuery_W
//sys DnsRecordListFree(rl *DNSRecord, freetype uint32) = dnsapi.DnsRecordListFree
//sys GetIfEntry(pIfRow *MibIfRow) (errcode int) = iphlpapi.GetIfEntry
Expand Down
16 changes: 16 additions & 0 deletions src/pkg/syscall/zsyscall_windows_386.go
Expand Up @@ -101,6 +101,7 @@ var (
procgethostbyname = modws2_32.NewProc("gethostbyname")
procgetservbyname = modws2_32.NewProc("getservbyname")
procntohs = modws2_32.NewProc("ntohs")
procgetprotobyname = modws2_32.NewProc("getprotobyname")
procDnsQuery_W = moddnsapi.NewProc("DnsQuery_W")
procDnsRecordListFree = moddnsapi.NewProc("DnsRecordListFree")
procGetIfEntry = modiphlpapi.NewProc("GetIfEntry")
Expand Down Expand Up @@ -1314,6 +1315,21 @@ func Ntohs(netshort uint16) (u uint16) {
return
}

func GetProtoByName(name string) (p *Protoent, errno int) {
r0, _, e1 := Syscall(procgetprotobyname.Addr(), 1, uintptr(unsafe.Pointer(StringBytePtr(name))), 0, 0)
p = (*Protoent)(unsafe.Pointer(r0))
if p == nil {
if e1 != 0 {
errno = int(e1)
} else {
errno = EINVAL
}
} else {
errno = 0
}
return
}

func DnsQuery(name string, qtype uint16, options uint32, extra *byte, qrs **DNSRecord, pr *byte) (status uint32) {
r0, _, _ := Syscall6(procDnsQuery_W.Addr(), 6, uintptr(unsafe.Pointer(StringToUTF16Ptr(name))), uintptr(qtype), uintptr(options), uintptr(unsafe.Pointer(extra)), uintptr(unsafe.Pointer(qrs)), uintptr(unsafe.Pointer(pr)))
status = uint32(r0)
Expand Down
16 changes: 16 additions & 0 deletions src/pkg/syscall/zsyscall_windows_amd64.go
Expand Up @@ -101,6 +101,7 @@ var (
procgethostbyname = modws2_32.NewProc("gethostbyname")
procgetservbyname = modws2_32.NewProc("getservbyname")
procntohs = modws2_32.NewProc("ntohs")
procgetprotobyname = modws2_32.NewProc("getprotobyname")
procDnsQuery_W = moddnsapi.NewProc("DnsQuery_W")
procDnsRecordListFree = moddnsapi.NewProc("DnsRecordListFree")
procGetIfEntry = modiphlpapi.NewProc("GetIfEntry")
Expand Down Expand Up @@ -1314,6 +1315,21 @@ func Ntohs(netshort uint16) (u uint16) {
return
}

func GetProtoByName(name string) (p *Protoent, errno int) {
r0, _, e1 := Syscall(procgetprotobyname.Addr(), 1, uintptr(unsafe.Pointer(StringBytePtr(name))), 0, 0)
p = (*Protoent)(unsafe.Pointer(r0))
if p == nil {
if e1 != 0 {
errno = int(e1)
} else {
errno = EINVAL
}
} else {
errno = 0
}
return
}

func DnsQuery(name string, qtype uint16, options uint32, extra *byte, qrs **DNSRecord, pr *byte) (status uint32) {
r0, _, _ := Syscall6(procDnsQuery_W.Addr(), 6, uintptr(unsafe.Pointer(StringToUTF16Ptr(name))), uintptr(qtype), uintptr(options), uintptr(unsafe.Pointer(extra)), uintptr(unsafe.Pointer(qrs)), uintptr(unsafe.Pointer(pr)))
status = uint32(r0)
Expand Down
6 changes: 6 additions & 0 deletions src/pkg/syscall/ztypes_windows.go
Expand Up @@ -411,6 +411,12 @@ type Hostent struct {
AddrList **byte
}

type Protoent struct {
Name *byte
Aliases **byte
Proto uint16
}

const (
DNS_TYPE_A = 0x0001
DNS_TYPE_NS = 0x0002
Expand Down

0 comments on commit 059c68b

Please sign in to comment.