diff --git a/config/config.go b/config/config.go index 3f0cd85e91..8f1c210827 100644 --- a/config/config.go +++ b/config/config.go @@ -129,6 +129,8 @@ type Config struct { DialRanker network.DialRanker SwarmOpts []swarm.Option + + DisableIdentifyAddressDiscovery bool } func (cfg *Config) makeSwarm(eventBus event.Bus, enableMetrics bool) (*swarm.Swarm, error) { @@ -290,19 +292,20 @@ func (cfg *Config) addTransports() ([]fx.Option, error) { func (cfg *Config) newBasicHost(swrm *swarm.Swarm, eventBus event.Bus) (*bhost.BasicHost, error) { h, err := bhost.NewHost(swrm, &bhost.HostOpts{ - EventBus: eventBus, - ConnManager: cfg.ConnManager, - AddrsFactory: cfg.AddrsFactory, - NATManager: cfg.NATManager, - EnablePing: !cfg.DisablePing, - UserAgent: cfg.UserAgent, - ProtocolVersion: cfg.ProtocolVersion, - EnableHolePunching: cfg.EnableHolePunching, - HolePunchingOptions: cfg.HolePunchingOptions, - EnableRelayService: cfg.EnableRelayService, - RelayServiceOpts: cfg.RelayServiceOpts, - EnableMetrics: !cfg.DisableMetrics, - PrometheusRegisterer: cfg.PrometheusRegisterer, + EventBus: eventBus, + ConnManager: cfg.ConnManager, + AddrsFactory: cfg.AddrsFactory, + NATManager: cfg.NATManager, + EnablePing: !cfg.DisablePing, + UserAgent: cfg.UserAgent, + ProtocolVersion: cfg.ProtocolVersion, + EnableHolePunching: cfg.EnableHolePunching, + HolePunchingOptions: cfg.HolePunchingOptions, + EnableRelayService: cfg.EnableRelayService, + RelayServiceOpts: cfg.RelayServiceOpts, + EnableMetrics: !cfg.DisableMetrics, + PrometheusRegisterer: cfg.PrometheusRegisterer, + DisableIdentifyAddressDiscovery: cfg.DisableIdentifyAddressDiscovery, }) if err != nil { return nil, err diff --git a/core/peerstore/peerstore.go b/core/peerstore/peerstore.go index 4c9227f811..0ef09df9fe 100644 --- a/core/peerstore/peerstore.go +++ b/core/peerstore/peerstore.go @@ -28,9 +28,10 @@ var ( // RecentlyConnectedAddrTTL is used when we recently connected to a peer. // It means that we are reasonably certain of the peer's address. - RecentlyConnectedAddrTTL = time.Minute * 30 + RecentlyConnectedAddrTTL = time.Minute * 15 // OwnObservedAddrTTL is used for our own external addresses observed by peers. + // Deprecated: observed addresses are maintained till we disconnect from the peer which provided it OwnObservedAddrTTL = time.Minute * 30 ) diff --git a/libp2p_test.go b/libp2p_test.go index 7e4ac1b61e..343681be85 100644 --- a/libp2p_test.go +++ b/libp2p_test.go @@ -376,6 +376,12 @@ func TestAutoNATService(t *testing.T) { h.Close() } +func TestDisableIdentifyAddressDiscovery(t *testing.T) { + h, err := New(DisableIdentifyAddressDiscovery()) + require.NoError(t, err) + h.Close() +} + func TestMain(m *testing.M) { goleak.VerifyTestMain( m, diff --git a/options.go b/options.go index 747d6c55e6..de95251ad3 100644 --- a/options.go +++ b/options.go @@ -598,3 +598,14 @@ func SwarmOpts(opts ...swarm.Option) Option { return nil } } + +// DisableIdentifyAddressDiscovery disables address discovery using peer provided observed addresses +// in identify. If you know your public addresses upfront, the recommended way is to use +// AddressFactory to provide the external adddress to the host and use this option to disable +// discovery from identify. +func DisableIdentifyAddressDiscovery() Option { + return func(cfg *Config) error { + cfg.DisableIdentifyAddressDiscovery = true + return nil + } +} diff --git a/p2p/host/basic/basic_host.go b/p2p/host/basic/basic_host.go index d9cdee5abf..8fc808e6b6 100644 --- a/p2p/host/basic/basic_host.go +++ b/p2p/host/basic/basic_host.go @@ -6,6 +6,7 @@ import ( "fmt" "io" "net" + "slices" "sync" "time" @@ -53,6 +54,8 @@ var ( DefaultAddrsFactory = func(addrs []ma.Multiaddr) []ma.Multiaddr { return addrs } ) +const maxPeerRecordSize = 8 * 1024 // 8k to be compatible with identify's limit + // AddrsFactory functions can be passed to New in order to override // addresses returned by Addrs. type AddrsFactory func([]ma.Multiaddr) []ma.Multiaddr @@ -161,6 +164,9 @@ type HostOpts struct { EnableMetrics bool // PrometheusRegisterer is the PrometheusRegisterer used for metrics PrometheusRegisterer prometheus.Registerer + + // DisableIdentifyAddressDiscovery disables address discovery using peer provided observed addresses in identify + DisableIdentifyAddressDiscovery bool } // NewHost constructs a new *BasicHost and activates it by attaching its stream and connection handlers to the given inet.Network. @@ -244,6 +250,9 @@ func NewHost(n network.Network, opts *HostOpts) (*BasicHost, error) { identify.WithMetricsTracer( identify.NewMetricsTracer(identify.WithRegisterer(opts.PrometheusRegisterer)))) } + if opts.DisableIdentifyAddressDiscovery { + idOpts = append(idOpts, identify.DisableObservedAddrManager()) + } h.ids, err = identify.NewIDService(h, idOpts...) if err != nil { @@ -482,15 +491,18 @@ func makeUpdatedAddrEvent(prev, current []ma.Multiaddr) *event.EvtLocalAddresses return &evt } -func (h *BasicHost) makeSignedPeerRecord(evt *event.EvtLocalAddressesUpdated) (*record.Envelope, error) { - current := make([]ma.Multiaddr, 0, len(evt.Current)) - for _, a := range evt.Current { - current = append(current, a.Address) +func (h *BasicHost) makeSignedPeerRecord(addrs []ma.Multiaddr) (*record.Envelope, error) { + // Limit the length of currentAddrs to ensure that our signed peer records aren't rejected + peerRecordSize := 64 // HostID + k, err := h.signKey.Raw() + if err != nil { + peerRecordSize += 2 * len(k) // 1 for signature, 1 for public key } - + // we want the final address list to be small for keeping the signed peer record in size + addrs = trimHostAddrList(addrs, maxPeerRecordSize-peerRecordSize-256) // 256 B of buffer rec := peer.PeerRecordFromAddrInfo(peer.AddrInfo{ ID: h.ID(), - Addrs: current, + Addrs: addrs, }) return record.Seal(rec, h.signKey) } @@ -513,7 +525,7 @@ func (h *BasicHost) background() { if !h.disableSignedPeerRecord { // add signed peer record to the event - sr, err := h.makeSignedPeerRecord(changeEvt) + sr, err := h.makeSignedPeerRecord(currentAddrs) if err != nil { log.Errorf("error creating a signed peer record from the set of current addresses, err=%s", err) return @@ -805,6 +817,7 @@ func (h *BasicHost) Addrs() []ma.Multiaddr { addrs[i] = addrWithCerthash } } + return addrs } @@ -997,6 +1010,58 @@ func inferWebtransportAddrsFromQuic(in []ma.Multiaddr) []ma.Multiaddr { return out } +func trimHostAddrList(addrs []ma.Multiaddr, maxSize int) []ma.Multiaddr { + totalSize := 0 + for _, a := range addrs { + totalSize += len(a.Bytes()) + } + if totalSize <= maxSize { + return addrs + } + + score := func(addr ma.Multiaddr) int { + var res int + if manet.IsPublicAddr(addr) { + res |= 1 << 12 + } else if !manet.IsIPLoopback(addr) { + res |= 1 << 11 + } + var protocolWeight int + ma.ForEach(addr, func(c ma.Component) bool { + switch c.Protocol().Code { + case ma.P_QUIC_V1: + protocolWeight = 5 + case ma.P_TCP: + protocolWeight = 4 + case ma.P_WSS: + protocolWeight = 3 + case ma.P_WEBTRANSPORT: + protocolWeight = 2 + case ma.P_WEBRTC_DIRECT: + protocolWeight = 1 + case ma.P_P2P: + return false + } + return true + }) + res |= 1 << protocolWeight + return res + } + + slices.SortStableFunc(addrs, func(a, b ma.Multiaddr) int { + return score(b) - score(a) // b-a for reverse order + }) + totalSize = 0 + for i, a := range addrs { + totalSize += len(a.Bytes()) + if totalSize > maxSize { + addrs = addrs[:i] + break + } + } + return addrs +} + // SetAutoNat sets the autonat service for the host. func (h *BasicHost) SetAutoNat(a autonat.AutoNAT) { h.addrMu.Lock() diff --git a/p2p/host/basic/basic_host_test.go b/p2p/host/basic/basic_host_test.go index 1fb3b4a397..c4f0680ea2 100644 --- a/p2p/host/basic/basic_host_test.go +++ b/p2p/host/basic/basic_host_test.go @@ -896,3 +896,55 @@ func TestInferWebtransportAddrsFromQuic(t *testing.T) { } } + +func TestTrimHostAddrList(t *testing.T) { + type testCase struct { + name string + in []ma.Multiaddr + threshold int + out []ma.Multiaddr + } + + tcpPublic := ma.StringCast("/ip4/1.1.1.1/tcp/1") + quicPublic := ma.StringCast("/ip4/1.1.1.1/udp/1/quic-v1") + + tcpPrivate := ma.StringCast("/ip4/192.168.1.1/tcp/1") + quicPrivate := ma.StringCast("/ip4/192.168.1.1/udp/1/quic-v1") + + tcpLocal := ma.StringCast("/ip4/127.0.0.1/tcp/1") + quicLocal := ma.StringCast("/ip4/127.0.0.1/udp/1/quic-v1") + + testCases := []testCase{ + { + name: "Public preferred over private", + in: []ma.Multiaddr{tcpPublic, quicPrivate}, + threshold: len(tcpLocal.Bytes()), + out: []ma.Multiaddr{tcpPublic}, + }, + { + name: "Public and private preffered over local", + in: []ma.Multiaddr{tcpPublic, tcpPrivate, quicLocal}, + threshold: len(tcpPublic.Bytes()) + len(tcpPrivate.Bytes()), + out: []ma.Multiaddr{tcpPublic, tcpPrivate}, + }, + { + name: "quic preferred over tcp", + in: []ma.Multiaddr{tcpPublic, quicPublic}, + threshold: len(quicPublic.Bytes()), + out: []ma.Multiaddr{quicPublic}, + }, + { + name: "no filtering on large threshold", + in: []ma.Multiaddr{tcpPublic, quicPublic, quicLocal, tcpLocal, tcpPrivate}, + threshold: 10000, + out: []ma.Multiaddr{tcpPublic, quicPublic, quicLocal, tcpLocal, tcpPrivate}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + got := trimHostAddrList(tc.in, tc.threshold) + require.ElementsMatch(t, got, tc.out) + }) + } +} diff --git a/p2p/protocol/identify/id.go b/p2p/protocol/identify/id.go index 7ae4feb935..a91cc4f92e 100644 --- a/p2p/protocol/identify/id.go +++ b/p2p/protocol/identify/id.go @@ -34,24 +34,27 @@ import ( var log = logging.Logger("net/identify") +var Timeout = 30 * time.Second // timeout on all incoming Identify interactions + const ( // ID is the protocol.ID of version 1.0.0 of the identify service. ID = "/ipfs/id/1.0.0" // IDPush is the protocol.ID of the Identify push protocol. // It sends full identify messages containing the current state of the peer. IDPush = "/ipfs/id/push/1.0.0" -) - -const ServiceName = "libp2p.identify" - -const maxPushConcurrency = 32 -var Timeout = 60 * time.Second // timeout on all incoming Identify interactions - -const ( - legacyIDSize = 2 * 1024 // 2k Bytes - signedIDSize = 8 * 1024 // 8K - maxMessages = 10 + ServiceName = "libp2p.identify" + + legacyIDSize = 2 * 1024 + signedIDSize = 8 * 1024 + maxOwnIdentifyMsgSize = 4 * 1024 // smaller than what we accept. This is 4k to be compatible with rust-libp2p + maxMessages = 10 + maxPushConcurrency = 32 + // number of addresses to keep for peers we have disconnected from for peerstore.RecentlyConnectedTTL time + // This number can be small as we already filter peer addresses based on whether the peer is connected to us over + // localhost, private IP or public IP address + recentlyConnectedPeerMaxAddrs = 20 + connectedPeerMaxAddrs = 500 ) var defaultUserAgent = "github.com/libp2p/go-libp2p" @@ -159,7 +162,8 @@ type idService struct { addrMu sync.Mutex // our own observed addresses. - observedAddrs *ObservedAddrManager + observedAddrMgr *ObservedAddrManager + disableObservedAddrManager bool emitters struct { evtPeerProtocolsUpdated event.Emitter @@ -171,6 +175,12 @@ type idService struct { sync.Mutex snapshot identifySnapshot } + + natEmitter *natEmitter +} + +type normalizer interface { + NormalizeMultiaddr(ma.Multiaddr) ma.Multiaddr } // NewIDService constructs a new *idService and activates it by @@ -199,11 +209,27 @@ func NewIDService(h host.Host, opts ...Option) (*idService, error) { metricsTracer: cfg.metricsTracer, } - observedAddrs, err := NewObservedAddrManager(h) - if err != nil { - return nil, fmt.Errorf("failed to create observed address manager: %s", err) + var normalize func(ma.Multiaddr) ma.Multiaddr + if hn, ok := h.(normalizer); ok { + normalize = hn.NormalizeMultiaddr + } + + var err error + if cfg.disableObservedAddrManager { + s.disableObservedAddrManager = true + } else { + observedAddrs, err := NewObservedAddrManager(h.Network().ListenAddresses, + h.Addrs, h.Network().InterfaceListenAddresses, normalize) + if err != nil { + return nil, fmt.Errorf("failed to create observed address manager: %s", err) + } + natEmitter, err := newNATEmitter(h, observedAddrs, time.Minute) + if err != nil { + return nil, fmt.Errorf("failed to create nat emitter: %s", err) + } + s.natEmitter = natEmitter + s.observedAddrMgr = observedAddrs } - s.observedAddrs = observedAddrs s.emitters.evtPeerProtocolsUpdated, err = h.EventBus().Emitter(&event.EvtPeerProtocolsUpdated{}) if err != nil { @@ -341,17 +367,26 @@ func (ids *idService) sendPushes(ctx context.Context) { // Close shuts down the idService func (ids *idService) Close() error { ids.ctxCancel() - ids.observedAddrs.Close() + if !ids.disableObservedAddrManager { + ids.observedAddrMgr.Close() + ids.natEmitter.Close() + } ids.refCount.Wait() return nil } func (ids *idService) OwnObservedAddrs() []ma.Multiaddr { - return ids.observedAddrs.Addrs() + if ids.disableObservedAddrManager { + return nil + } + return ids.observedAddrMgr.Addrs() } func (ids *idService) ObservedAddrsFor(local ma.Multiaddr) []ma.Multiaddr { - return ids.observedAddrs.AddrsFor(local) + if ids.disableObservedAddrManager { + return nil + } + return ids.observedAddrMgr.AddrsFor(local) } // IdentifyConn runs the Identify protocol on a connection. @@ -553,10 +588,18 @@ func readAllIDMessages(r pbio.Reader, finalMsg proto.Message) error { } func (ids *idService) updateSnapshot() (updated bool) { - addrs := ids.Host.Addrs() - slices.SortFunc(addrs, func(a, b ma.Multiaddr) int { return bytes.Compare(a.Bytes(), b.Bytes()) }) protos := ids.Host.Mux().Protocols() slices.Sort(protos) + + addrs := ids.Host.Addrs() + slices.SortFunc(addrs, func(a, b ma.Multiaddr) int { return bytes.Compare(a.Bytes(), b.Bytes()) }) + + usedSpace := len(ids.ProtocolVersion) + len(ids.UserAgent) + for i := 0; i < len(protos); i++ { + usedSpace += len(protos[i]) + } + addrs = trimHostAddrList(addrs, maxOwnIdentifyMsgSize-usedSpace-256) // 256 bytes of buffer + snapshot := identifySnapshot{ addrs: addrs, protocols: protos, @@ -715,9 +758,9 @@ func (ids *idService) consumeMessage(mes *pb.Identify, c network.Conn, isPush bo obsAddr = nil } - if obsAddr != nil { + if obsAddr != nil && !ids.disableObservedAddrManager { // TODO refactor this to use the emitted events instead of having this func call explicitly. - ids.observedAddrs.Record(c, obsAddr) + ids.observedAddrMgr.Record(c, obsAddr) } // mes.ListenAddrs @@ -777,7 +820,12 @@ func (ids *idService) consumeMessage(mes *pb.Identify, c network.Conn, isPush bo } else { addrs = lmaddrs } - ids.Host.Peerstore().AddAddrs(p, filterAddrs(addrs, c.RemoteMultiaddr()), ttl) + addrs = filterAddrs(addrs, c.RemoteMultiaddr()) + if len(addrs) > connectedPeerMaxAddrs { + addrs = addrs[:connectedPeerMaxAddrs] + } + + ids.Host.Peerstore().AddAddrs(p, addrs, ttl) // Finally, expire all temporary addrs. ids.Host.Peerstore().UpdateAddrs(p, peerstore.TempAddrTTL, 0) @@ -981,15 +1029,36 @@ func (nn *netNotifiee) Disconnected(_ network.Network, c network.Conn) { delete(ids.conns, c) ids.connsMu.Unlock() - switch ids.Host.Network().Connectedness(c.RemotePeer()) { - case network.Connected, network.Limited: - return + if !ids.disableObservedAddrManager { + ids.observedAddrMgr.removeConn(c) } + // Last disconnect. // Undo the setting of addresses to peer.ConnectedAddrTTL we did ids.addrMu.Lock() defer ids.addrMu.Unlock() - ids.Host.Peerstore().UpdateAddrs(c.RemotePeer(), peerstore.ConnectedAddrTTL, peerstore.RecentlyConnectedAddrTTL) + + // This check MUST happen after acquiring the Lock as identify on a different connection + // might be trying to add addresses. + switch ids.Host.Network().Connectedness(c.RemotePeer()) { + case network.Connected, network.Limited: + return + } + // peerstore returns the elements in a random order as it uses a map to store the addresses + addrs := ids.Host.Peerstore().Addrs(c.RemotePeer()) + n := len(addrs) + if n > recentlyConnectedPeerMaxAddrs { + // We want to always save the address we are connected to + for i, a := range addrs { + if a.Equal(c.RemoteMultiaddr()) { + addrs[i], addrs[0] = addrs[0], addrs[i] + } + } + n = recentlyConnectedPeerMaxAddrs + } + ids.Host.Peerstore().UpdateAddrs(c.RemotePeer(), peerstore.ConnectedAddrTTL, peerstore.TempAddrTTL) + ids.Host.Peerstore().AddAddrs(c.RemotePeer(), addrs[:n], peerstore.RecentlyConnectedAddrTTL) + ids.Host.Peerstore().UpdateAddrs(c.RemotePeer(), peerstore.TempAddrTTL, 0) } func (nn *netNotifiee) Listen(n network.Network, a ma.Multiaddr) {} @@ -1008,3 +1077,55 @@ func filterAddrs(addrs []ma.Multiaddr, remote ma.Multiaddr) []ma.Multiaddr { } return ma.FilterAddrs(addrs, manet.IsPublicAddr) } + +func trimHostAddrList(addrs []ma.Multiaddr, maxSize int) []ma.Multiaddr { + totalSize := 0 + for _, a := range addrs { + totalSize += len(a.Bytes()) + } + if totalSize <= maxSize { + return addrs + } + + score := func(addr ma.Multiaddr) int { + var res int + if manet.IsPublicAddr(addr) { + res |= 1 << 12 + } else if !manet.IsIPLoopback(addr) { + res |= 1 << 11 + } + var protocolWeight int + ma.ForEach(addr, func(c ma.Component) bool { + switch c.Protocol().Code { + case ma.P_QUIC_V1: + protocolWeight = 5 + case ma.P_TCP: + protocolWeight = 4 + case ma.P_WSS: + protocolWeight = 3 + case ma.P_WEBTRANSPORT: + protocolWeight = 2 + case ma.P_WEBRTC_DIRECT: + protocolWeight = 1 + case ma.P_P2P: + return false + } + return true + }) + res |= 1 << protocolWeight + return res + } + + slices.SortStableFunc(addrs, func(a, b ma.Multiaddr) int { + return score(b) - score(a) // b-a for reverse order + }) + totalSize = 0 + for i, a := range addrs { + totalSize += len(a.Bytes()) + if totalSize > maxSize { + addrs = addrs[:i] + break + } + } + return addrs +} diff --git a/p2p/protocol/identify/id_test.go b/p2p/protocol/identify/id_test.go index 2352abef0b..a65d64f24e 100644 --- a/p2p/protocol/identify/id_test.go +++ b/p2p/protocol/identify/id_test.go @@ -107,104 +107,121 @@ func emitAddrChangeEvt(t *testing.T, h host.Host) { // this is because it used to be concurrent. Now, Dial wait till the // id service is done. func TestIDService(t *testing.T) { - if race.WithRace() { - t.Skip("This test modifies peerstore.RecentlyConnectedAddrTTL, which is racy.") - } - // This test is highly timing dependent, waiting on timeouts/expiration. - oldTTL := peerstore.RecentlyConnectedAddrTTL - peerstore.RecentlyConnectedAddrTTL = 500 * time.Millisecond - t.Cleanup(func() { peerstore.RecentlyConnectedAddrTTL = oldTTL }) - - clk := mockClock.NewMock() - swarm1 := swarmt.GenSwarm(t, swarmt.WithClock(clk)) - swarm2 := swarmt.GenSwarm(t, swarmt.WithClock(clk)) - h1 := blhost.NewBlankHost(swarm1) - h2 := blhost.NewBlankHost(swarm2) - - h1p := h1.ID() - h2p := h2.ID() - - ids1, err := identify.NewIDService(h1) - require.NoError(t, err) - defer ids1.Close() - ids1.Start() - - ids2, err := identify.NewIDService(h2) - require.NoError(t, err) - defer ids2.Close() - ids2.Start() - - sub, err := ids1.Host.EventBus().Subscribe(new(event.EvtPeerIdentificationCompleted)) - if err != nil { - t.Fatal(err) - } - - testKnowsAddrs(t, h1, h2p, []ma.Multiaddr{}) // nothing - testKnowsAddrs(t, h2, h1p, []ma.Multiaddr{}) // nothing - - // the forgetMe addr represents an address for h1 that h2 has learned out of band - // (not via identify protocol). During the identify exchange, it will be - // forgotten and replaced by the addrs h1 sends. - forgetMe, _ := ma.NewMultiaddr("/ip4/1.2.3.4/tcp/1234") - - h2.Peerstore().AddAddr(h1p, forgetMe, peerstore.RecentlyConnectedAddrTTL) - h2pi := h2.Peerstore().PeerInfo(h2p) - require.NoError(t, h1.Connect(context.Background(), h2pi)) - - h1t2c := h1.Network().ConnsToPeer(h2p) - require.NotEmpty(t, h1t2c, "should have a conn here") - - ids1.IdentifyConn(h1t2c[0]) - - // the idService should be opened automatically, by the network. - // what we should see now is that both peers know about each others listen addresses. - t.Log("test peer1 has peer2 addrs correctly") - testKnowsAddrs(t, h1, h2p, h2.Addrs()) // has them - testHasAgentVersion(t, h1, h2p) - testHasPublicKey(t, h1, h2p, h2.Peerstore().PubKey(h2p)) // h1 should have h2's public key - - // now, this wait we do have to do. it's the wait for the Listening side - // to be done identifying the connection. - c := h2.Network().ConnsToPeer(h1.ID()) - require.NotEmpty(t, c, "should have connection by now at least.") - ids2.IdentifyConn(c[0]) + for _, withObsAddrManager := range []bool{false, true} { + t.Run(fmt.Sprintf("withObsAddrManager=%t", withObsAddrManager), func(t *testing.T) { + if race.WithRace() { + t.Skip("This test modifies peerstore.RecentlyConnectedAddrTTL, which is racy.") + } + // This test is highly timing dependent, waiting on timeouts/expiration. + oldTTL := peerstore.RecentlyConnectedAddrTTL + oldTempTTL := peerstore.TempAddrTTL + peerstore.RecentlyConnectedAddrTTL = 500 * time.Millisecond + peerstore.TempAddrTTL = 50 * time.Millisecond + t.Cleanup(func() { + peerstore.RecentlyConnectedAddrTTL = oldTTL + peerstore.TempAddrTTL = oldTempTTL + }) - // and the protocol versions. - t.Log("test peer2 has peer1 addrs correctly") - testKnowsAddrs(t, h2, h1p, h1.Addrs()) // has them - testHasAgentVersion(t, h2, h1p) - testHasPublicKey(t, h2, h1p, h1.Peerstore().PubKey(h1p)) // h1 should have h2's public key + clk := mockClock.NewMock() + swarm1 := swarmt.GenSwarm(t, swarmt.WithClock(clk)) + swarm2 := swarmt.GenSwarm(t, swarmt.WithClock(clk)) + h1 := blhost.NewBlankHost(swarm1) + h2 := blhost.NewBlankHost(swarm2) - // Need both sides to actually notice that the connection has been closed. - sentDisconnect1 := waitForDisconnectNotification(swarm1) - sentDisconnect2 := waitForDisconnectNotification(swarm2) - h1.Network().ClosePeer(h2p) - h2.Network().ClosePeer(h1p) - if len(h2.Network().ConnsToPeer(h1.ID())) != 0 || len(h1.Network().ConnsToPeer(h2.ID())) != 0 { - t.Fatal("should have no connections") - } + h1p := h1.ID() + h2p := h2.ID() - t.Log("testing addrs just after disconnect") - // addresses don't immediately expire on disconnect, so we should still have them - testKnowsAddrs(t, h2, h1p, h1.Addrs()) - testKnowsAddrs(t, h1, h2p, h2.Addrs()) - - <-sentDisconnect1 - <-sentDisconnect2 + opts := []identify.Option{} + if !withObsAddrManager { + opts = append(opts, identify.DisableObservedAddrManager()) + } + ids1, err := identify.NewIDService(h1, opts...) + require.NoError(t, err) + defer ids1.Close() + ids1.Start() + + opts = []identify.Option{} + if !withObsAddrManager { + opts = append(opts, identify.DisableObservedAddrManager()) + } + ids2, err := identify.NewIDService(h2, opts...) + require.NoError(t, err) + defer ids2.Close() + ids2.Start() + + sub, err := ids1.Host.EventBus().Subscribe(new(event.EvtPeerIdentificationCompleted)) + if err != nil { + t.Fatal(err) + } - // the addrs had their TTLs reduced on disconnect, and - // will be forgotten soon after - t.Log("testing addrs after TTL expiration") - clk.Add(time.Second) - testKnowsAddrs(t, h1, h2p, []ma.Multiaddr{}) - testKnowsAddrs(t, h2, h1p, []ma.Multiaddr{}) + testKnowsAddrs(t, h1, h2p, []ma.Multiaddr{}) // nothing + testKnowsAddrs(t, h2, h1p, []ma.Multiaddr{}) // nothing + + // the forgetMe addr represents an address for h1 that h2 has learned out of band + // (not via identify protocol). During the identify exchange, it will be + // forgotten and replaced by the addrs h1 sends. + forgetMe, _ := ma.NewMultiaddr("/ip4/1.2.3.4/tcp/1234") + + h2.Peerstore().AddAddr(h1p, forgetMe, peerstore.RecentlyConnectedAddrTTL) + h2pi := h2.Peerstore().PeerInfo(h2p) + require.NoError(t, h1.Connect(context.Background(), h2pi)) + + h1t2c := h1.Network().ConnsToPeer(h2p) + require.NotEmpty(t, h1t2c, "should have a conn here") + + ids1.IdentifyConn(h1t2c[0]) + + // the idService should be opened automatically, by the network. + // what we should see now is that both peers know about each others listen addresses. + t.Log("test peer1 has peer2 addrs correctly") + testKnowsAddrs(t, h1, h2p, h2.Addrs()) // has them + testHasAgentVersion(t, h1, h2p) + testHasPublicKey(t, h1, h2p, h2.Peerstore().PubKey(h2p)) // h1 should have h2's public key + + // now, this wait we do have to do. it's the wait for the Listening side + // to be done identifying the connection. + c := h2.Network().ConnsToPeer(h1.ID()) + require.NotEmpty(t, c, "should have connection by now at least.") + ids2.IdentifyConn(c[0]) + + // and the protocol versions. + t.Log("test peer2 has peer1 addrs correctly") + testKnowsAddrs(t, h2, h1p, h1.Addrs()) // has them + testHasAgentVersion(t, h2, h1p) + testHasPublicKey(t, h2, h1p, h1.Peerstore().PubKey(h1p)) // h1 should have h2's public key + + // Need both sides to actually notice that the connection has been closed. + sentDisconnect1 := waitForDisconnectNotification(swarm1) + sentDisconnect2 := waitForDisconnectNotification(swarm2) + h1.Network().ClosePeer(h2p) + h2.Network().ClosePeer(h1p) + if len(h2.Network().ConnsToPeer(h1.ID())) != 0 || len(h1.Network().ConnsToPeer(h2.ID())) != 0 { + t.Fatal("should have no connections") + } - // test that we received the "identify completed" event. - select { - case evtAny := <-sub.Out(): - assertCorrectEvtPeerIdentificationCompleted(t, evtAny, h2) - case <-time.After(3 * time.Second): - t.Fatalf("expected EvtPeerIdentificationCompleted event within 10 seconds; none received") + t.Log("testing addrs just after disconnect") + // addresses don't immediately expire on disconnect, so we should still have them + testKnowsAddrs(t, h2, h1p, h1.Addrs()) + testKnowsAddrs(t, h1, h2p, h2.Addrs()) + + <-sentDisconnect1 + <-sentDisconnect2 + + // the addrs had their TTLs reduced on disconnect, and + // will be forgotten soon after + t.Log("testing addrs after TTL expiration") + clk.Add(time.Second) + testKnowsAddrs(t, h1, h2p, []ma.Multiaddr{}) + testKnowsAddrs(t, h2, h1p, []ma.Multiaddr{}) + + // test that we received the "identify completed" event. + select { + case evtAny := <-sub.Out(): + assertCorrectEvtPeerIdentificationCompleted(t, evtAny, h2) + case <-time.After(3 * time.Second): + t.Fatalf("expected EvtPeerIdentificationCompleted event within 10 seconds; none received") + } + }) } } @@ -603,8 +620,13 @@ func TestLargeIdentifyMessage(t *testing.T) { t.Skip("setting peerstore.RecentlyConnectedAddrTTL is racy") } oldTTL := peerstore.RecentlyConnectedAddrTTL + oldTempTTL := peerstore.TempAddrTTL peerstore.RecentlyConnectedAddrTTL = 500 * time.Millisecond - t.Cleanup(func() { peerstore.RecentlyConnectedAddrTTL = oldTTL }) + peerstore.TempAddrTTL = 50 * time.Millisecond + t.Cleanup(func() { + peerstore.RecentlyConnectedAddrTTL = oldTTL + peerstore.TempAddrTTL = oldTempTTL + }) clk := mockClock.NewMock() swarm1 := swarmt.GenSwarm(t, swarmt.WithClock(clk)) diff --git a/p2p/protocol/identify/nat_emitter.go b/p2p/protocol/identify/nat_emitter.go new file mode 100644 index 0000000000..fec9b68fe2 --- /dev/null +++ b/p2p/protocol/identify/nat_emitter.go @@ -0,0 +1,119 @@ +package identify + +import ( + "context" + "fmt" + "sync" + "time" + + "github.com/libp2p/go-libp2p/core/event" + "github.com/libp2p/go-libp2p/core/host" + "github.com/libp2p/go-libp2p/core/network" + "github.com/libp2p/go-libp2p/p2p/host/eventbus" +) + +type natEmitter struct { + ctx context.Context + cancel context.CancelFunc + wg sync.WaitGroup + reachabilitySub event.Subscription + reachability network.Reachability + eventInterval time.Duration + + currentUDPNATDeviceType network.NATDeviceType + currentTCPNATDeviceType network.NATDeviceType + emitNATDeviceTypeChanged event.Emitter + + observedAddrMgr *ObservedAddrManager +} + +func newNATEmitter(h host.Host, o *ObservedAddrManager, eventInterval time.Duration) (*natEmitter, error) { + ctx, cancel := context.WithCancel(context.Background()) + n := &natEmitter{ + observedAddrMgr: o, + ctx: ctx, + cancel: cancel, + eventInterval: eventInterval, + } + reachabilitySub, err := h.EventBus().Subscribe(new(event.EvtLocalReachabilityChanged), eventbus.Name("identify (nat emitter)")) + if err != nil { + return nil, fmt.Errorf("failed to subscribe to reachability event: %s", err) + } + n.reachabilitySub = reachabilitySub + + emitter, err := h.EventBus().Emitter(new(event.EvtNATDeviceTypeChanged), eventbus.Stateful) + if err != nil { + return nil, fmt.Errorf("failed to create emitter for NATDeviceType: %s", err) + } + n.emitNATDeviceTypeChanged = emitter + + n.wg.Add(1) + go n.worker() + return n, nil +} + +func (n *natEmitter) worker() { + defer n.wg.Done() + subCh := n.reachabilitySub.Out() + ticker := time.NewTicker(n.eventInterval) + pendingUpdate := false + enoughTimeSinceLastUpdate := true + for { + select { + case evt, ok := <-subCh: + if !ok { + subCh = nil + continue + } + ev, ok := evt.(event.EvtLocalReachabilityChanged) + if !ok { + log.Error("invalid event: %v", evt) + continue + } + n.reachability = ev.Reachability + case <-ticker.C: + enoughTimeSinceLastUpdate = true + if pendingUpdate { + n.maybeNotify() + pendingUpdate = false + enoughTimeSinceLastUpdate = false + } + case <-n.observedAddrMgr.addrRecordedNotif: + pendingUpdate = true + if enoughTimeSinceLastUpdate { + n.maybeNotify() + pendingUpdate = false + enoughTimeSinceLastUpdate = false + } + case <-n.ctx.Done(): + return + } + } +} + +func (n *natEmitter) maybeNotify() { + if n.reachability == network.ReachabilityPrivate { + tcpNATType, udpNATType := n.observedAddrMgr.getNATType() + if tcpNATType != n.currentTCPNATDeviceType { + n.currentTCPNATDeviceType = tcpNATType + n.emitNATDeviceTypeChanged.Emit(event.EvtNATDeviceTypeChanged{ + TransportProtocol: network.NATTransportTCP, + NatDeviceType: n.currentTCPNATDeviceType, + }) + } + if udpNATType != n.currentUDPNATDeviceType { + n.currentUDPNATDeviceType = udpNATType + n.emitNATDeviceTypeChanged.Emit(event.EvtNATDeviceTypeChanged{ + TransportProtocol: network.NATTransportUDP, + NatDeviceType: n.currentUDPNATDeviceType, + }) + } + } +} + +func (n *natEmitter) Close() { + n.cancel() + n.wg.Wait() + n.reachabilitySub.Close() + n.emitNATDeviceTypeChanged.Close() +} diff --git a/p2p/protocol/identify/obsaddr.go b/p2p/protocol/identify/obsaddr.go index 70a7eccd49..4437c4b011 100644 --- a/p2p/protocol/identify/obsaddr.go +++ b/p2p/protocol/identify/obsaddr.go @@ -3,16 +3,12 @@ package identify import ( "context" "fmt" + "net" + "slices" + "sort" "sync" - "time" - "golang.org/x/exp/slices" - - "github.com/libp2p/go-libp2p/core/event" - "github.com/libp2p/go-libp2p/core/host" "github.com/libp2p/go-libp2p/core/network" - "github.com/libp2p/go-libp2p/core/peerstore" - "github.com/libp2p/go-libp2p/p2p/host/eventbus" ma "github.com/multiformats/go-multiaddr" manet "github.com/multiformats/go-multiaddr/net" @@ -25,217 +21,252 @@ import ( // the GC rounds set by GCInterval. var ActivationThresh = 4 -// GCInterval specicies how often to make a round cleaning seen events and -// observed addresses. An address will be cleaned if it has not been seen in -// OwnObservedAddressTTL (10 minutes). A "seen" event will be cleaned up if -// it is older than OwnObservedAddressTTL * ActivationThresh (40 minutes). -var GCInterval = 10 * time.Minute - // observedAddrManagerWorkerChannelSize defines how many addresses can be enqueued // for adding to an ObservedAddrManager. var observedAddrManagerWorkerChannelSize = 16 -// maxObservedAddrsPerIPAndTransport is the maximum number of observed addresses -// we will return for each (IPx/TCP or UDP) group. -var maxObservedAddrsPerIPAndTransport = 2 - -// observation records an address observation from an "observer" (where every IP -// address is a unique observer). -type observation struct { - // seenTime is the last time this observation was made. - seenTime time.Time - // inbound indicates whether or not this observation has been made from - // an inbound connection. This remains true even if we an observation - // from a subsequent outbound connection. - inbound bool -} +const maxExternalThinWaistAddrsPerLocalAddr = 3 -// observedAddr is an entry for an address reported by our peers. -// We only use addresses that: -// - have been observed at least 4 times in last 40 minutes. (counter symmetric nats) -// - have been observed at least once recently (10 minutes), because our position in the -// network, or network port mapppings, may have changed. -type observedAddr struct { - addr ma.Multiaddr - seenBy map[string]observation // peer(observer) address -> observation info - lastSeen time.Time - numInbound int +// thinWaist is a struct that stores the address along with it's thin waist prefix and rest of the multiaddr +type thinWaist struct { + Addr, TW, Rest ma.Multiaddr } -func (oa *observedAddr) activated() bool { - - // We only activate if other peers observed the same address - // of ours at least 4 times. SeenBy peers are removed by GC if - // they say the address more than ttl*ActivationThresh - return len(oa.seenBy) >= ActivationThresh +// thinWaistWithCount is a thinWaist along with the count of the connection that have it as the local address +type thinWaistWithCount struct { + thinWaist + Count int } -// GroupKey returns the group in which this observation belongs. Currently, an -// observed address's group is just the address with all ports set to 0. This -// means we can advertise the most commonly observed external ports without -// advertising _every_ observed port. -func (oa *observedAddr) groupKey() string { - key := make([]byte, 0, len(oa.addr.Bytes())) - ma.ForEach(oa.addr, func(c ma.Component) bool { - switch proto := c.Protocol(); proto.Code { - case ma.P_TCP, ma.P_UDP: - key = append(key, proto.VCode...) - key = append(key, 0, 0) // zero in two bytes - default: - key = append(key, c.Bytes()...) +func thinWaistForm(a ma.Multiaddr) (thinWaist, error) { + i := 0 + tw, rest := ma.SplitFunc(a, func(c ma.Component) bool { + if i > 1 { + return true } - return true + switch i { + case 0: + if c.Protocol().Code == ma.P_IP4 || c.Protocol().Code == ma.P_IP6 { + i++ + return false + } + return true + case 1: + if c.Protocol().Code == ma.P_TCP || c.Protocol().Code == ma.P_UDP { + i++ + return false + } + return true + } + return false }) - - return string(key) + if i <= 1 { + return thinWaist{}, fmt.Errorf("not a thinwaist address: %s", a) + } + return thinWaist{Addr: a, TW: tw, Rest: rest}, nil } -type newObservation struct { - conn network.Conn - observed ma.Multiaddr +// getObserver returns the observer for the multiaddress +// For an IPv4 multiaddress the observer is the IP address +// For an IPv6 multiaddress the observer is the first /56 prefix of the IP address +func getObserver(a ma.Multiaddr) (string, error) { + ip, err := manet.ToIP(a) + if err != nil { + return "", err + } + if ip4 := ip.To4(); ip4 != nil { + return ip4.String(), nil + } + // Count /56 prefix as a single observer. + return ip.Mask(net.CIDRMask(56, 128)).String(), nil } -// ObservedAddrManager keeps track of a ObservedAddrs. -type ObservedAddrManager struct { - host host.Host - - closeOnce sync.Once - refCount sync.WaitGroup - ctx context.Context // the context is canceled when Close is called - ctxCancel context.CancelFunc +// connMultiaddrs provides IsClosed along with network.ConnMultiaddrs. It is easier to mock this than network.Conn +type connMultiaddrs interface { + network.ConnMultiaddrs + IsClosed() bool +} - // latest observation from active connections - // we'll "re-observe" these when we gc - activeConnsMu sync.Mutex - // active connection -> most recent observation - activeConns map[network.Conn]ma.Multiaddr +// observerSetCacheSize is the number of transport sharing the same thinwaist (tcp, ws, wss), (quic, webtransport, webrtc-direct) +// This is 3 in practice right now, but keep a buffer of 3 extra elements +const observerSetCacheSize = 5 - mu sync.RWMutex - closed bool - // local(internal) address -> list of observed(external) addresses - addrs map[string][]*observedAddr - ttl time.Duration - refreshTimer *time.Timer +// observerSet is the set of observers who have observed ThinWaistAddr +type observerSet struct { + ObservedTWAddr ma.Multiaddr + ObservedBy map[string]int - // this is the worker channel - wch chan newObservation + mu sync.RWMutex // protects following + cachedMultiaddrs map[string]ma.Multiaddr // cache of localMultiaddr rest(addr - thinwaist) => output multiaddr +} - reachabilitySub event.Subscription - reachability network.Reachability +func (s *observerSet) cacheMultiaddr(addr ma.Multiaddr) ma.Multiaddr { + if addr == nil { + return s.ObservedTWAddr + } + addrStr := string(addr.Bytes()) + s.mu.RLock() + res, ok := s.cachedMultiaddrs[addrStr] + s.mu.RUnlock() + if ok { + return res + } + + s.mu.Lock() + defer s.mu.Unlock() + // Check if some other go routine added this while we were waiting + res, ok = s.cachedMultiaddrs[addrStr] + if ok { + return res + } + if s.cachedMultiaddrs == nil { + s.cachedMultiaddrs = make(map[string]ma.Multiaddr, observerSetCacheSize) + } + if len(s.cachedMultiaddrs) == observerSetCacheSize { + // remove one entry if we will go over the limit + for k := range s.cachedMultiaddrs { + delete(s.cachedMultiaddrs, k) + break + } + } + s.cachedMultiaddrs[addrStr] = ma.Join(s.ObservedTWAddr, addr) + return s.cachedMultiaddrs[addrStr] +} - currentUDPNATDeviceType network.NATDeviceType - currentTCPNATDeviceType network.NATDeviceType - emitNATDeviceTypeChanged event.Emitter +type observation struct { + conn connMultiaddrs + observed ma.Multiaddr } -// NewObservedAddrManager returns a new address manager using -// peerstore.OwnObservedAddressTTL as the TTL. -func NewObservedAddrManager(host host.Host) (*ObservedAddrManager, error) { - oas := &ObservedAddrManager{ - addrs: make(map[string][]*observedAddr), - ttl: peerstore.OwnObservedAddrTTL, - wch: make(chan newObservation, observedAddrManagerWorkerChannelSize), - host: host, - activeConns: make(map[network.Conn]ma.Multiaddr), - // refresh every ttl/2 so we don't forget observations from connected peers - refreshTimer: time.NewTimer(peerstore.OwnObservedAddrTTL / 2), - } - oas.ctx, oas.ctxCancel = context.WithCancel(context.Background()) - - reachabilitySub, err := host.EventBus().Subscribe(new(event.EvtLocalReachabilityChanged), eventbus.Name("identify (obsaddr)")) - if err != nil { - return nil, fmt.Errorf("failed to subscribe to reachability event: %s", err) - } - oas.reachabilitySub = reachabilitySub +// ObservedAddrManager maps connection's local multiaddrs to their externally observable multiaddress +type ObservedAddrManager struct { + // Our listen addrs + listenAddrs func() []ma.Multiaddr + // Our listen addrs with interface addrs for unspecified addrs + interfaceListenAddrs func() ([]ma.Multiaddr, error) + // All host addrs + hostAddrs func() []ma.Multiaddr + // Any normalization required before comparing. Useful to remove certhash + normalize func(ma.Multiaddr) ma.Multiaddr + // worker channel for new observations + wch chan observation + // notified on recording an observation + addrRecordedNotif chan struct{} + + // for closing + wg sync.WaitGroup + ctx context.Context + ctxCancel context.CancelFunc - emitter, err := host.EventBus().Emitter(new(event.EvtNATDeviceTypeChanged), eventbus.Stateful) - if err != nil { - return nil, fmt.Errorf("failed to create emitter for NATDeviceType: %s", err) - } - oas.emitNATDeviceTypeChanged = emitter + mu sync.RWMutex + // local thin waist => external thin waist => observerSet + externalAddrs map[string]map[string]*observerSet + // connObservedTWAddrs maps the connection to the last observed thin waist multiaddr on that connection + connObservedTWAddrs map[connMultiaddrs]ma.Multiaddr + // localMultiaddr => thin waist form with the count of the connections the multiaddr + // was seen on for tracking our local listen addresses + localAddrs map[string]*thinWaistWithCount +} - oas.host.Network().Notify((*obsAddrNotifiee)(oas)) - oas.refCount.Add(1) - go oas.worker() - return oas, nil +// NewObservedAddrManager returns a new address manager using peerstore.OwnObservedAddressTTL as the TTL. +func NewObservedAddrManager(listenAddrs, hostAddrs func() []ma.Multiaddr, + interfaceListenAddrs func() ([]ma.Multiaddr, error), normalize func(ma.Multiaddr) ma.Multiaddr) (*ObservedAddrManager, error) { + if normalize == nil { + normalize = func(addr ma.Multiaddr) ma.Multiaddr { return addr } + } + o := &ObservedAddrManager{ + externalAddrs: make(map[string]map[string]*observerSet), + connObservedTWAddrs: make(map[connMultiaddrs]ma.Multiaddr), + localAddrs: make(map[string]*thinWaistWithCount), + wch: make(chan observation, observedAddrManagerWorkerChannelSize), + addrRecordedNotif: make(chan struct{}, 1), + listenAddrs: listenAddrs, + interfaceListenAddrs: interfaceListenAddrs, + hostAddrs: hostAddrs, + normalize: normalize, + } + o.ctx, o.ctxCancel = context.WithCancel(context.Background()) + + o.wg.Add(1) + go o.worker() + return o, nil } // AddrsFor return all activated observed addresses associated with the given // (resolved) listen address. -func (oas *ObservedAddrManager) AddrsFor(addr ma.Multiaddr) (addrs []ma.Multiaddr) { - oas.mu.RLock() - defer oas.mu.RUnlock() - - if len(oas.addrs) == 0 { +func (o *ObservedAddrManager) AddrsFor(addr ma.Multiaddr) (addrs []ma.Multiaddr) { + if addr == nil { return nil } - - observedAddrs, ok := oas.addrs[string(addr.Bytes())] - if !ok { - return + o.mu.RLock() + defer o.mu.RUnlock() + tw, err := thinWaistForm(o.normalize(addr)) + if err != nil { + return nil } - return oas.filter(observedAddrs) + observerSets := o.getTopExternalAddrs(string(tw.TW.Bytes())) + res := make([]ma.Multiaddr, 0, len(observerSets)) + for _, s := range observerSets { + res = append(res, s.cacheMultiaddr(tw.Rest)) + } + return res } // Addrs return all activated observed addresses -func (oas *ObservedAddrManager) Addrs() []ma.Multiaddr { - oas.mu.RLock() - defer oas.mu.RUnlock() - - if len(oas.addrs) == 0 { - return nil - } - - var allObserved []*observedAddr - for _, addrs := range oas.addrs { - allObserved = append(allObserved, addrs...) +func (o *ObservedAddrManager) Addrs() []ma.Multiaddr { + o.mu.RLock() + defer o.mu.RUnlock() + + m := make(map[string][]*observerSet) + for localTWStr := range o.externalAddrs { + m[localTWStr] = append(m[localTWStr], o.getTopExternalAddrs(localTWStr)...) + } + addrs := make([]ma.Multiaddr, 0, maxExternalThinWaistAddrsPerLocalAddr*5) // assume 5 transports + for _, t := range o.localAddrs { + for _, s := range m[string(t.TW.Bytes())] { + addrs = append(addrs, s.cacheMultiaddr(t.Rest)) + } } - return oas.filter(allObserved) + return addrs } -func (oas *ObservedAddrManager) filter(observedAddrs []*observedAddr) []ma.Multiaddr { - pmap := make(map[string][]*observedAddr) - now := time.Now() - - for i := range observedAddrs { - a := observedAddrs[i] - if now.Sub(a.lastSeen) <= oas.ttl && a.activated() { - // group addresses by their IPX/Transport Protocol(TCP or UDP) pattern. - pat := a.groupKey() - pmap[pat] = append(pmap[pat], a) - +func (o *ObservedAddrManager) getTopExternalAddrs(localTWStr string) []*observerSet { + observerSets := make([]*observerSet, 0, len(o.externalAddrs[localTWStr])) + for _, v := range o.externalAddrs[localTWStr] { + if len(v.ObservedBy) >= ActivationThresh { + observerSets = append(observerSets, v) } } - - addrs := make([]ma.Multiaddr, 0, len(observedAddrs)) - for pat := range pmap { - s := pmap[pat] - - slices.SortFunc(s, func(first, second *observedAddr) int { - // We prefer inbound connection observations over outbound. - if first.numInbound > second.numInbound { - return -1 - } - // For ties, we prefer the ones with more votes. - if first.numInbound == second.numInbound && len(first.seenBy) > len(second.seenBy) { - return -1 - } + slices.SortFunc(observerSets, func(a, b *observerSet) int { + diff := len(b.ObservedBy) - len(a.ObservedBy) + if diff != 0 { + return diff + } + // In case we have elements with equal counts, + // keep the address list stable by using the lexicographically smaller address + as := a.ObservedTWAddr.String() + bs := b.ObservedTWAddr.String() + if as < bs { + return -1 + } else if as > bs { return 1 - }) - - for i := 0; i < maxObservedAddrsPerIPAndTransport && i < len(s); i++ { - addrs = append(addrs, s[i].addr) + } else { + return 0 } - } - return addrs + }) + n := len(observerSets) + if n > maxExternalThinWaistAddrsPerLocalAddr { + n = maxExternalThinWaistAddrsPerLocalAddr + } + return observerSets[:n] } -// Record records an address observation, if valid. -func (oas *ObservedAddrManager) Record(conn network.Conn, observed ma.Multiaddr) { +// Record enqueues an observation for recording +func (o *ObservedAddrManager) Record(conn connMultiaddrs, observed ma.Multiaddr) { select { - case oas.wch <- newObservation{ + case o.wch <- observation{ conn: conn, observed: observed, }: @@ -247,182 +278,72 @@ func (oas *ObservedAddrManager) Record(conn network.Conn, observed ma.Multiaddr) } } -func (oas *ObservedAddrManager) worker() { - defer oas.refCount.Done() - - ticker := time.NewTicker(GCInterval) - defer ticker.Stop() +func (o *ObservedAddrManager) worker() { + defer o.wg.Done() - subChan := oas.reachabilitySub.Out() for { select { - case evt, ok := <-subChan: - if !ok { - subChan = nil - continue - } - ev := evt.(event.EvtLocalReachabilityChanged) - oas.reachability = ev.Reachability - case obs := <-oas.wch: - oas.maybeRecordObservation(obs.conn, obs.observed) - case <-ticker.C: - oas.gc() - case <-oas.refreshTimer.C: - oas.refresh() - case <-oas.ctx.Done(): + case obs := <-o.wch: + o.maybeRecordObservation(obs.conn, obs.observed) + case <-o.ctx.Done(): return } } } -func (oas *ObservedAddrManager) refresh() { - oas.activeConnsMu.Lock() - recycledObservations := make([]newObservation, 0, len(oas.activeConns)) - for conn, observed := range oas.activeConns { - recycledObservations = append(recycledObservations, newObservation{ - conn: conn, - observed: observed, - }) +func (o *ObservedAddrManager) shouldRecordObservation(conn connMultiaddrs, observed ma.Multiaddr) (shouldRecord bool, localTW thinWaist, observedTW thinWaist) { + if conn == nil || observed == nil { + return false, thinWaist{}, thinWaist{} } - oas.activeConnsMu.Unlock() - - oas.mu.Lock() - defer oas.mu.Unlock() - for _, obs := range recycledObservations { - oas.recordObservationUnlocked(obs.conn, obs.observed) - } - // refresh every ttl/2 so we don't forget observations from connected peers - oas.refreshTimer.Reset(oas.ttl / 2) -} - -func (oas *ObservedAddrManager) gc() { - oas.mu.Lock() - defer oas.mu.Unlock() - - now := time.Now() - for local, observedAddrs := range oas.addrs { - filteredAddrs := observedAddrs[:0] - for _, a := range observedAddrs { - // clean up SeenBy set - for k, ob := range a.seenBy { - if now.Sub(ob.seenTime) > oas.ttl*time.Duration(ActivationThresh) { - delete(a.seenBy, k) - if ob.inbound { - a.numInbound-- - } - } - } - - // leave only alive observed addresses - if now.Sub(a.lastSeen) <= oas.ttl { - filteredAddrs = append(filteredAddrs, a) - } - } - if len(filteredAddrs) > 0 { - oas.addrs[local] = filteredAddrs - } else { - delete(oas.addrs, local) - } - } -} - -func (oas *ObservedAddrManager) addConn(conn network.Conn, observed ma.Multiaddr) { - oas.activeConnsMu.Lock() - defer oas.activeConnsMu.Unlock() - - // We need to make sure we haven't received a disconnect event for this - // connection yet. The only way to do that right now is to make sure the - // swarm still has the connection. - // - // Doing this under a lock that we _also_ take in a disconnect event - // handler ensures everything happens in the right order. - for _, c := range oas.host.Network().ConnsToPeer(conn.RemotePeer()) { - if c == conn { - oas.activeConns[conn] = observed - return - } - } -} - -func (oas *ObservedAddrManager) removeConn(conn network.Conn) { - // DO NOT remove this lock. - // This ensures we don't call addConn at the same time: - // 1. see that we have a connection and pause inside addConn right before recording it. - // 2. process a disconnect event. - // 3. record the connection (leaking it). - - oas.activeConnsMu.Lock() - delete(oas.activeConns, conn) - oas.activeConnsMu.Unlock() -} - -type normalizeMultiaddrer interface { - NormalizeMultiaddr(addr ma.Multiaddr) ma.Multiaddr -} - -type addrsProvider interface { - Addrs() []ma.Multiaddr -} - -type listenAddrsProvider interface { - ListenAddresses() []ma.Multiaddr - InterfaceListenAddresses() ([]ma.Multiaddr, error) -} - -func shouldRecordObservation(host addrsProvider, network listenAddrsProvider, conn network.ConnMultiaddrs, observed ma.Multiaddr) bool { - // First, determine if this observation is even worth keeping... - // Ignore observations from loopback nodes. We already know our loopback // addresses. if manet.IsIPLoopback(observed) { - return false + return false, thinWaist{}, thinWaist{} } // Provided by NAT64 peers, these addresses are specific to the peer and not publicly routable if manet.IsNAT64IPv4ConvertedIPv6Addr(observed) { - return false + return false, thinWaist{}, thinWaist{} } // we should only use ObservedAddr when our connection's LocalAddr is one // of our ListenAddrs. If we Dial out using an ephemeral addr, knowing that // address's external mapping is not very useful because the port will not be // the same as the listen addr. - ifaceaddrs, err := network.InterfaceListenAddresses() + ifaceaddrs, err := o.interfaceListenAddrs() if err != nil { log.Infof("failed to get interface listen addrs", err) - return false + return false, thinWaist{}, thinWaist{} } - normalizer, canNormalize := host.(normalizeMultiaddrer) - - if canNormalize { - for i, a := range ifaceaddrs { - ifaceaddrs[i] = normalizer.NormalizeMultiaddr(a) - } + for i, a := range ifaceaddrs { + ifaceaddrs[i] = o.normalize(a) } - local := conn.LocalMultiaddr() - if canNormalize { - local = normalizer.NormalizeMultiaddr(local) - } + local := o.normalize(conn.LocalMultiaddr()) - listenAddrs := network.ListenAddresses() - if canNormalize { - for i, a := range listenAddrs { - listenAddrs[i] = normalizer.NormalizeMultiaddr(a) - } + listenAddrs := o.listenAddrs() + for i, a := range listenAddrs { + listenAddrs[i] = o.normalize(a) } if !ma.Contains(ifaceaddrs, local) && !ma.Contains(listenAddrs, local) { // not in our list - return false + return false, thinWaist{}, thinWaist{} } - hostAddrs := host.Addrs() - if canNormalize { - for i, a := range hostAddrs { - hostAddrs[i] = normalizer.NormalizeMultiaddr(a) - } + localTW, err = thinWaistForm(local) + if err != nil { + return false, thinWaist{}, thinWaist{} + } + observedTW, err = thinWaistForm(o.normalize(observed)) + if err != nil { + return false, thinWaist{}, thinWaist{} + } + + hostAddrs := o.hostAddrs() + for i, a := range hostAddrs { + hostAddrs[i] = o.normalize(a) } // We should reject the connection if the observation doesn't match the @@ -434,207 +355,192 @@ func shouldRecordObservation(host addrsProvider, network listenAddrsProvider, co "from", conn.RemoteMultiaddr(), "observed", observed, ) - return false + return false, thinWaist{}, thinWaist{} } - return true + return true, localTW, observedTW } -func (oas *ObservedAddrManager) maybeRecordObservation(conn network.Conn, observed ma.Multiaddr) { - shouldRecord := shouldRecordObservation(oas.host, oas.host.Network(), conn, observed) - if shouldRecord { - // Ok, the observation is good, record it. - log.Debugw("added own observed listen addr", "observed", observed) - defer oas.addConn(conn, observed) - - oas.mu.Lock() - defer oas.mu.Unlock() - oas.recordObservationUnlocked(conn, observed) +func (o *ObservedAddrManager) maybeRecordObservation(conn connMultiaddrs, observed ma.Multiaddr) { + shouldRecord, localTW, observedTW := o.shouldRecordObservation(conn, observed) + if !shouldRecord { + return + } + log.Debugw("added own observed listen addr", "observed", observed) - if oas.reachability == network.ReachabilityPrivate { - oas.emitAllNATTypes() - } + o.mu.Lock() + defer o.mu.Unlock() + o.recordObservationUnlocked(conn, localTW, observedTW) + select { + case o.addrRecordedNotif <- struct{}{}: + default: } } -func (oas *ObservedAddrManager) recordObservationUnlocked(conn network.Conn, observed ma.Multiaddr) { - now := time.Now() - observerString := observerGroup(conn.RemoteMultiaddr()) - localString := string(conn.LocalMultiaddr().Bytes()) - ob := observation{ - seenTime: now, - inbound: conn.Stat().Direction == network.DirInbound, - } - - // check if observed address seen yet, if so, update it - for _, observedAddr := range oas.addrs[localString] { - if observedAddr.addr.Equal(observed) { - // Don't trump an outbound observation with an inbound - // one. - wasInbound := observedAddr.seenBy[observerString].inbound - isInbound := ob.inbound - ob.inbound = isInbound || wasInbound - - if !wasInbound && isInbound { - observedAddr.numInbound++ - } +func (o *ObservedAddrManager) recordObservationUnlocked(conn connMultiaddrs, localTW, observedTW thinWaist) { + if conn.IsClosed() { + // dont record if the connection is already closed. Any previous observations will be removed in + // the disconnected callback + return + } + localTWStr := string(localTW.TW.Bytes()) + observedTWStr := string(observedTW.TW.Bytes()) + observer, err := getObserver(conn.RemoteMultiaddr()) + if err != nil { + return + } - observedAddr.seenBy[observerString] = ob - observedAddr.lastSeen = now + prevObservedTWAddr, ok := o.connObservedTWAddrs[conn] + if !ok { + t, ok := o.localAddrs[string(localTW.Addr.Bytes())] + if !ok { + t = &thinWaistWithCount{ + thinWaist: localTW, + } + o.localAddrs[string(localTW.Addr.Bytes())] = t + } + t.Count++ + } else { + if prevObservedTWAddr.Equal(observedTW.TW) { + // we have received the same observation again, nothing to do return } + // if we have a previous entry remove it from externalAddrs + o.removeExternalAddrsUnlocked(observer, localTWStr, string(prevObservedTWAddr.Bytes())) + // no need to change the localAddrs map here } + o.connObservedTWAddrs[conn] = observedTW.TW + o.addExternalAddrsUnlocked(observedTW.TW, observer, localTWStr, observedTWStr) +} - // observed address not seen yet, append it - oa := &observedAddr{ - addr: observed, - seenBy: map[string]observation{ - observerString: ob, - }, - lastSeen: now, +func (o *ObservedAddrManager) removeExternalAddrsUnlocked(observer, localTWStr, observedTWStr string) { + s, ok := o.externalAddrs[localTWStr][observedTWStr] + if !ok { + return + } + s.ObservedBy[observer]-- + if s.ObservedBy[observer] <= 0 { + delete(s.ObservedBy, observer) + } + if len(s.ObservedBy) == 0 { + delete(o.externalAddrs[localTWStr], observedTWStr) + } + if len(o.externalAddrs[localTWStr]) == 0 { + delete(o.externalAddrs, localTWStr) } - if ob.inbound { - oa.numInbound++ +} + +func (o *ObservedAddrManager) addExternalAddrsUnlocked(observedTWAddr ma.Multiaddr, observer, localTWStr, observedTWStr string) { + s, ok := o.externalAddrs[localTWStr][observedTWStr] + if !ok { + s = &observerSet{ + ObservedTWAddr: observedTWAddr, + ObservedBy: make(map[string]int), + } + if _, ok := o.externalAddrs[localTWStr]; !ok { + o.externalAddrs[localTWStr] = make(map[string]*observerSet) + } + o.externalAddrs[localTWStr][observedTWStr] = s } - oas.addrs[localString] = append(oas.addrs[localString], oa) + s.ObservedBy[observer]++ } -// For a given transport Protocol (TCP/UDP): -// -// 1. If we have an activated address, we are behind an Cone NAT. -// With regards to RFC 3489, this could be either a Full Cone NAT, a Restricted Cone NAT or a -// Port Restricted Cone NAT. However, we do NOT differentiate between them here and simply classify all such NATs as a Cone NAT. -// -// 2. If four different peers observe a different address for us on outbound connections, we -// are MOST probably behind a Symmetric NAT. -// -// Please see the documentation on the enumerations for `network.NATDeviceType` for more details about these NAT Device types -// and how they relate to NAT traversal via Hole Punching. -func (oas *ObservedAddrManager) emitAllNATTypes() { - var allObserved []*observedAddr - for _, addrs := range oas.addrs { - allObserved = append(allObserved, addrs...) +func (o *ObservedAddrManager) removeConn(conn connMultiaddrs) { + if conn == nil { + return } + o.mu.Lock() + defer o.mu.Unlock() - hasChanged, natType := oas.emitSpecificNATType(allObserved, ma.P_TCP, network.NATTransportTCP, oas.currentTCPNATDeviceType) - if hasChanged { - oas.currentTCPNATDeviceType = natType + // normalize before obtaining the thinWaist so that we are always dealing + // with the normalized form of the address + localTW, err := thinWaistForm(o.normalize(conn.LocalMultiaddr())) + if err != nil { + return + } + t, ok := o.localAddrs[string(localTW.Addr.Bytes())] + if !ok { + return + } + t.Count-- + if t.Count <= 0 { + delete(o.localAddrs, string(localTW.Addr.Bytes())) } - hasChanged, natType = oas.emitSpecificNATType(allObserved, ma.P_UDP, network.NATTransportUDP, oas.currentUDPNATDeviceType) - if hasChanged { - oas.currentUDPNATDeviceType = natType + observedTWAddr, ok := o.connObservedTWAddrs[conn] + if !ok { + return + } + delete(o.connObservedTWAddrs, conn) + observer, err := getObserver(conn.RemoteMultiaddr()) + if err != nil { + return } -} -// returns true along with the new NAT device type if the NAT device type for the given protocol has changed. -// returns false otherwise. -func (oas *ObservedAddrManager) emitSpecificNATType(addrs []*observedAddr, protoCode int, transportProto network.NATTransportProtocol, - currentNATType network.NATDeviceType) (bool, network.NATDeviceType) { - now := time.Now() - seenBy := make(map[string]struct{}) - cnt := 0 - - for _, oa := range addrs { - _, err := oa.addr.ValueForProtocol(protoCode) - if err != nil { - continue - } + o.removeExternalAddrsUnlocked(observer, string(localTW.TW.Bytes()), string(observedTWAddr.Bytes())) + select { + case o.addrRecordedNotif <- struct{}{}: + default: + } +} - // if we have an activated addresses, it's a Cone NAT. - if now.Sub(oa.lastSeen) <= oas.ttl && oa.activated() { - if currentNATType != network.NATDeviceTypeCone { - oas.emitNATDeviceTypeChanged.Emit(event.EvtNATDeviceTypeChanged{ - TransportProtocol: transportProto, - NatDeviceType: network.NATDeviceTypeCone, - }) - return true, network.NATDeviceTypeCone +func (o *ObservedAddrManager) getNATType() (tcpNATType, udpNATType network.NATDeviceType) { + o.mu.RLock() + defer o.mu.RUnlock() + + var tcpCounts, udpCounts []int + var tcpTotal, udpTotal int + for _, m := range o.externalAddrs { + isTCP := false + for _, v := range m { + if _, err := v.ObservedTWAddr.ValueForProtocol(ma.P_TCP); err == nil { + isTCP = true } - - // our current NAT Device Type is already CONE, nothing to do here. - return false, 0 + break } - - // An observed address on an outbound connection that has ONLY been seen by one peer - if now.Sub(oa.lastSeen) <= oas.ttl && oa.numInbound == 0 && len(oa.seenBy) == 1 { - cnt++ - for s := range oa.seenBy { - seenBy[s] = struct{}{} + for _, v := range m { + if isTCP { + tcpCounts = append(tcpCounts, len(v.ObservedBy)) + tcpTotal += len(v.ObservedBy) + } else { + udpCounts = append(udpCounts, len(v.ObservedBy)) + udpTotal += len(v.ObservedBy) } } } - // If four different peers observe a different address for us on each of four outbound connections, we - // are MOST probably behind a Symmetric NAT. - if cnt >= ActivationThresh && len(seenBy) >= ActivationThresh { - if currentNATType != network.NATDeviceTypeSymmetric { - oas.emitNATDeviceTypeChanged.Emit(event.EvtNATDeviceTypeChanged{ - TransportProtocol: transportProto, - NatDeviceType: network.NATDeviceTypeSymmetric, - }) - return true, network.NATDeviceTypeSymmetric - } - } + sort.Sort(sort.Reverse(sort.IntSlice(tcpCounts))) + sort.Sort(sort.Reverse(sort.IntSlice(udpCounts))) - return false, 0 -} - -func (oas *ObservedAddrManager) Close() error { - oas.closeOnce.Do(func() { - oas.ctxCancel() - - oas.mu.Lock() - oas.closed = true - oas.refreshTimer.Stop() - oas.mu.Unlock() - - oas.refCount.Wait() - oas.reachabilitySub.Close() - oas.host.Network().StopNotify((*obsAddrNotifiee)(oas)) - }) - return nil -} - -// observerGroup is a function that determines what part of -// a multiaddr counts as a different observer. for example, -// two ipfs nodes at the same IP/TCP transport would get -// the exact same NAT mapping; they would count as the -// same observer. This may protect against NATs who assign -// different ports to addresses at different IP hosts, but -// not TCP ports. -// -// Here, we use the root multiaddr address. This is mostly -// IP addresses. In practice, this is what we want. -func observerGroup(m ma.Multiaddr) string { - // TODO: If IPv6 rolls out we should mark /64 routing zones as one group - first, _ := ma.SplitFirst(m) - return string(first.Bytes()) -} - -// SetTTL sets the TTL of an observed address manager. -func (oas *ObservedAddrManager) SetTTL(ttl time.Duration) { - oas.mu.Lock() - defer oas.mu.Unlock() - if oas.closed { - return + tcpTopCounts, udpTopCounts := 0, 0 + for i := 0; i < maxExternalThinWaistAddrsPerLocalAddr && i < len(tcpCounts); i++ { + tcpTopCounts += tcpCounts[i] + } + for i := 0; i < maxExternalThinWaistAddrsPerLocalAddr && i < len(udpCounts); i++ { + udpTopCounts += udpCounts[i] } - oas.ttl = ttl - // refresh every ttl/2 so we don't forget observations from connected peers - oas.refreshTimer.Reset(ttl / 2) -} -// TTL gets the TTL of an observed address manager. -func (oas *ObservedAddrManager) TTL() time.Duration { - oas.mu.RLock() - defer oas.mu.RUnlock() - return oas.ttl + // If the top elements cover more than 1/2 of all the observations, there's a > 50% chance that + // hole punching based on outputs of observed address manager will succeed + if tcpTotal >= 3*maxExternalThinWaistAddrsPerLocalAddr { + if tcpTopCounts >= tcpTotal/2 { + tcpNATType = network.NATDeviceTypeCone + } else { + tcpNATType = network.NATDeviceTypeSymmetric + } + } + if udpTotal >= 3*maxExternalThinWaistAddrsPerLocalAddr { + if udpTopCounts >= udpTotal/2 { + udpNATType = network.NATDeviceTypeCone + } else { + udpNATType = network.NATDeviceTypeSymmetric + } + } + return } -type obsAddrNotifiee ObservedAddrManager - -func (on *obsAddrNotifiee) Listen(n network.Network, a ma.Multiaddr) {} -func (on *obsAddrNotifiee) ListenClose(n network.Network, a ma.Multiaddr) {} -func (on *obsAddrNotifiee) Connected(n network.Network, v network.Conn) {} -func (on *obsAddrNotifiee) Disconnected(n network.Network, v network.Conn) { - (*ObservedAddrManager)(on).removeConn(v) +func (o *ObservedAddrManager) Close() error { + o.ctxCancel() + o.wg.Wait() + return nil } diff --git a/p2p/protocol/identify/obsaddr_glass_test.go b/p2p/protocol/identify/obsaddr_glass_test.go index d7535de830..31fd4f5726 100644 --- a/p2p/protocol/identify/obsaddr_glass_test.go +++ b/p2p/protocol/identify/obsaddr_glass_test.go @@ -5,77 +5,16 @@ package identify import ( "fmt" + "sync/atomic" "testing" ma "github.com/multiformats/go-multiaddr" "github.com/stretchr/testify/require" ) -func TestObservedAddrGroupKey(t *testing.T) { - oa1 := &observedAddr{addr: ma.StringCast("/ip4/1.2.3.4/tcp/2345")} - oa2 := &observedAddr{addr: ma.StringCast("/ip4/1.2.3.4/tcp/1231")} - oa3 := &observedAddr{addr: ma.StringCast("/ip4/1.2.3.5/tcp/1231")} - oa4 := &observedAddr{addr: ma.StringCast("/ip4/1.2.3.4/udp/1231")} - oa5 := &observedAddr{addr: ma.StringCast("/ip4/1.2.3.4/udp/1531")} - oa6 := &observedAddr{addr: ma.StringCast("/ip4/1.2.3.4/udp/1531/quic-v1")} - oa7 := &observedAddr{addr: ma.StringCast("/ip4/1.2.3.4/udp/1111/quic-v1")} - oa8 := &observedAddr{addr: ma.StringCast("/ip4/1.2.3.5/udp/1111/quic-v1")} - - // different ports, same IP => same key - require.Equal(t, oa1.groupKey(), oa2.groupKey()) - // different IPs => different key - require.NotEqual(t, oa2.groupKey(), oa3.groupKey()) - // same port, different protos => different keys - require.NotEqual(t, oa3.groupKey(), oa4.groupKey()) - // same port, same address, different protos => different keys - require.NotEqual(t, oa2.groupKey(), oa4.groupKey()) - // udp works as well - require.Equal(t, oa4.groupKey(), oa5.groupKey()) - // udp and quic are different - require.NotEqual(t, oa5.groupKey(), oa6.groupKey()) - // quic works as well - require.Equal(t, oa6.groupKey(), oa7.groupKey()) - require.NotEqual(t, oa7.groupKey(), oa8.groupKey()) -} - -type mockHost struct { - addrs []ma.Multiaddr - listenAddrs []ma.Multiaddr - ifaceListenAddrs []ma.Multiaddr -} - -// InterfaceListenAddresses implements listenAddrsProvider -func (h *mockHost) InterfaceListenAddresses() ([]ma.Multiaddr, error) { - return h.ifaceListenAddrs, nil -} - -// ListenAddresses implements listenAddrsProvider -func (h *mockHost) ListenAddresses() []ma.Multiaddr { - return h.listenAddrs -} - -// Addrs implements addrsProvider -func (h *mockHost) Addrs() []ma.Multiaddr { - return h.addrs -} - -// NormalizeMultiaddr implements normalizeMultiaddrer -func (h *mockHost) NormalizeMultiaddr(m ma.Multiaddr) ma.Multiaddr { - original := m - for { - rest, tail := ma.SplitLast(m) - if rest == nil { - return original - } - if tail.Protocol().Code == ma.P_WEBTRANSPORT { - return m - } - m = rest - } -} - type mockConn struct { local, remote ma.Multiaddr + isClosed atomic.Bool } // LocalMultiaddr implements connMultiaddrProvider @@ -88,21 +27,30 @@ func (c *mockConn) RemoteMultiaddr() ma.Multiaddr { return c.remote } +func (c *mockConn) Close() { + c.isClosed.Store(true) +} + +func (c *mockConn) IsClosed() bool { + return c.isClosed.Load() +} + func TestShouldRecordObservationWithWebTransport(t *testing.T) { listenAddr := ma.StringCast("/ip4/0.0.0.0/udp/0/quic-v1/webtransport/certhash/uEgNmb28") ifaceAddr := ma.StringCast("/ip4/10.0.0.2/udp/9999/quic-v1/webtransport/certhash/uEgNmb28") - h := &mockHost{ - listenAddrs: []ma.Multiaddr{listenAddr}, - ifaceListenAddrs: []ma.Multiaddr{ifaceAddr}, - addrs: []ma.Multiaddr{listenAddr}, - } + listenAddrs := func() []ma.Multiaddr { return []ma.Multiaddr{listenAddr} } + ifaceListenAddrs := func() ([]ma.Multiaddr, error) { return []ma.Multiaddr{ifaceAddr}, nil } + addrs := func() []ma.Multiaddr { return []ma.Multiaddr{listenAddr} } + c := &mockConn{ local: listenAddr, remote: ma.StringCast("/ip4/1.2.3.6/udp/1236/quic-v1/webtransport"), } observedAddr := ma.StringCast("/ip4/1.2.3.4/udp/1231/quic-v1/webtransport") - - require.True(t, shouldRecordObservation(h, h, c, observedAddr)) + o, err := NewObservedAddrManager(listenAddrs, addrs, ifaceListenAddrs, normalize) + require.NoError(t, err) + shouldRecord, _, _ := o.shouldRecordObservation(c, observedAddr) + require.True(t, shouldRecord) } func TestShouldRecordObservationWithNAT64Addr(t *testing.T) { @@ -111,11 +59,11 @@ func TestShouldRecordObservationWithNAT64Addr(t *testing.T) { listenAddr2 := ma.StringCast("/ip6/::/tcp/1234") ifaceAddr2 := ma.StringCast("/ip6/1::1/tcp/4321") - h := &mockHost{ - listenAddrs: []ma.Multiaddr{listenAddr1, listenAddr2}, - ifaceListenAddrs: []ma.Multiaddr{ifaceAddr1, ifaceAddr2}, - addrs: []ma.Multiaddr{listenAddr1, listenAddr2}, - } + var ( + listenAddrs = func() []ma.Multiaddr { return []ma.Multiaddr{listenAddr1, listenAddr2} } + ifaceListenAddrs = func() ([]ma.Multiaddr, error) { return []ma.Multiaddr{ifaceAddr1, ifaceAddr2}, nil } + addrs = func() []ma.Multiaddr { return []ma.Multiaddr{listenAddr1, listenAddr2} } + ) c := &mockConn{ local: listenAddr1, remote: ma.StringCast("/ip4/1.2.3.6/tcp/4321"), @@ -142,12 +90,70 @@ func TestShouldRecordObservationWithNAT64Addr(t *testing.T) { failureReason: "NAT64 IPv6 address shouldn't be observed", }, } + + o, err := NewObservedAddrManager(listenAddrs, addrs, ifaceListenAddrs, normalize) + require.NoError(t, err) for i, tc := range cases { t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { - if shouldRecordObservation(h, h, c, tc.addr) != tc.want { + if shouldRecord, _, _ := o.shouldRecordObservation(c, tc.addr); shouldRecord != tc.want { t.Fatalf("%s %s", tc.addr, tc.failureReason) } }) } } + +func TestThinWaistForm(t *testing.T) { + tc := []struct { + input string + tw string + rest string + err bool + }{{ + input: "/ip4/1.2.3.4/tcp/1", + tw: "/ip4/1.2.3.4/tcp/1", + rest: "", + }, { + input: "/ip4/1.2.3.4/tcp/1/ws", + tw: "/ip4/1.2.3.4/tcp/1", + rest: "/ws", + }, { + input: "/ip4/127.0.0.1/udp/1/quic-v1", + tw: "/ip4/127.0.0.1/udp/1", + rest: "/quic-v1", + }, { + input: "/ip4/1.2.3.4/udp/1/quic-v1/webtransport", + tw: "/ip4/1.2.3.4/udp/1", + rest: "/quic-v1/webtransport", + }, { + input: "/ip4/1.2.3.4/", + err: true, + }, { + input: "/tcp/1", + err: true, + }, { + input: "/ip6/::1/tcp/1", + tw: "/ip6/::1/tcp/1", + rest: "", + }} + for i, tt := range tc { + t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { + inputAddr := ma.StringCast(tt.input) + tw, err := thinWaistForm(inputAddr) + if tt.err { + require.Equal(t, tw, thinWaist{}) + require.Error(t, err) + return + } + wantTW := ma.StringCast(tt.tw) + var restTW ma.Multiaddr + if tt.rest != "" { + restTW = ma.StringCast(tt.rest) + } + require.Equal(t, tw.Addr, inputAddr, "%s %s", tw.Addr, inputAddr) + require.Equal(t, wantTW, tw.TW, "%s %s", tw.TW, wantTW) + require.Equal(t, restTW, tw.Rest, "%s %s", restTW, tw.Rest) + }) + } + +} diff --git a/p2p/protocol/identify/obsaddr_test.go b/p2p/protocol/identify/obsaddr_test.go index a738dc5384..9c2d8dee57 100644 --- a/p2p/protocol/identify/obsaddr_test.go +++ b/p2p/protocol/identify/obsaddr_test.go @@ -1,435 +1,662 @@ -package identify_test +package identify import ( - "crypto/rand" + crand "crypto/rand" + "fmt" + "net" "testing" "time" - ic "github.com/libp2p/go-libp2p/core/crypto" "github.com/libp2p/go-libp2p/core/event" - "github.com/libp2p/go-libp2p/core/host" "github.com/libp2p/go-libp2p/core/network" - "github.com/libp2p/go-libp2p/core/peer" + blankhost "github.com/libp2p/go-libp2p/p2p/host/blank" "github.com/libp2p/go-libp2p/p2p/host/eventbus" - mocknet "github.com/libp2p/go-libp2p/p2p/net/mock" - "github.com/libp2p/go-libp2p/p2p/protocol/identify" + swarmt "github.com/libp2p/go-libp2p/p2p/net/swarm/testing" ma "github.com/multiformats/go-multiaddr" + manet "github.com/multiformats/go-multiaddr/net" "github.com/stretchr/testify/require" ) -type harness struct { - t *testing.T - - mocknet mocknet.Mocknet - host host.Host - - oas *identify.ObservedAddrManager +func newConn(local, remote ma.Multiaddr) *mockConn { + return &mockConn{local: local, remote: remote} } -func (h *harness) add(observer ma.Multiaddr) peer.ID { - // create a new fake peer. - sk, _, err := ic.GenerateECDSAKeyPair(rand.Reader) - if err != nil { - h.t.Fatal(err) - } - h2, err := h.mocknet.AddPeer(sk, observer) - if err != nil { - h.t.Fatal(err) - } - _, err = h.mocknet.LinkPeers(h.host.ID(), h2.ID()) - if err != nil { - h.t.Fatal(err) +func normalize(addr ma.Multiaddr) ma.Multiaddr { + for { + out, last := ma.SplitLast(addr) + if last == nil { + return addr + } + if _, err := last.ValueForProtocol(ma.P_CERTHASH); err != nil { + return addr + } + addr = out } - return h2.ID() } -func (h *harness) conn(observer peer.ID) network.Conn { - c, err := h.mocknet.ConnectPeers(h.host.ID(), observer) - if err != nil { - h.t.Fatal(err) +func addrsEqual(a, b []ma.Multiaddr) bool { + if len(b) != len(a) { + return false } - if c.Stat().Direction != network.DirOutbound { - h.t.Fatal("expected conn direction to be outbound") + for _, x := range b { + found := false + for _, y := range a { + if y.Equal(x) { + found = true + break + } + } + if !found { + return false + } + } + for _, x := range a { + found := false + for _, y := range b { + if y.Equal(x) { + found = true + break + } + } + if !found { + return false + } } - return c + return true } -func (h *harness) connInbound(observer peer.ID) network.Conn { - c, err := h.mocknet.ConnectPeers(observer, h.host.ID()) - if err != nil { - h.t.Fatal(err) +func TestObservedAddrManager(t *testing.T) { + tcp4ListenAddr := ma.StringCast("/ip4/192.168.1.100/tcp/1") + quic4ListenAddr := ma.StringCast("/ip4/0.0.0.0/udp/1/quic-v1") + webTransport4ListenAddr := ma.StringCast("/ip4/0.0.0.0/udp/1/quic-v1/webtransport/certhash/uEgNmb28") + tcp6ListenAddr := ma.StringCast("/ip6/2004::1/tcp/1") + quic6ListenAddr := ma.StringCast("/ip6/::/udp/1/quic-v1") + webTransport6ListenAddr := ma.StringCast("/ip6/::/udp/1/quic-v1/webtransport/certhash/uEgNmb28") + newObservedAddrMgr := func() *ObservedAddrManager { + listenAddrs := []ma.Multiaddr{ + tcp4ListenAddr, quic4ListenAddr, webTransport4ListenAddr, tcp6ListenAddr, quic6ListenAddr, webTransport6ListenAddr, + } + listenAddrsFunc := func() []ma.Multiaddr { + return listenAddrs + } + interfaceListenAddrsFunc := func() ([]ma.Multiaddr, error) { + return listenAddrs, nil + } + o, err := NewObservedAddrManager(listenAddrsFunc, listenAddrsFunc, + interfaceListenAddrsFunc, normalize) + if err != nil { + t.Fatal(err) + } + return o } - c = mocknet.ConnComplement(c) - if c.Stat().Direction != network.DirInbound { - h.t.Fatal("expected conn direction to be inbound") + checkAllEntriesRemoved := func(o *ObservedAddrManager) bool { + return len(o.Addrs()) == 0 && len(o.externalAddrs) == 0 && len(o.connObservedTWAddrs) == 0 && len(o.localAddrs) == 0 } - return c -} + t.Run("Single Observation", func(t *testing.T) { + o := newObservedAddrMgr() + defer o.Close() + observed := ma.StringCast("/ip4/2.2.2.2/tcp/2") + c1 := newConn(tcp4ListenAddr, ma.StringCast("/ip4/1.2.3.1/tcp/1")) + c2 := newConn(tcp4ListenAddr, ma.StringCast("/ip4/1.2.3.2/tcp/1")) + c3 := newConn(tcp4ListenAddr, ma.StringCast("/ip4/1.2.3.3/tcp/1")) + c4 := newConn(tcp4ListenAddr, ma.StringCast("/ip4/1.2.3.4/tcp/1")) + o.Record(c1, observed) + o.Record(c2, observed) + o.Record(c3, observed) + o.Record(c4, observed) + require.Eventually(t, func() bool { + return addrsEqual(o.Addrs(), []ma.Multiaddr{observed}) + }, 1*time.Second, 100*time.Millisecond) + o.removeConn(c1) + o.removeConn(c2) + o.removeConn(c3) + o.removeConn(c4) + require.Eventually(t, func() bool { + return checkAllEntriesRemoved(o) + }, 1*time.Second, 100*time.Millisecond) + }) -func (h *harness) observe(observed ma.Multiaddr, observer peer.ID) network.Conn { - c := h.conn(observer) - h.oas.Record(c, observed) - time.Sleep(200 * time.Millisecond) // let the worker run - return c -} + t.Run("WebTransport inferred from QUIC", func(t *testing.T) { + o := newObservedAddrMgr() + defer o.Close() + observedQuic := ma.StringCast("/ip4/2.2.2.2/udp/2/quic-v1") + observedWebTransport := ma.StringCast("/ip4/2.2.2.2/udp/2/quic-v1/webtransport") + c1 := newConn(quic4ListenAddr, ma.StringCast("/ip4/1.2.3.1/udp/1/quic-v1")) + c2 := newConn(quic4ListenAddr, ma.StringCast("/ip4/1.2.3.2/udp/1/quic-v1")) + c3 := newConn(webTransport4ListenAddr, ma.StringCast("/ip4/1.2.3.3/udp/1/quic-v1/webtransport")) + c4 := newConn(webTransport4ListenAddr, ma.StringCast("/ip4/1.2.3.4/udp/1/quic-v1/webtransport")) + o.Record(c1, observedQuic) + o.Record(c2, observedQuic) + o.Record(c3, observedWebTransport) + o.Record(c4, observedWebTransport) + require.Eventually(t, func() bool { + return addrsEqual(o.Addrs(), []ma.Multiaddr{observedQuic, observedWebTransport}) + }, 1*time.Second, 100*time.Millisecond) + o.removeConn(c1) + o.removeConn(c2) + o.removeConn(c3) + o.removeConn(c4) + require.Eventually(t, func() bool { + return checkAllEntriesRemoved(o) + }, 1*time.Second, 100*time.Millisecond) + }) -func (h *harness) observeInbound(observed ma.Multiaddr, observer peer.ID) network.Conn { - c := h.connInbound(observer) - h.oas.Record(c, observed) - time.Sleep(200 * time.Millisecond) // let the worker run - return c -} + t.Run("SameObservers", func(t *testing.T) { + o := newObservedAddrMgr() + defer o.Close() -func newHarness(t *testing.T) harness { - return newHarnessWithMa(t, ma.StringCast("/ip4/127.0.0.1/tcp/10086")) -} + observedQuic := ma.StringCast("/ip4/2.2.2.2/udp/2/quic-v1") + + const N = 4 // ActivationThresh + var ob1, ob2 [N]connMultiaddrs + for i := 0; i < N; i++ { + ob1[i] = newConn(quic4ListenAddr, ma.StringCast(fmt.Sprintf("/ip4/1.2.3.%d/udp/1/quic-v1", i))) + ob2[i] = newConn(quic4ListenAddr, ma.StringCast(fmt.Sprintf("/ip4/1.2.3.%d/udp/1/quic-v1", i))) + } + for i := 0; i < N-1; i++ { + o.Record(ob1[i], observedQuic) + o.Record(ob2[i], observedQuic) + } + time.Sleep(100 * time.Millisecond) + require.Equal(t, o.Addrs(), []ma.Multiaddr{}) + + // We should have a valid address now + o.Record(ob1[N-1], observedQuic) + o.Record(ob2[N-1], observedQuic) + require.Eventually(t, func() bool { + return addrsEqual(o.Addrs(), []ma.Multiaddr{observedQuic}) + }, 2*time.Second, 100*time.Millisecond) + + // Now disconnect first observer group + for i := 0; i < N; i++ { + o.removeConn(ob1[i]) + } + time.Sleep(100 * time.Millisecond) + if !addrsEqual(o.Addrs(), []ma.Multiaddr{observedQuic}) { + t.Fatalf("address removed too earyly %v %v", o.Addrs(), observedQuic) + } -func newHarnessWithMa(t *testing.T, listenAddr ma.Multiaddr) harness { - mn := mocknet.New() - sk, _, err := ic.GenerateECDSAKeyPair(rand.Reader) - require.NoError(t, err) - h, err := mn.AddPeer(sk, listenAddr) - require.NoError(t, err) - oas, err := identify.NewObservedAddrManager(h) - require.NoError(t, err) - t.Cleanup(func() { - mn.Close() - oas.Close() + // Now disconnect the second group to check cleanup + for i := 0; i < N; i++ { + o.removeConn(ob2[i]) + } + require.Eventually(t, func() bool { + return checkAllEntriesRemoved(o) + }, 2*time.Second, 100*time.Millisecond) }) - return harness{ - oas: oas, - mocknet: mn, - host: h, - t: t, - } -} + t.Run("SameObserversDifferentAddrs", func(t *testing.T) { + o := newObservedAddrMgr() + defer o.Close() + + observedQuic1 := ma.StringCast("/ip4/2.2.2.2/udp/2/quic-v1") + observedQuic2 := ma.StringCast("/ip4/2.2.2.2/udp/3/quic-v1") + + const N = 4 // ActivationThresh + var ob1, ob2 [N]connMultiaddrs + for i := 0; i < N; i++ { + ob1[i] = newConn(quic4ListenAddr, ma.StringCast(fmt.Sprintf("/ip4/1.2.3.%d/udp/1/quic-v1", i))) + ob2[i] = newConn(quic4ListenAddr, ma.StringCast(fmt.Sprintf("/ip4/1.2.3.%d/udp/1/quic-v1", i))) + } + for i := 0; i < N-1; i++ { + o.Record(ob1[i], observedQuic1) + o.Record(ob2[i], observedQuic2) + } + time.Sleep(100 * time.Millisecond) + require.Equal(t, o.Addrs(), []ma.Multiaddr{}) + + // We should have a valid address now + o.Record(ob1[N-1], observedQuic1) + o.Record(ob2[N-1], observedQuic2) + require.Eventually(t, func() bool { + return addrsEqual(o.Addrs(), []ma.Multiaddr{observedQuic1, observedQuic2}) + }, 2*time.Second, 100*time.Millisecond) + + // Now disconnect first observer group + for i := 0; i < N; i++ { + o.removeConn(ob1[i]) + } + time.Sleep(100 * time.Millisecond) + if !addrsEqual(o.Addrs(), []ma.Multiaddr{observedQuic2}) { + t.Fatalf("address removed too earyly %v %v", o.Addrs(), observedQuic2) + } -// TestObsAddrSet -func TestObsAddrSet(t *testing.T) { - addrsMatch := func(a, b []ma.Multiaddr) bool { - if len(a) != len(b) { - return false + // Now disconnect the second group to check cleanup + for i := 0; i < N; i++ { + o.removeConn(ob2[i]) } - for _, aa := range a { - if !ma.Contains(b, aa) { - return false - } + require.Eventually(t, func() bool { + return checkAllEntriesRemoved(o) + }, 2*time.Second, 100*time.Millisecond) + }) + + t.Run("Old observations discarded", func(t *testing.T) { + o := newObservedAddrMgr() + defer o.Close() + c1 := newConn(quic4ListenAddr, ma.StringCast("/ip4/1.2.3.1/udp/1/quic-v1")) + c2 := newConn(quic4ListenAddr, ma.StringCast("/ip4/1.2.3.2/udp/1/quic-v1")) + c3 := newConn(webTransport4ListenAddr, ma.StringCast("/ip4/1.2.3.3/udp/1/quic-v1/webtransport")) + c4 := newConn(webTransport4ListenAddr, ma.StringCast("/ip4/1.2.3.4/udp/1/quic-v1/webtransport")) + var observedQuic, observedWebTransport ma.Multiaddr + for i := 0; i < 10; i++ { + // Change the IP address in each observation + observedQuic = ma.StringCast(fmt.Sprintf("/ip4/2.2.2.%d/udp/2/quic-v1", i)) + observedWebTransport = ma.StringCast(fmt.Sprintf("/ip4/2.2.2.%d/udp/2/quic-v1/webtransport", i)) + o.Record(c1, observedQuic) + o.Record(c2, observedQuic) + o.Record(c3, observedWebTransport) + o.Record(c4, observedWebTransport) + time.Sleep(20 * time.Millisecond) } - return true - } - a1 := ma.StringCast("/ip4/1.2.3.4/tcp/1231") - a2 := ma.StringCast("/ip4/1.2.3.4/tcp/1232") - a3 := ma.StringCast("/ip4/1.2.3.4/tcp/1233") - a4 := ma.StringCast("/ip4/1.2.3.4/tcp/1234") - a5 := ma.StringCast("/ip4/1.2.3.4/tcp/1235") - - b1 := ma.StringCast("/ip4/1.2.3.6/tcp/1236") - b2 := ma.StringCast("/ip4/1.2.3.7/tcp/1237") - b3 := ma.StringCast("/ip4/1.2.3.8/tcp/1237") - b4 := ma.StringCast("/ip4/1.2.3.9/tcp/1237") - b5 := ma.StringCast("/ip4/1.2.3.10/tcp/1237") - - harness := newHarness(t) - if !addrsMatch(harness.oas.Addrs(), nil) { - t.Error("addrs should be empty") - } + require.Eventually(t, func() bool { + return addrsEqual(o.Addrs(), []ma.Multiaddr{observedQuic, observedWebTransport}) + }, 1*time.Second, 100*time.Millisecond) - pa4 := harness.add(a4) - pa5 := harness.add(a5) + tw, err := thinWaistForm(quic4ListenAddr) + require.NoError(t, err) + require.Less(t, len(o.externalAddrs[string(tw.TW.Bytes())]), 2) - pb1 := harness.add(b1) - pb2 := harness.add(b2) - pb3 := harness.add(b3) - pb4 := harness.add(b4) - pb5 := harness.add(b5) + require.Equal(t, o.AddrsFor(webTransport4ListenAddr), []ma.Multiaddr{observedWebTransport}) + require.Equal(t, o.AddrsFor(quic4ListenAddr), []ma.Multiaddr{observedQuic}) - harness.observe(a1, pa4) - harness.observe(a2, pa4) - harness.observe(a3, pa4) + o.removeConn(c1) + o.removeConn(c2) + o.removeConn(c3) + o.removeConn(c4) + require.Eventually(t, func() bool { + return checkAllEntriesRemoved(o) + }, 1*time.Second, 100*time.Millisecond) + }) - // these are all different so we should not yet get them. - if !addrsMatch(harness.oas.Addrs(), nil) { - t.Error("addrs should _still_ be empty (once)") - } + t.Run("Many connection many observations", func(t *testing.T) { + o := newObservedAddrMgr() + defer o.Close() + const N = 100 + var tcpConns, quicConns, webTransportConns [N]*mockConn + for i := 0; i < N; i++ { + tcpConns[i] = newConn(tcp4ListenAddr, ma.StringCast(fmt.Sprintf("/ip4/1.2.3.%d/tcp/1", i))) + quicConns[i] = newConn(quic4ListenAddr, ma.StringCast(fmt.Sprintf("/ip4/1.2.3.%d/udp/1/quic-v1", i))) + webTransportConns[i] = newConn(webTransport4ListenAddr, ma.StringCast(fmt.Sprintf("/ip4/1.2.3.%d/udp/1/quic-v1/webtransport", i))) + } + var observedQuic, observedWebTransport, observedTCP ma.Multiaddr + for i := 0; i < N; i++ { + for j := 0; j < 5; j++ { + // ip addr has the form 2.2.. + observedQuic = ma.StringCast(fmt.Sprintf("/ip4/2.2.%d.%d/udp/2/quic-v1", i/10, j)) + observedWebTransport = ma.StringCast(fmt.Sprintf("/ip4/2.2.%d.%d/udp/2/quic-v1/webtransport", i/10, j)) + observedTCP = ma.StringCast(fmt.Sprintf("/ip4/2.2.%d.%d/tcp/2", i/10, j)) + o.Record(tcpConns[i], observedTCP) + o.Record(quicConns[i], observedQuic) + o.Record(webTransportConns[i], observedWebTransport) + time.Sleep(10 * time.Millisecond) + } + } + // At this point we have 10 groups of N / 10 with 10 observations for every connection + // The output should remain stable + require.Eventually(t, func() bool { + return len(o.Addrs()) == 3*maxExternalThinWaistAddrsPerLocalAddr + }, 1*time.Second, 100*time.Millisecond) + addrs := o.Addrs() + for i := 0; i < 10; i++ { + require.ElementsMatch(t, o.Addrs(), addrs, "%s %s", o.Addrs(), addrs) + time.Sleep(10 * time.Millisecond) + } - // same observer, so should not yet get them. - harness.observe(a1, pa4) - harness.observe(a2, pa4) - harness.observe(a3, pa4) - if !addrsMatch(harness.oas.Addrs(), nil) { - t.Error("addrs should _still_ be empty (same obs)") - } + // Now we bias a few address counts and check for sorting correctness + var resTCPAddrs, resQuicAddrs, resWebTransportAddrs [maxExternalThinWaistAddrsPerLocalAddr]ma.Multiaddr + for i := 0; i < maxExternalThinWaistAddrsPerLocalAddr; i++ { + resTCPAddrs[i] = ma.StringCast(fmt.Sprintf("/ip4/2.2.%d.4/tcp/2", 9-i)) + resQuicAddrs[i] = ma.StringCast(fmt.Sprintf("/ip4/2.2.%d.4/udp/2/quic-v1", 9-i)) + resWebTransportAddrs[i] = ma.StringCast(fmt.Sprintf("/ip4/2.2.%d.4/udp/2/quic-v1/webtransport", 9-i)) + o.Record(tcpConns[i], resTCPAddrs[i]) + o.Record(quicConns[i], resQuicAddrs[i]) + o.Record(webTransportConns[i], resWebTransportAddrs[i]) + time.Sleep(10 * time.Millisecond) + } + var allAddrs []ma.Multiaddr + allAddrs = append(allAddrs, resTCPAddrs[:]...) + allAddrs = append(allAddrs, resQuicAddrs[:]...) + allAddrs = append(allAddrs, resWebTransportAddrs[:]...) + require.Eventually(t, func() bool { + return addrsEqual(o.Addrs(), allAddrs) + }, 1*time.Second, 100*time.Millisecond) + + for i := 0; i < N; i++ { + o.removeConn(tcpConns[i]) + o.removeConn(quicConns[i]) + o.removeConn(webTransportConns[i]) + } + require.Eventually(t, func() bool { + return checkAllEntriesRemoved(o) + }, 1*time.Second, 100*time.Millisecond) + }) - // different observer, but same observer group. - harness.observe(a1, pa5) - harness.observe(a2, pa5) - harness.observe(a3, pa5) - if !addrsMatch(harness.oas.Addrs(), nil) { - t.Error("addrs should _still_ be empty (same obs group)") - } + t.Run("WebTransport certhash", func(t *testing.T) { + o := newObservedAddrMgr() + observedWebTransport := ma.StringCast("/ip4/2.2.2.2/udp/1/quic-v1/webtransport") + c1 := newConn(webTransport4ListenAddr, ma.StringCast("/ip4/1.2.3.1/udp/1/quic-v1/webtransport")) + c2 := newConn(webTransport4ListenAddr, ma.StringCast("/ip4/1.2.3.2/udp/1/quic-v1/webtransport")) + c3 := newConn(webTransport4ListenAddr, ma.StringCast("/ip4/1.2.3.3/udp/1/quic-v1/webtransport")) + c4 := newConn(webTransport4ListenAddr, ma.StringCast("/ip4/1.2.3.4/udp/1/quic-v1/webtransport")) + o.Record(c1, observedWebTransport) + o.Record(c2, observedWebTransport) + o.Record(c3, observedWebTransport) + o.Record(c4, observedWebTransport) + require.Eventually(t, func() bool { + return addrsEqual(o.Addrs(), []ma.Multiaddr{observedWebTransport}) + }, 1*time.Second, 100*time.Millisecond) + o.removeConn(c1) + o.removeConn(c2) + o.removeConn(c3) + o.removeConn(c4) + require.Eventually(t, func() bool { + return checkAllEntriesRemoved(o) + }, 1*time.Second, 100*time.Millisecond) + }) - harness.observe(a1, pb1) - harness.observe(a1, pb2) - harness.observe(a1, pb3) - if !addrsMatch(harness.oas.Addrs(), []ma.Multiaddr{a1}) { - t.Error("addrs should only have a1") - } + t.Run("getNATType", func(t *testing.T) { + o := newObservedAddrMgr() + defer o.Close() - harness.observe(a2, pa5) - harness.observe(a1, pa5) - harness.observe(a1, pa5) - harness.observe(a2, pb1) - harness.observe(a1, pb1) - harness.observe(a1, pb1) - harness.observe(a2, pb2) - harness.observe(a1, pb2) - harness.observe(a1, pb2) - harness.observe(a2, pb4) - harness.observe(a2, pb5) - if !addrsMatch(harness.oas.Addrs(), []ma.Multiaddr{a1, a2}) { - t.Error("addrs should only have a1, a2") - } + observedWebTransport := ma.StringCast("/ip4/2.2.2.2/udp/1/quic-v1/webtransport") + var udpConns [5 * maxExternalThinWaistAddrsPerLocalAddr]connMultiaddrs + for i := 0; i < len(udpConns); i++ { + udpConns[i] = newConn(webTransport4ListenAddr, ma.StringCast(fmt.Sprintf("/ip4/1.2.3.%d/udp/1/quic-v1/webtransport", i))) + o.Record(udpConns[i], observedWebTransport) + time.Sleep(10 * time.Millisecond) + } + require.Eventually(t, func() bool { + return addrsEqual(o.Addrs(), []ma.Multiaddr{observedWebTransport}) + }, 1*time.Second, 100*time.Millisecond) - // force a refresh. - harness.oas.SetTTL(time.Millisecond * 200) - require.Eventuallyf(t, - func() bool { return addrsMatch(harness.oas.Addrs(), []ma.Multiaddr{a1, a2}) }, - time.Second, - 50*time.Millisecond, - "addrs should only have %s, %s; have %s", a1, a2, harness.oas.Addrs(), - ) - - // disconnect from all but b5. - for _, p := range harness.host.Network().Peers() { - if p == pb5 { - continue - } - harness.host.Network().ClosePeer(p) - } + tcpNAT, udpNAT := o.getNATType() + require.Equal(t, tcpNAT, network.NATDeviceTypeUnknown) + require.Equal(t, udpNAT, network.NATDeviceTypeCone) + }) + t.Run("NATTypeSymmetric", func(t *testing.T) { + o := newObservedAddrMgr() + defer o.Close() + const N = 100 + var tcpConns, quicConns [N]*mockConn + for i := 0; i < N; i++ { + tcpConns[i] = newConn(tcp4ListenAddr, ma.StringCast(fmt.Sprintf("/ip4/1.2.3.%d/tcp/1", i))) + quicConns[i] = newConn(quic4ListenAddr, ma.StringCast(fmt.Sprintf("/ip4/1.2.3.%d/udp/1/quic-v1", i))) + } + var observedQuic, observedTCP ma.Multiaddr + for i := 0; i < N; i++ { + // ip addr has the form 2.2..2 + observedQuic = ma.StringCast(fmt.Sprintf("/ip4/2.2.%d.2/udp/2/quic-v1", i%20)) + observedTCP = ma.StringCast(fmt.Sprintf("/ip4/2.2.%d.2/tcp/2", i%20)) + o.Record(tcpConns[i], observedTCP) + o.Record(quicConns[i], observedQuic) + time.Sleep(10 * time.Millisecond) + } + // At this point we have 20 groups with 5 observations for every connection + // The output should remain stable + require.Eventually(t, func() bool { + return len(o.Addrs()) == 2*maxExternalThinWaistAddrsPerLocalAddr + }, 1*time.Second, 100*time.Millisecond) + + tcpNAT, udpNAT := o.getNATType() + require.Equal(t, tcpNAT, network.NATDeviceTypeSymmetric) + require.Equal(t, udpNAT, network.NATDeviceTypeSymmetric) + + for i := 0; i < N; i++ { + o.removeConn(tcpConns[i]) + o.removeConn(quicConns[i]) + } + require.Eventually(t, func() bool { + return checkAllEntriesRemoved(o) + }, 1*time.Second, 100*time.Millisecond) + }) + t.Run("Nill Input", func(t *testing.T) { + o := newObservedAddrMgr() + defer o.Close() + o.maybeRecordObservation(nil, nil) + remoteAddr := ma.StringCast("/ip4/1.2.3.4/tcp/1") + o.maybeRecordObservation(newConn(tcp4ListenAddr, remoteAddr), nil) + o.maybeRecordObservation(nil, remoteAddr) + o.AddrsFor(nil) + o.removeConn(nil) + }) + + t.Run("Nat Emitter", func(t *testing.T) { + o := newObservedAddrMgr() + defer o.Close() + bus := eventbus.NewBus() + + s := swarmt.GenSwarm(t, swarmt.EventBus(bus)) + h := blankhost.NewBlankHost(s, blankhost.WithEventBus(bus)) + defer h.Close() + // make reachability private + emitter, err := bus.Emitter(new(event.EvtLocalReachabilityChanged), eventbus.Stateful) + require.NoError(t, err) + emitter.Emit(event.EvtLocalReachabilityChanged{Reachability: network.ReachabilityPrivate}) + + // start nat emitter + n, err := newNATEmitter(h, o, 10*time.Millisecond) + require.NoError(t, err) + defer n.Close() + + sub, err := bus.Subscribe(new(event.EvtNATDeviceTypeChanged)) + require.NoError(t, err) + observedWebTransport := ma.StringCast("/ip4/2.2.2.2/udp/1/quic-v1/webtransport") + var udpConns [5 * maxExternalThinWaistAddrsPerLocalAddr]connMultiaddrs + for i := 0; i < len(udpConns); i++ { + udpConns[i] = newConn(webTransport4ListenAddr, ma.StringCast(fmt.Sprintf("/ip4/1.2.3.%d/udp/1/quic-v1/webtransport", i))) + o.Record(udpConns[i], observedWebTransport) + time.Sleep(10 * time.Millisecond) + } + require.Eventually(t, func() bool { + return addrsEqual(o.Addrs(), []ma.Multiaddr{observedWebTransport}) + }, 1*time.Second, 100*time.Millisecond) + + var e interface{} + select { + case e = <-sub.Out(): + case <-time.After(2 * time.Second): + t.Fatalf("expected NAT change event") + } + evt := e.(event.EvtNATDeviceTypeChanged) + require.Equal(t, evt.TransportProtocol, network.NATTransportUDP) + require.Equal(t, evt.NatDeviceType, network.NATDeviceTypeCone) + }) + t.Run("Many connection many observations IP4 And IP6", func(t *testing.T) { + o := newObservedAddrMgr() + defer o.Close() + const N = 100 + var tcp4Conns, quic4Conns, webTransport4Conns [N]*mockConn + var tcp6Conns, quic6Conns, webTransport6Conns [N]*mockConn + for i := 0; i < N; i++ { + tcp4Conns[i] = newConn(tcp4ListenAddr, ma.StringCast(fmt.Sprintf("/ip4/1.2.3.%d/tcp/1", i))) + quic4Conns[i] = newConn(quic4ListenAddr, ma.StringCast(fmt.Sprintf("/ip4/1.2.3.%d/udp/1/quic-v1", i))) + webTransport4Conns[i] = newConn(webTransport4ListenAddr, ma.StringCast(fmt.Sprintf("/ip4/1.2.3.%d/udp/1/quic-v1/webtransport", i))) + + tcp6Conns[i] = newConn(tcp6ListenAddr, ma.StringCast(fmt.Sprintf("/ip6/20%02x::/tcp/1", i))) + quic6Conns[i] = newConn(quic6ListenAddr, ma.StringCast(fmt.Sprintf("/ip6/20%02x::/udp/1/quic-v1", i))) + webTransport6Conns[i] = newConn(webTransport6ListenAddr, ma.StringCast(fmt.Sprintf("/ip6/20%02x::/udp/1/quic-v1/webtransport", i))) + } + var observedQUIC4, observedWebTransport4, observedTCP4 ma.Multiaddr + var observedQUIC6, observedWebTransport6, observedTCP6 ma.Multiaddr + for i := 0; i < N; i++ { + for j := 0; j < 5; j++ { + // ip addr has the form 2.2.. + observedQUIC4 = ma.StringCast(fmt.Sprintf("/ip4/2.2.%d.%d/udp/2/quic-v1", i/10, j)) + observedWebTransport4 = ma.StringCast(fmt.Sprintf("/ip4/2.2.%d.%d/udp/2/quic-v1/webtransport", i/10, j)) + observedTCP4 = ma.StringCast(fmt.Sprintf("/ip4/2.2.%d.%d/tcp/2", i/10, j)) + + // ip addr has the form 20XX::YY + observedQUIC6 = ma.StringCast(fmt.Sprintf("/ip6/20%02x::%02x/udp/2/quic-v1", i/10, j)) + observedWebTransport6 = ma.StringCast(fmt.Sprintf("/ip6/20%02x::%02x/udp/2/quic-v1/webtransport", i/10, j)) + observedTCP6 = ma.StringCast(fmt.Sprintf("/ip6/20%02x::%02x/tcp/2", i/10, j)) + + o.maybeRecordObservation(tcp4Conns[i], observedTCP4) + o.maybeRecordObservation(quic4Conns[i], observedQUIC4) + o.maybeRecordObservation(webTransport4Conns[i], observedWebTransport4) + + o.maybeRecordObservation(tcp6Conns[i], observedTCP6) + o.maybeRecordObservation(quic6Conns[i], observedQUIC6) + o.maybeRecordObservation(webTransport6Conns[i], observedWebTransport6) + } + } + // At this point we have 10 groups of N / 10 with 10 observations for every connection + // The output should remain stable + require.Eventually(t, func() bool { + return len(o.Addrs()) == 2*3*maxExternalThinWaistAddrsPerLocalAddr + }, 1*time.Second, 100*time.Millisecond) + addrs := o.Addrs() + for i := 0; i < 10; i++ { + require.ElementsMatch(t, o.Addrs(), addrs, "%s %s", o.Addrs(), addrs) + time.Sleep(10 * time.Millisecond) + } - // Wait for all other addresses to time out. - // After that, we hould still have a2. - require.Eventuallyf(t, - func() bool { return addrsMatch(harness.oas.Addrs(), []ma.Multiaddr{a2}) }, - time.Second, - 50*time.Millisecond, - "should only have a2 (%s), have: %v", a2, harness.oas.Addrs(), - ) - harness.host.Network().ClosePeer(pb5) - - // wait for all addresses to timeout - require.Eventually(t, - func() bool { return len(harness.oas.Addrs()) == 0 }, - 400*time.Millisecond, - 20*time.Millisecond, - "addrs should have timed out", - ) + // Now we bias a few address counts and check for sorting correctness + var resTCPAddrs, resQuicAddrs, resWebTransportAddrs []ma.Multiaddr + + for i, idx := 0, 0; i < maxExternalThinWaistAddrsPerLocalAddr; i++ { + resTCPAddrs = append(resTCPAddrs, ma.StringCast(fmt.Sprintf("/ip4/2.2.%d.4/tcp/2", 9-i))) + resQuicAddrs = append(resQuicAddrs, ma.StringCast(fmt.Sprintf("/ip4/2.2.%d.4/udp/2/quic-v1", 9-i))) + resWebTransportAddrs = append(resWebTransportAddrs, ma.StringCast(fmt.Sprintf("/ip4/2.2.%d.4/udp/2/quic-v1/webtransport", 9-i))) + + o.maybeRecordObservation(tcp4Conns[i], resTCPAddrs[idx]) + o.maybeRecordObservation(quic4Conns[i], resQuicAddrs[idx]) + o.maybeRecordObservation(webTransport4Conns[i], resWebTransportAddrs[idx]) + idx++ + + resTCPAddrs = append(resTCPAddrs, ma.StringCast(fmt.Sprintf("/ip6/20%02x::04/tcp/2", 9-i))) + resQuicAddrs = append(resQuicAddrs, ma.StringCast(fmt.Sprintf("/ip6/20%02x::04/udp/2/quic-v1", 9-i))) + resWebTransportAddrs = append(resWebTransportAddrs, ma.StringCast(fmt.Sprintf("/ip6/20%02x::04/udp/2/quic-v1/webtransport", 9-i))) + o.maybeRecordObservation(tcp6Conns[i], resTCPAddrs[idx]) + o.maybeRecordObservation(quic6Conns[i], resQuicAddrs[idx]) + o.maybeRecordObservation(webTransport6Conns[i], resWebTransportAddrs[idx]) + idx++ + } + var allAddrs []ma.Multiaddr + allAddrs = append(allAddrs, resTCPAddrs[:]...) + allAddrs = append(allAddrs, resQuicAddrs[:]...) + allAddrs = append(allAddrs, resWebTransportAddrs[:]...) + require.Eventually(t, func() bool { + return addrsEqual(o.Addrs(), allAddrs) + }, 1*time.Second, 100*time.Millisecond) + + for i := 0; i < N; i++ { + o.removeConn(tcp4Conns[i]) + o.removeConn(quic4Conns[i]) + o.removeConn(webTransport4Conns[i]) + o.removeConn(tcp6Conns[i]) + o.removeConn(quic6Conns[i]) + o.removeConn(webTransport6Conns[i]) + } + require.Eventually(t, func() bool { + return checkAllEntriesRemoved(o) + }, 1*time.Second, 100*time.Millisecond) + }) } -func TestObservedAddrFiltering(t *testing.T) { - harness := newHarness(t) - require.Empty(t, harness.oas.Addrs()) - - // IP4/TCP - it1 := ma.StringCast("/ip4/1.2.3.4/tcp/1231") - it2 := ma.StringCast("/ip4/1.2.3.4/tcp/1232") - it3 := ma.StringCast("/ip4/1.2.3.4/tcp/1233") - it4 := ma.StringCast("/ip4/1.2.3.4/tcp/1234") - it5 := ma.StringCast("/ip4/1.2.3.4/tcp/1235") - it6 := ma.StringCast("/ip4/1.2.3.4/tcp/1236") - it7 := ma.StringCast("/ip4/1.2.3.4/tcp/1237") - - // observers - b1 := ma.StringCast("/ip4/1.2.3.6/tcp/1236") - b2 := ma.StringCast("/ip4/1.2.3.7/tcp/1237") - b3 := ma.StringCast("/ip4/1.2.3.8/tcp/1237") - b4 := ma.StringCast("/ip4/1.2.3.9/tcp/1237") - b5 := ma.StringCast("/ip4/1.2.3.10/tcp/1237") - - b6 := ma.StringCast("/ip4/1.2.3.11/tcp/1237") - b7 := ma.StringCast("/ip4/1.2.3.12/tcp/1237") - - // These are all observers in the same group. - b8 := ma.StringCast("/ip4/1.2.3.13/tcp/1237") - b9 := ma.StringCast("/ip4/1.2.3.13/tcp/1238") - b10 := ma.StringCast("/ip4/1.2.3.13/tcp/1239") - - peers := []peer.ID{ - harness.add(b1), - harness.add(b2), - harness.add(b3), - harness.add(b4), - harness.add(b5), - - harness.add(b6), - harness.add(b7), - - harness.add(b8), - harness.add(b9), - harness.add(b10), - } - for i := 0; i < 4; i++ { - harness.observe(it1, peers[i]) - harness.observe(it2, peers[i]) - harness.observe(it3, peers[i]) - harness.observe(it4, peers[i]) - harness.observe(it5, peers[i]) - harness.observe(it6, peers[i]) - harness.observe(it7, peers[i]) +func genIPMultiaddr(ip6 bool) ma.Multiaddr { + var ipB [16]byte + crand.Read(ipB[:]) + var ip net.IP + if ip6 { + ip = net.IP(ipB[:]) + } else { + ip = net.IP(ipB[:4]) } - - harness.observe(it1, peers[4]) - harness.observe(it7, peers[4]) - - addrs := harness.oas.Addrs() - require.Len(t, addrs, 2) - require.Contains(t, addrs, it1) - require.Contains(t, addrs, it7) - - // Bump the number of observations so 1 & 7 have 7 observations. - harness.observe(it1, peers[5]) - harness.observe(it1, peers[6]) - harness.observe(it7, peers[5]) - harness.observe(it7, peers[6]) - - // Add an observation from IP 1.2.3.13 - // 2 & 3 now have 5 observations - harness.observe(it2, peers[7]) - harness.observe(it3, peers[7]) - - addrs = harness.oas.Addrs() - require.Len(t, addrs, 2) - require.Contains(t, addrs, it1) - require.Contains(t, addrs, it7) - - // Add an inbound observation from IP 1.2.3.13, it should override the - // existing observation and it should make these addresses win even - // though we have fewer observations. - // - // 2 & 3 now have 6 observations. - harness.observeInbound(it2, peers[8]) - harness.observeInbound(it3, peers[8]) - addrs = harness.oas.Addrs() - require.Len(t, addrs, 2) - require.Contains(t, addrs, it2) - require.Contains(t, addrs, it3) - - // Adding an outbound observation shouldn't "downgrade" it. - // - // 2 & 3 now have 7 observations. - harness.observe(it2, peers[9]) - harness.observe(it3, peers[9]) - addrs = harness.oas.Addrs() - require.Len(t, addrs, 2) - require.Contains(t, addrs, it2) - require.Contains(t, addrs, it3) + addr, _ := manet.FromIP(ip) + return addr } -func TestEmitNATDeviceTypeSymmetric(t *testing.T) { - harness := newHarness(t) - require.Empty(t, harness.oas.Addrs()) - emitter, err := harness.host.EventBus().Emitter(new(event.EvtLocalReachabilityChanged), eventbus.Stateful) - require.NoError(t, err) - require.NoError(t, emitter.Emit(event.EvtLocalReachabilityChanged{Reachability: network.ReachabilityPrivate})) - - // TCP - it1 := ma.StringCast("/ip4/1.2.3.4/tcp/1231") - it2 := ma.StringCast("/ip4/1.2.3.4/tcp/1232") - it3 := ma.StringCast("/ip4/1.2.3.4/tcp/1233") - it4 := ma.StringCast("/ip4/1.2.3.4/tcp/1234") - - // observers - b1 := ma.StringCast("/ip4/1.2.3.6/tcp/1236") - b2 := ma.StringCast("/ip4/1.2.3.7/tcp/1237") - b3 := ma.StringCast("/ip4/1.2.3.8/tcp/1237") - b4 := ma.StringCast("/ip4/1.2.3.9/tcp/1237") - - peers := []peer.ID{ - harness.add(b1), - harness.add(b2), - harness.add(b3), - harness.add(b4), +func FuzzObservedAddrManager(f *testing.F) { + protos := []string{ + "/webrtc-direct", + "/quic-v1", + "/quic-v1/webtransport", } - - harness.observe(it1, peers[0]) - harness.observe(it2, peers[1]) - harness.observe(it3, peers[2]) - harness.observe(it4, peers[3]) - - sub, err := harness.host.EventBus().Subscribe(new(event.EvtNATDeviceTypeChanged)) - require.NoError(t, err) - select { - case ev := <-sub.Out(): - evt := ev.(event.EvtNATDeviceTypeChanged) - require.Equal(t, network.NATDeviceTypeSymmetric, evt.NatDeviceType) - require.Equal(t, network.NATTransportTCP, evt.TransportProtocol) - case <-time.After(5 * time.Second): - t.Fatal("did not get Symmetric NAT event") + tcp4 := ma.StringCast("/ip4/192.168.1.100/tcp/1") + quic4 := ma.StringCast("/ip4/0.0.0.0/udp/1/quic-v1") + wt4 := ma.StringCast("/ip4/0.0.0.0/udp/1/quic-v1/webtransport/certhash/uEgNmb28") + tcp6 := ma.StringCast("/ip6/1::1/tcp/1") + quic6 := ma.StringCast("/ip6/::/udp/1/quic-v1") + wt6 := ma.StringCast("/ip6/::/udp/1/quic-v1/webtransport/certhash/uEgNmb28") + newObservedAddrMgr := func() *ObservedAddrManager { + listenAddrs := []ma.Multiaddr{ + tcp4, quic4, wt4, tcp6, quic6, wt6, + } + listenAddrsFunc := func() []ma.Multiaddr { + return listenAddrs + } + interfaceListenAddrsFunc := func() ([]ma.Multiaddr, error) { + return listenAddrs, nil + } + o, err := NewObservedAddrManager(listenAddrsFunc, listenAddrsFunc, + interfaceListenAddrsFunc, normalize) + if err != nil { + panic(err) + } + return o } + + f.Fuzz(func(t *testing.T, port uint16) { + addrs := []ma.Multiaddr{genIPMultiaddr(true), genIPMultiaddr(false)} + n := len(addrs) + for i := 0; i < n; i++ { + addrs = append(addrs, addrs[i].Encapsulate(ma.StringCast(fmt.Sprintf("/tcp/%d", port)))) + addrs = append(addrs, addrs[i].Encapsulate(ma.StringCast(fmt.Sprintf("/udp/%d", port)))) + addrs = append(addrs, ma.StringCast(fmt.Sprintf("/tcp/%d", port))) + addrs = append(addrs, ma.StringCast(fmt.Sprintf("/udp/%d", port))) + } + n = len(addrs) + for i := 0; i < n; i++ { + for j := 0; j < len(protos); j++ { + protoAddr := ma.StringCast(protos[j]) + addrs = append(addrs, addrs[i].Encapsulate(protoAddr)) + addrs = append(addrs, protoAddr) + } + } + o := newObservedAddrMgr() + defer o.Close() + for i := 0; i < len(addrs); i++ { + for _, l := range o.listenAddrs() { + c := newConn(l, addrs[i]) + o.maybeRecordObservation(c, addrs[i]) + o.maybeRecordObservation(c, nil) + o.maybeRecordObservation(nil, addrs[i]) + o.removeConn(c) + } + } + }) } -func TestEmitNATDeviceTypeCone(t *testing.T) { - harness := newHarness(t) - require.Empty(t, harness.oas.Addrs()) - emitter, err := harness.host.EventBus().Emitter(new(event.EvtLocalReachabilityChanged), eventbus.Stateful) - require.NoError(t, err) - require.NoError(t, emitter.Emit(event.EvtLocalReachabilityChanged{Reachability: network.ReachabilityPrivate})) - - it1 := ma.StringCast("/ip4/1.2.3.4/tcp/1231") - it2 := ma.StringCast("/ip4/1.2.3.4/tcp/1231") - it3 := ma.StringCast("/ip4/1.2.3.4/tcp/1231") - it4 := ma.StringCast("/ip4/1.2.3.4/tcp/1231") - - // observers - b1 := ma.StringCast("/ip4/1.2.3.6/tcp/1236") - b2 := ma.StringCast("/ip4/1.2.3.7/tcp/1237") - b3 := ma.StringCast("/ip4/1.2.3.8/tcp/1237") - b4 := ma.StringCast("/ip4/1.2.3.9/tcp/1237") - - peers := []peer.ID{ - harness.add(b1), - harness.add(b2), - harness.add(b3), - harness.add(b4), +func TestObserver(t *testing.T) { + tests := []struct { + addr ma.Multiaddr + want string + }{ + { + addr: ma.StringCast("/ip4/1.2.3.4/tcp/1"), + want: "1.2.3.4", + }, + { + addr: ma.StringCast("/ip4/192.168.0.1/tcp/1"), + want: "192.168.0.1", + }, + { + addr: ma.StringCast("/ip6/200::1/udp/1/quic-v1"), + want: "200::", + }, + { + addr: ma.StringCast("/ip6/::1/udp/1/quic-v1"), + want: "::", + }, } - harness.observe(it1, peers[0]) - harness.observe(it2, peers[1]) - harness.observe(it3, peers[2]) - harness.observe(it4, peers[3]) - - sub, err := harness.host.EventBus().Subscribe(new(event.EvtNATDeviceTypeChanged)) - require.NoError(t, err) - select { - case ev := <-sub.Out(): - evt := ev.(event.EvtNATDeviceTypeChanged) - require.Equal(t, network.NATDeviceTypeCone, evt.NatDeviceType) - case <-time.After(5 * time.Second): - t.Fatal("did not get Cone NAT event") + for i, tc := range tests { + t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { + got, err := getObserver(tc.addr) + require.NoError(t, err) + require.Equal(t, got, tc.want) + }) } } - -func TestObserveWebtransport(t *testing.T) { - listenAddr := ma.StringCast("/ip4/1.2.3.4/udp/9999/quic-v1/webtransport/certhash/uEgNmb28") - observedAddr := ma.StringCast("/ip4/1.2.3.4/udp/1231/quic-v1/webtransport") - - harness := newHarnessWithMa(t, listenAddr) - - pb1 := harness.add(ma.StringCast("/ip4/1.2.3.6/udp/1236/quic-v1/webtransport")) - pb2 := harness.add(ma.StringCast("/ip4/1.2.3.7/udp/1237/quic-v1/webtransport")) - pb3 := harness.add(ma.StringCast("/ip4/1.2.3.8/udp/1237/quic-v1/webtransport")) - pb4 := harness.add(ma.StringCast("/ip4/1.2.3.9/udp/1237/quic-v1/webtransport")) - pb5 := harness.add(ma.StringCast("/ip4/1.2.3.10/udp/1237/quic-v1/webtransport")) - - harness.observe(observedAddr, pb1) - harness.observe(observedAddr, pb2) - harness.observe(observedAddr, pb3) - harness.observe(observedAddr, pb4) - harness.observe(observedAddr, pb5) - - require.Len(t, harness.oas.Addrs(), 1) - require.Equal(t, "/ip4/1.2.3.4/udp/1231/quic-v1/webtransport", harness.oas.Addrs()[0].String()) -} diff --git a/p2p/protocol/identify/opts.go b/p2p/protocol/identify/opts.go index f188665686..bd0fd896b8 100644 --- a/p2p/protocol/identify/opts.go +++ b/p2p/protocol/identify/opts.go @@ -1,10 +1,11 @@ package identify type config struct { - protocolVersion string - userAgent string - disableSignedPeerRecord bool - metricsTracer MetricsTracer + protocolVersion string + userAgent string + disableSignedPeerRecord bool + metricsTracer MetricsTracer + disableObservedAddrManager bool } // Option is an option function for identify. @@ -38,3 +39,11 @@ func WithMetricsTracer(tr MetricsTracer) Option { cfg.metricsTracer = tr } } + +// DisableObservedAddrManager disables the observed address manager. It also +// effectively disables the nat emitter and EvtNATDeviceTypeChanged +func DisableObservedAddrManager() Option { + return func(cfg *config) { + cfg.disableObservedAddrManager = true + } +}