Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

holepunch: add multiaddress filter #1839

Merged
merged 7 commits into from
Nov 17, 2022
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions p2p/protocol/holepunch/filter.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package holepunch

import (
"github.com/libp2p/go-libp2p/core/peer"
ma "github.com/multiformats/go-multiaddr"
)

// WithAddrFilter is a Service option that enables multiaddress filtering.
// It allows to only send a subset of observed addresses to the remote
// peer. E.g., only announce TCP or QUIC multi addresses instead of both.
// It also allows to only consider a subset of received multi addresses
// that remote peers announced to us.
// Theoretically, this API also allows to add multi addresses in both cases.
func WithAddrFilter(f AddrFilter) Option {
return func(hps *Service) error {
hps.filter = f
return nil
}
}

// AddrFilter defines the interface for the multi address filtering.
type AddrFilter interface {
// FilterLocal is a function that filters the multi addresses that we send to the remote peer.
dennis-tra marked this conversation as resolved.
Show resolved Hide resolved
FilterLocal(remoteID peer.ID, maddrs []ma.Multiaddr) []ma.Multiaddr
// FilterRemote is a function that filters the multi addresses which we received from the remote peer.
dennis-tra marked this conversation as resolved.
Show resolved Hide resolved
FilterRemote(remoteID peer.ID, maddrs []ma.Multiaddr) []ma.Multiaddr
}
79 changes: 63 additions & 16 deletions p2p/protocol/holepunch/holepunch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,21 @@ func (m *mockEventTracer) getEvents() []*holepunch.Event {

var _ holepunch.EventTracer = &mockEventTracer{}

type mockMaddrFilter struct {
filterLocal func(remoteID peer.ID, maddrs []ma.Multiaddr) []ma.Multiaddr
filterRemote func(remoteID peer.ID, maddrs []ma.Multiaddr) []ma.Multiaddr
}

func (m mockMaddrFilter) FilterLocal(remoteID peer.ID, maddrs []ma.Multiaddr) []ma.Multiaddr {
return m.filterLocal(remoteID, maddrs)
}

func (m mockMaddrFilter) FilterRemote(remoteID peer.ID, maddrs []ma.Multiaddr) []ma.Multiaddr {
return m.filterRemote(remoteID, maddrs)
}

var _ holepunch.AddrFilter = &mockMaddrFilter{}

type mockIDService struct {
identify.IDService
}
Expand Down Expand Up @@ -110,7 +125,7 @@ func TestDirectDialWorks(t *testing.T) {
func TestEndToEndSimConnect(t *testing.T) {
h1tr := &mockEventTracer{}
h2tr := &mockEventTracer{}
h1, h2, relay, _ := makeRelayedHosts(t, holepunch.WithTracer(h1tr), holepunch.WithTracer(h2tr), true)
h1, h2, relay, _ := makeRelayedHosts(t, []holepunch.Option{holepunch.WithTracer(h1tr)}, []holepunch.Option{holepunch.WithTracer(h2tr)}, true)
defer h1.Close()
defer h2.Close()
defer relay.Close()
Expand Down Expand Up @@ -151,6 +166,7 @@ func TestFailuresOnInitiator(t *testing.T) {
rhandler func(s network.Stream)
errMsg string
holePunchTimeout time.Duration
filter func(remoteID peer.ID, maddrs []ma.Multiaddr) []ma.Multiaddr
}{
"responder does NOT send a CONNECT message": {
rhandler: func(s network.Stream) {
Expand All @@ -175,6 +191,12 @@ func TestFailuresOnInitiator(t *testing.T) {
},
errMsg: "i/o deadline reached",
},
"no addrs after filtering": {
errMsg: "aborting hole punch initiation, as we have no public address after filtering",
filter: func(remoteID peer.ID, maddrs []ma.Multiaddr) []ma.Multiaddr {
return []ma.Multiaddr{}
},
},
}

for name, tc := range tcs {
Expand All @@ -190,7 +212,17 @@ func TestFailuresOnInitiator(t *testing.T) {
defer h1.Close()
defer h2.Close()
defer relay.Close()
hps := addHolePunchService(t, h2, holepunch.WithTracer(tr))

opts := []holepunch.Option{holepunch.WithTracer(tr)}
if tc.filter != nil {
f := mockMaddrFilter{
filterLocal: tc.filter,
filterRemote: tc.filter,
}
opts = append(opts, holepunch.WithAddrFilter(f))
}

hps := addHolePunchService(t, h2, opts...)

if tc.rhandler != nil {
h1.SetStreamHandler(holepunch.Protocol, tc.rhandler)
Expand Down Expand Up @@ -221,6 +253,7 @@ func TestFailuresOnResponder(t *testing.T) {
initiator func(s network.Stream)
errMsg string
holePunchTimeout time.Duration
filter func(remoteID peer.ID, maddrs []ma.Multiaddr) []ma.Multiaddr
}{
"initiator does NOT send a CONNECT message": {
initiator: func(s network.Stream) {
Expand Down Expand Up @@ -258,6 +291,19 @@ func TestFailuresOnResponder(t *testing.T) {
},
errMsg: "expected CONNECT message to contain at least one address",
},
"no addrs after filtering": {
errMsg: "rejecting hole punch request, as we don't have any public addresses",
initiator: func(s network.Stream) {
protoio.NewDelimitedWriter(s).WriteMsg(&holepunch_pb.HolePunch{
Type: holepunch_pb.HolePunch_CONNECT.Enum(),
ObsAddrs: addrsToBytes([]ma.Multiaddr{ma.StringCast("/ip4/127.0.0.1/tcp/1234")}),
})
time.Sleep(10 * time.Second)
},
filter: func(remoteID peer.ID, maddrs []ma.Multiaddr) []ma.Multiaddr {
return []ma.Multiaddr{}
},
},
}

for name, tc := range tcs {
Expand All @@ -267,9 +313,18 @@ func TestFailuresOnResponder(t *testing.T) {
holepunch.StreamTimeout = tc.holePunchTimeout
defer func() { holepunch.StreamTimeout = cpy }()
}

tr := &mockEventTracer{}
h1, h2, relay, _ := makeRelayedHosts(t, holepunch.WithTracer(tr), nil, false)

opts := []holepunch.Option{holepunch.WithTracer(tr)}
if tc.filter != nil {
f := mockMaddrFilter{
filterLocal: tc.filter,
filterRemote: tc.filter,
}
opts = append(opts, holepunch.WithAddrFilter(f))
}

h1, h2, relay, _ := makeRelayedHosts(t, opts, nil, false)
defer h1.Close()
defer h2.Close()
defer relay.Close()
Expand Down Expand Up @@ -377,13 +432,9 @@ func mkHostWithStaticAutoRelay(t *testing.T, relay host.Host) host.Host {
return h
}

func makeRelayedHosts(t *testing.T, h1opt, h2opt holepunch.Option, addHolePuncher bool) (h1, h2, relay host.Host, hps *holepunch.Service) {
func makeRelayedHosts(t *testing.T, h1opt, h2opt []holepunch.Option, addHolePuncher bool) (h1, h2, relay host.Host, hps *holepunch.Service) {
t.Helper()
var h1opts []holepunch.Option
if h1opt != nil {
h1opts = append(h1opts, h1opt)
}
h1, _ = mkHostWithHolePunchSvc(t, h1opts...)
h1, _ = mkHostWithHolePunchSvc(t, h1opt...)
var err error
relay, err = libp2p.New(libp2p.ListenAddrs(ma.StringCast("/ip4/127.0.0.1/tcp/0")), libp2p.DisableRelay())
require.NoError(t, err)
Expand All @@ -393,7 +444,7 @@ func makeRelayedHosts(t *testing.T, h1opt, h2opt holepunch.Option, addHolePunche

h2 = mkHostWithStaticAutoRelay(t, relay)
if addHolePuncher {
hps = addHolePunchService(t, h2, h2opt)
hps = addHolePunchService(t, h2, h2opt...)
}

// h1 has a relay addr
Expand All @@ -413,12 +464,8 @@ func makeRelayedHosts(t *testing.T, h1opt, h2opt holepunch.Option, addHolePunche
return
}

func addHolePunchService(t *testing.T, h host.Host, opt holepunch.Option) *holepunch.Service {
func addHolePunchService(t *testing.T, h host.Host, opts ...holepunch.Option) *holepunch.Service {
t.Helper()
var opts []holepunch.Option
if opt != nil {
opts = append(opts, opt)
}
hps, err := holepunch.NewService(h, newMockIDService(t, h), opts...)
require.NoError(t, err)
return hps
Expand Down
19 changes: 17 additions & 2 deletions p2p/protocol/holepunch/holepuncher.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,16 @@ type holePuncher struct {
closed bool

tracer *tracer
filter AddrFilter
}

func newHolePuncher(h host.Host, ids identify.IDService, tracer *tracer) *holePuncher {
func newHolePuncher(h host.Host, ids identify.IDService, tracer *tracer, filter AddrFilter) *holePuncher {
hp := &holePuncher{
host: h,
ids: ids,
active: make(map[peer.ID]struct{}),
tracer: tracer,
filter: filter,
}
hp.ctx, hp.ctxCancel = context.WithCancel(context.Background())
h.Network().Notify((*netNotifiee)(hp))
Expand Down Expand Up @@ -204,10 +206,18 @@ func (hp *holePuncher) initiateHolePunchImpl(str network.Stream) ([]ma.Multiaddr
str.SetDeadline(time.Now().Add(StreamTimeout))

// send a CONNECT and start RTT measurement.
obsAddrs := removeRelayAddrs(hp.ids.OwnObservedAddrs())
if hp.filter != nil {
obsAddrs = hp.filter.FilterLocal(str.Conn().RemotePeer(), obsAddrs)
}
if len(obsAddrs) == 0 {
return nil, 0, errors.New("aborting hole punch initiation as we have no public address")
}

start := time.Now()
if err := w.WriteMsg(&pb.HolePunch{
Type: pb.HolePunch_CONNECT.Enum(),
ObsAddrs: addrsToBytes(removeRelayAddrs(hp.ids.OwnObservedAddrs())),
ObsAddrs: addrsToBytes(obsAddrs),
}); err != nil {
str.Reset()
return nil, 0, err
Expand All @@ -222,7 +232,12 @@ func (hp *holePuncher) initiateHolePunchImpl(str network.Stream) ([]ma.Multiaddr
if t := msg.GetType(); t != pb.HolePunch_CONNECT {
return nil, 0, fmt.Errorf("expect CONNECT message, got %s", t)
}

addrs := removeRelayAddrs(addrsFromBytes(msg.ObsAddrs))
if hp.filter != nil {
addrs = hp.filter.FilterRemote(str.Conn().RemotePeer(), addrs)
}

if len(addrs) == 0 {
return nil, 0, errors.New("didn't receive any public addresses in CONNECT")
}
Expand Down
12 changes: 11 additions & 1 deletion p2p/protocol/holepunch/svc.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ type Service struct {
hasPublicAddrsChan chan struct{}

tracer *tracer
filter AddrFilter

refCount sync.WaitGroup
}
Expand Down Expand Up @@ -140,7 +141,7 @@ func (s *Service) watchForPublicAddr() {
continue
}
s.holePuncherMx.Lock()
s.holePuncher = newHolePuncher(s.host, s.ids, s.tracer)
s.holePuncher = newHolePuncher(s.host, s.ids, s.tracer, s.filter)
s.holePuncherMx.Unlock()
close(s.hasPublicAddrsChan)
return
Expand Down Expand Up @@ -169,6 +170,10 @@ func (s *Service) incomingHolePunch(str network.Stream) (rtt time.Duration, addr
return 0, nil, fmt.Errorf("received hole punch stream: %s", str.Conn().RemoteMultiaddr())
}
ownAddrs := removeRelayAddrs(s.ids.OwnObservedAddrs())
if s.filter != nil {
ownAddrs = s.filter.FilterLocal(str.Conn().RemotePeer(), ownAddrs)
}

// If we can't tell the peer where to dial us, there's no point in starting the hole punching.
if len(ownAddrs) == 0 {
return 0, nil, errors.New("rejecting hole punch request, as we don't have any public addresses")
Expand All @@ -194,7 +199,12 @@ func (s *Service) incomingHolePunch(str network.Stream) (rtt time.Duration, addr
if t := msg.GetType(); t != pb.HolePunch_CONNECT {
return 0, nil, fmt.Errorf("expected CONNECT message from initiator but got %d", t)
}

obsDial := removeRelayAddrs(addrsFromBytes(msg.ObsAddrs))
if s.filter != nil {
obsDial = s.filter.FilterRemote(str.Conn().RemotePeer(), obsDial)
}

log.Debugw("received hole punch request", "peer", str.Conn().RemotePeer(), "addrs", obsDial)
if len(obsDial) == 0 {
return 0, nil, errors.New("expected CONNECT message to contain at least one address")
Expand Down