diff --git a/x/mongo/driver/topology/topology.go b/x/mongo/driver/topology/topology.go index bed45ae5b7..5343934ccf 100644 --- a/x/mongo/driver/topology/topology.go +++ b/x/mongo/driver/topology/topology.go @@ -694,7 +694,6 @@ func (t *Topology) selectServerFromSubscription( case <-ctx.Done(): return nil, ServerSelectionError{Wrapped: ctx.Err(), Desc: current} case current = <-subscriptionCh: - default: } suitable, err := t.selectServerFromDescription(current, srvSelector) diff --git a/x/mongo/driver/topology/topology_test.go b/x/mongo/driver/topology/topology_test.go index 403179606e..a4acb4732e 100644 --- a/x/mongo/driver/topology/topology_test.go +++ b/x/mongo/driver/topology/topology_test.go @@ -45,6 +45,18 @@ func compareErrors(err1, err2 error) bool { return true } +var _ description.ServerSelector = &mockServerSelector{} + +type mockServerSelector struct { + selectServerCalls atomic.Int64 +} + +func (m *mockServerSelector) SelectServer(description.Topology, []description.Server) ([]description.Server, error) { + m.selectServerCalls.Add(1) + + return nil, nil +} + func TestServerSelection(t *testing.T) { var selectFirst serverselector.Func = func(_ description.Topology, candidates []description.Server) ([]description.Server, error) { if len(candidates) == 0 { @@ -263,6 +275,30 @@ func TestServerSelection(t *testing.T) { _, err = topo.SelectServer(context.Background(), &serverselector.Write{}) assert.Equal(t, ErrSubscribeAfterClosed, err, "expected error %v, got %v", ErrSubscribeAfterClosed, err) }) + t.Run("if no servers are suitable, block on topology updates", func(t *testing.T) { + // Create a connected Topology with no selectable servers. + topo, err := New(nil) + require.NoError(t, err) + atomic.StoreInt64(&topo.state, topologyConnected) + + mss := &mockServerSelector{} + + ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) + defer cancel() + + _, err = topo.SelectServer(ctx, mss) + assert.ErrorIs(t, err, context.DeadlineExceeded, "expected context deadline exceeded error") + + // Expect SelectServer to be called twice: once for the fast path and + // once to select from the topology updates subscription. + // + // Note: The second call is due to Topology.Subscript pre-populating the + // channel with the current topology. It's not clear what the purpose of + // that behavior is. The main goal of this assertion is to make sure the + // subscription path blocks on updates and doesn't turn into a busy + // wait. + assert.Equal(t, int64(2), mss.selectServerCalls.Load(), "expected SelectServer to be called once") + }) } func TestSessionTimeout(t *testing.T) {