Skip to content

Commit

Permalink
consistently use protocol.ID instead of strings (#2004)
Browse files Browse the repository at this point in the history
* Change PeerStore interface to use protocol.ID

This reduces the string to protocol.ID translations happening
at various places in the code

* Fix misc cases of protocol.ID conversion

* Merge multistream changes

* Use protocol.ID in network.ConnectionState

* don't update examples

* fix error message tests

* merge new go-multistream changes

* update test-plans go mod

* change transport back to string
  • Loading branch information
sukunrt authored Jan 27, 2023
1 parent 3919359 commit 6b9c116
Show file tree
Hide file tree
Showing 35 changed files with 204 additions and 198 deletions.
2 changes: 1 addition & 1 deletion core/host/host.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ type Host interface {

// SetStreamHandlerMatch sets the protocol handler on the Host's Mux
// using a matching function for protocol selection.
SetStreamHandlerMatch(protocol.ID, func(string) bool, network.StreamHandler)
SetStreamHandlerMatch(protocol.ID, func(protocol.ID) bool, network.StreamHandler)

// RemoveStreamHandler removes a handler on the mux that was set by
// SetStreamHandler
Expand Down
5 changes: 3 additions & 2 deletions core/network/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (

ic "github.com/libp2p/go-libp2p/core/crypto"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/libp2p/go-libp2p/core/protocol"

ma "github.com/multiformats/go-multiaddr"
)
Expand Down Expand Up @@ -37,9 +38,9 @@ type Conn interface {
// ConnectionState holds information about the connection.
type ConnectionState struct {
// The stream multiplexer used on this connection (if any). For example: /yamux/1.0.0
StreamMultiplexer string
StreamMultiplexer protocol.ID
// The security protocol used on this connection (if any). For example: /tls/1.0.0
Security string
Security protocol.ID
// the transport used on this connection. For example: tcp
Transport string
}
Expand Down
15 changes: 8 additions & 7 deletions core/peerstore/peerstore.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (

ic "github.com/libp2p/go-libp2p/core/crypto"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/libp2p/go-libp2p/core/protocol"
"github.com/libp2p/go-libp2p/core/record"

ma "github.com/multiformats/go-multiaddr"
Expand Down Expand Up @@ -230,19 +231,19 @@ type Metrics interface {

// ProtoBook tracks the protocols supported by peers.
type ProtoBook interface {
GetProtocols(peer.ID) ([]string, error)
AddProtocols(peer.ID, ...string) error
SetProtocols(peer.ID, ...string) error
RemoveProtocols(peer.ID, ...string) error
GetProtocols(peer.ID) ([]protocol.ID, error)
AddProtocols(peer.ID, ...protocol.ID) error
SetProtocols(peer.ID, ...protocol.ID) error
RemoveProtocols(peer.ID, ...protocol.ID) error

// SupportsProtocols returns the set of protocols the peer supports from among the given protocols.
// If the returned error is not nil, the result is indeterminate.
SupportsProtocols(peer.ID, ...string) ([]string, error)
SupportsProtocols(peer.ID, ...protocol.ID) ([]protocol.ID, error)

// FirstSupportedProtocol returns the first protocol that the peer supports among the given protocols.
// If the peer does not support any of the given protocols, this function will return an empty string and a nil error.
// If the peer does not support any of the given protocols, this function will return an empty protocol.ID and a nil error.
// If the returned error is not nil, the result is indeterminate.
FirstSupportedProtocol(peer.ID, ...string) (string, error)
FirstSupportedProtocol(peer.ID, ...protocol.ID) (protocol.ID, error)

// RemovePeer removes all protocols associated with a peer.
RemovePeer(peer.ID)
Expand Down
14 changes: 8 additions & 6 deletions core/protocol/switch.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package protocol

import (
"io"

"github.com/multiformats/go-multistream"
)

// HandlerFunc is a user-provided function used by the Router to
Expand All @@ -11,7 +13,7 @@ import (
// Will be invoked with the protocol ID string as the first argument,
// which may differ from the ID used for registration if the handler
// was registered using a match function.
type HandlerFunc = func(protocol string, rwc io.ReadWriteCloser) error
type HandlerFunc = multistream.HandlerFunc[ID]

// Router is an interface that allows users to add and remove protocol handlers,
// which will be invoked when incoming stream requests for registered protocols
Expand All @@ -25,7 +27,7 @@ type Router interface {

// AddHandler registers the given handler to be invoked for
// an exact literal match of the given protocol ID string.
AddHandler(protocol string, handler HandlerFunc)
AddHandler(protocol ID, handler HandlerFunc)

// AddHandlerWithFunc registers the given handler to be invoked
// when the provided match function returns true.
Expand All @@ -35,17 +37,17 @@ type Router interface {
// the protocol. Note that the protocol ID argument is not
// used for matching; if you want to match the protocol ID
// string exactly, you must check for it in your match function.
AddHandlerWithFunc(protocol string, match func(string) bool, handler HandlerFunc)
AddHandlerWithFunc(protocol ID, match func(ID) bool, handler HandlerFunc)

// RemoveHandler removes the registered handler (if any) for the
// given protocol ID string.
RemoveHandler(protocol string)
RemoveHandler(protocol ID)

// Protocols returns a list of all registered protocol ID strings.
// Note that the Router may be able to handle protocol IDs not
// included in this list if handlers were added with match functions
// using AddHandlerWithFunc.
Protocols() []string
Protocols() []ID
}

// Negotiator is a component capable of reaching agreement over what protocols
Expand All @@ -55,7 +57,7 @@ type Negotiator interface {
// inbound stream, returning after the protocol has been determined and the
// Negotiator has finished using the stream for negotiation. Returns an
// error if negotiation fails.
Negotiate(rwc io.ReadWriteCloser) (string, HandlerFunc, error)
Negotiate(rwc io.ReadWriteCloser) (ID, HandlerFunc, error)

// Handle calls Negotiate to determine which protocol handler to use for an
// inbound stream, then invokes the protocol handler function, passing it
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ require (
github.com/multiformats/go-multibase v0.1.1
github.com/multiformats/go-multicodec v0.7.0
github.com/multiformats/go-multihash v0.2.1
github.com/multiformats/go-multistream v0.3.3
github.com/multiformats/go-multistream v0.4.0
github.com/multiformats/go-varint v0.0.7
github.com/pbnjay/memory v0.0.0-20210728143218-7b4eea64cf58
github.com/prometheus/client_golang v1.14.0
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -384,8 +384,8 @@ github.com/multiformats/go-multicodec v0.7.0/go.mod h1:GUC8upxSBE4oG+q3kWZRw/+6y
github.com/multiformats/go-multihash v0.0.8/go.mod h1:YSLudS+Pi8NHE7o6tb3D8vrpKa63epEDmG8nTduyAew=
github.com/multiformats/go-multihash v0.2.1 h1:aem8ZT0VA2nCHHk7bPJ1BjUbHNciqZC/d16Vve9l108=
github.com/multiformats/go-multihash v0.2.1/go.mod h1:WxoMcYG85AZVQUyRyo9s4wULvW5qrI9vb2Lt6evduFc=
github.com/multiformats/go-multistream v0.3.3 h1:d5PZpjwRgVlbwfdTDjife7XszfZd8KYWfROYFlGcR8o=
github.com/multiformats/go-multistream v0.3.3/go.mod h1:ODRoqamLUsETKS9BNcII4gcRsJBU5VAwRIv7O39cEXg=
github.com/multiformats/go-multistream v0.4.0 h1:5i4JbawClkbuaX+mIVXiHQYVPxUW+zjv6w7jtSRukxc=
github.com/multiformats/go-multistream v0.4.0/go.mod h1:BS6ZSYcA4NwYEaIMeCtpJydp2Dc+fNRA6uJMSu/m8+4=
github.com/multiformats/go-varint v0.0.1/go.mod h1:3Ls8CIEsrijN6+B7PbrXRPxHRPuXSrVKRY101jdMZYE=
github.com/multiformats/go-varint v0.0.7 h1:sWSGR+f/eu5ABZA2ZpYKBILXTTs9JWpdEM/nEGOHFS8=
github.com/multiformats/go-varint v0.0.7/go.mod h1:r8PUYw/fD/SjBCiKOoDlGF6QawOELpZAu9eioSos/OU=
Expand Down
4 changes: 2 additions & 2 deletions p2p/host/autorelay/relay_finder.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ import (
)

const (
protoIDv1 = string(relayv1.ProtoID)
protoIDv2 = string(circuitv2_proto.ProtoIDv2Hop)
protoIDv1 = relayv1.ProtoID
protoIDv2 = circuitv2_proto.ProtoIDv2Hop
)

// Terminology:
Expand Down
37 changes: 17 additions & 20 deletions p2p/host/basic/basic_host.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ type BasicHost struct {

network network.Network
psManager *pstoremanager.PeerstoreManager
mux *msmux.MultistreamMuxer
mux *msmux.MultistreamMuxer[protocol.ID]
ids identify.IDService
hps *holepunch.Service
pings *ping.PingService
Expand Down Expand Up @@ -108,7 +108,7 @@ var _ host.Host = (*BasicHost)(nil)
// customize construction of the *BasicHost.
type HostOpts struct {
// MultistreamMuxer is essential for the *BasicHost and will use a sensible default value if omitted.
MultistreamMuxer *msmux.MultistreamMuxer
MultistreamMuxer *msmux.MultistreamMuxer[protocol.ID]

// NegotiationTimeout determines the read and write timeouts on streams.
// If 0 or omitted, it will use DefaultNegotiationTimeout.
Expand Down Expand Up @@ -168,7 +168,7 @@ func NewHost(n network.Network, opts *HostOpts) (*BasicHost, error) {
h := &BasicHost{
network: n,
psManager: psManager,
mux: msmux.NewMultistreamMuxer(),
mux: msmux.NewMultistreamMuxer[protocol.ID](),
negtimeout: DefaultNegotiationTimeout,
AddrsFactory: DefaultAddrsFactory,
maResolver: madns.DefaultResolver,
Expand Down Expand Up @@ -407,7 +407,7 @@ func (h *BasicHost) newStreamHandler(s network.Stream) {
}
}

if err := s.SetProtocol(protocol.ID(protoID)); err != nil {
if err := s.SetProtocol(protoID); err != nil {
log.Debugf("error setting stream protocol: %s", err)
s.Reset()
return
Expand Down Expand Up @@ -571,9 +571,9 @@ func (h *BasicHost) EventBus() event.Bus {
//
// (Threadsafe)
func (h *BasicHost) SetStreamHandler(pid protocol.ID, handler network.StreamHandler) {
h.Mux().AddHandler(string(pid), func(p string, rwc io.ReadWriteCloser) error {
h.Mux().AddHandler(pid, func(p protocol.ID, rwc io.ReadWriteCloser) error {
is := rwc.(network.Stream)
is.SetProtocol(protocol.ID(p))
is.SetProtocol(p)
handler(is)
return nil
})
Expand All @@ -584,10 +584,10 @@ func (h *BasicHost) SetStreamHandler(pid protocol.ID, handler network.StreamHand

// SetStreamHandlerMatch sets the protocol handler on the Host's Mux
// using a matching function to do protocol comparisons
func (h *BasicHost) SetStreamHandlerMatch(pid protocol.ID, m func(string) bool, handler network.StreamHandler) {
h.Mux().AddHandlerWithFunc(string(pid), m, func(p string, rwc io.ReadWriteCloser) error {
func (h *BasicHost) SetStreamHandlerMatch(pid protocol.ID, m func(protocol.ID) bool, handler network.StreamHandler) {
h.Mux().AddHandlerWithFunc(pid, m, func(p protocol.ID, rwc io.ReadWriteCloser) error {
is := rwc.(network.Stream)
is.SetProtocol(protocol.ID(p))
is.SetProtocol(p)
handler(is)
return nil
})
Expand All @@ -598,7 +598,7 @@ func (h *BasicHost) SetStreamHandlerMatch(pid protocol.ID, m func(string) bool,

// RemoveStreamHandler returns ..
func (h *BasicHost) RemoveStreamHandler(pid protocol.ID) {
h.Mux().RemoveHandler(string(pid))
h.Mux().RemoveHandler(pid)
h.emitters.evtLocalProtocolsUpdated.Emit(event.EvtLocalProtocolsUpdated{
Removed: []protocol.ID{pid},
})
Expand Down Expand Up @@ -637,28 +637,26 @@ func (h *BasicHost) NewStream(ctx context.Context, p peer.ID, pids ...protocol.I
return nil, ctx.Err()
}

pidStrings := protocol.ConvertToStrings(pids)

pref, err := h.preferredProtocol(p, pidStrings)
pref, err := h.preferredProtocol(p, pids)
if err != nil {
_ = s.Reset()
return nil, err
}

if pref != "" {
s.SetProtocol(pref)
lzcon := msmux.NewMSSelect(s, string(pref))
lzcon := msmux.NewMSSelect(s, pref)
return &streamWrapper{
Stream: s,
rw: lzcon,
}, nil
}

// Negotiate the protocol in the background, obeying the context.
var selected string
var selected protocol.ID
errCh := make(chan error, 1)
go func() {
selected, err = msmux.SelectOneOf(pidStrings, s)
selected, err = msmux.SelectOneOf(pids, s)
errCh <- err
}()
select {
Expand All @@ -674,21 +672,20 @@ func (h *BasicHost) NewStream(ctx context.Context, p peer.ID, pids ...protocol.I
return nil, ctx.Err()
}

selpid := protocol.ID(selected)
s.SetProtocol(selpid)
s.SetProtocol(selected)
h.Peerstore().AddProtocols(p, selected)
return s, nil
}

func (h *BasicHost) preferredProtocol(p peer.ID, pids []string) (protocol.ID, error) {
func (h *BasicHost) preferredProtocol(p peer.ID, pids []protocol.ID) (protocol.ID, error) {
supported, err := h.Peerstore().SupportsProtocols(p, pids...)
if err != nil {
return "", err
}

var out protocol.ID
if len(supported) > 0 {
out = protocol.ID(supported[0])
out = supported[0]
}
return out, nil
}
Expand Down
12 changes: 6 additions & 6 deletions p2p/host/basic/basic_host_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,8 @@ func TestProtocolHandlerEvents(t *testing.T) {

h.SetStreamHandler(protocol.TestingID, func(s network.Stream) {})
assert([]protocol.ID{protocol.TestingID}, nil)
h.SetStreamHandler(protocol.ID("foo"), func(s network.Stream) {})
assert([]protocol.ID{protocol.ID("foo")}, nil)
h.SetStreamHandler("foo", func(s network.Stream) {})
assert([]protocol.ID{"foo"}, nil)
h.RemoveStreamHandler(protocol.TestingID)
assert(nil, []protocol.ID{protocol.TestingID})
}
Expand Down Expand Up @@ -273,9 +273,9 @@ func TestHostProtoPreference(t *testing.T) {
defer h2.Close()

const (
protoOld = protocol.ID("/testing")
protoNew = protocol.ID("/testing/1.1.0")
protoMinor = protocol.ID("/testing/1.2.0")
protoOld = "/testing"
protoNew = "/testing/1.1.0"
protoMinor = "/testing/1.2.0"
)

connectedOn := make(chan protocol.ID)
Expand All @@ -299,7 +299,7 @@ func TestHostProtoPreference(t *testing.T) {
assertWait(t, connectedOn, protoOld)
s.Close()

h2.SetStreamHandlerMatch(protoMinor, func(string) bool { return true }, handler)
h2.SetStreamHandlerMatch(protoMinor, func(protocol.ID) bool { return true }, handler)
// remembered preference will be chosen first, even when the other side newly supports it
s2, err := h1.NewStream(context.Background(), h2.ID(), protoMinor, protoNew, protoOld)
require.NoError(t, err)
Expand Down
28 changes: 11 additions & 17 deletions p2p/host/blank/blank.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ var log = logging.Logger("blankhost")
// BlankHost is the thinnest implementation of the host.Host interface
type BlankHost struct {
n network.Network
mux *mstream.MultistreamMuxer
mux *mstream.MultistreamMuxer[protocol.ID]
cmgr connmgr.ConnManager
eventbus event.Bus
emitters struct {
Expand Down Expand Up @@ -65,7 +65,7 @@ func NewBlankHost(n network.Network, options ...Option) *BlankHost {
bh := &BlankHost{
n: n,
cmgr: cfg.cmgr,
mux: mstream.NewMultistreamMuxer(),
mux: mstream.NewMultistreamMuxer[protocol.ID](),
}
if bh.eventbus == nil {
bh.eventbus = eventbus.NewBus()
Expand Down Expand Up @@ -158,35 +158,29 @@ func (bh *BlankHost) NewStream(ctx context.Context, p peer.ID, protos ...protoco
return nil, err
}

protoStrs := make([]string, len(protos))
for i, pid := range protos {
protoStrs[i] = string(pid)
}

selected, err := mstream.SelectOneOf(protoStrs, s)
selected, err := mstream.SelectOneOf(protos, s)
if err != nil {
s.Reset()
return nil, err
}

selpid := protocol.ID(selected)
s.SetProtocol(selpid)
s.SetProtocol(selected)
bh.Peerstore().AddProtocols(p, selected)

return s, nil
}

func (bh *BlankHost) RemoveStreamHandler(pid protocol.ID) {
bh.Mux().RemoveHandler(string(pid))
bh.Mux().RemoveHandler(pid)
bh.emitters.evtLocalProtocolsUpdated.Emit(event.EvtLocalProtocolsUpdated{
Removed: []protocol.ID{pid},
})
}

func (bh *BlankHost) SetStreamHandler(pid protocol.ID, handler network.StreamHandler) {
bh.Mux().AddHandler(string(pid), func(p string, rwc io.ReadWriteCloser) error {
bh.Mux().AddHandler(pid, func(p protocol.ID, rwc io.ReadWriteCloser) error {
is := rwc.(network.Stream)
is.SetProtocol(protocol.ID(p))
is.SetProtocol(p)
handler(is)
return nil
})
Expand All @@ -195,10 +189,10 @@ func (bh *BlankHost) SetStreamHandler(pid protocol.ID, handler network.StreamHan
})
}

func (bh *BlankHost) SetStreamHandlerMatch(pid protocol.ID, m func(string) bool, handler network.StreamHandler) {
bh.Mux().AddHandlerWithFunc(string(pid), m, func(p string, rwc io.ReadWriteCloser) error {
func (bh *BlankHost) SetStreamHandlerMatch(pid protocol.ID, m func(protocol.ID) bool, handler network.StreamHandler) {
bh.Mux().AddHandlerWithFunc(pid, m, func(p protocol.ID, rwc io.ReadWriteCloser) error {
is := rwc.(network.Stream)
is.SetProtocol(protocol.ID(p))
is.SetProtocol(p)
handler(is)
return nil
})
Expand All @@ -216,7 +210,7 @@ func (bh *BlankHost) newStreamHandler(s network.Stream) {
return
}

s.SetProtocol(protocol.ID(protoID))
s.SetProtocol(protoID)

go handle(protoID, s)
}
Expand Down
Loading

0 comments on commit 6b9c116

Please sign in to comment.