Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 7 additions & 10 deletions checksum_row_iterator.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (

"cloud.google.com/go/spanner"
sppb "cloud.google.com/go/spanner/apiv1/spannerpb"
"github.com/googleapis/go-sql-spanner/parser"
"google.golang.org/api/iterator"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
Expand Down Expand Up @@ -51,10 +52,11 @@ type checksumRowIterator struct {
*spanner.RowIterator
metadata *sppb.ResultSetMetadata

ctx context.Context
tx *readWriteTransaction
stmt spanner.Statement
options spanner.QueryOptions
ctx context.Context
tx *readWriteTransaction
stmt spanner.Statement
stmtType parser.StatementType
options spanner.QueryOptions
// nc (nextCount) indicates the number of times that next has been called
// on the iterator. Next() will be called the same number of times during
// a retry.
Expand Down Expand Up @@ -253,10 +255,5 @@ func (it *checksumRowIterator) Metadata() (*sppb.ResultSetMetadata, error) {
}

func (it *checksumRowIterator) ResultSetStats() *sppb.ResultSetStats {
// TODO: The Spanner client library should offer an option to get the full
// ResultSetStats, instead of only the RowCount and QueryPlan.
return &sppb.ResultSetStats{
RowCount: &sppb.ResultSetStats_RowCountExact{RowCountExact: it.RowIterator.RowCount},
QueryPlan: it.RowIterator.QueryPlan,
}
return createResultSetStats(it.RowIterator, it.stmtType)
}
22 changes: 12 additions & 10 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -259,8 +259,8 @@ type conn struct {
resetForRetry bool
database string

execSingleQuery func(ctx context.Context, c *spanner.Client, statement spanner.Statement, bound spanner.TimestampBound, options *ExecOptions) *spanner.RowIterator
execSingleQueryTransactional func(ctx context.Context, c *spanner.Client, statement spanner.Statement, options *ExecOptions) (rowIterator, *spanner.CommitResponse, error)
execSingleQuery func(ctx context.Context, c *spanner.Client, statement spanner.Statement, statementInfo *parser.StatementInfo, bound spanner.TimestampBound, options *ExecOptions) *spanner.RowIterator
execSingleQueryTransactional func(ctx context.Context, c *spanner.Client, statement spanner.Statement, statementInfo *parser.StatementInfo, options *ExecOptions) (rowIterator, *spanner.CommitResponse, error)
execSingleDMLTransactional func(ctx context.Context, c *spanner.Client, statement spanner.Statement, statementInfo *parser.StatementInfo, options *ExecOptions) (*result, *spanner.CommitResponse, error)
execSingleDMLPartitioned func(ctx context.Context, c *spanner.Client, statement spanner.Statement, options *ExecOptions) (int64, error)

Expand Down Expand Up @@ -831,9 +831,9 @@ func (c *conn) queryContext(ctx context.Context, query string, execOptions *Exec
if err != nil {
return nil, err
}
statementType := c.parser.DetectStatementType(query)
statementInfo := c.parser.DetectStatementType(query)
// DDL statements are not supported in QueryContext so use the execContext method for the execution.
if statementType.StatementType == parser.StatementTypeDdl {
if statementInfo.StatementType == parser.StatementTypeDdl {
res, err := c.execContext(ctx, query, execOptions, args)
if err != nil {
return nil, err
Expand All @@ -842,10 +842,10 @@ func (c *conn) queryContext(ctx context.Context, query string, execOptions *Exec
}
var iter rowIterator
if c.tx == nil {
if statementType.StatementType == parser.StatementTypeDml {
if statementInfo.StatementType == parser.StatementTypeDml {
// Use a read/write transaction to execute the statement.
var commitResponse *spanner.CommitResponse
iter, commitResponse, err = c.execSingleQueryTransactional(ctx, c.client, stmt, execOptions)
iter, commitResponse, err = c.execSingleQueryTransactional(ctx, c.client, stmt, statementInfo, execOptions)
if err != nil {
return nil, err
}
Expand All @@ -858,13 +858,13 @@ func (c *conn) queryContext(ctx context.Context, query string, execOptions *Exec
// The statement was either detected as being a query, or potentially not recognized at all.
// In that case, just default to using a single-use read-only transaction and let Spanner
// return an error if the statement is not suited for that type of transaction.
iter = &readOnlyRowIterator{c.execSingleQuery(ctx, c.client, stmt, c.ReadOnlyStaleness(), execOptions)}
iter = &readOnlyRowIterator{c.execSingleQuery(ctx, c.client, stmt, statementInfo, c.ReadOnlyStaleness(), execOptions), statementInfo.StatementType}
}
} else {
if execOptions.PartitionedQueryOptions.PartitionQuery {
return c.tx.partitionQuery(ctx, stmt, execOptions)
}
iter, err = c.tx.Query(ctx, stmt, execOptions)
iter, err = c.tx.Query(ctx, stmt, statementInfo.StatementType, execOptions)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -1273,7 +1273,7 @@ func (c *conn) rollback(ctx context.Context) error {
return c.tx.Rollback()
}

func queryInSingleUse(ctx context.Context, c *spanner.Client, statement spanner.Statement, tb spanner.TimestampBound, options *ExecOptions) *spanner.RowIterator {
func queryInSingleUse(ctx context.Context, c *spanner.Client, statement spanner.Statement, statementInfo *parser.StatementInfo, tb spanner.TimestampBound, options *ExecOptions) *spanner.RowIterator {
return c.Single().WithTimestampBound(tb).QueryWithOptions(ctx, statement, options.QueryOptions)
}

Expand All @@ -1295,7 +1295,7 @@ func (c *conn) executeAutoPartitionedQuery(ctx context.Context, query string, ex
return r, nil
}

func queryInNewRWTransaction(ctx context.Context, c *spanner.Client, statement spanner.Statement, options *ExecOptions) (rowIterator, *spanner.CommitResponse, error) {
func queryInNewRWTransaction(ctx context.Context, c *spanner.Client, statement spanner.Statement, statementInfo *parser.StatementInfo, options *ExecOptions) (rowIterator, *spanner.CommitResponse, error) {
var result *wrappedRowIterator
options.QueryOptions.LastStatement = true
fn := func(ctx context.Context, tx *spanner.ReadWriteTransaction) error {
Expand All @@ -1304,6 +1304,7 @@ func queryInNewRWTransaction(ctx context.Context, c *spanner.Client, statement s
if err == iterator.Done {
result = &wrappedRowIterator{
RowIterator: it,
stmtType: statementInfo.StatementType,
noRows: true,
}
} else if err != nil {
Expand All @@ -1312,6 +1313,7 @@ func queryInNewRWTransaction(ctx context.Context, c *spanner.Client, statement s
} else {
result = &wrappedRowIterator{
RowIterator: it,
stmtType: statementInfo.StatementType,
firstRow: row,
}
}
Expand Down
8 changes: 4 additions & 4 deletions driver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -630,7 +630,7 @@ func TestConn_NonDdlStatementsInDdlBatch(t *testing.T) {
logger: noopLogger,
state: createInitialConnectionState(connectionstate.TypeNonTransactional, map[string]connectionstate.ConnectionPropertyValue{}),
batch: &batch{tp: parser.BatchTypeDdl},
execSingleQuery: func(ctx context.Context, c *spanner.Client, statement spanner.Statement, tb spanner.TimestampBound, options *ExecOptions) *spanner.RowIterator {
execSingleQuery: func(ctx context.Context, c *spanner.Client, statement spanner.Statement, statementInfo *parser.StatementInfo, tb spanner.TimestampBound, options *ExecOptions) *spanner.RowIterator {
return &spanner.RowIterator{}
},
execSingleDMLTransactional: func(ctx context.Context, c *spanner.Client, statement spanner.Statement, statementInfo *parser.StatementInfo, options *ExecOptions) (*result, *spanner.CommitResponse, error) {
Expand Down Expand Up @@ -670,7 +670,7 @@ func TestConn_NonDmlStatementsInDmlBatch(t *testing.T) {
logger: noopLogger,
state: createInitialConnectionState(connectionstate.TypeNonTransactional, map[string]connectionstate.ConnectionPropertyValue{}),
batch: &batch{tp: parser.BatchTypeDml},
execSingleQuery: func(ctx context.Context, c *spanner.Client, statement spanner.Statement, tb spanner.TimestampBound, options *ExecOptions) *spanner.RowIterator {
execSingleQuery: func(ctx context.Context, c *spanner.Client, statement spanner.Statement, statementInfo *parser.StatementInfo, tb spanner.TimestampBound, options *ExecOptions) *spanner.RowIterator {
return &spanner.RowIterator{}
},
execSingleDMLTransactional: func(ctx context.Context, c *spanner.Client, statement spanner.Statement, statementInfo *parser.StatementInfo, options *ExecOptions) (*result, *spanner.CommitResponse, error) {
Expand Down Expand Up @@ -761,7 +761,7 @@ func TestConn_GetCommitResponseAfterAutocommitDml(t *testing.T) {
parser: p,
logger: noopLogger,
state: createInitialConnectionState(connectionstate.TypeNonTransactional, map[string]connectionstate.ConnectionPropertyValue{}),
execSingleQuery: func(ctx context.Context, c *spanner.Client, statement spanner.Statement, tb spanner.TimestampBound, options *ExecOptions) *spanner.RowIterator {
execSingleQuery: func(ctx context.Context, c *spanner.Client, statement spanner.Statement, statementInfo *parser.StatementInfo, tb spanner.TimestampBound, options *ExecOptions) *spanner.RowIterator {
return &spanner.RowIterator{}
},
execSingleDMLTransactional: func(ctx context.Context, c *spanner.Client, statement spanner.Statement, statementInfo *parser.StatementInfo, options *ExecOptions) (*result, *spanner.CommitResponse, error) {
Expand Down Expand Up @@ -800,7 +800,7 @@ func TestConn_GetCommitResponseAfterAutocommitQuery(t *testing.T) {
parser: p,
logger: noopLogger,
state: createInitialConnectionState(connectionstate.TypeTransactional, map[string]connectionstate.ConnectionPropertyValue{}),
execSingleQuery: func(ctx context.Context, c *spanner.Client, statement spanner.Statement, tb spanner.TimestampBound, options *ExecOptions) *spanner.RowIterator {
execSingleQuery: func(ctx context.Context, c *spanner.Client, statement spanner.Statement, statementInfo *parser.StatementInfo, tb spanner.TimestampBound, options *ExecOptions) *spanner.RowIterator {
return &spanner.RowIterator{}
},
execSingleDMLTransactional: func(ctx context.Context, c *spanner.Client, statement spanner.Statement, statementInfo *parser.StatementInfo, options *ExecOptions) (*result, *spanner.CommitResponse, error) {
Expand Down
67 changes: 67 additions & 0 deletions driver_with_mockserver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5384,6 +5384,73 @@ func TestReturnResultSetStats(t *testing.T) {
}
}

func TestReturnResultSetStatsForQuery(t *testing.T) {
t.Parallel()

db, server, teardown := setupTestDBConnection(t)
defer teardown()
query := "select id from singers where id=42598"
resultSet := testutil.CreateSingleColumnInt64ResultSet([]int64{42598}, "id")
_ = server.TestSpanner.PutStatementResult(query, &testutil.StatementResult{
Type: testutil.StatementResultResultSet,
ResultSet: resultSet,
})

rows, err := db.QueryContext(context.Background(), query, ExecOptions{ReturnResultSetStats: true})
if err != nil {
t.Fatal(err)
}
defer func() { _ = rows.Close() }()

// The first result set should contain the data.
for want := int64(42598); rows.Next(); want++ {
cols, err := rows.Columns()
if err != nil {
t.Fatal(err)
}
if !cmp.Equal(cols, []string{"id"}) {
t.Fatalf("cols mismatch\nGot: %v\nWant: %v", cols, []string{"id"})
}
var got int64
err = rows.Scan(&got)
if err != nil {
t.Fatal(err)
}
if got != want {
t.Fatalf("value mismatch\nGot: %v\nWant: %v", got, want)
}
}
if rows.Err() != nil {
t.Fatal(rows.Err())
}

// The next result set should contain the stats.
if !rows.NextResultSet() {
t.Fatal("missing stats result set")
}

// Get the stats.
if !rows.Next() {
t.Fatal("no stats rows")
}
var stats *sppb.ResultSetStats
if err := rows.Scan(&stats); err != nil {
t.Fatalf("failed to scan stats: %v", err)
}
// The stats should not contain any update count.
if stats.GetRowCount() != nil {
t.Fatalf("got update count for query")
}
if rows.Next() {
t.Fatal("more rows than expected")
}

// There should be no more result sets.
if rows.NextResultSet() {
t.Fatal("more result sets than expected")
}
}

func TestReturnResultSetMetadataAndStats(t *testing.T) {
t.Parallel()

Expand Down
3 changes: 2 additions & 1 deletion partitioned_query.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"io"

"cloud.google.com/go/spanner"
"github.com/googleapis/go-sql-spanner/parser"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
Expand Down Expand Up @@ -231,7 +232,7 @@ func (pq *PartitionedQuery) execute(ctx context.Context, index int) (*rows, erro
return nil, spanner.ToSpannerError(status.Errorf(codes.InvalidArgument, "invalid partition index: %d", index))
}
spannerIter := pq.tx.Execute(ctx, pq.Partitions[index])
iter := &readOnlyRowIterator{spannerIter}
iter := &readOnlyRowIterator{spannerIter, parser.StatementTypeQuery}
return &rows{it: iter, decodeOption: pq.execOptions.DecodeOption}, nil
}

Expand Down
24 changes: 16 additions & 8 deletions transaction.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ type contextTransaction interface {
Commit() error
Rollback() error
resetForRetry(ctx context.Context) error
Query(ctx context.Context, stmt spanner.Statement, execOptions *ExecOptions) (rowIterator, error)
Query(ctx context.Context, stmt spanner.Statement, stmtType parser.StatementType, execOptions *ExecOptions) (rowIterator, error)
partitionQuery(ctx context.Context, stmt spanner.Statement, execOptions *ExecOptions) (driver.Rows, error)
ExecContext(ctx context.Context, stmt spanner.Statement, statementInfo *parser.StatementInfo, options spanner.QueryOptions) (*result, error)

Expand All @@ -67,6 +67,7 @@ var _ rowIterator = &readOnlyRowIterator{}

type readOnlyRowIterator struct {
*spanner.RowIterator
stmtType parser.StatementType
}

func (ri *readOnlyRowIterator) Next() (*spanner.Row, error) {
Expand All @@ -82,12 +83,19 @@ func (ri *readOnlyRowIterator) Metadata() (*sppb.ResultSetMetadata, error) {
}

func (ri *readOnlyRowIterator) ResultSetStats() *sppb.ResultSetStats {
return createResultSetStats(ri.RowIterator, ri.stmtType)
}

func createResultSetStats(it *spanner.RowIterator, stmtType parser.StatementType) *sppb.ResultSetStats {
// TODO: The Spanner client library should offer an option to get the full
// ResultSetStats, instead of only the RowCount and QueryPlan.
return &sppb.ResultSetStats{
RowCount: &sppb.ResultSetStats_RowCountExact{RowCountExact: ri.RowIterator.RowCount},
QueryPlan: ri.RowIterator.QueryPlan,
stats := &sppb.ResultSetStats{
QueryPlan: it.QueryPlan,
}
if stmtType == parser.StatementTypeDml {
stats.RowCount = &sppb.ResultSetStats_RowCountExact{RowCountExact: it.RowCount}
}
return stats
}

type txResult int
Expand Down Expand Up @@ -135,7 +143,7 @@ func (tx *readOnlyTransaction) resetForRetry(ctx context.Context) error {
return nil
}

func (tx *readOnlyTransaction) Query(ctx context.Context, stmt spanner.Statement, execOptions *ExecOptions) (rowIterator, error) {
func (tx *readOnlyTransaction) Query(ctx context.Context, stmt spanner.Statement, stmtType parser.StatementType, execOptions *ExecOptions) (rowIterator, error) {
tx.logger.DebugContext(ctx, "Query", "stmt", stmt.SQL)
if execOptions.PartitionedQueryOptions.AutoPartitionQuery {
if tx.boTx == nil {
Expand All @@ -152,7 +160,7 @@ func (tx *readOnlyTransaction) Query(ctx context.Context, stmt spanner.Statement
}
return mi, nil
}
return &readOnlyRowIterator{tx.roTx.QueryWithOptions(ctx, stmt, execOptions.QueryOptions)}, nil
return &readOnlyRowIterator{tx.roTx.QueryWithOptions(ctx, stmt, execOptions.QueryOptions), stmtType}, nil
}

func (tx *readOnlyTransaction) partitionQuery(ctx context.Context, stmt spanner.Statement, execOptions *ExecOptions) (driver.Rows, error) {
Expand Down Expand Up @@ -456,7 +464,7 @@ func (tx *readWriteTransaction) resetForRetry(ctx context.Context) error {
// Query executes a query using the read/write transaction and returns a
// rowIterator that will automatically retry the read/write transaction if the
// transaction is aborted during the query or while iterating the returned rows.
func (tx *readWriteTransaction) Query(ctx context.Context, stmt spanner.Statement, execOptions *ExecOptions) (rowIterator, error) {
func (tx *readWriteTransaction) Query(ctx context.Context, stmt spanner.Statement, stmtType parser.StatementType, execOptions *ExecOptions) (rowIterator, error) {
tx.logger.Debug("Query", "stmt", stmt.SQL)
tx.active = true
if err := tx.maybeRunAutoDmlBatch(ctx); err != nil {
Expand All @@ -465,7 +473,7 @@ func (tx *readWriteTransaction) Query(ctx context.Context, stmt spanner.Statemen
// If internal retries have been disabled, we don't need to keep track of a
// running checksum for all results that we have seen.
if !tx.retryAborts() {
return &readOnlyRowIterator{tx.rwTx.QueryWithOptions(ctx, stmt, execOptions.QueryOptions)}, nil
return &readOnlyRowIterator{tx.rwTx.QueryWithOptions(ctx, stmt, execOptions.QueryOptions), stmtType}, nil
}

// If retries are enabled, we need to use a row iterator that will keep
Expand Down
8 changes: 4 additions & 4 deletions wrapped_row_iterator.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,17 @@ package spannerdriver
import (
"cloud.google.com/go/spanner"
"cloud.google.com/go/spanner/apiv1/spannerpb"
"github.com/googleapis/go-sql-spanner/parser"
"google.golang.org/api/iterator"
)

var _ rowIterator = &wrappedRowIterator{}

// wrappedRowIterator is used for DML statements that may or may not contain rows.
type wrappedRowIterator struct {
*spanner.RowIterator

stmtType parser.StatementType
noRows bool
firstRow *spanner.Row
}
Expand All @@ -49,8 +52,5 @@ func (ri *wrappedRowIterator) Metadata() (*spannerpb.ResultSetMetadata, error) {
}

func (ri *wrappedRowIterator) ResultSetStats() *spannerpb.ResultSetStats {
return &spannerpb.ResultSetStats{
RowCount: &spannerpb.ResultSetStats_RowCountExact{RowCountExact: ri.RowIterator.RowCount},
QueryPlan: ri.RowIterator.QueryPlan,
}
return createResultSetStats(ri.RowIterator, ri.stmtType)
}
Loading