Skip to content

Commit

Permalink
webrtc: correctly report incoming packet address on muxed connection
Browse files Browse the repository at this point in the history
  • Loading branch information
sukunrt committed Sep 26, 2023
1 parent 7f72151 commit 9ce1216
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 24 deletions.
38 changes: 25 additions & 13 deletions p2p/transport/webrtc/udpmux/mux.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,15 @@ type UDPMux struct {

queue chan Candidate

mx sync.Mutex
mx sync.Mutex
// ufragMap allows us to multiplex incoming STUN packets based on ufrag
ufragMap map[ufragConnKey]*muxedConnection
addrMap map[string]*muxedConnection
// addrMap allows us to correctly direct incoming packets after the connection
// is established and ufrag isn't available on all packets
addrMap map[string]*muxedConnection
// ufragAddrMap allows us to clean up all addresses from the addrMap once
// connection is closed
ufragAddrMap map[ufragConnKey][]net.Addr

// the context controls the lifecycle of the mux
wg sync.WaitGroup
Expand All @@ -57,12 +63,13 @@ var _ ice.UDPMux = &UDPMux{}
func NewUDPMux(socket net.PacketConn) *UDPMux {
ctx, cancel := context.WithCancel(context.Background())
mux := &UDPMux{
ctx: ctx,
cancel: cancel,
socket: socket,
ufragMap: make(map[ufragConnKey]*muxedConnection),
addrMap: make(map[string]*muxedConnection),
queue: make(chan Candidate, 32),
ctx: ctx,
cancel: cancel,
socket: socket,
ufragMap: make(map[ufragConnKey]*muxedConnection),
addrMap: make(map[string]*muxedConnection),
ufragAddrMap: make(map[ufragConnKey][]net.Addr),
queue: make(chan Candidate, 32),
}

return mux
Expand Down Expand Up @@ -157,7 +164,7 @@ func (mux *UDPMux) processPacket(buf []byte, addr net.Addr) (processed bool) {
conn, ok := mux.addrMap[addr.String()]
mux.mx.Unlock()
if ok {
if err := conn.Push(buf); err != nil {
if err := conn.Push(buf, addr); err != nil {
log.Debugf("could not push packet: %v", err)
return false
}
Expand Down Expand Up @@ -196,7 +203,7 @@ func (mux *UDPMux) processPacket(buf []byte, addr net.Addr) (processed bool) {
}
}

if err := conn.Push(buf); err != nil {
if err := conn.Push(buf, addr); err != nil {
log.Debugf("could not push packet: %v", err)
return false
}
Expand Down Expand Up @@ -250,9 +257,12 @@ func (mux *UDPMux) RemoveConnByUfrag(ufrag string) {

for _, isIPv6 := range [...]bool{true, false} {
key := ufragConnKey{ufrag: ufrag, isIPv6: isIPv6}
if conn, ok := mux.ufragMap[key]; ok {
if _, ok := mux.ufragMap[key]; ok {
delete(mux.ufragMap, key)
delete(mux.addrMap, conn.RemoteAddr().String())
for _, addr := range mux.ufragAddrMap[key] {
delete(mux.addrMap, addr.String())
}
delete(mux.ufragAddrMap, key)
}
}
}
Expand All @@ -264,12 +274,14 @@ func (mux *UDPMux) getOrCreateConn(ufrag string, isIPv6 bool, _ *UDPMux, addr ne
defer mux.mx.Unlock()

if conn, ok := mux.ufragMap[key]; ok {
mux.addrMap[addr.String()] = conn
mux.ufragAddrMap[key] = append(mux.ufragAddrMap[key], addr)
return false, conn
}

conn := newMuxedConnection(mux, func() { mux.RemoveConnByUfrag(ufrag) }, addr)
mux.ufragMap[key] = conn
mux.addrMap[addr.String()] = conn

mux.ufragAddrMap[key] = append(mux.ufragAddrMap[key], addr)
return true, conn
}
27 changes: 16 additions & 11 deletions p2p/transport/webrtc/udpmux/muxed_connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@ import (
pool "github.com/libp2p/go-buffer-pool"
)

type packet struct {
buf []byte
addr net.Addr
}

var _ net.PacketConn = &muxedConnection{}

const queueLen = 128
Expand All @@ -21,7 +26,7 @@ type muxedConnection struct {
ctx context.Context
cancel context.CancelFunc
onClose func()
queue chan []byte
queue chan packet
remote net.Addr
mux *UDPMux
}
Expand All @@ -33,36 +38,36 @@ func newMuxedConnection(mux *UDPMux, onClose func(), remote net.Addr) *muxedConn
return &muxedConnection{
ctx: ctx,
cancel: cancel,
queue: make(chan []byte, queueLen),
queue: make(chan packet, queueLen),
onClose: onClose,
remote: remote,
mux: mux,
}
}

func (c *muxedConnection) Push(buf []byte) error {
func (c *muxedConnection) Push(buf []byte, addr net.Addr) error {
select {
case <-c.ctx.Done():
return errors.New("closed")
default:
}
select {
case c.queue <- buf:
case c.queue <- packet{buf: buf, addr: addr}:
return nil
default:
return errors.New("queue full")
}
}

func (c *muxedConnection) ReadFrom(p []byte) (int, net.Addr, error) {
func (c *muxedConnection) ReadFrom(buf []byte) (int, net.Addr, error) {
select {
case buf := <-c.queue:
n := copy(p, buf) // This might discard parts of the packet, if p is too short
if n < len(buf) {
log.Debugf("short read, had %d, read %d", len(buf), n)
case p := <-c.queue:
n := copy(buf, p.buf) // This might discard parts of the packet, if p is too short
if n < len(p.buf) {
log.Debugf("short read, had %d, read %d", len(p.buf), n)
}
pool.Put(buf)
return n, c.remote, nil
pool.Put(p.buf)
return n, p.addr, nil
case <-c.ctx.Done():
return 0, nil, c.ctx.Err()
}
Expand Down

0 comments on commit 9ce1216

Please sign in to comment.