Skip to content

Commit

Permalink
Merge branch 'cloudfetch' into main (#154)
Browse files Browse the repository at this point in the history
Supports executing queries with Cloud Fetch for increased performance
and caching.

Steps taken:
- Synced fork `mattdeekay` for both `cloudfetch` and `main` branches
- On `mattdeekay:main`, ran `git merge --squash cloudfetch`
- Resolved merge conflicts
- Fixed `cloudfetch_test.go` end-to-end test to (WithEnableCloudFetch ->
WithCloudFetch)
- Commit and create PR
- Fix `connector_test.go` to add cloud fetch (forgot to add earlier)
- Add link expiration test to `batchloader_test.go`
- Fix `arrowRows_test.go`
- `golangci-lint run`
  • Loading branch information
rcypher-databricks committed Aug 8, 2023
2 parents 65bde57 + b8c87f7 commit 7e079fd
Show file tree
Hide file tree
Showing 17 changed files with 1,254 additions and 275 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ You can set query timeout value by appending a `timeout` query parameter (in sec
```
token:[your token]@[Workspace hostname]:[Port number][Endpoint HTTP Path]?timeout=1000&maxRows=1000
```
You can turn on Cloud Fetch to increase the performance of extracting large query results by fetching data in parallel via cloud storage (more info [here](https://www.databricks.com/blog/2021/08/11/how-we-achieved-high-bandwidth-connectivity-with-bi-tools.html)). To turn on Cloud Fetch, append `useCloudFetch=true`. You can also set the number of concurrently fetching goroutines by setting the `maxDownloadThreads` query parameter (default is 10):
```
token:[your token]@[Workspace hostname]:[Port number][Endpoint HTTP Path]?useCloudFetch=true&maxDownloadThreads=3
```

### Connecting with a new Connector

Expand Down
5 changes: 5 additions & 0 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,7 @@ func (c *conn) executeStatement(ctx context.Context, query string, args []driver
GetDirectResults: &cli_service.TSparkGetDirectResults{
MaxRows: int64(c.cfg.MaxRows),
},
CanDecompressLZ4Result_: &c.cfg.UseLz4Compression,
}

if c.cfg.UseArrowBatches {
Expand All @@ -295,6 +296,10 @@ func (c *conn) executeStatement(ctx context.Context, query string, args []driver
}
}

if c.cfg.UseCloudFetch {
req.CanDownloadResult_ = &c.cfg.UseCloudFetch
}

ctx = driverctx.NewContextWithConnId(ctx, c.id)
resp, err := c.client.ExecuteStatement(ctx, &req)

Expand Down
14 changes: 14 additions & 0 deletions connector.go
Original file line number Diff line number Diff line change
Expand Up @@ -245,3 +245,17 @@ func WithTransport(t http.RoundTripper) connOption {
c.Transport = t
}
}

// WithCloudFetch sets up the use of cloud fetch for query execution. Default is false.
func WithCloudFetch(useCloudFetch bool) connOption {
return func(c *config.Config) {
c.UseCloudFetch = useCloudFetch
}
}

// WithMaxDownloadThreads sets up maximum download threads for cloud fetch. Default is 10.
func WithMaxDownloadThreads(numThreads int) connOption {
return func(c *config.Config) {
c.MaxDownloadThreads = numThreads
}
}
99 changes: 61 additions & 38 deletions connector_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,24 +36,33 @@ func TestNewConnector(t *testing.T) {
WithSessionParams(sessionParams),
WithRetries(10, 3*time.Second, 60*time.Second),
WithTransport(roundTripper),
WithCloudFetch(true),
WithMaxDownloadThreads(15),
)
expectedCloudFetchConfig := config.CloudFetchConfig{
UseCloudFetch: true,
MaxDownloadThreads: 15,
MaxFilesInMemory: 10,
MinTimeToExpiry: 0 * time.Second,
}
expectedUserConfig := config.UserConfig{
Host: host,
Port: port,
Protocol: "https",
AccessToken: accessToken,
Authenticator: &pat.PATAuth{AccessToken: accessToken},
HTTPPath: "/" + httpPath,
MaxRows: maxRows,
QueryTimeout: timeout,
Catalog: catalog,
Schema: schema,
UserAgentEntry: userAgentEntry,
SessionParams: sessionParams,
RetryMax: 10,
RetryWaitMin: 3 * time.Second,
RetryWaitMax: 60 * time.Second,
Transport: roundTripper,
Host: host,
Port: port,
Protocol: "https",
AccessToken: accessToken,
Authenticator: &pat.PATAuth{AccessToken: accessToken},
HTTPPath: "/" + httpPath,
MaxRows: maxRows,
QueryTimeout: timeout,
Catalog: catalog,
Schema: schema,
UserAgentEntry: userAgentEntry,
SessionParams: sessionParams,
RetryMax: 10,
RetryWaitMin: 3 * time.Second,
RetryWaitMax: 60 * time.Second,
Transport: roundTripper,
CloudFetchConfig: expectedCloudFetchConfig,
}
expectedCfg := config.WithDefaults()
expectedCfg.DriverVersion = DriverVersion
Expand All @@ -75,18 +84,25 @@ func TestNewConnector(t *testing.T) {
WithAccessToken(accessToken),
WithHTTPPath(httpPath),
)
expectedCloudFetchConfig := config.CloudFetchConfig{
UseCloudFetch: false,
MaxDownloadThreads: 10,
MaxFilesInMemory: 10,
MinTimeToExpiry: 0 * time.Second,
}
expectedUserConfig := config.UserConfig{
Host: host,
Port: port,
Protocol: "https",
AccessToken: accessToken,
Authenticator: &pat.PATAuth{AccessToken: accessToken},
HTTPPath: "/" + httpPath,
MaxRows: maxRows,
SessionParams: sessionParams,
RetryMax: 4,
RetryWaitMin: 1 * time.Second,
RetryWaitMax: 30 * time.Second,
Host: host,
Port: port,
Protocol: "https",
AccessToken: accessToken,
Authenticator: &pat.PATAuth{AccessToken: accessToken},
HTTPPath: "/" + httpPath,
MaxRows: maxRows,
SessionParams: sessionParams,
RetryMax: 4,
RetryWaitMin: 1 * time.Second,
RetryWaitMax: 30 * time.Second,
CloudFetchConfig: expectedCloudFetchConfig,
}
expectedCfg := config.WithDefaults()
expectedCfg.UserConfig = expectedUserConfig
Expand All @@ -109,18 +125,25 @@ func TestNewConnector(t *testing.T) {
WithHTTPPath(httpPath),
WithRetries(-1, 0, 0),
)
expectedCloudFetchConfig := config.CloudFetchConfig{
UseCloudFetch: false,
MaxDownloadThreads: 10,
MaxFilesInMemory: 10,
MinTimeToExpiry: 0 * time.Second,
}
expectedUserConfig := config.UserConfig{
Host: host,
Port: port,
Protocol: "https",
AccessToken: accessToken,
Authenticator: &pat.PATAuth{AccessToken: accessToken},
HTTPPath: "/" + httpPath,
MaxRows: maxRows,
SessionParams: sessionParams,
RetryMax: -1,
RetryWaitMin: 0,
RetryWaitMax: 0,
Host: host,
Port: port,
Protocol: "https",
AccessToken: accessToken,
Authenticator: &pat.PATAuth{AccessToken: accessToken},
HTTPPath: "/" + httpPath,
MaxRows: maxRows,
SessionParams: sessionParams,
RetryMax: -1,
RetryWaitMin: 0,
RetryWaitMax: 0,
CloudFetchConfig: expectedCloudFetchConfig,
}
expectedCfg := config.WithDefaults()
expectedCfg.DriverVersion = DriverVersion
Expand Down
4 changes: 4 additions & 0 deletions doc.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ Supported optional connection parameters can be specified in param=value and inc
- maxRows: Sets up the max rows fetched per request. Default is 100000
- timeout: Adds timeout (in seconds) for the server query execution. Default is no timeout
- userAgentEntry: Used to identify partners. Set as a string with format <isv-name+product-name>
- useCloudFetch: Used to enable cloud fetch for the query execution. Default is false
- maxDownloadThreads: Sets up the max number of concurrent workers for cloud fetch. Default is 10
Supported optional session parameters can be specified in param=value and include:
Expand Down Expand Up @@ -79,6 +81,8 @@ Supported functional options include:
- WithSessionParams(<params_map> map[string]string): Sets up session parameters including "timezone" and "ansi_mode". Optional
- WithTimeout(<timeout> Duration). Adds timeout (in time.Duration) for the server query execution. Default is no timeout. Optional
- WithUserAgentEntry(<isv-name+product-name> string). Used to identify partners. Optional
- WithCloudFetch (bool). Used to enable cloud fetch for the query execution. Default is false. Optional
- WithMaxDownloadThreads (<num_threads> int). Sets up the max number of concurrent workers for cloud fetch. Default is 10. Optional
# Query cancellation and timeout
Expand Down
5 changes: 5 additions & 0 deletions errors/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,13 @@ const (

// Execution error messages (query failure)
ErrQueryExecution = "failed to execute query"
ErrLinkExpired = "link expired"
)

func InvalidDSNFormat(param string, value string, expected string) string {
return fmt.Sprintf("invalid DSN: param %s with value %s is not of type %s", param, value, expected)
}

func ErrInvalidOperationState(state string) string {
return fmt.Sprintf("invalid operation state %s. This should not have happened", state)
}
Expand Down
100 changes: 100 additions & 0 deletions examples/cloudfetch/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
package main

import (
"context"
"database/sql"
"fmt"
dbsql "github.com/databricks/databricks-sql-go"
"log"
"os"
"strconv"
"time"
)

type row struct {
symbol string
companyName string
industry string
date string
open float64
high float64
low float64
close float64
volume int
change float64
changePercentage float64
upTrend bool
volatile bool
}

func runTest(withCloudFetch bool, query string) ([]row, error) {
port, err := strconv.Atoi(os.Getenv("DATABRICKS_PORT"))
if err != nil {
return nil, err
}

connector, err := dbsql.NewConnector(
dbsql.WithServerHostname(os.Getenv("DATABRICKS_HOST")),
dbsql.WithPort(port),
dbsql.WithHTTPPath(os.Getenv("DATABRICKS_HTTPPATH")),
dbsql.WithAccessToken(os.Getenv("DATABRICKS_ACCESSTOKEN")),
dbsql.WithTimeout(10),
dbsql.WithInitialNamespace("hive_metastore", "default"),
dbsql.WithCloudFetch(withCloudFetch),
)
if err != nil {
return nil, err
}
db := sql.OpenDB(connector)
defer db.Close()

ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
if err := db.PingContext(ctx); err != nil {
return nil, err
}
rows, err1 := db.QueryContext(context.Background(), query)
defer rows.Close()

if err1 != nil {
if err1 == sql.ErrNoRows {
fmt.Println("not found")
return nil, err
} else {
return nil, err
}
}
var res []row
for rows.Next() {
r := row{}
err := rows.Scan(&r.symbol, &r.companyName, &r.industry, &r.date, &r.open, &r.high, &r.low, &r.close, &r.volume, &r.change, &r.changePercentage, &r.upTrend, &r.volatile)
if err != nil {
fmt.Println(err)
return nil, err
}
res = append(res, r)
}
return res, nil
}

func main() {
query := "select * from stock_data where date is not null and volume is not null order by date, symbol limit 10000000"

// Local arrow batch
abRes, err := runTest(false, query)
if err != nil {
log.Fatal(err)
}

// Cloud fetch batch
cfRes, err := runTest(true, query)
if err != nil {
log.Fatal(err)
}

for i := 0; i < len(abRes); i++ {
if abRes[i] != cfRes[i] {
log.Fatal(fmt.Sprintf("not equal for row: %d", i))
}
}
}
Loading

0 comments on commit 7e079fd

Please sign in to comment.