Skip to content

Commit

Permalink
fix: standardize returned errors (#32)
Browse files Browse the repository at this point in the history
Changes all errors that are created directly by the driver to Spanner errors.

Fixes #14
  • Loading branch information
olavloite authored Aug 30, 2021
1 parent 0ae00ad commit e780348
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 16 deletions.
7 changes: 4 additions & 3 deletions driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ import (
"context"
"database/sql"
"database/sql/driver"
"errors"
"fmt"
"regexp"
"strconv"
Expand All @@ -27,6 +26,8 @@ import (
"cloud.google.com/go/spanner"
"github.com/cloudspannerecosystem/go-sql-spanner/internal"
"google.golang.org/api/option"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"

adminapi "cloud.google.com/go/spanner/admin/database/apiv1"
adminpb "google.golang.org/genproto/googleapis/spanner/admin/database/v1"
Expand Down Expand Up @@ -114,7 +115,7 @@ func extractConnectorParams(paramsString string) (map[string]string, error) {
}
keyValue := strings.SplitN(keyValueString, "=", 2)
if keyValue == nil || len(keyValue) != 2 {
return nil, fmt.Errorf("invalid connection property: %s", keyValueString)
return nil, spanner.ToSpannerError(status.Errorf(codes.InvalidArgument, "invalid connection property: %s", keyValueString))
}
params[strings.ToLower(keyValue[0])] = keyValue[1]
}
Expand Down Expand Up @@ -326,7 +327,7 @@ func (c *conn) Begin() (driver.Tx, error) {

func (c *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
if c.inTransaction() {
return nil, errors.New("already in a transaction")
return nil, spanner.ToSpannerError(status.Errorf(codes.FailedPrecondition, "already in a transaction"))
}

if opts.ReadOnly {
Expand Down
13 changes: 8 additions & 5 deletions internal/statement_parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,12 @@
package internal

import (
"fmt"
"strings"
"unicode"

"cloud.google.com/go/spanner"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)

var ddlStatements = map[string]bool{"CREATE": true, "DROP": true, "ALTER": true}
Expand Down Expand Up @@ -76,7 +79,7 @@ func removeCommentsAndTrim(sql string) (string, error) {
c := runes[index]
if isInQuoted {
if (c == '\n' || c == '\r') && !isTripleQuoted {
return "", fmt.Errorf("statement contains an unclosed literal: %s", sql)
return "", spanner.ToSpannerError(status.Errorf(codes.InvalidArgument, "statement contains an unclosed literal: %s", sql))
} else if c == startQuote {
if lastCharWasEscapeChar {
lastCharWasEscapeChar = false
Expand Down Expand Up @@ -138,7 +141,7 @@ func removeCommentsAndTrim(sql string) (string, error) {
index++
}
if isInQuoted {
return "", fmt.Errorf("statement contains an unclosed literal: %s", sql)
return "", spanner.ToSpannerError(status.Errorf(codes.InvalidArgument, "statement contains an unclosed literal: %s", sql))
}
trimmed := strings.TrimSpace(res.String())
if len(trimmed) > 0 && trimmed[len(trimmed)-1] == ';' {
Expand Down Expand Up @@ -197,7 +200,7 @@ func findParams(sql string) ([]string, error) {
c := runes[index]
if isInQuoted {
if (c == '\n' || c == '\r') && !isTripleQuoted {
return nil, fmt.Errorf("statement contains an unclosed literal: %s", sql)
return nil, spanner.ToSpannerError(status.Errorf(codes.InvalidArgument, "statement contains an unclosed literal: %s", sql))
} else if c == startQuote {
if lastCharWasEscapeChar {
lastCharWasEscapeChar = false
Expand Down Expand Up @@ -248,7 +251,7 @@ func findParams(sql string) ([]string, error) {
index++
}
if isInQuoted {
return nil, fmt.Errorf("statement contains an unclosed literal: %s", sql)
return nil, spanner.ToSpannerError(status.Errorf(codes.InvalidArgument, "statement contains an unclosed literal: %s", sql))
}
return res, nil
}
Expand Down
24 changes: 24 additions & 0 deletions internal/statement_parser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ package internal
import (
"testing"

"cloud.google.com/go/spanner"
"github.com/google/go-cmp/cmp"
"google.golang.org/grpc/codes"
)

func TestRemoveCommentsAndTrim(t *testing.T) {
Expand Down Expand Up @@ -660,3 +662,25 @@ func TestIsDdl(t *testing.T) {
}
}
}

func TestRemoveCommentsAndTrim_Errors(t *testing.T) {
_, err := removeCommentsAndTrim("SELECT 'Hello World FROM SomeTable")
if g, w := spanner.ErrCode(err), codes.InvalidArgument; g != w {
t.Errorf("error code mismatch\nGot: %v\nWant: %v\n", g, w)
}
_, err = removeCommentsAndTrim("SELECT 'Hello World\nFROM SomeTable")
if g, w := spanner.ErrCode(err), codes.InvalidArgument; g != w {
t.Errorf("error code mismatch\nGot: %v\nWant: %v\n", g, w)
}
}

func TestFindParams_Errors(t *testing.T) {
_, err := findParams("SELECT 'Hello World FROM SomeTable WHERE id=@id")
if g, w := spanner.ErrCode(err), codes.InvalidArgument; g != w {
t.Errorf("error code mismatch\nGot: %v\nWant: %v\n", g, w)
}
_, err = findParams("SELECT 'Hello World\nFROM SomeTable WHERE id=@id")
if g, w := spanner.ErrCode(err), codes.InvalidArgument; g != w {
t.Errorf("error code mismatch\nGot: %v\nWant: %v\n", g, w)
}
}
12 changes: 6 additions & 6 deletions stmt.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@ package spannerdriver
import (
"context"
"database/sql/driver"
"errors"
"fmt"

"cloud.google.com/go/spanner"
"github.com/cloudspannerecosystem/go-sql-spanner/internal"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)

type stmt struct {
Expand All @@ -39,15 +39,15 @@ func (s *stmt) NumInput() int {
}

func (s *stmt) Exec(args []driver.Value) (driver.Result, error) {
return nil, fmt.Errorf("use ExecContext instead")
return nil, spanner.ToSpannerError(status.Errorf(codes.Unimplemented, "use ExecContext instead"))
}

func (s *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
return s.conn.ExecContext(ctx, s.query, args)
}

func (s *stmt) Query(args []driver.Value) (driver.Rows, error) {
return nil, fmt.Errorf("use QueryContext instead")
return nil, spanner.ToSpannerError(status.Errorf(codes.Unimplemented, "use QueryContext instead"))
}

func (s *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
Expand Down Expand Up @@ -75,7 +75,7 @@ func prepareSpannerStmt(q string, args []driver.NamedValue) (spanner.Statement,
return spanner.Statement{}, err
}
if len(names) != len(args) {
return spanner.Statement{}, fmt.Errorf("got %v argument values, but found %v parameters in the sql string", len(args), len(names))
return spanner.Statement{}, spanner.ToSpannerError(status.Errorf(codes.InvalidArgument, "got %v argument values, but found %v parameters in the sql string", len(args), len(names)))
}
ss := spanner.NewStatement(q)
for i, v := range args {
Expand All @@ -93,7 +93,7 @@ type result struct {
}

func (r *result) LastInsertId() (int64, error) {
return 0, errors.New("spanner doesn't autogenerate IDs")
return 0, spanner.ToSpannerError(status.Errorf(codes.Unimplemented, "Cloud Spanner does not support auto-generated ids"))
}

func (r *result) RowsAffected() (int64, error) {
Expand Down
3 changes: 1 addition & 2 deletions transaction.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ import (
"bytes"
"context"
"encoding/gob"
"fmt"

"cloud.google.com/go/spanner"
sppb "google.golang.org/genproto/googleapis/spanner/v1"
Expand Down Expand Up @@ -87,7 +86,7 @@ func (tx *readOnlyTransaction) Query(ctx context.Context, stmt spanner.Statement
}

func (tx *readOnlyTransaction) ExecContext(_ context.Context, stmt spanner.Statement) (int64, error) {
return 0, fmt.Errorf("read-only transactions cannot write")
return 0, spanner.ToSpannerError(status.Errorf(codes.FailedPrecondition, "read-only transactions cannot write"))
}

// ErrAbortedDueToConcurrentModification is returned by a read/write transaction
Expand Down

0 comments on commit e780348

Please sign in to comment.