Skip to content

Commit

Permalink
swarm: return errors on filtered addresses when dialling
Browse files Browse the repository at this point in the history
  • Loading branch information
sukunrt committed Aug 21, 2023
1 parent 37319a6 commit bc3a2ec
Show file tree
Hide file tree
Showing 5 changed files with 96 additions and 50 deletions.
7 changes: 5 additions & 2 deletions p2p/net/swarm/black_hole_detector.go
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ type blackHoleDetector struct {
}

// FilterAddrs filters the peer's addresses removing black holed addresses
func (d *blackHoleDetector) FilterAddrs(addrs []ma.Multiaddr) []ma.Multiaddr {
func (d *blackHoleDetector) FilterAddrs(addrs []ma.Multiaddr) (valid []ma.Multiaddr, blackHoled []ma.Multiaddr) {
hasUDP, hasIPv6 := false, false
for _, a := range addrs {
if !manet.IsPublicAddr(a) {
Expand All @@ -202,6 +202,7 @@ func (d *blackHoleDetector) FilterAddrs(addrs []ma.Multiaddr) []ma.Multiaddr {
ipv6Res = d.ipv6.HandleRequest()
}

blackHoled = make([]ma.Multiaddr, 0, len(addrs))
return ma.FilterAddrs(
addrs,
func(a ma.Multiaddr) bool {
Expand All @@ -218,14 +219,16 @@ func (d *blackHoleDetector) FilterAddrs(addrs []ma.Multiaddr) []ma.Multiaddr {
}

if udpRes == blackHoleResultBlocked && isProtocolAddr(a, ma.P_UDP) {
blackHoled = append(blackHoled, a)
return false
}
if ipv6Res == blackHoleResultBlocked && isProtocolAddr(a, ma.P_IP6) {
blackHoled = append(blackHoled, a)
return false
}
return true
},
)
), blackHoled
}

// RecordResult updates the state of the relevant `blackHoleFilter`s for addr
Expand Down
44 changes: 31 additions & 13 deletions p2p/net/swarm/black_hole_detector_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ func TestBlackHoleDetectorInApplicableAddress(t *testing.T) {
ma.StringCast("/ip4/192.168.1.5/udp/1234/quic-v1"),
}
for i := 0; i < 1000; i++ {
filteredAddrs := bhd.FilterAddrs(addrs)
filteredAddrs, _ := bhd.FilterAddrs(addrs)
require.ElementsMatch(t, addrs, filteredAddrs)
for j := 0; j < len(addrs); j++ {
bhd.RecordResult(addrs[j], false)
Expand All @@ -101,20 +101,29 @@ func TestBlackHoleDetectorUDPDisabled(t *testing.T) {
for i := 0; i < 100; i++ {
bhd.RecordResult(publicAddr, false)
}
addrs := []ma.Multiaddr{publicAddr, privAddr}
require.ElementsMatch(t, addrs, bhd.FilterAddrs(addrs))
wantAddrs := []ma.Multiaddr{publicAddr, privAddr}
wantRemovedAddrs := make([]ma.Multiaddr, 0)

gotAddrs, gotRemovedAddrs := bhd.FilterAddrs(wantAddrs)
require.ElementsMatch(t, wantAddrs, gotAddrs)
require.ElementsMatch(t, wantRemovedAddrs, gotRemovedAddrs)
}

func TestBlackHoleDetectorIPv6Disabled(t *testing.T) {
udpConfig := blackHoleConfig{Enabled: true, N: 10, MinSuccesses: 5}
bhd := newBlackHoleDetector(udpConfig, blackHoleConfig{Enabled: false}, nil)
publicAddr := ma.StringCast("/ip6/1::1/tcp/1234")
privAddr := ma.StringCast("/ip6/::1/tcp/1234")
addrs := []ma.Multiaddr{publicAddr, privAddr}
for i := 0; i < 100; i++ {
bhd.RecordResult(publicAddr, false)
}
require.ElementsMatch(t, addrs, bhd.FilterAddrs(addrs))

wantAddrs := []ma.Multiaddr{publicAddr, privAddr}
wantRemovedAddrs := make([]ma.Multiaddr, 0)

gotAddrs, gotRemovedAddrs := bhd.FilterAddrs(wantAddrs)
require.ElementsMatch(t, wantAddrs, gotAddrs)
require.ElementsMatch(t, wantRemovedAddrs, gotRemovedAddrs)
}

func TestBlackHoleDetectorProbes(t *testing.T) {
Expand All @@ -128,7 +137,7 @@ func TestBlackHoleDetectorProbes(t *testing.T) {
bhd.RecordResult(udp6Addr, false)
}
for i := 1; i < 100; i++ {
filteredAddrs := bhd.FilterAddrs(addrs)
filteredAddrs, _ := bhd.FilterAddrs(addrs)
if i%2 == 0 || i%3 == 0 {
if len(filteredAddrs) == 0 {
t.Fatalf("expected probe to be allowed irrespective of the state of other black hole filter")
Expand All @@ -145,7 +154,7 @@ func TestBlackHoleDetectorProbes(t *testing.T) {
func TestBlackHoleDetectorAddrFiltering(t *testing.T) {
udp6Pub := ma.StringCast("/ip6/1::1/udp/1234/quic-v1")
udp6Pri := ma.StringCast("/ip6/::1/udp/1234/quic-v1")
upd4Pub := ma.StringCast("/ip4/1.2.3.4/udp/1234/quic-v1")
udp4Pub := ma.StringCast("/ip4/1.2.3.4/udp/1234/quic-v1")
udp4Pri := ma.StringCast("/ip4/192.168.1.5/udp/1234/quic-v1")
tcp6Pub := ma.StringCast("/ip6/1::1/tcp/1234/quic-v1")
tcp6Pri := ma.StringCast("/ip6/::1/tcp/1234/quic-v1")
Expand All @@ -158,26 +167,35 @@ func TestBlackHoleDetectorAddrFiltering(t *testing.T) {
ipv6: &blackHoleFilter{n: 100, minSuccesses: 10, name: "ipv6"},
}
for i := 0; i < 100; i++ {
bhd.RecordResult(upd4Pub, !udpBlocked)
bhd.RecordResult(udp4Pub, !udpBlocked)
}
for i := 0; i < 100; i++ {
bhd.RecordResult(tcp6Pub, !ipv6Blocked)
}
return bhd
}

allInput := []ma.Multiaddr{udp6Pub, udp6Pri, upd4Pub, udp4Pri, tcp6Pub, tcp6Pri,
allInput := []ma.Multiaddr{udp6Pub, udp6Pri, udp4Pub, udp4Pri, tcp6Pub, tcp6Pri,
tcp4Pub, tcp4Pri}

udpBlockedOutput := []ma.Multiaddr{udp6Pri, udp4Pri, tcp6Pub, tcp6Pri, tcp4Pub, tcp4Pri}
udpPublicAddrs := []ma.Multiaddr{udp6Pub, udp4Pub}
bhd := makeBHD(true, false)
require.ElementsMatch(t, udpBlockedOutput, bhd.FilterAddrs(allInput))
gotAddrs, gotRemovedAddrs := bhd.FilterAddrs(allInput)
require.ElementsMatch(t, udpBlockedOutput, gotAddrs)
require.ElementsMatch(t, udpPublicAddrs, gotRemovedAddrs)

ip6BlockedOutput := []ma.Multiaddr{udp6Pri, upd4Pub, udp4Pri, tcp6Pri, tcp4Pub, tcp4Pri}
ip6BlockedOutput := []ma.Multiaddr{udp6Pri, udp4Pub, udp4Pri, tcp6Pri, tcp4Pub, tcp4Pri}
ip6PublicAddrs := []ma.Multiaddr{udp6Pub, tcp6Pub}
bhd = makeBHD(false, true)
require.ElementsMatch(t, ip6BlockedOutput, bhd.FilterAddrs(allInput))
gotAddrs, gotRemovedAddrs = bhd.FilterAddrs(allInput)
require.ElementsMatch(t, ip6BlockedOutput, gotAddrs)
require.ElementsMatch(t, ip6PublicAddrs, gotRemovedAddrs)

bothBlockedOutput := []ma.Multiaddr{udp6Pri, udp4Pri, tcp6Pri, tcp4Pub, tcp4Pri}
bothPublicAddrs := []ma.Multiaddr{udp6Pub, tcp6Pub, udp4Pub}
bhd = makeBHD(true, true)
require.ElementsMatch(t, bothBlockedOutput, bhd.FilterAddrs(allInput))
gotAddrs, gotRemovedAddrs = bhd.FilterAddrs(allInput)
require.ElementsMatch(t, bothBlockedOutput, gotAddrs)
require.ElementsMatch(t, bothPublicAddrs, gotRemovedAddrs)
}
13 changes: 10 additions & 3 deletions p2p/net/swarm/dial_worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,9 +165,14 @@ loop:
continue loop
}

addrs, err := w.s.addrsForDial(req.ctx, w.peer)
addrs, addrErrs, err := w.s.addrsForDial(req.ctx, w.peer)
if err != nil {
req.resch <- dialResponse{err: err}
req.resch <- dialResponse{
err: &DialError{
Peer: w.peer,
DialErrors: addrErrs,
Cause: err,
}}
continue loop
}

Expand All @@ -179,8 +184,8 @@ loop:
// create the pending request object
pr := &pendRequest{
req: req,
err: &DialError{Peer: w.peer},
addrs: make(map[string]struct{}, len(addrRanking)),
err: &DialError{Peer: w.peer, DialErrors: addrErrs},
}
for _, adelay := range addrRanking {
pr.addrs[string(adelay.Addr.Bytes())] = struct{}{}
Expand Down Expand Up @@ -221,6 +226,7 @@ loop:

if len(todial) == 0 && len(tojoin) == 0 {
// all request applicable addrs have been dialed, we must have errored
pr.err.Cause = ErrAllDialsFailed
req.resch <- dialResponse{err: pr.err}
continue loop
}
Expand Down Expand Up @@ -371,6 +377,7 @@ func (w *dialWorker) dispatchError(ad *addrDial, err error) {
if c != nil {
pr.req.resch <- dialResponse{conn: c}
} else {
pr.err.Cause = ErrAllDialsFailed
pr.req.resch <- dialResponse{err: pr.err}
}
delete(w.pendingRequests, pr)
Expand Down
61 changes: 39 additions & 22 deletions p2p/net/swarm/swarm_dial.go
Original file line number Diff line number Diff line change
Expand Up @@ -280,10 +280,10 @@ func (s *Swarm) dialWorkerLoop(p peer.ID, reqch <-chan dialRequest) {
w.loop()
}

func (s *Swarm) addrsForDial(ctx context.Context, p peer.ID) ([]ma.Multiaddr, error) {
func (s *Swarm) addrsForDial(ctx context.Context, p peer.ID) (goodAddrs []ma.Multiaddr, addrErrs []TransportError, err error) {
peerAddrs := s.peers.Addrs(p)
if len(peerAddrs) == 0 {
return nil, ErrNoAddresses
return nil, nil, ErrNoAddresses
}

peerAddrsAfterTransportResolved := make([]ma.Multiaddr, 0, len(peerAddrs))
Expand All @@ -308,22 +308,22 @@ func (s *Swarm) addrsForDial(ctx context.Context, p peer.ID) ([]ma.Multiaddr, er
Addrs: peerAddrsAfterTransportResolved,
})
if err != nil {
return nil, err
return nil, nil, err
}

goodAddrs := s.filterKnownUndialables(p, resolved)
goodAddrs = ma.Unique(resolved)
goodAddrs, addrErrs = s.filterKnownUndialables(p, goodAddrs)
if forceDirect, _ := network.GetForceDirectDial(ctx); forceDirect {
goodAddrs = ma.FilterAddrs(goodAddrs, s.nonProxyAddr)
}
goodAddrs = ma.Unique(goodAddrs)

if len(goodAddrs) == 0 {
return nil, ErrNoGoodAddresses
return nil, addrErrs, ErrNoGoodAddresses
}

s.peers.AddAddrs(p, goodAddrs, peerstore.TempAddrTTL)

return goodAddrs, nil
return goodAddrs, addrErrs, nil
}

func (s *Swarm) resolveAddrs(ctx context.Context, pi peer.AddrInfo) ([]ma.Multiaddr, error) {
Expand Down Expand Up @@ -402,11 +402,6 @@ func (s *Swarm) dialNextAddr(ctx context.Context, p peer.ID, addr ma.Multiaddr,
return nil
}

func (s *Swarm) canDial(addr ma.Multiaddr) bool {
t := s.TransportForDialing(addr)
return t != nil && t.CanDial(addr)
}

func (s *Swarm) nonProxyAddr(addr ma.Multiaddr) bool {
t := s.TransportForDialing(addr)
return !t.Proxy()
Expand All @@ -418,7 +413,7 @@ func (s *Swarm) nonProxyAddr(addr ma.Multiaddr) bool {
// addresses that we know to be our own, and addresses with a better tranport
// available. This is an optimization to avoid wasting time on dials that we
// know are going to fail or for which we have a better alternative.
func (s *Swarm) filterKnownUndialables(p peer.ID, addrs []ma.Multiaddr) []ma.Multiaddr {
func (s *Swarm) filterKnownUndialables(p peer.ID, addrs []ma.Multiaddr) (goodAddrs []ma.Multiaddr, addrErrs []TransportError) {
lisAddrs, _ := s.InterfaceListenAddresses()
var ourAddrs []ma.Multiaddr
for _, addr := range lisAddrs {
Expand All @@ -431,27 +426,49 @@ func (s *Swarm) filterKnownUndialables(p peer.ID, addrs []ma.Multiaddr) []ma.Mul
})
}

// The order of these two filters is important. If we can only dial /webtransport,
// we don't want to filter /webtransport addresses out because the peer had a /quic-v1
// address
addrErrs = make([]TransportError, 0, len(addrs))

// filter addresses we cannot dial
addrs = ma.FilterAddrs(addrs, s.canDial)
// The order of checking for transport and filtering low priority addrs is important. If we
// can only dial /webtransport, we don't want to filter /webtransport addresses out because
// the peer had a /quic-v1 address

// filter addresses with no transport
addrs = ma.FilterAddrs(addrs, func(a ma.Multiaddr) bool {
if s.TransportForDialing(a) == nil {
addrErrs = append(addrErrs, TransportError{Address: a, Cause: ErrNoTransport})
return false
}
return true
})

// filter low priority addresses among the addresses we can dial
// We don't return an error for these addresses
addrs = filterLowPriorityAddresses(addrs)

// remove black holed addrs
addrs = s.bhd.FilterAddrs(addrs)
addrs, blackHoledAddrs := s.bhd.FilterAddrs(addrs)
for _, a := range blackHoledAddrs {
addrErrs = append(addrErrs, TransportError{Address: a, Cause: ErrDialRefusedBlackHole})
}

return ma.FilterAddrs(addrs,
func(addr ma.Multiaddr) bool { return !ma.Contains(ourAddrs, addr) },
func(addr ma.Multiaddr) bool {
if ma.Contains(ourAddrs, addr) {
addrErrs = append(addrErrs, TransportError{Address: addr, Cause: ErrDialToSelf})
return false
}
return true
},
// TODO: Consider allowing link-local addresses
func(addr ma.Multiaddr) bool { return !manet.IsIP6LinkLocal(addr) },
func(addr ma.Multiaddr) bool {
return s.gater == nil || s.gater.InterceptAddrDial(p, addr)
if s.gater != nil && !s.gater.InterceptAddrDial(p, addr) {
addrErrs = append(addrErrs, TransportError{Address: addr, Cause: ErrGaterDisallowedConnection})
return false
}
return true
},
)
), addrErrs
}

// limitedDial will start a dial to the given peer when
Expand Down
21 changes: 11 additions & 10 deletions p2p/net/swarm/swarm_dial_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bytes"
"context"
"crypto/rand"
"errors"
"net"
"sort"
"testing"
Expand Down Expand Up @@ -65,7 +66,7 @@ func TestAddrsForDial(t *testing.T) {
ps.AddAddr(otherPeer, ma.StringCast("/dns4/example.com/tcp/1234/wss"), time.Hour)

ctx := context.Background()
mas, err := s.addrsForDial(ctx, otherPeer)
mas, _, err := s.addrsForDial(ctx, otherPeer)
require.NoError(t, err)

require.NotZero(t, len(mas))
Expand Down Expand Up @@ -110,7 +111,7 @@ func TestDedupAddrsForDial(t *testing.T) {
ps.AddAddr(otherPeer, ma.StringCast("/ip4/1.2.3.4/tcp/1234"), time.Hour)

ctx := context.Background()
mas, err := s.addrsForDial(ctx, otherPeer)
mas, _, err := s.addrsForDial(ctx, otherPeer)
require.NoError(t, err)

require.Equal(t, 1, len(mas))
Expand Down Expand Up @@ -183,7 +184,7 @@ func TestAddrResolution(t *testing.T) {

tctx, cancel := context.WithTimeout(ctx, time.Millisecond*100)
defer cancel()
mas, err := s.addrsForDial(tctx, p1)
mas, _, err := s.addrsForDial(tctx, p1)
require.NoError(t, err)

require.Len(t, mas, 1)
Expand Down Expand Up @@ -241,7 +242,7 @@ func TestAddrResolutionRecursive(t *testing.T) {
tctx, cancel := context.WithTimeout(ctx, time.Millisecond*100)
defer cancel()
s.Peerstore().AddAddrs(pi1.ID, pi1.Addrs, peerstore.TempAddrTTL)
_, err = s.addrsForDial(tctx, p1)
_, _, err = s.addrsForDial(tctx, p1)
require.NoError(t, err)

addrs1 := s.Peerstore().Addrs(pi1.ID)
Expand All @@ -253,7 +254,7 @@ func TestAddrResolutionRecursive(t *testing.T) {
require.NoError(t, err)

s.Peerstore().AddAddrs(pi2.ID, pi2.Addrs, peerstore.TempAddrTTL)
_, err = s.addrsForDial(tctx, p2)
_, _, err = s.addrsForDial(tctx, p2)
// This never resolves to a good address
require.Equal(t, ErrNoGoodAddresses, err)

Expand Down Expand Up @@ -315,7 +316,7 @@ func TestAddrsForDialFiltering(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
s.Peerstore().ClearAddrs(p1)
s.Peerstore().AddAddrs(p1, tc.input, peerstore.PermanentAddrTTL)
result, err := s.addrsForDial(ctx, p1)
result, _, err := s.addrsForDial(ctx, p1)
require.NoError(t, err)
sort.Slice(result, func(i, j int) bool { return bytes.Compare(result[i].Bytes(), result[j].Bytes()) < 0 })
sort.Slice(tc.output, func(i, j int) bool { return bytes.Compare(tc.output[i].Bytes(), tc.output[j].Bytes()) < 0 })
Expand Down Expand Up @@ -366,10 +367,10 @@ func TestBlackHoledAddrBlocked(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
conn, err := s.DialPeer(ctx, p)
if conn != nil {
t.Fatalf("expected dial to be blocked")
}
if err != ErrNoGoodAddresses {
require.Nil(t, conn)
var de *DialError
if !errors.As(err, &de) {
t.Fatalf("expected to receive an error of type *DialError, got %s of type %T", err, err)
}
require.Contains(t, de.DialErrors, TransportError{Address: addr, Cause: ErrDialRefusedBlackHole})
}

0 comments on commit bc3a2ec

Please sign in to comment.