Skip to content

Commit

Permalink
fix race condition on perfect negotiation
Browse files Browse the repository at this point in the history
  • Loading branch information
Yohan Totting committed Aug 7, 2023
1 parent 98fd915 commit 9fcb2c5
Show file tree
Hide file tree
Showing 6 changed files with 299 additions and 63 deletions.
92 changes: 75 additions & 17 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,20 +47,24 @@ type Client struct {
canAddCandidate bool
initialTracksCount int
isInRenegotiation bool
isInRemoteNegotiation bool
idleTimeoutContext context.Context
idleTimeoutCancel context.CancelFunc
mutex sync.RWMutex
peerConnection *webrtc.PeerConnection
pendingReceivedTracks map[string]*webrtc.TrackLocalStaticRTP
pendingPublishedTracks map[string]*webrtc.TrackLocalStaticRTP
pendingRemoteRenegotiation bool
publishedTracks map[string]*webrtc.TrackLocalStaticRTP
queue *queue
State int
sfu *SFU
onConnectionStateChanged func(webrtc.PeerConnectionState)
onConnectionStateChangedCallbacks []func(webrtc.PeerConnectionState)
OnIceCandidate func(context.Context, *webrtc.ICECandidate)
OnBeforeRenegotiation func(context.Context) bool
OnRenegotiation func(context.Context, webrtc.SessionDescription) webrtc.SessionDescription
OnRenegotiation func(context.Context, webrtc.SessionDescription) (webrtc.SessionDescription, error)
OnAllowedRemoteRenegotiation func()
onStopped func()
onTrack func(context.Context, *webrtc.TrackLocalStaticRTP)
options ClientOptions
Expand Down Expand Up @@ -103,19 +107,44 @@ func (c *Client) CompleteNegotiation(answer webrtc.SessionDescription) {
}
}

// ask allow negotiation is required before call negotiation to make sure there is no racing condition of negotiation between local and remote clients.
// ask if allowed for remote negotiation is required before call negotiation to make sure there is no racing condition of negotiation between local and remote clients.
// return false means the negotiation is in process, the requester must have a mechanism to repeat the request once it's done.
// requesting this must be followed by calling Negotate() to make sure the state is completed. Failed on called Negotiate() will cause the client to be in inconsistent state.
func (c *Client) IsAllowNegotiation() bool {
if c.isInRenegotiation {
c.pendingRemoteRenegotiation = true
return false
}

c.isInRenegotiation = true
c.isInRemoteNegotiation = true

return true
}

// SDP negotiation from remote client
func (c *Client) Negotiate(offer webrtc.SessionDescription) (*webrtc.SessionDescription, error) {
answerChan := make(chan webrtc.SessionDescription)
errorChan := make(chan error)
c.queue.Push(negotiationQueue{
Client: c,
SDP: offer,
AnswerChan: answerChan,
ErrorChan: errorChan,
})

select {
case err := <-errorChan:
return nil, err
case answer := <-answerChan:
return &answer, nil
}
}

func (c *Client) negotiateQueuOp(offer webrtc.SessionDescription) (*webrtc.SessionDescription, error) {
c.isInRemoteNegotiation = true

currentTransceiverCount := len(c.peerConnection.GetTransceivers())

// Set the remote SessionDescription
err := c.peerConnection.SetRemoteDescription(offer)
if err != nil {
Expand Down Expand Up @@ -145,28 +174,37 @@ func (c *Client) Negotiate(offer webrtc.SessionDescription) (*webrtc.SessionDesc
}
}

c.initialTracksCount = len(c.peerConnection.GetTransceivers())
c.initialTracksCount = len(c.peerConnection.GetTransceivers()) - currentTransceiverCount

// send pending local candidates if any
c.sendPendingLocalCandidates()

c.pendingRemoteCandidates = nil

c.isInRenegotiation = false
c.isInRemoteNegotiation = false

// send pending local candidates if any
c.sendPendingLocalCandidates()
// call renegotiation that might delay because the remote client is doing renegotiation

return c.peerConnection.LocalDescription(), nil
}

func (c *Client) renegotiate() {
c.queue.Push(renegotiateQueue{
Client: c,
})
}

// The renegotiation can be in race condition when a client is renegotiating and new track is added to the client because another client is publishing to the room.
// We can block the renegotiation until the current renegotiation is finish, but it will block the negotiation process for a while.
func (c *Client) renegotiate() {
if c.GetType() == ClientTypeUpBridge && c.OnBeforeRenegotiation != nil && !c.OnBeforeRenegotiation(c.Context) {
log.Println("sfu: renegotiation is not allowed because the downbridge is doing renegotiation", c.ID)
func (c *Client) renegotiateQueuOp() {
c.NegotiationNeeded = true

if c.isInRemoteNegotiation {
log.Println("sfu: renegotiation is delayed because the remote client is doing negotiation", c.ID)

return
}

c.NegotiationNeeded = true

// no need to run another negotiation if it's already in progress, it will rerun because we mark the negotiationneeded to true
if c.isInRenegotiation {
return
Expand Down Expand Up @@ -197,7 +235,12 @@ func (c *Client) renegotiate() {
}

// this will be blocking until the renegotiation is done
answer := c.OnRenegotiation(c.Context, *c.peerConnection.LocalDescription())
answer, err := c.OnRenegotiation(c.Context, *c.peerConnection.LocalDescription())
if err != nil {
//TODO: when this happen, we need to close the client and ask the remote client to reconnect
log.Println("sfu: error on renegotiation ", err)
return
}

if answer.Type != webrtc.SDPTypeAnswer {
log.Println("sfu: error on renegotiation, the answer is not an answer type")
Expand All @@ -214,6 +257,20 @@ func (c *Client) renegotiate() {
c.isInRenegotiation = false
}

func (c *Client) allowRemoteRenegotiation() {
c.queue.Push(allowRemoteRenegotiationQueue{
Client: c,
})
}

// inform to remote client that it's allowed to do renegotiation through event
func (c *Client) allowRemoteRenegotiationQueuOp() {
if c.OnAllowedRemoteRenegotiation != nil {
c.isInRemoteNegotiation = true
go c.OnAllowedRemoteRenegotiation()
}
}

// return boolean if need a renegotiation after track added
func (c *Client) addTrack(track *webrtc.TrackLocalStaticRTP) bool {
// if the client is not connected, we wait until it's connected in go routine
Expand Down Expand Up @@ -279,9 +336,6 @@ func (c *Client) removePublishedTrack(streamID, trackID string) bool {
if track != nil && track.ID() == trackID && track.StreamID() == streamID {
if err := c.peerConnection.RemoveTrack(sender); err != nil {
log.Println("client: error remove track ", err)
} else {
log.Println("client: published track removed ", c.ID, streamID, trackID)
// go c.renegotiate()
}
}
}
Expand Down Expand Up @@ -319,7 +373,7 @@ func (c *Client) processPendingTracks() bool {
trackAdded = c.setClientTrack(track)
}

c.pendingReceivedTracks = nil
c.pendingReceivedTracks = make(map[string]*webrtc.TrackLocalStaticRTP)

return trackAdded
}
Expand Down Expand Up @@ -463,3 +517,7 @@ func (c *Client) GetTracks() map[string]*webrtc.TrackLocalStaticRTP {
func (c *Client) GetPeerConnection() *webrtc.PeerConnection {
return c.peerConnection
}

func (c *Client) errorHandler(err error) {
log.Println("client: error ", err)
}
2 changes: 2 additions & 0 deletions examples/http-websocket/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ <h1>HTTP WebSocket Example</h1>
await peerConnection.setRemoteDescription(msg.data);
} else if (msg.type == 'candidate') {
await peerConnection.addIceCandidate(msg.data);
} else if (msg.type == 'allow_renegotiation') {
// use this event to retry renegotation when doing renegotation was not allowed
}
} catch (error) {
console.log(error);
Expand Down
31 changes: 24 additions & 7 deletions examples/http-websocket/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package main
import (
"context"
"encoding/json"
"errors"
"log"
"net/http"
"time"
Expand All @@ -24,10 +25,11 @@ type Respose struct {
}

const (
TypeOffer = "offer"
TypeAnswer = "answer"
TypeCandidate = "candidate"
TypeError = "error"
TypeOffer = "offer"
TypeAnswer = "answer"
TypeCandidate = "candidate"
TypeError = "error"
TypeAllowRenegotiation = "allow_renegotiation"
)

func main() {
Expand Down Expand Up @@ -110,7 +112,7 @@ func clientHandler(conn *websocket.Conn, messageChan chan Request, r *sfu.Room)

answerChan := make(chan webrtc.SessionDescription)

client.OnRenegotiation = func(ctx context.Context, offer webrtc.SessionDescription) webrtc.SessionDescription {
client.OnRenegotiation = func(ctx context.Context, offer webrtc.SessionDescription) (webrtc.SessionDescription, error) {
// SFU request a renegotiation, send the offer to client
log.Println("receive renegotiation offer from SFU")

Expand All @@ -133,13 +135,28 @@ func clientHandler(conn *websocket.Conn, messageChan chan Request, r *sfu.Room)
select {
case <-ctxTimeout.Done():
log.Println("timeout on renegotiation")
return webrtc.SessionDescription{}
return webrtc.SessionDescription{}, errors.New("timeout on renegotiation")
case answer := <-answerChan:
log.Println("received answer from client ", client.GetType(), client.ID)
return answer
return answer, nil
}
}

client.OnAllowedRemoteRenegotiation = func() {
// SFU allow a remote renegotiation
log.Println("receive allow remote renegotiation from SFU")

resp := Respose{
Status: true,
Type: TypeAllowRenegotiation,
Data: "ok",
}

respBytes, _ := json.Marshal(resp)

_, _ = conn.Write(respBytes)
}

client.OnIceCandidate = func(ctx context.Context, candidate *webrtc.ICECandidate) {
// SFU send an ICE candidate to client
resp := Respose{
Expand Down
69 changes: 69 additions & 0 deletions queue.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
package sfu

import (
"context"

"github.com/pion/webrtc/v3"
)

type queue struct {
opChan chan interface{}
}

type negotiationQueue struct {
Client *Client
SDP webrtc.SessionDescription
AnswerChan chan webrtc.SessionDescription
ErrorChan chan error
}

type renegotiateQueue struct {
Client *Client
}

type allowRemoteRenegotiationQueue struct {
Client *Client
}

func NewQueue(ctx context.Context) *queue {
q := &queue{
opChan: make(chan interface{}),
}

go q.run(ctx)

return q
}

func (q *queue) Push(item interface{}) {
go func() {
q.opChan <- item
}()
}

func (q *queue) run(ctx context.Context) {
ctxx, cancel := context.WithCancel(ctx)
defer cancel()

for {
select {
case <-ctxx.Done():
return
case item := <-q.opChan:
switch opItem := item.(type) {
case negotiationQueue:
answer, err := opItem.Client.negotiateQueuOp(opItem.SDP)
if err != nil {
opItem.ErrorChan <- err
continue
}

opItem.AnswerChan <- *answer
case renegotiateQueue:
opItem.Client.renegotiateQueuOp()
case allowRemoteRenegotiationQueue:
opItem.Client.allowRemoteRenegotiationQueuOp()
}
}
}
}

0 comments on commit 9fcb2c5

Please sign in to comment.