diff --git a/spanner/internal/testutil/inmem_spanner_server.go b/spanner/internal/testutil/inmem_spanner_server.go index dde0a1ece06d..3dff54465ff6 100644 --- a/spanner/internal/testutil/inmem_spanner_server.go +++ b/spanner/internal/testutil/inmem_spanner_server.go @@ -827,8 +827,8 @@ func (s *inMemSpannerServer) ExecuteSql(ctx context.Context, req *spannerpb.Exec if err != nil { return nil, err } - statementResult = statementResult.getResultSetWithTransactionSet(req.GetTransaction(), id) s.mu.Lock() + statementResult = statementResult.getResultSetWithTransactionSet(req.GetTransaction(), id) isPartitionedDml := s.partitionedDmlTransactions[string(id)] s.mu.Unlock() switch statementResult.Type { @@ -874,8 +874,8 @@ func (s *inMemSpannerServer) executeStreamingSQL(req *spannerpb.ExecuteSqlReques if err != nil { return err } - statementResult.getResultSetWithTransactionSet(req.GetTransaction(), id) s.mu.Lock() + statementResult.getResultSetWithTransactionSet(req.GetTransaction(), id) isPartitionedDml := s.partitionedDmlTransactions[string(id)] s.mu.Unlock() switch statementResult.Type { @@ -948,7 +948,9 @@ func (s *inMemSpannerServer) ExecuteBatchDml(ctx context.Context, req *spannerpb if err != nil { return nil, err } + s.mu.Lock() statementResult = statementResult.getResultSetWithTransactionSet(req.GetTransaction(), id) + s.mu.Unlock() switch statementResult.Type { case StatementResultError: resp.Status = &status.Status{Code: int32(gstatus.Code(statementResult.Err)), Message: statementResult.Err.Error()}