diff --git a/conn.go b/conn.go index 57fae247..f6ac23be 100644 --- a/conn.go +++ b/conn.go @@ -59,6 +59,9 @@ type SpannerConn interface { // RunBatch sends all batched DDL or DML statements to Spanner. This is a // no-op if no statements have been batched or if there is no active batch. RunBatch(ctx context.Context) error + // RunDmlBatch sends all batched DML statements to Spanner. This is a + // no-op if no statements have been batched or if there is no active DML batch. + RunDmlBatch(ctx context.Context) (SpannerResult, error) // AbortBatch aborts the current DDL or DML batch and discards all batched // statements. AbortBatch() error @@ -446,6 +449,18 @@ func (c *conn) RunBatch(ctx context.Context) error { return err } +func (c *conn) RunDmlBatch(ctx context.Context) (SpannerResult, error) { + res, err := c.runBatch(ctx) + if err != nil { + return nil, err + } + spannerRes, ok := res.(SpannerResult) + if !ok { + return nil, spanner.ToSpannerError(status.Errorf(codes.FailedPrecondition, "not a DML batch")) + } + return spannerRes, nil +} + func (c *conn) AbortBatch() error { _, err := c.abortBatch() return err diff --git a/conn_with_mockserver_test.go b/conn_with_mockserver_test.go index cbdfa256..842855f3 100644 --- a/conn_with_mockserver_test.go +++ b/conn_with_mockserver_test.go @@ -369,3 +369,44 @@ func TestDDLUsingQueryContextInReadWriteTransaction(t *testing.T) { t.Fatalf("error mismatch\n Got: %v\nWant: %v", g, w) } } + +func TestRunDmlBatch(t *testing.T) { + t.Parallel() + + db, _, teardown := setupTestDBConnection(t) + defer teardown() + ctx := context.Background() + + conn, err := db.Conn(ctx) + if err != nil { + t.Fatal(err) + } + defer silentClose(conn) + if err := conn.Raw(func(driverConn interface{}) error { + spannerConn, _ := driverConn.(SpannerConn) + return spannerConn.StartBatchDML() + }); err != nil { + t.Fatal(err) + } + // Buffer two DML statements. + for range 2 { + if _, err := conn.ExecContext(ctx, testutil.UpdateBarSetFoo); err != nil { + t.Fatal(err) + } + } + var res SpannerResult + if err := conn.Raw(func(driverConn interface{}) (err error) { + spannerConn, _ := driverConn.(SpannerConn) + res, err = spannerConn.RunDmlBatch(ctx) + return err + }); err != nil { + t.Fatal(err) + } + affected, err := res.BatchRowsAffected() + if err != nil { + t.Fatal(err) + } + if g, w := affected, []int64{testutil.UpdateBarSetFooRowCount, testutil.UpdateBarSetFooRowCount}; !reflect.DeepEqual(g, w) { + t.Fatalf("affected mismatch\n Got: %v\nWant: %v", g, w) + } +} diff --git a/examples/dml-batches/main.go b/examples/dml-batches/main.go index f5eeaa51..4273f886 100644 --- a/examples/dml-batches/main.go +++ b/examples/dml-batches/main.go @@ -106,11 +106,17 @@ func dmlBatch(projectId, instanceId, databaseId string) error { return fmt.Errorf("failed to insert: %v", err) } // Run the batch. This will apply all the batched DML statements to the database in one atomic operation. - if err := conn.Raw(func(driverConn interface{}) error { - return driverConn.(spannerdriver.SpannerConn).RunBatch(ctx) + var res spannerdriver.SpannerResult + if err := conn.Raw(func(driverConn interface{}) (err error) { + res, err = driverConn.(spannerdriver.SpannerConn).RunDmlBatch(ctx) + return err }); err != nil { return fmt.Errorf("failed to run DML batch: %v", err) } + // BatchRowsAffected returns a slice with the affected rows per DML statement in the batch. + affected, _ := res.BatchRowsAffected() + fmt.Printf("Affected rows: %v\n", affected) + if err := db.QueryRowContext(ctx, "SELECT COUNT(*) FROM Singers").Scan(&c); err != nil { return fmt.Errorf("failed to get singers count: %v", err) }