Skip to content

Commit

Permalink
GODRIVER-2658 Better guard against nil pinned connections. (#1153)
Browse files Browse the repository at this point in the history
  • Loading branch information
benjirewis committed Jan 17, 2023
1 parent af212d0 commit 667bfcf
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 14 deletions.
37 changes: 23 additions & 14 deletions x/mongo/driver/topology/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -578,9 +578,10 @@ func (c initConnection) SupportsStreaming() bool {
}

// Connection implements the driver.Connection interface to allow reading and writing wire
// messages and the driver.Expirable interface to allow expiring.
// messages and the driver.Expirable interface to allow expiring. It wraps an underlying
// topology.connection to make it more goroutine-safe and nil-safe.
type Connection struct {
*connection
connection *connection
refCount int
cleanupPoolFn func()

Expand All @@ -602,7 +603,7 @@ func (c *Connection) WriteWireMessage(ctx context.Context, wm []byte) error {
if c.connection == nil {
return ErrConnectionClosed
}
return c.writeWireMessage(ctx, wm)
return c.connection.writeWireMessage(ctx, wm)
}

// ReadWireMessage handles reading a wire message from the underlying connection. The dst parameter
Expand All @@ -613,7 +614,7 @@ func (c *Connection) ReadWireMessage(ctx context.Context, dst []byte) ([]byte, e
if c.connection == nil {
return dst, ErrConnectionClosed
}
return c.readWireMessage(ctx, dst)
return c.connection.readWireMessage(ctx, dst)
}

// CompressWireMessage handles compressing the provided wire message using the underlying
Expand Down Expand Up @@ -656,7 +657,7 @@ func (c *Connection) Description() description.Server {
if c.connection == nil {
return description.Server{}
}
return c.desc
return c.connection.desc
}

// Close returns this connection to the connection pool. This method may not closeConnection the underlying
Expand All @@ -679,12 +680,12 @@ func (c *Connection) Expire() error {
return nil
}

_ = c.close()
_ = c.connection.close()
return c.cleanupReferences()
}

func (c *Connection) cleanupReferences() error {
err := c.pool.checkIn(c.connection)
err := c.connection.pool.checkIn(c.connection)
if c.cleanupPoolFn != nil {
c.cleanupPoolFn()
c.cleanupPoolFn = nil
Expand All @@ -709,14 +710,22 @@ func (c *Connection) ID() string {
if c.connection == nil {
return "<closed>"
}
return c.id
return c.connection.id
}

// ServerConnectionID returns the server connection ID of this connection.
func (c *Connection) ServerConnectionID() *int32 {
if c.connection == nil {
return nil
}
return c.connection.serverConnectionID
}

// Stale returns if the connection is stale.
func (c *Connection) Stale() bool {
c.mu.RLock()
defer c.mu.RUnlock()
return c.pool.stale(c.connection)
return c.connection.pool.stale(c.connection)
}

// Address returns the address of this connection.
Expand All @@ -726,27 +735,27 @@ func (c *Connection) Address() address.Address {
if c.connection == nil {
return address.Address("0.0.0.0")
}
return c.addr
return c.connection.addr
}

// LocalAddress returns the local address of the connection
func (c *Connection) LocalAddress() address.Address {
c.mu.RLock()
defer c.mu.RUnlock()
if c.connection == nil || c.nc == nil {
if c.connection == nil || c.connection.nc == nil {
return address.Address("0.0.0.0")
}
return address.Address(c.nc.LocalAddr().String())
return address.Address(c.connection.nc.LocalAddr().String())
}

// PinToCursor updates this connection to reflect that it is pinned to a cursor.
func (c *Connection) PinToCursor() error {
return c.pin("cursor", c.pool.pinConnectionToCursor, c.pool.unpinConnectionFromCursor)
return c.pin("cursor", c.connection.pool.pinConnectionToCursor, c.connection.pool.unpinConnectionFromCursor)
}

// PinToTransaction updates this connection to reflect that it is pinned to a transaction.
func (c *Connection) PinToTransaction() error {
return c.pin("transaction", c.pool.pinConnectionToTransaction, c.pool.unpinConnectionFromTransaction)
return c.pin("transaction", c.connection.pool.pinConnectionToTransaction, c.connection.pool.unpinConnectionFromTransaction)
}

func (c *Connection) pin(reason string, updatePoolFn, cleanupPoolFn func()) error {
Expand Down
6 changes: 6 additions & 0 deletions x/mongo/driver/topology/connection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -762,6 +762,12 @@ func TestConnection(t *testing.T) {
if !cmp.Equal(got, want) {
t.Errorf("LocalAddresses do not match. got %v; want %v", got, want)
}

want = (*int32)(nil)
got = conn.ServerConnectionID()
if !cmp.Equal(got, want) {
t.Errorf("ServerConnectionIDs do not match. got %v; want %v", got, want)
}
})

t.Run("pinning", func(t *testing.T) {
Expand Down

0 comments on commit 667bfcf

Please sign in to comment.