Skip to content

Commit

Permalink
Merge branch 'cloudfetch' into main
Browse files Browse the repository at this point in the history
Signed-off-by: Matthew Kim <11141331+mattdeekay@users.noreply.github.com>
  • Loading branch information
mattdeekay committed Aug 4, 2023
1 parent 65bde57 commit 4be8ace
Show file tree
Hide file tree
Showing 16 changed files with 1,085 additions and 187 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ You can set query timeout value by appending a `timeout` query parameter (in sec
```
token:[your token]@[Workspace hostname]:[Port number][Endpoint HTTP Path]?timeout=1000&maxRows=1000
```
You can turn on Cloud Fetch to increase the performance of extracting large query results by fetching data in parallel via cloud storage (more info [here](https://www.databricks.com/blog/2021/08/11/how-we-achieved-high-bandwidth-connectivity-with-bi-tools.html)). To turn on Cloud Fetch, append `useCloudFetch=true`. You can also set the number of concurrently fetching goroutines by setting the `maxDownloadThreads` query parameter (default is 10):
```
token:[your token]@[Workspace hostname]:[Port number][Endpoint HTTP Path]?useCloudFetch=true&maxDownloadThreads=3
```

### Connecting with a new Connector

Expand Down
5 changes: 5 additions & 0 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,7 @@ func (c *conn) executeStatement(ctx context.Context, query string, args []driver
GetDirectResults: &cli_service.TSparkGetDirectResults{
MaxRows: int64(c.cfg.MaxRows),
},
CanDecompressLZ4Result_: &c.cfg.UseLz4Compression,
}

if c.cfg.UseArrowBatches {
Expand All @@ -295,6 +296,10 @@ func (c *conn) executeStatement(ctx context.Context, query string, args []driver
}
}

if c.cfg.UseCloudFetch {
req.CanDownloadResult_ = &c.cfg.UseCloudFetch
}

ctx = driverctx.NewContextWithConnId(ctx, c.id)
resp, err := c.client.ExecuteStatement(ctx, &req)

Expand Down
14 changes: 14 additions & 0 deletions connector.go
Original file line number Diff line number Diff line change
Expand Up @@ -245,3 +245,17 @@ func WithTransport(t http.RoundTripper) connOption {
c.Transport = t
}
}

// WithCloudFetch sets up the use of cloud fetch for query execution. Default is false.
func WithCloudFetch(useCloudFetch bool) connOption {
return func(c *config.Config) {
c.UseCloudFetch = useCloudFetch
}
}

// WithMaxDownloadThreads sets up maximum download threads for cloud fetch. Default is 10.
func WithMaxDownloadThreads(numThreads int) connOption {
return func(c *config.Config) {
c.MaxDownloadThreads = numThreads
}
}
4 changes: 4 additions & 0 deletions doc.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ Supported optional connection parameters can be specified in param=value and inc
- maxRows: Sets up the max rows fetched per request. Default is 100000
- timeout: Adds timeout (in seconds) for the server query execution. Default is no timeout
- userAgentEntry: Used to identify partners. Set as a string with format <isv-name+product-name>
- useCloudFetch: Used to enable cloud fetch for the query execution. Default is false
- maxDownloadThreads: Sets up the max number of concurrent workers for cloud fetch. Default is 10
Supported optional session parameters can be specified in param=value and include:
Expand Down Expand Up @@ -79,6 +81,8 @@ Supported functional options include:
- WithSessionParams(<params_map> map[string]string): Sets up session parameters including "timezone" and "ansi_mode". Optional
- WithTimeout(<timeout> Duration). Adds timeout (in time.Duration) for the server query execution. Default is no timeout. Optional
- WithUserAgentEntry(<isv-name+product-name> string). Used to identify partners. Optional
- WithCloudFetch (bool). Used to enable cloud fetch for the query execution. Default is false. Optional
- WithMaxDownloadThreads (<num_threads> int). Sets up the max number of concurrent workers for cloud fetch. Default is 10. Optional
# Query cancellation and timeout
Expand Down
5 changes: 5 additions & 0 deletions errors/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,13 @@ const (

// Execution error messages (query failure)
ErrQueryExecution = "failed to execute query"
ErrLinkExpired = "link expired"
)

func InvalidDSNFormat(param string, value string, expected string) string {
return fmt.Sprintf("invalid DSN: param %s with value %s is not of type %s", param, value, expected)
}

func ErrInvalidOperationState(state string) string {
return fmt.Sprintf("invalid operation state %s. This should not have happened", state)
}
Expand Down
97 changes: 97 additions & 0 deletions examples/cloudfetch/cloudfetch_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
package main

import (
"context"
"database/sql"
"fmt"
dbsql "github.com/databricks/databricks-sql-go"
"github.com/stretchr/testify/assert"
"os"
"strconv"
"testing"
"time"
)

type row struct {
symbol string
companyName string
industry string
date string
open float64
high float64
low float64
close float64
volume int
change float64
changePercentage float64
upTrend bool
volatile bool
}

func runTest(withCloudFetch bool, query string) ([]row, error) {
port, err := strconv.Atoi(os.Getenv("DATABRICKS_PORT"))
if err != nil {
return nil, err
}

connector, err := dbsql.NewConnector(
dbsql.WithServerHostname(os.Getenv("DATABRICKS_HOST")),
dbsql.WithPort(port),
dbsql.WithHTTPPath(os.Getenv("DATABRICKS_HTTPPATH")),
dbsql.WithAccessToken(os.Getenv("DATABRICKS_ACCESSTOKEN")),
dbsql.WithTimeout(10),
dbsql.WithInitialNamespace("hive_metastore", "default"),
dbsql.WithCloudFetch(withCloudFetch),
)
if err != nil {
return nil, err
}
db := sql.OpenDB(connector)
defer db.Close()

ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
if err := db.PingContext(ctx); err != nil {
return nil, err
}
rows, err1 := db.QueryContext(context.Background(), query)
defer rows.Close()

if err1 != nil {
if err1 == sql.ErrNoRows {
fmt.Println("not found")
return nil, err
} else {
return nil, err
}
}
var res []row
for rows.Next() {
r := row{}
err := rows.Scan(&r.symbol, &r.companyName, &r.industry, &r.date, &r.open, &r.high, &r.low, &r.close, &r.volume, &r.change, &r.changePercentage, &r.upTrend, &r.volatile)
if err != nil {
fmt.Println(err)
return nil, err
}
res = append(res, r)
}
return res, nil
}

func TestCloudFetch(t *testing.T) {
t.Run("Compare local batch to cloud fetch", func(t *testing.T) {
query := "select * from stock_data where date is not null and volume is not null order by date, symbol limit 10000000"

// Local arrow batch
abRes, err := runTest(false, query)
assert.NoError(t, err)

// Cloud fetch batch
cfRes, err := runTest(true, query)
assert.NoError(t, err)

for i := 0; i < len(abRes); i++ {
assert.Equal(t, abRes[i], cfRes[i], fmt.Sprintf("not equal for row: %d", i))
}
})
}
61 changes: 60 additions & 1 deletion internal/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@ type UserConfig struct {
RetryWaitMax time.Duration
RetryMax int
Transport http.RoundTripper
UseLz4Compression bool
CloudFetchConfig
}

// DeepCopy returns a true deep copy of UserConfig
Expand Down Expand Up @@ -138,6 +140,8 @@ func (ucfg UserConfig) DeepCopy() UserConfig {
RetryWaitMax: ucfg.RetryWaitMax,
RetryMax: ucfg.RetryMax,
Transport: ucfg.Transport,
UseLz4Compression: ucfg.UseLz4Compression,
CloudFetchConfig: ucfg.CloudFetchConfig,
}
}

Expand Down Expand Up @@ -170,6 +174,8 @@ func (ucfg UserConfig) WithDefaults() UserConfig {
if ucfg.RetryWaitMax == 0 {
ucfg.RetryWaitMax = 30 * time.Second
}
ucfg.UseLz4Compression = false
ucfg.CloudFetchConfig = CloudFetchConfig{}.WithDefaults()

return ucfg
}
Expand All @@ -194,7 +200,7 @@ func WithDefaults() *Config {

}

// ParseDSN constructs UserConfig by parsing DSN string supplied to `sql.Open()`
// ParseDSN constructs UserConfig and CloudFetchConfig by parsing DSN string supplied to `sql.Open()`
func ParseDSN(dsn string) (UserConfig, error) {
fullDSN := dsn
if !strings.HasPrefix(dsn, "https://") && !strings.HasPrefix(dsn, "http://") {
Expand Down Expand Up @@ -266,6 +272,25 @@ func ParseDSN(dsn string) (UserConfig, error) {
ucfg.Schema = params.Get("schema")
params.Del("schema")
}

// Cloud Fetch parameters
if params.Has("useCloudFetch") {
useCloudFetch, err := strconv.ParseBool(params.Get("useCloudFetch"))
if err != nil {
return UserConfig{}, dbsqlerrint.NewRequestError(context.TODO(), dbsqlerr.InvalidDSNFormat("useCloudFetch", params.Get("useCloudFetch"), "bool"), err)
}
ucfg.UseCloudFetch = useCloudFetch
}
params.Del("useCloudFetch")
if params.Has("maxDownloadThreads") {
numThreads, err := strconv.Atoi(params.Get("maxDownloadThreads"))
if err != nil {
return UserConfig{}, dbsqlerrint.NewRequestError(context.TODO(), dbsqlerr.InvalidDSNFormat("maxDownloadThreads", params.Get("maxDownloadThreads"), "int"), err)
}
ucfg.MaxDownloadThreads = numThreads
}
params.Del("maxDownloadThreads")

for k := range params {
if strings.ToLower(k) == "timezone" {
ucfg.Location, err = time.LoadLocation(params.Get("timezone"))
Expand Down Expand Up @@ -310,3 +335,37 @@ func (arrowConfig ArrowConfig) DeepCopy() ArrowConfig {
UseArrowNativeIntervalTypes: arrowConfig.UseArrowNativeIntervalTypes,
}
}

type CloudFetchConfig struct {
UseCloudFetch bool
MaxDownloadThreads int
MaxFilesInMemory int
MinTimeToExpiry time.Duration
}

func (cfg CloudFetchConfig) WithDefaults() CloudFetchConfig {
cfg.UseCloudFetch = false

if cfg.MaxDownloadThreads <= 0 {
cfg.MaxDownloadThreads = 10
}

if cfg.MaxFilesInMemory < 1 {
cfg.MaxFilesInMemory = 10
}

if cfg.MinTimeToExpiry < 0 {
cfg.MinTimeToExpiry = 0 * time.Second
}

return cfg
}

func (cfg CloudFetchConfig) DeepCopy() CloudFetchConfig {
return CloudFetchConfig{
UseCloudFetch: cfg.UseCloudFetch,
MaxDownloadThreads: cfg.MaxDownloadThreads,
MaxFilesInMemory: cfg.MaxFilesInMemory,
MinTimeToExpiry: cfg.MinTimeToExpiry,
}
}
Loading

0 comments on commit 4be8ace

Please sign in to comment.