Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge branch 'cloudfetch' into main #154

Merged
merged 6 commits into from
Aug 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ You can set query timeout value by appending a `timeout` query parameter (in sec
```
token:[your token]@[Workspace hostname]:[Port number][Endpoint HTTP Path]?timeout=1000&maxRows=1000
```
You can turn on Cloud Fetch to increase the performance of extracting large query results by fetching data in parallel via cloud storage (more info [here](https://www.databricks.com/blog/2021/08/11/how-we-achieved-high-bandwidth-connectivity-with-bi-tools.html)). To turn on Cloud Fetch, append `useCloudFetch=true`. You can also set the number of concurrently fetching goroutines by setting the `maxDownloadThreads` query parameter (default is 10):
```
token:[your token]@[Workspace hostname]:[Port number][Endpoint HTTP Path]?useCloudFetch=true&maxDownloadThreads=3
```

### Connecting with a new Connector

Expand Down
5 changes: 5 additions & 0 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,7 @@ func (c *conn) executeStatement(ctx context.Context, query string, args []driver
GetDirectResults: &cli_service.TSparkGetDirectResults{
MaxRows: int64(c.cfg.MaxRows),
},
CanDecompressLZ4Result_: &c.cfg.UseLz4Compression,
}

if c.cfg.UseArrowBatches {
Expand All @@ -295,6 +296,10 @@ func (c *conn) executeStatement(ctx context.Context, query string, args []driver
}
}

if c.cfg.UseCloudFetch {
req.CanDownloadResult_ = &c.cfg.UseCloudFetch
}

ctx = driverctx.NewContextWithConnId(ctx, c.id)
resp, err := c.client.ExecuteStatement(ctx, &req)

Expand Down
14 changes: 14 additions & 0 deletions connector.go
Original file line number Diff line number Diff line change
Expand Up @@ -245,3 +245,17 @@ func WithTransport(t http.RoundTripper) connOption {
c.Transport = t
}
}

// WithCloudFetch sets up the use of cloud fetch for query execution. Default is false.
func WithCloudFetch(useCloudFetch bool) connOption {
return func(c *config.Config) {
c.UseCloudFetch = useCloudFetch
}
}

// WithMaxDownloadThreads sets up maximum download threads for cloud fetch. Default is 10.
func WithMaxDownloadThreads(numThreads int) connOption {
return func(c *config.Config) {
c.MaxDownloadThreads = numThreads
}
}
99 changes: 61 additions & 38 deletions connector_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,24 +36,33 @@ func TestNewConnector(t *testing.T) {
WithSessionParams(sessionParams),
WithRetries(10, 3*time.Second, 60*time.Second),
WithTransport(roundTripper),
WithCloudFetch(true),
WithMaxDownloadThreads(15),
)
expectedCloudFetchConfig := config.CloudFetchConfig{
UseCloudFetch: true,
MaxDownloadThreads: 15,
MaxFilesInMemory: 10,
MinTimeToExpiry: 0 * time.Second,
}
expectedUserConfig := config.UserConfig{
Host: host,
Port: port,
Protocol: "https",
AccessToken: accessToken,
Authenticator: &pat.PATAuth{AccessToken: accessToken},
HTTPPath: "/" + httpPath,
MaxRows: maxRows,
QueryTimeout: timeout,
Catalog: catalog,
Schema: schema,
UserAgentEntry: userAgentEntry,
SessionParams: sessionParams,
RetryMax: 10,
RetryWaitMin: 3 * time.Second,
RetryWaitMax: 60 * time.Second,
Transport: roundTripper,
Host: host,
Port: port,
Protocol: "https",
AccessToken: accessToken,
Authenticator: &pat.PATAuth{AccessToken: accessToken},
HTTPPath: "/" + httpPath,
MaxRows: maxRows,
QueryTimeout: timeout,
Catalog: catalog,
Schema: schema,
UserAgentEntry: userAgentEntry,
SessionParams: sessionParams,
RetryMax: 10,
RetryWaitMin: 3 * time.Second,
RetryWaitMax: 60 * time.Second,
Transport: roundTripper,
CloudFetchConfig: expectedCloudFetchConfig,
}
expectedCfg := config.WithDefaults()
expectedCfg.DriverVersion = DriverVersion
Expand All @@ -75,18 +84,25 @@ func TestNewConnector(t *testing.T) {
WithAccessToken(accessToken),
WithHTTPPath(httpPath),
)
expectedCloudFetchConfig := config.CloudFetchConfig{
UseCloudFetch: false,
MaxDownloadThreads: 10,
MaxFilesInMemory: 10,
MinTimeToExpiry: 0 * time.Second,
}
expectedUserConfig := config.UserConfig{
Host: host,
Port: port,
Protocol: "https",
AccessToken: accessToken,
Authenticator: &pat.PATAuth{AccessToken: accessToken},
HTTPPath: "/" + httpPath,
MaxRows: maxRows,
SessionParams: sessionParams,
RetryMax: 4,
RetryWaitMin: 1 * time.Second,
RetryWaitMax: 30 * time.Second,
Host: host,
Port: port,
Protocol: "https",
AccessToken: accessToken,
Authenticator: &pat.PATAuth{AccessToken: accessToken},
HTTPPath: "/" + httpPath,
MaxRows: maxRows,
SessionParams: sessionParams,
RetryMax: 4,
RetryWaitMin: 1 * time.Second,
RetryWaitMax: 30 * time.Second,
CloudFetchConfig: expectedCloudFetchConfig,
}
expectedCfg := config.WithDefaults()
expectedCfg.UserConfig = expectedUserConfig
Expand All @@ -109,18 +125,25 @@ func TestNewConnector(t *testing.T) {
WithHTTPPath(httpPath),
WithRetries(-1, 0, 0),
)
expectedCloudFetchConfig := config.CloudFetchConfig{
UseCloudFetch: false,
MaxDownloadThreads: 10,
MaxFilesInMemory: 10,
MinTimeToExpiry: 0 * time.Second,
}
expectedUserConfig := config.UserConfig{
Host: host,
Port: port,
Protocol: "https",
AccessToken: accessToken,
Authenticator: &pat.PATAuth{AccessToken: accessToken},
HTTPPath: "/" + httpPath,
MaxRows: maxRows,
SessionParams: sessionParams,
RetryMax: -1,
RetryWaitMin: 0,
RetryWaitMax: 0,
Host: host,
Port: port,
Protocol: "https",
AccessToken: accessToken,
Authenticator: &pat.PATAuth{AccessToken: accessToken},
HTTPPath: "/" + httpPath,
MaxRows: maxRows,
SessionParams: sessionParams,
RetryMax: -1,
RetryWaitMin: 0,
RetryWaitMax: 0,
CloudFetchConfig: expectedCloudFetchConfig,
}
expectedCfg := config.WithDefaults()
expectedCfg.DriverVersion = DriverVersion
Expand Down
4 changes: 4 additions & 0 deletions doc.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ Supported optional connection parameters can be specified in param=value and inc
- maxRows: Sets up the max rows fetched per request. Default is 100000
- timeout: Adds timeout (in seconds) for the server query execution. Default is no timeout
- userAgentEntry: Used to identify partners. Set as a string with format <isv-name+product-name>
- useCloudFetch: Used to enable cloud fetch for the query execution. Default is false
- maxDownloadThreads: Sets up the max number of concurrent workers for cloud fetch. Default is 10

Supported optional session parameters can be specified in param=value and include:

Expand Down Expand Up @@ -79,6 +81,8 @@ Supported functional options include:
- WithSessionParams(<params_map> map[string]string): Sets up session parameters including "timezone" and "ansi_mode". Optional
- WithTimeout(<timeout> Duration). Adds timeout (in time.Duration) for the server query execution. Default is no timeout. Optional
- WithUserAgentEntry(<isv-name+product-name> string). Used to identify partners. Optional
- WithCloudFetch (bool). Used to enable cloud fetch for the query execution. Default is false. Optional
- WithMaxDownloadThreads (<num_threads> int). Sets up the max number of concurrent workers for cloud fetch. Default is 10. Optional

# Query cancellation and timeout

Expand Down
5 changes: 5 additions & 0 deletions errors/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,13 @@ const (

// Execution error messages (query failure)
ErrQueryExecution = "failed to execute query"
ErrLinkExpired = "link expired"
)

func InvalidDSNFormat(param string, value string, expected string) string {
return fmt.Sprintf("invalid DSN: param %s with value %s is not of type %s", param, value, expected)
}

func ErrInvalidOperationState(state string) string {
return fmt.Sprintf("invalid operation state %s. This should not have happened", state)
}
Expand Down
100 changes: 100 additions & 0 deletions examples/cloudfetch/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
package main

import (
"context"
"database/sql"
"fmt"
dbsql "github.com/databricks/databricks-sql-go"
"log"
"os"
"strconv"
"time"
)

type row struct {
symbol string
companyName string
industry string
date string
open float64
high float64
low float64
close float64
volume int
change float64
changePercentage float64
upTrend bool
volatile bool
}

func runTest(withCloudFetch bool, query string) ([]row, error) {
port, err := strconv.Atoi(os.Getenv("DATABRICKS_PORT"))
if err != nil {
return nil, err
}

connector, err := dbsql.NewConnector(
dbsql.WithServerHostname(os.Getenv("DATABRICKS_HOST")),
dbsql.WithPort(port),
dbsql.WithHTTPPath(os.Getenv("DATABRICKS_HTTPPATH")),
dbsql.WithAccessToken(os.Getenv("DATABRICKS_ACCESSTOKEN")),
dbsql.WithTimeout(10),
dbsql.WithInitialNamespace("hive_metastore", "default"),
dbsql.WithCloudFetch(withCloudFetch),
)
if err != nil {
return nil, err
}
db := sql.OpenDB(connector)
defer db.Close()

ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
if err := db.PingContext(ctx); err != nil {
return nil, err
}
rows, err1 := db.QueryContext(context.Background(), query)
defer rows.Close()

if err1 != nil {
if err1 == sql.ErrNoRows {
fmt.Println("not found")
return nil, err
} else {
return nil, err
}
}
var res []row
for rows.Next() {
r := row{}
err := rows.Scan(&r.symbol, &r.companyName, &r.industry, &r.date, &r.open, &r.high, &r.low, &r.close, &r.volume, &r.change, &r.changePercentage, &r.upTrend, &r.volatile)
if err != nil {
fmt.Println(err)
return nil, err
}
res = append(res, r)
}
return res, nil
}

func main() {
query := "select * from stock_data where date is not null and volume is not null order by date, symbol limit 10000000"

// Local arrow batch
abRes, err := runTest(false, query)
if err != nil {
log.Fatal(err)
}

// Cloud fetch batch
cfRes, err := runTest(true, query)
if err != nil {
log.Fatal(err)
}

for i := 0; i < len(abRes); i++ {
if abRes[i] != cfRes[i] {
log.Fatal(fmt.Sprintf("not equal for row: %d", i))
}
}
}