From bc3a2ec625b73abf3ee84e3aa986a3d7f620508b Mon Sep 17 00:00:00 2001 From: sukun Date: Mon, 7 Aug 2023 17:13:50 +0530 Subject: [PATCH] swarm: return errors on filtered addresses when dialling --- p2p/net/swarm/black_hole_detector.go | 7 ++- p2p/net/swarm/black_hole_detector_test.go | 44 +++++++++++----- p2p/net/swarm/dial_worker.go | 13 +++-- p2p/net/swarm/swarm_dial.go | 61 +++++++++++++++-------- p2p/net/swarm/swarm_dial_test.go | 21 ++++---- 5 files changed, 96 insertions(+), 50 deletions(-) diff --git a/p2p/net/swarm/black_hole_detector.go b/p2p/net/swarm/black_hole_detector.go index 078b1126c4..dd7849eea6 100644 --- a/p2p/net/swarm/black_hole_detector.go +++ b/p2p/net/swarm/black_hole_detector.go @@ -178,7 +178,7 @@ type blackHoleDetector struct { } // FilterAddrs filters the peer's addresses removing black holed addresses -func (d *blackHoleDetector) FilterAddrs(addrs []ma.Multiaddr) []ma.Multiaddr { +func (d *blackHoleDetector) FilterAddrs(addrs []ma.Multiaddr) (valid []ma.Multiaddr, blackHoled []ma.Multiaddr) { hasUDP, hasIPv6 := false, false for _, a := range addrs { if !manet.IsPublicAddr(a) { @@ -202,6 +202,7 @@ func (d *blackHoleDetector) FilterAddrs(addrs []ma.Multiaddr) []ma.Multiaddr { ipv6Res = d.ipv6.HandleRequest() } + blackHoled = make([]ma.Multiaddr, 0, len(addrs)) return ma.FilterAddrs( addrs, func(a ma.Multiaddr) bool { @@ -218,14 +219,16 @@ func (d *blackHoleDetector) FilterAddrs(addrs []ma.Multiaddr) []ma.Multiaddr { } if udpRes == blackHoleResultBlocked && isProtocolAddr(a, ma.P_UDP) { + blackHoled = append(blackHoled, a) return false } if ipv6Res == blackHoleResultBlocked && isProtocolAddr(a, ma.P_IP6) { + blackHoled = append(blackHoled, a) return false } return true }, - ) + ), blackHoled } // RecordResult updates the state of the relevant `blackHoleFilter`s for addr diff --git a/p2p/net/swarm/black_hole_detector_test.go b/p2p/net/swarm/black_hole_detector_test.go index 7b10fc88a6..dfbb30f90d 100644 --- a/p2p/net/swarm/black_hole_detector_test.go +++ b/p2p/net/swarm/black_hole_detector_test.go @@ -85,7 +85,7 @@ func TestBlackHoleDetectorInApplicableAddress(t *testing.T) { ma.StringCast("/ip4/192.168.1.5/udp/1234/quic-v1"), } for i := 0; i < 1000; i++ { - filteredAddrs := bhd.FilterAddrs(addrs) + filteredAddrs, _ := bhd.FilterAddrs(addrs) require.ElementsMatch(t, addrs, filteredAddrs) for j := 0; j < len(addrs); j++ { bhd.RecordResult(addrs[j], false) @@ -101,8 +101,12 @@ func TestBlackHoleDetectorUDPDisabled(t *testing.T) { for i := 0; i < 100; i++ { bhd.RecordResult(publicAddr, false) } - addrs := []ma.Multiaddr{publicAddr, privAddr} - require.ElementsMatch(t, addrs, bhd.FilterAddrs(addrs)) + wantAddrs := []ma.Multiaddr{publicAddr, privAddr} + wantRemovedAddrs := make([]ma.Multiaddr, 0) + + gotAddrs, gotRemovedAddrs := bhd.FilterAddrs(wantAddrs) + require.ElementsMatch(t, wantAddrs, gotAddrs) + require.ElementsMatch(t, wantRemovedAddrs, gotRemovedAddrs) } func TestBlackHoleDetectorIPv6Disabled(t *testing.T) { @@ -110,11 +114,16 @@ func TestBlackHoleDetectorIPv6Disabled(t *testing.T) { bhd := newBlackHoleDetector(udpConfig, blackHoleConfig{Enabled: false}, nil) publicAddr := ma.StringCast("/ip6/1::1/tcp/1234") privAddr := ma.StringCast("/ip6/::1/tcp/1234") - addrs := []ma.Multiaddr{publicAddr, privAddr} for i := 0; i < 100; i++ { bhd.RecordResult(publicAddr, false) } - require.ElementsMatch(t, addrs, bhd.FilterAddrs(addrs)) + + wantAddrs := []ma.Multiaddr{publicAddr, privAddr} + wantRemovedAddrs := make([]ma.Multiaddr, 0) + + gotAddrs, gotRemovedAddrs := bhd.FilterAddrs(wantAddrs) + require.ElementsMatch(t, wantAddrs, gotAddrs) + require.ElementsMatch(t, wantRemovedAddrs, gotRemovedAddrs) } func TestBlackHoleDetectorProbes(t *testing.T) { @@ -128,7 +137,7 @@ func TestBlackHoleDetectorProbes(t *testing.T) { bhd.RecordResult(udp6Addr, false) } for i := 1; i < 100; i++ { - filteredAddrs := bhd.FilterAddrs(addrs) + filteredAddrs, _ := bhd.FilterAddrs(addrs) if i%2 == 0 || i%3 == 0 { if len(filteredAddrs) == 0 { t.Fatalf("expected probe to be allowed irrespective of the state of other black hole filter") @@ -145,7 +154,7 @@ func TestBlackHoleDetectorProbes(t *testing.T) { func TestBlackHoleDetectorAddrFiltering(t *testing.T) { udp6Pub := ma.StringCast("/ip6/1::1/udp/1234/quic-v1") udp6Pri := ma.StringCast("/ip6/::1/udp/1234/quic-v1") - upd4Pub := ma.StringCast("/ip4/1.2.3.4/udp/1234/quic-v1") + udp4Pub := ma.StringCast("/ip4/1.2.3.4/udp/1234/quic-v1") udp4Pri := ma.StringCast("/ip4/192.168.1.5/udp/1234/quic-v1") tcp6Pub := ma.StringCast("/ip6/1::1/tcp/1234/quic-v1") tcp6Pri := ma.StringCast("/ip6/::1/tcp/1234/quic-v1") @@ -158,7 +167,7 @@ func TestBlackHoleDetectorAddrFiltering(t *testing.T) { ipv6: &blackHoleFilter{n: 100, minSuccesses: 10, name: "ipv6"}, } for i := 0; i < 100; i++ { - bhd.RecordResult(upd4Pub, !udpBlocked) + bhd.RecordResult(udp4Pub, !udpBlocked) } for i := 0; i < 100; i++ { bhd.RecordResult(tcp6Pub, !ipv6Blocked) @@ -166,18 +175,27 @@ func TestBlackHoleDetectorAddrFiltering(t *testing.T) { return bhd } - allInput := []ma.Multiaddr{udp6Pub, udp6Pri, upd4Pub, udp4Pri, tcp6Pub, tcp6Pri, + allInput := []ma.Multiaddr{udp6Pub, udp6Pri, udp4Pub, udp4Pri, tcp6Pub, tcp6Pri, tcp4Pub, tcp4Pri} udpBlockedOutput := []ma.Multiaddr{udp6Pri, udp4Pri, tcp6Pub, tcp6Pri, tcp4Pub, tcp4Pri} + udpPublicAddrs := []ma.Multiaddr{udp6Pub, udp4Pub} bhd := makeBHD(true, false) - require.ElementsMatch(t, udpBlockedOutput, bhd.FilterAddrs(allInput)) + gotAddrs, gotRemovedAddrs := bhd.FilterAddrs(allInput) + require.ElementsMatch(t, udpBlockedOutput, gotAddrs) + require.ElementsMatch(t, udpPublicAddrs, gotRemovedAddrs) - ip6BlockedOutput := []ma.Multiaddr{udp6Pri, upd4Pub, udp4Pri, tcp6Pri, tcp4Pub, tcp4Pri} + ip6BlockedOutput := []ma.Multiaddr{udp6Pri, udp4Pub, udp4Pri, tcp6Pri, tcp4Pub, tcp4Pri} + ip6PublicAddrs := []ma.Multiaddr{udp6Pub, tcp6Pub} bhd = makeBHD(false, true) - require.ElementsMatch(t, ip6BlockedOutput, bhd.FilterAddrs(allInput)) + gotAddrs, gotRemovedAddrs = bhd.FilterAddrs(allInput) + require.ElementsMatch(t, ip6BlockedOutput, gotAddrs) + require.ElementsMatch(t, ip6PublicAddrs, gotRemovedAddrs) bothBlockedOutput := []ma.Multiaddr{udp6Pri, udp4Pri, tcp6Pri, tcp4Pub, tcp4Pri} + bothPublicAddrs := []ma.Multiaddr{udp6Pub, tcp6Pub, udp4Pub} bhd = makeBHD(true, true) - require.ElementsMatch(t, bothBlockedOutput, bhd.FilterAddrs(allInput)) + gotAddrs, gotRemovedAddrs = bhd.FilterAddrs(allInput) + require.ElementsMatch(t, bothBlockedOutput, gotAddrs) + require.ElementsMatch(t, bothPublicAddrs, gotRemovedAddrs) } diff --git a/p2p/net/swarm/dial_worker.go b/p2p/net/swarm/dial_worker.go index 379fbf9ba6..78795f17b3 100644 --- a/p2p/net/swarm/dial_worker.go +++ b/p2p/net/swarm/dial_worker.go @@ -165,9 +165,14 @@ loop: continue loop } - addrs, err := w.s.addrsForDial(req.ctx, w.peer) + addrs, addrErrs, err := w.s.addrsForDial(req.ctx, w.peer) if err != nil { - req.resch <- dialResponse{err: err} + req.resch <- dialResponse{ + err: &DialError{ + Peer: w.peer, + DialErrors: addrErrs, + Cause: err, + }} continue loop } @@ -179,8 +184,8 @@ loop: // create the pending request object pr := &pendRequest{ req: req, - err: &DialError{Peer: w.peer}, addrs: make(map[string]struct{}, len(addrRanking)), + err: &DialError{Peer: w.peer, DialErrors: addrErrs}, } for _, adelay := range addrRanking { pr.addrs[string(adelay.Addr.Bytes())] = struct{}{} @@ -221,6 +226,7 @@ loop: if len(todial) == 0 && len(tojoin) == 0 { // all request applicable addrs have been dialed, we must have errored + pr.err.Cause = ErrAllDialsFailed req.resch <- dialResponse{err: pr.err} continue loop } @@ -371,6 +377,7 @@ func (w *dialWorker) dispatchError(ad *addrDial, err error) { if c != nil { pr.req.resch <- dialResponse{conn: c} } else { + pr.err.Cause = ErrAllDialsFailed pr.req.resch <- dialResponse{err: pr.err} } delete(w.pendingRequests, pr) diff --git a/p2p/net/swarm/swarm_dial.go b/p2p/net/swarm/swarm_dial.go index f2df93af2f..e17c27353b 100644 --- a/p2p/net/swarm/swarm_dial.go +++ b/p2p/net/swarm/swarm_dial.go @@ -280,10 +280,10 @@ func (s *Swarm) dialWorkerLoop(p peer.ID, reqch <-chan dialRequest) { w.loop() } -func (s *Swarm) addrsForDial(ctx context.Context, p peer.ID) ([]ma.Multiaddr, error) { +func (s *Swarm) addrsForDial(ctx context.Context, p peer.ID) (goodAddrs []ma.Multiaddr, addrErrs []TransportError, err error) { peerAddrs := s.peers.Addrs(p) if len(peerAddrs) == 0 { - return nil, ErrNoAddresses + return nil, nil, ErrNoAddresses } peerAddrsAfterTransportResolved := make([]ma.Multiaddr, 0, len(peerAddrs)) @@ -308,22 +308,22 @@ func (s *Swarm) addrsForDial(ctx context.Context, p peer.ID) ([]ma.Multiaddr, er Addrs: peerAddrsAfterTransportResolved, }) if err != nil { - return nil, err + return nil, nil, err } - goodAddrs := s.filterKnownUndialables(p, resolved) + goodAddrs = ma.Unique(resolved) + goodAddrs, addrErrs = s.filterKnownUndialables(p, goodAddrs) if forceDirect, _ := network.GetForceDirectDial(ctx); forceDirect { goodAddrs = ma.FilterAddrs(goodAddrs, s.nonProxyAddr) } - goodAddrs = ma.Unique(goodAddrs) if len(goodAddrs) == 0 { - return nil, ErrNoGoodAddresses + return nil, addrErrs, ErrNoGoodAddresses } s.peers.AddAddrs(p, goodAddrs, peerstore.TempAddrTTL) - return goodAddrs, nil + return goodAddrs, addrErrs, nil } func (s *Swarm) resolveAddrs(ctx context.Context, pi peer.AddrInfo) ([]ma.Multiaddr, error) { @@ -402,11 +402,6 @@ func (s *Swarm) dialNextAddr(ctx context.Context, p peer.ID, addr ma.Multiaddr, return nil } -func (s *Swarm) canDial(addr ma.Multiaddr) bool { - t := s.TransportForDialing(addr) - return t != nil && t.CanDial(addr) -} - func (s *Swarm) nonProxyAddr(addr ma.Multiaddr) bool { t := s.TransportForDialing(addr) return !t.Proxy() @@ -418,7 +413,7 @@ func (s *Swarm) nonProxyAddr(addr ma.Multiaddr) bool { // addresses that we know to be our own, and addresses with a better tranport // available. This is an optimization to avoid wasting time on dials that we // know are going to fail or for which we have a better alternative. -func (s *Swarm) filterKnownUndialables(p peer.ID, addrs []ma.Multiaddr) []ma.Multiaddr { +func (s *Swarm) filterKnownUndialables(p peer.ID, addrs []ma.Multiaddr) (goodAddrs []ma.Multiaddr, addrErrs []TransportError) { lisAddrs, _ := s.InterfaceListenAddresses() var ourAddrs []ma.Multiaddr for _, addr := range lisAddrs { @@ -431,27 +426,49 @@ func (s *Swarm) filterKnownUndialables(p peer.ID, addrs []ma.Multiaddr) []ma.Mul }) } - // The order of these two filters is important. If we can only dial /webtransport, - // we don't want to filter /webtransport addresses out because the peer had a /quic-v1 - // address + addrErrs = make([]TransportError, 0, len(addrs)) - // filter addresses we cannot dial - addrs = ma.FilterAddrs(addrs, s.canDial) + // The order of checking for transport and filtering low priority addrs is important. If we + // can only dial /webtransport, we don't want to filter /webtransport addresses out because + // the peer had a /quic-v1 address + + // filter addresses with no transport + addrs = ma.FilterAddrs(addrs, func(a ma.Multiaddr) bool { + if s.TransportForDialing(a) == nil { + addrErrs = append(addrErrs, TransportError{Address: a, Cause: ErrNoTransport}) + return false + } + return true + }) // filter low priority addresses among the addresses we can dial + // We don't return an error for these addresses addrs = filterLowPriorityAddresses(addrs) // remove black holed addrs - addrs = s.bhd.FilterAddrs(addrs) + addrs, blackHoledAddrs := s.bhd.FilterAddrs(addrs) + for _, a := range blackHoledAddrs { + addrErrs = append(addrErrs, TransportError{Address: a, Cause: ErrDialRefusedBlackHole}) + } return ma.FilterAddrs(addrs, - func(addr ma.Multiaddr) bool { return !ma.Contains(ourAddrs, addr) }, + func(addr ma.Multiaddr) bool { + if ma.Contains(ourAddrs, addr) { + addrErrs = append(addrErrs, TransportError{Address: addr, Cause: ErrDialToSelf}) + return false + } + return true + }, // TODO: Consider allowing link-local addresses func(addr ma.Multiaddr) bool { return !manet.IsIP6LinkLocal(addr) }, func(addr ma.Multiaddr) bool { - return s.gater == nil || s.gater.InterceptAddrDial(p, addr) + if s.gater != nil && !s.gater.InterceptAddrDial(p, addr) { + addrErrs = append(addrErrs, TransportError{Address: addr, Cause: ErrGaterDisallowedConnection}) + return false + } + return true }, - ) + ), addrErrs } // limitedDial will start a dial to the given peer when diff --git a/p2p/net/swarm/swarm_dial_test.go b/p2p/net/swarm/swarm_dial_test.go index 9538ae731d..f365444110 100644 --- a/p2p/net/swarm/swarm_dial_test.go +++ b/p2p/net/swarm/swarm_dial_test.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "crypto/rand" + "errors" "net" "sort" "testing" @@ -65,7 +66,7 @@ func TestAddrsForDial(t *testing.T) { ps.AddAddr(otherPeer, ma.StringCast("/dns4/example.com/tcp/1234/wss"), time.Hour) ctx := context.Background() - mas, err := s.addrsForDial(ctx, otherPeer) + mas, _, err := s.addrsForDial(ctx, otherPeer) require.NoError(t, err) require.NotZero(t, len(mas)) @@ -110,7 +111,7 @@ func TestDedupAddrsForDial(t *testing.T) { ps.AddAddr(otherPeer, ma.StringCast("/ip4/1.2.3.4/tcp/1234"), time.Hour) ctx := context.Background() - mas, err := s.addrsForDial(ctx, otherPeer) + mas, _, err := s.addrsForDial(ctx, otherPeer) require.NoError(t, err) require.Equal(t, 1, len(mas)) @@ -183,7 +184,7 @@ func TestAddrResolution(t *testing.T) { tctx, cancel := context.WithTimeout(ctx, time.Millisecond*100) defer cancel() - mas, err := s.addrsForDial(tctx, p1) + mas, _, err := s.addrsForDial(tctx, p1) require.NoError(t, err) require.Len(t, mas, 1) @@ -241,7 +242,7 @@ func TestAddrResolutionRecursive(t *testing.T) { tctx, cancel := context.WithTimeout(ctx, time.Millisecond*100) defer cancel() s.Peerstore().AddAddrs(pi1.ID, pi1.Addrs, peerstore.TempAddrTTL) - _, err = s.addrsForDial(tctx, p1) + _, _, err = s.addrsForDial(tctx, p1) require.NoError(t, err) addrs1 := s.Peerstore().Addrs(pi1.ID) @@ -253,7 +254,7 @@ func TestAddrResolutionRecursive(t *testing.T) { require.NoError(t, err) s.Peerstore().AddAddrs(pi2.ID, pi2.Addrs, peerstore.TempAddrTTL) - _, err = s.addrsForDial(tctx, p2) + _, _, err = s.addrsForDial(tctx, p2) // This never resolves to a good address require.Equal(t, ErrNoGoodAddresses, err) @@ -315,7 +316,7 @@ func TestAddrsForDialFiltering(t *testing.T) { t.Run(tc.name, func(t *testing.T) { s.Peerstore().ClearAddrs(p1) s.Peerstore().AddAddrs(p1, tc.input, peerstore.PermanentAddrTTL) - result, err := s.addrsForDial(ctx, p1) + result, _, err := s.addrsForDial(ctx, p1) require.NoError(t, err) sort.Slice(result, func(i, j int) bool { return bytes.Compare(result[i].Bytes(), result[j].Bytes()) < 0 }) sort.Slice(tc.output, func(i, j int) bool { return bytes.Compare(tc.output[i].Bytes(), tc.output[j].Bytes()) < 0 }) @@ -366,10 +367,10 @@ func TestBlackHoledAddrBlocked(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) defer cancel() conn, err := s.DialPeer(ctx, p) - if conn != nil { - t.Fatalf("expected dial to be blocked") - } - if err != ErrNoGoodAddresses { + require.Nil(t, conn) + var de *DialError + if !errors.As(err, &de) { t.Fatalf("expected to receive an error of type *DialError, got %s of type %T", err, err) } + require.Contains(t, de.DialErrors, TransportError{Address: addr, Cause: ErrDialRefusedBlackHole}) }