diff --git a/p2p/protocol/autonatv2/autonat.go b/p2p/protocol/autonatv2/autonat.go index b43c2fb5f5..df49dae356 100644 --- a/p2p/protocol/autonatv2/autonat.go +++ b/p2p/protocol/autonatv2/autonat.go @@ -64,8 +64,9 @@ type Result struct { Status pb.DialStatus } -// AutoNAT implements the AutoNAT v2 client and server. Users can check reachability -// for their addresses using the CheckReachability method. +// AutoNAT implements the AutoNAT v2 client and server. +// Users can check reachability for their addresses using the CheckReachability method. +// The server provides amplification attack prevention and rate limiting. type AutoNAT struct { host host.Host sub event.Subscription @@ -140,6 +141,7 @@ func (an *AutoNAT) background() { select { case <-an.ctx.Done(): an.srv.Disable() + an.srv.Close() an.peers = nil an.wg.Done() return diff --git a/p2p/protocol/autonatv2/autonat_test.go b/p2p/protocol/autonatv2/autonat_test.go index 1873b6d897..f1e8299cb2 100644 --- a/p2p/protocol/autonatv2/autonat_test.go +++ b/p2p/protocol/autonatv2/autonat_test.go @@ -2,8 +2,8 @@ package autonatv2 import ( "context" + "errors" "fmt" - "reflect" "sync/atomic" "testing" "time" @@ -19,6 +19,7 @@ import ( "github.com/libp2p/go-msgio/pbio" ma "github.com/multiformats/go-multiaddr" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -52,6 +53,7 @@ func parseAddrs(t *testing.T, msg *pb.Message) []ma.Multiaddr { return addrs } +// idAndConnect identifies b to a and connects them func idAndConnect(t *testing.T, a, b host.Host) { a.Peerstore().AddAddrs(b.ID(), b.Addrs(), peerstore.PermanentAddrTTL) a.Peerstore().AddProtocols(b.ID(), DialProtocol) @@ -60,7 +62,7 @@ func idAndConnect(t *testing.T, a, b host.Host) { require.NoError(t, err) } -// waitForPeer waits for a to process all peer events +// waitForPeer waits for a to have 1 peer in the peerMap func waitForPeer(t *testing.T, a *AutoNAT) { t.Helper() require.Eventually(t, func() bool { @@ -70,8 +72,8 @@ func waitForPeer(t *testing.T, a *AutoNAT) { }, 5*time.Second, 100*time.Millisecond) } -// identify provides server address and protocol to client -func identify(t *testing.T, cli *AutoNAT, srv *AutoNAT) { +// idAndWait provides server address and protocol to client +func idAndWait(t *testing.T, cli *AutoNAT, srv *AutoNAT) { idAndConnect(t, cli.host, srv.host) waitForPeer(t, cli) } @@ -80,187 +82,204 @@ func TestAutoNATPrivateAddr(t *testing.T) { an := newAutoNAT(t, nil) res, err := an.CheckReachability(context.Background(), []Request{{Addr: ma.StringCast("/ip4/192.168.0.1/udp/10/quic-v1")}}) require.Equal(t, res, Result{}) - require.NotNil(t, err) + require.Contains(t, err.Error(), "private address cannot be verified by autonatv2") } func TestClientRequest(t *testing.T) { an := newAutoNAT(t, nil, allowAllAddrs) + defer an.Close() + defer an.host.Close() + + b := bhost.NewBlankHost(swarmt.GenSwarm(t)) + defer b.Close() + idAndConnect(t, an.host, b) + waitForPeer(t, an) addrs := an.host.Addrs() + addrbs := make([][]byte, len(addrs)) + for i := 0; i < len(addrs); i++ { + addrbs[i] = addrs[i].Bytes() + } - var gotReq atomic.Bool - p := bhost.NewBlankHost(swarmt.GenSwarm(t)) - p.SetStreamHandler(DialProtocol, func(s network.Stream) { - gotReq.Store(true) + var receivedRequest atomic.Bool + b.SetStreamHandler(DialProtocol, func(s network.Stream) { + receivedRequest.Store(true) r := pbio.NewDelimitedReader(s, maxMsgSize) var msg pb.Message - if err := r.ReadMsg(&msg); err != nil { - t.Error(err) - return - } - if msg.GetDialRequest() == nil { - t.Errorf("expected message to be of type DialRequest, got %T", msg.Msg) - return - } - addrsb := make([][]byte, len(addrs)) - for i := 0; i < len(addrs); i++ { - addrsb[i] = addrs[i].Bytes() - } - if !reflect.DeepEqual(addrsb, msg.GetDialRequest().Addrs) { - t.Errorf("expected elements to be equal want: %s got: %s", addrsb, msg.GetDialRequest().Addrs) - } + assert.NoError(t, r.ReadMsg(&msg)) + assert.NotNil(t, msg.GetDialRequest()) + assert.Equal(t, addrbs, msg.GetDialRequest().Addrs) s.Reset() }) - idAndConnect(t, an.host, p) - waitForPeer(t, an) - - res, err := an.CheckReachability( - context.Background(), - []Request{ - {Addr: addrs[0], SendDialData: true}, - {Addr: addrs[1]}, - }) + res, err := an.CheckReachability(context.Background(), []Request{ + {Addr: addrs[0], SendDialData: true}, {Addr: addrs[1]}, + }) require.Equal(t, res, Result{}) require.NotNil(t, err) - require.True(t, gotReq.Load()) + require.True(t, receivedRequest.Load()) } func TestClientServerError(t *testing.T) { an := newAutoNAT(t, nil, allowAllAddrs) - addrs := an.host.Addrs() + defer an.Close() + defer an.host.Close() - p := bhost.NewBlankHost(swarmt.GenSwarm(t)) - idAndConnect(t, an.host, p) + b := bhost.NewBlankHost(swarmt.GenSwarm(t)) + defer b.Close() + idAndConnect(t, an.host, b) waitForPeer(t, an) - done := make(chan bool) tests := []struct { - handler func(network.Stream) + handler func(network.Stream) + errorStr string }{ - {handler: func(s network.Stream) { - s.Reset() - done <- true - }}, - {handler: func(s network.Stream) { - r := pbio.NewDelimitedReader(s, maxMsgSize) - var msg pb.Message - r.ReadMsg(&msg) - w := pbio.NewDelimitedWriter(s) - w.WriteMsg(&pb.Message{ - Msg: &pb.Message_DialRequest{ - DialRequest: &pb.DialRequest{ - Addrs: [][]byte{}, - Nonce: 0, - }, - }, - }) - if err := r.ReadMsg(&msg); err == nil { - t.Errorf("expected read to fail: %T", msg.Msg) - } - done <- true - }}, + { + handler: func(s network.Stream) { + s.Reset() + }, + errorStr: "stream reset", + }, + { + handler: func(s network.Stream) { + w := pbio.NewDelimitedWriter(s) + assert.NoError(t, w.WriteMsg( + &pb.Message{Msg: &pb.Message_DialRequest{DialRequest: &pb.DialRequest{}}})) + }, + errorStr: "invalid msg type", + }, + { + handler: func(s network.Stream) { + w := pbio.NewDelimitedWriter(s) + assert.NoError(t, w.WriteMsg( + &pb.Message{Msg: &pb.Message_DialResponse{ + DialResponse: &pb.DialResponse{ + Status: pb.DialResponse_E_DIAL_REFUSED, + }, + }}, + )) + }, + errorStr: ErrDialRefused.Error(), + }, } for i, tc := range tests { t.Run(fmt.Sprintf("test-%d", i), func(t *testing.T) { - p.SetStreamHandler(DialProtocol, tc.handler) + b.SetStreamHandler(DialProtocol, tc.handler) + addrs := an.host.Addrs() res, err := an.CheckReachability( context.Background(), - []Request{ - {Addr: addrs[0], SendDialData: true}, - {Addr: addrs[1]}, - }) + newTestRequests(addrs, false)) require.Equal(t, res, Result{}) require.NotNil(t, err) - <-done + require.Contains(t, err.Error(), tc.errorStr) }) } } func TestClientDataRequest(t *testing.T) { an := newAutoNAT(t, nil, allowAllAddrs) - addrs := an.host.Addrs() + defer an.Close() + defer an.host.Close() - p := bhost.NewBlankHost(swarmt.GenSwarm(t)) - idAndConnect(t, an.host, p) + b := bhost.NewBlankHost(swarmt.GenSwarm(t)) + defer b.Close() + idAndConnect(t, an.host, b) waitForPeer(t, an) - done := make(chan bool) tests := []struct { handler func(network.Stream) + name string }{ - {handler: func(s network.Stream) { - r := pbio.NewDelimitedReader(s, maxMsgSize) - var msg pb.Message - r.ReadMsg(&msg) - w := pbio.NewDelimitedWriter(s) - w.WriteMsg(&pb.Message{ - Msg: &pb.Message_DialDataRequest{ - DialDataRequest: &pb.DialDataRequest{ - AddrIdx: 0, - NumBytes: 10000, - }, - }}, - ) - remain := 10000 - for remain > 0 { - if err := r.ReadMsg(&msg); err != nil { - t.Errorf("expected a valid data response") - break + { + name: "provides dial data", + handler: func(s network.Stream) { + r := pbio.NewDelimitedReader(s, maxMsgSize) + var msg pb.Message + assert.NoError(t, r.ReadMsg(&msg)) + w := pbio.NewDelimitedWriter(s) + if err := w.WriteMsg(&pb.Message{ + Msg: &pb.Message_DialDataRequest{ + DialDataRequest: &pb.DialDataRequest{ + AddrIdx: 0, + NumBytes: 10000, + }, + }}, + ); err != nil { + t.Error(err) + s.Reset() + return } - if msg.GetDialDataResponse() == nil { - t.Errorf("expected type DialDataResponse got %T", msg.Msg) - break + var dialData []byte + for len(dialData) < 10000 { + if err := r.ReadMsg(&msg); err != nil { + t.Error(err) + s.Reset() + return + } + if msg.GetDialDataResponse() == nil { + t.Errorf("expected to receive msg of type DialDataResponse") + s.Reset() + return + } + dialData = append(dialData, msg.GetDialDataResponse().Data...) } - remain -= len(msg.GetDialDataResponse().Data) - } - s.Reset() - done <- true - }}, - {handler: func(s network.Stream) { - r := pbio.NewDelimitedReader(s, maxMsgSize) - var msg pb.Message - r.ReadMsg(&msg) - w := pbio.NewDelimitedWriter(s) - w.WriteMsg(&pb.Message{ - Msg: &pb.Message_DialDataRequest{ - DialDataRequest: &pb.DialDataRequest{ - AddrIdx: 1, - NumBytes: 10000, - }, - }}, - ) - if err := r.ReadMsg(&msg); err == nil { - t.Errorf("expected to reject data request for low priority address") - } - s.Reset() - done <- true - }}, - {handler: func(s network.Stream) { - r := pbio.NewDelimitedReader(s, maxMsgSize) - var msg pb.Message - r.ReadMsg(&msg) - w := pbio.NewDelimitedWriter(s) - w.WriteMsg(&pb.Message{ - Msg: &pb.Message_DialDataRequest{ - DialDataRequest: &pb.DialDataRequest{ - AddrIdx: 0, - NumBytes: 1000_000, - }, - }}, - ) - if err := r.ReadMsg(&msg); err == nil { - t.Errorf("expected to reject request for 1MB dial data") - } - s.Reset() - done <- true - }}, + s.Reset() + }, + }, + { + name: "low priority addr", + handler: func(s network.Stream) { + r := pbio.NewDelimitedReader(s, maxMsgSize) + var msg pb.Message + assert.NoError(t, r.ReadMsg(&msg)) + w := pbio.NewDelimitedWriter(s) + if err := w.WriteMsg(&pb.Message{ + Msg: &pb.Message_DialDataRequest{ + DialDataRequest: &pb.DialDataRequest{ + AddrIdx: 1, + NumBytes: 10000, + }, + }}, + ); err != nil { + t.Error(err) + s.Reset() + return + } + assert.Error(t, r.ReadMsg(&msg)) + s.Reset() + }, + }, + { + name: "too high dial data request", + handler: func(s network.Stream) { + r := pbio.NewDelimitedReader(s, maxMsgSize) + var msg pb.Message + assert.NoError(t, r.ReadMsg(&msg)) + w := pbio.NewDelimitedWriter(s) + if err := w.WriteMsg(&pb.Message{ + Msg: &pb.Message_DialDataRequest{ + DialDataRequest: &pb.DialDataRequest{ + AddrIdx: 0, + NumBytes: 1 << 32, + }, + }}, + ); err != nil { + t.Error(err) + s.Reset() + return + } + assert.Error(t, r.ReadMsg(&msg)) + s.Reset() + }, + }, } - for i, tc := range tests { - t.Run(fmt.Sprintf("test-%d", i), func(t *testing.T) { - p.SetStreamHandler(DialProtocol, tc.handler) + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + b.SetStreamHandler(DialProtocol, tc.handler) + addrs := an.host.Addrs() + res, err := an.CheckReachability( context.Background(), []Request{ @@ -269,74 +288,111 @@ func TestClientDataRequest(t *testing.T) { }) require.Equal(t, res, Result{}) require.NotNil(t, err) - <-done }) } } func TestClientDialBacks(t *testing.T) { an := newAutoNAT(t, nil, allowAllAddrs) - addrs := an.host.Addrs() + defer an.Close() + defer an.host.Close() - p := bhost.NewBlankHost(swarmt.GenSwarm(t)) - idAndConnect(t, an.host, p) + b := bhost.NewBlankHost(swarmt.GenSwarm(t)) + defer b.Close() + idAndConnect(t, an.host, b) waitForPeer(t, an) + dialerHost := bhost.NewBlankHost(swarmt.GenSwarm(t)) + defer dialerHost.Close() + + readReq := func(r pbio.Reader) ([]ma.Multiaddr, uint64, error) { + var msg pb.Message + if err := r.ReadMsg(&msg); err != nil { + return nil, 0, err + } + if msg.GetDialRequest() == nil { + return nil, 0, errors.New("no dial request in msg") + } + addrs := parseAddrs(t, &msg) + return addrs, msg.GetDialRequest().GetNonce(), nil + } + + writeNonce := func(addr ma.Multiaddr, nonce uint64) error { + pid := an.host.ID() + dialerHost.Peerstore().AddAddr(pid, addr, peerstore.PermanentAddrTTL) + defer func() { + dialerHost.Network().ClosePeer(pid) + dialerHost.Peerstore().RemovePeer(pid) + dialerHost.Peerstore().ClearAddrs(pid) + }() + as, err := dialerHost.NewStream(context.Background(), pid, DialBackProtocol) + if err != nil { + return err + } + w := pbio.NewDelimitedWriter(as) + if err := w.WriteMsg(&pb.DialBack{Nonce: nonce}); err != nil { + return err + } + as.CloseWrite() + data := make([]byte, 1) + as.Read(data) + as.Close() + return nil + } + tests := []struct { + name string handler func(network.Stream) success bool - isError bool }{ { + name: "correct dial attempt", handler: func(s network.Stream) { r := pbio.NewDelimitedReader(s, maxMsgSize) - var msg pb.Message - if err := r.ReadMsg(&msg); err != nil { + w := pbio.NewDelimitedWriter(s) + + addrs, nonce, err := readReq(r) + if err != nil { + s.Reset() t.Error(err) + return } - resp := &pb.DialResponse{ - Status: pb.DialResponse_OK, - DialStatus: pb.DialStatus_OK, - AddrIdx: 0, + if err := writeNonce(addrs[1], nonce); err != nil { + s.Reset() + t.Error(err) + return } - w := pbio.NewDelimitedWriter(s) w.WriteMsg(&pb.Message{ Msg: &pb.Message_DialResponse{ - DialResponse: resp, + DialResponse: &pb.DialResponse{ + Status: pb.DialResponse_OK, + DialStatus: pb.DialStatus_OK, + AddrIdx: 1, + }, }, }) s.Close() }, - success: false, + success: true, }, { + name: "no dial attempt", handler: func(s network.Stream) { r := pbio.NewDelimitedReader(s, maxMsgSize) - var msg pb.Message - r.ReadMsg(&msg) - req := msg.GetDialRequest() - addrs := parseAddrs(t, &msg) - hh := bhost.NewBlankHost(swarmt.GenSwarm(t)) - defer hh.Close() - hh.Peerstore().AddAddr(s.Conn().RemotePeer(), addrs[1], peerstore.PermanentAddrTTL) - as, err := hh.NewStream(context.Background(), s.Conn().RemotePeer(), DialBackProtocol) - if err != nil { - t.Error("failed to open stream", err) + if _, _, err := readReq(r); err != nil { s.Reset() + t.Error(err) return } - w := pbio.NewDelimitedWriter(as) - w.WriteMsg(&pb.DialBack{Nonce: req.Nonce}) - as.CloseWrite() - - w = pbio.NewDelimitedWriter(s) + resp := &pb.DialResponse{ + Status: pb.DialResponse_OK, + DialStatus: pb.DialStatus_OK, + AddrIdx: 0, + } + w := pbio.NewDelimitedWriter(s) w.WriteMsg(&pb.Message{ Msg: &pb.Message_DialResponse{ - DialResponse: &pb.DialResponse{ - Status: pb.DialResponse_OK, - DialStatus: pb.DialStatus_OK, - AddrIdx: 0, - }, + DialResponse: resp, }, }) s.Close() @@ -344,34 +400,21 @@ func TestClientDialBacks(t *testing.T) { success: false, }, { + name: "invalid reported address", handler: func(s network.Stream) { r := pbio.NewDelimitedReader(s, maxMsgSize) - var msg pb.Message - r.ReadMsg(&msg) - req := msg.GetDialRequest() - addrs := parseAddrs(t, &msg) - hh := bhost.NewBlankHost(swarmt.GenSwarm(t)) - defer hh.Close() - hh.Peerstore().AddAddr(s.Conn().RemotePeer(), addrs[1], peerstore.PermanentAddrTTL) - as, err := hh.NewStream(context.Background(), s.Conn().RemotePeer(), DialBackProtocol) - as.SetDeadline(time.Now().Add(5 * time.Second)) + addrs, nonce, err := readReq(r) if err != nil { - t.Error("failed to open stream", err) s.Reset() + t.Error(err) return } - ww := pbio.NewDelimitedWriter(as) - if err := ww.WriteMsg(&pb.DialBack{Nonce: req.Nonce - 1}); err != nil { + + if err := writeNonce(addrs[1], nonce); err != nil { s.Reset() - as.Reset() + t.Error(err) return } - as.CloseWrite() - defer func() { - data := make([]byte, 1) - as.Read(data) - as.Close() - }() w := pbio.NewDelimitedWriter(s) w.WriteMsg(&pb.Message{ @@ -388,86 +431,45 @@ func TestClientDialBacks(t *testing.T) { success: false, }, { + name: "invalid nonce", handler: func(s network.Stream) { r := pbio.NewDelimitedReader(s, maxMsgSize) - var msg pb.Message - r.ReadMsg(&msg) - req := msg.GetDialRequest() - addrs := parseAddrs(t, &msg) - - hh := bhost.NewBlankHost(swarmt.GenSwarm(t)) - defer hh.Close() - hh.Peerstore().AddAddr(s.Conn().RemotePeer(), addrs[1], peerstore.PermanentAddrTTL) - as, err := hh.NewStream(context.Background(), s.Conn().RemotePeer(), DialBackProtocol) + addrs, nonce, err := readReq(r) if err != nil { - t.Error("failed to open stream", err) s.Reset() + t.Error(err) return } - - w := pbio.NewDelimitedWriter(as) - if err := w.WriteMsg(&pb.DialBack{Nonce: req.Nonce}); err != nil { - t.Error("failed to write nonce", err) + if err := writeNonce(addrs[0], nonce-1); err != nil { s.Reset() - as.Reset() + t.Error(err) return } - as.CloseWrite() - defer func() { - data := make([]byte, 1) - as.Read(data) - as.Close() - }() - - w = pbio.NewDelimitedWriter(s) - + w := pbio.NewDelimitedWriter(s) w.WriteMsg(&pb.Message{ Msg: &pb.Message_DialResponse{ DialResponse: &pb.DialResponse{ Status: pb.DialResponse_OK, DialStatus: pb.DialStatus_OK, - AddrIdx: 1, + AddrIdx: 0, }, }, }) s.Close() }, - success: true, + success: false, }, { + name: "invalid addr index", handler: func(s network.Stream) { r := pbio.NewDelimitedReader(s, maxMsgSize) - var msg pb.Message - r.ReadMsg(&msg) - req := msg.GetDialRequest() - addrs := parseAddrs(t, &msg) - - hh := bhost.NewBlankHost(swarmt.GenSwarm(t)) - defer hh.Close() - hh.Peerstore().AddAddr(s.Conn().RemotePeer(), addrs[1], peerstore.PermanentAddrTTL) - as, err := hh.NewStream(context.Background(), s.Conn().RemotePeer(), DialBackProtocol) + _, _, err := readReq(r) if err != nil { - t.Error("failed to open stream", err) - s.Reset() - return - } - - w := pbio.NewDelimitedWriter(as) - if err := w.WriteMsg(&pb.DialBack{Nonce: req.Nonce}); err != nil { - t.Error("failed to write nonce", err) s.Reset() - as.Reset() + t.Error(err) return } - as.CloseWrite() - defer func() { - data := make([]byte, 1) - as.Read(data) - as.Close() - }() - - w = pbio.NewDelimitedWriter(s) - + w := pbio.NewDelimitedWriter(s) w.WriteMsg(&pb.Message{ Msg: &pb.Message_DialResponse{ DialResponse: &pb.DialResponse{ @@ -480,13 +482,13 @@ func TestClientDialBacks(t *testing.T) { s.Close() }, success: false, - isError: true, }, } - for i, tc := range tests { - t.Run(fmt.Sprintf("test-%d", i), func(t *testing.T) { - p.SetStreamHandler(DialProtocol, tc.handler) + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + addrs := an.host.Addrs() + b.SetStreamHandler(DialProtocol, tc.handler) res, err := an.CheckReachability( context.Background(), []Request{ @@ -494,14 +496,8 @@ func TestClientDialBacks(t *testing.T) { {Addr: addrs[1]}, }) if !tc.success { - if tc.isError { - require.Equal(t, res, Result{}) - require.Error(t, err) - } else { - require.NoError(t, err) - require.Equal(t, res.Reachability, network.ReachabilityUnknown) - require.NotEqual(t, res.Status, pb.DialStatus_OK, "got: %d", res.Status) - } + require.Error(t, err) + require.Equal(t, Result{}, res) } else { require.NoError(t, err) require.Equal(t, res.Reachability, network.ReachabilityPublic) @@ -588,3 +584,55 @@ func TestPeersMap(t *testing.T) { require.Equal(t, emptyPeerID, p.GetRand()) }) } + +func TestAreAddrsConsistency(t *testing.T) { + tests := []struct { + name string + localAddr ma.Multiaddr + dialAddr ma.Multiaddr + success bool + }{ + { + name: "simple match", + localAddr: ma.StringCast("/ip4/192.168.0.1/tcp/12345"), + dialAddr: ma.StringCast("/ip4/1.2.3.4/tcp/23232"), + success: true, + }, + { + name: "nat64 match", + localAddr: ma.StringCast("/ip6/1::1/tcp/12345"), + dialAddr: ma.StringCast("/ip4/1.2.3.4/tcp/23232"), + success: true, + }, + { + name: "simple mismatch", + localAddr: ma.StringCast("/ip4/192.168.0.1/tcp/12345"), + dialAddr: ma.StringCast("/ip4/1.2.3.4/udp/23232/quic-v1"), + success: false, + }, + { + name: "quic-vs-webtransport", + localAddr: ma.StringCast("/ip4/192.168.0.1/udp/12345/quic-v1"), + dialAddr: ma.StringCast("/ip4/1.2.3.4/udp/123/quic-v1/webtransport"), + success: false, + }, + { + name: "nat64 mismatch", + localAddr: ma.StringCast("/ip4/192.168.0.1/udp/12345/quic-v1"), + dialAddr: ma.StringCast("/ip6/1::1/udp/123/quic-v1/"), + success: false, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + if areAddrsConsistent(tc.localAddr, tc.dialAddr) != tc.success { + wantStr := "match" + if !tc.success { + wantStr = "mismatch" + } + t.Errorf("expected %s between\nlocal addr: %s\ndial addr: %s", wantStr, tc.localAddr, tc.dialAddr) + } + }) + } + +} diff --git a/p2p/protocol/autonatv2/client.go b/p2p/protocol/autonatv2/client.go index 5bde7f28b2..e2a7db75cd 100644 --- a/p2p/protocol/autonatv2/client.go +++ b/p2p/protocol/autonatv2/client.go @@ -30,7 +30,7 @@ type client struct { } func newClient(h host.Host) *client { - return &client{host: h, dialData: make([]byte, 4096), dialBackQueues: make(map[uint64]chan ma.Multiaddr)} + return &client{host: h, dialData: make([]byte, 8000), dialBackQueues: make(map[uint64]chan ma.Multiaddr)} } // RegisterDialBack registers the client to receive DialBack streams initiated by the server to send the nonce. @@ -152,32 +152,38 @@ func (ac *client) CheckReachability(ctx context.Context, p peer.ID, reqs []Reque } timer.Stop() } - return ac.newResult(resp, reqs, dialBackAddr), nil + return ac.newResult(resp, reqs, dialBackAddr) } -func (ac *client) newResult(resp *pb.DialResponse, reqs []Request, dialBackAddr ma.Multiaddr) Result { +func (ac *client) newResult(resp *pb.DialResponse, reqs []Request, dialBackAddr ma.Multiaddr) (Result, error) { idx := int(resp.AddrIdx) addr := reqs[idx].Addr - rch := network.ReachabilityUnknown - status := resp.DialStatus - switch status { + var rch network.Reachability + switch resp.DialStatus { case pb.DialStatus_OK: - if areAddrsConsistent(dialBackAddr, addr) { - rch = network.ReachabilityPublic - } else { - status = pb.DialStatus_E_DIAL_BACK_ERROR + if !areAddrsConsistent(dialBackAddr, addr) { + // The server reported a successful dial back but we didn't receive the nonce. + // Discard the response and fail. + return Result{}, fmt.Errorf("invalid repsonse: no dialback received") } + rch = network.ReachabilityPublic case pb.DialStatus_E_DIAL_ERROR: rch = network.ReachabilityPrivate + case pb.DialStatus_E_DIAL_BACK_ERROR: + rch = network.ReachabilityUnknown + default: + // Unexpected response code. Discard the response and fail. + log.Warnf("invalid status code received in response for addr %s: %d", addr, resp.DialStatus) + return Result{}, fmt.Errorf("invalid response: invalid status code for addr %s: %d", addr, resp.DialStatus) } return Result{ Idx: idx, Addr: addr, Reachability: rch, - Status: status, - } + Status: resp.DialStatus, + }, nil } func (ac *client) sendDialData(req *pb.DialDataRequest, w pbio.Writer, msg *pb.Message) error { @@ -217,6 +223,7 @@ func newDialRequest(reqs []Request, nonce uint64) pb.Message { } } +// handleDialBack receives the nonce on the dial-back stream func (ac *client) handleDialBack(s network.Stream) { if err := s.Scope().SetService(ServiceName); err != nil { log.Debugf("failed to attach stream to service %s: %w", ServiceName, err) diff --git a/p2p/protocol/autonatv2/options.go b/p2p/protocol/autonatv2/options.go index 823ac63caf..3a59d8d823 100644 --- a/p2p/protocol/autonatv2/options.go +++ b/p2p/protocol/autonatv2/options.go @@ -35,7 +35,7 @@ func WithServerRateLimit(rpm, perPeerRPM, dialDataRPM int) AutoNATOption { } } -func WithDataRequestPolicy(drp dataRequestPolicyFunc) AutoNATOption { +func withDataRequestPolicy(drp dataRequestPolicyFunc) AutoNATOption { return func(s *autoNATSettings) error { s.dataRequestPolicy = drp return nil diff --git a/p2p/protocol/autonatv2/server.go b/p2p/protocol/autonatv2/server.go index cb4797c05c..886c9d22e3 100644 --- a/p2p/protocol/autonatv2/server.go +++ b/p2p/protocol/autonatv2/server.go @@ -21,7 +21,7 @@ import ( type dataRequestPolicyFunc = func(s network.Stream, dialAddr ma.Multiaddr) bool // server implements the AutoNATv2 server. -// +// It can ask client to provide dial data before attempting the requested dial. // It rate limits requests on a global level, per peer level and on whether the request requires dial data. type server struct { host host.Host @@ -53,14 +53,21 @@ func newServer(host, dialer host.Host, s *autoNATSettings) *server { } } +// Enable attaches the stream handler to the host. func (as *server) Enable() { as.host.SetStreamHandler(DialProtocol, as.handleDialRequest) } +// Disable removes the stream handles from the host. func (as *server) Disable() { as.host.RemoveStreamHandler(DialProtocol) } +func (as *server) Close() { + as.dialerHost.Close() +} + +// handleDialRequest is the dial-request protocol stream handler func (as *server) handleDialRequest(s network.Stream) { if err := s.Scope().SetService(ServiceName); err != nil { s.Reset() @@ -160,7 +167,6 @@ func (as *server) handleDialRequest(s network.Stream) { } dialStatus := as.dialBack(s.Conn().RemotePeer(), dialAddr, nonce) - msg = pb.Message{ Msg: &pb.Message_DialResponse{ DialResponse: &pb.DialResponse{ @@ -205,6 +211,7 @@ func getDialData(w pbio.Writer, r pbio.Reader, msg *pb.Message, addrIdx int) err func (as *server) dialBack(p peer.ID, addr ma.Multiaddr, nonce uint64) pb.DialStatus { ctx, cancel := context.WithTimeout(context.Background(), dialBackDialTimeout) + ctx = network.WithForceDirectDial(ctx, "autonatv2") as.dialerHost.Peerstore().AddAddr(p, addr, peerstore.TempAddrTTL) defer func() { cancel() @@ -212,10 +219,17 @@ func (as *server) dialBack(p peer.ID, addr ma.Multiaddr, nonce uint64) pb.DialSt as.dialerHost.Peerstore().ClearAddrs(p) as.dialerHost.Peerstore().RemovePeer(p) }() - s, err := as.dialerHost.NewStream(ctx, p, DialBackProtocol) + + err := as.dialerHost.Connect(ctx, peer.AddrInfo{ID: p}) if err != nil { return pb.DialStatus_E_DIAL_ERROR } + + s, err := as.dialerHost.NewStream(ctx, p, DialBackProtocol) + if err != nil { + return pb.DialStatus_E_DIAL_BACK_ERROR + } + defer s.Close() s.SetDeadline(as.now().Add(dialBackStreamTimeout)) @@ -239,15 +253,20 @@ func (as *server) dialBack(p peer.ID, addr ma.Multiaddr, nonce uint64) pb.DialSt // rateLimiter implements a sliding window rate limit of requests per minute. It allows 1 concurrent request // per peer. It rate limits requests globally, at a peer level and depending on whether it requires dial data. type rateLimiter struct { - PerPeerRPM int - RPM int + // PerPeerRPM is the rate limit per peer + PerPeerRPM int + // RPM is the global rate limit + RPM int + // DialDataRPM is the rate limit for requests that require dial data DialDataRPM int mu sync.Mutex reqs []time.Time peerReqs map[peer.ID][]time.Time dialDataReqs []time.Time - ongoingReqs map[peer.ID]struct{} + // ongoingReqs tracks in progress requests. This is used to disallow multiple concurrent requests by the + // same peer + ongoingReqs map[peer.ID]struct{} now func() time.Time // for tests } diff --git a/p2p/protocol/autonatv2/server_test.go b/p2p/protocol/autonatv2/server_test.go index 6d274375d5..00409bac6d 100644 --- a/p2p/protocol/autonatv2/server_test.go +++ b/p2p/protocol/autonatv2/server_test.go @@ -22,44 +22,42 @@ func newTestRequests(addrs []ma.Multiaddr, sendDialData bool) (reqs []Request) { return } -func TestServerAllAddrsInvalid(t *testing.T) { - dialer := bhost.NewBlankHost(swarmt.GenSwarm(t, swarmt.OptDisableQUIC, swarmt.OptDisableTCP)) - an := newAutoNAT(t, dialer, allowAllAddrs) - defer an.Close() - defer an.host.Close() - an.srv.Enable() - +func TestServerInvalidAddrsRejected(t *testing.T) { c := newAutoNAT(t, nil, allowAllAddrs) defer c.Close() defer c.host.Close() - identify(t, c, an) + t.Run("no transport", func(t *testing.T) { + dialer := bhost.NewBlankHost(swarmt.GenSwarm(t, swarmt.OptDisableQUIC, swarmt.OptDisableTCP)) + an := newAutoNAT(t, dialer, allowAllAddrs) + defer an.Close() + defer an.host.Close() - res, err := c.CheckReachability(context.Background(), newTestRequests(c.host.Addrs(), true)) - require.ErrorIs(t, err, ErrDialRefused) - require.Equal(t, Result{}, res) -} + idAndWait(t, c, an) -func TestServerPrivateRejected(t *testing.T) { - an := newAutoNAT(t, nil) - defer an.Close() - defer an.host.Close() - an.srv.Enable() + res, err := c.CheckReachability(context.Background(), newTestRequests(c.host.Addrs(), true)) + require.ErrorIs(t, err, ErrDialRefused) + require.Equal(t, Result{}, res) + }) - c := newAutoNAT(t, nil, allowAllAddrs) - defer c.Close() - defer c.host.Close() + t.Run("private addrs", func(t *testing.T) { + an := newAutoNAT(t, nil) + defer an.Close() + defer an.host.Close() - identify(t, c, an) + idAndWait(t, c, an) - res, err := c.CheckReachability(context.Background(), newTestRequests(c.host.Addrs(), true)) - require.ErrorIs(t, err, ErrDialRefused) - require.Equal(t, Result{}, res) + res, err := c.CheckReachability(context.Background(), newTestRequests(c.host.Addrs(), true)) + require.ErrorIs(t, err, ErrDialRefused) + require.Equal(t, Result{}, res) + }) } func TestServerDataRequest(t *testing.T) { + // server will skip all tcp addresses dialer := bhost.NewBlankHost(swarmt.GenSwarm(t, swarmt.OptDisableTCP)) - an := newAutoNAT(t, dialer, allowAllAddrs, WithDataRequestPolicy( + // ask for dial data for quic address + an := newAutoNAT(t, dialer, allowAllAddrs, withDataRequestPolicy( func(s network.Stream, dialAddr ma.Multiaddr) bool { if _, err := dialAddr.ValueForProtocol(ma.P_QUIC_V1); err == nil { return true @@ -68,15 +66,14 @@ func TestServerDataRequest(t *testing.T) { }), WithServerRateLimit(10, 10, 10), ) - an.srv.Enable() + defer an.Close() defer an.host.Close() - c := newAutoNAT(t, nil) - c.allowAllAddrs = true + c := newAutoNAT(t, nil, allowAllAddrs) defer c.Close() defer c.host.Close() - identify(t, c, an) + idAndWait(t, c, an) var quicAddr, tcpAddr ma.Multiaddr for _, a := range c.host.Addrs() { @@ -103,35 +100,52 @@ func TestServerDataRequest(t *testing.T) { func TestServerDial(t *testing.T) { an := newAutoNAT(t, nil, WithServerRateLimit(10, 10, 10), allowAllAddrs) + defer an.Close() defer an.host.Close() - an.srv.Enable() c := newAutoNAT(t, nil, allowAllAddrs) defer c.Close() defer c.host.Close() - identify(t, c, an) + idAndWait(t, c, an) - randAddr := ma.StringCast("/ip4/1.2.3.4/tcp/2") + unreachableAddr := ma.StringCast("/ip4/1.2.3.4/tcp/2") hostAddrs := c.host.Addrs() - res, err := c.CheckReachability(context.Background(), - append([]Request{{Addr: randAddr, SendDialData: true}}, newTestRequests(hostAddrs, false)...)) - require.NoError(t, err) - require.Equal(t, Result{ - Idx: 0, - Addr: randAddr, - Reachability: network.ReachabilityPrivate, - Status: pb.DialStatus_E_DIAL_ERROR, - }, res) - res, err = c.CheckReachability(context.Background(), newTestRequests(c.host.Addrs(), false)) - require.NoError(t, err) - require.Equal(t, Result{ - Idx: 0, - Addr: hostAddrs[0], - Reachability: network.ReachabilityPublic, - Status: pb.DialStatus_OK, - }, res) + t.Run("unreachable addr", func(t *testing.T) { + res, err := c.CheckReachability(context.Background(), + append([]Request{{Addr: unreachableAddr, SendDialData: true}}, newTestRequests(hostAddrs, false)...)) + require.NoError(t, err) + require.Equal(t, Result{ + Idx: 0, + Addr: unreachableAddr, + Reachability: network.ReachabilityPrivate, + Status: pb.DialStatus_E_DIAL_ERROR, + }, res) + }) + + t.Run("reachable addr", func(t *testing.T) { + res, err := c.CheckReachability(context.Background(), newTestRequests(c.host.Addrs(), false)) + require.NoError(t, err) + require.Equal(t, Result{ + Idx: 0, + Addr: hostAddrs[0], + Reachability: network.ReachabilityPublic, + Status: pb.DialStatus_OK, + }, res) + }) + + t.Run("dialback error", func(t *testing.T) { + c.host.RemoveStreamHandler(DialBackProtocol) + res, err := c.CheckReachability(context.Background(), newTestRequests(c.host.Addrs(), false)) + require.NoError(t, err) + require.Equal(t, Result{ + Idx: 0, + Addr: hostAddrs[0], + Reachability: network.ReachabilityUnknown, + Status: pb.DialStatus_E_DIAL_BACK_ERROR, + }, res) + }) } func TestRateLimiter(t *testing.T) {