Skip to content

Commit

Permalink
fix: Prevent concurrent WS socket write errors in the CDS client
Browse files Browse the repository at this point in the history
  • Loading branch information
rg0now committed Jun 15, 2024
1 parent 904b352 commit 012f8c0
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 21 deletions.
5 changes: 4 additions & 1 deletion pkg/config/client/cds_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"github.com/gorilla/websocket"
stnrv1 "github.com/l7mp/stunner/pkg/apis/v1"
"github.com/l7mp/stunner/pkg/config/client/api"
"github.com/l7mp/stunner/pkg/config/util"
"github.com/pion/logging"
)

Expand Down Expand Up @@ -268,10 +269,12 @@ func poll(ctx context.Context, a CdsApi, ch chan<- *stnrv1.StunnerConfig) error
_, url := a.Endpoint()
a.Tracef("poll: trying to open connection to CDS server at %s", url)

conn, _, err := websocket.DefaultDialer.DialContext(ctx, url, makeHeader(url))
wc, _, err := websocket.DefaultDialer.DialContext(ctx, url, makeHeader(url))
if err != nil {
return err
}
// wrap with a locker to prevent concurrent writes
conn := util.NewConn(wc)
defer conn.Close() // this will close the poller goroutine

a.Infof("connection successfully opened to config discovery server at %s", url)
Expand Down
26 changes: 6 additions & 20 deletions pkg/config/server/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,23 @@ import (

"github.com/gorilla/websocket"
stnrv1 "github.com/l7mp/stunner/pkg/apis/v1"
"github.com/l7mp/stunner/pkg/config/util"
)

type ClientConfigPatcher func(conf *stnrv1.StunnerConfig) (*stnrv1.StunnerConfig, error)

// Conn represents a client WebSocket connection.
type Conn struct {
*websocket.Conn
Filter ConfigFilter
patch ClientConfigPatcher
cancel context.CancelFunc
readLock, writeLock sync.Mutex // for writemessage
*util.Conn
Filter ConfigFilter
patch ClientConfigPatcher
cancel context.CancelFunc
}

// NewConn wraps a WebSocket connection.
func NewConn(conn *websocket.Conn, filter ConfigFilter, patch ClientConfigPatcher, cancel context.CancelFunc) *Conn {
return &Conn{
Conn: conn,
Conn: util.NewConn(conn),
Filter: filter,
patch: patch,
cancel: cancel,
Expand All @@ -35,20 +35,6 @@ func (c *Conn) Id() string {
return fmt.Sprintf("%s:%s", c.RemoteAddr().Network(), c.RemoteAddr().String())
}

// WriteMessage writes a message to the client connection with proper locking.
func (c *Conn) WriteMessage(messageType int, data []byte) error {
c.writeLock.Lock()
defer c.writeLock.Unlock()
return c.Conn.WriteMessage(messageType, data)
}

// ReadMessage reads a message from the client connection with proper locking.
func (c *Conn) ReadMessage() (int, []byte, error) {
c.readLock.Lock()
defer c.readLock.Unlock()
return c.Conn.ReadMessage()
}

// ConnTrack represents the server's connection tracking table.
type ConnTrack struct {
conns []*Conn
Expand Down
34 changes: 34 additions & 0 deletions pkg/config/util/conn.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
package util

import (
"sync"

"github.com/gorilla/websocket"
)

// Conn represents a client WebSocket connection. An added lock guards the underlying connection
// from concurrent write to websocket connection errors.
type Conn struct {
*websocket.Conn
readLock, writeLock sync.Mutex // for writemessage

}

// NewConn wraps a WebSocket connection.
func NewConn(conn *websocket.Conn) *Conn {
return &Conn{Conn: conn}
}

// WriteMessage writes a message to the client connection with proper locking.
func (c *Conn) WriteMessage(messageType int, data []byte) error {
c.writeLock.Lock()
defer c.writeLock.Unlock()
return c.Conn.WriteMessage(messageType, data)
}

// ReadMessage reads a message from the client connection with proper locking.
func (c *Conn) ReadMessage() (int, []byte, error) {
c.readLock.Lock()
defer c.readLock.Unlock()
return c.Conn.ReadMessage()
}

0 comments on commit 012f8c0

Please sign in to comment.