Skip to content

Commit

Permalink
fix: Use atomic value for logger in peer (#1257)
Browse files Browse the repository at this point in the history
This caused many races where logs would escape the tests
by milliseconds. By using an atomic on the logger,
we can fix all of it!
  • Loading branch information
kylecarbs committed May 2, 2022
1 parent e531c09 commit 3176e10
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 41 deletions.
6 changes: 3 additions & 3 deletions peer/channel.go
Expand Up @@ -118,14 +118,14 @@ func (c *Channel) init() {
}
})
c.dc.OnClose(func() {
c.conn.opts.Logger.Debug(context.Background(), "datachannel closing from OnClose", slog.F("id", c.dc.ID()), slog.F("label", c.dc.Label()))
c.conn.logger().Debug(context.Background(), "datachannel closing from OnClose", slog.F("id", c.dc.ID()), slog.F("label", c.dc.Label()))
_ = c.closeWithError(ErrClosed)
})
c.dc.OnOpen(func() {
c.closeMutex.Lock()
defer c.closeMutex.Unlock()

c.conn.opts.Logger.Debug(context.Background(), "datachannel opening", slog.F("id", c.dc.ID()), slog.F("label", c.dc.Label()))
c.conn.logger().Debug(context.Background(), "datachannel opening", slog.F("id", c.dc.ID()), slog.F("label", c.dc.Label()))
var err error
c.rwc, err = c.dc.Detach()
if err != nil {
Expand Down Expand Up @@ -289,7 +289,7 @@ func (c *Channel) closeWithError(err error) error {
return c.closeError
}

c.conn.opts.Logger.Debug(context.Background(), "datachannel closing with error", slog.F("id", c.dc.ID()), slog.F("label", c.dc.Label()), slog.Error(err))
c.conn.logger().Debug(context.Background(), "datachannel closing with error", slog.F("id", c.dc.ID()), slog.F("label", c.dc.Label()), slog.Error(err))
if err == nil {
c.closeError = ErrClosed
} else {
Expand Down
75 changes: 37 additions & 38 deletions peer/conn.go
Expand Up @@ -63,7 +63,6 @@ func newWithClientOrServer(servers []webrtc.ICEServer, client bool, opts *ConnOp
conn := &Conn{
pingChannelID: 1,
pingEchoChannelID: 2,
opts: opts,
rtc: rtc,
offerer: client,
closed: make(chan struct{}),
Expand All @@ -75,7 +74,9 @@ func newWithClientOrServer(servers []webrtc.ICEServer, client bool, opts *ConnOp
localCandidateChannel: make(chan webrtc.ICECandidateInit),
localSessionDescriptionChannel: make(chan webrtc.SessionDescription, 1),
remoteSessionDescriptionChannel: make(chan webrtc.SessionDescription, 1),
settingEngine: opts.SettingEngine,
}
conn.loggerValue.Store(opts.Logger)
if client {
// If we're the client, we want to flip the echo and
// ping channel IDs so pings don't accidentally hit each other.
Expand All @@ -100,8 +101,7 @@ type ConnOptions struct {
// This struct wraps webrtc.PeerConnection to add bidirectional pings,
// concurrent-safe webrtc.DataChannel, and standardized errors for connection state.
type Conn struct {
rtc *webrtc.PeerConnection
opts *ConnOptions
rtc *webrtc.PeerConnection
// Determines whether this connection will send the offer or the answer.
offerer bool

Expand All @@ -127,6 +127,9 @@ type Conn struct {
negotiateMutex sync.Mutex
hasNegotiated bool

loggerValue atomic.Value
settingEngine webrtc.SettingEngine

pingChannelID uint16
pingEchoChannelID uint16

Expand All @@ -139,6 +142,14 @@ type Conn struct {
pingError error
}

func (c *Conn) logger() slog.Logger {
log, valid := c.loggerValue.Load().(slog.Logger)
if !valid {
return slog.Logger{}
}
return log
}

func (c *Conn) init() error {
// The negotiation needed callback can take a little bit to execute!
c.negotiateMutex.Lock()
Expand All @@ -152,7 +163,7 @@ func (c *Conn) init() error {
// Don't log more state changes if we've already closed.
return
default:
c.opts.Logger.Debug(context.Background(), "ice connection state updated",
c.logger().Debug(context.Background(), "ice connection state updated",
slog.F("state", iceConnectionState))

if iceConnectionState == webrtc.ICEConnectionStateClosed {
Expand All @@ -171,7 +182,7 @@ func (c *Conn) init() error {
// Don't log more state changes if we've already closed.
return
default:
c.opts.Logger.Debug(context.Background(), "ice gathering state updated",
c.logger().Debug(context.Background(), "ice gathering state updated",
slog.F("state", iceGatherState))

if iceGatherState == webrtc.ICEGathererStateClosed {
Expand All @@ -189,7 +200,7 @@ func (c *Conn) init() error {
if c.isClosed() {
return
}
c.opts.Logger.Debug(context.Background(), "rtc connection updated",
c.logger().Debug(context.Background(), "rtc connection updated",
slog.F("state", peerConnectionState))
}()

Expand Down Expand Up @@ -225,38 +236,25 @@ func (c *Conn) init() error {
// These functions need to check if the conn is closed, because they can be
// called after being closed.
c.rtc.OnSignalingStateChange(func(signalState webrtc.SignalingState) {
if c.isClosed() {
return
}
c.opts.Logger.Debug(context.Background(), "signaling state updated",
c.logger().Debug(context.Background(), "signaling state updated",
slog.F("state", signalState))
})
c.rtc.SCTP().Transport().OnStateChange(func(dtlsTransportState webrtc.DTLSTransportState) {
if c.isClosed() {
return
}
c.opts.Logger.Debug(context.Background(), "dtls transport state updated",
c.logger().Debug(context.Background(), "dtls transport state updated",
slog.F("state", dtlsTransportState))
})
c.rtc.SCTP().Transport().ICETransport().OnSelectedCandidatePairChange(func(candidatePair *webrtc.ICECandidatePair) {
if c.isClosed() {
return
}
c.opts.Logger.Debug(context.Background(), "selected candidate pair changed",
c.logger().Debug(context.Background(), "selected candidate pair changed",
slog.F("local", candidatePair.Local), slog.F("remote", candidatePair.Remote))
})
c.rtc.OnICECandidate(func(iceCandidate *webrtc.ICECandidate) {
if c.isClosed() {
return
}

if iceCandidate == nil {
return
}
// Run this in a goroutine so we don't block pion/webrtc
// from continuing.
go func() {
c.opts.Logger.Debug(context.Background(), "sending local candidate", slog.F("candidate", iceCandidate.ToJSON().Candidate))
c.logger().Debug(context.Background(), "sending local candidate", slog.F("candidate", iceCandidate.ToJSON().Candidate))
select {
case <-c.closed:
break
Expand Down Expand Up @@ -287,7 +285,7 @@ func (c *Conn) init() error {
// negotiate is triggered when a connection is ready to be established.
// See trickle ICE for the expected exchange: https://webrtchacks.com/trickle-ice/
func (c *Conn) negotiate() {
c.opts.Logger.Debug(context.Background(), "negotiating")
c.logger().Debug(context.Background(), "negotiating")
// ICE candidates cannot be added until SessionDescriptions have been
// exchanged between peers.
if c.hasNegotiated {
Expand All @@ -311,23 +309,23 @@ func (c *Conn) negotiate() {
_ = c.CloseWithError(xerrors.Errorf("set local description: %w", err))
return
}
c.opts.Logger.Debug(context.Background(), "sending offer", slog.F("offer", offer))
c.logger().Debug(context.Background(), "sending offer", slog.F("offer", offer))
select {
case <-c.closed:
return
case c.localSessionDescriptionChannel <- offer:
}
c.opts.Logger.Debug(context.Background(), "sent offer")
c.logger().Debug(context.Background(), "sent offer")
}

var sessionDescription webrtc.SessionDescription
c.opts.Logger.Debug(context.Background(), "awaiting remote description...")
c.logger().Debug(context.Background(), "awaiting remote description...")
select {
case <-c.closed:
return
case sessionDescription = <-c.remoteSessionDescriptionChannel:
}
c.opts.Logger.Debug(context.Background(), "setting remote description")
c.logger().Debug(context.Background(), "setting remote description")

err := c.rtc.SetRemoteDescription(sessionDescription)
if err != nil {
Expand All @@ -350,13 +348,13 @@ func (c *Conn) negotiate() {
_ = c.CloseWithError(xerrors.Errorf("set local description: %w", err))
return
}
c.opts.Logger.Debug(context.Background(), "sending answer", slog.F("answer", answer))
c.logger().Debug(context.Background(), "sending answer", slog.F("answer", answer))
select {
case <-c.closed:
return
case c.localSessionDescriptionChannel <- answer:
}
c.opts.Logger.Debug(context.Background(), "sent answer")
c.logger().Debug(context.Background(), "sent answer")
}
}

Expand All @@ -373,7 +371,7 @@ func (c *Conn) AddRemoteCandidate(i webrtc.ICECandidateInit) {
if c.isClosed() {
return
}
c.opts.Logger.Debug(context.Background(), "accepting candidate", slog.F("candidate", i.Candidate))
c.logger().Debug(context.Background(), "accepting candidate", slog.F("candidate", i.Candidate))
err := c.rtc.AddICECandidate(i)
if err != nil {
if c.rtc.ConnectionState() == webrtc.PeerConnectionStateClosed {
Expand Down Expand Up @@ -482,7 +480,7 @@ func (c *Conn) Dial(ctx context.Context, label string, opts *ChannelOptions) (*C
}

func (c *Conn) dialChannel(ctx context.Context, label string, opts *ChannelOptions) (*Channel, error) {
c.opts.Logger.Debug(ctx, "creating data channel", slog.F("label", label), slog.F("opts", opts))
c.logger().Debug(ctx, "creating data channel", slog.F("label", label), slog.F("opts", opts))
var id *uint16
if opts.ID != 0 {
id = &opts.ID
Expand Down Expand Up @@ -531,7 +529,7 @@ func (c *Conn) Ping() (time.Duration, error) {
if err != nil {
return 0, xerrors.Errorf("send ping: %w", err)
}
c.opts.Logger.Debug(context.Background(), "wrote ping",
c.logger().Debug(context.Background(), "wrote ping",
slog.F("connection_state", c.rtc.ConnectionState()))

pingDataReceived := make([]byte, pingDataLength)
Expand Down Expand Up @@ -568,12 +566,11 @@ func (c *Conn) isClosed() bool {
func (c *Conn) CloseWithError(err error) error {
c.closeMutex.Lock()
defer c.closeMutex.Unlock()

if c.isClosed() {
return c.closeError
}

c.opts.Logger.Debug(context.Background(), "closing conn with error", slog.Error(err))
c.logger().Debug(context.Background(), "closing conn with error", slog.Error(err))
if err == nil {
c.closeError = ErrClosed
} else {
Expand All @@ -591,19 +588,21 @@ func (c *Conn) CloseWithError(err error) error {
// Waiting for pion/webrtc to report closed state on both of these
// ensures no goroutine leaks.
if c.rtc.ConnectionState() != webrtc.PeerConnectionStateNew {
c.opts.Logger.Debug(context.Background(), "waiting for rtc connection close...")
c.logger().Debug(context.Background(), "waiting for rtc connection close...")
<-c.closedRTC
}
if c.rtc.ICEConnectionState() != webrtc.ICEConnectionStateNew {
c.opts.Logger.Debug(context.Background(), "waiting for ice connection close...")
c.logger().Debug(context.Background(), "waiting for ice connection close...")
<-c.closedICE
}

// Waits for all DataChannels to exit before officially labeling as closed.
// All logging, goroutines, and async functionality is cleaned up after this.
c.dcClosedWaitGroup.Wait()

c.opts.Logger.Debug(context.Background(), "closed")
c.logger().Debug(context.Background(), "closed")
// Disable logging!
c.loggerValue.Store(slog.Logger{})
close(c.closed)
return err
}

0 comments on commit 3176e10

Please sign in to comment.