Skip to content

Commit

Permalink
Add currentDirection to RTPTransceiver
Browse files Browse the repository at this point in the history
add currentDirection to RTPTransceiver, don't reuse
transceiver if its currentDirection is sendrecv or sendonly
  • Loading branch information
cnderrauber committed Sep 15, 2022
1 parent 4e85358 commit 045df4c
Show file tree
Hide file tree
Showing 3 changed files with 125 additions and 10 deletions.
56 changes: 55 additions & 1 deletion peerconnection.go
Expand Up @@ -984,6 +984,7 @@ func (pc *PeerConnection) SetLocalDescription(desc SessionDescription) error {
weAnswer := desc.Type == SDPTypeAnswer
remoteDesc := pc.RemoteDescription()
if weAnswer && remoteDesc != nil {
_ = setRTPTransceiverCurrentDirection(&desc, currentTransceivers, false)
if err := pc.startRTPSenders(currentTransceivers); err != nil {
return err
}
Expand Down Expand Up @@ -1143,6 +1144,7 @@ func (pc *PeerConnection) SetRemoteDescription(desc SessionDescription) error {

if isRenegotation {
if weOffer {
_ = setRTPTransceiverCurrentDirection(&desc, currentTransceivers, true)
if err = pc.startRTPSenders(currentTransceivers); err != nil {
return err
}
Expand Down Expand Up @@ -1172,6 +1174,7 @@ func (pc *PeerConnection) SetRemoteDescription(desc SessionDescription) error {
// Start the networking in a new routine since it will block until
// the connection is actually established.
if weOffer {
_ = setRTPTransceiverCurrentDirection(&desc, currentTransceivers, true)
if err := pc.startRTPSenders(currentTransceivers); err != nil {
return err
}
Expand Down Expand Up @@ -1230,6 +1233,51 @@ func (pc *PeerConnection) startReceiver(incoming trackDetails, receiver *RTPRece
}
}

func setRTPTransceiverCurrentDirection(answer *SessionDescription, currentTransceivers []*RTPTransceiver, weOffer bool) error {
currentTransceivers = append([]*RTPTransceiver{}, currentTransceivers...)
for _, media := range answer.parsed.MediaDescriptions {
midValue := getMidValue(media)
if midValue == "" {
return errPeerConnRemoteDescriptionWithoutMidValue
}

if media.MediaName.Media == mediaSectionApplication {
continue
}

var t *RTPTransceiver
t, currentTransceivers = findByMid(midValue, currentTransceivers)

if t == nil {
return fmt.Errorf("%w: %q", errPeerConnTranscieverMidNil, midValue)
}

direction := getPeerDirection(media)
if direction == RTPTransceiverDirection(Unknown) {
continue
}

// reverse direction if it was a remote answer
if weOffer {
switch direction {
case RTPTransceiverDirectionSendonly:
direction = RTPTransceiverDirectionRecvonly
case RTPTransceiverDirectionRecvonly:
// Pion will answer recvonly with a offer recvonly transceiver, so we should
// not change the direction to sendonly if we are the offerer, otherwise this
// tranceiver can't be reuse for AddTrack
if t.Direction() != RTPTransceiverDirectionRecvonly {
direction = RTPTransceiverDirectionSendonly
}
default:
}
}

t.setCurrentDirection(direction)
}
return nil
}

func runIfNewReceiver(
incomingTrack trackDetails,
transceivers []*RTPTransceiver,
Expand Down Expand Up @@ -1706,7 +1754,13 @@ func (pc *PeerConnection) AddTrack(track TrackLocal) (*RTPSender, error) {
pc.mu.Lock()
defer pc.mu.Unlock()
for _, t := range pc.rtpTransceivers {
if !t.stopped && t.kind == track.Kind() && t.Sender() == nil {
currentDirection := t.getCurrentDirection()
// According to https://www.w3.org/TR/webrtc/#dom-rtcpeerconnection-addtrack, if the
// transceiver can be reused only if it's currentDirection never be sendrecv or sendonly.
// But that will cause sdp inflate. So we only check currentDirection's current value,
// that's worked for all browsers.
if !t.stopped && t.kind == track.Kind() && t.Sender() == nil &&
!(currentDirection == RTPTransceiverDirectionSendrecv || currentDirection == RTPTransceiverDirectionSendonly) {
sender, err := pc.api.NewRTPSender(track, pc.dtlsTransport)
if err == nil {
err = t.SetSender(sender, track)
Expand Down
57 changes: 52 additions & 5 deletions peerconnection_renegotiation_test.go
Expand Up @@ -128,14 +128,14 @@ func TestPeerConnection_Renegotiation_AddRecvonlyTransceiver(t *testing.T) {
pcOffer.OnTrack(func(track *TrackRemote, r *RTPReceiver) {
onTrackFiredFunc()
})
assert.NoError(t, signalPair(pcAnswer, pcOffer))
} else {
pcAnswer.OnTrack(func(track *TrackRemote, r *RTPReceiver) {
onTrackFiredFunc()
})
assert.NoError(t, signalPair(pcOffer, pcAnswer))
}

assert.NoError(t, signalPair(pcOffer, pcAnswer))

sendVideoUntilDone(onTrackFired.Done(), t, []*TrackLocalStaticSample{localTrack})

closePairNow(t, pcOffer, pcAnswer)
Expand Down Expand Up @@ -380,6 +380,7 @@ func TestPeerConnection_Transceiver_Mid(t *testing.T) {

offer, err = pcOffer.CreateOffer(nil)
assert.NoError(t, err)
assert.NoError(t, pcOffer.SetLocalDescription(offer))

assert.Equal(t, len(offer.parsed.MediaDescriptions), 2)

Expand All @@ -391,6 +392,11 @@ func TestPeerConnection_Transceiver_Mid(t *testing.T) {
pcOffer.ops.Done()
pcAnswer.ops.Done()

assert.NoError(t, pcAnswer.SetRemoteDescription(offer))
answer, err = pcAnswer.CreateAnswer(nil)
assert.NoError(t, err)
assert.NoError(t, pcOffer.SetRemoteDescription(answer))

track3, err := NewTrackLocalStaticSample(RTPCodecCapability{MimeType: MimeTypeVP8}, "video", "pion3")
require.NoError(t, err)

Expand Down Expand Up @@ -468,12 +474,12 @@ func TestPeerConnection_Renegotiation_CodecChange(t *testing.T) {

require.NoError(t, pcOffer.RemoveTrack(sender1))

sender2, err := pcOffer.AddTrack(track2)
require.NoError(t, err)

require.NoError(t, signalPair(pcOffer, pcAnswer))
<-tracksClosed

sender2, err := pcOffer.AddTrack(track2)
require.NoError(t, err)
require.NoError(t, signalPair(pcOffer, pcAnswer))
transceivers = pcOffer.GetTransceivers()
require.Equal(t, 1, len(transceivers))
require.Equal(t, "0", transceivers[0].Mid())
Expand Down Expand Up @@ -1145,3 +1151,44 @@ func TestPeerConnection_Renegotiation_Simulcast(t *testing.T) {
closePairNow(t, pcOffer, pcAnswer)
})
}

func TestPeerConnection_Regegotiation_ReuseTransceiver(t *testing.T) {
lim := test.TimeOut(time.Second * 30)
defer lim.Stop()

report := test.CheckRoutines(t)
defer report()

pcOffer, pcAnswer, err := newPair()
if err != nil {
t.Fatal(err)
}

vp8Track, err := NewTrackLocalStaticSample(RTPCodecCapability{MimeType: MimeTypeVP8}, "foo", "bar")
assert.NoError(t, err)
sender, err := pcOffer.AddTrack(vp8Track)
assert.NoError(t, err)
assert.NoError(t, signalPair(pcOffer, pcAnswer))

assert.Equal(t, len(pcOffer.GetTransceivers()), 1)
assert.Equal(t, pcOffer.GetTransceivers()[0].getCurrentDirection(), RTPTransceiverDirectionSendonly)
assert.NoError(t, pcOffer.RemoveTrack(sender))
assert.Equal(t, pcOffer.GetTransceivers()[0].getCurrentDirection(), RTPTransceiverDirectionSendonly)

// should not reuse tranceiver
vp8Track2, err := NewTrackLocalStaticSample(RTPCodecCapability{MimeType: MimeTypeVP8}, "foo", "bar")
assert.NoError(t, err)
sender2, err := pcOffer.AddTrack(vp8Track2)
assert.NoError(t, err)
assert.Equal(t, len(pcOffer.GetTransceivers()), 2)
assert.NoError(t, signalPair(pcOffer, pcAnswer))
assert.True(t, sender2.rtpTransceiver == pcOffer.GetTransceivers()[1])

// should reuse first transceiver
sender, err = pcOffer.AddTrack(vp8Track)
assert.NoError(t, err)
assert.Equal(t, len(pcOffer.GetTransceivers()), 2)
assert.True(t, sender.rtpTransceiver == pcOffer.GetTransceivers()[0])

closePairNow(t, pcOffer, pcAnswer)
}
22 changes: 18 additions & 4 deletions rtptransceiver.go
Expand Up @@ -13,10 +13,11 @@ import (

// RTPTransceiver represents a combination of an RTPSender and an RTPReceiver that share a common mid.
type RTPTransceiver struct {
mid atomic.Value // string
sender atomic.Value // *RTPSender
receiver atomic.Value // *RTPReceiver
direction atomic.Value // RTPTransceiverDirection
mid atomic.Value // string
sender atomic.Value // *RTPSender
receiver atomic.Value // *RTPReceiver
direction atomic.Value // RTPTransceiverDirection
currentDirection atomic.Value // RTPTransceiverDirection

codecs []RTPCodecParameters // User provided codecs via SetCodecPreferences

Expand All @@ -38,6 +39,7 @@ func newRTPTransceiver(
t.setReceiver(receiver)
t.setSender(sender)
t.setDirection(direction)
t.setCurrentDirection(RTPTransceiverDirection(Unknown))
return t
}

Expand Down Expand Up @@ -160,6 +162,7 @@ func (t *RTPTransceiver) Stop() error {
}

t.setDirection(RTPTransceiverDirectionInactive)
t.setCurrentDirection(RTPTransceiverDirectionInactive)
return nil
}

Expand All @@ -179,6 +182,17 @@ func (t *RTPTransceiver) setDirection(d RTPTransceiverDirection) {
t.direction.Store(d)
}

func (t *RTPTransceiver) setCurrentDirection(d RTPTransceiverDirection) {
t.currentDirection.Store(d)
}

func (t *RTPTransceiver) getCurrentDirection() RTPTransceiverDirection {
if v, ok := t.currentDirection.Load().(RTPTransceiverDirection); ok {
return v
}
return RTPTransceiverDirection(Unknown)
}

func (t *RTPTransceiver) setSendingTrack(track TrackLocal) error {
if err := t.Sender().ReplaceTrack(track); err != nil {
return err
Expand Down

0 comments on commit 045df4c

Please sign in to comment.