diff --git a/checksum_row_iterator.go b/checksum_row_iterator.go index 623832a1..c6980a27 100644 --- a/checksum_row_iterator.go +++ b/checksum_row_iterator.go @@ -23,6 +23,7 @@ import ( "cloud.google.com/go/spanner" sppb "cloud.google.com/go/spanner/apiv1/spannerpb" + "github.com/googleapis/go-sql-spanner/parser" "google.golang.org/api/iterator" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -51,10 +52,11 @@ type checksumRowIterator struct { *spanner.RowIterator metadata *sppb.ResultSetMetadata - ctx context.Context - tx *readWriteTransaction - stmt spanner.Statement - options spanner.QueryOptions + ctx context.Context + tx *readWriteTransaction + stmt spanner.Statement + stmtType parser.StatementType + options spanner.QueryOptions // nc (nextCount) indicates the number of times that next has been called // on the iterator. Next() will be called the same number of times during // a retry. @@ -253,10 +255,5 @@ func (it *checksumRowIterator) Metadata() (*sppb.ResultSetMetadata, error) { } func (it *checksumRowIterator) ResultSetStats() *sppb.ResultSetStats { - // TODO: The Spanner client library should offer an option to get the full - // ResultSetStats, instead of only the RowCount and QueryPlan. - return &sppb.ResultSetStats{ - RowCount: &sppb.ResultSetStats_RowCountExact{RowCountExact: it.RowIterator.RowCount}, - QueryPlan: it.RowIterator.QueryPlan, - } + return createResultSetStats(it.RowIterator, it.stmtType) } diff --git a/conn.go b/conn.go index 7f9ece77..e506c86c 100644 --- a/conn.go +++ b/conn.go @@ -259,8 +259,8 @@ type conn struct { resetForRetry bool database string - execSingleQuery func(ctx context.Context, c *spanner.Client, statement spanner.Statement, bound spanner.TimestampBound, options *ExecOptions) *spanner.RowIterator - execSingleQueryTransactional func(ctx context.Context, c *spanner.Client, statement spanner.Statement, options *ExecOptions) (rowIterator, *spanner.CommitResponse, error) + execSingleQuery func(ctx context.Context, c *spanner.Client, statement spanner.Statement, statementInfo *parser.StatementInfo, bound spanner.TimestampBound, options *ExecOptions) *spanner.RowIterator + execSingleQueryTransactional func(ctx context.Context, c *spanner.Client, statement spanner.Statement, statementInfo *parser.StatementInfo, options *ExecOptions) (rowIterator, *spanner.CommitResponse, error) execSingleDMLTransactional func(ctx context.Context, c *spanner.Client, statement spanner.Statement, statementInfo *parser.StatementInfo, options *ExecOptions) (*result, *spanner.CommitResponse, error) execSingleDMLPartitioned func(ctx context.Context, c *spanner.Client, statement spanner.Statement, options *ExecOptions) (int64, error) @@ -831,9 +831,9 @@ func (c *conn) queryContext(ctx context.Context, query string, execOptions *Exec if err != nil { return nil, err } - statementType := c.parser.DetectStatementType(query) + statementInfo := c.parser.DetectStatementType(query) // DDL statements are not supported in QueryContext so use the execContext method for the execution. - if statementType.StatementType == parser.StatementTypeDdl { + if statementInfo.StatementType == parser.StatementTypeDdl { res, err := c.execContext(ctx, query, execOptions, args) if err != nil { return nil, err @@ -842,10 +842,10 @@ func (c *conn) queryContext(ctx context.Context, query string, execOptions *Exec } var iter rowIterator if c.tx == nil { - if statementType.StatementType == parser.StatementTypeDml { + if statementInfo.StatementType == parser.StatementTypeDml { // Use a read/write transaction to execute the statement. var commitResponse *spanner.CommitResponse - iter, commitResponse, err = c.execSingleQueryTransactional(ctx, c.client, stmt, execOptions) + iter, commitResponse, err = c.execSingleQueryTransactional(ctx, c.client, stmt, statementInfo, execOptions) if err != nil { return nil, err } @@ -858,13 +858,13 @@ func (c *conn) queryContext(ctx context.Context, query string, execOptions *Exec // The statement was either detected as being a query, or potentially not recognized at all. // In that case, just default to using a single-use read-only transaction and let Spanner // return an error if the statement is not suited for that type of transaction. - iter = &readOnlyRowIterator{c.execSingleQuery(ctx, c.client, stmt, c.ReadOnlyStaleness(), execOptions)} + iter = &readOnlyRowIterator{c.execSingleQuery(ctx, c.client, stmt, statementInfo, c.ReadOnlyStaleness(), execOptions), statementInfo.StatementType} } } else { if execOptions.PartitionedQueryOptions.PartitionQuery { return c.tx.partitionQuery(ctx, stmt, execOptions) } - iter, err = c.tx.Query(ctx, stmt, execOptions) + iter, err = c.tx.Query(ctx, stmt, statementInfo.StatementType, execOptions) if err != nil { return nil, err } @@ -1273,7 +1273,7 @@ func (c *conn) rollback(ctx context.Context) error { return c.tx.Rollback() } -func queryInSingleUse(ctx context.Context, c *spanner.Client, statement spanner.Statement, tb spanner.TimestampBound, options *ExecOptions) *spanner.RowIterator { +func queryInSingleUse(ctx context.Context, c *spanner.Client, statement spanner.Statement, statementInfo *parser.StatementInfo, tb spanner.TimestampBound, options *ExecOptions) *spanner.RowIterator { return c.Single().WithTimestampBound(tb).QueryWithOptions(ctx, statement, options.QueryOptions) } @@ -1295,7 +1295,7 @@ func (c *conn) executeAutoPartitionedQuery(ctx context.Context, query string, ex return r, nil } -func queryInNewRWTransaction(ctx context.Context, c *spanner.Client, statement spanner.Statement, options *ExecOptions) (rowIterator, *spanner.CommitResponse, error) { +func queryInNewRWTransaction(ctx context.Context, c *spanner.Client, statement spanner.Statement, statementInfo *parser.StatementInfo, options *ExecOptions) (rowIterator, *spanner.CommitResponse, error) { var result *wrappedRowIterator options.QueryOptions.LastStatement = true fn := func(ctx context.Context, tx *spanner.ReadWriteTransaction) error { @@ -1304,6 +1304,7 @@ func queryInNewRWTransaction(ctx context.Context, c *spanner.Client, statement s if err == iterator.Done { result = &wrappedRowIterator{ RowIterator: it, + stmtType: statementInfo.StatementType, noRows: true, } } else if err != nil { @@ -1312,6 +1313,7 @@ func queryInNewRWTransaction(ctx context.Context, c *spanner.Client, statement s } else { result = &wrappedRowIterator{ RowIterator: it, + stmtType: statementInfo.StatementType, firstRow: row, } } diff --git a/driver_test.go b/driver_test.go index 8e62281a..dc09940f 100644 --- a/driver_test.go +++ b/driver_test.go @@ -630,7 +630,7 @@ func TestConn_NonDdlStatementsInDdlBatch(t *testing.T) { logger: noopLogger, state: createInitialConnectionState(connectionstate.TypeNonTransactional, map[string]connectionstate.ConnectionPropertyValue{}), batch: &batch{tp: parser.BatchTypeDdl}, - execSingleQuery: func(ctx context.Context, c *spanner.Client, statement spanner.Statement, tb spanner.TimestampBound, options *ExecOptions) *spanner.RowIterator { + execSingleQuery: func(ctx context.Context, c *spanner.Client, statement spanner.Statement, statementInfo *parser.StatementInfo, tb spanner.TimestampBound, options *ExecOptions) *spanner.RowIterator { return &spanner.RowIterator{} }, execSingleDMLTransactional: func(ctx context.Context, c *spanner.Client, statement spanner.Statement, statementInfo *parser.StatementInfo, options *ExecOptions) (*result, *spanner.CommitResponse, error) { @@ -670,7 +670,7 @@ func TestConn_NonDmlStatementsInDmlBatch(t *testing.T) { logger: noopLogger, state: createInitialConnectionState(connectionstate.TypeNonTransactional, map[string]connectionstate.ConnectionPropertyValue{}), batch: &batch{tp: parser.BatchTypeDml}, - execSingleQuery: func(ctx context.Context, c *spanner.Client, statement spanner.Statement, tb spanner.TimestampBound, options *ExecOptions) *spanner.RowIterator { + execSingleQuery: func(ctx context.Context, c *spanner.Client, statement spanner.Statement, statementInfo *parser.StatementInfo, tb spanner.TimestampBound, options *ExecOptions) *spanner.RowIterator { return &spanner.RowIterator{} }, execSingleDMLTransactional: func(ctx context.Context, c *spanner.Client, statement spanner.Statement, statementInfo *parser.StatementInfo, options *ExecOptions) (*result, *spanner.CommitResponse, error) { @@ -761,7 +761,7 @@ func TestConn_GetCommitResponseAfterAutocommitDml(t *testing.T) { parser: p, logger: noopLogger, state: createInitialConnectionState(connectionstate.TypeNonTransactional, map[string]connectionstate.ConnectionPropertyValue{}), - execSingleQuery: func(ctx context.Context, c *spanner.Client, statement spanner.Statement, tb spanner.TimestampBound, options *ExecOptions) *spanner.RowIterator { + execSingleQuery: func(ctx context.Context, c *spanner.Client, statement spanner.Statement, statementInfo *parser.StatementInfo, tb spanner.TimestampBound, options *ExecOptions) *spanner.RowIterator { return &spanner.RowIterator{} }, execSingleDMLTransactional: func(ctx context.Context, c *spanner.Client, statement spanner.Statement, statementInfo *parser.StatementInfo, options *ExecOptions) (*result, *spanner.CommitResponse, error) { @@ -800,7 +800,7 @@ func TestConn_GetCommitResponseAfterAutocommitQuery(t *testing.T) { parser: p, logger: noopLogger, state: createInitialConnectionState(connectionstate.TypeTransactional, map[string]connectionstate.ConnectionPropertyValue{}), - execSingleQuery: func(ctx context.Context, c *spanner.Client, statement spanner.Statement, tb spanner.TimestampBound, options *ExecOptions) *spanner.RowIterator { + execSingleQuery: func(ctx context.Context, c *spanner.Client, statement spanner.Statement, statementInfo *parser.StatementInfo, tb spanner.TimestampBound, options *ExecOptions) *spanner.RowIterator { return &spanner.RowIterator{} }, execSingleDMLTransactional: func(ctx context.Context, c *spanner.Client, statement spanner.Statement, statementInfo *parser.StatementInfo, options *ExecOptions) (*result, *spanner.CommitResponse, error) { diff --git a/driver_with_mockserver_test.go b/driver_with_mockserver_test.go index a8787d3b..ac9cbf79 100644 --- a/driver_with_mockserver_test.go +++ b/driver_with_mockserver_test.go @@ -5384,6 +5384,73 @@ func TestReturnResultSetStats(t *testing.T) { } } +func TestReturnResultSetStatsForQuery(t *testing.T) { + t.Parallel() + + db, server, teardown := setupTestDBConnection(t) + defer teardown() + query := "select id from singers where id=42598" + resultSet := testutil.CreateSingleColumnInt64ResultSet([]int64{42598}, "id") + _ = server.TestSpanner.PutStatementResult(query, &testutil.StatementResult{ + Type: testutil.StatementResultResultSet, + ResultSet: resultSet, + }) + + rows, err := db.QueryContext(context.Background(), query, ExecOptions{ReturnResultSetStats: true}) + if err != nil { + t.Fatal(err) + } + defer func() { _ = rows.Close() }() + + // The first result set should contain the data. + for want := int64(42598); rows.Next(); want++ { + cols, err := rows.Columns() + if err != nil { + t.Fatal(err) + } + if !cmp.Equal(cols, []string{"id"}) { + t.Fatalf("cols mismatch\nGot: %v\nWant: %v", cols, []string{"id"}) + } + var got int64 + err = rows.Scan(&got) + if err != nil { + t.Fatal(err) + } + if got != want { + t.Fatalf("value mismatch\nGot: %v\nWant: %v", got, want) + } + } + if rows.Err() != nil { + t.Fatal(rows.Err()) + } + + // The next result set should contain the stats. + if !rows.NextResultSet() { + t.Fatal("missing stats result set") + } + + // Get the stats. + if !rows.Next() { + t.Fatal("no stats rows") + } + var stats *sppb.ResultSetStats + if err := rows.Scan(&stats); err != nil { + t.Fatalf("failed to scan stats: %v", err) + } + // The stats should not contain any update count. + if stats.GetRowCount() != nil { + t.Fatalf("got update count for query") + } + if rows.Next() { + t.Fatal("more rows than expected") + } + + // There should be no more result sets. + if rows.NextResultSet() { + t.Fatal("more result sets than expected") + } +} + func TestReturnResultSetMetadataAndStats(t *testing.T) { t.Parallel() diff --git a/partitioned_query.go b/partitioned_query.go index 77fe8240..018f7b7d 100644 --- a/partitioned_query.go +++ b/partitioned_query.go @@ -22,6 +22,7 @@ import ( "io" "cloud.google.com/go/spanner" + "github.com/googleapis/go-sql-spanner/parser" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" ) @@ -231,7 +232,7 @@ func (pq *PartitionedQuery) execute(ctx context.Context, index int) (*rows, erro return nil, spanner.ToSpannerError(status.Errorf(codes.InvalidArgument, "invalid partition index: %d", index)) } spannerIter := pq.tx.Execute(ctx, pq.Partitions[index]) - iter := &readOnlyRowIterator{spannerIter} + iter := &readOnlyRowIterator{spannerIter, parser.StatementTypeQuery} return &rows{it: iter, decodeOption: pq.execOptions.DecodeOption}, nil } diff --git a/transaction.go b/transaction.go index dcc61d1b..de11fe0d 100644 --- a/transaction.go +++ b/transaction.go @@ -44,7 +44,7 @@ type contextTransaction interface { Commit() error Rollback() error resetForRetry(ctx context.Context) error - Query(ctx context.Context, stmt spanner.Statement, execOptions *ExecOptions) (rowIterator, error) + Query(ctx context.Context, stmt spanner.Statement, stmtType parser.StatementType, execOptions *ExecOptions) (rowIterator, error) partitionQuery(ctx context.Context, stmt spanner.Statement, execOptions *ExecOptions) (driver.Rows, error) ExecContext(ctx context.Context, stmt spanner.Statement, statementInfo *parser.StatementInfo, options spanner.QueryOptions) (*result, error) @@ -67,6 +67,7 @@ var _ rowIterator = &readOnlyRowIterator{} type readOnlyRowIterator struct { *spanner.RowIterator + stmtType parser.StatementType } func (ri *readOnlyRowIterator) Next() (*spanner.Row, error) { @@ -82,12 +83,19 @@ func (ri *readOnlyRowIterator) Metadata() (*sppb.ResultSetMetadata, error) { } func (ri *readOnlyRowIterator) ResultSetStats() *sppb.ResultSetStats { + return createResultSetStats(ri.RowIterator, ri.stmtType) +} + +func createResultSetStats(it *spanner.RowIterator, stmtType parser.StatementType) *sppb.ResultSetStats { // TODO: The Spanner client library should offer an option to get the full // ResultSetStats, instead of only the RowCount and QueryPlan. - return &sppb.ResultSetStats{ - RowCount: &sppb.ResultSetStats_RowCountExact{RowCountExact: ri.RowIterator.RowCount}, - QueryPlan: ri.RowIterator.QueryPlan, + stats := &sppb.ResultSetStats{ + QueryPlan: it.QueryPlan, + } + if stmtType == parser.StatementTypeDml { + stats.RowCount = &sppb.ResultSetStats_RowCountExact{RowCountExact: it.RowCount} } + return stats } type txResult int @@ -135,7 +143,7 @@ func (tx *readOnlyTransaction) resetForRetry(ctx context.Context) error { return nil } -func (tx *readOnlyTransaction) Query(ctx context.Context, stmt spanner.Statement, execOptions *ExecOptions) (rowIterator, error) { +func (tx *readOnlyTransaction) Query(ctx context.Context, stmt spanner.Statement, stmtType parser.StatementType, execOptions *ExecOptions) (rowIterator, error) { tx.logger.DebugContext(ctx, "Query", "stmt", stmt.SQL) if execOptions.PartitionedQueryOptions.AutoPartitionQuery { if tx.boTx == nil { @@ -152,7 +160,7 @@ func (tx *readOnlyTransaction) Query(ctx context.Context, stmt spanner.Statement } return mi, nil } - return &readOnlyRowIterator{tx.roTx.QueryWithOptions(ctx, stmt, execOptions.QueryOptions)}, nil + return &readOnlyRowIterator{tx.roTx.QueryWithOptions(ctx, stmt, execOptions.QueryOptions), stmtType}, nil } func (tx *readOnlyTransaction) partitionQuery(ctx context.Context, stmt spanner.Statement, execOptions *ExecOptions) (driver.Rows, error) { @@ -456,7 +464,7 @@ func (tx *readWriteTransaction) resetForRetry(ctx context.Context) error { // Query executes a query using the read/write transaction and returns a // rowIterator that will automatically retry the read/write transaction if the // transaction is aborted during the query or while iterating the returned rows. -func (tx *readWriteTransaction) Query(ctx context.Context, stmt spanner.Statement, execOptions *ExecOptions) (rowIterator, error) { +func (tx *readWriteTransaction) Query(ctx context.Context, stmt spanner.Statement, stmtType parser.StatementType, execOptions *ExecOptions) (rowIterator, error) { tx.logger.Debug("Query", "stmt", stmt.SQL) tx.active = true if err := tx.maybeRunAutoDmlBatch(ctx); err != nil { @@ -465,7 +473,7 @@ func (tx *readWriteTransaction) Query(ctx context.Context, stmt spanner.Statemen // If internal retries have been disabled, we don't need to keep track of a // running checksum for all results that we have seen. if !tx.retryAborts() { - return &readOnlyRowIterator{tx.rwTx.QueryWithOptions(ctx, stmt, execOptions.QueryOptions)}, nil + return &readOnlyRowIterator{tx.rwTx.QueryWithOptions(ctx, stmt, execOptions.QueryOptions), stmtType}, nil } // If retries are enabled, we need to use a row iterator that will keep diff --git a/wrapped_row_iterator.go b/wrapped_row_iterator.go index 4e99cff2..6dd4e828 100644 --- a/wrapped_row_iterator.go +++ b/wrapped_row_iterator.go @@ -17,14 +17,17 @@ package spannerdriver import ( "cloud.google.com/go/spanner" "cloud.google.com/go/spanner/apiv1/spannerpb" + "github.com/googleapis/go-sql-spanner/parser" "google.golang.org/api/iterator" ) var _ rowIterator = &wrappedRowIterator{} +// wrappedRowIterator is used for DML statements that may or may not contain rows. type wrappedRowIterator struct { *spanner.RowIterator + stmtType parser.StatementType noRows bool firstRow *spanner.Row } @@ -49,8 +52,5 @@ func (ri *wrappedRowIterator) Metadata() (*spannerpb.ResultSetMetadata, error) { } func (ri *wrappedRowIterator) ResultSetStats() *spannerpb.ResultSetStats { - return &spannerpb.ResultSetStats{ - RowCount: &spannerpb.ResultSetStats_RowCountExact{RowCountExact: ri.RowIterator.RowCount}, - QueryPlan: ri.RowIterator.QueryPlan, - } + return createResultSetStats(ri.RowIterator, ri.stmtType) }