diff --git a/net/convert.go b/net/convert.go new file mode 100644 index 0000000..07057e8 --- /dev/null +++ b/net/convert.go @@ -0,0 +1,330 @@ +package manet + +import ( + "fmt" + "net" + "path/filepath" + "runtime" + "strings" + + ma "github.com/multiformats/go-multiaddr" +) + +var errIncorrectNetAddr = fmt.Errorf("incorrect network addr conversion") +var errNotIP = fmt.Errorf("multiaddr does not start with an IP address") + +// FromNetAddr converts a net.Addr type to a Multiaddr. +func FromNetAddr(a net.Addr) (ma.Multiaddr, error) { + return defaultCodecs.FromNetAddr(a) +} + +// FromNetAddr converts a net.Addr to Multiaddress. +func (cm *CodecMap) FromNetAddr(a net.Addr) (ma.Multiaddr, error) { + if a == nil { + return nil, fmt.Errorf("nil multiaddr") + } + p, err := cm.getAddrParser(a.Network()) + if err != nil { + return nil, err + } + + return p(a) +} + +// ToNetAddr converts a Multiaddr to a net.Addr +// Must be ThinWaist. acceptable protocol stacks are: +// /ip{4,6}/{tcp, udp} +func ToNetAddr(maddr ma.Multiaddr) (net.Addr, error) { + return defaultCodecs.ToNetAddr(maddr) +} + +// ToNetAddr converts a Multiaddress to a standard net.Addr. +func (cm *CodecMap) ToNetAddr(maddr ma.Multiaddr) (net.Addr, error) { + protos := maddr.Protocols() + final := protos[len(protos)-1] + + p, err := cm.getMaddrParser(final.Name) + if err != nil { + return nil, err + } + + return p(maddr) +} + +func parseBasicNetMaddr(maddr ma.Multiaddr) (net.Addr, error) { + network, host, err := DialArgs(maddr) + if err != nil { + return nil, err + } + + switch network { + case "tcp", "tcp4", "tcp6": + return net.ResolveTCPAddr(network, host) + case "udp", "udp4", "udp6": + return net.ResolveUDPAddr(network, host) + case "ip", "ip4", "ip6": + return net.ResolveIPAddr(network, host) + case "unix": + return net.ResolveUnixAddr(network, host) + } + + return nil, fmt.Errorf("network not supported: %s", network) +} + +func FromIPAndZone(ip net.IP, zone string) (ma.Multiaddr, error) { + switch { + case ip.To4() != nil: + return ma.NewComponent("ip4", ip.String()) + case ip.To16() != nil: + ip6, err := ma.NewComponent("ip6", ip.String()) + if err != nil { + return nil, err + } + if zone == "" { + return ip6, nil + } else { + zone, err := ma.NewComponent("ip6zone", zone) + if err != nil { + return nil, err + } + return zone.Encapsulate(ip6), nil + } + default: + return nil, errIncorrectNetAddr + } +} + +// FromIP converts a net.IP type to a Multiaddr. +func FromIP(ip net.IP) (ma.Multiaddr, error) { + return FromIPAndZone(ip, "") +} + +// ToIP converts a Multiaddr to a net.IP when possible +func ToIP(addr ma.Multiaddr) (net.IP, error) { + var ip net.IP + ma.ForEach(addr, func(c ma.Component) bool { + switch c.Protocol().Code { + case ma.P_IP6ZONE: + // we can't return these anyways. + return true + case ma.P_IP6, ma.P_IP4: + ip = net.IP(c.RawValue()) + return false + } + return false + }) + if ip == nil { + return nil, errNotIP + } + return ip, nil +} + +// DialArgs is a convenience function that returns network and address as +// expected by net.Dial. See https://godoc.org/net#Dial for an overview of +// possible return values (we do not support the unixpacket ones yet). Unix +// addresses do not, at present, compose. +func DialArgs(m ma.Multiaddr) (string, string, error) { + zone, network, ip, port, hostname, err := dialArgComponents(m) + if err != nil { + return "", "", err + } + + // If we have a hostname (dns*), we don't want any fancy ipv6 formatting + // logic (zone, brackets, etc.). + if hostname { + switch network { + case "ip", "ip4", "ip6": + return network, ip, nil + case "tcp", "tcp4", "tcp6", "udp", "udp4", "udp6": + return network, ip + ":" + port, nil + } + // Hostname is only true when network is one of the above. + panic("unreachable") + } + + switch network { + case "ip6": + if zone != "" { + ip += "%" + zone + } + fallthrough + case "ip4": + return network, ip, nil + case "tcp4", "udp4": + return network, ip + ":" + port, nil + case "tcp6", "udp6": + if zone != "" { + ip += "%" + zone + } + return network, "[" + ip + "]" + ":" + port, nil + case "unix": + if runtime.GOOS == "windows" { + // convert /c:/... to c:\... + ip = filepath.FromSlash(strings.TrimLeft(ip, "/")) + } + return network, ip, nil + default: + return "", "", fmt.Errorf("%s is not a 'thin waist' address", m) + } +} + +// dialArgComponents extracts the raw pieces used in dialing a Multiaddr +func dialArgComponents(m ma.Multiaddr) (zone, network, ip, port string, hostname bool, err error) { + ma.ForEach(m, func(c ma.Component) bool { + switch network { + case "": + switch c.Protocol().Code { + case ma.P_IP6ZONE: + if zone != "" { + err = fmt.Errorf("%s has multiple zones", m) + return false + } + zone = c.Value() + return true + case ma.P_IP6: + network = "ip6" + ip = c.Value() + return true + case ma.P_IP4: + if zone != "" { + err = fmt.Errorf("%s has ip4 with zone", m) + return false + } + network = "ip4" + ip = c.Value() + return true + case ma.P_DNS: + network = "ip" + hostname = true + ip = c.Value() + return true + case ma.P_DNS4: + network = "ip4" + hostname = true + ip = c.Value() + return true + case ma.P_DNS6: + network = "ip6" + hostname = true + ip = c.Value() + return true + case ma.P_UNIX: + network = "unix" + ip = c.Value() + return false + } + case "ip": + switch c.Protocol().Code { + case ma.P_UDP: + network = "udp" + case ma.P_TCP: + network = "tcp" + default: + return false + } + port = c.Value() + case "ip4": + switch c.Protocol().Code { + case ma.P_UDP: + network = "udp4" + case ma.P_TCP: + network = "tcp4" + default: + return false + } + port = c.Value() + case "ip6": + switch c.Protocol().Code { + case ma.P_UDP: + network = "udp6" + case ma.P_TCP: + network = "tcp6" + default: + return false + } + port = c.Value() + } + // Done. + return false + }) + return +} + +func parseTCPNetAddr(a net.Addr) (ma.Multiaddr, error) { + ac, ok := a.(*net.TCPAddr) + if !ok { + return nil, errIncorrectNetAddr + } + + // Get IP Addr + ipm, err := FromIPAndZone(ac.IP, ac.Zone) + if err != nil { + return nil, errIncorrectNetAddr + } + + // Get TCP Addr + tcpm, err := ma.NewMultiaddr(fmt.Sprintf("/tcp/%d", ac.Port)) + if err != nil { + return nil, errIncorrectNetAddr + } + + // Encapsulate + return ipm.Encapsulate(tcpm), nil +} + +func parseUDPNetAddr(a net.Addr) (ma.Multiaddr, error) { + ac, ok := a.(*net.UDPAddr) + if !ok { + return nil, errIncorrectNetAddr + } + + // Get IP Addr + ipm, err := FromIPAndZone(ac.IP, ac.Zone) + if err != nil { + return nil, errIncorrectNetAddr + } + + // Get UDP Addr + udpm, err := ma.NewMultiaddr(fmt.Sprintf("/udp/%d", ac.Port)) + if err != nil { + return nil, errIncorrectNetAddr + } + + // Encapsulate + return ipm.Encapsulate(udpm), nil +} + +func parseIPNetAddr(a net.Addr) (ma.Multiaddr, error) { + ac, ok := a.(*net.IPAddr) + if !ok { + return nil, errIncorrectNetAddr + } + return FromIPAndZone(ac.IP, ac.Zone) +} + +func parseIPPlusNetAddr(a net.Addr) (ma.Multiaddr, error) { + ac, ok := a.(*net.IPNet) + if !ok { + return nil, errIncorrectNetAddr + } + return FromIP(ac.IP) +} + +func parseUnixNetAddr(a net.Addr) (ma.Multiaddr, error) { + ac, ok := a.(*net.UnixAddr) + if !ok { + return nil, errIncorrectNetAddr + } + + path := ac.Name + if runtime.GOOS == "windows" { + // Convert c:\foobar\... to c:/foobar/... + path = filepath.ToSlash(path) + } + if len(path) == 0 || path[0] != '/' { + // convert "" and "c:/..." to "/..." + path = "/" + path + } + + return ma.NewComponent("unix", path) +} diff --git a/net/convert_test.go b/net/convert_test.go new file mode 100644 index 0000000..1701824 --- /dev/null +++ b/net/convert_test.go @@ -0,0 +1,204 @@ +package manet + +import ( + "net" + "runtime" + "testing" + + ma "github.com/multiformats/go-multiaddr" +) + +type GenFunc func() (ma.Multiaddr, error) + +func testConvert(t *testing.T, s string, gen GenFunc) { + m, err := gen() + if err != nil { + t.Fatal("failed to generate.") + } + + if s2 := m.String(); err != nil || s2 != s { + t.Fatal("failed to convert: " + s + " != " + s2) + } +} + +func testToNetAddr(t *testing.T, maddr, ntwk, addr string) { + m, err := ma.NewMultiaddr(maddr) + if err != nil { + t.Fatal("failed to generate.") + } + + naddr, err := ToNetAddr(m) + if addr == "" { // should fail + if err == nil { + t.Fatalf("failed to error: %s", m) + } + return + } + + // shouldn't fail + if err != nil { + t.Fatalf("failed to convert to net addr: %s", m) + } + + if naddr.String() != addr { + t.Fatalf("naddr.Address() == %s != %s", naddr, addr) + } + + if naddr.Network() != ntwk { + t.Fatalf("naddr.Network() == %s != %s", naddr.Network(), ntwk) + } + + // should convert properly + switch ntwk { + case "tcp": + taddr := naddr.(*net.TCPAddr) + if ip, err := ToIP(m); err != nil || !taddr.IP.Equal(ip) { + t.Fatalf("ToIP() and ToNetAddr diverged: %s != %s", taddr, ip) + } + case "udp": + uaddr := naddr.(*net.UDPAddr) + if ip, err := ToIP(m); err != nil || !uaddr.IP.Equal(ip) { + t.Fatalf("ToIP() and ToNetAddr diverged: %s != %s", uaddr, ip) + } + case "ip": + ipaddr := naddr.(*net.IPAddr) + if ip, err := ToIP(m); err != nil || !ipaddr.IP.Equal(ip) { + t.Fatalf("ToIP() and ToNetAddr diverged: %s != %s", ipaddr, ip) + } + } +} + +func TestFromIP4(t *testing.T) { + testConvert(t, "/ip4/10.20.30.40", func() (ma.Multiaddr, error) { + return FromNetAddr(&net.IPAddr{IP: net.ParseIP("10.20.30.40")}) + }) +} + +func TestFromUnix(t *testing.T) { + path := "/C:/foo/bar" + if runtime.GOOS == "windows" { + path = `C:\foo\bar` + } + testConvert(t, "/unix/C:/foo/bar", func() (ma.Multiaddr, error) { + return FromNetAddr(&net.UnixAddr{Name: path, Net: "unix"}) + }) +} + +func TestToUnix(t *testing.T) { + path := "/C:/foo/bar" + if runtime.GOOS == "windows" { + path = `C:\foo\bar` + } + testToNetAddr(t, "/unix/C:/foo/bar", "unix", path) +} + +func TestFromIP6(t *testing.T) { + testConvert(t, "/ip6/2001:4860:0:2001::68", func() (ma.Multiaddr, error) { + return FromNetAddr(&net.IPAddr{IP: net.ParseIP("2001:4860:0:2001::68")}) + }) +} + +func TestFromTCP(t *testing.T) { + testConvert(t, "/ip4/10.20.30.40/tcp/1234", func() (ma.Multiaddr, error) { + return FromNetAddr(&net.TCPAddr{ + IP: net.ParseIP("10.20.30.40"), + Port: 1234, + }) + }) +} + +func TestFromUDP(t *testing.T) { + testConvert(t, "/ip4/10.20.30.40/udp/1234", func() (ma.Multiaddr, error) { + return FromNetAddr(&net.UDPAddr{ + IP: net.ParseIP("10.20.30.40"), + Port: 1234, + }) + }) +} + +func TestThinWaist(t *testing.T) { + addrs := map[string]bool{ + "/ip4/127.0.0.1/udp/1234": true, + "/ip4/127.0.0.1/tcp/1234": true, + "/ip4/127.0.0.1/udp/1234/tcp/1234": true, + "/ip4/127.0.0.1/tcp/12345/ip4/1.2.3.4": true, + "/ip6/::1/tcp/80": true, + "/ip6/::1/udp/80": true, + "/ip6/::1": true, + "/ip6zone/hello/ip6/fe80::1/tcp/80": true, + "/ip6zone/hello/ip6/fe80::1": true, + "/tcp/1234/ip4/1.2.3.4": false, + "/tcp/1234": false, + "/tcp/1234/udp/1234": false, + "/ip4/1.2.3.4/ip4/2.3.4.5": true, + "/ip6/fe80::1/ip4/2.3.4.5": true, + "/ip6zone/hello/ip6/fe80::1/ip4/2.3.4.5": true, + + // Invalid ip6zone usage: + "/ip6zone/hello": false, + "/ip6zone/hello/ip4/1.1.1.1": false, + } + + for a, res := range addrs { + m, err := ma.NewMultiaddr(a) + if err != nil { + t.Fatalf("failed to construct Multiaddr: %s", a) + } + + if IsThinWaist(m) != res { + t.Fatalf("IsThinWaist(%s) != %v", a, res) + } + } +} + +func TestDialArgs(t *testing.T) { + test := func(e_maddr, e_nw, e_host string) { + m, err := ma.NewMultiaddr(e_maddr) + if err != nil { + t.Fatal("failed to construct", e_maddr) + } + + nw, host, err := DialArgs(m) + if err != nil { + t.Fatal("failed to get dial args", e_maddr, m, err) + } + + if nw != e_nw { + t.Error("failed to get udp network Dial Arg", e_nw, nw) + } + + if host != e_host { + t.Error("failed to get host:port Dial Arg", e_host, host) + } + } + + test_error := func(e_maddr string) { + m, err := ma.NewMultiaddr(e_maddr) + if err != nil { + t.Fatal("failed to construct", e_maddr) + } + + _, _, err = DialArgs(m) + if err == nil { + t.Fatal("expected DialArgs to fail on", e_maddr) + } + } + + test("/ip4/127.0.0.1/udp/1234", "udp4", "127.0.0.1:1234") + test("/ip4/127.0.0.1/tcp/4321", "tcp4", "127.0.0.1:4321") + test("/ip6/::1/udp/1234", "udp6", "[::1]:1234") + test("/ip6/::1/tcp/4321", "tcp6", "[::1]:4321") + test("/ip6/::1", "ip6", "::1") // Just an IP + test("/ip4/1.2.3.4", "ip4", "1.2.3.4") // Just an IP + test("/ip6zone/foo/ip6/::1/tcp/4321", "tcp6", "[::1%foo]:4321") // zone + test("/ip6zone/foo/ip6/::1/udp/4321", "udp6", "[::1%foo]:4321") // zone + test("/ip6zone/foo/ip6/::1", "ip6", "::1%foo") // no TCP + test_error("/ip6zone/foo/ip4/127.0.0.1") // IP4 doesn't take zone + test("/ip6zone/foo/ip6/::1/ip6zone/bar", "ip6", "::1%foo") // IP over IP + test_error("/ip6zone/foo/ip6zone/bar/ip6/::1") // Only one zone per IP6 + test("/dns/abc.com/tcp/1234", "tcp", "abc.com:1234") // DNS4:port + test("/dns4/abc.com/tcp/1234", "tcp4", "abc.com:1234") // DNS4:port + test("/dns4/abc.com", "ip4", "abc.com") // Just DNS4 + test("/dns6/abc.com/udp/1234", "udp6", "abc.com:1234") // DNS6:port + test("/dns6/abc.com", "ip6", "abc.com") // Just DNS6 +} diff --git a/net/doc.go b/net/doc.go new file mode 100644 index 0000000..040ad3f --- /dev/null +++ b/net/doc.go @@ -0,0 +1,5 @@ +// Package manet provides Multiaddr specific versions of common +// functions in stdlib's net package. This means wrappers of +// standard net symbols like net.Dial and net.Listen, as well +// as conversion to/from net.Addr. +package manet diff --git a/net/ip.go b/net/ip.go new file mode 100644 index 0000000..1cf9a77 --- /dev/null +++ b/net/ip.go @@ -0,0 +1,118 @@ +package manet + +import ( + "net" + + ma "github.com/multiformats/go-multiaddr" +) + +// Loopback Addresses +var ( + // IP4Loopback is the ip4 loopback multiaddr + IP4Loopback = ma.StringCast("/ip4/127.0.0.1") + + // IP6Loopback is the ip6 loopback multiaddr + IP6Loopback = ma.StringCast("/ip6/::1") + + // IP4MappedIP6Loopback is the IPv4 Mapped IPv6 loopback address. + IP4MappedIP6Loopback = ma.StringCast("/ip6/::ffff:127.0.0.1") +) + +// Unspecified Addresses (used for ) +var ( + IP4Unspecified = ma.StringCast("/ip4/0.0.0.0") + IP6Unspecified = ma.StringCast("/ip6/::") +) + +// IsThinWaist returns whether a Multiaddr starts with "Thin Waist" Protocols. +// This means: /{IP4, IP6}[/{TCP, UDP}] +func IsThinWaist(m ma.Multiaddr) bool { + m = zoneless(m) + if m == nil { + return false + } + p := m.Protocols() + + // nothing? not even a waist. + if len(p) == 0 { + return false + } + + if p[0].Code != ma.P_IP4 && p[0].Code != ma.P_IP6 { + return false + } + + // only IP? still counts. + if len(p) == 1 { + return true + } + + switch p[1].Code { + case ma.P_TCP, ma.P_UDP, ma.P_IP4, ma.P_IP6: + return true + default: + return false + } +} + +// IsIPLoopback returns whether a Multiaddr starts with a "Loopback" IP address +// This means either /ip4/127.*.*.*/*, /ip6/::1/*, or /ip6/::ffff:127.*.*.*.*/*, +// or /ip6zone//ip6//* +func IsIPLoopback(m ma.Multiaddr) bool { + m = zoneless(m) + c, _ := ma.SplitFirst(m) + if c == nil { + return false + } + switch c.Protocol().Code { + case ma.P_IP4, ma.P_IP6: + return net.IP(c.RawValue()).IsLoopback() + } + return false +} + +// IsIP6LinkLocal returns whether a Multiaddr starts with an IPv6 link-local +// multiaddress (with zero or one leading zone). These addresses are non +// routable. +func IsIP6LinkLocal(m ma.Multiaddr) bool { + m = zoneless(m) + c, _ := ma.SplitFirst(m) + if c == nil || c.Protocol().Code != ma.P_IP6 { + return false + } + ip := net.IP(c.RawValue()) + return ip.IsLinkLocalMulticast() || ip.IsLinkLocalUnicast() +} + +// IsIPUnspecified returns whether a Multiaddr starts with an Unspecified IP address +// This means either /ip4/0.0.0.0/* or /ip6/::/* +func IsIPUnspecified(m ma.Multiaddr) bool { + m = zoneless(m) + if m == nil { + return false + } + c, _ := ma.SplitFirst(m) + return net.IP(c.RawValue()).IsUnspecified() +} + +// If m matches [zone,ip6,...], return [ip6,...] +// else if m matches [], [zone], or [zone,...], return nil +// else return m +func zoneless(m ma.Multiaddr) ma.Multiaddr { + head, tail := ma.SplitFirst(m) + if head == nil { + return nil + } + if head.Protocol().Code == ma.P_IP6ZONE { + if tail == nil { + return nil + } + tailhead, _ := ma.SplitFirst(tail) + if tailhead.Protocol().Code != ma.P_IP6 { + return nil + } + return tail + } else { + return m + } +} diff --git a/net/net.go b/net/net.go new file mode 100644 index 0000000..95439af --- /dev/null +++ b/net/net.go @@ -0,0 +1,422 @@ +// Package manet provides Multiaddr +// (https://github.com/multiformats/go-multiaddr) specific versions of common +// functions in Go's standard `net` package. This means wrappers of standard +// net symbols like `net.Dial` and `net.Listen`, as well as conversion to +// and from `net.Addr`. +package manet + +import ( + "context" + "fmt" + "net" + + ma "github.com/multiformats/go-multiaddr" +) + +// Conn is the equivalent of a net.Conn object. It is the +// result of calling the Dial or Listen functions in this +// package, with associated local and remote Multiaddrs. +type Conn interface { + net.Conn + + // LocalMultiaddr returns the local Multiaddr associated + // with this connection + LocalMultiaddr() ma.Multiaddr + + // RemoteMultiaddr returns the remote Multiaddr associated + // with this connection + RemoteMultiaddr() ma.Multiaddr +} + +type halfOpen interface { + net.Conn + CloseRead() error + CloseWrite() error +} + +func wrap(nconn net.Conn, laddr, raddr ma.Multiaddr) Conn { + endpts := maEndpoints{ + laddr: laddr, + raddr: raddr, + } + // This sucks. However, it's the only way to reliably expose the + // underlying methods. This way, users that need access to, e.g., + // CloseRead and CloseWrite, can do so via type assertions. + switch nconn := nconn.(type) { + case *net.TCPConn: + return &struct { + *net.TCPConn + maEndpoints + }{nconn, endpts} + case *net.UDPConn: + return &struct { + *net.UDPConn + maEndpoints + }{nconn, endpts} + case *net.IPConn: + return &struct { + *net.IPConn + maEndpoints + }{nconn, endpts} + case *net.UnixConn: + return &struct { + *net.UnixConn + maEndpoints + }{nconn, endpts} + case halfOpen: + return &struct { + halfOpen + maEndpoints + }{nconn, endpts} + default: + return &struct { + net.Conn + maEndpoints + }{nconn, endpts} + } +} + +// WrapNetConn wraps a net.Conn object with a Multiaddr friendly Conn. +// +// This function does it's best to avoid "hiding" methods exposed by the wrapped +// type. Guarantees: +// +// * If the wrapped connection exposes the "half-open" closer methods +// (CloseWrite, CloseRead), these will be available on the wrapped connection +// via type assertions. +// * If the wrapped connection is a UnixConn, IPConn, TCPConn, or UDPConn, all +// methods on these wrapped connections will be available via type assertions. +func WrapNetConn(nconn net.Conn) (Conn, error) { + if nconn == nil { + return nil, fmt.Errorf("failed to convert nconn.LocalAddr: nil") + } + + laddr, err := FromNetAddr(nconn.LocalAddr()) + if err != nil { + return nil, fmt.Errorf("failed to convert nconn.LocalAddr: %s", err) + } + + raddr, err := FromNetAddr(nconn.RemoteAddr()) + if err != nil { + return nil, fmt.Errorf("failed to convert nconn.RemoteAddr: %s", err) + } + + return wrap(nconn, laddr, raddr), nil +} + +type maEndpoints struct { + laddr ma.Multiaddr + raddr ma.Multiaddr +} + +// LocalMultiaddr returns the local address associated with +// this connection +func (c *maEndpoints) LocalMultiaddr() ma.Multiaddr { + return c.laddr +} + +// RemoteMultiaddr returns the remote address associated with +// this connection +func (c *maEndpoints) RemoteMultiaddr() ma.Multiaddr { + return c.raddr +} + +// Dialer contains options for connecting to an address. It +// is effectively the same as net.Dialer, but its LocalAddr +// and RemoteAddr options are Multiaddrs, instead of net.Addrs. +type Dialer struct { + + // Dialer is just an embedded net.Dialer, with all its options. + net.Dialer + + // LocalAddr is the local address to use when dialing an + // address. The address must be of a compatible type for the + // network being dialed. + // If nil, a local address is automatically chosen. + LocalAddr ma.Multiaddr +} + +// Dial connects to a remote address, using the options of the +// Dialer. Dialer uses an underlying net.Dialer to Dial a +// net.Conn, then wraps that in a Conn object (with local and +// remote Multiaddrs). +func (d *Dialer) Dial(remote ma.Multiaddr) (Conn, error) { + return d.DialContext(context.Background(), remote) +} + +// DialContext allows to provide a custom context to Dial(). +func (d *Dialer) DialContext(ctx context.Context, remote ma.Multiaddr) (Conn, error) { + // if a LocalAddr is specified, use it on the embedded dialer. + if d.LocalAddr != nil { + // convert our multiaddr to net.Addr friendly + naddr, err := ToNetAddr(d.LocalAddr) + if err != nil { + return nil, err + } + + // set the dialer's LocalAddr as naddr + d.Dialer.LocalAddr = naddr + } + + // get the net.Dial friendly arguments from the remote addr + rnet, rnaddr, err := DialArgs(remote) + if err != nil { + return nil, err + } + + // ok, Dial! + var nconn net.Conn + switch rnet { + case "tcp", "tcp4", "tcp6", "udp", "udp4", "udp6", "unix": + nconn, err = d.Dialer.DialContext(ctx, rnet, rnaddr) + if err != nil { + return nil, err + } + default: + return nil, fmt.Errorf("unrecognized network: %s", rnet) + } + + // get local address (pre-specified or assigned within net.Conn) + local := d.LocalAddr + // This block helps us avoid parsing addresses in transports (such as unix + // sockets) that don't have local addresses when dialing out. + if local == nil && nconn.LocalAddr().String() != "" { + local, err = FromNetAddr(nconn.LocalAddr()) + if err != nil { + return nil, err + } + } + return wrap(nconn, local, remote), nil +} + +// Dial connects to a remote address. It uses an underlying net.Conn, +// then wraps it in a Conn object (with local and remote Multiaddrs). +func Dial(remote ma.Multiaddr) (Conn, error) { + return (&Dialer{}).Dial(remote) +} + +// A Listener is a generic network listener for stream-oriented protocols. +// it uses an embedded net.Listener, overriding net.Listener.Accept to +// return a Conn and providing Multiaddr. +type Listener interface { + // Accept waits for and returns the next connection to the listener. + // Returns a Multiaddr friendly Conn + Accept() (Conn, error) + + // Close closes the listener. + // Any blocked Accept operations will be unblocked and return errors. + Close() error + + // Multiaddr returns the listener's (local) Multiaddr. + Multiaddr() ma.Multiaddr + + // Addr returns the net.Listener's network address. + Addr() net.Addr +} + +type netListenerAdapter struct { + Listener +} + +func (nla *netListenerAdapter) Accept() (net.Conn, error) { + return nla.Listener.Accept() +} + +// NetListener turns this Listener into a net.Listener. +// +// * Connections returned from Accept implement multiaddr/net Conn. +// * Calling WrapNetListener on the net.Listener returned by this function will +// return the original (underlying) multiaddr/net Listener. +func NetListener(l Listener) net.Listener { + return &netListenerAdapter{l} +} + +// maListener implements Listener +type maListener struct { + net.Listener + laddr ma.Multiaddr +} + +// Accept waits for and returns the next connection to the listener. +// Returns a Multiaddr friendly Conn +func (l *maListener) Accept() (Conn, error) { + nconn, err := l.Listener.Accept() + if err != nil { + return nil, err + } + + var raddr ma.Multiaddr + // This block protects us in transports (i.e. unix sockets) that don't have + // remote addresses for inbound connections. + if nconn.RemoteAddr().String() != "" { + raddr, err = FromNetAddr(nconn.RemoteAddr()) + if err != nil { + return nil, fmt.Errorf("failed to convert conn.RemoteAddr: %s", err) + } + } + + return wrap(nconn, l.laddr, raddr), nil +} + +// Multiaddr returns the listener's (local) Multiaddr. +func (l *maListener) Multiaddr() ma.Multiaddr { + return l.laddr +} + +// Addr returns the listener's network address. +func (l *maListener) Addr() net.Addr { + return l.Listener.Addr() +} + +// Listen announces on the local network address laddr. +// The Multiaddr must be a "ThinWaist" stream-oriented network: +// ip4/tcp, ip6/tcp, (TODO: unix, unixpacket) +// See Dial for the syntax of laddr. +func Listen(laddr ma.Multiaddr) (Listener, error) { + + // get the net.Listen friendly arguments from the remote addr + lnet, lnaddr, err := DialArgs(laddr) + if err != nil { + return nil, err + } + + nl, err := net.Listen(lnet, lnaddr) + if err != nil { + return nil, err + } + + // we want to fetch the new multiaddr from the listener, as it may + // have resolved to some other value. WrapNetListener does it for us. + return WrapNetListener(nl) +} + +// WrapNetListener wraps a net.Listener with a manet.Listener. +func WrapNetListener(nl net.Listener) (Listener, error) { + if nla, ok := nl.(*netListenerAdapter); ok { + return nla.Listener, nil + } + + laddr, err := FromNetAddr(nl.Addr()) + if err != nil { + return nil, err + } + + return &maListener{ + Listener: nl, + laddr: laddr, + }, nil +} + +// A PacketConn is a generic packet oriented network connection which uses an +// underlying net.PacketConn, wrapped with the locally bound Multiaddr. +type PacketConn interface { + net.PacketConn + + LocalMultiaddr() ma.Multiaddr + + ReadFromMultiaddr(b []byte) (int, ma.Multiaddr, error) + WriteToMultiaddr(b []byte, maddr ma.Multiaddr) (int, error) +} + +// maPacketConn implements PacketConn +type maPacketConn struct { + net.PacketConn + laddr ma.Multiaddr +} + +var _ PacketConn = (*maPacketConn)(nil) + +// LocalMultiaddr returns the bound local Multiaddr. +func (l *maPacketConn) LocalMultiaddr() ma.Multiaddr { + return l.laddr +} + +func (l *maPacketConn) ReadFromMultiaddr(b []byte) (int, ma.Multiaddr, error) { + n, addr, err := l.ReadFrom(b) + maddr, _ := FromNetAddr(addr) + return n, maddr, err +} + +func (l *maPacketConn) WriteToMultiaddr(b []byte, maddr ma.Multiaddr) (int, error) { + addr, err := ToNetAddr(maddr) + if err != nil { + return 0, err + } + return l.WriteTo(b, addr) +} + +// ListenPacket announces on the local network address laddr. +// The Multiaddr must be a packet driven network, like udp4 or udp6. +// See Dial for the syntax of laddr. +func ListenPacket(laddr ma.Multiaddr) (PacketConn, error) { + lnet, lnaddr, err := DialArgs(laddr) + if err != nil { + return nil, err + } + + pc, err := net.ListenPacket(lnet, lnaddr) + if err != nil { + return nil, err + } + + // We want to fetch the new multiaddr from the listener, as it may + // have resolved to some other value. WrapPacketConn does this. + return WrapPacketConn(pc) +} + +// WrapPacketConn wraps a net.PacketConn with a manet.PacketConn. +func WrapPacketConn(pc net.PacketConn) (PacketConn, error) { + laddr, err := FromNetAddr(pc.LocalAddr()) + if err != nil { + return nil, err + } + + return &maPacketConn{ + PacketConn: pc, + laddr: laddr, + }, nil +} + +// InterfaceMultiaddrs will return the addresses matching net.InterfaceAddrs +func InterfaceMultiaddrs() ([]ma.Multiaddr, error) { + addrs, err := net.InterfaceAddrs() + if err != nil { + return nil, err + } + + maddrs := make([]ma.Multiaddr, len(addrs)) + for i, a := range addrs { + maddrs[i], err = FromNetAddr(a) + if err != nil { + return nil, err + } + } + return maddrs, nil +} + +// AddrMatch returns the Multiaddrs that match the protocol stack on addr +func AddrMatch(match ma.Multiaddr, addrs []ma.Multiaddr) []ma.Multiaddr { + + // we should match transports entirely. + p1s := match.Protocols() + + out := make([]ma.Multiaddr, 0, len(addrs)) + for _, a := range addrs { + p2s := a.Protocols() + if len(p1s) != len(p2s) { + continue + } + + match := true + for i, p2 := range p2s { + if p1s[i].Code != p2.Code { + match = false + break + } + } + if match { + out = append(out, a) + } + } + return out +} diff --git a/net/net_test.go b/net/net_test.go new file mode 100644 index 0000000..51a71e5 --- /dev/null +++ b/net/net_test.go @@ -0,0 +1,666 @@ +package manet + +import ( + "bytes" + "fmt" + "io/ioutil" + "net" + "os" + "path/filepath" + "sync" + "testing" + "time" + + ma "github.com/multiformats/go-multiaddr" +) + +func newMultiaddr(t *testing.T, m string) ma.Multiaddr { + maddr, err := ma.NewMultiaddr(m) + if err != nil { + t.Fatal("failed to construct multiaddr:", m, err) + } + return maddr +} + +func TestDial(t *testing.T) { + + listener, err := net.Listen("tcp", "127.0.0.1:4321") + if err != nil { + t.Fatal("failed to listen") + } + + var wg sync.WaitGroup + wg.Add(1) + go func() { + + cB, err := listener.Accept() + if err != nil { + t.Fatal("failed to accept") + } + + // echo out + buf := make([]byte, 1024) + for { + _, err := cB.Read(buf) + if err != nil { + break + } + cB.Write(buf) + } + + wg.Done() + }() + + maddr := newMultiaddr(t, "/ip4/127.0.0.1/tcp/4321") + cA, err := Dial(maddr) + if err != nil { + t.Fatal("failed to dial") + } + + buf := make([]byte, 1024) + if _, err := cA.Write([]byte("beep boop")); err != nil { + t.Fatal("failed to write:", err) + } + + if _, err := cA.Read(buf); err != nil { + t.Fatal("failed to read:", buf, err) + } + + if !bytes.Equal(buf[:9], []byte("beep boop")) { + t.Fatal("failed to echo:", buf) + } + + maddr2 := cA.RemoteMultiaddr() + if !maddr2.Equal(maddr) { + t.Fatal("remote multiaddr not equal:", maddr, maddr2) + } + + cA.Close() + wg.Wait() +} + +func TestUnixSockets(t *testing.T) { + dir, err := ioutil.TempDir(os.TempDir(), "manettest") + if err != nil { + t.Fatal(err) + } + path := filepath.Join(dir, "listen.sock") + maddr := newMultiaddr(t, "/unix/"+path) + + listener, err := Listen(maddr) + if err != nil { + t.Fatal(err) + } + + payload := []byte("hello") + + // listen + done := make(chan struct{}, 1) + go func() { + conn, err := listener.Accept() + if err != nil { + t.Fatal(err) + } + defer conn.Close() + buf := make([]byte, 1024) + n, err := conn.Read(buf) + if err != nil { + t.Fatal(err) + } + if n != len(payload) { + t.Fatal("failed to read appropriate number of bytes") + } + if !bytes.Equal(buf[0:n], payload) { + t.Fatal("payload did not match") + } + done <- struct{}{} + }() + + // dial + conn, err := Dial(maddr) + if err != nil { + t.Fatal(err) + } + n, err := conn.Write(payload) + if err != nil { + t.Fatal(err) + } + if n != len(payload) { + t.Fatal("failed to write appropriate number of bytes") + } + select { + case <-done: + case <-time.After(1 * time.Second): + t.Fatal("timed out waiting for read") + } +} + +func TestListen(t *testing.T) { + + maddr := newMultiaddr(t, "/ip4/127.0.0.1/tcp/4322") + listener, err := Listen(maddr) + if err != nil { + t.Fatal("failed to listen") + } + + var wg sync.WaitGroup + wg.Add(1) + go func() { + + cB, err := listener.Accept() + if err != nil { + t.Fatal("failed to accept") + } + + if !cB.LocalMultiaddr().Equal(maddr) { + t.Fatal("local multiaddr not equal:", maddr, cB.LocalMultiaddr()) + } + + // echo out + buf := make([]byte, 1024) + for { + _, err := cB.Read(buf) + if err != nil { + break + } + cB.Write(buf) + } + + wg.Done() + }() + + cA, err := net.Dial("tcp", "127.0.0.1:4322") + if err != nil { + t.Fatal("failed to dial") + } + + buf := make([]byte, 1024) + if _, err := cA.Write([]byte("beep boop")); err != nil { + t.Fatal("failed to write:", err) + } + + if _, err := cA.Read(buf); err != nil { + t.Fatal("failed to read:", buf, err) + } + + if !bytes.Equal(buf[:9], []byte("beep boop")) { + t.Fatal("failed to echo:", buf) + } + + maddr2, err := FromNetAddr(cA.RemoteAddr()) + if err != nil { + t.Fatal("failed to convert", err) + } + if !maddr2.Equal(maddr) { + t.Fatal("remote multiaddr not equal:", maddr, maddr2) + } + + cA.Close() + wg.Wait() +} + +func TestListenAddrs(t *testing.T) { + + test := func(addr, resaddr string, succeed bool) { + if resaddr == "" { + resaddr = addr + } + + maddr := newMultiaddr(t, addr) + l, err := Listen(maddr) + if !succeed { + if err == nil { + t.Fatal("succeeded in listening", addr) + } + return + } + if succeed && err != nil { + t.Error("failed to listen", addr, err) + } + if l == nil { + t.Error("failed to listen", addr, succeed, err) + } + if l.Multiaddr().String() != resaddr { + t.Error("listen addr did not resolve properly", l.Multiaddr().String(), resaddr, succeed, err) + } + + if err = l.Close(); err != nil { + t.Fatal("failed to close listener", addr, err) + } + } + + test("/ip4/127.0.0.1/tcp/4324", "", true) + test("/ip4/127.0.0.1/udp/4325", "", false) + test("/ip4/127.0.0.1/udp/4326/udt", "", false) + test("/ip4/0.0.0.0/tcp/4324", "", true) + test("/ip4/0.0.0.0/udp/4325", "", false) + test("/ip4/0.0.0.0/udp/4326/udt", "", false) + + test("/ip6/::1/tcp/4324", "", true) + test("/ip6/::1/udp/4325", "", false) + test("/ip6/::1/udp/4326/udt", "", false) + test("/ip6/::/tcp/4324", "", true) + test("/ip6/::/udp/4325", "", false) + test("/ip6/::/udp/4326/udt", "", false) + + /* "An implementation should also support the concept of a "default" + * zone for each scope. And, when supported, the index value zero + * at each scope SHOULD be reserved to mean "use the default zone"." + * -- rfc4007. So, this _should_ work everywhere(?). */ + test("/ip6zone/0/ip6/::1/tcp/4324", "/ip6/::1/tcp/4324", true) + test("/ip6zone/0/ip6/::1/udp/4324", "", false) +} + +func TestListenAndDial(t *testing.T) { + + maddr := newMultiaddr(t, "/ip4/127.0.0.1/tcp/4323") + listener, err := Listen(maddr) + if err != nil { + t.Fatal("failed to listen") + } + + var wg sync.WaitGroup + wg.Add(1) + go func() { + + cB, err := listener.Accept() + if err != nil { + t.Fatal("failed to accept") + } + + if !cB.LocalMultiaddr().Equal(maddr) { + t.Fatal("local multiaddr not equal:", maddr, cB.LocalMultiaddr()) + } + + // echo out + buf := make([]byte, 1024) + for { + _, err := cB.Read(buf) + if err != nil { + break + } + cB.Write(buf) + } + + wg.Done() + }() + + cA, err := Dial(newMultiaddr(t, "/ip4/127.0.0.1/tcp/4323")) + if err != nil { + t.Fatal("failed to dial") + } + + buf := make([]byte, 1024) + if _, err := cA.Write([]byte("beep boop")); err != nil { + t.Fatal("failed to write:", err) + } + + if _, err := cA.Read(buf); err != nil { + t.Fatal("failed to read:", buf, err) + } + + if !bytes.Equal(buf[:9], []byte("beep boop")) { + t.Fatal("failed to echo:", buf) + } + + maddr2 := cA.RemoteMultiaddr() + if !maddr2.Equal(maddr) { + t.Fatal("remote multiaddr not equal:", maddr, maddr2) + } + + cA.Close() + wg.Wait() +} + +func TestListenPacketAndDial(t *testing.T) { + maddr := newMultiaddr(t, "/ip4/127.0.0.1/udp/4324") + pc, err := ListenPacket(maddr) + if err != nil { + t.Fatal("failed to listen", err) + } + + var wg sync.WaitGroup + wg.Add(1) + + go func() { + if !pc.LocalMultiaddr().Equal(maddr) { + t.Fatal("connection multiaddr not equal:", maddr, pc.LocalMultiaddr()) + } + + buffer := make([]byte, 1024) + _, addr, err := pc.ReadFrom(buffer) + if err != nil { + t.Fatal("failed to read into buffer", err) + } + pc.WriteTo(buffer, addr) + + wg.Done() + }() + + cn, err := Dial(maddr) + if err != nil { + t.Fatal("failed to dial", err) + } + + buf := make([]byte, 1024) + if _, err := cn.Write([]byte("beep boop")); err != nil { + t.Fatal("failed to write", err) + } + + if _, err := cn.Read(buf); err != nil { + t.Fatal("failed to read:", buf, err) + } + + if !bytes.Equal(buf[:9], []byte("beep boop")) { + t.Fatal("failed to echk:", buf) + } + + maddr2 := cn.RemoteMultiaddr() + if !maddr2.Equal(maddr) { + t.Fatal("remote multiaddr not equal:", maddr, maddr2) + } + + cn.Close() + pc.Close() + wg.Wait() +} + +func TestIPLoopback(t *testing.T) { + if IP4Loopback.String() != "/ip4/127.0.0.1" { + t.Error("IP4Loopback incorrect:", IP4Loopback) + } + + if IP6Loopback.String() != "/ip6/::1" { + t.Error("IP6Loopback incorrect:", IP6Loopback) + } + + if IP4MappedIP6Loopback.String() != "/ip6/::ffff:127.0.0.1" { + t.Error("IP4MappedIP6Loopback incorrect:", IP4MappedIP6Loopback) + } + + if !IsIPLoopback(IP4Loopback) { + t.Error("IsIPLoopback failed (IP4Loopback)") + } + + if !IsIPLoopback(newMultiaddr(t, "/ip4/127.1.80.9")) { + t.Error("IsIPLoopback failed (/ip4/127.1.80.9)") + } + + if IsIPLoopback(newMultiaddr(t, "/ip4/112.123.11.1")) { + t.Error("IsIPLoopback false positive (/ip4/112.123.11.1)") + } + + if IsIPLoopback(newMultiaddr(t, "/ip4/192.168.0.1/ip6/::1")) { + t.Error("IsIPLoopback false positive (/ip4/192.168.0.1/ip6/::1)") + } + + if !IsIPLoopback(IP6Loopback) { + t.Error("IsIPLoopback failed (IP6Loopback)") + } + + if !IsIPLoopback(newMultiaddr(t, "/ip6/127.0.0.1")) { + t.Error("IsIPLoopback failed (/ip6/127.0.0.1)") + } + + if !IsIPLoopback(newMultiaddr(t, "/ip6/127.99.3.2")) { + t.Error("IsIPLoopback failed (/ip6/127.99.3.2)") + } + + if IsIPLoopback(newMultiaddr(t, "/ip6/::fffa:127.99.3.2")) { + t.Error("IsIPLoopback false positive (/ip6/::fffa:127.99.3.2)") + } + + if !IsIPLoopback(newMultiaddr(t, "/ip6zone/0/ip6/::1")) { + t.Error("IsIPLoopback failed (/ip6zone/0/ip6/::1)") + } + + if !IsIPLoopback(newMultiaddr(t, "/ip6zone/xxx/ip6/::1")) { + t.Error("IsIPLoopback failed (/ip6zone/xxx/ip6/::1)") + } + + if IsIPLoopback(newMultiaddr(t, "/ip6zone/0/ip6/1::1")) { + t.Errorf("IsIPLoopback false positive (/ip6zone/0/ip6/1::1)") + } +} + +func TestIPUnspecified(t *testing.T) { + if IP4Unspecified.String() != "/ip4/0.0.0.0" { + t.Error("IP4Unspecified incorrect:", IP4Unspecified) + } + + if IP6Unspecified.String() != "/ip6/::" { + t.Error("IP6Unspecified incorrect:", IP6Unspecified) + } + + if !IsIPUnspecified(IP4Unspecified) { + t.Error("IsIPUnspecified failed (IP4Unspecified)") + } + + if !IsIPUnspecified(IP6Unspecified) { + t.Error("IsIPUnspecified failed (IP6Unspecified)") + } + + if !IsIPUnspecified(newMultiaddr(t, "/ip6zone/xxx/ip6/::")) { + t.Error("IsIPUnspecified failed (/ip6zone/xxx/ip6/::)") + } +} + +func TestIP6LinkLocal(t *testing.T) { + for a := 0; a < 65536; a++ { + isLinkLocal := (a&0xffc0 == 0xfe80 || a&0xff0f == 0xff02) + m := newMultiaddr(t, fmt.Sprintf("/ip6/%x::1", a)) + if IsIP6LinkLocal(m) != isLinkLocal { + t.Errorf("IsIP6LinkLocal failed (%s != %v)", m, isLinkLocal) + } + } + + if !IsIP6LinkLocal(newMultiaddr(t, "/ip6zone/hello/ip6/fe80::9999")) { + t.Error("IsIP6LinkLocal failed (/ip6/fe80::9999)") + } +} + +func TestConvertNetAddr(t *testing.T) { + m1 := newMultiaddr(t, "/ip4/1.2.3.4/tcp/4001") + + n1, err := ToNetAddr(m1) + if err != nil { + t.Fatal(err) + } + + m2, err := FromNetAddr(n1) + if err != nil { + t.Fatal(err) + } + + if m1.String() != m2.String() { + t.Fatal("ToNetAddr + FromNetAddr did not work") + } +} + +func TestWrapNetConn(t *testing.T) { + // test WrapNetConn nil + if _, err := WrapNetConn(nil); err == nil { + t.Error("WrapNetConn(nil) should return an error") + } + + checkErr := func(err error, s string) { + if err != nil { + t.Fatal(s, err) + } + } + + listener, err := net.Listen("tcp", "127.0.0.1:0") + checkErr(err, "failed to listen") + + var wg sync.WaitGroup + defer wg.Wait() + wg.Add(1) + go func() { + defer wg.Done() + cB, err := listener.Accept() + checkErr(err, "failed to accept") + _ = cB.(halfOpen) + cB.Close() + }() + + cA, err := net.Dial("tcp", listener.Addr().String()) + checkErr(err, "failed to dial") + defer cA.Close() + _ = cA.(halfOpen) + + lmaddr, err := FromNetAddr(cA.LocalAddr()) + checkErr(err, "failed to get local addr") + rmaddr, err := FromNetAddr(cA.RemoteAddr()) + checkErr(err, "failed to get remote addr") + + mcA, err := WrapNetConn(cA) + checkErr(err, "failed to wrap conn") + + _ = mcA.(halfOpen) + + if mcA.LocalAddr().String() != cA.LocalAddr().String() { + t.Error("wrapped conn local addr differs") + } + if mcA.RemoteAddr().String() != cA.RemoteAddr().String() { + t.Error("wrapped conn remote addr differs") + } + if mcA.LocalMultiaddr().String() != lmaddr.String() { + t.Error("wrapped conn local maddr differs") + } + if mcA.RemoteMultiaddr().String() != rmaddr.String() { + t.Error("wrapped conn remote maddr differs") + } +} + +func TestAddrMatch(t *testing.T) { + + test := func(m ma.Multiaddr, input, expect []ma.Multiaddr) { + actual := AddrMatch(m, input) + testSliceEqual(t, expect, actual) + } + + a := []ma.Multiaddr{ + newMultiaddr(t, "/ip4/1.2.3.4/tcp/1234"), + newMultiaddr(t, "/ip4/1.2.3.4/tcp/2345"), + newMultiaddr(t, "/ip4/1.2.3.4/tcp/1234/tcp/2345"), + newMultiaddr(t, "/ip4/1.2.3.4/tcp/1234/tcp/2345"), + newMultiaddr(t, "/ip4/1.2.3.4/tcp/1234/udp/1234"), + newMultiaddr(t, "/ip4/1.2.3.4/tcp/1234/udp/1234"), + newMultiaddr(t, "/ip4/1.2.3.4/tcp/1234/ip6/::1"), + newMultiaddr(t, "/ip4/1.2.3.4/tcp/1234/ip6/::1"), + newMultiaddr(t, "/ip6/::1/tcp/1234"), + newMultiaddr(t, "/ip6/::1/tcp/2345"), + newMultiaddr(t, "/ip6/::1/tcp/1234/tcp/2345"), + newMultiaddr(t, "/ip6/::1/tcp/1234/tcp/2345"), + newMultiaddr(t, "/ip6/::1/tcp/1234/udp/1234"), + newMultiaddr(t, "/ip6/::1/tcp/1234/udp/1234"), + newMultiaddr(t, "/ip6/::1/tcp/1234/ip6/::1"), + newMultiaddr(t, "/ip6/::1/tcp/1234/ip6/::1"), + } + + test(a[0], a, []ma.Multiaddr{ + newMultiaddr(t, "/ip4/1.2.3.4/tcp/1234"), + newMultiaddr(t, "/ip4/1.2.3.4/tcp/2345"), + }) + test(a[2], a, []ma.Multiaddr{ + newMultiaddr(t, "/ip4/1.2.3.4/tcp/1234/tcp/2345"), + newMultiaddr(t, "/ip4/1.2.3.4/tcp/1234/tcp/2345"), + }) + test(a[4], a, []ma.Multiaddr{ + newMultiaddr(t, "/ip4/1.2.3.4/tcp/1234/udp/1234"), + newMultiaddr(t, "/ip4/1.2.3.4/tcp/1234/udp/1234"), + }) + test(a[6], a, []ma.Multiaddr{ + newMultiaddr(t, "/ip4/1.2.3.4/tcp/1234/ip6/::1"), + newMultiaddr(t, "/ip4/1.2.3.4/tcp/1234/ip6/::1"), + }) + test(a[8], a, []ma.Multiaddr{ + newMultiaddr(t, "/ip6/::1/tcp/1234"), + newMultiaddr(t, "/ip6/::1/tcp/2345"), + }) + test(a[10], a, []ma.Multiaddr{ + newMultiaddr(t, "/ip6/::1/tcp/1234/tcp/2345"), + newMultiaddr(t, "/ip6/::1/tcp/1234/tcp/2345"), + }) + test(a[12], a, []ma.Multiaddr{ + newMultiaddr(t, "/ip6/::1/tcp/1234/udp/1234"), + newMultiaddr(t, "/ip6/::1/tcp/1234/udp/1234"), + }) + test(a[14], a, []ma.Multiaddr{ + newMultiaddr(t, "/ip6/::1/tcp/1234/ip6/::1"), + newMultiaddr(t, "/ip6/::1/tcp/1234/ip6/::1"), + }) + +} + +func testSliceEqual(t *testing.T, a, b []ma.Multiaddr) { + if len(a) != len(b) { + t.Error("differ", a, b) + } + for i, addrA := range a { + if !addrA.Equal(b[i]) { + t.Error("differ", a, b) + } + } +} + +func TestInterfaceAddressesWorks(t *testing.T) { + _, err := InterfaceMultiaddrs() + if err != nil { + t.Fatal(err) + } +} + +func TestNetListener(t *testing.T) { + listener, err := net.Listen("tcp", "127.0.0.1:1234") + if err != nil { + t.Fatal(err) + } + defer listener.Close() + malist, err := WrapNetListener(listener) + if err != nil { + t.Fatal(err) + } + if !malist.Multiaddr().Equal(newMultiaddr(t, "/ip4/127.0.0.1/tcp/1234")) { + t.Fatal("unexpected multiaddr") + } + + go func() { + c, err := Dial(malist.Multiaddr()) + if err != nil { + t.Fatal("failed to dial") + } + if !c.RemoteMultiaddr().Equal(malist.Multiaddr()) { + t.Fatal("dialed wrong target") + } + c.Close() + + c, err = Dial(malist.Multiaddr()) + if err != nil { + t.Fatal("failed to dial") + } + c.Close() + }() + + c, err := malist.Accept() + if err != nil { + t.Fatal(err) + } + c.Close() + netList := NetListener(malist) + malist2, err := WrapNetListener(netList) + if err != nil { + t.Fatal(err) + } + if malist2 != malist { + t.Fatal("expected WrapNetListener(NetListener(malist)) == malist") + } + nc, err := netList.Accept() + if err != nil { + t.Fatal(err) + } + if !nc.(Conn).LocalMultiaddr().Equal(malist.Multiaddr()) { + t.Fatal("wrong multiaddr on conn") + } + nc.Close() +} diff --git a/net/private.go b/net/private.go new file mode 100644 index 0000000..26e547c --- /dev/null +++ b/net/private.go @@ -0,0 +1,116 @@ +package manet + +import ( + "net" + + ma "github.com/multiformats/go-multiaddr" +) + +// Private4 and Private6 are well-known private networks +var Private4, Private6 []*net.IPNet +var privateCIDR4 = []string{ + // localhost + "127.0.0.0/8", + // private networks + "10.0.0.0/8", + "100.64.0.0/10", + "172.16.0.0/12", + "192.168.0.0/16", + // link local + "169.254.0.0/16", +} +var privateCIDR6 = []string{ + // localhost + "::1/128", + // ULA reserved + "fc00::/7", + // link local + "fe80::/10", +} + +// Unroutable4 and Unroutable6 are well known unroutable address ranges +var Unroutable4, Unroutable6 []*net.IPNet +var unroutableCIDR4 = []string{ + "0.0.0.0/8", + "192.0.0.0/26", + "192.0.2.0/24", + "192.88.99.0/24", + "198.18.0.0/15", + "198.51.100.0/24", + "203.0.113.0/24", + "224.0.0.0/4", + "240.0.0.0/4", + "255.255.255.255/32", +} +var unroutableCIDR6 = []string{ + "ff00::/8", +} + +func init() { + Private4 = parseCIDR(privateCIDR4) + Private6 = parseCIDR(privateCIDR6) + Unroutable4 = parseCIDR(unroutableCIDR4) + Unroutable6 = parseCIDR(unroutableCIDR6) +} + +func parseCIDR(cidrs []string) []*net.IPNet { + ipnets := make([]*net.IPNet, len(cidrs)) + for i, cidr := range cidrs { + _, ipnet, err := net.ParseCIDR(cidr) + if err != nil { + panic(err) + } + ipnets[i] = ipnet + } + return ipnets +} + +// IsPublicAddr retruns true if the IP part of the multiaddr is a publicly routable address +func IsPublicAddr(a ma.Multiaddr) bool { + isPublic := false + ma.ForEach(a, func(c ma.Component) bool { + switch c.Protocol().Code { + case ma.P_IP6ZONE: + return true + default: + return false + case ma.P_IP4: + ip := net.IP(c.RawValue()) + isPublic = !inAddrRange(ip, Private4) && !inAddrRange(ip, Unroutable4) + case ma.P_IP6: + ip := net.IP(c.RawValue()) + isPublic = !inAddrRange(ip, Private6) && !inAddrRange(ip, Unroutable6) + } + return false + }) + return isPublic +} + +// IsPrivateAddr returns true if the IP part of the mutiaddr is in a private network +func IsPrivateAddr(a ma.Multiaddr) bool { + isPrivate := false + ma.ForEach(a, func(c ma.Component) bool { + switch c.Protocol().Code { + case ma.P_IP6ZONE: + return true + default: + return false + case ma.P_IP4: + isPrivate = inAddrRange(net.IP(c.RawValue()), Private4) + case ma.P_IP6: + isPrivate = inAddrRange(net.IP(c.RawValue()), Private6) + } + return false + }) + return isPrivate +} + +func inAddrRange(ip net.IP, ipnets []*net.IPNet) bool { + for _, ipnet := range ipnets { + if ipnet.Contains(ip) { + return true + } + } + + return false +} diff --git a/net/private_test.go b/net/private_test.go new file mode 100644 index 0000000..a4380a5 --- /dev/null +++ b/net/private_test.go @@ -0,0 +1,48 @@ +package manet + +import ( + "testing" + + ma "github.com/multiformats/go-multiaddr" +) + +func TestIsPublicAddr(t *testing.T) { + a, err := ma.NewMultiaddr("/ip4/192.168.1.1/tcp/80") + if err != nil { + t.Fatal(err) + } + + if IsPublicAddr(a) { + t.Fatal("192.168.1.1 is not a public address!") + } + + if !IsPrivateAddr(a) { + t.Fatal("192.168.1.1 is a private address!") + } + + a, err = ma.NewMultiaddr("/ip4/1.1.1.1/tcp/80") + if err != nil { + t.Fatal(err) + } + + if !IsPublicAddr(a) { + t.Fatal("1.1.1.1 is a public address!") + } + + if IsPrivateAddr(a) { + t.Fatal("1.1.1.1 is not a private address!") + } + + a, err = ma.NewMultiaddr("/tcp/80/ip4/1.1.1.1") + if err != nil { + t.Fatal(err) + } + + if IsPublicAddr(a) { + t.Fatal("shouldn't consider an address that starts with /tcp/ as *public*") + } + + if IsPrivateAddr(a) { + t.Fatal("shouldn't consider an address that starts with /tcp/ as *private*") + } +} diff --git a/net/registry.go b/net/registry.go new file mode 100644 index 0000000..fc6561c --- /dev/null +++ b/net/registry.go @@ -0,0 +1,133 @@ +package manet + +import ( + "fmt" + "net" + "sync" + + ma "github.com/multiformats/go-multiaddr" +) + +// FromNetAddrFunc is a generic function which converts a net.Addr to Multiaddress +type FromNetAddrFunc func(a net.Addr) (ma.Multiaddr, error) + +// ToNetAddrFunc is a generic function which converts a Multiaddress to net.Addr +type ToNetAddrFunc func(ma ma.Multiaddr) (net.Addr, error) + +var defaultCodecs = NewCodecMap() + +func init() { + defaultCodecs.RegisterFromNetAddr(parseTCPNetAddr, "tcp", "tcp4", "tcp6") + defaultCodecs.RegisterFromNetAddr(parseUDPNetAddr, "udp", "udp4", "udp6") + defaultCodecs.RegisterFromNetAddr(parseIPNetAddr, "ip", "ip4", "ip6") + defaultCodecs.RegisterFromNetAddr(parseIPPlusNetAddr, "ip+net") + defaultCodecs.RegisterFromNetAddr(parseUnixNetAddr, "unix") + + defaultCodecs.RegisterToNetAddr(parseBasicNetMaddr, "tcp", "udp", "ip6", "ip4", "unix") +} + +// CodecMap holds a map of NetCodecs indexed by their Protocol ID +// along with parsers for the addresses they use. +// It is used to keep a list of supported network address codecs (protocols +// which addresses can be converted to and from multiaddresses). +type CodecMap struct { + codecs map[string]*NetCodec + addrParsers map[string]FromNetAddrFunc + maddrParsers map[string]ToNetAddrFunc + lk sync.Mutex +} + +// NewCodecMap initializes and returns a CodecMap object. +func NewCodecMap() *CodecMap { + return &CodecMap{ + addrParsers: make(map[string]FromNetAddrFunc), + maddrParsers: make(map[string]ToNetAddrFunc), + } +} + +// NetCodec is used to identify a network codec, that is, a network type for +// which we are able to translate multiaddresses into standard Go net.Addr +// and back. +// +// Deprecated: Unfortunately, these mappings aren't one to one. This abstraction +// assumes that multiple "networks" can map to a single multiaddr protocol but +// not the reverse. For example, this abstraction supports `tcp6, tcp4, tcp -> +// /tcp/` really well but doesn't support `ip -> {/ip4/, /ip6/}`. +// +// Please use `RegisterFromNetAddr` and `RegisterToNetAddr` directly. +type NetCodec struct { + // NetAddrNetworks is an array of strings that may be returned + // by net.Addr.Network() calls on addresses belonging to this type + NetAddrNetworks []string + + // ProtocolName is the string value for Multiaddr address keys + ProtocolName string + + // ParseNetAddr parses a net.Addr belonging to this type into a multiaddr + ParseNetAddr FromNetAddrFunc + + // ConvertMultiaddr converts a multiaddr of this type back into a net.Addr + ConvertMultiaddr ToNetAddrFunc + + // Protocol returns the multiaddr protocol struct for this type + Protocol ma.Protocol +} + +// RegisterNetCodec adds a new NetCodec to the default codecs. +func RegisterNetCodec(a *NetCodec) { + defaultCodecs.RegisterNetCodec(a) +} + +// RegisterNetCodec adds a new NetCodec to the CodecMap. This function is +// thread safe. +func (cm *CodecMap) RegisterNetCodec(a *NetCodec) { + cm.lk.Lock() + defer cm.lk.Unlock() + for _, n := range a.NetAddrNetworks { + cm.addrParsers[n] = a.ParseNetAddr + } + + cm.maddrParsers[a.ProtocolName] = a.ConvertMultiaddr +} + +// RegisterFromNetAddr registers a conversion from net.Addr instances to multiaddrs +func (cm *CodecMap) RegisterFromNetAddr(from FromNetAddrFunc, networks ...string) { + cm.lk.Lock() + defer cm.lk.Unlock() + + for _, n := range networks { + cm.addrParsers[n] = from + } +} + +// RegisterToNetAddr registers a conversion from multiaddrs to net.Addr instances +func (cm *CodecMap) RegisterToNetAddr(to ToNetAddrFunc, protocols ...string) { + cm.lk.Lock() + defer cm.lk.Unlock() + + for _, p := range protocols { + cm.maddrParsers[p] = to + } +} + +func (cm *CodecMap) getAddrParser(net string) (FromNetAddrFunc, error) { + cm.lk.Lock() + defer cm.lk.Unlock() + + parser, ok := cm.addrParsers[net] + if !ok { + return nil, fmt.Errorf("unknown network %v", net) + } + return parser, nil +} + +func (cm *CodecMap) getMaddrParser(name string) (ToNetAddrFunc, error) { + cm.lk.Lock() + defer cm.lk.Unlock() + p, ok := cm.maddrParsers[name] + if !ok { + return nil, fmt.Errorf("network not supported: %s", name) + } + + return p, nil +} diff --git a/net/registry_test.go b/net/registry_test.go new file mode 100644 index 0000000..b3777ca --- /dev/null +++ b/net/registry_test.go @@ -0,0 +1,50 @@ +package manet + +import ( + "net" + "testing" + + ma "github.com/multiformats/go-multiaddr" +) + +func TestRegisterSpec(t *testing.T) { + cm := NewCodecMap() + myproto := &NetCodec{ + ProtocolName: "test", + NetAddrNetworks: []string{"test", "iptest", "blahtest"}, + ConvertMultiaddr: func(a ma.Multiaddr) (net.Addr, error) { return nil, nil }, + ParseNetAddr: func(a net.Addr) (ma.Multiaddr, error) { return nil, nil }, + } + + cm.RegisterNetCodec(myproto) + + _, ok := cm.addrParsers["test"] + if !ok { + t.Fatal("myproto not properly registered") + } + + _, ok = cm.addrParsers["iptest"] + if !ok { + t.Fatal("myproto not properly registered") + } + + _, ok = cm.addrParsers["blahtest"] + if !ok { + t.Fatal("myproto not properly registered") + } + + _, ok = cm.maddrParsers["test"] + if !ok { + t.Fatal("myproto not properly registered") + } + + _, ok = cm.maddrParsers["iptest"] + if ok { + t.Fatal("myproto not properly registered") + } + + _, ok = cm.maddrParsers["blahtest"] + if ok { + t.Fatal("myproto not properly registered") + } +}