From 6b9c11680e0bb412ef7515f266531c23245015dc Mon Sep 17 00:00:00 2001 From: Sukun Date: Fri, 27 Jan 2023 15:09:59 +0530 Subject: [PATCH] consistently use protocol.ID instead of strings (#2004) * Change PeerStore interface to use protocol.ID This reduces the string to protocol.ID translations happening at various places in the code * Fix misc cases of protocol.ID conversion * Merge multistream changes * Use protocol.ID in network.ConnectionState * don't update examples * fix error message tests * merge new go-multistream changes * update test-plans go mod * change transport back to string --- core/host/host.go | 2 +- core/network/conn.go | 5 ++- core/peerstore/peerstore.go | 15 +++---- core/protocol/switch.go | 14 ++++--- go.mod | 2 +- go.sum | 4 +- p2p/host/autorelay/relay_finder.go | 4 +- p2p/host/basic/basic_host.go | 37 ++++++++---------- p2p/host/basic/basic_host_test.go | 12 +++--- p2p/host/blank/blank.go | 28 ++++++------- p2p/host/peerstore/pstoreds/metadata.go | 3 +- p2p/host/peerstore/pstoreds/protobook.go | 25 ++++++------ p2p/host/peerstore/pstoremem/protobook.go | 31 ++++++++------- p2p/host/peerstore/test/peerstore_suite.go | 25 +++++++----- p2p/host/pstoremanager/mock_peerstore_test.go | 19 ++++----- p2p/host/resource-manager/extapi.go | 2 +- p2p/host/routed/routed.go | 2 +- p2p/net/swarm/swarm_metrics.go | 6 +-- p2p/net/upgrader/conn.go | 4 +- p2p/net/upgrader/listener_test.go | 2 +- p2p/net/upgrader/upgrader.go | 38 +++++++++--------- .../circuitv2/client/reservation_test.go | 2 +- p2p/protocol/holepunch/holepunch_test.go | 2 +- p2p/protocol/identify/id.go | 13 ++++--- p2p/protocol/identify/id_test.go | 18 ++++----- p2p/protocol/identify/peer_loop.go | 6 +-- p2p/security/noise/session.go | 3 +- p2p/security/noise/transport.go | 18 ++++----- p2p/security/noise/transport_test.go | 39 ++++++++++--------- p2p/security/tls/transport.go | 2 +- p2p/security/tls/transport_test.go | 4 +- p2p/test/negotiation/muxer_test.go | 4 +- p2p/test/negotiation/security_test.go | 5 +-- test-plans/go.mod | 2 +- test-plans/go.sum | 4 +- 35 files changed, 204 insertions(+), 198 deletions(-) diff --git a/core/host/host.go b/core/host/host.go index 5fac751a8b..e62be281f1 100644 --- a/core/host/host.go +++ b/core/host/host.go @@ -52,7 +52,7 @@ type Host interface { // SetStreamHandlerMatch sets the protocol handler on the Host's Mux // using a matching function for protocol selection. - SetStreamHandlerMatch(protocol.ID, func(string) bool, network.StreamHandler) + SetStreamHandlerMatch(protocol.ID, func(protocol.ID) bool, network.StreamHandler) // RemoveStreamHandler removes a handler on the mux that was set by // SetStreamHandler diff --git a/core/network/conn.go b/core/network/conn.go index 6d16d298cf..4191a9d544 100644 --- a/core/network/conn.go +++ b/core/network/conn.go @@ -6,6 +6,7 @@ import ( ic "github.com/libp2p/go-libp2p/core/crypto" "github.com/libp2p/go-libp2p/core/peer" + "github.com/libp2p/go-libp2p/core/protocol" ma "github.com/multiformats/go-multiaddr" ) @@ -37,9 +38,9 @@ type Conn interface { // ConnectionState holds information about the connection. type ConnectionState struct { // The stream multiplexer used on this connection (if any). For example: /yamux/1.0.0 - StreamMultiplexer string + StreamMultiplexer protocol.ID // The security protocol used on this connection (if any). For example: /tls/1.0.0 - Security string + Security protocol.ID // the transport used on this connection. For example: tcp Transport string } diff --git a/core/peerstore/peerstore.go b/core/peerstore/peerstore.go index 7561f32be2..b63582afe6 100644 --- a/core/peerstore/peerstore.go +++ b/core/peerstore/peerstore.go @@ -11,6 +11,7 @@ import ( ic "github.com/libp2p/go-libp2p/core/crypto" "github.com/libp2p/go-libp2p/core/peer" + "github.com/libp2p/go-libp2p/core/protocol" "github.com/libp2p/go-libp2p/core/record" ma "github.com/multiformats/go-multiaddr" @@ -230,19 +231,19 @@ type Metrics interface { // ProtoBook tracks the protocols supported by peers. type ProtoBook interface { - GetProtocols(peer.ID) ([]string, error) - AddProtocols(peer.ID, ...string) error - SetProtocols(peer.ID, ...string) error - RemoveProtocols(peer.ID, ...string) error + GetProtocols(peer.ID) ([]protocol.ID, error) + AddProtocols(peer.ID, ...protocol.ID) error + SetProtocols(peer.ID, ...protocol.ID) error + RemoveProtocols(peer.ID, ...protocol.ID) error // SupportsProtocols returns the set of protocols the peer supports from among the given protocols. // If the returned error is not nil, the result is indeterminate. - SupportsProtocols(peer.ID, ...string) ([]string, error) + SupportsProtocols(peer.ID, ...protocol.ID) ([]protocol.ID, error) // FirstSupportedProtocol returns the first protocol that the peer supports among the given protocols. - // If the peer does not support any of the given protocols, this function will return an empty string and a nil error. + // If the peer does not support any of the given protocols, this function will return an empty protocol.ID and a nil error. // If the returned error is not nil, the result is indeterminate. - FirstSupportedProtocol(peer.ID, ...string) (string, error) + FirstSupportedProtocol(peer.ID, ...protocol.ID) (protocol.ID, error) // RemovePeer removes all protocols associated with a peer. RemovePeer(peer.ID) diff --git a/core/protocol/switch.go b/core/protocol/switch.go index f839e0163c..683ef56fef 100644 --- a/core/protocol/switch.go +++ b/core/protocol/switch.go @@ -3,6 +3,8 @@ package protocol import ( "io" + + "github.com/multiformats/go-multistream" ) // HandlerFunc is a user-provided function used by the Router to @@ -11,7 +13,7 @@ import ( // Will be invoked with the protocol ID string as the first argument, // which may differ from the ID used for registration if the handler // was registered using a match function. -type HandlerFunc = func(protocol string, rwc io.ReadWriteCloser) error +type HandlerFunc = multistream.HandlerFunc[ID] // Router is an interface that allows users to add and remove protocol handlers, // which will be invoked when incoming stream requests for registered protocols @@ -25,7 +27,7 @@ type Router interface { // AddHandler registers the given handler to be invoked for // an exact literal match of the given protocol ID string. - AddHandler(protocol string, handler HandlerFunc) + AddHandler(protocol ID, handler HandlerFunc) // AddHandlerWithFunc registers the given handler to be invoked // when the provided match function returns true. @@ -35,17 +37,17 @@ type Router interface { // the protocol. Note that the protocol ID argument is not // used for matching; if you want to match the protocol ID // string exactly, you must check for it in your match function. - AddHandlerWithFunc(protocol string, match func(string) bool, handler HandlerFunc) + AddHandlerWithFunc(protocol ID, match func(ID) bool, handler HandlerFunc) // RemoveHandler removes the registered handler (if any) for the // given protocol ID string. - RemoveHandler(protocol string) + RemoveHandler(protocol ID) // Protocols returns a list of all registered protocol ID strings. // Note that the Router may be able to handle protocol IDs not // included in this list if handlers were added with match functions // using AddHandlerWithFunc. - Protocols() []string + Protocols() []ID } // Negotiator is a component capable of reaching agreement over what protocols @@ -55,7 +57,7 @@ type Negotiator interface { // inbound stream, returning after the protocol has been determined and the // Negotiator has finished using the stream for negotiation. Returns an // error if negotiation fails. - Negotiate(rwc io.ReadWriteCloser) (string, HandlerFunc, error) + Negotiate(rwc io.ReadWriteCloser) (ID, HandlerFunc, error) // Handle calls Negotiate to determine which protocol handler to use for an // inbound stream, then invokes the protocol handler function, passing it diff --git a/go.mod b/go.mod index 36b87781d9..a7835981e3 100644 --- a/go.mod +++ b/go.mod @@ -42,7 +42,7 @@ require ( github.com/multiformats/go-multibase v0.1.1 github.com/multiformats/go-multicodec v0.7.0 github.com/multiformats/go-multihash v0.2.1 - github.com/multiformats/go-multistream v0.3.3 + github.com/multiformats/go-multistream v0.4.0 github.com/multiformats/go-varint v0.0.7 github.com/pbnjay/memory v0.0.0-20210728143218-7b4eea64cf58 github.com/prometheus/client_golang v1.14.0 diff --git a/go.sum b/go.sum index 58713e5c46..89d7b78c30 100644 --- a/go.sum +++ b/go.sum @@ -384,8 +384,8 @@ github.com/multiformats/go-multicodec v0.7.0/go.mod h1:GUC8upxSBE4oG+q3kWZRw/+6y github.com/multiformats/go-multihash v0.0.8/go.mod h1:YSLudS+Pi8NHE7o6tb3D8vrpKa63epEDmG8nTduyAew= github.com/multiformats/go-multihash v0.2.1 h1:aem8ZT0VA2nCHHk7bPJ1BjUbHNciqZC/d16Vve9l108= github.com/multiformats/go-multihash v0.2.1/go.mod h1:WxoMcYG85AZVQUyRyo9s4wULvW5qrI9vb2Lt6evduFc= -github.com/multiformats/go-multistream v0.3.3 h1:d5PZpjwRgVlbwfdTDjife7XszfZd8KYWfROYFlGcR8o= -github.com/multiformats/go-multistream v0.3.3/go.mod h1:ODRoqamLUsETKS9BNcII4gcRsJBU5VAwRIv7O39cEXg= +github.com/multiformats/go-multistream v0.4.0 h1:5i4JbawClkbuaX+mIVXiHQYVPxUW+zjv6w7jtSRukxc= +github.com/multiformats/go-multistream v0.4.0/go.mod h1:BS6ZSYcA4NwYEaIMeCtpJydp2Dc+fNRA6uJMSu/m8+4= github.com/multiformats/go-varint v0.0.1/go.mod h1:3Ls8CIEsrijN6+B7PbrXRPxHRPuXSrVKRY101jdMZYE= github.com/multiformats/go-varint v0.0.7 h1:sWSGR+f/eu5ABZA2ZpYKBILXTTs9JWpdEM/nEGOHFS8= github.com/multiformats/go-varint v0.0.7/go.mod h1:r8PUYw/fD/SjBCiKOoDlGF6QawOELpZAu9eioSos/OU= diff --git a/p2p/host/autorelay/relay_finder.go b/p2p/host/autorelay/relay_finder.go index 851d1422e5..20d71240b9 100644 --- a/p2p/host/autorelay/relay_finder.go +++ b/p2p/host/autorelay/relay_finder.go @@ -23,8 +23,8 @@ import ( ) const ( - protoIDv1 = string(relayv1.ProtoID) - protoIDv2 = string(circuitv2_proto.ProtoIDv2Hop) + protoIDv1 = relayv1.ProtoID + protoIDv2 = circuitv2_proto.ProtoIDv2Hop ) // Terminology: diff --git a/p2p/host/basic/basic_host.go b/p2p/host/basic/basic_host.go index 82e29a37d2..2a500151fc 100644 --- a/p2p/host/basic/basic_host.go +++ b/p2p/host/basic/basic_host.go @@ -70,7 +70,7 @@ type BasicHost struct { network network.Network psManager *pstoremanager.PeerstoreManager - mux *msmux.MultistreamMuxer + mux *msmux.MultistreamMuxer[protocol.ID] ids identify.IDService hps *holepunch.Service pings *ping.PingService @@ -108,7 +108,7 @@ var _ host.Host = (*BasicHost)(nil) // customize construction of the *BasicHost. type HostOpts struct { // MultistreamMuxer is essential for the *BasicHost and will use a sensible default value if omitted. - MultistreamMuxer *msmux.MultistreamMuxer + MultistreamMuxer *msmux.MultistreamMuxer[protocol.ID] // NegotiationTimeout determines the read and write timeouts on streams. // If 0 or omitted, it will use DefaultNegotiationTimeout. @@ -168,7 +168,7 @@ func NewHost(n network.Network, opts *HostOpts) (*BasicHost, error) { h := &BasicHost{ network: n, psManager: psManager, - mux: msmux.NewMultistreamMuxer(), + mux: msmux.NewMultistreamMuxer[protocol.ID](), negtimeout: DefaultNegotiationTimeout, AddrsFactory: DefaultAddrsFactory, maResolver: madns.DefaultResolver, @@ -407,7 +407,7 @@ func (h *BasicHost) newStreamHandler(s network.Stream) { } } - if err := s.SetProtocol(protocol.ID(protoID)); err != nil { + if err := s.SetProtocol(protoID); err != nil { log.Debugf("error setting stream protocol: %s", err) s.Reset() return @@ -571,9 +571,9 @@ func (h *BasicHost) EventBus() event.Bus { // // (Threadsafe) func (h *BasicHost) SetStreamHandler(pid protocol.ID, handler network.StreamHandler) { - h.Mux().AddHandler(string(pid), func(p string, rwc io.ReadWriteCloser) error { + h.Mux().AddHandler(pid, func(p protocol.ID, rwc io.ReadWriteCloser) error { is := rwc.(network.Stream) - is.SetProtocol(protocol.ID(p)) + is.SetProtocol(p) handler(is) return nil }) @@ -584,10 +584,10 @@ func (h *BasicHost) SetStreamHandler(pid protocol.ID, handler network.StreamHand // SetStreamHandlerMatch sets the protocol handler on the Host's Mux // using a matching function to do protocol comparisons -func (h *BasicHost) SetStreamHandlerMatch(pid protocol.ID, m func(string) bool, handler network.StreamHandler) { - h.Mux().AddHandlerWithFunc(string(pid), m, func(p string, rwc io.ReadWriteCloser) error { +func (h *BasicHost) SetStreamHandlerMatch(pid protocol.ID, m func(protocol.ID) bool, handler network.StreamHandler) { + h.Mux().AddHandlerWithFunc(pid, m, func(p protocol.ID, rwc io.ReadWriteCloser) error { is := rwc.(network.Stream) - is.SetProtocol(protocol.ID(p)) + is.SetProtocol(p) handler(is) return nil }) @@ -598,7 +598,7 @@ func (h *BasicHost) SetStreamHandlerMatch(pid protocol.ID, m func(string) bool, // RemoveStreamHandler returns .. func (h *BasicHost) RemoveStreamHandler(pid protocol.ID) { - h.Mux().RemoveHandler(string(pid)) + h.Mux().RemoveHandler(pid) h.emitters.evtLocalProtocolsUpdated.Emit(event.EvtLocalProtocolsUpdated{ Removed: []protocol.ID{pid}, }) @@ -637,9 +637,7 @@ func (h *BasicHost) NewStream(ctx context.Context, p peer.ID, pids ...protocol.I return nil, ctx.Err() } - pidStrings := protocol.ConvertToStrings(pids) - - pref, err := h.preferredProtocol(p, pidStrings) + pref, err := h.preferredProtocol(p, pids) if err != nil { _ = s.Reset() return nil, err @@ -647,7 +645,7 @@ func (h *BasicHost) NewStream(ctx context.Context, p peer.ID, pids ...protocol.I if pref != "" { s.SetProtocol(pref) - lzcon := msmux.NewMSSelect(s, string(pref)) + lzcon := msmux.NewMSSelect(s, pref) return &streamWrapper{ Stream: s, rw: lzcon, @@ -655,10 +653,10 @@ func (h *BasicHost) NewStream(ctx context.Context, p peer.ID, pids ...protocol.I } // Negotiate the protocol in the background, obeying the context. - var selected string + var selected protocol.ID errCh := make(chan error, 1) go func() { - selected, err = msmux.SelectOneOf(pidStrings, s) + selected, err = msmux.SelectOneOf(pids, s) errCh <- err }() select { @@ -674,13 +672,12 @@ func (h *BasicHost) NewStream(ctx context.Context, p peer.ID, pids ...protocol.I return nil, ctx.Err() } - selpid := protocol.ID(selected) - s.SetProtocol(selpid) + s.SetProtocol(selected) h.Peerstore().AddProtocols(p, selected) return s, nil } -func (h *BasicHost) preferredProtocol(p peer.ID, pids []string) (protocol.ID, error) { +func (h *BasicHost) preferredProtocol(p peer.ID, pids []protocol.ID) (protocol.ID, error) { supported, err := h.Peerstore().SupportsProtocols(p, pids...) if err != nil { return "", err @@ -688,7 +685,7 @@ func (h *BasicHost) preferredProtocol(p peer.ID, pids []string) (protocol.ID, er var out protocol.ID if len(supported) > 0 { - out = protocol.ID(supported[0]) + out = supported[0] } return out, nil } diff --git a/p2p/host/basic/basic_host_test.go b/p2p/host/basic/basic_host_test.go index 9093beea01..cd56684f98 100644 --- a/p2p/host/basic/basic_host_test.go +++ b/p2p/host/basic/basic_host_test.go @@ -158,8 +158,8 @@ func TestProtocolHandlerEvents(t *testing.T) { h.SetStreamHandler(protocol.TestingID, func(s network.Stream) {}) assert([]protocol.ID{protocol.TestingID}, nil) - h.SetStreamHandler(protocol.ID("foo"), func(s network.Stream) {}) - assert([]protocol.ID{protocol.ID("foo")}, nil) + h.SetStreamHandler("foo", func(s network.Stream) {}) + assert([]protocol.ID{"foo"}, nil) h.RemoveStreamHandler(protocol.TestingID) assert(nil, []protocol.ID{protocol.TestingID}) } @@ -273,9 +273,9 @@ func TestHostProtoPreference(t *testing.T) { defer h2.Close() const ( - protoOld = protocol.ID("/testing") - protoNew = protocol.ID("/testing/1.1.0") - protoMinor = protocol.ID("/testing/1.2.0") + protoOld = "/testing" + protoNew = "/testing/1.1.0" + protoMinor = "/testing/1.2.0" ) connectedOn := make(chan protocol.ID) @@ -299,7 +299,7 @@ func TestHostProtoPreference(t *testing.T) { assertWait(t, connectedOn, protoOld) s.Close() - h2.SetStreamHandlerMatch(protoMinor, func(string) bool { return true }, handler) + h2.SetStreamHandlerMatch(protoMinor, func(protocol.ID) bool { return true }, handler) // remembered preference will be chosen first, even when the other side newly supports it s2, err := h1.NewStream(context.Background(), h2.ID(), protoMinor, protoNew, protoOld) require.NoError(t, err) diff --git a/p2p/host/blank/blank.go b/p2p/host/blank/blank.go index 16753eb0d5..08f1332643 100644 --- a/p2p/host/blank/blank.go +++ b/p2p/host/blank/blank.go @@ -27,7 +27,7 @@ var log = logging.Logger("blankhost") // BlankHost is the thinnest implementation of the host.Host interface type BlankHost struct { n network.Network - mux *mstream.MultistreamMuxer + mux *mstream.MultistreamMuxer[protocol.ID] cmgr connmgr.ConnManager eventbus event.Bus emitters struct { @@ -65,7 +65,7 @@ func NewBlankHost(n network.Network, options ...Option) *BlankHost { bh := &BlankHost{ n: n, cmgr: cfg.cmgr, - mux: mstream.NewMultistreamMuxer(), + mux: mstream.NewMultistreamMuxer[protocol.ID](), } if bh.eventbus == nil { bh.eventbus = eventbus.NewBus() @@ -158,35 +158,29 @@ func (bh *BlankHost) NewStream(ctx context.Context, p peer.ID, protos ...protoco return nil, err } - protoStrs := make([]string, len(protos)) - for i, pid := range protos { - protoStrs[i] = string(pid) - } - - selected, err := mstream.SelectOneOf(protoStrs, s) + selected, err := mstream.SelectOneOf(protos, s) if err != nil { s.Reset() return nil, err } - selpid := protocol.ID(selected) - s.SetProtocol(selpid) + s.SetProtocol(selected) bh.Peerstore().AddProtocols(p, selected) return s, nil } func (bh *BlankHost) RemoveStreamHandler(pid protocol.ID) { - bh.Mux().RemoveHandler(string(pid)) + bh.Mux().RemoveHandler(pid) bh.emitters.evtLocalProtocolsUpdated.Emit(event.EvtLocalProtocolsUpdated{ Removed: []protocol.ID{pid}, }) } func (bh *BlankHost) SetStreamHandler(pid protocol.ID, handler network.StreamHandler) { - bh.Mux().AddHandler(string(pid), func(p string, rwc io.ReadWriteCloser) error { + bh.Mux().AddHandler(pid, func(p protocol.ID, rwc io.ReadWriteCloser) error { is := rwc.(network.Stream) - is.SetProtocol(protocol.ID(p)) + is.SetProtocol(p) handler(is) return nil }) @@ -195,10 +189,10 @@ func (bh *BlankHost) SetStreamHandler(pid protocol.ID, handler network.StreamHan }) } -func (bh *BlankHost) SetStreamHandlerMatch(pid protocol.ID, m func(string) bool, handler network.StreamHandler) { - bh.Mux().AddHandlerWithFunc(string(pid), m, func(p string, rwc io.ReadWriteCloser) error { +func (bh *BlankHost) SetStreamHandlerMatch(pid protocol.ID, m func(protocol.ID) bool, handler network.StreamHandler) { + bh.Mux().AddHandlerWithFunc(pid, m, func(p protocol.ID, rwc io.ReadWriteCloser) error { is := rwc.(network.Stream) - is.SetProtocol(protocol.ID(p)) + is.SetProtocol(p) handler(is) return nil }) @@ -216,7 +210,7 @@ func (bh *BlankHost) newStreamHandler(s network.Stream) { return } - s.SetProtocol(protocol.ID(protoID)) + s.SetProtocol(protoID) go handle(protoID, s) } diff --git a/p2p/host/peerstore/pstoreds/metadata.go b/p2p/host/peerstore/pstoreds/metadata.go index 64c24fd782..c4f2819458 100644 --- a/p2p/host/peerstore/pstoreds/metadata.go +++ b/p2p/host/peerstore/pstoreds/metadata.go @@ -8,6 +8,7 @@ import ( pool "github.com/libp2p/go-buffer-pool" "github.com/libp2p/go-libp2p/core/peer" pstore "github.com/libp2p/go-libp2p/core/peerstore" + "github.com/libp2p/go-libp2p/core/protocol" ds "github.com/ipfs/go-datastore" "github.com/ipfs/go-datastore/query" @@ -28,7 +29,7 @@ func init() { // Gob registers basic types by default. // // Register complex types used by the peerstore itself. - gob.Register(make(map[string]struct{})) + gob.Register(make(map[protocol.ID]struct{})) } // NewPeerMetadata creates a metadata store backed by a persistent db. It uses gob for serialisation. diff --git a/p2p/host/peerstore/pstoreds/protobook.go b/p2p/host/peerstore/pstoreds/protobook.go index cfc6ef7dd0..f5d76573b4 100644 --- a/p2p/host/peerstore/pstoreds/protobook.go +++ b/p2p/host/peerstore/pstoreds/protobook.go @@ -7,6 +7,7 @@ import ( "github.com/libp2p/go-libp2p/core/peer" pstore "github.com/libp2p/go-libp2p/core/peerstore" + "github.com/libp2p/go-libp2p/core/protocol" ) type protoSegment struct { @@ -58,12 +59,12 @@ func NewProtoBook(meta pstore.PeerMetadata, opts ...ProtoBookOption) (*dsProtoBo return pb, nil } -func (pb *dsProtoBook) SetProtocols(p peer.ID, protos ...string) error { +func (pb *dsProtoBook) SetProtocols(p peer.ID, protos ...protocol.ID) error { if len(protos) > pb.maxProtos { return errTooManyProtocols } - protomap := make(map[string]struct{}, len(protos)) + protomap := make(map[protocol.ID]struct{}, len(protos)) for _, proto := range protos { protomap[proto] = struct{}{} } @@ -75,7 +76,7 @@ func (pb *dsProtoBook) SetProtocols(p peer.ID, protos ...string) error { return pb.meta.Put(p, "protocols", protomap) } -func (pb *dsProtoBook) AddProtocols(p peer.ID, protos ...string) error { +func (pb *dsProtoBook) AddProtocols(p peer.ID, protos ...protocol.ID) error { s := pb.segments.get(p) s.Lock() defer s.Unlock() @@ -95,7 +96,7 @@ func (pb *dsProtoBook) AddProtocols(p peer.ID, protos ...string) error { return pb.meta.Put(p, "protocols", pmap) } -func (pb *dsProtoBook) GetProtocols(p peer.ID) ([]string, error) { +func (pb *dsProtoBook) GetProtocols(p peer.ID) ([]protocol.ID, error) { s := pb.segments.get(p) s.RLock() defer s.RUnlock() @@ -105,7 +106,7 @@ func (pb *dsProtoBook) GetProtocols(p peer.ID) ([]string, error) { return nil, err } - res := make([]string, 0, len(pmap)) + res := make([]protocol.ID, 0, len(pmap)) for proto := range pmap { res = append(res, proto) } @@ -113,7 +114,7 @@ func (pb *dsProtoBook) GetProtocols(p peer.ID) ([]string, error) { return res, nil } -func (pb *dsProtoBook) SupportsProtocols(p peer.ID, protos ...string) ([]string, error) { +func (pb *dsProtoBook) SupportsProtocols(p peer.ID, protos ...protocol.ID) ([]protocol.ID, error) { s := pb.segments.get(p) s.RLock() defer s.RUnlock() @@ -123,7 +124,7 @@ func (pb *dsProtoBook) SupportsProtocols(p peer.ID, protos ...string) ([]string, return nil, err } - res := make([]string, 0, len(protos)) + res := make([]protocol.ID, 0, len(protos)) for _, proto := range protos { if _, ok := pmap[proto]; ok { res = append(res, proto) @@ -133,7 +134,7 @@ func (pb *dsProtoBook) SupportsProtocols(p peer.ID, protos ...string) ([]string, return res, nil } -func (pb *dsProtoBook) FirstSupportedProtocol(p peer.ID, protos ...string) (string, error) { +func (pb *dsProtoBook) FirstSupportedProtocol(p peer.ID, protos ...protocol.ID) (protocol.ID, error) { s := pb.segments.get(p) s.RLock() defer s.RUnlock() @@ -151,7 +152,7 @@ func (pb *dsProtoBook) FirstSupportedProtocol(p peer.ID, protos ...string) (stri return "", nil } -func (pb *dsProtoBook) RemoveProtocols(p peer.ID, protos ...string) error { +func (pb *dsProtoBook) RemoveProtocols(p peer.ID, protos ...protocol.ID) error { s := pb.segments.get(p) s.Lock() defer s.Unlock() @@ -173,15 +174,15 @@ func (pb *dsProtoBook) RemoveProtocols(p peer.ID, protos ...string) error { return pb.meta.Put(p, "protocols", pmap) } -func (pb *dsProtoBook) getProtocolMap(p peer.ID) (map[string]struct{}, error) { +func (pb *dsProtoBook) getProtocolMap(p peer.ID) (map[protocol.ID]struct{}, error) { iprotomap, err := pb.meta.Get(p, "protocols") switch err { default: return nil, err case pstore.ErrNotFound: - return make(map[string]struct{}), nil + return make(map[protocol.ID]struct{}), nil case nil: - cast, ok := iprotomap.(map[string]struct{}) + cast, ok := iprotomap.(map[protocol.ID]struct{}) if !ok { return nil, fmt.Errorf("stored protocol set was not a map") } diff --git a/p2p/host/peerstore/pstoremem/protobook.go b/p2p/host/peerstore/pstoremem/protobook.go index 7a955c0769..0000f97ff1 100644 --- a/p2p/host/peerstore/pstoremem/protobook.go +++ b/p2p/host/peerstore/pstoremem/protobook.go @@ -6,11 +6,12 @@ import ( "github.com/libp2p/go-libp2p/core/peer" pstore "github.com/libp2p/go-libp2p/core/peerstore" + "github.com/libp2p/go-libp2p/core/protocol" ) type protoSegment struct { sync.RWMutex - protocols map[peer.ID]map[string]struct{} + protocols map[peer.ID]map[protocol.ID]struct{} } type protoSegments [256]*protoSegment @@ -27,7 +28,7 @@ type memoryProtoBook struct { maxProtos int lk sync.RWMutex - interned map[string]string + interned map[protocol.ID]protocol.ID } var _ pstore.ProtoBook = (*memoryProtoBook)(nil) @@ -43,11 +44,11 @@ func WithMaxProtocols(num int) ProtoBookOption { func NewProtoBook(opts ...ProtoBookOption) (*memoryProtoBook, error) { pb := &memoryProtoBook{ - interned: make(map[string]string, 256), + interned: make(map[protocol.ID]protocol.ID, 256), segments: func() (ret protoSegments) { for i := range ret { ret[i] = &protoSegment{ - protocols: make(map[peer.ID]map[string]struct{}), + protocols: make(map[peer.ID]map[protocol.ID]struct{}), } } return ret @@ -63,7 +64,7 @@ func NewProtoBook(opts ...ProtoBookOption) (*memoryProtoBook, error) { return pb, nil } -func (pb *memoryProtoBook) internProtocol(proto string) string { +func (pb *memoryProtoBook) internProtocol(proto protocol.ID) protocol.ID { // check if it is interned with the read lock pb.lk.RLock() interned, ok := pb.interned[proto] @@ -87,12 +88,12 @@ func (pb *memoryProtoBook) internProtocol(proto string) string { return proto } -func (pb *memoryProtoBook) SetProtocols(p peer.ID, protos ...string) error { +func (pb *memoryProtoBook) SetProtocols(p peer.ID, protos ...protocol.ID) error { if len(protos) > pb.maxProtos { return errTooManyProtocols } - newprotos := make(map[string]struct{}, len(protos)) + newprotos := make(map[protocol.ID]struct{}, len(protos)) for _, proto := range protos { newprotos[pb.internProtocol(proto)] = struct{}{} } @@ -105,14 +106,14 @@ func (pb *memoryProtoBook) SetProtocols(p peer.ID, protos ...string) error { return nil } -func (pb *memoryProtoBook) AddProtocols(p peer.ID, protos ...string) error { +func (pb *memoryProtoBook) AddProtocols(p peer.ID, protos ...protocol.ID) error { s := pb.segments.get(p) s.Lock() defer s.Unlock() protomap, ok := s.protocols[p] if !ok { - protomap = make(map[string]struct{}) + protomap = make(map[protocol.ID]struct{}) s.protocols[p] = protomap } if len(protomap)+len(protos) > pb.maxProtos { @@ -125,12 +126,12 @@ func (pb *memoryProtoBook) AddProtocols(p peer.ID, protos ...string) error { return nil } -func (pb *memoryProtoBook) GetProtocols(p peer.ID) ([]string, error) { +func (pb *memoryProtoBook) GetProtocols(p peer.ID) ([]protocol.ID, error) { s := pb.segments.get(p) s.RLock() defer s.RUnlock() - out := make([]string, 0, len(s.protocols[p])) + out := make([]protocol.ID, 0, len(s.protocols[p])) for k := range s.protocols[p] { out = append(out, k) } @@ -138,7 +139,7 @@ func (pb *memoryProtoBook) GetProtocols(p peer.ID) ([]string, error) { return out, nil } -func (pb *memoryProtoBook) RemoveProtocols(p peer.ID, protos ...string) error { +func (pb *memoryProtoBook) RemoveProtocols(p peer.ID, protos ...protocol.ID) error { s := pb.segments.get(p) s.Lock() defer s.Unlock() @@ -155,12 +156,12 @@ func (pb *memoryProtoBook) RemoveProtocols(p peer.ID, protos ...string) error { return nil } -func (pb *memoryProtoBook) SupportsProtocols(p peer.ID, protos ...string) ([]string, error) { +func (pb *memoryProtoBook) SupportsProtocols(p peer.ID, protos ...protocol.ID) ([]protocol.ID, error) { s := pb.segments.get(p) s.RLock() defer s.RUnlock() - out := make([]string, 0, len(protos)) + out := make([]protocol.ID, 0, len(protos)) for _, proto := range protos { if _, ok := s.protocols[p][proto]; ok { out = append(out, proto) @@ -170,7 +171,7 @@ func (pb *memoryProtoBook) SupportsProtocols(p peer.ID, protos ...string) ([]str return out, nil } -func (pb *memoryProtoBook) FirstSupportedProtocol(p peer.ID, protos ...string) (string, error) { +func (pb *memoryProtoBook) FirstSupportedProtocol(p peer.ID, protos ...protocol.ID) (protocol.ID, error) { s := pb.segments.get(p) s.RLock() defer s.RUnlock() diff --git a/p2p/host/peerstore/test/peerstore_suite.go b/p2p/host/peerstore/test/peerstore_suite.go index 325399dec3..576b8fd47a 100644 --- a/p2p/host/peerstore/test/peerstore_suite.go +++ b/p2p/host/peerstore/test/peerstore_suite.go @@ -12,6 +12,7 @@ import ( "github.com/libp2p/go-libp2p/core/crypto" "github.com/libp2p/go-libp2p/core/peer" pstore "github.com/libp2p/go-libp2p/core/peerstore" + "github.com/libp2p/go-libp2p/core/protocol" ma "github.com/multiformats/go-multiaddr" "github.com/stretchr/testify/require" @@ -44,6 +45,10 @@ func TestPeerstore(t *testing.T, factory PeerstoreFactory) { } } +func sortProtos(protos []protocol.ID) { + sort.Slice(protos, func(i, j int) bool { return protos[i] < protos[j] }) +} + func testAddrStream(ps pstore.Peerstore) func(t *testing.T) { return func(t *testing.T) { addrs, pid := getAddrs(t, 100), peer.ID("testpeer") @@ -209,14 +214,14 @@ func testPeerstoreProtoStore(ps pstore.Peerstore) func(t *testing.T) { return func(t *testing.T) { t.Run("adding and removing protocols", func(t *testing.T) { p1 := peer.ID("TESTPEER") - protos := []string{"a", "b", "c", "d"} + protos := []protocol.ID{"a", "b", "c", "d"} require.NoError(t, ps.AddProtocols(p1, protos...)) out, err := ps.GetProtocols(p1) require.NoError(t, err) require.Len(t, out, len(protos), "got wrong number of protocols back") - sort.Strings(out) + sortProtos(out) for i, p := range protos { if out[i] != p { t.Fatal("got wrong protocol") @@ -233,7 +238,7 @@ func testPeerstoreProtoStore(ps pstore.Peerstore) func(t *testing.T) { b, err := ps.FirstSupportedProtocol(p1, "q", "w", "a", "y", "b") require.NoError(t, err) - require.Equal(t, "a", b) + require.Equal(t, protocol.ID("a"), b) b, err = ps.FirstSupportedProtocol(p1, "q", "x", "z") require.NoError(t, err) @@ -241,9 +246,9 @@ func testPeerstoreProtoStore(ps pstore.Peerstore) func(t *testing.T) { b, err = ps.FirstSupportedProtocol(p1, "a") require.NoError(t, err) - require.Equal(t, "a", b) + require.Equal(t, protocol.ID("a"), b) - protos = []string{"other", "yet another", "one more"} + protos = []protocol.ID{"other", "yet another", "one more"} require.NoError(t, ps.SetProtocols(p1, protos...)) supported, err = ps.SupportsProtocols(p1, "q", "w", "a", "y", "b") @@ -253,8 +258,8 @@ func testPeerstoreProtoStore(ps pstore.Peerstore) func(t *testing.T) { supported, err = ps.GetProtocols(p1) require.NoError(t, err) - sort.Strings(supported) - sort.Strings(protos) + sortProtos(supported) + sortProtos(protos) if !reflect.DeepEqual(supported, protos) { t.Fatalf("expected previously set protos; expected: %v, have: %v", protos, supported) } @@ -270,7 +275,7 @@ func testPeerstoreProtoStore(ps pstore.Peerstore) func(t *testing.T) { t.Run("removing peer", func(t *testing.T) { p := peer.ID("foobar") - protos := []string{"a", "b"} + protos := []protocol.ID{"a", "b"} require.NoError(t, ps.SetProtocols(p, protos...)) out, err := ps.GetProtocols(p) @@ -383,9 +388,9 @@ func getAddrs(t *testing.T, n int) []ma.Multiaddr { func TestPeerstoreProtoStoreLimits(t *testing.T, ps pstore.Peerstore, limit int) { p := peer.ID("foobar") - protocols := make([]string, limit) + protocols := make([]protocol.ID, limit) for i := 0; i < limit; i++ { - protocols[i] = fmt.Sprintf("protocol %d", i) + protocols[i] = protocol.ID(fmt.Sprintf("protocol %d", i)) } t.Run("setting protocols", func(t *testing.T) { diff --git a/p2p/host/pstoremanager/mock_peerstore_test.go b/p2p/host/pstoremanager/mock_peerstore_test.go index 8bd6578850..fb9a282501 100644 --- a/p2p/host/pstoremanager/mock_peerstore_test.go +++ b/p2p/host/pstoremanager/mock_peerstore_test.go @@ -12,6 +12,7 @@ import ( gomock "github.com/golang/mock/gomock" crypto "github.com/libp2p/go-libp2p/core/crypto" peer "github.com/libp2p/go-libp2p/core/peer" + protocol "github.com/libp2p/go-libp2p/core/protocol" multiaddr "github.com/multiformats/go-multiaddr" ) @@ -77,7 +78,7 @@ func (mr *MockPeerstoreMockRecorder) AddPrivKey(arg0, arg1 interface{}) *gomock. } // AddProtocols mocks base method. -func (m *MockPeerstore) AddProtocols(arg0 peer.ID, arg1 ...string) error { +func (m *MockPeerstore) AddProtocols(arg0 peer.ID, arg1 ...protocol.ID) error { m.ctrl.T.Helper() varargs := []interface{}{arg0} for _, a := range arg1 { @@ -164,14 +165,14 @@ func (mr *MockPeerstoreMockRecorder) Close() *gomock.Call { } // FirstSupportedProtocol mocks base method. -func (m *MockPeerstore) FirstSupportedProtocol(arg0 peer.ID, arg1 ...string) (string, error) { +func (m *MockPeerstore) FirstSupportedProtocol(arg0 peer.ID, arg1 ...protocol.ID) (protocol.ID, error) { m.ctrl.T.Helper() varargs := []interface{}{arg0} for _, a := range arg1 { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "FirstSupportedProtocol", varargs...) - ret0, _ := ret[0].(string) + ret0, _ := ret[0].(protocol.ID) ret1, _ := ret[1].(error) return ret0, ret1 } @@ -199,10 +200,10 @@ func (mr *MockPeerstoreMockRecorder) Get(arg0, arg1 interface{}) *gomock.Call { } // GetProtocols mocks base method. -func (m *MockPeerstore) GetProtocols(arg0 peer.ID) ([]string, error) { +func (m *MockPeerstore) GetProtocols(arg0 peer.ID) ([]protocol.ID, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetProtocols", arg0) - ret0, _ := ret[0].([]string) + ret0, _ := ret[0].([]protocol.ID) ret1, _ := ret[1].(error) return ret0, ret1 } @@ -350,7 +351,7 @@ func (mr *MockPeerstoreMockRecorder) RemovePeer(arg0 interface{}) *gomock.Call { } // RemoveProtocols mocks base method. -func (m *MockPeerstore) RemoveProtocols(arg0 peer.ID, arg1 ...string) error { +func (m *MockPeerstore) RemoveProtocols(arg0 peer.ID, arg1 ...protocol.ID) error { m.ctrl.T.Helper() varargs := []interface{}{arg0} for _, a := range arg1 { @@ -393,7 +394,7 @@ func (mr *MockPeerstoreMockRecorder) SetAddrs(arg0, arg1, arg2 interface{}) *gom } // SetProtocols mocks base method. -func (m *MockPeerstore) SetProtocols(arg0 peer.ID, arg1 ...string) error { +func (m *MockPeerstore) SetProtocols(arg0 peer.ID, arg1 ...protocol.ID) error { m.ctrl.T.Helper() varargs := []interface{}{arg0} for _, a := range arg1 { @@ -412,14 +413,14 @@ func (mr *MockPeerstoreMockRecorder) SetProtocols(arg0 interface{}, arg1 ...inte } // SupportsProtocols mocks base method. -func (m *MockPeerstore) SupportsProtocols(arg0 peer.ID, arg1 ...string) ([]string, error) { +func (m *MockPeerstore) SupportsProtocols(arg0 peer.ID, arg1 ...protocol.ID) ([]protocol.ID, error) { m.ctrl.T.Helper() varargs := []interface{}{arg0} for _, a := range arg1 { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "SupportsProtocols", varargs...) - ret0, _ := ret[0].([]string) + ret0, _ := ret[0].([]protocol.ID) ret1, _ := ret[1].(error) return ret0, ret1 } diff --git a/p2p/host/resource-manager/extapi.go b/p2p/host/resource-manager/extapi.go index 302678e198..03edcd79ea 100644 --- a/p2p/host/resource-manager/extapi.go +++ b/p2p/host/resource-manager/extapi.go @@ -87,7 +87,7 @@ func (r *resourceManager) ListProtocols() []protocol.ID { } sort.Slice(result, func(i, j int) bool { - return strings.Compare(string(result[i]), string(result[j])) < 0 + return result[i] < result[j] }) return result diff --git a/p2p/host/routed/routed.go b/p2p/host/routed/routed.go index 4188d2fcd8..0cedd48f54 100644 --- a/p2p/host/routed/routed.go +++ b/p2p/host/routed/routed.go @@ -188,7 +188,7 @@ func (rh *RoutedHost) SetStreamHandler(pid protocol.ID, handler network.StreamHa rh.host.SetStreamHandler(pid, handler) } -func (rh *RoutedHost) SetStreamHandlerMatch(pid protocol.ID, m func(string) bool, handler network.StreamHandler) { +func (rh *RoutedHost) SetStreamHandlerMatch(pid protocol.ID, m func(protocol.ID) bool, handler network.StreamHandler) { rh.host.SetStreamHandlerMatch(pid, m, handler) } diff --git a/p2p/net/swarm/swarm_metrics.go b/p2p/net/swarm/swarm_metrics.go index fec82c17a1..16f431b53b 100644 --- a/p2p/net/swarm/swarm_metrics.go +++ b/p2p/net/swarm/swarm_metrics.go @@ -122,12 +122,12 @@ func appendConnectionState(tags []string, cs network.ConnectionState) []string { // This shouldn't happen, unless the transport doesn't properly set the Transport field in the ConnectionState. tags = append(tags, "unknown") } else { - tags = append(tags, cs.Transport) + tags = append(tags, string(cs.Transport)) } // These might be empty, depending on the transport. // For example, QUIC doesn't set security nor muxer. - tags = append(tags, cs.Security) - tags = append(tags, cs.StreamMultiplexer) + tags = append(tags, string(cs.Security)) + tags = append(tags, string(cs.StreamMultiplexer)) return tags } diff --git a/p2p/net/upgrader/conn.go b/p2p/net/upgrader/conn.go index 7a079b29fe..5db2175517 100644 --- a/p2p/net/upgrader/conn.go +++ b/p2p/net/upgrader/conn.go @@ -56,8 +56,8 @@ func (t *transportConn) Close() error { func (t *transportConn) ConnState() network.ConnectionState { return network.ConnectionState{ - StreamMultiplexer: string(t.muxer), - Security: string(t.security), + StreamMultiplexer: t.muxer, + Security: t.security, Transport: "tcp", } } diff --git a/p2p/net/upgrader/listener_test.go b/p2p/net/upgrader/listener_test.go index ce11b606a5..c6b1b8850b 100644 --- a/p2p/net/upgrader/listener_test.go +++ b/p2p/net/upgrader/listener_test.go @@ -405,7 +405,7 @@ func TestNoCommonSecurityProto(t *testing.T) { }() _, err = dial(t, ub, ln.Multiaddr(), idA, &network.NullScope{}) - require.EqualError(t, err, "failed to negotiate security protocol: protocol not supported") + require.ErrorContains(t, err, "failed to negotiate security protocol: protocols not supported") select { case <-done: t.Fatal("didn't expect to accept a connection") diff --git a/p2p/net/upgrader/upgrader.go b/p2p/net/upgrader/upgrader.go index 38c6faea4c..658e5df7b2 100644 --- a/p2p/net/upgrader/upgrader.go +++ b/p2p/net/upgrader/upgrader.go @@ -53,13 +53,13 @@ type upgrader struct { connGater connmgr.ConnectionGater rcmgr network.ResourceManager - muxerMuxer *mss.MultistreamMuxer + muxerMuxer *mss.MultistreamMuxer[protocol.ID] muxers []StreamMuxer - muxerIDs []string + muxerIDs []protocol.ID security []sec.SecureTransport - securityMuxer *mss.MultistreamMuxer - securityIDs []string + securityMuxer *mss.MultistreamMuxer[protocol.ID] + securityIDs []protocol.ID // AcceptTimeout is the maximum duration an Accept is allowed to take. // This includes the time between accepting the raw network connection, @@ -77,10 +77,10 @@ func New(security []sec.SecureTransport, muxers []StreamMuxer, psk ipnet.PSK, rc rcmgr: rcmgr, connGater: connGater, psk: psk, - muxerMuxer: mss.NewMultistreamMuxer(), + muxerMuxer: mss.NewMultistreamMuxer[protocol.ID](), muxers: muxers, security: security, - securityMuxer: mss.NewMultistreamMuxer(), + securityMuxer: mss.NewMultistreamMuxer[protocol.ID](), } for _, opt := range opts { if err := opt(u); err != nil { @@ -90,15 +90,15 @@ func New(security []sec.SecureTransport, muxers []StreamMuxer, psk ipnet.PSK, rc if u.rcmgr == nil { u.rcmgr = &network.NullResourceManager{} } - u.muxerIDs = make([]string, 0, len(muxers)) + u.muxerIDs = make([]protocol.ID, 0, len(muxers)) for _, m := range muxers { - u.muxerMuxer.AddHandler(string(m.ID), nil) - u.muxerIDs = append(u.muxerIDs, string(m.ID)) + u.muxerMuxer.AddHandler(m.ID, nil) + u.muxerIDs = append(u.muxerIDs, m.ID) } - u.securityIDs = make([]string, 0, len(security)) + u.securityIDs = make([]protocol.ID, 0, len(security)) for _, s := range security { - u.securityMuxer.AddHandler(string(s.ID()), nil) - u.securityIDs = append(u.securityIDs, string(s.ID())) + u.securityMuxer.AddHandler(s.ID(), nil) + u.securityIDs = append(u.securityIDs, s.ID()) } return u, nil } @@ -219,7 +219,7 @@ func (u *upgrader) negotiateMuxer(nc net.Conn, isServer bool) (*StreamMuxer, err return nil, err } - var proto string + var proto protocol.ID if isServer { selected, _, err := u.muxerMuxer.Negotiate(nc) if err != nil { @@ -244,9 +244,9 @@ func (u *upgrader) negotiateMuxer(nc net.Conn, isServer bool) (*StreamMuxer, err return nil, fmt.Errorf("selected protocol we don't have a transport for") } -func (u *upgrader) getMuxerByID(id string) *StreamMuxer { +func (u *upgrader) getMuxerByID(id protocol.ID) *StreamMuxer { for _, m := range u.muxers { - if string(m.ID) == id { + if m.ID == id { return &m } } @@ -265,7 +265,7 @@ func (u *upgrader) setupMuxer(ctx context.Context, conn sec.SecureConn, server b if err != nil { return "", nil, err } - return protocol.ID(muxerSelected), c, nil + return muxerSelected, c, nil } type result struct { @@ -298,9 +298,9 @@ func (u *upgrader) setupMuxer(ctx context.Context, conn sec.SecureConn, server b } } -func (u *upgrader) getSecurityByID(id string) sec.SecureTransport { +func (u *upgrader) getSecurityByID(id protocol.ID) sec.SecureTransport { for _, s := range u.security { - if string(s.ID()) == id { + if s.ID() == id { return s } } @@ -309,7 +309,7 @@ func (u *upgrader) getSecurityByID(id string) sec.SecureTransport { func (u *upgrader) negotiateSecurity(ctx context.Context, insecure net.Conn, server bool) (sec.SecureTransport, bool, error) { type result struct { - proto string + proto protocol.ID iamserver bool err error } diff --git a/p2p/protocol/circuitv2/client/reservation_test.go b/p2p/protocol/circuitv2/client/reservation_test.go index 15b0dd9dec..31e2865448 100644 --- a/p2p/protocol/circuitv2/client/reservation_test.go +++ b/p2p/protocol/circuitv2/client/reservation_test.go @@ -27,7 +27,7 @@ func TestReservationFailures(t *testing.T) { { name: "unsupported protocol", streamHandler: nil, - err: "protocol not supported", + err: "protocols not supported", }, { name: "wrong message type", diff --git a/p2p/protocol/holepunch/holepunch_test.go b/p2p/protocol/holepunch/holepunch_test.go index 0e9c520188..da8bf7c508 100644 --- a/p2p/protocol/holepunch/holepunch_test.go +++ b/p2p/protocol/holepunch/holepunch_test.go @@ -224,7 +224,7 @@ func TestFailuresOnInitiator(t *testing.T) { hps := addHolePunchService(t, h2, opts...) // wait until the hole punching protocol has actually started require.Eventually(t, func() bool { - protos, _ := h2.Peerstore().SupportsProtocols(h1.ID(), string(holepunch.Protocol)) + protos, _ := h2.Peerstore().SupportsProtocols(h1.ID(), holepunch.Protocol) return len(protos) > 0 }, 200*time.Millisecond, 10*time.Millisecond) diff --git a/p2p/protocol/identify/id.go b/p2p/protocol/identify/id.go index 4348d69cdd..2a1af713c7 100644 --- a/p2p/protocol/identify/id.go +++ b/p2p/protocol/identify/id.go @@ -500,7 +500,7 @@ func (ids *idService) createBaseIdentifyResponse( localAddr := conn.LocalMultiaddr() // set protocols this node is currently handling - mes.Protocols = snapshot.protocols + mes.Protocols = protocol.ConvertToStrings(snapshot.protocols) // observed address so other side is informed of their // "public" address, at least in relation to us. @@ -560,7 +560,7 @@ func (ids *idService) getSignedRecord(snapshot *identifySnapshot) []byte { } // diff takes two slices of strings (a and b) and computes which elements were added and removed in b -func diff(a, b []string) (added, removed []string) { +func diff(a, b []protocol.ID) (added, removed []protocol.ID) { // This is O(n^2), but it's fine because the slices are small. for _, x := range b { var found bool @@ -593,13 +593,14 @@ func (ids *idService) consumeMessage(mes *pb.Identify, c network.Conn, isPush bo p := c.RemotePeer() supported, _ := ids.Host.Peerstore().GetProtocols(p) - added, removed := diff(supported, mes.Protocols) - ids.Host.Peerstore().SetProtocols(p, mes.Protocols...) + mesProtocols := protocol.ConvertFromStrings(mes.Protocols) + added, removed := diff(supported, mesProtocols) + ids.Host.Peerstore().SetProtocols(p, mesProtocols...) if isPush { ids.emitters.evtPeerProtocolsUpdated.Emit(event.EvtPeerProtocolsUpdated{ Peer: p, - Added: protocol.ConvertFromStrings(added), - Removed: protocol.ConvertFromStrings(removed), + Added: added, + Removed: removed, }) } diff --git a/p2p/protocol/identify/id_test.go b/p2p/protocol/identify/id_test.go index 2a7305c46a..e95333ccee 100644 --- a/p2p/protocol/identify/id_test.go +++ b/p2p/protocol/identify/id_test.go @@ -383,7 +383,7 @@ func TestIdentifyPushWhileIdentifyingConn(t *testing.T) { handler := func(s network.Stream) { <-block w := pbio.NewDelimitedWriter(s) - w.WriteMsg(&pb.Identify{Protocols: h1.Mux().Protocols()}) + w.WriteMsg(&pb.Identify{Protocols: protocol.ConvertToStrings(h1.Mux().Protocols())}) s.Close() } h1.RemoveStreamHandler(identify.ID) @@ -587,14 +587,14 @@ func TestSendPush(t *testing.T) { // h1 starts listening on a new protocol and h2 finds out about that through a push h1.SetStreamHandler("rand", func(network.Stream) {}) require.Eventually(t, func() bool { - sup, err := h2.Peerstore().SupportsProtocols(h1.ID(), []string{"rand"}...) + sup, err := h2.Peerstore().SupportsProtocols(h1.ID(), []protocol.ID{"rand"}...) return err == nil && len(sup) == 1 && sup[0] == "rand" }, time.Second, 10*time.Millisecond) // h1 stops listening on a protocol and h2 finds out about it via a push h1.RemoveStreamHandler("rand") require.Eventually(t, func() bool { - sup, err := h2.Peerstore().SupportsProtocols(h1.ID(), []string{"rand"}...) + sup, err := h2.Peerstore().SupportsProtocols(h1.ID(), []protocol.ID{"rand"}...) return err == nil && len(sup) == 0 }, time.Second, 10*time.Millisecond) } @@ -613,9 +613,9 @@ func TestLargeIdentifyMessage(t *testing.T) { // add protocol strings to make the message larger // about 2K of protocol strings for i := 0; i < 500; i++ { - r := fmt.Sprintf("rand%d", i) - h1.SetStreamHandler(protocol.ID(r), func(network.Stream) {}) - h2.SetStreamHandler(protocol.ID(r), func(network.Stream) {}) + r := protocol.ID(fmt.Sprintf("rand%d", i)) + h1.SetStreamHandler(r, func(network.Stream) {}) + h2.SetStreamHandler(r, func(network.Stream) {}) } h1p := h1.ID() @@ -719,9 +719,9 @@ func TestLargePushMessage(t *testing.T) { // add protocol strings to make the message larger // about 2K of protocol strings for i := 0; i < 500; i++ { - r := fmt.Sprintf("rand%d", i) - h1.SetStreamHandler(protocol.ID(r), func(network.Stream) {}) - h2.SetStreamHandler(protocol.ID(r), func(network.Stream) {}) + r := protocol.ID(fmt.Sprintf("rand%d", i)) + h1.SetStreamHandler(r, func(network.Stream) {}) + h2.SetStreamHandler(r, func(network.Stream) {}) } h1p := h1.ID() diff --git a/p2p/protocol/identify/peer_loop.go b/p2p/protocol/identify/peer_loop.go index 589f04a471..5462dff616 100644 --- a/p2p/protocol/identify/peer_loop.go +++ b/p2p/protocol/identify/peer_loop.go @@ -16,7 +16,7 @@ import ( var errProtocolNotSupported = errors.New("protocol not supported") type identifySnapshot struct { - protocols []string + protocols []protocol.ID addrs []ma.Multiaddr record *record.Envelope } @@ -103,7 +103,7 @@ func (ph *peerHandler) sendPush(ctx context.Context) error { return nil } -func (ph *peerHandler) openStream(ctx context.Context, proto string) (network.Stream, error) { +func (ph *peerHandler) openStream(ctx context.Context, proto protocol.ID) (network.Stream, error) { // wait for the other peer to send us an Identify response on "all" connections we have with it // so we can look at it's supported protocols and avoid a multistream-select roundtrip to negotiate the protocol // if we know for a fact that it doesn't support the protocol. @@ -127,5 +127,5 @@ func (ph *peerHandler) openStream(ctx context.Context, proto string) (network.St // negotiate a stream without opening a new connection as we "should" already have a connection. ctx, cancel := context.WithTimeout(network.WithNoDial(ctx, "should already have connection"), 30*time.Second) defer cancel() - return ph.ids.Host.NewStream(ctx, ph.pid, protocol.ID(proto)) + return ph.ids.Host.NewStream(ctx, ph.pid, proto) } diff --git a/p2p/security/noise/session.go b/p2p/security/noise/session.go index 93ce5217be..803ed7cba3 100644 --- a/p2p/security/noise/session.go +++ b/p2p/security/noise/session.go @@ -12,6 +12,7 @@ import ( "github.com/libp2p/go-libp2p/core/crypto" "github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/core/peer" + "github.com/libp2p/go-libp2p/core/protocol" ) type secureSession struct { @@ -134,7 +135,7 @@ func (s *secureSession) Close() error { return s.insecureConn.Close() } -func SessionWithConnState(s *secureSession, muxer string) *secureSession { +func SessionWithConnState(s *secureSession, muxer protocol.ID) *secureSession { if s != nil { s.connectionState.StreamMultiplexer = muxer } diff --git a/p2p/security/noise/transport.go b/p2p/security/noise/transport.go index 6e2882b8bf..e42cea1bf7 100644 --- a/p2p/security/noise/transport.go +++ b/p2p/security/noise/transport.go @@ -23,7 +23,7 @@ type Transport struct { protocolID protocol.ID localID peer.ID privateKey crypto.PrivKey - muxers []string + muxers []protocol.ID } var _ sec.SecureTransport = &Transport{} @@ -36,16 +36,16 @@ func New(id protocol.ID, privkey crypto.PrivKey, muxers []tptu.StreamMuxer) (*Tr return nil, err } - smuxers := make([]string, 0, len(muxers)) + muxerIDs := make([]protocol.ID, 0, len(muxers)) for _, m := range muxers { - smuxers = append(smuxers, string(m.ID)) + muxerIDs = append(muxerIDs, m.ID) } return &Transport{ protocolID: id, localID: localID, privateKey: privkey, - muxers: smuxers, + muxers: muxerIDs, }, nil } @@ -87,7 +87,7 @@ func (t *Transport) ID() protocol.ID { return t.protocolID } -func matchMuxers(initiatorMuxers, responderMuxers []string) string { +func matchMuxers(initiatorMuxers, responderMuxers []protocol.ID) protocol.ID { for _, initMuxer := range initiatorMuxers { for _, respMuxer := range responderMuxers { if initMuxer == respMuxer { @@ -100,7 +100,7 @@ func matchMuxers(initiatorMuxers, responderMuxers []string) string { type transportEarlyDataHandler struct { transport *Transport - receivedMuxers []string + receivedMuxers []protocol.ID } var _ EarlyDataHandler = &transportEarlyDataHandler{} @@ -111,19 +111,19 @@ func newTransportEDH(t *Transport) *transportEarlyDataHandler { func (i *transportEarlyDataHandler) Send(context.Context, net.Conn, peer.ID) *pb.NoiseExtensions { return &pb.NoiseExtensions{ - StreamMuxers: i.transport.muxers, + StreamMuxers: protocol.ConvertToStrings(i.transport.muxers), } } func (i *transportEarlyDataHandler) Received(_ context.Context, _ net.Conn, extension *pb.NoiseExtensions) error { // Discard messages with size or the number of protocols exceeding extension limit for security. if extension != nil && len(extension.StreamMuxers) <= maxProtoNum { - i.receivedMuxers = extension.GetStreamMuxers() + i.receivedMuxers = protocol.ConvertFromStrings(extension.GetStreamMuxers()) } return nil } -func (i *transportEarlyDataHandler) MatchMuxers(isInitiator bool) string { +func (i *transportEarlyDataHandler) MatchMuxers(isInitiator bool) protocol.ID { if isInitiator { return matchMuxers(i.transport.muxers, i.receivedMuxers) } diff --git a/p2p/security/noise/transport_test.go b/p2p/security/noise/transport_test.go index 0912dab448..4006c9095a 100644 --- a/p2p/security/noise/transport_test.go +++ b/p2p/security/noise/transport_test.go @@ -15,6 +15,7 @@ import ( "github.com/libp2p/go-libp2p/core/crypto" "github.com/libp2p/go-libp2p/core/peer" + "github.com/libp2p/go-libp2p/core/protocol" "github.com/libp2p/go-libp2p/core/sec" "github.com/libp2p/go-libp2p/p2p/security/noise/pb" @@ -37,7 +38,7 @@ func newTestTransport(t *testing.T, typ, bits int) *Transport { } } -func newTestTransportWithMuxers(t *testing.T, typ, bits int, muxers []string) *Transport { +func newTestTransportWithMuxers(t *testing.T, typ, bits int, muxers []protocol.ID) *Transport { transport := newTestTransport(t, typ, bits) transport.muxers = muxers return transport @@ -632,9 +633,9 @@ func TestEarlyfffDataAcceptedWithNoHandler(t *testing.T) { } type noiseEarlyDataTestCase struct { - clientProtos []string - serverProtos []string - expectedResult string + clientProtos []protocol.ID + serverProtos []protocol.ID + expectedResult protocol.ID } func TestHandshakeWithTransportEarlyData(t *testing.T) { @@ -645,43 +646,43 @@ func TestHandshakeWithTransportEarlyData(t *testing.T) { expectedResult: "", }, { - clientProtos: []string{"muxer1"}, - serverProtos: []string{"muxer1"}, + clientProtos: []protocol.ID{"muxer1"}, + serverProtos: []protocol.ID{"muxer1"}, expectedResult: "muxer1", }, { - clientProtos: []string{"muxer1"}, - serverProtos: []string{}, + clientProtos: []protocol.ID{"muxer1"}, + serverProtos: []protocol.ID{}, expectedResult: "", }, { - clientProtos: []string{}, - serverProtos: []string{"muxer2"}, + clientProtos: []protocol.ID{}, + serverProtos: []protocol.ID{"muxer2"}, expectedResult: "", }, { - clientProtos: []string{"muxer2"}, - serverProtos: []string{"muxer1"}, + clientProtos: []protocol.ID{"muxer2"}, + serverProtos: []protocol.ID{"muxer1"}, expectedResult: "", }, { - clientProtos: []string{"muxer1", "muxer2"}, - serverProtos: []string{"muxer2", "muxer1"}, + clientProtos: []protocol.ID{"muxer1", "muxer2"}, + serverProtos: []protocol.ID{"muxer2", "muxer1"}, expectedResult: "muxer1", }, { - clientProtos: []string{"muxer3", "muxer2", "muxer1"}, - serverProtos: []string{"muxer2", "muxer1"}, + clientProtos: []protocol.ID{"muxer3", "muxer2", "muxer1"}, + serverProtos: []protocol.ID{"muxer2", "muxer1"}, expectedResult: "muxer2", }, { - clientProtos: []string{"muxer1", "muxer2"}, - serverProtos: []string{"muxer3"}, + clientProtos: []protocol.ID{"muxer1", "muxer2"}, + serverProtos: []protocol.ID{"muxer3"}, expectedResult: "", }, } - noiseHandshake := func(t *testing.T, initProtos, respProtos []string, expectedProto string) { + noiseHandshake := func(t *testing.T, initProtos, respProtos []protocol.ID, expectedProto protocol.ID) { initTransport := newTestTransportWithMuxers(t, crypto.Ed25519, 2048, initProtos) respTransport := newTestTransportWithMuxers(t, crypto.Ed25519, 2048, respProtos) diff --git a/p2p/security/tls/transport.go b/p2p/security/tls/transport.go index b036bf8911..3ca837e5ef 100644 --- a/p2p/security/tls/transport.go +++ b/p2p/security/tls/transport.go @@ -171,7 +171,7 @@ func (t *Transport) setupConn(tlsConn *tls.Conn, remotePubKey ci.PubKey) (sec.Se privKey: t.privKey, remotePeer: remotePeerID, remotePubKey: remotePubKey, - connectionState: network.ConnectionState{StreamMultiplexer: nextProto}, + connectionState: network.ConnectionState{StreamMultiplexer: protocol.ID(nextProto)}, }, nil } diff --git a/p2p/security/tls/transport_test.go b/p2p/security/tls/transport_test.go index 0f5247fbb9..ab7cb5f382 100644 --- a/p2p/security/tls/transport_test.go +++ b/p2p/security/tls/transport_test.go @@ -180,7 +180,7 @@ func TestHandshakeSucceeds(t *testing.T) { type testcase struct { clientProtos []protocol.ID serverProtos []protocol.ID - expectedResult string + expectedResult protocol.ID } func TestHandshakeWithNextProtoSucceeds(t *testing.T) { @@ -225,7 +225,7 @@ func TestHandshakeWithNextProtoSucceeds(t *testing.T) { clientID, clientKey := createPeer(t) serverID, serverKey := createPeer(t) - handshake := func(t *testing.T, clientTransport *Transport, serverTransport *Transport, expectedMuxer string) { + handshake := func(t *testing.T, clientTransport *Transport, serverTransport *Transport, expectedMuxer protocol.ID) { clientInsecureConn, serverInsecureConn := connect(t) serverConnChan := make(chan sec.SecureConn) diff --git a/p2p/test/negotiation/muxer_test.go b/p2p/test/negotiation/muxer_test.go index 9fc13fd312..9301e66c2c 100644 --- a/p2p/test/negotiation/muxer_test.go +++ b/p2p/test/negotiation/muxer_test.go @@ -69,7 +69,7 @@ func TestMuxerNegotiation(t *testing.T) { Name: "no preference overlap", ServerPreference: []libp2p.Option{yamuxOpt}, ClientPreference: []libp2p.Option{mplexOpt}, - Error: "failed to negotiate stream multiplexer: protocol not supported", + Error: "failed to negotiate stream multiplexer: protocols not supported", }, } @@ -119,7 +119,7 @@ func TestMuxerNegotiation(t *testing.T) { require.NoError(t, err) conns := client.Network().ConnsToPeer(server.ID()) require.Len(t, conns, 1, "expected exactly one connection") - require.Equal(t, tc.Expected, protocol.ID(conns[0].ConnState().StreamMultiplexer)) + require.Equal(t, tc.Expected, conns[0].ConnState().StreamMultiplexer) }) } } diff --git a/p2p/test/negotiation/security_test.go b/p2p/test/negotiation/security_test.go index 3c309e2422..b7324744bf 100644 --- a/p2p/test/negotiation/security_test.go +++ b/p2p/test/negotiation/security_test.go @@ -8,7 +8,6 @@ import ( "github.com/libp2p/go-libp2p" "github.com/libp2p/go-libp2p/core/crypto" "github.com/libp2p/go-libp2p/core/peer" - "github.com/libp2p/go-libp2p/core/protocol" "github.com/libp2p/go-libp2p/p2p/security/noise" tls "github.com/libp2p/go-libp2p/p2p/security/tls" "github.com/libp2p/go-libp2p/p2p/transport/tcp" @@ -45,7 +44,7 @@ func TestSecurityNegotiation(t *testing.T) { Name: "no overlap", ServerPreference: []libp2p.Option{noiseOpt}, ClientPreference: []libp2p.Option{tlsOpt}, - Error: "failed to negotiate security protocol: protocol not supported", + Error: "failed to negotiate security protocol: protocols not supported", }, } @@ -84,7 +83,7 @@ func TestSecurityNegotiation(t *testing.T) { require.NoError(t, err) conns := client.Network().ConnsToPeer(server.ID()) require.Len(t, conns, 1, "expected exactly one connection") - require.Equal(t, tc.Expected, protocol.ID(conns[0].ConnState().Security)) + require.Equal(t, tc.Expected, conns[0].ConnState().Security) }) } } diff --git a/test-plans/go.mod b/test-plans/go.mod index 3db83ad7aa..75557ea396 100644 --- a/test-plans/go.mod +++ b/test-plans/go.mod @@ -67,7 +67,7 @@ require ( github.com/multiformats/go-multibase v0.1.1 // indirect github.com/multiformats/go-multicodec v0.7.0 // indirect github.com/multiformats/go-multihash v0.2.1 // indirect - github.com/multiformats/go-multistream v0.3.3 // indirect + github.com/multiformats/go-multistream v0.4.0 // indirect github.com/multiformats/go-varint v0.0.7 // indirect github.com/onsi/ginkgo/v2 v2.5.1 // indirect github.com/opencontainers/runtime-spec v1.0.2 // indirect diff --git a/test-plans/go.sum b/test-plans/go.sum index d13b82c018..37218590bf 100644 --- a/test-plans/go.sum +++ b/test-plans/go.sum @@ -335,8 +335,8 @@ github.com/multiformats/go-multicodec v0.7.0/go.mod h1:GUC8upxSBE4oG+q3kWZRw/+6y github.com/multiformats/go-multihash v0.0.8/go.mod h1:YSLudS+Pi8NHE7o6tb3D8vrpKa63epEDmG8nTduyAew= github.com/multiformats/go-multihash v0.2.1 h1:aem8ZT0VA2nCHHk7bPJ1BjUbHNciqZC/d16Vve9l108= github.com/multiformats/go-multihash v0.2.1/go.mod h1:WxoMcYG85AZVQUyRyo9s4wULvW5qrI9vb2Lt6evduFc= -github.com/multiformats/go-multistream v0.3.3 h1:d5PZpjwRgVlbwfdTDjife7XszfZd8KYWfROYFlGcR8o= -github.com/multiformats/go-multistream v0.3.3/go.mod h1:ODRoqamLUsETKS9BNcII4gcRsJBU5VAwRIv7O39cEXg= +github.com/multiformats/go-multistream v0.4.0 h1:5i4JbawClkbuaX+mIVXiHQYVPxUW+zjv6w7jtSRukxc= +github.com/multiformats/go-multistream v0.4.0/go.mod h1:BS6ZSYcA4NwYEaIMeCtpJydp2Dc+fNRA6uJMSu/m8+4= github.com/multiformats/go-varint v0.0.1/go.mod h1:3Ls8CIEsrijN6+B7PbrXRPxHRPuXSrVKRY101jdMZYE= github.com/multiformats/go-varint v0.0.7 h1:sWSGR+f/eu5ABZA2ZpYKBILXTTs9JWpdEM/nEGOHFS8= github.com/multiformats/go-varint v0.0.7/go.mod h1:r8PUYw/fD/SjBCiKOoDlGF6QawOELpZAu9eioSos/OU=