Skip to content

Commit

Permalink
swarm: Dedup addresses to dial (#2322)
Browse files Browse the repository at this point in the history
* Dedup addresses to dial

Co-authored-by: Aayush Rajasekaran <arajasek94@gmail.com>

* Move DedupAddrs test

* Typo

---------

Co-authored-by: Aayush Rajasekaran <arajasek94@gmail.com>
  • Loading branch information
MarcoPolo and arajasek committed Jun 2, 2023
1 parent fc89448 commit fd88935
Show file tree
Hide file tree
Showing 7 changed files with 118 additions and 52 deletions.
22 changes: 22 additions & 0 deletions core/network/network.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
package network

import (
"bytes"
"context"
"io"
"sort"
"time"

"github.com/libp2p/go-libp2p/core/peer"
Expand Down Expand Up @@ -184,3 +186,23 @@ type Dialer interface {
Notify(Notifiee)
StopNotify(Notifiee)
}

// DedupAddrs deduplicates addresses in place, leave only unique addresses.
// It doesn't allocate.
func DedupAddrs(addrs []ma.Multiaddr) []ma.Multiaddr {
if len(addrs) == 0 {
return addrs
}
sort.Slice(addrs, func(i, j int) bool { return bytes.Compare(addrs[i].Bytes(), addrs[j].Bytes()) < 0 })
idx := 1
for i := 1; i < len(addrs); i++ {
if !addrs[i-1].Equal(addrs[i]) {
addrs[idx] = addrs[i]
idx++
}
}
for i := idx; i < len(addrs); i++ {
addrs[i] = nil
}
return addrs[:idx]
}
36 changes: 36 additions & 0 deletions core/network/network_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package network

import (
"fmt"
"testing"

ma "github.com/multiformats/go-multiaddr"

"github.com/stretchr/testify/require"
)

func TestDedupAddrs(t *testing.T) {
tcpAddr := ma.StringCast("/ip4/127.0.0.1/tcp/1234")
quicAddr := ma.StringCast("/ip4/127.0.0.1/udp/1234/quic-v1")
wsAddr := ma.StringCast("/ip4/127.0.0.1/tcp/1234/ws")

type testcase struct {
in, out []ma.Multiaddr
}

for i, tc := range []testcase{
{in: nil, out: nil},
{in: []ma.Multiaddr{tcpAddr}, out: []ma.Multiaddr{tcpAddr}},
{in: []ma.Multiaddr{tcpAddr, tcpAddr, tcpAddr}, out: []ma.Multiaddr{tcpAddr}},
{in: []ma.Multiaddr{tcpAddr, quicAddr, tcpAddr}, out: []ma.Multiaddr{tcpAddr, quicAddr}},
{in: []ma.Multiaddr{tcpAddr, quicAddr, wsAddr}, out: []ma.Multiaddr{tcpAddr, quicAddr, wsAddr}},
} {
tc := tc
t.Run(fmt.Sprintf("test %d", i), func(t *testing.T) {
deduped := DedupAddrs(tc.in)
for _, a := range tc.out {
require.Contains(t, deduped, a)
}
})
}
}
26 changes: 2 additions & 24 deletions p2p/host/basic/basic_host.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
package basichost

import (
"bytes"
"context"
"errors"
"fmt"
"io"
"net"
"sort"
"sync"
"time"

Expand Down Expand Up @@ -816,26 +814,6 @@ func (h *BasicHost) NormalizeMultiaddr(addr ma.Multiaddr) ma.Multiaddr {
return addr
}

// dedupAddrs deduplicates addresses in place, leave only unique addresses.
// It doesn't allocate.
func dedupAddrs(addrs []ma.Multiaddr) []ma.Multiaddr {
if len(addrs) == 0 {
return addrs
}
sort.Slice(addrs, func(i, j int) bool { return bytes.Compare(addrs[i].Bytes(), addrs[j].Bytes()) < 0 })
idx := 1
for i := 1; i < len(addrs); i++ {
if !addrs[i-1].Equal(addrs[i]) {
addrs[idx] = addrs[i]
idx++
}
}
for i := idx; i < len(addrs); i++ {
addrs[i] = nil
}
return addrs[:idx]
}

// AllAddrs returns all the addresses of BasicHost at this moment in time.
// It's ok to not include addresses if they're not available to be used now.
func (h *BasicHost) AllAddrs() []ma.Multiaddr {
Expand All @@ -860,7 +838,7 @@ func (h *BasicHost) AllAddrs() []ma.Multiaddr {
finalAddrs = append(finalAddrs, resolved...)
}

finalAddrs = dedupAddrs(finalAddrs)
finalAddrs = network.DedupAddrs(finalAddrs)

var natMappings []inat.Mapping

Expand Down Expand Up @@ -1010,7 +988,7 @@ func (h *BasicHost) AllAddrs() []ma.Multiaddr {
}
finalAddrs = append(finalAddrs, observedAddrs...)
}
finalAddrs = dedupAddrs(finalAddrs)
finalAddrs = network.DedupAddrs(finalAddrs)
finalAddrs = inferWebtransportAddrsFromQuic(finalAddrs)

return finalAddrs
Expand Down
26 changes: 0 additions & 26 deletions p2p/host/basic/basic_host_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -825,32 +825,6 @@ func TestNormalizeMultiaddr(t *testing.T) {
require.Equal(t, "/ip4/1.2.3.4/udp/9999/quic-v1/webtransport", h1.NormalizeMultiaddr(ma.StringCast("/ip4/1.2.3.4/udp/9999/quic-v1/webtransport/certhash/uEgNmb28")).String())
}

func TestDedupAddrs(t *testing.T) {
tcpAddr := ma.StringCast("/ip4/127.0.0.1/tcp/1234")
quicAddr := ma.StringCast("/ip4/127.0.0.1/udp/1234/quic-v1")
wsAddr := ma.StringCast("/ip4/127.0.0.1/tcp/1234/ws")

type testcase struct {
in, out []ma.Multiaddr
}

for i, tc := range []testcase{
{in: nil, out: nil},
{in: []ma.Multiaddr{tcpAddr}, out: []ma.Multiaddr{tcpAddr}},
{in: []ma.Multiaddr{tcpAddr, tcpAddr, tcpAddr}, out: []ma.Multiaddr{tcpAddr}},
{in: []ma.Multiaddr{tcpAddr, quicAddr, tcpAddr}, out: []ma.Multiaddr{tcpAddr, quicAddr}},
{in: []ma.Multiaddr{tcpAddr, quicAddr, wsAddr}, out: []ma.Multiaddr{tcpAddr, quicAddr, wsAddr}},
} {
tc := tc
t.Run(fmt.Sprintf("test %d", i), func(t *testing.T) {
deduped := dedupAddrs(tc.in)
for _, a := range tc.out {
require.Contains(t, deduped, a)
}
})
}
}

func TestInferWebtransportAddrsFromQuic(t *testing.T) {
type testCase struct {
name string
Expand Down
14 changes: 12 additions & 2 deletions p2p/net/swarm/dial_worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,12 @@ loop:
case <-w.triggerDial:
for _, addr := range w.nextDial {
// spawn the dial
ad := w.pending[string(addr.Bytes())]
ad, ok := w.pending[string(addr.Bytes())]
if !ok {
log.Warn("unexpectedly missing pending addrDial for addr")
// Assume nothing to dial here
continue
}
err := w.s.dialNextAddr(ad.ctx, w.peer, addr, w.resch)
if err != nil {
w.dispatchError(ad, err)
Expand All @@ -195,7 +200,12 @@ loop:
w.connected = true
}

ad := w.pending[string(res.Addr.Bytes())]
ad, ok := w.pending[string(res.Addr.Bytes())]
if !ok {
log.Warn("unexpectedly missing pending addrDial res")
// Assume nothing to do here
continue
}

if res.Conn != nil {
// we got a connection, add it to the swarm
Expand Down
1 change: 1 addition & 0 deletions p2p/net/swarm/swarm_dial.go
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,7 @@ func (s *Swarm) addrsForDial(ctx context.Context, p peer.ID) ([]ma.Multiaddr, er
if forceDirect, _ := network.GetForceDirectDial(ctx); forceDirect {
goodAddrs = ma.FilterAddrs(goodAddrs, s.nonProxyAddr)
}
goodAddrs = network.DedupAddrs(goodAddrs)

if len(goodAddrs) == 0 {
return nil, ErrNoGoodAddresses
Expand Down
45 changes: 45 additions & 0 deletions p2p/net/swarm/swarm_dial_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,51 @@ func TestAddrsForDial(t *testing.T) {
require.NotZero(t, len(mas))
}

func TestDedupAddrsForDial(t *testing.T) {
mockResolver := madns.MockResolver{IP: make(map[string][]net.IPAddr)}
ipaddr, err := net.ResolveIPAddr("ip4", "1.2.3.4")
if err != nil {
t.Fatal(err)
}
mockResolver.IP["example.com"] = []net.IPAddr{*ipaddr}

resolver, err := madns.NewResolver(madns.WithDomainResolver("example.com", &mockResolver))
if err != nil {
t.Fatal(err)
}

priv, _, err := crypto.GenerateEd25519Key(rand.Reader)
require.NoError(t, err)
id, err := peer.IDFromPrivateKey(priv)
require.NoError(t, err)

ps, err := pstoremem.NewPeerstore()
require.NoError(t, err)
ps.AddPubKey(id, priv.GetPublic())
ps.AddPrivKey(id, priv)
t.Cleanup(func() { ps.Close() })

s, err := NewSwarm(id, ps, eventbus.NewBus(), WithMultiaddrResolver(resolver))
require.NoError(t, err)
defer s.Close()

tpt, err := tcp.NewTCPTransport(nil, &network.NullResourceManager{})
require.NoError(t, err)
err = s.AddTransport(tpt)
require.NoError(t, err)

otherPeer := test.RandPeerIDFatal(t)

ps.AddAddr(otherPeer, ma.StringCast("/dns4/example.com/tcp/1234"), time.Hour)
ps.AddAddr(otherPeer, ma.StringCast("/ip4/1.2.3.4/tcp/1234"), time.Hour)

ctx := context.Background()
mas, err := s.addrsForDial(ctx, otherPeer)
require.NoError(t, err)

require.Equal(t, 1, len(mas))
}

func newTestSwarmWithResolver(t *testing.T, resolver *madns.Resolver) *Swarm {
priv, _, err := crypto.GenerateEd25519Key(rand.Reader)
require.NoError(t, err)
Expand Down

0 comments on commit fd88935

Please sign in to comment.