diff --git a/config/config.go b/config/config.go index cea8ef1314..0416c18414 100644 --- a/config/config.go +++ b/config/config.go @@ -128,6 +128,13 @@ type Config struct { DialRanker network.DialRanker SwarmOpts []swarm.Option + + DisableAutoNATv2 bool + + UDPBlackHoleFilter *swarm.BlackHoleFilter + CustomUDPBlackHoleFilter bool + IPv6BlackHoleFilter *swarm.BlackHoleFilter + CustomIPv6BlackHoleFilter bool } func (cfg *Config) makeSwarm(eventBus event.Bus, enableMetrics bool) (*swarm.Swarm, error) { @@ -162,7 +169,10 @@ func (cfg *Config) makeSwarm(eventBus event.Bus, enableMetrics bool) (*swarm.Swa return nil, err } - opts := cfg.SwarmOpts + opts := append(cfg.SwarmOpts, + swarm.WithUDPBlackHoleFilter(cfg.UDPBlackHoleFilter), + swarm.WithIPv6BlackHoleFilter(cfg.IPv6BlackHoleFilter), + ) if cfg.Reporter != nil { opts = append(opts, swarm.WithMetrics(cfg.Reporter)) } @@ -190,6 +200,50 @@ func (cfg *Config) makeSwarm(eventBus event.Bus, enableMetrics bool) (*swarm.Swa return swarm.NewSwarm(pid, cfg.Peerstore, eventBus, opts...) } +func (cfg *Config) makeAutoNATHost() (*blankhost.BlankHost, error) { + autonatPrivKey, _, err := crypto.GenerateEd25519Key(rand.Reader) + if err != nil { + return nil, err + } + ps, err := pstoremem.NewPeerstore() + if err != nil { + return nil, err + } + + // Pull out the pieces of the config that we _actually_ care about. + // Specifically, don't set up things like autorelay, listeners, + // identify, etc. + autoNatCfg := Config{ + Transports: cfg.Transports, + Muxers: cfg.Muxers, + SecurityTransports: cfg.SecurityTransports, + Insecure: cfg.Insecure, + PSK: cfg.PSK, + ConnectionGater: cfg.ConnectionGater, + Reporter: cfg.Reporter, + PeerKey: autonatPrivKey, + Peerstore: ps, + DialRanker: swarm.NoDelayDialRanker, + UDPBlackHoleFilter: cfg.UDPBlackHoleFilter, + IPv6BlackHoleFilter: cfg.IPv6BlackHoleFilter, + SwarmOpts: []swarm.Option{ + // Don't update black hole state for failed autonat dials + swarm.WithReadOnlyBlackHoleDetector(), + }, + } + + dialer, err := autoNatCfg.makeSwarm(eventbus.NewBus(), false) + if err != nil { + return nil, err + } + dialerHost := blankhost.NewBlankHost(dialer) + if err := autoNatCfg.addTransports(dialerHost); err != nil { + dialerHost.Close() + return nil, err + } + return dialerHost, nil +} + func (cfg *Config) addTransports(h host.Host) error { swrm, ok := h.Network().(transport.TransportNetwork) if !ok { @@ -305,6 +359,15 @@ func (cfg *Config) NewNode() (host.Host, error) { rcmgr.MustRegisterWith(cfg.PrometheusRegisterer) } + var autonatv2Dialer *blankhost.BlankHost + if !cfg.DisableAutoNATv2 { + ah, err := cfg.makeAutoNATHost() + if err != nil { + return nil, err + } + autonatv2Dialer = ah + } + h, err := bhost.NewHost(swrm, &bhost.HostOpts{ EventBus: eventBus, ConnManager: cfg.ConnManager, @@ -319,6 +382,8 @@ func (cfg *Config) NewNode() (host.Host, error) { RelayServiceOpts: cfg.RelayServiceOpts, EnableMetrics: !cfg.DisableMetrics, PrometheusRegisterer: cfg.PrometheusRegisterer, + EnableAutoNATv2: !cfg.DisableAutoNATv2, + AutoNATv2Dialer: autonatv2Dialer, }) if err != nil { swrm.Close() @@ -396,46 +461,15 @@ func (cfg *Config) NewNode() (host.Host, error) { autonat.WithPeerThrottling(cfg.AutoNATConfig.ThrottlePeerLimit)) } if cfg.AutoNATConfig.EnableService { - autonatPrivKey, _, err := crypto.GenerateEd25519Key(rand.Reader) + ah, err := cfg.makeAutoNATHost() if err != nil { - return nil, err - } - ps, err := pstoremem.NewPeerstore() - if err != nil { - return nil, err - } - - // Pull out the pieces of the config that we _actually_ care about. - // Specifically, don't set up things like autorelay, listeners, - // identify, etc. - autoNatCfg := Config{ - Transports: cfg.Transports, - Muxers: cfg.Muxers, - SecurityTransports: cfg.SecurityTransports, - Insecure: cfg.Insecure, - PSK: cfg.PSK, - ConnectionGater: cfg.ConnectionGater, - Reporter: cfg.Reporter, - PeerKey: autonatPrivKey, - Peerstore: ps, - DialRanker: swarm.NoDelayDialRanker, - } - - dialer, err := autoNatCfg.makeSwarm(eventbus.NewBus(), false) - if err != nil { - h.Close() - return nil, err - } - dialerHost := blankhost.NewBlankHost(dialer) - if err := autoNatCfg.addTransports(dialerHost); err != nil { - dialerHost.Close() h.Close() return nil, err } // NOTE: We're dropping the blank host here but that's fine. It // doesn't really _do_ anything and doesn't even need to be // closed (as long as we close the underlying network). - autonatOpts = append(autonatOpts, autonat.EnableService(dialerHost.Network())) + autonatOpts = append(autonatOpts, autonat.EnableService(ah.Network())) } if cfg.AutoNATConfig.ForceReachability != nil { autonatOpts = append(autonatOpts, autonat.WithReachability(*cfg.AutoNATConfig.ForceReachability)) diff --git a/core/network/network.go b/core/network/network.go index 66b0a1cd34..9928f52641 100644 --- a/core/network/network.go +++ b/core/network/network.go @@ -185,6 +185,9 @@ type Dialer interface { // Notify/StopNotify register and unregister a notifiee for signals Notify(Notifiee) StopNotify(Notifiee) + + // CanDial returns whether the dialer can dial peer p at addr + CanDial(p peer.ID, addr ma.Multiaddr) bool } // AddrDelay provides an address along with the delay after which the address diff --git a/defaults.go b/defaults.go index d11302690c..56212875d4 100644 --- a/defaults.go +++ b/defaults.go @@ -10,6 +10,7 @@ import ( rcmgr "github.com/libp2p/go-libp2p/p2p/host/resource-manager" "github.com/libp2p/go-libp2p/p2p/muxer/yamux" "github.com/libp2p/go-libp2p/p2p/net/connmgr" + "github.com/libp2p/go-libp2p/p2p/net/swarm" "github.com/libp2p/go-libp2p/p2p/security/noise" tls "github.com/libp2p/go-libp2p/p2p/security/tls" quic "github.com/libp2p/go-libp2p/p2p/transport/quic" @@ -133,6 +134,18 @@ var DefaultPrometheusRegisterer = func(cfg *Config) error { return cfg.Apply(PrometheusRegisterer(prometheus.DefaultRegisterer)) } +var defaultUDPBlackHoleDetector = func(cfg *Config) error { + // A black hole is a binary property. On a network if UDP dials are blocked, all dials will + // fail. So a low success rate of 5 out 100 dials is good enough. + return cfg.Apply(UDPBlackHoleFilter(&swarm.BlackHoleFilter{N: 100, MinSuccesses: 5, Name: "UDP"})) +} + +var defaultIPv6BlackHoleDetector = func(cfg *Config) error { + // A black hole is a binary property. On a network if there is no IPv6 connectivity, all + // dials will fail. So a low success rate of 5 out 100 dials is good enough. + return cfg.Apply(IPv6BlackHoleFilter(&swarm.BlackHoleFilter{N: 100, MinSuccesses: 5, Name: "IPv6"})) +} + // Complete list of default options and when to fallback on them. // // Please *DON'T* specify default options any other way. Putting this all here @@ -189,6 +202,18 @@ var defaults = []struct { fallback: func(cfg *Config) bool { return !cfg.DisableMetrics && cfg.PrometheusRegisterer == nil }, opt: DefaultPrometheusRegisterer, }, + { + fallback: func(cfg *Config) bool { + return !cfg.CustomUDPBlackHoleFilter && cfg.UDPBlackHoleFilter == nil + }, + opt: defaultUDPBlackHoleDetector, + }, + { + fallback: func(cfg *Config) bool { + return !cfg.CustomIPv6BlackHoleFilter && cfg.IPv6BlackHoleFilter == nil + }, + opt: defaultIPv6BlackHoleDetector, + }, } // Defaults configures libp2p to use the default options. Can be combined with diff --git a/options.go b/options.go index beb4930f7c..ac9e21af4b 100644 --- a/options.go +++ b/options.go @@ -597,3 +597,29 @@ func SwarmOpts(opts ...swarm.Option) Option { return nil } } + +// DisableAutoNATv2 disables autonat +func DisableAutoNATv2() Option { + return func(cfg *Config) error { + cfg.DisableAutoNATv2 = true + return nil + } +} + +// UDPBlackHoleFilter configures libp2p to use f as the black hole filter for UDP addrs +func UDPBlackHoleFilter(f *swarm.BlackHoleFilter) Option { + return func(cfg *Config) error { + cfg.UDPBlackHoleFilter = f + cfg.CustomUDPBlackHoleFilter = true + return nil + } +} + +// IPv6BlackHoleFilter configures libp2p to use f as the black hole filter for IPv6 addrs +func IPv6BlackHoleFilter(f *swarm.BlackHoleFilter) Option { + return func(cfg *Config) error { + cfg.IPv6BlackHoleFilter = f + cfg.CustomIPv6BlackHoleFilter = true + return nil + } +} diff --git a/p2p/host/basic/basic_host.go b/p2p/host/basic/basic_host.go index 71ad396768..098080bd6e 100644 --- a/p2p/host/basic/basic_host.go +++ b/p2p/host/basic/basic_host.go @@ -23,6 +23,7 @@ import ( "github.com/libp2p/go-libp2p/p2p/host/eventbus" "github.com/libp2p/go-libp2p/p2p/host/pstoremanager" "github.com/libp2p/go-libp2p/p2p/host/relaysvc" + "github.com/libp2p/go-libp2p/p2p/protocol/autonatv2" relayv2 "github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/relay" "github.com/libp2p/go-libp2p/p2p/protocol/holepunch" "github.com/libp2p/go-libp2p/p2p/protocol/identify" @@ -101,6 +102,8 @@ type BasicHost struct { caBook peerstore.CertifiedAddrBook autoNat autonat.AutoNAT + + autonatv2 *autonatv2.AutoNAT } var _ host.Host = (*BasicHost)(nil) @@ -160,6 +163,9 @@ type HostOpts struct { EnableMetrics bool // PrometheusRegisterer is the PrometheusRegisterer used for metrics PrometheusRegisterer prometheus.Registerer + + EnableAutoNATv2 bool + AutoNATv2Dialer host.Host } // NewHost constructs a new *BasicHost and activates it by attaching its stream and connection handlers to the given inet.Network. @@ -301,6 +307,13 @@ func NewHost(n network.Network, opts *HostOpts) (*BasicHost, error) { h.pings = ping.NewPingService(h) } + if opts.EnableAutoNATv2 { + h.autonatv2, err = autonatv2.New(h, opts.AutoNATv2Dialer) + if err != nil { + return nil, fmt.Errorf("failed to create autonatv2: %w", err) + } + } + n.SetStreamHandler(h.newStreamHandler) // register to be notified when the network's listen addrs change, @@ -1025,6 +1038,9 @@ func (h *BasicHost) Close() error { if h.hps != nil { h.hps.Close() } + if h.autonatv2 != nil { + h.autonatv2.Close() + } _ = h.emitters.evtLocalProtocolsUpdated.Close() _ = h.emitters.evtLocalAddrsUpdated.Close() diff --git a/p2p/host/blank/blank.go b/p2p/host/blank/blank.go index 24304498b0..bf00ecd565 100644 --- a/p2p/host/blank/blank.go +++ b/p2p/host/blank/blank.go @@ -63,9 +63,10 @@ func NewBlankHost(n network.Network, options ...Option) *BlankHost { } bh := &BlankHost{ - n: n, - cmgr: cfg.cmgr, - mux: mstream.NewMultistreamMuxer[protocol.ID](), + n: n, + cmgr: cfg.cmgr, + mux: mstream.NewMultistreamMuxer[protocol.ID](), + eventbus: cfg.eventBus, } if bh.eventbus == nil { bh.eventbus = eventbus.NewBus(eventbus.WithMetricsTracer(eventbus.NewMetricsTracer())) diff --git a/p2p/net/mock/mock_peernet.go b/p2p/net/mock/mock_peernet.go index 2e56b7f2bb..0b525d3e64 100644 --- a/p2p/net/mock/mock_peernet.go +++ b/p2p/net/mock/mock_peernet.go @@ -434,3 +434,7 @@ func (pn *peernet) notifyAll(notification func(f network.Notifiee)) { func (pn *peernet) ResourceManager() network.ResourceManager { return &network.NullResourceManager{} } + +func (pn *peernet) CanDial(p peer.ID, addr ma.Multiaddr) bool { + return true +} diff --git a/p2p/net/swarm/black_hole_detector.go b/p2p/net/swarm/black_hole_detector.go index dd7849eea6..37b37fc715 100644 --- a/p2p/net/swarm/black_hole_detector.go +++ b/p2p/net/swarm/black_hole_detector.go @@ -29,35 +29,25 @@ func (st blackHoleState) String() string { } } -type blackHoleResult int - -const ( - blackHoleResultAllowed blackHoleResult = iota - blackHoleResultProbing - blackHoleResultBlocked -) - -// blackHoleFilter provides black hole filtering for dials. This filter should be used in -// concert with a UDP of IPv6 address filter to detect UDP or IPv6 black hole. In a black -// holed environments dial requests are blocked and only periodic probes to check the -// state of the black hole are allowed. -// -// Requests are blocked if the number of successes in the last n dials is less than -// minSuccesses. If a request succeeds in Blocked state, the filter state is reset and n -// subsequent requests are allowed before reevaluating black hole state. Dials cancelled -// when some other concurrent dial succeeded are counted as failures. A sufficiently large -// n prevents false negatives in such cases. -type blackHoleFilter struct { - // n serves the dual purpose of being the minimum number of requests after which we - // probe the state of the black hole in blocked state and the minimum number of - // completed dials required before evaluating black hole state. - n int - // minSuccesses is the minimum number of Success required in the last n dials +// BlackHoleFilter provides black hole filtering for dials. This filter should be used in concert +// with a UDP or IPv6 address filter to detect UDP or IPv6 black hole. In a black holed environment, +// dial requests are refused Requests are blocked if the number of successes in the last N dials is +// less than MinSuccesses. +// If a request succeeds in Blocked state, the filter state is reset and N subsequent requests are +// allowed before reevaluating black hole state. Dials cancelled when some other concurrent dial +// succeeded are counted as failures. A sufficiently large N prevents false negatives in such cases. +type BlackHoleFilter struct { + // N serves the dual purpose of being the minimum number of completed dials required before + // evaluating black hole state and the minimum number of requests after which we probe the + // state of the black hole in blocked state + N int + // MinSuccesses is the minimum number of Success required in the last n dials // to consider we are not blocked. - minSuccesses int - // name for the detector. - name string + MinSuccesses int + // Name for the detector. + Name string + mu sync.Mutex // requests counts number of dial requests to peers. We handle request at a peer // level and record results at individual address dial level. requests int @@ -67,22 +57,19 @@ type blackHoleFilter struct { successes int // state is the current state of the detector state blackHoleState - - mu sync.Mutex - metricsTracer MetricsTracer } -// RecordResult records the outcome of a dial. A successful dial will change the state -// of the filter to Allowed. A failed dial only blocks subsequent requests if the success +// RecordResult records the outcome of a dial. A successful dial in Blocked state will change the +// state of the filter to Probing. A failed dial only blocks subsequent requests if the success // fraction over the last n outcomes is less than the minSuccessFraction of the filter. -func (b *blackHoleFilter) RecordResult(success bool) { +func (b *BlackHoleFilter) RecordResult(success bool) { b.mu.Lock() defer b.mu.Unlock() if b.state == blackHoleStateBlocked && success { // If the call succeeds in a blocked state we reset to allowed. // This is better than slowly accumulating values till we cross the minSuccessFraction - // threshold since a blackhole is a binary property. + // threshold since a black hole is a binary property. b.reset() return } @@ -92,7 +79,7 @@ func (b *blackHoleFilter) RecordResult(success bool) { } b.dialResults = append(b.dialResults, success) - if len(b.dialResults) > b.n { + if len(b.dialResults) > b.N { if b.dialResults[0] { b.successes-- } @@ -100,58 +87,68 @@ func (b *blackHoleFilter) RecordResult(success bool) { } b.updateState() - b.trackMetrics() } // HandleRequest returns the result of applying the black hole filter for the request. -func (b *blackHoleFilter) HandleRequest() blackHoleResult { +func (b *BlackHoleFilter) HandleRequest() blackHoleState { b.mu.Lock() defer b.mu.Unlock() b.requests++ - b.trackMetrics() - if b.state == blackHoleStateAllowed { - return blackHoleResultAllowed - } else if b.state == blackHoleStateProbing || b.requests%b.n == 0 { - return blackHoleResultProbing + return blackHoleStateAllowed + } else if b.state == blackHoleStateProbing || b.requests%b.N == 0 { + return blackHoleStateProbing } else { - return blackHoleResultBlocked + return blackHoleStateBlocked } } -func (b *blackHoleFilter) reset() { +func (b *BlackHoleFilter) reset() { b.successes = 0 b.dialResults = b.dialResults[:0] b.requests = 0 b.updateState() } -func (b *blackHoleFilter) updateState() { +func (b *BlackHoleFilter) updateState() { st := b.state - if len(b.dialResults) < b.n { + if len(b.dialResults) < b.N { b.state = blackHoleStateProbing - } else if b.successes >= b.minSuccesses { + } else if b.successes >= b.MinSuccesses { b.state = blackHoleStateAllowed } else { b.state = blackHoleStateBlocked } if st != b.state { - log.Debugf("%s blackHoleDetector state changed from %s to %s", b.name, st, b.state) + log.Debugf("%s blackHoleDetector state changed from %s to %s", b.Name, st, b.state) } } -func (b *blackHoleFilter) trackMetrics() { - if b.metricsTracer == nil { - return - } +func (b *BlackHoleFilter) State() blackHoleState { + b.mu.Lock() + defer b.mu.Unlock() + + return b.state +} - nextRequestAllowedAfter := 0 +type blackHoleInfo struct { + name string + state blackHoleState + nextProbeAfter int + successFraction float64 +} + +func (b *BlackHoleFilter) info() blackHoleInfo { + b.mu.Lock() + defer b.mu.Unlock() + + nextProbeAfter := 0 if b.state == blackHoleStateBlocked { - nextRequestAllowedAfter = b.n - (b.requests % b.n) + nextProbeAfter = b.N - (b.requests % b.N) } successFraction := 0.0 @@ -159,22 +156,27 @@ func (b *blackHoleFilter) trackMetrics() { successFraction = float64(b.successes) / float64(len(b.dialResults)) } - b.metricsTracer.UpdatedBlackHoleFilterState( - b.name, - b.state, - nextRequestAllowedAfter, - successFraction, - ) + return blackHoleInfo{ + name: b.Name, + state: b.state, + nextProbeAfter: nextProbeAfter, + successFraction: successFraction, + } } -// blackHoleDetector provides UDP and IPv6 black hole detection using a `blackHoleFilter` -// for each. For details of the black hole detection logic see `blackHoleFilter`. +// blackHoleDetector provides UDP and IPv6 black hole detection using a `blackHoleFilter` for each. +// For details of the black hole detection logic see `blackHoleFilter`. +// In Read Only mode, detector doesn't update the state of underlying filters and refuses requests +// when black hole state is unknown. This is useful for Swarms made specifically for services like +// AutoNAT where we care about accurately reporting the reachability of a peer. // -// black hole filtering is done at a peer dial level to ensure that periodic probes to -// detect change of the black hole state are actually dialed and are not skipped -// because of dial prioritisation logic. +// Black hole filtering is done at a peer dial level to ensure that periodic probes to detect change +// of the black hole state are actually dialed and are not skipped because of dial prioritisation +// logic. type blackHoleDetector struct { - udp, ipv6 *blackHoleFilter + udp, ipv6 *BlackHoleFilter + mt MetricsTracer + readOnly bool } // FilterAddrs filters the peer's addresses removing black holed addresses @@ -192,14 +194,16 @@ func (d *blackHoleDetector) FilterAddrs(addrs []ma.Multiaddr) (valid []ma.Multia } } - udpRes := blackHoleResultAllowed + udpRes := blackHoleStateAllowed if d.udp != nil && hasUDP { - udpRes = d.udp.HandleRequest() + udpRes = d.getFilterState(d.udp) + d.trackMetrics(d.udp) } - ipv6Res := blackHoleResultAllowed + ipv6Res := blackHoleStateAllowed if d.ipv6 != nil && hasIPv6 { - ipv6Res = d.ipv6.HandleRequest() + ipv6Res = d.getFilterState(d.ipv6) + d.trackMetrics(d.ipv6) } blackHoled = make([]ma.Multiaddr, 0, len(addrs)) @@ -210,19 +214,19 @@ func (d *blackHoleDetector) FilterAddrs(addrs []ma.Multiaddr) (valid []ma.Multia return true } // allow all UDP addresses while probing irrespective of IPv6 black hole state - if udpRes == blackHoleResultProbing && isProtocolAddr(a, ma.P_UDP) { + if udpRes == blackHoleStateProbing && isProtocolAddr(a, ma.P_UDP) { return true } // allow all IPv6 addresses while probing irrespective of UDP black hole state - if ipv6Res == blackHoleResultProbing && isProtocolAddr(a, ma.P_IP6) { + if ipv6Res == blackHoleStateProbing && isProtocolAddr(a, ma.P_IP6) { return true } - if udpRes == blackHoleResultBlocked && isProtocolAddr(a, ma.P_UDP) { + if udpRes == blackHoleStateBlocked && isProtocolAddr(a, ma.P_UDP) { blackHoled = append(blackHoled, a) return false } - if ipv6Res == blackHoleResultBlocked && isProtocolAddr(a, ma.P_IP6) { + if ipv6Res == blackHoleStateBlocked && isProtocolAddr(a, ma.P_IP6) { blackHoled = append(blackHoled, a) return false } @@ -231,49 +235,36 @@ func (d *blackHoleDetector) FilterAddrs(addrs []ma.Multiaddr) (valid []ma.Multia ), blackHoled } -// RecordResult updates the state of the relevant `blackHoleFilter`s for addr +// RecordResult updates the state of the relevant blackHoleFilters for addr func (d *blackHoleDetector) RecordResult(addr ma.Multiaddr, success bool) { - if !manet.IsPublicAddr(addr) { + if d.readOnly || !manet.IsPublicAddr(addr) { return } if d.udp != nil && isProtocolAddr(addr, ma.P_UDP) { d.udp.RecordResult(success) + d.trackMetrics(d.udp) } if d.ipv6 != nil && isProtocolAddr(addr, ma.P_IP6) { d.ipv6.RecordResult(success) + d.trackMetrics(d.ipv6) } } -// blackHoleConfig is the config used for black hole detection -type blackHoleConfig struct { - // Enabled enables black hole detection - Enabled bool - // N is the size of the sliding window used to evaluate black hole state - N int - // MinSuccesses is the minimum number of successes out of N required to not - // block requests - MinSuccesses int -} - -func newBlackHoleDetector(udpConfig, ipv6Config blackHoleConfig, mt MetricsTracer) *blackHoleDetector { - d := &blackHoleDetector{} - - if udpConfig.Enabled { - d.udp = &blackHoleFilter{ - n: udpConfig.N, - minSuccesses: udpConfig.MinSuccesses, - name: "UDP", - metricsTracer: mt, +func (d *blackHoleDetector) getFilterState(f *BlackHoleFilter) blackHoleState { + if d.readOnly { + if f.State() != blackHoleStateAllowed { + return blackHoleStateBlocked } + return blackHoleStateAllowed } + return f.HandleRequest() +} - if ipv6Config.Enabled { - d.ipv6 = &blackHoleFilter{ - n: ipv6Config.N, - minSuccesses: ipv6Config.MinSuccesses, - name: "IPv6", - metricsTracer: mt, - } +func (d *blackHoleDetector) trackMetrics(f *BlackHoleFilter) { + if d.readOnly || d.mt == nil { + return } - return d + // Track metrics only in non readOnly state + info := f.info() + d.mt.UpdatedBlackHoleFilterState(info.name, info.state, info.nextProbeAfter, info.successFraction) } diff --git a/p2p/net/swarm/black_hole_detector_test.go b/p2p/net/swarm/black_hole_detector_test.go index dfbb30f90d..667f0b0881 100644 --- a/p2p/net/swarm/black_hole_detector_test.go +++ b/p2p/net/swarm/black_hole_detector_test.go @@ -10,36 +10,48 @@ import ( func TestBlackHoleFilterReset(t *testing.T) { n := 10 - bhf := &blackHoleFilter{n: n, minSuccesses: 2, name: "test"} + bhf := &BlackHoleFilter{N: n, MinSuccesses: 2, Name: "test"} var i = 0 // calls up to n should be probing for i = 1; i <= n; i++ { - if bhf.HandleRequest() != blackHoleResultProbing { + if bhf.HandleRequest() != blackHoleStateProbing { t.Fatalf("expected calls up to n to be probes") } + if bhf.State() != blackHoleStateProbing { + t.Fatalf("expected state to be probing got %s", bhf.State()) + } bhf.RecordResult(false) } // after threshold calls every nth call should be a probe for i = n + 1; i < 42; i++ { result := bhf.HandleRequest() - if (i%n == 0 && result != blackHoleResultProbing) || (i%n != 0 && result != blackHoleResultBlocked) { + if (i%n == 0 && result != blackHoleStateProbing) || (i%n != 0 && result != blackHoleStateBlocked) { t.Fatalf("expected every nth dial to be a probe") } + if bhf.State() != blackHoleStateBlocked { + t.Fatalf("expected state to be blocked, got %s", bhf.State()) + } } bhf.RecordResult(true) // check if calls up to n are probes again for i = 0; i < n; i++ { - if bhf.HandleRequest() != blackHoleResultProbing { + if bhf.HandleRequest() != blackHoleStateProbing { t.Fatalf("expected black hole detector state to reset after success") } + if bhf.State() != blackHoleStateProbing { + t.Fatalf("expected state to be probing got %s", bhf.State()) + } bhf.RecordResult(false) } // next call should be blocked - if bhf.HandleRequest() != blackHoleResultBlocked { + if bhf.HandleRequest() != blackHoleStateBlocked { t.Fatalf("expected dial to be blocked") + if bhf.State() != blackHoleStateBlocked { + t.Fatalf("expected state to be blocked, got %s", bhf.State()) + } } } @@ -47,19 +59,19 @@ func TestBlackHoleFilterSuccessFraction(t *testing.T) { n := 10 tests := []struct { minSuccesses, successes int - result blackHoleResult + result blackHoleState }{ - {minSuccesses: 5, successes: 5, result: blackHoleResultAllowed}, - {minSuccesses: 3, successes: 3, result: blackHoleResultAllowed}, - {minSuccesses: 5, successes: 4, result: blackHoleResultBlocked}, - {minSuccesses: 5, successes: 7, result: blackHoleResultAllowed}, - {minSuccesses: 3, successes: 1, result: blackHoleResultBlocked}, - {minSuccesses: 0, successes: 0, result: blackHoleResultAllowed}, - {minSuccesses: 10, successes: 10, result: blackHoleResultAllowed}, + {minSuccesses: 5, successes: 5, result: blackHoleStateAllowed}, + {minSuccesses: 3, successes: 3, result: blackHoleStateAllowed}, + {minSuccesses: 5, successes: 4, result: blackHoleStateBlocked}, + {minSuccesses: 5, successes: 7, result: blackHoleStateAllowed}, + {minSuccesses: 3, successes: 1, result: blackHoleStateBlocked}, + {minSuccesses: 0, successes: 0, result: blackHoleStateAllowed}, + {minSuccesses: 10, successes: 10, result: blackHoleStateAllowed}, } for i, tc := range tests { t.Run(fmt.Sprintf("case-%d", i), func(t *testing.T) { - bhf := blackHoleFilter{n: n, minSuccesses: tc.minSuccesses} + bhf := BlackHoleFilter{N: n, MinSuccesses: tc.minSuccesses} for i := 0; i < tc.successes; i++ { bhf.RecordResult(true) } @@ -75,9 +87,9 @@ func TestBlackHoleFilterSuccessFraction(t *testing.T) { } func TestBlackHoleDetectorInApplicableAddress(t *testing.T) { - udpConfig := blackHoleConfig{Enabled: true, N: 10, MinSuccesses: 5} - ipv6Config := blackHoleConfig{Enabled: true, N: 10, MinSuccesses: 5} - bhd := newBlackHoleDetector(udpConfig, ipv6Config, nil) + udpF := &BlackHoleFilter{N: 10, MinSuccesses: 5} + ipv6F := &BlackHoleFilter{N: 10, MinSuccesses: 5} + bhd := &blackHoleDetector{udp: udpF, ipv6: ipv6F} addrs := []ma.Multiaddr{ ma.StringCast("/ip4/1.2.3.4/tcp/1234"), ma.StringCast("/ip4/1.2.3.4/tcp/1233"), @@ -94,8 +106,8 @@ func TestBlackHoleDetectorInApplicableAddress(t *testing.T) { } func TestBlackHoleDetectorUDPDisabled(t *testing.T) { - ipv6Config := blackHoleConfig{Enabled: true, N: 10, MinSuccesses: 5} - bhd := newBlackHoleDetector(blackHoleConfig{Enabled: false}, ipv6Config, nil) + ipv6F := &BlackHoleFilter{N: 10, MinSuccesses: 5} + bhd := &blackHoleDetector{ipv6: ipv6F} publicAddr := ma.StringCast("/ip4/1.2.3.4/udp/1234/quic-v1") privAddr := ma.StringCast("/ip4/192.168.1.5/udp/1234/quic-v1") for i := 0; i < 100; i++ { @@ -110,8 +122,8 @@ func TestBlackHoleDetectorUDPDisabled(t *testing.T) { } func TestBlackHoleDetectorIPv6Disabled(t *testing.T) { - udpConfig := blackHoleConfig{Enabled: true, N: 10, MinSuccesses: 5} - bhd := newBlackHoleDetector(udpConfig, blackHoleConfig{Enabled: false}, nil) + udpF := &BlackHoleFilter{N: 10, MinSuccesses: 5} + bhd := &blackHoleDetector{udp: udpF} publicAddr := ma.StringCast("/ip6/1::1/tcp/1234") privAddr := ma.StringCast("/ip6/::1/tcp/1234") for i := 0; i < 100; i++ { @@ -128,8 +140,8 @@ func TestBlackHoleDetectorIPv6Disabled(t *testing.T) { func TestBlackHoleDetectorProbes(t *testing.T) { bhd := &blackHoleDetector{ - udp: &blackHoleFilter{n: 2, minSuccesses: 1, name: "udp"}, - ipv6: &blackHoleFilter{n: 3, minSuccesses: 1, name: "ipv6"}, + udp: &BlackHoleFilter{N: 2, MinSuccesses: 1, Name: "udp"}, + ipv6: &BlackHoleFilter{N: 3, MinSuccesses: 1, Name: "ipv6"}, } udp6Addr := ma.StringCast("/ip6/1::1/udp/1234/quic-v1") addrs := []ma.Multiaddr{udp6Addr} @@ -163,8 +175,8 @@ func TestBlackHoleDetectorAddrFiltering(t *testing.T) { makeBHD := func(udpBlocked, ipv6Blocked bool) *blackHoleDetector { bhd := &blackHoleDetector{ - udp: &blackHoleFilter{n: 100, minSuccesses: 10, name: "udp"}, - ipv6: &blackHoleFilter{n: 100, minSuccesses: 10, name: "ipv6"}, + udp: &BlackHoleFilter{N: 100, MinSuccesses: 10, Name: "udp"}, + ipv6: &BlackHoleFilter{N: 100, MinSuccesses: 10, Name: "ipv6"}, } for i := 0; i < 100; i++ { bhd.RecordResult(udp4Pub, !udpBlocked) @@ -199,3 +211,35 @@ func TestBlackHoleDetectorAddrFiltering(t *testing.T) { require.ElementsMatch(t, bothBlockedOutput, gotAddrs) require.ElementsMatch(t, bothPublicAddrs, gotRemovedAddrs) } + +func TestBlackHoleDetectorReadOnlyMode(t *testing.T) { + udpF := &BlackHoleFilter{N: 10, MinSuccesses: 5} + ipv6F := &BlackHoleFilter{N: 10, MinSuccesses: 5} + bhd := &blackHoleDetector{udp: udpF, ipv6: ipv6F, readOnly: true} + publicAddr := ma.StringCast("/ip4/1.2.3.4/udp/1234/quic-v1") + privAddr := ma.StringCast("/ip6/::1/tcp/1234") + for i := 0; i < 100; i++ { + bhd.RecordResult(publicAddr, true) + } + allAddr := []ma.Multiaddr{privAddr, publicAddr} + // public addr filtered because state is probing + wantAddrs := []ma.Multiaddr{privAddr} + wantRemovedAddrs := []ma.Multiaddr{publicAddr} + + gotAddrs, gotRemovedAddrs := bhd.FilterAddrs(allAddr) + require.ElementsMatch(t, wantAddrs, gotAddrs) + require.ElementsMatch(t, wantRemovedAddrs, gotRemovedAddrs) + + // a non readonly shared state black hole detector + nbhd := &blackHoleDetector{udp: bhd.udp, ipv6: bhd.ipv6, readOnly: false} + for i := 0; i < 100; i++ { + nbhd.RecordResult(publicAddr, true) + } + // no addresses filtered because state is allowed + wantAddrs = []ma.Multiaddr{privAddr, publicAddr} + wantRemovedAddrs = []ma.Multiaddr{} + + gotAddrs, gotRemovedAddrs = bhd.FilterAddrs(allAddr) + require.ElementsMatch(t, wantAddrs, gotAddrs) + require.ElementsMatch(t, wantRemovedAddrs, gotRemovedAddrs) +} diff --git a/p2p/net/swarm/swarm.go b/p2p/net/swarm/swarm.go index 5155cd2228..483ae074d1 100644 --- a/p2p/net/swarm/swarm.go +++ b/p2p/net/swarm/swarm.go @@ -111,22 +111,33 @@ func WithDialRanker(d network.DialRanker) Option { } } -// WithUDPBlackHoleConfig configures swarm to use c as the config for UDP black hole detection +// WithUDPBlackHoleFilter configures swarm to use the provided config for UDP black hole detection // n is the size of the sliding window used to evaluate black hole state // min is the minimum number of successes out of n required to not block requests -func WithUDPBlackHoleConfig(enabled bool, n, min int) Option { +func WithUDPBlackHoleFilter(f *BlackHoleFilter) Option { return func(s *Swarm) error { - s.udpBlackHoleConfig = blackHoleConfig{Enabled: enabled, N: n, MinSuccesses: min} + s.udpBHF = f return nil } } -// WithIPv6BlackHoleConfig configures swarm to use c as the config for IPv6 black hole detection +// WithIPv6BlackHoleFilter configures swarm to use the provided config for IPv6 black hole detection // n is the size of the sliding window used to evaluate black hole state // min is the minimum number of successes out of n required to not block requests -func WithIPv6BlackHoleConfig(enabled bool, n, min int) Option { +func WithIPv6BlackHoleFilter(f *BlackHoleFilter) Option { return func(s *Swarm) error { - s.ipv6BlackHoleConfig = blackHoleConfig{Enabled: enabled, N: n, MinSuccesses: min} + s.ipv6BHF = f + return nil + } +} + +// WithReadOnlyBlackHoleDetector configures the swarm to use the black hole detector in +// read only mode. In Read Only mode dial requests are refused in unknown state and +// no updates to the detector state are made. This is useful for services like AutoNAT that +// care about accurately providing reachability info. +func WithReadOnlyBlackHoleDetector() Option { + return func(s *Swarm) error { + s.readOnlyBHD = true return nil } } @@ -197,9 +208,10 @@ type Swarm struct { dialRanker network.DialRanker - udpBlackHoleConfig blackHoleConfig - ipv6BlackHoleConfig blackHoleConfig - bhd *blackHoleDetector + udpBHF *BlackHoleFilter + ipv6BHF *BlackHoleFilter + bhd *blackHoleDetector + readOnlyBHD bool } // NewSwarm constructs a Swarm. @@ -223,8 +235,8 @@ func NewSwarm(local peer.ID, peers peerstore.Peerstore, eventBus event.Bus, opts // A black hole is a binary property. On a network if UDP dials are blocked or there is // no IPv6 connectivity, all dials will fail. So a low success rate of 5 out 100 dials // is good enough. - udpBlackHoleConfig: blackHoleConfig{Enabled: true, N: 100, MinSuccesses: 5}, - ipv6BlackHoleConfig: blackHoleConfig{Enabled: true, N: 100, MinSuccesses: 5}, + udpBHF: &BlackHoleFilter{N: 100, MinSuccesses: 5, Name: "UDP"}, + ipv6BHF: &BlackHoleFilter{N: 100, MinSuccesses: 5, Name: "IPv6"}, } s.conns.m = make(map[peer.ID][]*Conn) @@ -246,8 +258,12 @@ func NewSwarm(local peer.ID, peers peerstore.Peerstore, eventBus event.Bus, opts s.limiter = newDialLimiter(s.dialAddr) s.backf.init(s.ctx) - s.bhd = newBlackHoleDetector(s.udpBlackHoleConfig, s.ipv6BlackHoleConfig, s.metricsTracer) - + s.bhd = &blackHoleDetector{ + udp: s.udpBHF, + ipv6: s.ipv6BHF, + mt: s.metricsTracer, + readOnly: s.readOnlyBHD, + } return s, nil } diff --git a/p2p/net/swarm/swarm_dial.go b/p2p/net/swarm/swarm_dial.go index 1e5638f379..451505bd25 100644 --- a/p2p/net/swarm/swarm_dial.go +++ b/p2p/net/swarm/swarm_dial.go @@ -417,6 +417,11 @@ func (s *Swarm) dialNextAddr(ctx context.Context, p peer.ID, addr ma.Multiaddr, return nil } +func (s *Swarm) CanDial(p peer.ID, addr ma.Multiaddr) bool { + dialable, _ := s.filterKnownUndialables(p, []ma.Multiaddr{addr}) + return len(dialable) > 0 +} + func (s *Swarm) nonProxyAddr(addr ma.Multiaddr) bool { t := s.TransportForDialing(addr) return !t.Proxy() diff --git a/p2p/net/swarm/swarm_dial_test.go b/p2p/net/swarm/swarm_dial_test.go index 47310978fe..989d6508fb 100644 --- a/p2p/net/swarm/swarm_dial_test.go +++ b/p2p/net/swarm/swarm_dial_test.go @@ -341,7 +341,7 @@ func TestBlackHoledAddrBlocked(t *testing.T) { defer s.Close() n := 3 - s.bhd.ipv6 = &blackHoleFilter{n: n, minSuccesses: 1, name: "IPv6"} + s.bhd.ipv6 = &BlackHoleFilter{N: n, MinSuccesses: 1, Name: "IPv6"} // all dials to the address will fail. RFC6666 Discard Prefix addr := ma.StringCast("/ip6/0100::1/tcp/54321/") diff --git a/p2p/protocol/autonatv2/autonat.go b/p2p/protocol/autonatv2/autonat.go new file mode 100644 index 0000000000..26811d0943 --- /dev/null +++ b/p2p/protocol/autonatv2/autonat.go @@ -0,0 +1,250 @@ +package autonatv2 + +import ( + "context" + "errors" + "fmt" + "sync" + "time" + + logging "github.com/ipfs/go-log/v2" + "github.com/libp2p/go-libp2p/core/event" + "github.com/libp2p/go-libp2p/core/host" + "github.com/libp2p/go-libp2p/core/network" + "github.com/libp2p/go-libp2p/core/peer" + "github.com/libp2p/go-libp2p/p2p/protocol/autonatv2/pb" + ma "github.com/multiformats/go-multiaddr" + manet "github.com/multiformats/go-multiaddr/net" + "golang.org/x/exp/rand" + "golang.org/x/exp/slices" +) + +//go:generate protoc --go_out=. --go_opt=Mpb/autonatv2.proto=./pb pb/autonatv2.proto + +const ( + ServiceName = "libp2p.autonatv2" + DialBackProtocol = "/libp2p/autonat/2/dial-back" + DialProtocol = "/libp2p/autonat/2/dial-request" + + maxMsgSize = 8192 + streamTimeout = time.Minute + dialBackStreamTimeout = 5 * time.Second + dialBackDialTimeout = 30 * time.Second + minHandshakeSizeBytes = 30_000 // for amplification attack prevention + maxHandshakeSizeBytes = 100_000 + // maxPeerAddresses is the number of addresses in a dial request the server + // will inspect, rest are ignored. + maxPeerAddresses = 50 +) + +var ( + ErrNoValidPeers = errors.New("no valid peers for autonat v2") + ErrDialRefused = errors.New("dial refused") + + log = logging.Logger("autonatv2") +) + +// Request is the request to verify reachability of a single address +type Request struct { + // Addr is the multiaddr to verify + Addr ma.Multiaddr + // SendDialData indicates whether to send dial data if the server requests it for Addr + SendDialData bool +} + +// Result is the result of the CheckReachability call +type Result struct { + // Idx is the index of the dialed address + Idx int + // Addr is the dialed address + Addr ma.Multiaddr + // Reachability of the dialed address + Reachability network.Reachability + // Status is the outcome of the dialback + Status pb.DialStatus +} + +// AutoNAT implements the AutoNAT v2 client and server. +// Users can check reachability for their addresses using the CheckReachability method. +// The server provides amplification attack prevention and rate limiting. +type AutoNAT struct { + host host.Host + sub event.Subscription + + // for cleanly closing + ctx context.Context + cancel context.CancelFunc + wg sync.WaitGroup + + srv *server + cli *client + + mx sync.Mutex + peers *peersMap + + allowAllAddrs bool // for testing +} + +// New returns a new AutoNAT instance. The returned instance runs the server when the provided host +// is publicly reachable. +// host and dialerHost should have the same dialing capabilities. In case the host doesn't support +// a transport, dial back requests for address for that transport will be ignored. +func New(host host.Host, dialerHost host.Host, opts ...AutoNATOption) (*AutoNAT, error) { + s := defaultSettings() + for _, o := range opts { + if err := o(s); err != nil { + return nil, fmt.Errorf("failed to apply option: %w", err) + } + } + // We are listening on event.EvtPeerProtocolsUpdated, event.EvtPeerConnectednessChanged + // event.EvtPeerIdentificationCompleted to maintain our set of autonat supporting peers. + // + // We listen on event.EvtLocalReachabilityChanged to Disable the server if we are not + // publicly reachable. Currently this event is sent by the AutoNAT v1 module. During the + // transition period from AutoNAT v1 to v2, there won't be enough v2 servers on the network + // and most clients will be unable to discover a peer which supports AutoNAT v2. So, we use + // v1 to determine reachability for the transition period. + // + // Once there are enough v2 servers on the network for nodes to determine their reachability + // using AutoNAT v2, we'll use Address Pipeline + // (https://github.com/libp2p/go-libp2p/issues/2229)(to be implemented in a future release) + // to determine reachability using v2 client and send this event from Address Pipeline, if + // we are publicly reachable. + sub, err := host.EventBus().Subscribe([]interface{}{ + new(event.EvtLocalReachabilityChanged), + new(event.EvtPeerProtocolsUpdated), + new(event.EvtPeerConnectednessChanged), + new(event.EvtPeerIdentificationCompleted), + }) + if err != nil { + return nil, fmt.Errorf("event subscription: %w", err) + } + + ctx, cancel := context.WithCancel(context.Background()) + an := &AutoNAT{ + host: host, + ctx: ctx, + cancel: cancel, + sub: sub, + srv: newServer(host, dialerHost, s), + cli: newClient(host), + allowAllAddrs: s.allowAllAddrs, + peers: newPeersMap(), + } + an.cli.RegisterDialBack() + + an.wg.Add(1) + go an.background() + return an, nil +} + +func (an *AutoNAT) background() { + for { + select { + case <-an.ctx.Done(): + an.srv.Disable() + an.srv.Close() + an.peers = nil + an.wg.Done() + return + case e := <-an.sub.Out(): + switch evt := e.(type) { + case event.EvtLocalReachabilityChanged: + if evt.Reachability == network.ReachabilityPrivate { + an.srv.Disable() + } else { + an.srv.Enable() + } + case event.EvtPeerProtocolsUpdated: + an.updatePeer(evt.Peer) + case event.EvtPeerConnectednessChanged: + an.updatePeer(evt.Peer) + case event.EvtPeerIdentificationCompleted: + an.updatePeer(evt.Peer) + } + } + } +} + +func (an *AutoNAT) Close() { + an.cancel() + an.wg.Wait() +} + +// CheckReachability makes a single dial request for checking reachability for requested addresses +func (an *AutoNAT) CheckReachability(ctx context.Context, reqs []Request) (Result, error) { + if !an.allowAllAddrs { + for _, r := range reqs { + if !manet.IsPublicAddr(r.Addr) { + return Result{}, fmt.Errorf("private address cannot be verified by autonatv2: %s", r.Addr) + } + } + } + p := an.peers.GetRand() + if p == "" { + return Result{}, ErrNoValidPeers + } + + res, err := an.cli.CheckReachability(ctx, p, reqs) + if err != nil { + log.Debugf("reachability check with %s failed, err: %s", p, err) + return Result{}, fmt.Errorf("reachability check with %s failed: %w", p, err) + } + log.Debugf("reachability check with %s successful", p) + return res, nil +} + +func (an *AutoNAT) updatePeer(p peer.ID) { + an.mx.Lock() + defer an.mx.Unlock() + + // There are no ordering gurantees between identify and swarm events. Check peerstore + // and swarm for the current state + protos, err := an.host.Peerstore().SupportsProtocols(p, DialProtocol) + connectedness := an.host.Network().Connectedness(p) + if err == nil && slices.Contains(protos, DialProtocol) && connectedness == network.Connected { + an.peers.Put(p) + } else { + an.peers.Delete(p) + } +} + +// peersMap provides random access to a set of peers. This is useful when the map iteration order is +// not sufficiently random. +type peersMap struct { + peerIdx map[peer.ID]int + peers []peer.ID +} + +func newPeersMap() *peersMap { + return &peersMap{ + peerIdx: make(map[peer.ID]int), + peers: make([]peer.ID, 0), + } +} + +func (p *peersMap) GetRand() peer.ID { + if len(p.peers) == 0 { + return "" + } + return p.peers[rand.Intn(len(p.peers))] +} + +func (p *peersMap) Put(pid peer.ID) { + if _, ok := p.peerIdx[pid]; ok { + return + } + p.peers = append(p.peers, pid) + p.peerIdx[pid] = len(p.peers) - 1 +} + +func (p *peersMap) Delete(pid peer.ID) { + idx, ok := p.peerIdx[pid] + if !ok { + return + } + p.peers[idx] = p.peers[len(p.peers)-1] + p.peerIdx[p.peers[idx]] = idx + p.peers = p.peers[:len(p.peers)-1] + delete(p.peerIdx, pid) +} diff --git a/p2p/protocol/autonatv2/autonat_test.go b/p2p/protocol/autonatv2/autonat_test.go new file mode 100644 index 0000000000..fa0b236f90 --- /dev/null +++ b/p2p/protocol/autonatv2/autonat_test.go @@ -0,0 +1,644 @@ +package autonatv2 + +import ( + "context" + "errors" + "fmt" + "sync/atomic" + "testing" + "time" + + "github.com/libp2p/go-libp2p/core/host" + "github.com/libp2p/go-libp2p/core/network" + "github.com/libp2p/go-libp2p/core/peer" + "github.com/libp2p/go-libp2p/core/peerstore" + bhost "github.com/libp2p/go-libp2p/p2p/host/blank" + "github.com/libp2p/go-libp2p/p2p/host/eventbus" + "github.com/libp2p/go-libp2p/p2p/net/swarm" + swarmt "github.com/libp2p/go-libp2p/p2p/net/swarm/testing" + "github.com/libp2p/go-libp2p/p2p/protocol/autonatv2/pb" + + "github.com/libp2p/go-msgio/pbio" + ma "github.com/multiformats/go-multiaddr" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func newAutoNAT(t *testing.T, dialer host.Host, opts ...AutoNATOption) *AutoNAT { + t.Helper() + b := eventbus.NewBus() + h := bhost.NewBlankHost( + swarmt.GenSwarm(t, swarmt.EventBus(b)), bhost.WithEventBus(b)) + if dialer == nil { + dialer = bhost.NewBlankHost( + swarmt.GenSwarm(t, + swarmt.WithSwarmOpts( + swarm.WithUDPBlackHoleFilter(nil), + swarm.WithIPv6BlackHoleFilter(nil)))) + } + an, err := New(h, dialer, opts...) + if err != nil { + t.Error(err) + } + an.srv.Enable() + an.cli.RegisterDialBack() + return an +} + +func parseAddrs(t *testing.T, msg *pb.Message) []ma.Multiaddr { + t.Helper() + req := msg.GetDialRequest() + addrs := make([]ma.Multiaddr, 0) + for _, ab := range req.Addrs { + a, err := ma.NewMultiaddrBytes(ab) + if err != nil { + t.Error("invalid addr bytes", ab) + } + addrs = append(addrs, a) + } + return addrs +} + +// idAndConnect identifies b to a and connects them +func idAndConnect(t *testing.T, a, b host.Host) { + a.Peerstore().AddAddrs(b.ID(), b.Addrs(), peerstore.PermanentAddrTTL) + a.Peerstore().AddProtocols(b.ID(), DialProtocol) + + err := a.Connect(context.Background(), peer.AddrInfo{ID: b.ID()}) + require.NoError(t, err) +} + +// waitForPeer waits for a to have 1 peer in the peerMap +func waitForPeer(t *testing.T, a *AutoNAT) { + t.Helper() + require.Eventually(t, func() bool { + a.mx.Lock() + defer a.mx.Unlock() + return a.peers.GetRand() != "" + }, 5*time.Second, 100*time.Millisecond) +} + +// idAndWait provides server address and protocol to client +func idAndWait(t *testing.T, cli *AutoNAT, srv *AutoNAT) { + idAndConnect(t, cli.host, srv.host) + waitForPeer(t, cli) +} + +func TestAutoNATPrivateAddr(t *testing.T) { + an := newAutoNAT(t, nil) + res, err := an.CheckReachability(context.Background(), []Request{{Addr: ma.StringCast("/ip4/192.168.0.1/udp/10/quic-v1")}}) + require.Equal(t, res, Result{}) + require.Contains(t, err.Error(), "private address cannot be verified by autonatv2") +} + +func TestClientRequest(t *testing.T) { + an := newAutoNAT(t, nil, allowAllAddrs) + defer an.Close() + defer an.host.Close() + + b := bhost.NewBlankHost(swarmt.GenSwarm(t)) + defer b.Close() + idAndConnect(t, an.host, b) + waitForPeer(t, an) + + addrs := an.host.Addrs() + addrbs := make([][]byte, len(addrs)) + for i := 0; i < len(addrs); i++ { + addrbs[i] = addrs[i].Bytes() + } + + var receivedRequest atomic.Bool + b.SetStreamHandler(DialProtocol, func(s network.Stream) { + receivedRequest.Store(true) + r := pbio.NewDelimitedReader(s, maxMsgSize) + var msg pb.Message + assert.NoError(t, r.ReadMsg(&msg)) + assert.NotNil(t, msg.GetDialRequest()) + assert.Equal(t, addrbs, msg.GetDialRequest().Addrs) + s.Reset() + }) + + res, err := an.CheckReachability(context.Background(), []Request{ + {Addr: addrs[0], SendDialData: true}, {Addr: addrs[1]}, + }) + require.Equal(t, res, Result{}) + require.NotNil(t, err) + require.True(t, receivedRequest.Load()) +} + +func TestClientServerError(t *testing.T) { + an := newAutoNAT(t, nil, allowAllAddrs) + defer an.Close() + defer an.host.Close() + + b := bhost.NewBlankHost(swarmt.GenSwarm(t)) + defer b.Close() + idAndConnect(t, an.host, b) + waitForPeer(t, an) + + tests := []struct { + handler func(network.Stream) + errorStr string + }{ + { + handler: func(s network.Stream) { + s.Reset() + }, + errorStr: "stream reset", + }, + { + handler: func(s network.Stream) { + w := pbio.NewDelimitedWriter(s) + assert.NoError(t, w.WriteMsg( + &pb.Message{Msg: &pb.Message_DialRequest{DialRequest: &pb.DialRequest{}}})) + }, + errorStr: "invalid msg type", + }, + { + handler: func(s network.Stream) { + w := pbio.NewDelimitedWriter(s) + assert.NoError(t, w.WriteMsg( + &pb.Message{Msg: &pb.Message_DialResponse{ + DialResponse: &pb.DialResponse{ + Status: pb.DialResponse_E_DIAL_REFUSED, + }, + }}, + )) + }, + errorStr: ErrDialRefused.Error(), + }, + } + + for i, tc := range tests { + t.Run(fmt.Sprintf("test-%d", i), func(t *testing.T) { + b.SetStreamHandler(DialProtocol, tc.handler) + addrs := an.host.Addrs() + res, err := an.CheckReachability( + context.Background(), + newTestRequests(addrs, false)) + require.Equal(t, res, Result{}) + require.NotNil(t, err) + require.Contains(t, err.Error(), tc.errorStr) + }) + } +} + +func TestClientDataRequest(t *testing.T) { + an := newAutoNAT(t, nil, allowAllAddrs) + defer an.Close() + defer an.host.Close() + + b := bhost.NewBlankHost(swarmt.GenSwarm(t)) + defer b.Close() + idAndConnect(t, an.host, b) + waitForPeer(t, an) + + tests := []struct { + handler func(network.Stream) + name string + }{ + { + name: "provides dial data", + handler: func(s network.Stream) { + r := pbio.NewDelimitedReader(s, maxMsgSize) + var msg pb.Message + assert.NoError(t, r.ReadMsg(&msg)) + w := pbio.NewDelimitedWriter(s) + if err := w.WriteMsg(&pb.Message{ + Msg: &pb.Message_DialDataRequest{ + DialDataRequest: &pb.DialDataRequest{ + AddrIdx: 0, + NumBytes: 10000, + }, + }}, + ); err != nil { + t.Error(err) + s.Reset() + return + } + var dialData []byte + for len(dialData) < 10000 { + if err := r.ReadMsg(&msg); err != nil { + t.Error(err) + s.Reset() + return + } + if msg.GetDialDataResponse() == nil { + t.Errorf("expected to receive msg of type DialDataResponse") + s.Reset() + return + } + dialData = append(dialData, msg.GetDialDataResponse().Data...) + } + s.Reset() + }, + }, + { + name: "low priority addr", + handler: func(s network.Stream) { + r := pbio.NewDelimitedReader(s, maxMsgSize) + var msg pb.Message + assert.NoError(t, r.ReadMsg(&msg)) + w := pbio.NewDelimitedWriter(s) + if err := w.WriteMsg(&pb.Message{ + Msg: &pb.Message_DialDataRequest{ + DialDataRequest: &pb.DialDataRequest{ + AddrIdx: 1, + NumBytes: 10000, + }, + }}, + ); err != nil { + t.Error(err) + s.Reset() + return + } + assert.Error(t, r.ReadMsg(&msg)) + s.Reset() + }, + }, + { + name: "too high dial data request", + handler: func(s network.Stream) { + r := pbio.NewDelimitedReader(s, maxMsgSize) + var msg pb.Message + assert.NoError(t, r.ReadMsg(&msg)) + w := pbio.NewDelimitedWriter(s) + if err := w.WriteMsg(&pb.Message{ + Msg: &pb.Message_DialDataRequest{ + DialDataRequest: &pb.DialDataRequest{ + AddrIdx: 0, + NumBytes: 1 << 32, + }, + }}, + ); err != nil { + t.Error(err) + s.Reset() + return + } + assert.Error(t, r.ReadMsg(&msg)) + s.Reset() + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + b.SetStreamHandler(DialProtocol, tc.handler) + addrs := an.host.Addrs() + + res, err := an.CheckReachability( + context.Background(), + []Request{ + {Addr: addrs[0], SendDialData: true}, + {Addr: addrs[1]}, + }) + require.Equal(t, res, Result{}) + require.NotNil(t, err) + }) + } +} + +func TestClientDialBacks(t *testing.T) { + an := newAutoNAT(t, nil, allowAllAddrs) + defer an.Close() + defer an.host.Close() + + b := bhost.NewBlankHost(swarmt.GenSwarm(t)) + defer b.Close() + idAndConnect(t, an.host, b) + waitForPeer(t, an) + + dialerHost := bhost.NewBlankHost(swarmt.GenSwarm(t)) + defer dialerHost.Close() + + readReq := func(r pbio.Reader) ([]ma.Multiaddr, uint64, error) { + var msg pb.Message + if err := r.ReadMsg(&msg); err != nil { + return nil, 0, err + } + if msg.GetDialRequest() == nil { + return nil, 0, errors.New("no dial request in msg") + } + addrs := parseAddrs(t, &msg) + return addrs, msg.GetDialRequest().GetNonce(), nil + } + + writeNonce := func(addr ma.Multiaddr, nonce uint64) error { + pid := an.host.ID() + dialerHost.Peerstore().AddAddr(pid, addr, peerstore.PermanentAddrTTL) + defer func() { + dialerHost.Network().ClosePeer(pid) + dialerHost.Peerstore().RemovePeer(pid) + dialerHost.Peerstore().ClearAddrs(pid) + }() + as, err := dialerHost.NewStream(context.Background(), pid, DialBackProtocol) + if err != nil { + return err + } + w := pbio.NewDelimitedWriter(as) + if err := w.WriteMsg(&pb.DialBack{Nonce: nonce}); err != nil { + return err + } + as.CloseWrite() + data := make([]byte, 1) + as.Read(data) + as.Close() + return nil + } + + tests := []struct { + name string + handler func(network.Stream) + success bool + }{ + { + name: "correct dial attempt", + handler: func(s network.Stream) { + r := pbio.NewDelimitedReader(s, maxMsgSize) + w := pbio.NewDelimitedWriter(s) + + addrs, nonce, err := readReq(r) + if err != nil { + s.Reset() + t.Error(err) + return + } + if err := writeNonce(addrs[1], nonce); err != nil { + s.Reset() + t.Error(err) + return + } + w.WriteMsg(&pb.Message{ + Msg: &pb.Message_DialResponse{ + DialResponse: &pb.DialResponse{ + Status: pb.DialResponse_OK, + DialStatus: pb.DialStatus_OK, + AddrIdx: 1, + }, + }, + }) + s.Close() + }, + success: true, + }, + { + name: "no dial attempt", + handler: func(s network.Stream) { + r := pbio.NewDelimitedReader(s, maxMsgSize) + if _, _, err := readReq(r); err != nil { + s.Reset() + t.Error(err) + return + } + resp := &pb.DialResponse{ + Status: pb.DialResponse_OK, + DialStatus: pb.DialStatus_OK, + AddrIdx: 0, + } + w := pbio.NewDelimitedWriter(s) + w.WriteMsg(&pb.Message{ + Msg: &pb.Message_DialResponse{ + DialResponse: resp, + }, + }) + s.Close() + }, + success: false, + }, + { + name: "invalid reported address", + handler: func(s network.Stream) { + r := pbio.NewDelimitedReader(s, maxMsgSize) + addrs, nonce, err := readReq(r) + if err != nil { + s.Reset() + t.Error(err) + return + } + + if err := writeNonce(addrs[1], nonce); err != nil { + s.Reset() + t.Error(err) + return + } + + w := pbio.NewDelimitedWriter(s) + w.WriteMsg(&pb.Message{ + Msg: &pb.Message_DialResponse{ + DialResponse: &pb.DialResponse{ + Status: pb.DialResponse_OK, + DialStatus: pb.DialStatus_OK, + AddrIdx: 0, + }, + }, + }) + s.Close() + }, + success: false, + }, + { + name: "invalid nonce", + handler: func(s network.Stream) { + r := pbio.NewDelimitedReader(s, maxMsgSize) + addrs, nonce, err := readReq(r) + if err != nil { + s.Reset() + t.Error(err) + return + } + if err := writeNonce(addrs[0], nonce-1); err != nil { + s.Reset() + t.Error(err) + return + } + w := pbio.NewDelimitedWriter(s) + w.WriteMsg(&pb.Message{ + Msg: &pb.Message_DialResponse{ + DialResponse: &pb.DialResponse{ + Status: pb.DialResponse_OK, + DialStatus: pb.DialStatus_OK, + AddrIdx: 0, + }, + }, + }) + s.Close() + }, + success: false, + }, + { + name: "invalid addr index", + handler: func(s network.Stream) { + r := pbio.NewDelimitedReader(s, maxMsgSize) + _, _, err := readReq(r) + if err != nil { + s.Reset() + t.Error(err) + return + } + w := pbio.NewDelimitedWriter(s) + w.WriteMsg(&pb.Message{ + Msg: &pb.Message_DialResponse{ + DialResponse: &pb.DialResponse{ + Status: pb.DialResponse_OK, + DialStatus: pb.DialStatus_OK, + AddrIdx: 10, + }, + }, + }) + s.Close() + }, + success: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + addrs := an.host.Addrs() + b.SetStreamHandler(DialProtocol, tc.handler) + res, err := an.CheckReachability( + context.Background(), + []Request{ + {Addr: addrs[0], SendDialData: true}, + {Addr: addrs[1]}, + }) + if !tc.success { + require.Error(t, err) + require.Equal(t, Result{}, res) + } else { + require.NoError(t, err) + require.Equal(t, res.Reachability, network.ReachabilityPublic) + require.Equal(t, res.Status, pb.DialStatus_OK) + } + }) + } +} + +func TestEventSubscription(t *testing.T) { + an := newAutoNAT(t, nil) + defer an.host.Close() + + b := bhost.NewBlankHost(swarmt.GenSwarm(t)) + defer b.Close() + c := bhost.NewBlankHost(swarmt.GenSwarm(t)) + defer c.Close() + + idAndConnect(t, an.host, b) + require.Eventually(t, func() bool { + an.mx.Lock() + defer an.mx.Unlock() + return len(an.peers.peers) == 1 + }, 5*time.Second, 100*time.Millisecond) + + idAndConnect(t, an.host, c) + require.Eventually(t, func() bool { + an.mx.Lock() + defer an.mx.Unlock() + return len(an.peers.peers) == 2 + }, 5*time.Second, 100*time.Millisecond) + + an.host.Network().ClosePeer(b.ID()) + require.Eventually(t, func() bool { + an.mx.Lock() + defer an.mx.Unlock() + return len(an.peers.peers) == 1 + }, 5*time.Second, 100*time.Millisecond) + + an.host.Network().ClosePeer(c.ID()) + require.Eventually(t, func() bool { + an.mx.Lock() + defer an.mx.Unlock() + return len(an.peers.peers) == 0 + }, 5*time.Second, 100*time.Millisecond) +} + +func TestPeersMap(t *testing.T) { + emptyPeerID := peer.ID("") + + t.Run("single_item", func(t *testing.T) { + p := newPeersMap() + p.Put("peer1") + p.Delete("peer1") + p.Put("peer1") + require.Equal(t, peer.ID("peer1"), p.GetRand()) + p.Delete("peer1") + require.Equal(t, emptyPeerID, p.GetRand()) + }) + + t.Run("multiple_items", func(t *testing.T) { + p := newPeersMap() + require.Equal(t, emptyPeerID, p.GetRand()) + + allPeers := make(map[peer.ID]bool) + for i := 0; i < 20; i++ { + pid := peer.ID(fmt.Sprintf("peer-%d", i)) + allPeers[pid] = true + p.Put(pid) + } + foundPeers := make(map[peer.ID]bool) + for i := 0; i < 1000; i++ { + pid := p.GetRand() + require.NotEqual(t, emptyPeerID, p) + require.True(t, allPeers[pid]) + foundPeers[pid] = true + if len(foundPeers) == len(allPeers) { + break + } + } + for pid := range allPeers { + p.Delete(pid) + } + require.Equal(t, emptyPeerID, p.GetRand()) + }) +} + +func TestAreAddrsConsistency(t *testing.T) { + tests := []struct { + name string + localAddr ma.Multiaddr + dialAddr ma.Multiaddr + success bool + }{ + { + name: "simple match", + localAddr: ma.StringCast("/ip4/192.168.0.1/tcp/12345"), + dialAddr: ma.StringCast("/ip4/1.2.3.4/tcp/23232"), + success: true, + }, + { + name: "nat64 match", + localAddr: ma.StringCast("/ip6/1::1/tcp/12345"), + dialAddr: ma.StringCast("/ip4/1.2.3.4/tcp/23232"), + success: true, + }, + { + name: "simple mismatch", + localAddr: ma.StringCast("/ip4/192.168.0.1/tcp/12345"), + dialAddr: ma.StringCast("/ip4/1.2.3.4/udp/23232/quic-v1"), + success: false, + }, + { + name: "quic-vs-webtransport", + localAddr: ma.StringCast("/ip4/192.168.0.1/udp/12345/quic-v1"), + dialAddr: ma.StringCast("/ip4/1.2.3.4/udp/123/quic-v1/webtransport"), + success: false, + }, + { + name: "nat64 mismatch", + localAddr: ma.StringCast("/ip4/192.168.0.1/udp/12345/quic-v1"), + dialAddr: ma.StringCast("/ip6/1::1/udp/123/quic-v1/"), + success: false, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + if areAddrsConsistent(tc.localAddr, tc.dialAddr) != tc.success { + wantStr := "match" + if !tc.success { + wantStr = "mismatch" + } + t.Errorf("expected %s between\nlocal addr: %s\ndial addr: %s", wantStr, tc.localAddr, tc.dialAddr) + } + }) + } + +} diff --git a/p2p/protocol/autonatv2/client.go b/p2p/protocol/autonatv2/client.go new file mode 100644 index 0000000000..e2a7db75cd --- /dev/null +++ b/p2p/protocol/autonatv2/client.go @@ -0,0 +1,286 @@ +package autonatv2 + +import ( + "context" + "fmt" + "sync" + "time" + + "github.com/libp2p/go-libp2p/core/host" + "github.com/libp2p/go-libp2p/core/network" + "github.com/libp2p/go-libp2p/core/peer" + "github.com/libp2p/go-libp2p/p2p/protocol/autonatv2/pb" + "github.com/libp2p/go-msgio/pbio" + ma "github.com/multiformats/go-multiaddr" + "golang.org/x/exp/rand" +) + +//go:generate protoc --go_out=. --go_opt=Mpb/autonatv2.proto=./pb pb/autonatv2.proto + +// client implements the client for making dial requests for AutoNAT v2. It verifies successful +// dials and provides an option to send data for dial requests. +type client struct { + host host.Host + dialData []byte + + mu sync.Mutex + // dialBackQueues maps nonce to the channel for providing the local multiaddr of the connection + // the nonce was received on + dialBackQueues map[uint64]chan ma.Multiaddr +} + +func newClient(h host.Host) *client { + return &client{host: h, dialData: make([]byte, 8000), dialBackQueues: make(map[uint64]chan ma.Multiaddr)} +} + +// RegisterDialBack registers the client to receive DialBack streams initiated by the server to send the nonce. +func (ac *client) RegisterDialBack() { + ac.host.SetStreamHandler(DialBackProtocol, ac.handleDialBack) +} + +// CheckReachability verifies address reachability with a AutoNAT v2 server p. +func (ac *client) CheckReachability(ctx context.Context, p peer.ID, reqs []Request) (Result, error) { + ctx, cancel := context.WithTimeout(ctx, streamTimeout) + defer cancel() + + s, err := ac.host.NewStream(ctx, p, DialProtocol) + if err != nil { + return Result{}, fmt.Errorf("open %s stream failed: %w", DialProtocol, err) + } + + if err := s.Scope().SetService(ServiceName); err != nil { + s.Reset() + return Result{}, fmt.Errorf("attach stream %s to service %s failed: %w", DialProtocol, ServiceName, err) + } + + if err := s.Scope().ReserveMemory(maxMsgSize, network.ReservationPriorityAlways); err != nil { + s.Reset() + return Result{}, fmt.Errorf("failed to reserve memory for stream %s: %w", DialProtocol, err) + } + defer s.Scope().ReleaseMemory(maxMsgSize) + + s.SetDeadline(time.Now().Add(streamTimeout)) + defer s.Close() + + nonce := rand.Uint64() + ch := make(chan ma.Multiaddr, 1) + ac.mu.Lock() + ac.dialBackQueues[nonce] = ch + ac.mu.Unlock() + defer func() { + ac.mu.Lock() + delete(ac.dialBackQueues, nonce) + ac.mu.Unlock() + }() + + msg := newDialRequest(reqs, nonce) + w := pbio.NewDelimitedWriter(s) + if err := w.WriteMsg(&msg); err != nil { + s.Reset() + return Result{}, fmt.Errorf("dial request write failed: %w", err) + } + + r := pbio.NewDelimitedReader(s, maxMsgSize) + if err := r.ReadMsg(&msg); err != nil { + s.Reset() + return Result{}, fmt.Errorf("dial msg read failed: %w", err) + } + + switch { + case msg.GetDialResponse() != nil: + break + // provide dial data if appropriate + case msg.GetDialDataRequest() != nil: + idx := int(msg.GetDialDataRequest().AddrIdx) + if idx >= len(reqs) { // invalid address index + s.Reset() + return Result{}, fmt.Errorf("dial data: addr index out of range: %d [0-%d)", idx, len(reqs)) + } + if msg.GetDialDataRequest().NumBytes > maxHandshakeSizeBytes { // data request is too high + s.Reset() + return Result{}, fmt.Errorf("dial data requested too high: %d", msg.GetDialDataRequest().NumBytes) + } + if !reqs[idx].SendDialData { // low priority addr + s.Reset() + return Result{}, fmt.Errorf("dial data requested for low priority addr: %s index %d", reqs[idx].Addr, idx) + } + + // dial data request is valid and we want to send data + if err := ac.sendDialData(msg.GetDialDataRequest(), w, &msg); err != nil { + s.Reset() + return Result{}, fmt.Errorf("dial data send failed: %w", err) + } + if err := r.ReadMsg(&msg); err != nil { + s.Reset() + return Result{}, fmt.Errorf("dial response read failed: %w", err) + } + if msg.GetDialResponse() == nil { + s.Reset() + return Result{}, fmt.Errorf("invalid response type: %T", msg.Msg) + } + default: + s.Reset() + return Result{}, fmt.Errorf("invalid msg type: %T", msg.Msg) + } + + resp := msg.GetDialResponse() + if resp.GetStatus() != pb.DialResponse_OK { + // E_DIAL_REFUSED has implication for deciding future address verificiation priorities + // wrap a distinct error for convenient errors.Is usage + if resp.GetStatus() == pb.DialResponse_E_DIAL_REFUSED { + return Result{}, fmt.Errorf("dial request failed: %w", ErrDialRefused) + } + return Result{}, fmt.Errorf("dial request failed: response status %d %s", resp.GetStatus(), + pb.DialStatus_name[int32(resp.GetStatus())]) + } + if resp.GetDialStatus() == pb.DialStatus_UNUSED { + return Result{}, fmt.Errorf("invalid response: invalid dial status UNUSED") + } + if int(resp.AddrIdx) >= len(reqs) { + return Result{}, fmt.Errorf("invalid response: addr index out of range: %d [0-%d)", resp.AddrIdx, len(reqs)) + } + + // wait for nonce from the server + var dialBackAddr ma.Multiaddr + if resp.GetDialStatus() == pb.DialStatus_OK { + timer := time.NewTimer(dialBackStreamTimeout) + select { + case at := <-ch: + dialBackAddr = at + case <-ctx.Done(): + case <-timer.C: + } + timer.Stop() + } + return ac.newResult(resp, reqs, dialBackAddr) +} + +func (ac *client) newResult(resp *pb.DialResponse, reqs []Request, dialBackAddr ma.Multiaddr) (Result, error) { + idx := int(resp.AddrIdx) + addr := reqs[idx].Addr + + var rch network.Reachability + switch resp.DialStatus { + case pb.DialStatus_OK: + if !areAddrsConsistent(dialBackAddr, addr) { + // The server reported a successful dial back but we didn't receive the nonce. + // Discard the response and fail. + return Result{}, fmt.Errorf("invalid repsonse: no dialback received") + } + rch = network.ReachabilityPublic + case pb.DialStatus_E_DIAL_ERROR: + rch = network.ReachabilityPrivate + case pb.DialStatus_E_DIAL_BACK_ERROR: + rch = network.ReachabilityUnknown + default: + // Unexpected response code. Discard the response and fail. + log.Warnf("invalid status code received in response for addr %s: %d", addr, resp.DialStatus) + return Result{}, fmt.Errorf("invalid response: invalid status code for addr %s: %d", addr, resp.DialStatus) + } + + return Result{ + Idx: idx, + Addr: addr, + Reachability: rch, + Status: resp.DialStatus, + }, nil +} + +func (ac *client) sendDialData(req *pb.DialDataRequest, w pbio.Writer, msg *pb.Message) error { + nb := req.GetNumBytes() + ddResp := &pb.DialDataResponse{Data: ac.dialData} + *msg = pb.Message{ + Msg: &pb.Message_DialDataResponse{ + DialDataResponse: ddResp, + }, + } + for remain := int(nb); remain > 0; { + end := remain + if end > len(ddResp.Data) { + end = len(ddResp.Data) + } + ddResp.Data = ddResp.Data[:end] + if err := w.WriteMsg(msg); err != nil { + return err + } + remain -= end + } + return nil +} + +func newDialRequest(reqs []Request, nonce uint64) pb.Message { + addrbs := make([][]byte, len(reqs)) + for i, r := range reqs { + addrbs[i] = r.Addr.Bytes() + } + return pb.Message{ + Msg: &pb.Message_DialRequest{ + DialRequest: &pb.DialRequest{ + Addrs: addrbs, + Nonce: nonce, + }, + }, + } +} + +// handleDialBack receives the nonce on the dial-back stream +func (ac *client) handleDialBack(s network.Stream) { + if err := s.Scope().SetService(ServiceName); err != nil { + log.Debugf("failed to attach stream to service %s: %w", ServiceName, err) + s.Reset() + return + } + + if err := s.Scope().ReserveMemory(maxMsgSize, network.ReservationPriorityAlways); err != nil { + log.Debugf("failed to reserve memory for stream %s: %w", DialBackProtocol, err) + s.Reset() + return + } + defer s.Scope().ReleaseMemory(maxMsgSize) + + s.SetDeadline(time.Now().Add(dialBackStreamTimeout)) + defer s.Close() + + r := pbio.NewDelimitedReader(s, maxMsgSize) + var msg pb.DialBack + if err := r.ReadMsg(&msg); err != nil { + log.Debugf("failed to read dialback msg from %s: %s", s.Conn().RemotePeer(), err) + s.Reset() + return + } + nonce := msg.GetNonce() + + ac.mu.Lock() + ch := ac.dialBackQueues[nonce] + ac.mu.Unlock() + select { + case ch <- s.Conn().LocalMultiaddr(): + default: + } +} + +func areAddrsConsistent(local, external ma.Multiaddr) bool { + if local == nil || external == nil { + return false + } + + localProtos := local.Protocols() + externalProtos := external.Protocols() + if len(localProtos) != len(externalProtos) { + return false + } + for i := 0; i < len(localProtos); i++ { + if i == 0 { + if localProtos[i].Code == externalProtos[i].Code || + localProtos[i].Code == ma.P_IP6 && externalProtos[i].Code == ma.P_IP4 /* NAT64 */ { + continue + } + return false + } else { + if localProtos[i].Code != externalProtos[i].Code { + return false + } + } + } + return true +} diff --git a/p2p/protocol/autonatv2/options.go b/p2p/protocol/autonatv2/options.go new file mode 100644 index 0000000000..3a59d8d823 --- /dev/null +++ b/p2p/protocol/autonatv2/options.go @@ -0,0 +1,48 @@ +package autonatv2 + +import "time" + +// autoNATSettings is used to configure AutoNAT +type autoNATSettings struct { + allowAllAddrs bool + serverRPM int + serverPerPeerRPM int + serverDialDataRPM int + dataRequestPolicy dataRequestPolicyFunc + now func() time.Time +} + +func defaultSettings() *autoNATSettings { + return &autoNATSettings{ + allowAllAddrs: false, + // TODO: confirm rate limiting defaults + serverRPM: 20, + serverPerPeerRPM: 2, + serverDialDataRPM: 5, + dataRequestPolicy: amplificationAttackPrevention, + now: time.Now, + } +} + +type AutoNATOption func(s *autoNATSettings) error + +func WithServerRateLimit(rpm, perPeerRPM, dialDataRPM int) AutoNATOption { + return func(s *autoNATSettings) error { + s.serverRPM = rpm + s.serverPerPeerRPM = perPeerRPM + s.serverDialDataRPM = dialDataRPM + return nil + } +} + +func withDataRequestPolicy(drp dataRequestPolicyFunc) AutoNATOption { + return func(s *autoNATSettings) error { + s.dataRequestPolicy = drp + return nil + } +} + +func allowAllAddrs(s *autoNATSettings) error { + s.allowAllAddrs = true + return nil +} diff --git a/p2p/protocol/autonatv2/pb/autonatv2.pb.go b/p2p/protocol/autonatv2/pb/autonatv2.pb.go new file mode 100644 index 0000000000..f94b077acf --- /dev/null +++ b/p2p/protocol/autonatv2/pb/autonatv2.pb.go @@ -0,0 +1,706 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.30.0 +// protoc v3.21.12 +// source: pb/autonatv2.proto + +package pb + +import ( + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +type DialStatus int32 + +const ( + DialStatus_UNUSED DialStatus = 0 + DialStatus_E_DIAL_ERROR DialStatus = 100 + DialStatus_E_DIAL_BACK_ERROR DialStatus = 101 + DialStatus_OK DialStatus = 200 +) + +// Enum value maps for DialStatus. +var ( + DialStatus_name = map[int32]string{ + 0: "UNUSED", + 100: "E_DIAL_ERROR", + 101: "E_DIAL_BACK_ERROR", + 200: "OK", + } + DialStatus_value = map[string]int32{ + "UNUSED": 0, + "E_DIAL_ERROR": 100, + "E_DIAL_BACK_ERROR": 101, + "OK": 200, + } +) + +func (x DialStatus) Enum() *DialStatus { + p := new(DialStatus) + *p = x + return p +} + +func (x DialStatus) String() string { + return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x)) +} + +func (DialStatus) Descriptor() protoreflect.EnumDescriptor { + return file_pb_autonatv2_proto_enumTypes[0].Descriptor() +} + +func (DialStatus) Type() protoreflect.EnumType { + return &file_pb_autonatv2_proto_enumTypes[0] +} + +func (x DialStatus) Number() protoreflect.EnumNumber { + return protoreflect.EnumNumber(x) +} + +// Deprecated: Use DialStatus.Descriptor instead. +func (DialStatus) EnumDescriptor() ([]byte, []int) { + return file_pb_autonatv2_proto_rawDescGZIP(), []int{0} +} + +type DialResponse_ResponseStatus int32 + +const ( + DialResponse_E_INTERNAL_ERROR DialResponse_ResponseStatus = 0 + DialResponse_E_REQUEST_REJECTED DialResponse_ResponseStatus = 100 + DialResponse_E_DIAL_REFUSED DialResponse_ResponseStatus = 101 + DialResponse_OK DialResponse_ResponseStatus = 200 +) + +// Enum value maps for DialResponse_ResponseStatus. +var ( + DialResponse_ResponseStatus_name = map[int32]string{ + 0: "E_INTERNAL_ERROR", + 100: "E_REQUEST_REJECTED", + 101: "E_DIAL_REFUSED", + 200: "OK", + } + DialResponse_ResponseStatus_value = map[string]int32{ + "E_INTERNAL_ERROR": 0, + "E_REQUEST_REJECTED": 100, + "E_DIAL_REFUSED": 101, + "OK": 200, + } +) + +func (x DialResponse_ResponseStatus) Enum() *DialResponse_ResponseStatus { + p := new(DialResponse_ResponseStatus) + *p = x + return p +} + +func (x DialResponse_ResponseStatus) String() string { + return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x)) +} + +func (DialResponse_ResponseStatus) Descriptor() protoreflect.EnumDescriptor { + return file_pb_autonatv2_proto_enumTypes[1].Descriptor() +} + +func (DialResponse_ResponseStatus) Type() protoreflect.EnumType { + return &file_pb_autonatv2_proto_enumTypes[1] +} + +func (x DialResponse_ResponseStatus) Number() protoreflect.EnumNumber { + return protoreflect.EnumNumber(x) +} + +// Deprecated: Use DialResponse_ResponseStatus.Descriptor instead. +func (DialResponse_ResponseStatus) EnumDescriptor() ([]byte, []int) { + return file_pb_autonatv2_proto_rawDescGZIP(), []int{3, 0} +} + +type Message struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + // Types that are assignable to Msg: + // + // *Message_DialRequest + // *Message_DialResponse + // *Message_DialDataRequest + // *Message_DialDataResponse + Msg isMessage_Msg `protobuf_oneof:"msg"` +} + +func (x *Message) Reset() { + *x = Message{} + if protoimpl.UnsafeEnabled { + mi := &file_pb_autonatv2_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *Message) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Message) ProtoMessage() {} + +func (x *Message) ProtoReflect() protoreflect.Message { + mi := &file_pb_autonatv2_proto_msgTypes[0] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Message.ProtoReflect.Descriptor instead. +func (*Message) Descriptor() ([]byte, []int) { + return file_pb_autonatv2_proto_rawDescGZIP(), []int{0} +} + +func (m *Message) GetMsg() isMessage_Msg { + if m != nil { + return m.Msg + } + return nil +} + +func (x *Message) GetDialRequest() *DialRequest { + if x, ok := x.GetMsg().(*Message_DialRequest); ok { + return x.DialRequest + } + return nil +} + +func (x *Message) GetDialResponse() *DialResponse { + if x, ok := x.GetMsg().(*Message_DialResponse); ok { + return x.DialResponse + } + return nil +} + +func (x *Message) GetDialDataRequest() *DialDataRequest { + if x, ok := x.GetMsg().(*Message_DialDataRequest); ok { + return x.DialDataRequest + } + return nil +} + +func (x *Message) GetDialDataResponse() *DialDataResponse { + if x, ok := x.GetMsg().(*Message_DialDataResponse); ok { + return x.DialDataResponse + } + return nil +} + +type isMessage_Msg interface { + isMessage_Msg() +} + +type Message_DialRequest struct { + DialRequest *DialRequest `protobuf:"bytes,1,opt,name=dialRequest,proto3,oneof"` +} + +type Message_DialResponse struct { + DialResponse *DialResponse `protobuf:"bytes,2,opt,name=dialResponse,proto3,oneof"` +} + +type Message_DialDataRequest struct { + DialDataRequest *DialDataRequest `protobuf:"bytes,3,opt,name=dialDataRequest,proto3,oneof"` +} + +type Message_DialDataResponse struct { + DialDataResponse *DialDataResponse `protobuf:"bytes,4,opt,name=dialDataResponse,proto3,oneof"` +} + +func (*Message_DialRequest) isMessage_Msg() {} + +func (*Message_DialResponse) isMessage_Msg() {} + +func (*Message_DialDataRequest) isMessage_Msg() {} + +func (*Message_DialDataResponse) isMessage_Msg() {} + +type DialRequest struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Addrs [][]byte `protobuf:"bytes,1,rep,name=addrs,proto3" json:"addrs,omitempty"` + Nonce uint64 `protobuf:"fixed64,2,opt,name=nonce,proto3" json:"nonce,omitempty"` +} + +func (x *DialRequest) Reset() { + *x = DialRequest{} + if protoimpl.UnsafeEnabled { + mi := &file_pb_autonatv2_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *DialRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*DialRequest) ProtoMessage() {} + +func (x *DialRequest) ProtoReflect() protoreflect.Message { + mi := &file_pb_autonatv2_proto_msgTypes[1] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use DialRequest.ProtoReflect.Descriptor instead. +func (*DialRequest) Descriptor() ([]byte, []int) { + return file_pb_autonatv2_proto_rawDescGZIP(), []int{1} +} + +func (x *DialRequest) GetAddrs() [][]byte { + if x != nil { + return x.Addrs + } + return nil +} + +func (x *DialRequest) GetNonce() uint64 { + if x != nil { + return x.Nonce + } + return 0 +} + +type DialDataRequest struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + AddrIdx uint32 `protobuf:"varint,1,opt,name=addrIdx,proto3" json:"addrIdx,omitempty"` + NumBytes uint64 `protobuf:"varint,2,opt,name=numBytes,proto3" json:"numBytes,omitempty"` +} + +func (x *DialDataRequest) Reset() { + *x = DialDataRequest{} + if protoimpl.UnsafeEnabled { + mi := &file_pb_autonatv2_proto_msgTypes[2] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *DialDataRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*DialDataRequest) ProtoMessage() {} + +func (x *DialDataRequest) ProtoReflect() protoreflect.Message { + mi := &file_pb_autonatv2_proto_msgTypes[2] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use DialDataRequest.ProtoReflect.Descriptor instead. +func (*DialDataRequest) Descriptor() ([]byte, []int) { + return file_pb_autonatv2_proto_rawDescGZIP(), []int{2} +} + +func (x *DialDataRequest) GetAddrIdx() uint32 { + if x != nil { + return x.AddrIdx + } + return 0 +} + +func (x *DialDataRequest) GetNumBytes() uint64 { + if x != nil { + return x.NumBytes + } + return 0 +} + +type DialResponse struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Status DialResponse_ResponseStatus `protobuf:"varint,1,opt,name=status,proto3,enum=autonatv2.pb.DialResponse_ResponseStatus" json:"status,omitempty"` + AddrIdx uint32 `protobuf:"varint,2,opt,name=addrIdx,proto3" json:"addrIdx,omitempty"` + DialStatus DialStatus `protobuf:"varint,3,opt,name=dialStatus,proto3,enum=autonatv2.pb.DialStatus" json:"dialStatus,omitempty"` +} + +func (x *DialResponse) Reset() { + *x = DialResponse{} + if protoimpl.UnsafeEnabled { + mi := &file_pb_autonatv2_proto_msgTypes[3] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *DialResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*DialResponse) ProtoMessage() {} + +func (x *DialResponse) ProtoReflect() protoreflect.Message { + mi := &file_pb_autonatv2_proto_msgTypes[3] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use DialResponse.ProtoReflect.Descriptor instead. +func (*DialResponse) Descriptor() ([]byte, []int) { + return file_pb_autonatv2_proto_rawDescGZIP(), []int{3} +} + +func (x *DialResponse) GetStatus() DialResponse_ResponseStatus { + if x != nil { + return x.Status + } + return DialResponse_E_INTERNAL_ERROR +} + +func (x *DialResponse) GetAddrIdx() uint32 { + if x != nil { + return x.AddrIdx + } + return 0 +} + +func (x *DialResponse) GetDialStatus() DialStatus { + if x != nil { + return x.DialStatus + } + return DialStatus_UNUSED +} + +type DialDataResponse struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Data []byte `protobuf:"bytes,1,opt,name=data,proto3" json:"data,omitempty"` +} + +func (x *DialDataResponse) Reset() { + *x = DialDataResponse{} + if protoimpl.UnsafeEnabled { + mi := &file_pb_autonatv2_proto_msgTypes[4] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *DialDataResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*DialDataResponse) ProtoMessage() {} + +func (x *DialDataResponse) ProtoReflect() protoreflect.Message { + mi := &file_pb_autonatv2_proto_msgTypes[4] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use DialDataResponse.ProtoReflect.Descriptor instead. +func (*DialDataResponse) Descriptor() ([]byte, []int) { + return file_pb_autonatv2_proto_rawDescGZIP(), []int{4} +} + +func (x *DialDataResponse) GetData() []byte { + if x != nil { + return x.Data + } + return nil +} + +type DialBack struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Nonce uint64 `protobuf:"fixed64,1,opt,name=nonce,proto3" json:"nonce,omitempty"` +} + +func (x *DialBack) Reset() { + *x = DialBack{} + if protoimpl.UnsafeEnabled { + mi := &file_pb_autonatv2_proto_msgTypes[5] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *DialBack) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*DialBack) ProtoMessage() {} + +func (x *DialBack) ProtoReflect() protoreflect.Message { + mi := &file_pb_autonatv2_proto_msgTypes[5] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use DialBack.ProtoReflect.Descriptor instead. +func (*DialBack) Descriptor() ([]byte, []int) { + return file_pb_autonatv2_proto_rawDescGZIP(), []int{5} +} + +func (x *DialBack) GetNonce() uint64 { + if x != nil { + return x.Nonce + } + return 0 +} + +var File_pb_autonatv2_proto protoreflect.FileDescriptor + +var file_pb_autonatv2_proto_rawDesc = []byte{ + 0x0a, 0x12, 0x70, 0x62, 0x2f, 0x61, 0x75, 0x74, 0x6f, 0x6e, 0x61, 0x74, 0x76, 0x32, 0x2e, 0x70, + 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x0c, 0x61, 0x75, 0x74, 0x6f, 0x6e, 0x61, 0x74, 0x76, 0x32, 0x2e, + 0x70, 0x62, 0x22, 0xaa, 0x02, 0x0a, 0x07, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x12, 0x3d, + 0x0a, 0x0b, 0x64, 0x69, 0x61, 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x18, 0x01, 0x20, + 0x01, 0x28, 0x0b, 0x32, 0x19, 0x2e, 0x61, 0x75, 0x74, 0x6f, 0x6e, 0x61, 0x74, 0x76, 0x32, 0x2e, + 0x70, 0x62, 0x2e, 0x44, 0x69, 0x61, 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x48, 0x00, + 0x52, 0x0b, 0x64, 0x69, 0x61, 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x40, 0x0a, + 0x0c, 0x64, 0x69, 0x61, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x18, 0x02, 0x20, + 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x61, 0x75, 0x74, 0x6f, 0x6e, 0x61, 0x74, 0x76, 0x32, 0x2e, + 0x70, 0x62, 0x2e, 0x44, 0x69, 0x61, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x48, + 0x00, 0x52, 0x0c, 0x64, 0x69, 0x61, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, + 0x49, 0x0a, 0x0f, 0x64, 0x69, 0x61, 0x6c, 0x44, 0x61, 0x74, 0x61, 0x52, 0x65, 0x71, 0x75, 0x65, + 0x73, 0x74, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1d, 0x2e, 0x61, 0x75, 0x74, 0x6f, 0x6e, + 0x61, 0x74, 0x76, 0x32, 0x2e, 0x70, 0x62, 0x2e, 0x44, 0x69, 0x61, 0x6c, 0x44, 0x61, 0x74, 0x61, + 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x48, 0x00, 0x52, 0x0f, 0x64, 0x69, 0x61, 0x6c, 0x44, + 0x61, 0x74, 0x61, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x4c, 0x0a, 0x10, 0x64, 0x69, + 0x61, 0x6c, 0x44, 0x61, 0x74, 0x61, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x18, 0x04, + 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1e, 0x2e, 0x61, 0x75, 0x74, 0x6f, 0x6e, 0x61, 0x74, 0x76, 0x32, + 0x2e, 0x70, 0x62, 0x2e, 0x44, 0x69, 0x61, 0x6c, 0x44, 0x61, 0x74, 0x61, 0x52, 0x65, 0x73, 0x70, + 0x6f, 0x6e, 0x73, 0x65, 0x48, 0x00, 0x52, 0x10, 0x64, 0x69, 0x61, 0x6c, 0x44, 0x61, 0x74, 0x61, + 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x42, 0x05, 0x0a, 0x03, 0x6d, 0x73, 0x67, 0x22, + 0x39, 0x0a, 0x0b, 0x44, 0x69, 0x61, 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x14, + 0x0a, 0x05, 0x61, 0x64, 0x64, 0x72, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0c, 0x52, 0x05, 0x61, + 0x64, 0x64, 0x72, 0x73, 0x12, 0x14, 0x0a, 0x05, 0x6e, 0x6f, 0x6e, 0x63, 0x65, 0x18, 0x02, 0x20, + 0x01, 0x28, 0x06, 0x52, 0x05, 0x6e, 0x6f, 0x6e, 0x63, 0x65, 0x22, 0x47, 0x0a, 0x0f, 0x44, 0x69, + 0x61, 0x6c, 0x44, 0x61, 0x74, 0x61, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x18, 0x0a, + 0x07, 0x61, 0x64, 0x64, 0x72, 0x49, 0x64, 0x78, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x07, + 0x61, 0x64, 0x64, 0x72, 0x49, 0x64, 0x78, 0x12, 0x1a, 0x0a, 0x08, 0x6e, 0x75, 0x6d, 0x42, 0x79, + 0x74, 0x65, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x04, 0x52, 0x08, 0x6e, 0x75, 0x6d, 0x42, 0x79, + 0x74, 0x65, 0x73, 0x22, 0x82, 0x02, 0x0a, 0x0c, 0x44, 0x69, 0x61, 0x6c, 0x52, 0x65, 0x73, 0x70, + 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x41, 0x0a, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x18, 0x01, + 0x20, 0x01, 0x28, 0x0e, 0x32, 0x29, 0x2e, 0x61, 0x75, 0x74, 0x6f, 0x6e, 0x61, 0x74, 0x76, 0x32, + 0x2e, 0x70, 0x62, 0x2e, 0x44, 0x69, 0x61, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, + 0x2e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, + 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x61, 0x64, 0x64, 0x72, 0x49, + 0x64, 0x78, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x07, 0x61, 0x64, 0x64, 0x72, 0x49, 0x64, + 0x78, 0x12, 0x38, 0x0a, 0x0a, 0x64, 0x69, 0x61, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x18, + 0x03, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x18, 0x2e, 0x61, 0x75, 0x74, 0x6f, 0x6e, 0x61, 0x74, 0x76, + 0x32, 0x2e, 0x70, 0x62, 0x2e, 0x44, 0x69, 0x61, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, + 0x0a, 0x64, 0x69, 0x61, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x22, 0x5b, 0x0a, 0x0e, 0x52, + 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x14, 0x0a, + 0x10, 0x45, 0x5f, 0x49, 0x4e, 0x54, 0x45, 0x52, 0x4e, 0x41, 0x4c, 0x5f, 0x45, 0x52, 0x52, 0x4f, + 0x52, 0x10, 0x00, 0x12, 0x16, 0x0a, 0x12, 0x45, 0x5f, 0x52, 0x45, 0x51, 0x55, 0x45, 0x53, 0x54, + 0x5f, 0x52, 0x45, 0x4a, 0x45, 0x43, 0x54, 0x45, 0x44, 0x10, 0x64, 0x12, 0x12, 0x0a, 0x0e, 0x45, + 0x5f, 0x44, 0x49, 0x41, 0x4c, 0x5f, 0x52, 0x45, 0x46, 0x55, 0x53, 0x45, 0x44, 0x10, 0x65, 0x12, + 0x07, 0x0a, 0x02, 0x4f, 0x4b, 0x10, 0xc8, 0x01, 0x22, 0x26, 0x0a, 0x10, 0x44, 0x69, 0x61, 0x6c, + 0x44, 0x61, 0x74, 0x61, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x12, 0x0a, 0x04, + 0x64, 0x61, 0x74, 0x61, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x04, 0x64, 0x61, 0x74, 0x61, + 0x22, 0x20, 0x0a, 0x08, 0x44, 0x69, 0x61, 0x6c, 0x42, 0x61, 0x63, 0x6b, 0x12, 0x14, 0x0a, 0x05, + 0x6e, 0x6f, 0x6e, 0x63, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x06, 0x52, 0x05, 0x6e, 0x6f, 0x6e, + 0x63, 0x65, 0x2a, 0x4a, 0x0a, 0x0a, 0x44, 0x69, 0x61, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, + 0x12, 0x0a, 0x0a, 0x06, 0x55, 0x4e, 0x55, 0x53, 0x45, 0x44, 0x10, 0x00, 0x12, 0x10, 0x0a, 0x0c, + 0x45, 0x5f, 0x44, 0x49, 0x41, 0x4c, 0x5f, 0x45, 0x52, 0x52, 0x4f, 0x52, 0x10, 0x64, 0x12, 0x15, + 0x0a, 0x11, 0x45, 0x5f, 0x44, 0x49, 0x41, 0x4c, 0x5f, 0x42, 0x41, 0x43, 0x4b, 0x5f, 0x45, 0x52, + 0x52, 0x4f, 0x52, 0x10, 0x65, 0x12, 0x07, 0x0a, 0x02, 0x4f, 0x4b, 0x10, 0xc8, 0x01, 0x62, 0x06, + 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, +} + +var ( + file_pb_autonatv2_proto_rawDescOnce sync.Once + file_pb_autonatv2_proto_rawDescData = file_pb_autonatv2_proto_rawDesc +) + +func file_pb_autonatv2_proto_rawDescGZIP() []byte { + file_pb_autonatv2_proto_rawDescOnce.Do(func() { + file_pb_autonatv2_proto_rawDescData = protoimpl.X.CompressGZIP(file_pb_autonatv2_proto_rawDescData) + }) + return file_pb_autonatv2_proto_rawDescData +} + +var file_pb_autonatv2_proto_enumTypes = make([]protoimpl.EnumInfo, 2) +var file_pb_autonatv2_proto_msgTypes = make([]protoimpl.MessageInfo, 6) +var file_pb_autonatv2_proto_goTypes = []interface{}{ + (DialStatus)(0), // 0: autonatv2.pb.DialStatus + (DialResponse_ResponseStatus)(0), // 1: autonatv2.pb.DialResponse.ResponseStatus + (*Message)(nil), // 2: autonatv2.pb.Message + (*DialRequest)(nil), // 3: autonatv2.pb.DialRequest + (*DialDataRequest)(nil), // 4: autonatv2.pb.DialDataRequest + (*DialResponse)(nil), // 5: autonatv2.pb.DialResponse + (*DialDataResponse)(nil), // 6: autonatv2.pb.DialDataResponse + (*DialBack)(nil), // 7: autonatv2.pb.DialBack +} +var file_pb_autonatv2_proto_depIdxs = []int32{ + 3, // 0: autonatv2.pb.Message.dialRequest:type_name -> autonatv2.pb.DialRequest + 5, // 1: autonatv2.pb.Message.dialResponse:type_name -> autonatv2.pb.DialResponse + 4, // 2: autonatv2.pb.Message.dialDataRequest:type_name -> autonatv2.pb.DialDataRequest + 6, // 3: autonatv2.pb.Message.dialDataResponse:type_name -> autonatv2.pb.DialDataResponse + 1, // 4: autonatv2.pb.DialResponse.status:type_name -> autonatv2.pb.DialResponse.ResponseStatus + 0, // 5: autonatv2.pb.DialResponse.dialStatus:type_name -> autonatv2.pb.DialStatus + 6, // [6:6] is the sub-list for method output_type + 6, // [6:6] is the sub-list for method input_type + 6, // [6:6] is the sub-list for extension type_name + 6, // [6:6] is the sub-list for extension extendee + 0, // [0:6] is the sub-list for field type_name +} + +func init() { file_pb_autonatv2_proto_init() } +func file_pb_autonatv2_proto_init() { + if File_pb_autonatv2_proto != nil { + return + } + if !protoimpl.UnsafeEnabled { + file_pb_autonatv2_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*Message); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_pb_autonatv2_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*DialRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_pb_autonatv2_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*DialDataRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_pb_autonatv2_proto_msgTypes[3].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*DialResponse); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_pb_autonatv2_proto_msgTypes[4].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*DialDataResponse); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_pb_autonatv2_proto_msgTypes[5].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*DialBack); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + } + file_pb_autonatv2_proto_msgTypes[0].OneofWrappers = []interface{}{ + (*Message_DialRequest)(nil), + (*Message_DialResponse)(nil), + (*Message_DialDataRequest)(nil), + (*Message_DialDataResponse)(nil), + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: file_pb_autonatv2_proto_rawDesc, + NumEnums: 2, + NumMessages: 6, + NumExtensions: 0, + NumServices: 0, + }, + GoTypes: file_pb_autonatv2_proto_goTypes, + DependencyIndexes: file_pb_autonatv2_proto_depIdxs, + EnumInfos: file_pb_autonatv2_proto_enumTypes, + MessageInfos: file_pb_autonatv2_proto_msgTypes, + }.Build() + File_pb_autonatv2_proto = out.File + file_pb_autonatv2_proto_rawDesc = nil + file_pb_autonatv2_proto_goTypes = nil + file_pb_autonatv2_proto_depIdxs = nil +} diff --git a/p2p/protocol/autonatv2/pb/autonatv2.proto b/p2p/protocol/autonatv2/pb/autonatv2.proto new file mode 100644 index 0000000000..1a02286060 --- /dev/null +++ b/p2p/protocol/autonatv2/pb/autonatv2.proto @@ -0,0 +1,55 @@ +syntax = "proto3"; + +package autonatv2.pb; + +message Message { + oneof msg { + DialRequest dialRequest = 1; + DialResponse dialResponse = 2; + DialDataRequest dialDataRequest = 3; + DialDataResponse dialDataResponse = 4; + } +} + +message DialRequest { + repeated bytes addrs = 1; + fixed64 nonce = 2; +} + + +message DialDataRequest { + uint32 addrIdx = 1; + uint64 numBytes = 2; +} + + +enum DialStatus { + UNUSED = 0; + E_DIAL_ERROR = 100; + E_DIAL_BACK_ERROR = 101; + OK = 200; +} + + +message DialResponse { + enum ResponseStatus { + E_INTERNAL_ERROR = 0; + E_REQUEST_REJECTED = 100; + E_DIAL_REFUSED = 101; + OK = 200; + } + + ResponseStatus status = 1; + uint32 addrIdx = 2; + DialStatus dialStatus = 3; +} + + +message DialDataResponse { + bytes data = 1; +} + + +message DialBack { + fixed64 nonce = 1; +} \ No newline at end of file diff --git a/p2p/protocol/autonatv2/server.go b/p2p/protocol/autonatv2/server.go new file mode 100644 index 0000000000..bbf651c2af --- /dev/null +++ b/p2p/protocol/autonatv2/server.go @@ -0,0 +1,354 @@ +package autonatv2 + +import ( + "context" + "fmt" + "sync" + "time" + + "github.com/libp2p/go-libp2p/core/host" + "github.com/libp2p/go-libp2p/core/network" + "github.com/libp2p/go-libp2p/core/peer" + "github.com/libp2p/go-libp2p/core/peerstore" + "github.com/libp2p/go-libp2p/p2p/protocol/autonatv2/pb" + "github.com/libp2p/go-msgio/pbio" + + ma "github.com/multiformats/go-multiaddr" + manet "github.com/multiformats/go-multiaddr/net" + "golang.org/x/exp/rand" +) + +type dataRequestPolicyFunc = func(s network.Stream, dialAddr ma.Multiaddr) bool + +// server implements the AutoNATv2 server. +// It can ask client to provide dial data before attempting the requested dial. +// It rate limits requests on a global level, per peer level and on whether the request requires dial data. +type server struct { + host host.Host + dialerHost host.Host + limiter *rateLimiter + + // dialDataRequestPolicy is used to determine whether dialing the address requires receiving dial data. + // It is set to amplification attack prevention by default. + dialDataRequestPolicy dataRequestPolicyFunc + + // for tests + now func() time.Time + allowAllAddrs bool +} + +func newServer(host, dialer host.Host, s *autoNATSettings) *server { + return &server{ + dialerHost: dialer, + host: host, + dialDataRequestPolicy: s.dataRequestPolicy, + allowAllAddrs: s.allowAllAddrs, + limiter: &rateLimiter{ + RPM: s.serverRPM, + PerPeerRPM: s.serverPerPeerRPM, + DialDataRPM: s.serverDialDataRPM, + now: s.now, + }, + now: s.now, + } +} + +// Enable attaches the stream handler to the host. +func (as *server) Enable() { + as.host.SetStreamHandler(DialProtocol, as.handleDialRequest) +} + +// Disable removes the stream handles from the host. +func (as *server) Disable() { + as.host.RemoveStreamHandler(DialProtocol) +} + +func (as *server) Close() { + as.dialerHost.Close() +} + +// handleDialRequest is the dial-request protocol stream handler +func (as *server) handleDialRequest(s network.Stream) { + if err := s.Scope().SetService(ServiceName); err != nil { + s.Reset() + log.Debugf("failed to attach stream to service %s: %w", ServiceName, err) + return + } + + if err := s.Scope().ReserveMemory(maxMsgSize, network.ReservationPriorityAlways); err != nil { + s.Reset() + log.Debugf("failed to reserve memory for stream %s: %w", DialProtocol, err) + return + } + defer s.Scope().ReleaseMemory(maxMsgSize) + + s.SetDeadline(as.now().Add(streamTimeout)) + defer s.Close() + + p := s.Conn().RemotePeer() + r := pbio.NewDelimitedReader(s, maxMsgSize) + var msg pb.Message + if err := r.ReadMsg(&msg); err != nil { + s.Reset() + log.Debugf("failed to read request from %s: %s", p, err) + return + } + if msg.GetDialRequest() == nil { + s.Reset() + log.Debugf("invalid message type from %s: %T", p, msg.Msg) + return + } + + nonce := msg.GetDialRequest().Nonce + // parse peer's addresses + var dialAddr ma.Multiaddr + var addrIdx int + for i, ab := range msg.GetDialRequest().GetAddrs() { + if i >= maxPeerAddresses { + break + } + a, err := ma.NewMultiaddrBytes(ab) + if err != nil { + continue + } + if !as.allowAllAddrs && !manet.IsPublicAddr(a) { + continue + } + if !as.dialerHost.Network().CanDial(p, a) { + continue + } + dialAddr = a + addrIdx = i + break + } + w := pbio.NewDelimitedWriter(s) + // No dialable address + if dialAddr == nil { + msg = pb.Message{ + Msg: &pb.Message_DialResponse{ + DialResponse: &pb.DialResponse{ + Status: pb.DialResponse_E_DIAL_REFUSED, + }, + }, + } + if err := w.WriteMsg(&msg); err != nil { + s.Reset() + log.Debugf("failed to write response to %s: %s", p, err) + return + } + return + } + + isDialDataRequired := as.dialDataRequestPolicy(s, dialAddr) + + if !as.limiter.Accept(p, isDialDataRequired) { + msg = pb.Message{ + Msg: &pb.Message_DialResponse{ + DialResponse: &pb.DialResponse{ + Status: pb.DialResponse_E_REQUEST_REJECTED, + }, + }, + } + if err := w.WriteMsg(&msg); err != nil { + s.Reset() + log.Debugf("failed to write response to %s: %s", p, err) + return + } + log.Debugf("rejecting request from %s: rate limit exceeded", p) + return + } + defer as.limiter.CompleteRequest(p) + + if isDialDataRequired { + if err := getDialData(w, r, &msg, addrIdx); err != nil { + s.Reset() + log.Debugf("%s refused dial data request: %s", p, err) + return + } + } + + dialStatus := as.dialBack(s.Conn().RemotePeer(), dialAddr, nonce) + msg = pb.Message{ + Msg: &pb.Message_DialResponse{ + DialResponse: &pb.DialResponse{ + Status: pb.DialResponse_OK, + DialStatus: dialStatus, + AddrIdx: uint32(addrIdx), + }, + }, + } + if err := w.WriteMsg(&msg); err != nil { + s.Reset() + log.Debugf("failed to write response to %s: %s", p, err) + return + } +} + +// getDialData gets data from the client for dialing the address +func getDialData(w pbio.Writer, r pbio.Reader, msg *pb.Message, addrIdx int) error { + numBytes := minHandshakeSizeBytes + rand.Intn(maxHandshakeSizeBytes-minHandshakeSizeBytes) + *msg = pb.Message{ + Msg: &pb.Message_DialDataRequest{ + DialDataRequest: &pb.DialDataRequest{ + AddrIdx: uint32(addrIdx), + NumBytes: uint64(numBytes), + }, + }, + } + if err := w.WriteMsg(msg); err != nil { + return fmt.Errorf("dial data write: %w", err) + } + for remain := numBytes; remain > 0; { + if err := r.ReadMsg(msg); err != nil { + return fmt.Errorf("dial data read: %w", err) + } + if msg.GetDialDataResponse() == nil { + return fmt.Errorf("invalid msg type %T", msg.Msg) + } + remain -= len(msg.GetDialDataResponse().Data) + } + return nil +} + +func (as *server) dialBack(p peer.ID, addr ma.Multiaddr, nonce uint64) pb.DialStatus { + ctx, cancel := context.WithTimeout(context.Background(), dialBackDialTimeout) + ctx = network.WithForceDirectDial(ctx, "autonatv2") + as.dialerHost.Peerstore().AddAddr(p, addr, peerstore.TempAddrTTL) + defer func() { + cancel() + as.dialerHost.Network().ClosePeer(p) + as.dialerHost.Peerstore().ClearAddrs(p) + as.dialerHost.Peerstore().RemovePeer(p) + }() + + err := as.dialerHost.Connect(ctx, peer.AddrInfo{ID: p}) + if err != nil { + return pb.DialStatus_E_DIAL_ERROR + } + + s, err := as.dialerHost.NewStream(ctx, p, DialBackProtocol) + if err != nil { + return pb.DialStatus_E_DIAL_BACK_ERROR + } + + defer s.Close() + s.SetDeadline(as.now().Add(dialBackStreamTimeout)) + + w := pbio.NewDelimitedWriter(s) + if err := w.WriteMsg(&pb.DialBack{Nonce: nonce}); err != nil { + s.Reset() + return pb.DialStatus_E_DIAL_BACK_ERROR + } + + // Since the underlying connection is on a separate dialer, it'll be closed after this function returns. + // Connection close will drop all the queued writes. To ensure message delivery, do a CloseWrite and + // wait a second for the peer to Close its end of the stream. + s.CloseWrite() + s.SetDeadline(as.now().Add(1 * time.Second)) + b := make([]byte, 1) // Read 1 byte here because 0 len reads are free to return (0, nil) immediately + s.Read(b) + + return pb.DialStatus_OK +} + +// rateLimiter implements a sliding window rate limit of requests per minute. It allows 1 concurrent request +// per peer. It rate limits requests globally, at a peer level and depending on whether it requires dial data. +type rateLimiter struct { + // PerPeerRPM is the rate limit per peer + PerPeerRPM int + // RPM is the global rate limit + RPM int + // DialDataRPM is the rate limit for requests that require dial data + DialDataRPM int + + mu sync.Mutex + reqs []time.Time + peerReqs map[peer.ID][]time.Time + dialDataReqs []time.Time + // ongoingReqs tracks in progress requests. This is used to disallow multiple concurrent requests by the + // same peer + ongoingReqs map[peer.ID]struct{} + + now func() time.Time // for tests +} + +func (r *rateLimiter) Accept(p peer.ID, requiresData bool) bool { + r.mu.Lock() + defer r.mu.Unlock() + if r.peerReqs == nil { + r.peerReqs = make(map[peer.ID][]time.Time) + r.ongoingReqs = make(map[peer.ID]struct{}) + } + + nw := r.now() + r.cleanup(p, nw) + + if _, ok := r.ongoingReqs[p]; ok { + return false + } + if len(r.reqs) >= r.RPM || len(r.peerReqs[p]) >= r.PerPeerRPM { + return false + } + if requiresData && len(r.dialDataReqs) >= r.DialDataRPM { + return false + } + + r.ongoingReqs[p] = struct{}{} + r.reqs = append(r.reqs, nw) + r.peerReqs[p] = append(r.peerReqs[p], nw) + if requiresData { + r.dialDataReqs = append(r.dialDataReqs, nw) + } + return true +} + +// cleanup removes stale requests. +// +// This is fast enough in rate limited cases and the state is small enough to +// clean up quickly when blocking requests. +func (r *rateLimiter) cleanup(p peer.ID, now time.Time) { + idx := len(r.reqs) + for i, t := range r.reqs { + if now.Sub(t) < time.Minute { + idx = i + break + } + } + r.reqs = r.reqs[idx:] + + idx = len(r.dialDataReqs) + for i, t := range r.dialDataReqs { + if now.Sub(t) < time.Minute { + idx = i + break + } + } + r.dialDataReqs = r.dialDataReqs[idx:] + + idx = len(r.peerReqs[p]) + for i, t := range r.peerReqs[p] { + if now.Sub(t) < time.Minute { + idx = i + break + } + } + r.peerReqs[p] = r.peerReqs[p][idx:] +} + +func (r *rateLimiter) CompleteRequest(p peer.ID) { + r.mu.Lock() + defer r.mu.Unlock() + + delete(r.ongoingReqs, p) +} + +// amplificationAttackPrevention is a dialDataRequestPolicy which requests data when the peer's observed +// IP address is different from the dial back IP address +func amplificationAttackPrevention(s network.Stream, dialAddr ma.Multiaddr) bool { + connIP, err := manet.ToIP(s.Conn().RemoteMultiaddr()) + if err != nil { + return true + } + dialIP, _ := manet.ToIP(s.Conn().LocalMultiaddr()) // must be an IP multiaddr + return !connIP.Equal(dialIP) +} diff --git a/p2p/protocol/autonatv2/server_test.go b/p2p/protocol/autonatv2/server_test.go new file mode 100644 index 0000000000..ff8b062519 --- /dev/null +++ b/p2p/protocol/autonatv2/server_test.go @@ -0,0 +1,208 @@ +package autonatv2 + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/libp2p/go-libp2p/core/network" + "github.com/libp2p/go-libp2p/core/test" + bhost "github.com/libp2p/go-libp2p/p2p/host/blank" + "github.com/libp2p/go-libp2p/p2p/net/swarm" + swarmt "github.com/libp2p/go-libp2p/p2p/net/swarm/testing" + "github.com/libp2p/go-libp2p/p2p/protocol/autonatv2/pb" + ma "github.com/multiformats/go-multiaddr" + "github.com/stretchr/testify/require" +) + +func newTestRequests(addrs []ma.Multiaddr, sendDialData bool) (reqs []Request) { + reqs = make([]Request, len(addrs)) + for i := 0; i < len(addrs); i++ { + reqs[i] = Request{Addr: addrs[i], SendDialData: sendDialData} + } + return +} + +func TestServerInvalidAddrsRejected(t *testing.T) { + c := newAutoNAT(t, nil, allowAllAddrs) + defer c.Close() + defer c.host.Close() + + t.Run("no transport", func(t *testing.T) { + dialer := bhost.NewBlankHost(swarmt.GenSwarm(t, swarmt.OptDisableQUIC, swarmt.OptDisableTCP)) + an := newAutoNAT(t, dialer, allowAllAddrs) + defer an.Close() + defer an.host.Close() + + idAndWait(t, c, an) + + res, err := c.CheckReachability(context.Background(), newTestRequests(c.host.Addrs(), true)) + fmt.Println(res, err) + require.ErrorIs(t, err, ErrDialRefused) + require.Equal(t, Result{}, res) + }) + + t.Run("black holed addr", func(t *testing.T) { + dialer := bhost.NewBlankHost(swarmt.GenSwarm( + t, swarmt.WithSwarmOpts(swarm.WithReadOnlyBlackHoleDetector()))) + an := newAutoNAT(t, dialer) + defer an.Close() + defer an.host.Close() + + idAndWait(t, c, an) + + res, err := c.CheckReachability(context.Background(), + []Request{{ + Addr: ma.StringCast("/ip4/1.2.3.4/udp/1234/quic-v1"), + SendDialData: true, + }}) + require.ErrorIs(t, err, ErrDialRefused) + require.Equal(t, Result{}, res) + }) + + t.Run("private addrs", func(t *testing.T) { + an := newAutoNAT(t, nil) + defer an.Close() + defer an.host.Close() + + idAndWait(t, c, an) + + res, err := c.CheckReachability(context.Background(), newTestRequests(c.host.Addrs(), true)) + require.ErrorIs(t, err, ErrDialRefused) + require.Equal(t, Result{}, res) + }) +} + +func TestServerDataRequest(t *testing.T) { + // server will skip all tcp addresses + dialer := bhost.NewBlankHost(swarmt.GenSwarm(t, swarmt.OptDisableTCP)) + // ask for dial data for quic address + an := newAutoNAT(t, dialer, allowAllAddrs, withDataRequestPolicy( + func(s network.Stream, dialAddr ma.Multiaddr) bool { + if _, err := dialAddr.ValueForProtocol(ma.P_QUIC_V1); err == nil { + return true + } + return false + }), + WithServerRateLimit(10, 10, 10), + ) + defer an.Close() + defer an.host.Close() + + c := newAutoNAT(t, nil, allowAllAddrs) + defer c.Close() + defer c.host.Close() + + idAndWait(t, c, an) + + var quicAddr, tcpAddr ma.Multiaddr + for _, a := range c.host.Addrs() { + if _, err := a.ValueForProtocol(ma.P_QUIC_V1); err == nil { + quicAddr = a + } else if _, err := a.ValueForProtocol(ma.P_TCP); err == nil { + tcpAddr = a + } + } + + _, err := c.CheckReachability(context.Background(), []Request{{Addr: tcpAddr, SendDialData: true}, {Addr: quicAddr}}) + require.Error(t, err) + + res, err := c.CheckReachability(context.Background(), []Request{{Addr: quicAddr, SendDialData: true}, {Addr: tcpAddr}}) + require.NoError(t, err) + + require.Equal(t, Result{ + Idx: 0, + Addr: quicAddr, + Reachability: network.ReachabilityPublic, + Status: pb.DialStatus_OK, + }, res) +} + +func TestServerDial(t *testing.T) { + an := newAutoNAT(t, nil, WithServerRateLimit(10, 10, 10), allowAllAddrs) + defer an.Close() + defer an.host.Close() + + c := newAutoNAT(t, nil, allowAllAddrs) + defer c.Close() + defer c.host.Close() + + idAndWait(t, c, an) + + unreachableAddr := ma.StringCast("/ip4/1.2.3.4/tcp/2") + hostAddrs := c.host.Addrs() + + t.Run("unreachable addr", func(t *testing.T) { + res, err := c.CheckReachability(context.Background(), + append([]Request{{Addr: unreachableAddr, SendDialData: true}}, newTestRequests(hostAddrs, false)...)) + require.NoError(t, err) + require.Equal(t, Result{ + Idx: 0, + Addr: unreachableAddr, + Reachability: network.ReachabilityPrivate, + Status: pb.DialStatus_E_DIAL_ERROR, + }, res) + }) + + t.Run("reachable addr", func(t *testing.T) { + res, err := c.CheckReachability(context.Background(), newTestRequests(c.host.Addrs(), false)) + require.NoError(t, err) + require.Equal(t, Result{ + Idx: 0, + Addr: hostAddrs[0], + Reachability: network.ReachabilityPublic, + Status: pb.DialStatus_OK, + }, res) + }) + + t.Run("dialback error", func(t *testing.T) { + c.host.RemoveStreamHandler(DialBackProtocol) + res, err := c.CheckReachability(context.Background(), newTestRequests(c.host.Addrs(), false)) + require.NoError(t, err) + require.Equal(t, Result{ + Idx: 0, + Addr: hostAddrs[0], + Reachability: network.ReachabilityUnknown, + Status: pb.DialStatus_E_DIAL_BACK_ERROR, + }, res) + }) +} + +func TestRateLimiter(t *testing.T) { + cl := test.NewMockClock() + r := rateLimiter{RPM: 3, PerPeerRPM: 2, DialDataRPM: 1, now: cl.Now} + + require.True(t, r.Accept("peer1", false)) + + cl.AdvanceBy(10 * time.Second) + require.False(t, r.Accept("peer1", false)) // first request is still active + r.CompleteRequest("peer1") + + require.True(t, r.Accept("peer1", false)) + r.CompleteRequest("peer1") + + cl.AdvanceBy(10 * time.Second) + require.False(t, r.Accept("peer1", false)) + + cl.AdvanceBy(10 * time.Second) + require.True(t, r.Accept("peer2", false)) + r.CompleteRequest("peer2") + + cl.AdvanceBy(10 * time.Second) + require.False(t, r.Accept("peer3", false)) + + cl.AdvanceBy(21 * time.Second) // first request expired + require.True(t, r.Accept("peer1", false)) + r.CompleteRequest("peer1") + + cl.AdvanceBy(10 * time.Second) + require.True(t, r.Accept("peer3", true)) + r.CompleteRequest("peer3") + + cl.AdvanceBy(50 * time.Second) + require.False(t, r.Accept("peer3", true)) + + cl.AdvanceBy(11 * time.Second) + require.True(t, r.Accept("peer3", true)) +}