Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -836,9 +836,13 @@ func (c *conn) queryContext(ctx context.Context, query string, execOptions *Exec
return nil, err
}
statementType := c.parser.DetectStatementType(query)
// DDL statements are not supported in QueryContext so fail early.
// DDL statements are not supported in QueryContext so use the execContext method for the execution.
if statementType.StatementType == parser.StatementTypeDdl {
return nil, spanner.ToSpannerError(status.Errorf(codes.FailedPrecondition, "QueryContext does not support DDL statements, use ExecContext instead"))
res, err := c.execContext(ctx, query, execOptions, args)
if err != nil {
return nil, err
}
return createDriverResultRows(res, execOptions), nil
}
var iter rowIterator
if c.tx == nil {
Expand Down
29 changes: 19 additions & 10 deletions conn_with_mockserver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -567,17 +567,26 @@ func TestDropDatabase(t *testing.T) {
func TestDDLUsingQueryContext(t *testing.T) {
t.Parallel()

db, _, teardown := setupTestDBConnection(t)
db, server, teardown := setupTestDBConnection(t)
defer teardown()
var expectedResponse = &emptypb.Empty{}
anyMsg, _ := anypb.New(expectedResponse)
server.TestDatabaseAdmin.SetResps([]proto.Message{
&longrunningpb.Operation{
Done: true,
Result: &longrunningpb.Operation_Response{Response: anyMsg},
Name: "test-operation",
},
})
ctx := context.Background()

// DDL statements should not use the query context.
_, err := db.QueryContext(ctx, "CREATE TABLE Foo (Bar STRING(100))")
if err == nil {
t.Fatal("expected error for DDL statement using QueryContext, got nil")
}
if g, w := err.Error(), `spanner: code = "FailedPrecondition", desc = "QueryContext does not support DDL statements, use ExecContext instead"`; g != w {
t.Fatalf("error mismatch\n Got: %v\nWant: %v", g, w)
// DDL statements should be able to use QueryContext.
if it, err := db.QueryContext(ctx, "CREATE TABLE Foo (Bar STRING(100))"); err != nil {
t.Fatal(err)
} else {
if it.Next() {
t.Fatalf("DDL should not return any rows")
}
}
}

Expand All @@ -598,7 +607,7 @@ func TestDDLUsingQueryContextInReadOnlyTx(t *testing.T) {
if err == nil {
t.Fatal("expected error for DDL statement using QueryContext in read-only transaction, got nil")
}
if g, w := err.Error(), `spanner: code = "FailedPrecondition", desc = "QueryContext does not support DDL statements, use ExecContext instead"`; g != w {
if g, w := err.Error(), `spanner: code = "FailedPrecondition", desc = "cannot execute DDL as part of a transaction"`; g != w {
t.Fatalf("error mismatch\n Got: %v\nWant: %v", g, w)
}
}
Expand All @@ -621,7 +630,7 @@ func TestDDLUsingQueryContextInReadWriteTransaction(t *testing.T) {
if err == nil {
t.Fatal("expected error for DDL statement using QueryContext in read-write transaction, got nil")
}
if g, w := err.Error(), `spanner: code = "FailedPrecondition", desc = "QueryContext does not support DDL statements, use ExecContext instead"`; g != w {
if g, w := err.Error(), `spanner: code = "FailedPrecondition", desc = "cannot execute DDL as part of a transaction"`; g != w {
t.Fatalf("error mismatch\n Got: %v\nWant: %v", g, w)
}
}
Expand Down
94 changes: 94 additions & 0 deletions rows.go
Original file line number Diff line number Diff line change
Expand Up @@ -476,3 +476,97 @@ func (r *rows) nextStats(dest []driver.Value) error {
dest[0] = r.it.ResultSetStats()
return nil
}

var _ driver.Rows = (*emptyRows)(nil)
var _ driver.RowsNextResultSet = (*emptyRows)(nil)
var emptyRowsMetadata = &sppb.ResultSetMetadata{
RowType: &sppb.StructType{
Fields: []*sppb.StructType_Field{{Name: "affected_rows", Type: &sppb.Type{Code: sppb.TypeCode_INT64}}},
},
}
var emptyRowsStats = &sppb.ResultSetStats{}

type emptyRows struct {
currentResultSetType resultSetType
returnResultSetMetadata bool
returnResultSetStats bool

hasReturnedResultSetMetadata bool
hasReturnedResultSetStats bool
}

func createDriverResultRows(_ driver.Result, opts *ExecOptions) *emptyRows {
res := &emptyRows{
returnResultSetMetadata: opts.ReturnResultSetMetadata,
returnResultSetStats: opts.ReturnResultSetStats,
}
if !opts.ReturnResultSetMetadata {
res.currentResultSetType = resultSetTypeResults
}
return res
}

func (e *emptyRows) HasNextResultSet() bool {
if e.currentResultSetType == resultSetTypeMetadata && e.returnResultSetMetadata {
return true
}
if e.currentResultSetType == resultSetTypeResults && e.returnResultSetStats {
return true
}
return false
}

func (e *emptyRows) NextResultSet() error {
if !e.HasNextResultSet() {
return io.EOF
}
e.currentResultSetType++
return nil
}

func (e *emptyRows) Columns() []string {
switch e.currentResultSetType {
case resultSetTypeMetadata:
return []string{"metadata"}
case resultSetTypeResults:
return []string{"affected_rows"}
case resultSetTypeStats:
return []string{"stats"}
case resultSetTypeNoMoreResults:
return []string{}
}
return []string{}
}

func (e *emptyRows) Close() error {
return nil
}

func (e *emptyRows) Next(dest []driver.Value) error {
if e.currentResultSetType == resultSetTypeMetadata {
return e.nextMetadata(dest)
}
if e.currentResultSetType == resultSetTypeStats {
return e.nextStats(dest)
}

return io.EOF
}

func (e *emptyRows) nextMetadata(dest []driver.Value) error {
if e.hasReturnedResultSetMetadata {
return io.EOF
}
e.hasReturnedResultSetMetadata = true
dest[0] = emptyRowsMetadata
return nil
}

func (e *emptyRows) nextStats(dest []driver.Value) error {
if e.hasReturnedResultSetStats {
return io.EOF
}
e.hasReturnedResultSetStats = true
dest[0] = emptyRowsStats
return nil
}
70 changes: 70 additions & 0 deletions rows_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,15 @@ package spannerdriver

import (
"database/sql/driver"
"errors"
"fmt"
"io"
"reflect"
"testing"

"cloud.google.com/go/spanner"
sppb "cloud.google.com/go/spanner/apiv1/spannerpb"
"github.com/google/go-cmp/cmp"
"google.golang.org/protobuf/types/known/structpb"
)

Expand Down Expand Up @@ -153,3 +156,70 @@ func TestRows_Next_Unsupported(t *testing.T) {
t.Fatalf("expected error %q, but got %q", expectedError, err.Error())
}
}

func TestEmptyRows(t *testing.T) {
r := createDriverResultRows(&result{}, &ExecOptions{})

if g, w := r.Columns(), []string{"affected_rows"}; !cmp.Equal(g, w) {
t.Fatalf("columns mismatch\n Got: %v\nWant: %v", g, w)
}
if r.HasNextResultSet() {
t.Fatalf("unexpected next result set available")
}
}

func TestEmptyRowsWithMetadataAndStats(t *testing.T) {
r := createDriverResultRows(&result{}, &ExecOptions{ReturnResultSetMetadata: true, ReturnResultSetStats: true})

// The first result set should contain ResultSetMetadata.
if g, w := r.Columns(), []string{"metadata"}; !cmp.Equal(g, w) {
t.Fatalf("columns mismatch\n Got: %v\nWant: %v", g, w)
}
values := make([]driver.Value, 1)
if err := r.Next(values); err != nil {
t.Fatalf("unexpected error from Next: %v", err)
}
if g, w := reflect.TypeOf(values[0]), reflect.TypeOf(&sppb.ResultSetMetadata{}); g != w {
t.Fatalf("result set metadata type mismatch\n Got: %v\nWant: %v", g, w)
}
if g, w := r.Next(values), io.EOF; !errors.Is(g, w) {
t.Fatalf("next result set mismatch\n Got: %v\nWant: %v", g, w)
}

// The second result set should contain the actual data (which is empty).
if !r.HasNextResultSet() {
t.Fatalf("missing next result set")
}
if err := r.NextResultSet(); err != nil {
t.Fatalf("unexpected error from NextResultSet: %v", err)
}
if g, w := r.Columns(), []string{"affected_rows"}; !cmp.Equal(g, w) {
t.Fatalf("columns mismatch\n Got: %v\nWant: %v", g, w)
}
// There should be no data.
if g, w := r.Next(values), io.EOF; !errors.Is(g, w) {
t.Fatalf("next result set mismatch\n Got: %v\nWant: %v", g, w)
}

// The third result set should contain ResultSetStats.
if !r.HasNextResultSet() {
t.Fatalf("missing next result set")
}
if err := r.NextResultSet(); err != nil {
t.Fatalf("unexpected error from NextResultSet: %v", err)
}
if err := r.Next(values); err != nil {
t.Fatalf("unexpected error from Next: %v", err)
}
if g, w := reflect.TypeOf(values[0]), reflect.TypeOf(&sppb.ResultSetStats{}); g != w {
t.Fatalf("result set stats type mismatch\n Got: %v\nWant: %v", g, w)
}
if g, w := r.Next(values), io.EOF; !errors.Is(g, w) {
t.Fatalf("next result set mismatch\n Got: %v\nWant: %v", g, w)
}

// There should be no more result sets.
if r.HasNextResultSet() {
t.Fatalf("unexpected next result set available")
}
}
Loading