Skip to content

Commit

Permalink
PECO-1054 Expose Arrow batches to users, part three (#166)
Browse files Browse the repository at this point in the history
Added DBSqlRows and DBSQLArrowBatchIterator public interfaces. 
Added arrowRecordIterator which implements  DBSQLArrowBatchIterator. 
Moved closing the database operation from rows type into
resultPageIterator as well as properties that are only used by
resultPageIterator.
Added GetArrowBatches function to rows and arrowRowScanner types. 
Added HasNext function to BatchIterator and SparkArrowBatch interfaces. 
Added example for accessing Arrow batches and updated doc.go
  • Loading branch information
rcypher-databricks committed Sep 29, 2023
2 parents 5a3a210 + 2d88022 commit 6bb1879
Show file tree
Hide file tree
Showing 20 changed files with 1,854 additions and 99 deletions.
74 changes: 74 additions & 0 deletions doc.go
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,80 @@ Example usage:
See the documentation for dbsql/errors for more information.
# Retrieving Arrow Batches
The driver supports the ability to retrieve Apache Arrow record batches.
To work with record batches it is necessary to use sql.Conn.Raw() to access the underlying driver connection to retrieve a driver.Rows instance.
The driver exposes two public interfaces for working with record batches from the rows sub-package:
type Rows interface {
GetArrowBatches(context.Context) (ArrowBatchIterator, error)
}
type ArrowBatchIterator interface {
// Retrieve the next arrow.Record.
// Will return io.EOF if there are no more records
Next() (arrow.Record, error)
// Return true if the iterator contains more batches, false otherwise.
HasNext() bool
// Release any resources in use by the iterator.
Close()
}
The driver.Rows instance retrieved using Conn.Raw() can be converted to a Databricks Rows instance via a type assertion, then use GetArrowBatches() to retrieve a batch iterator.
If the ArrowBatchIterator is not closed it will leak resources, such as the underlying connection.
Calling code must call Release() on records returned by DBSQLArrowBatchIterator.Next().
Example usage:
import (
...
dbsqlrows "github.com/databricks/databricks-sql-go/rows"
)
func main() {
...
db := sql.OpenDB(connector)
defer db.Close()
conn, _ := db.Conn(context.BackGround())
defer conn.Close()
query := `select * from main.default.taxi_trip_data`
var rows driver.Rows
var err error
err = conn.Raw(func(d interface{}) error {
rows, err = d.(driver.QueryerContext).QueryContext(ctx, query, nil)
return err
})
if err != nil {
log.Fatalf("unable to run the query. err: %v", err)
}
defer rows.Close()
batches, err := rows.(dbsqlrows.Rows).GetArrowBatches(context.BackGround())
if err != nil {
log.Fatalf("unable to get arrow batches. err: %v", err)
}
var iBatch, nRows int
for batches.HasNext() {
b, err := batches.Next()
if err != nil {
log.Fatalf("Failure retrieving batch. err: %v", err)
}
log.Printf("batch %v: nRecords=%v\n", iBatch, b.NumRows())
iBatch += 1
nRows += int(b.NumRows())
}
log.Printf("NRows: %v\n", nRows)
}
# Supported Data Types
==================================
Expand Down
139 changes: 139 additions & 0 deletions examples/arrrowbatches/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
package main

import (
"context"
"database/sql"
"database/sql/driver"
"io"
"log"
"os"
"strconv"
"time"

"github.com/apache/arrow/go/v12/arrow"
dbsql "github.com/databricks/databricks-sql-go"
dbsqlrows "github.com/databricks/databricks-sql-go/rows"
"github.com/joho/godotenv"
)

func main() {
// Opening a driver typically will not attempt to connect to the database.
err := godotenv.Load()
if err != nil {
log.Fatal(err.Error())
}

// dbsqllog.SetLogLevel("debug")

port, err := strconv.Atoi(os.Getenv("DATABRICKS_PORT"))
if err != nil {
log.Fatal(err.Error())
}
connector, err := dbsql.NewConnector(
dbsql.WithServerHostname(os.Getenv("DATABRICKS_HOST")),
dbsql.WithPort(port),
dbsql.WithHTTPPath(os.Getenv("DATABRICKS_HTTPPATH")),
dbsql.WithAccessToken(os.Getenv("DATABRICKS_ACCESSTOKEN")),
dbsql.WithMaxRows(10000),
)

if err != nil {
// This will not be a connection error, but a DSN parse error or
// another initialization error.
log.Fatal(err)
}

db := sql.OpenDB(connector)
defer db.Close()

loopWithHasNext(db)
loopWithNext(db)
}

func loopWithHasNext(db *sql.DB) {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()

conn, _ := db.Conn(ctx)
defer conn.Close()

query := `select * from main.default.diamonds`

var rows driver.Rows
var err error
err = conn.Raw(func(d interface{}) error {
rows, err = d.(driver.QueryerContext).QueryContext(ctx, query, nil)
return err
})

if err != nil {
log.Fatalf("unable to run the query. err: %v", err)
}
defer rows.Close()

ctx2, cancel2 := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel2()

batches, err := rows.(dbsqlrows.Rows).GetArrowBatches(ctx2)
if err != nil {
log.Fatalf("unable to get arrow batches. err: %v", err)
}

var iBatch, nRows int
for batches.HasNext() {
b, err := batches.Next()
if err != nil {
log.Fatalf("Failure retrieving batch. err: %v", err)
}

log.Printf("batch %v: nRecords=%v\n", iBatch, b.NumRows())
iBatch += 1
nRows += int(b.NumRows())
}
log.Printf("NRows: %v\n", nRows)
}

func loopWithNext(db *sql.DB) {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()

conn, _ := db.Conn(ctx)
defer conn.Close()

query := `select * from main.default.diamonds`

var rows driver.Rows
var err error

err = conn.Raw(func(d interface{}) error {
rows, err = d.(driver.QueryerContext).QueryContext(ctx, query, nil)
return err
})
if err != nil {
log.Fatalf("unable to run the query. err: %v", err)
}
defer rows.Close()

ctx2, cancel2 := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel2()

batches, err := rows.(dbsqlrows.Rows).GetArrowBatches(ctx2)
if err != nil {
log.Fatalf("unable to get arrow batches. err: %v", err)
}

var iBatch, nRows int
var b arrow.Record
for b, err = batches.Next(); err == nil; b, err = batches.Next() {
log.Printf("batch %v: nRecords=%v\n", iBatch, b.NumRows())
iBatch += 1
nRows += int(b.NumRows())
}

log.Printf("NRows: %v\n", nRows)
if err == io.EOF {
log.Println("normal loop termination")
} else {
log.Printf("loop terminated with error: %v", err)
}
}

0 comments on commit 6bb1879

Please sign in to comment.