diff --git a/p2p/transport/webrtc/udpmux/mux_test.go b/p2p/transport/webrtc/udpmux/mux_test.go index 4121b6fdf5..243c8dabb1 100644 --- a/p2p/transport/webrtc/udpmux/mux_test.go +++ b/p2p/transport/webrtc/udpmux/mux_test.go @@ -1,89 +1,224 @@ package udpmux import ( + "context" + "fmt" "net" "testing" "time" + "github.com/pion/stun" "github.com/stretchr/testify/require" ) -var _ net.PacketConn = dummyPacketConn{} - -type dummyPacketConn struct{} - -// Close implements net.PacketConn -func (dummyPacketConn) Close() error { - return nil -} - -// LocalAddr implements net.PacketConn -func (dummyPacketConn) LocalAddr() net.Addr { - return nil -} - -// ReadFrom implements net.PacketConn -func (dummyPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { - return 0, &net.UDPAddr{}, nil -} - -// SetDeadline implements net.PacketConn -func (dummyPacketConn) SetDeadline(t time.Time) error { - return nil +func getSTUNBindingRequest(ufrag string) *stun.Message { + msg := stun.New() + msg.SetType(stun.BindingRequest) + uattr := stun.RawAttribute{ + Type: stun.AttrUsername, + Value: []byte(fmt.Sprintf("%s:%s", ufrag, ufrag)), // This is the format we expect in our connections + } + uattr.AddTo(msg) + msg.Encode() + return msg } -// SetReadDeadline implements net.PacketConn -func (dummyPacketConn) SetReadDeadline(t time.Time) error { - return nil +func setupMapping(t *testing.T, ufrag string, from net.PacketConn, m *UDPMux) { + t.Helper() + msg := getSTUNBindingRequest(ufrag) + _, err := from.WriteTo(msg.Raw, m.GetListenAddresses()[0]) + require.NoError(t, err) } -// SetWriteDeadline implements net.PacketConn -func (dummyPacketConn) SetWriteDeadline(t time.Time) error { - return nil +func newPacketConn(t *testing.T) net.PacketConn { + t.Helper() + udpPort0 := &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0} + c, err := net.ListenUDP("udp", udpPort0) + require.NoError(t, err) + t.Cleanup(func() { c.Close() }) + return c } -// WriteTo implements net.PacketConn -func (dummyPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { - return 0, nil +func TestAccept(t *testing.T) { + c := newPacketConn(t) + defer c.Close() + m := NewUDPMux(c) + m.Start() + defer m.Close() + + ufrags := []string{"a", "b", "c", "d"} + conns := make([]net.PacketConn, len(ufrags)) + for i, ufrag := range ufrags { + conns[i] = newPacketConn(t) + setupMapping(t, ufrag, conns[i], m) + } + for i, ufrag := range ufrags { + c, err := m.Accept(context.Background()) + require.NoError(t, err) + require.Equal(t, c.Ufrag, ufrag) + require.Equal(t, c.Addr, conns[i].LocalAddr()) + } + + for i, ufrag := range ufrags { + // should not be accepted + setupMapping(t, ufrag, conns[i], m) + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + _, err := m.Accept(ctx) + require.Error(t, err) + + // should not be accepted + cc := newPacketConn(t) + setupMapping(t, ufrag, cc, m) + ctx, cancel = context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + _, err = m.Accept(ctx) + require.Error(t, err) + } } -func hasConn(m *UDPMux, ufrag string, isIPv6 bool) bool { - m.mx.Lock() - _, ok := m.ufragMap[ufragConnKey{ufrag: ufrag, isIPv6: isIPv6}] - m.mx.Unlock() - return ok +func TestGetConn(t *testing.T) { + c := newPacketConn(t) + m := NewUDPMux(c) + m.Start() + defer m.Close() + + ufrags := []string{"a", "b", "c", "d"} + conns := make([]net.PacketConn, len(ufrags)) + for i, ufrag := range ufrags { + conns[i] = newPacketConn(t) + setupMapping(t, ufrag, conns[i], m) + } + for i, ufrag := range ufrags { + c, err := m.Accept(context.Background()) + require.NoError(t, err) + require.Equal(t, c.Ufrag, ufrag) + require.Equal(t, c.Addr, conns[i].LocalAddr()) + } + + for i, ufrag := range ufrags { + c, err := m.GetConn(ufrag, conns[i].LocalAddr()) + require.NoError(t, err) + require.Equal(t, c.(*muxedConnection).RemoteAddr(), conns[i].LocalAddr()) + msg := make([]byte, 100) + _, _, err = c.ReadFrom(msg) + require.NoError(t, err) + } + + for i, ufrag := range ufrags { + cc := newPacketConn(t) + // setupMapping of cc to ufrags[0] and remove the stun binding request from the queue + setupMapping(t, ufrag, cc, m) + mc, err := m.GetConn(ufrag, cc.LocalAddr()) + require.NoError(t, err) + msg := make([]byte, 100) + _, _, err = mc.ReadFrom(msg) + require.NoError(t, err) + + // Write from new connection should provide the new address on ReadFrom + _, err = cc.WriteTo([]byte("test1"), c.LocalAddr()) + require.NoError(t, err) + n, addr, err := mc.ReadFrom(msg) + require.NoError(t, err) + require.Equal(t, addr, cc.LocalAddr()) + require.Equal(t, string(msg[:n]), "test1") + + // Write from original connection should provide the original address + _, err = conns[i].WriteTo([]byte("test2"), c.LocalAddr()) + require.NoError(t, err) + n, addr, err = mc.ReadFrom(msg) + require.NoError(t, err) + require.Equal(t, addr, conns[i].LocalAddr()) + require.Equal(t, string(msg[:n]), "test2") + } } -var ( - addrV4 = net.UDPAddr{IP: net.IPv4zero, Port: 1234} - addrV6 = net.UDPAddr{IP: net.IPv6zero, Port: 1234} -) - -func TestUDPMux_GetConn(t *testing.T) { - m := NewUDPMux(dummyPacketConn{}) - require.False(t, hasConn(m, "test", false)) - conn, err := m.GetConn("test", &addrV4) +func TestRemoveConnByUfrag(t *testing.T) { + c := newPacketConn(t) + m := NewUDPMux(c) + m.Start() + defer m.Close() + + // Map each ufrag to two addresses + ufrag := "a" + count := 10 + conns := make([]net.PacketConn, count) + for i := 0; i < 10; i++ { + conns[i] = newPacketConn(t) + setupMapping(t, ufrag, conns[i], m) + } + mc, err := m.GetConn(ufrag, conns[0].LocalAddr()) require.NoError(t, err) - require.NotNil(t, conn) - - require.False(t, hasConn(m, "test", true)) - connv6, err := m.GetConn("test", &addrV6) - require.NoError(t, err) - require.NotNil(t, connv6) - - require.NotEqual(t, conn, connv6) -} - -func TestUDPMux_RemoveConnectionOnClose(t *testing.T) { - mux := NewUDPMux(dummyPacketConn{}) - conn, err := mux.GetConn("test", &addrV4) + for i := 0; i < 10; i++ { + mc1, err := m.GetConn(ufrag, conns[i].LocalAddr()) + require.NoError(t, err) + require.Equal(t, mc1, mc) + } + + // Now remove the ufrag + m.RemoveConnByUfrag(ufrag) + + // All connections should now be associated with b + ufrag = "b" + for i := 0; i < 10; i++ { + setupMapping(t, ufrag, conns[i], m) + } + mc, err = m.GetConn(ufrag, conns[0].LocalAddr()) require.NoError(t, err) - require.NotNil(t, conn) - - require.True(t, hasConn(mux, "test", false)) - - err = conn.Close() + for i := 0; i < 10; i++ { + mc1, err := m.GetConn(ufrag, conns[i].LocalAddr()) + require.NoError(t, err) + require.Equal(t, mc1, mc) + } + + // Should be different even if the address is the same + mc1, err := m.GetConn("a", conns[0].LocalAddr()) require.NoError(t, err) + require.NotEqual(t, mc1, mc) +} - require.False(t, hasConn(mux, "test", false)) +func TestMuxedConnection(t *testing.T) { + c := newPacketConn(t) + m := NewUDPMux(c) + m.Start() + defer m.Close() + + msgCount := 3 + connCount := 3 + + ufrags := []string{"a", "b", "c"} + addrUfragMap := make(map[string]string) + for _, ufrag := range ufrags { + go func(ufrag string) { + for i := 0; i < connCount; i++ { + cc := newPacketConn(t) + addrUfragMap[cc.LocalAddr().String()] = ufrag + setupMapping(t, ufrag, cc, m) + for j := 0; j < msgCount; j++ { + cc.WriteTo([]byte(ufrag), c.LocalAddr()) + } + } + }(ufrag) + } + + for _, ufrag := range ufrags { + mc, err := m.GetConn(ufrag, c.LocalAddr()) // the address is irrelevant + require.NoError(t, err) + for i := 0; i < connCount; i++ { + msg := make([]byte, 100) + // Read the binding request + _, addr1, err := mc.ReadFrom(msg) + require.NoError(t, err) + require.Equal(t, addrUfragMap[addr1.String()], ufrag) + // Read individual msgs + for i := 0; i < msgCount; i++ { + n, addr2, err := mc.ReadFrom(msg) + require.NoError(t, err) + require.Equal(t, addr2, addr1) + require.Equal(t, ufrag, string(msg[:n])) + } + delete(addrUfragMap, addr1.String()) + } + } + require.Equal(t, len(addrUfragMap), 0) }