Skip to content

Commit

Permalink
Renamed DBSQLRows and DBSQLArrowBatchIterator
Browse files Browse the repository at this point in the history
Renamed DBSQLRows and DBSQLArrowBatchIterator to Rows and ArrowBatchIterator by dropping the DBSQL prefix.
Updated example to use UC

Signed-off-by: Raymond Cypher <raymond.cypher@databricks.com>
  • Loading branch information
rcypher-databricks committed Sep 27, 2023
1 parent 7ae207a commit 6930fa2
Show file tree
Hide file tree
Showing 8 changed files with 17 additions and 17 deletions.
10 changes: 5 additions & 5 deletions examples/arrrowbatches/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,15 +57,15 @@ func loopWithHasNext(db *sql.DB) {
conn, _ := db.Conn(ctx)
defer conn.Close()

query := `select * from hive_metastore.main.taxi_trip_data`
query := `select * from main.default.diamonds`

var rows driver.Rows
var err error
err = conn.Raw(func(d interface{}) error {
rows, err = d.(driver.QueryerContext).QueryContext(ctx, query, nil)
return err
})

if err != nil {
log.Fatalf("unable to run the query. err: %v", err)
}
Expand All @@ -74,7 +74,7 @@ func loopWithHasNext(db *sql.DB) {
ctx2, cancel2 := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel2()

batches, err := rows.(dbsqlrows.DBSQLRows).GetArrowBatches(ctx2)
batches, err := rows.(dbsqlrows.Rows).GetArrowBatches(ctx2)
if err != nil {
log.Fatalf("unable to get arrow batches. err: %v", err)
}
Expand All @@ -100,7 +100,7 @@ func loopWithNext(db *sql.DB) {
conn, _ := db.Conn(ctx)
defer conn.Close()

query := `select * from hive_metastore.main.taxi_trip_data`
query := `select * from main.default.diamonds`

var rows driver.Rows
var err error
Expand All @@ -117,7 +117,7 @@ func loopWithNext(db *sql.DB) {
ctx2, cancel2 := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel2()

batches, err := rows.(dbsqlrows.DBSQLRows).GetArrowBatches(ctx2)
batches, err := rows.(dbsqlrows.Rows).GetArrowBatches(ctx2)
if err != nil {
log.Fatalf("unable to get arrow batches. err: %v", err)
}
Expand Down
4 changes: 2 additions & 2 deletions internal/rows/arrowbased/arrowRecordIterator.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import (
"github.com/databricks/databricks-sql-go/rows"
)

func NewArrowRecordIterator(ctx context.Context, rpi rowscanner.ResultPageIterator, bi BatchIterator, arrowSchemaBytes []byte, cfg config.Config) rows.DBSQLArrowBatchIterator {
func NewArrowRecordIterator(ctx context.Context, rpi rowscanner.ResultPageIterator, bi BatchIterator, arrowSchemaBytes []byte, cfg config.Config) rows.ArrowBatchIterator {
ari := arrowRecordIterator{
cfg: cfg,
batchIterator: bi,
Expand All @@ -36,7 +36,7 @@ type arrowRecordIterator struct {
arrowSchemaBytes []byte
}

var _ rows.DBSQLArrowBatchIterator = (*arrowRecordIterator)(nil)
var _ rows.ArrowBatchIterator = (*arrowRecordIterator)(nil)

// Retrieve the next arrow record
func (ri *arrowRecordIterator) Next() (arrow.Record, error) {
Expand Down
2 changes: 1 addition & 1 deletion internal/rows/arrowbased/arrowRows.go
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ func (ars *arrowRowScanner) validateRowNumber(rowNumber int64) dbsqlerr.DBError
return nil
}

func (ars *arrowRowScanner) GetArrowBatches(ctx context.Context, cfg config.Config, rpi rowscanner.ResultPageIterator) (dbsqlrows.DBSQLArrowBatchIterator, error) {
func (ars *arrowRowScanner) GetArrowBatches(ctx context.Context, cfg config.Config, rpi rowscanner.ResultPageIterator) (dbsqlrows.ArrowBatchIterator, error) {
ri := NewArrowRecordIterator(ctx, rpi, ars.batchIterator, ars.arrowSchemaBytes, cfg)
return ri, nil
}
Expand Down
2 changes: 1 addition & 1 deletion internal/rows/columnbased/columnRows.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,6 @@ func (crs *columnRowScanner) value(tColumn *cli_service.TColumn, tColumnDesc *cl
func (crs *columnRowScanner) GetArrowBatches(
ctx context.Context,
cfg config.Config,
rpi rowscanner.ResultPageIterator) (dbsqlrows.DBSQLArrowBatchIterator, error) {
rpi rowscanner.ResultPageIterator) (dbsqlrows.ArrowBatchIterator, error) {
return nil, dbsqlerr_int.NewDriverError(ctx, "databricks: result set is not in arrow format", nil)
}
4 changes: 2 additions & 2 deletions internal/rows/rows.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ var _ driver.RowsColumnTypeScanType = (*rows)(nil)
var _ driver.RowsColumnTypeDatabaseTypeName = (*rows)(nil)
var _ driver.RowsColumnTypeNullable = (*rows)(nil)
var _ driver.RowsColumnTypeLength = (*rows)(nil)
var _ dbsqlrows.DBSQLRows = (*rows)(nil)
var _ dbsqlrows.Rows = (*rows)(nil)

func NewRows(
connId string,
Expand Down Expand Up @@ -532,7 +532,7 @@ func (r *rows) logger() *dbsqllog.DBSQLLogger {
return r.logger_
}

func (r *rows) GetArrowBatches(ctx context.Context) (dbsqlrows.DBSQLArrowBatchIterator, error) {
func (r *rows) GetArrowBatches(ctx context.Context) (dbsqlrows.ArrowBatchIterator, error) {
// update context with correlationId and connectionId which will be used in logging and errors
ctx = driverctx.NewContextWithCorrelationId(driverctx.NewContextWithConnId(ctx, r.connId), r.correlationId)

Expand Down
4 changes: 2 additions & 2 deletions internal/rows/rows_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -802,7 +802,7 @@ func TestGetArrowBatches(t *testing.T) {
rows, err := NewRows("connId", "corrId", nil, client, cfg, executeStatementResp.DirectResults)
assert.Nil(t, err)

rows2, ok := rows.(dbsqlrows.DBSQLRows)
rows2, ok := rows.(dbsqlrows.Rows)
assert.True(t, ok)

rs, err2 := rows2.GetArrowBatches(context.Background())
Expand Down Expand Up @@ -872,7 +872,7 @@ func TestGetArrowBatches(t *testing.T) {
rows, err := NewRows("connId", "corrId", nil, client, cfg, nil)
assert.Nil(t, err)

rows2, ok := rows.(dbsqlrows.DBSQLRows)
rows2, ok := rows.(dbsqlrows.Rows)
assert.True(t, ok)

rs, err2 := rows2.GetArrowBatches(context.Background())
Expand Down
2 changes: 1 addition & 1 deletion internal/rows/rowscanner/rowScanner.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ type RowScanner interface {
// Close any open resources
Close()

GetArrowBatches(ctx context.Context, cfg config.Config, rpi ResultPageIterator) (dbsqlrows.DBSQLArrowBatchIterator, error)
GetArrowBatches(ctx context.Context, cfg config.Config, rpi ResultPageIterator) (dbsqlrows.ArrowBatchIterator, error)
}

// Expected formats for TIMESTAMP and DATE types when represented by a string value
Expand Down
6 changes: 3 additions & 3 deletions rows/rows.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@ import (
"github.com/apache/arrow/go/v12/arrow"
)

type DBSQLRows interface {
GetArrowBatches(context.Context) (DBSQLArrowBatchIterator, error)
type Rows interface {
GetArrowBatches(context.Context) (ArrowBatchIterator, error)
}

type DBSQLArrowBatchIterator interface {
type ArrowBatchIterator interface {
// Retrieve the next arrow.Record.
// Will return io.EOF if there are no more records
Next() (arrow.Record, error)
Expand Down

0 comments on commit 6930fa2

Please sign in to comment.