Skip to content

Commit

Permalink
fix: added ddl support to driver, changed tests to call driver rather…
Browse files Browse the repository at this point in the history
… than api directly
  • Loading branch information
sagebee committed Feb 1, 2021
1 parent fd70120 commit a9c4c8a
Show file tree
Hide file tree
Showing 2 changed files with 294 additions and 16 deletions.
75 changes: 71 additions & 4 deletions driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,17 @@ import (
"database/sql"
"database/sql/driver"
"errors"
"os"
"regexp"
"time"

"cloud.google.com/go/spanner"
"github.com/rakyll/go-sql-driver-spanner/internal"
"google.golang.org/api/option"

adminapi "cloud.google.com/go/spanner/admin/database/apiv1"
adminpb "google.golang.org/genproto/googleapis/spanner/admin/database/v1"
"google.golang.org/grpc"
)

const userAgent = "go-sql-driver-spanner/0.1"
Expand Down Expand Up @@ -78,17 +84,47 @@ func openDriverConn(ctx context.Context, d *Driver, name string) (driver.Conn, e
if err != nil {
return nil, err
}
return &conn{client: client}, nil

adminClient, err := CreateAdminClient(ctx)
return &conn{client: client, adminClient: adminClient, name: name}, nil
}

func CreateAdminClient(ctx context.Context) (*adminapi.DatabaseAdminClient, error) {

var adminClient *adminapi.DatabaseAdminClient
var err error

// Admin client will connect tp emulator if SPANNER_EMULATOR_HOST
// is set in the environment.
if spannerHost, ok := os.LookupEnv("SPANNER_EMULATOR_HOST"); ok {
adminClient, err = adminapi.NewDatabaseAdminClient(
ctx,
option.WithoutAuthentication(),
option.WithEndpoint(spannerHost),
option.WithGRPCDialOption(grpc.WithInsecure()))
if err != nil {
return nil, err
}
} else {
adminClient, err = adminapi.NewDatabaseAdminClient(ctx)
if err != nil {
return nil, err
}
}

return adminClient, nil
}

func (c *connector) Driver() driver.Driver {
return &Driver{}
}

type conn struct {
client *spanner.Client
roTx *spanner.ReadOnlyTransaction
rwTx *rwTx
client *spanner.Client
adminClient *adminapi.DatabaseAdminClient
roTx *spanner.ReadOnlyTransaction
rwTx *rwTx
name string
}

func (c *conn) Prepare(query string) (driver.Stmt, error) {
Expand All @@ -105,6 +141,27 @@ func (c *conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, e
}

func (c *conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {

// Use admin API if DDL statement is provided.
ddl, err := IsDdlStatement(query)
if err != nil {
return nil, err
}

if ddl {
op, err := c.adminClient.UpdateDatabaseDdl(ctx, &adminpb.UpdateDatabaseDdlRequest{
Database: c.name,
Statements: []string{query},
})
if err != nil {
return nil, err
}
if err := op.Wait(ctx); err != nil {
return nil, err
}
return &result{rowsAffected: 0}, nil
}

if c.roTx != nil {
return nil, errors.New("cannot write in read-only transaction")
}
Expand All @@ -125,6 +182,16 @@ func (c *conn) ExecContext(ctx context.Context, query string, args []driver.Name
return &result{rowsAffected: rowsAffected}, nil
}

func IsDdlStatement(query string) (bool, error) {

matchddl, err := regexp.MatchString(`(?is)^\n*\s*(CREATE|DROP|ALTER)\s+.+$`, query)
if err != nil {
return false, err
}

return matchddl, nil
}

func (c *conn) Close() error {
c.client.Close()
return nil
Expand Down
235 changes: 223 additions & 12 deletions driver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,26 +12,237 @@
// See the License for the specific language governing permissions and
// limitations under the License.

package spannerdriver_test
package spannerdriver

import (
"cloud.google.com/go/spanner"
"context"
"database/sql"
"log"
"os"
"reflect"
"testing"
)

spannerdriver "github.com/rakyll/go-sql-driver-spanner"
"google.golang.org/api/option"
var (
dsn string
)

func ExampleDriver() {
driver := &spannerdriver.Driver{
Options: []option.ClientOption{
option.WithCredentialsFile("/path/to/service-account-key.json"),
type Connector struct {
ctx context.Context
client *spanner.Client
}

func NewConnector() (*Connector, error) {

ctx := context.Background()

dataClient, err := spanner.NewClient(ctx, dsn)
if err != nil {
return nil, err
}

conn := &Connector{
ctx: ctx,
client: dataClient,
}
return conn, nil
}

func (c *Connector) Close() {
c.client.Close()
}

func init() {

var projectId, instanceId, databaseId string
var ok bool

// Get environment variables or set to default.
if instanceId, ok = os.LookupEnv("SPANNER_TEST_INSTANCE"); !ok {
instanceId = "test-instance"
}
if projectId, ok = os.LookupEnv("SPANNER_TEST_PROJECT"); !ok {
projectId = "test-project"
}
if databaseId, ok = os.LookupEnv("SPANNER_TEST_DBID"); !ok {
databaseId = "gotest"
}

// Derive data source name.
dsn = "projects/" + projectId + "/instances/" + instanceId + "/databases/" + databaseId
}

// Executes DML using the client library.
func ExecuteDMLClientLib(dml []string) error {

// Open client/
ctx := context.Background()
client, err := spanner.NewClient(ctx, dsn)
if err != nil {
return err
}
defer client.Close()

// Put strings into spanner.Statement structure.
var stmts []spanner.Statement
for _, line := range dml {
stmts = append(stmts, spanner.NewStatement(line))
}

// Execute statements.
_, err = client.ReadWriteTransaction(ctx, func(ctx context.Context, txn *spanner.ReadWriteTransaction) error {
_, err := txn.BatchUpdate(ctx, stmts)
if err != nil {
return err
}
return nil
})

return err
}

func TestQueryContext(t *testing.T) {

// Open db.
ctx := context.Background()
db, err := sql.Open("spanner", dsn)
if err != nil {
t.Fatal(err)
}
defer db.Close()

// Set up test table.
_, err = db.ExecContext(ctx, `CREATE TABLE TestQueryContext (
A STRING(1024),
B STRING(1024),
C STRING(1024)
) PRIMARY KEY (A)`)
if err != nil {
t.Fatal(err)
}

conn, err := NewConnector()
if err != nil {
t.Fatal(err)
}
defer conn.Close()

err = ExecuteDMLClientLib([]string{`INSERT INTO TestQueryContext (A, B, C)
VALUES ("a1", "b1", "c1"), ("a2", "b2", "c2") , ("a3", "b3", "c3") `})
if err != nil {
t.Fatal(err)
}

type testQueryContextRow struct {
A, B, C string
}

tests := []struct {
name string
input string
want []testQueryContextRow
wantErrorQuery bool
wantErrorScan bool
wantErrorClose bool
}{
{
name: "empty query",
wantErrorClose: true,
input: "",
want: []testQueryContextRow{},
},
{
name: "syntax error",
wantErrorClose: true,
input: "SELECT SELECT * FROM TestQueryContext",
want: []testQueryContextRow{},
},
{
name: "return nothing",
input: "SELECT * FROM TestQueryContext WHERE A = \"hihihi\"",
want: []testQueryContextRow{},
},
{
name: "select one tuple",
input: "SELECT * FROM TestQueryContext WHERE A = \"a1\"",
want: []testQueryContextRow{
{A: "a1", B: "b1", C: "c1"},
},
},
{
name: "select subset of tuples",
input: "SELECT * FROM TestQueryContext WHERE A = \"a1\" OR A = \"a2\"",
want: []testQueryContextRow{
{A: "a1", B: "b1", C: "c1"},
{A: "a2", B: "b2", C: "c2"},
},
},
{
name: "select subset of tuples with !=",
input: "SELECT * FROM TestQueryContext WHERE A != \"a3\"",
want: []testQueryContextRow{
{A: "a1", B: "b1", C: "c1"},
{A: "a2", B: "b2", C: "c2"},
},
},
{
name: "select entire table",
input: "SELECT * FROM TestQueryContext ORDER BY A",
want: []testQueryContextRow{
{A: "a1", B: "b1", C: "c1"},
{A: "a2", B: "b2", C: "c2"},
{A: "a3", B: "b3", C: "c3"},
},
},
{
name: "query non existant table",
wantErrorClose: true,
input: "SELECT * FROM TestQueryContexta", want: []testQueryContextRow{},
},
}
connector, err := driver.OpenConnector("projects/$PROJECT/instances/$INSTANCE/databases/$DATABASE")

// Run tests
for _, tc := range tests {

rows, err := db.QueryContext(ctx, tc.input)
if (err != nil) && (!tc.wantErrorQuery) {
t.Errorf("%s: unexpected query error: %v", tc.name, err)
}
if (err == nil) && (tc.wantErrorQuery) {
t.Errorf("%s: expected query error but error was %v", tc.name, err)
}

got := []testQueryContextRow{}
for rows.Next() {
var curr testQueryContextRow
err := rows.Scan(&curr.A, &curr.B, &curr.C)
if (err != nil) && (!tc.wantErrorScan) {
t.Errorf("%s: unexpected query error: %v", tc.name, err)
}
if (err == nil) && (tc.wantErrorScan) {
t.Errorf("%s: expected query error but error was %v", tc.name, err)
}

got = append(got, curr)
}

rows.Close()
err = rows.Err()
if (err != nil) && (!tc.wantErrorClose) {
t.Errorf("%s: unexpected query error: %v", tc.name, err)
}
if (err == nil) && (tc.wantErrorClose) {
t.Errorf("%s: expected query error but error was %v", tc.name, err)
}
if !reflect.DeepEqual(tc.want, got) {
t.Errorf("Test failed: %s. want: %v, got: %v", tc.name, tc.want, got)
}

}

// Drop table.
_, err = db.ExecContext(ctx, `DROP TABLE TestQueryContext`)
if err != nil {
log.Fatal(err)
t.Error(err)
}
db := sql.OpenDB(connector)
_ = db // Use db.
}

0 comments on commit a9c4c8a

Please sign in to comment.