Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion experimental/aitools/cmd/discover_schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ For each table, returns:
sess.Set(middlewares.DatabricksClientKey, w)
ctx = session.WithSession(ctx, sess)

warehouseID, err := middlewares.GetWarehouseID(ctx)
warehouseID, err := middlewares.GetWarehouseID(ctx, true)
if err != nil {
return err
}
Expand Down
2 changes: 1 addition & 1 deletion experimental/aitools/cmd/get_default_warehouse.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ Returns warehouse ID of the default warehouse. Use --output json to get the full
sess.Set(middlewares.DatabricksClientKey, w)
ctx = session.WithSession(ctx, sess)

warehouse, err := middlewares.GetWarehouseEndpoint(ctx)
warehouse, err := middlewares.GetWarehouseEndpoint(ctx, false)
if err != nil {
return err
}
Expand Down
2 changes: 1 addition & 1 deletion experimental/aitools/cmd/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ Output includes the query results as JSON and row count.`,
sess.Set(middlewares.DatabricksClientKey, w)
ctx = session.WithSession(ctx, sess)

warehouseID, err := middlewares.GetWarehouseID(ctx)
warehouseID, err := middlewares.GetWarehouseID(ctx, true)
if err != nil {
return err
}
Expand Down
68 changes: 64 additions & 4 deletions experimental/aitools/lib/middlewares/warehouse.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"github.com/databricks/cli/experimental/aitools/lib/session"
"github.com/databricks/cli/libs/databrickscfg/cfgpickers"
"github.com/databricks/cli/libs/env"
"github.com/databricks/databricks-sdk-go"
"github.com/databricks/databricks-sdk-go/service/sql"
)

Expand Down Expand Up @@ -39,7 +40,9 @@ func loadWarehouseInBackground(ctx context.Context) {
sess.Set("warehouse_endpoint", warehouse)
}

func GetWarehouseEndpoint(ctx context.Context) (*sql.EndpointInfo, error) {
// GetWarehouseEndpoint returns the resolved warehouse endpoint.
// If autoStart is true and the warehouse is stopped, it will be started automatically.
func GetWarehouseEndpoint(ctx context.Context, autoStart bool) (*sql.EndpointInfo, error) {
sess, err := session.GetSession(ctx)
if err != nil {
return nil, err
Expand Down Expand Up @@ -68,23 +71,62 @@ func GetWarehouseEndpoint(ctx context.Context) (*sql.EndpointInfo, error) {
sess.Set("warehouse_endpoint", warehouse)
}

return warehouse.(*sql.EndpointInfo), nil
endpoint := warehouse.(*sql.EndpointInfo)

if autoStart && (endpoint.State == sql.StateStopped || endpoint.State == sql.StateStopping) {
endpoint, err = startWarehouse(ctx, endpoint.Id)
if err != nil {
return nil, err
}
sess.Set("warehouse_endpoint", endpoint)
}

return endpoint, nil
}

func GetWarehouseID(ctx context.Context) (string, error) {
warehouse, err := GetWarehouseEndpoint(ctx)
// GetWarehouseID returns the resolved warehouse ID.
// If autoStart is true and the warehouse is stopped, it will be started automatically.
func GetWarehouseID(ctx context.Context, autoStart bool) (string, error) {
warehouse, err := GetWarehouseEndpoint(ctx, autoStart)
if err != nil {
return "", err
}
return warehouse.Id, nil
}

func startWarehouse(ctx context.Context, id string) (*sql.EndpointInfo, error) {
w, err := GetDatabricksClient(ctx)
if err != nil {
return nil, fmt.Errorf("get databricks client: %w", err)
}
wait, err := w.Warehouses.Start(ctx, sql.StartRequest{Id: id})
if err != nil {
return nil, fmt.Errorf("start warehouse %s: %w", id, err)
}
resp, err := wait.Get()
if err != nil {
return nil, fmt.Errorf("wait for warehouse %s to start: %w", id, err)
}
return &sql.EndpointInfo{
Id: resp.Id,
Name: resp.Name,
State: resp.State,
}, nil
}

func getDefaultWarehouse(ctx context.Context) (*sql.EndpointInfo, error) {
w, err := GetDatabricksClient(ctx)
if err != nil {
return nil, fmt.Errorf("get databricks client: %w", err)
}
return resolveWarehouse(ctx, w)
}

// resolveWarehouse selects a warehouse using the following priority:
// 1. DATABRICKS_WAREHOUSE_ID env var
// 2. User's default warehouse override (CUSTOM type only)
// 3. Server-side default / first usable warehouse by state
func resolveWarehouse(ctx context.Context, w *databricks.WorkspaceClient) (*sql.EndpointInfo, error) {
// first resolve DATABRICKS_WAREHOUSE_ID env variable
warehouseID := env.Get(ctx, "DATABRICKS_WAREHOUSE_ID")
if warehouseID != "" {
Expand All @@ -101,5 +143,23 @@ func getDefaultWarehouse(ctx context.Context) (*sql.EndpointInfo, error) {
}, nil
}

// Check user's default warehouse override (set via the SQL UI or CLI).
// Only CUSTOM overrides are used; LAST_SELECTED requires UI state we don't have.
override, err := w.Warehouses.GetDefaultWarehouseOverride(ctx, sql.GetDefaultWarehouseOverrideRequest{
Name: "default-warehouse-overrides/me",
})
if err == nil && override.Type == sql.DefaultWarehouseOverrideTypeCustom && override.WarehouseId != "" {
warehouse, err := w.Warehouses.Get(ctx, sql.GetWarehouseRequest{
Id: override.WarehouseId,
})
if err == nil && warehouse.State != sql.StateDeleted && warehouse.State != sql.StateDeleting {
return &sql.EndpointInfo{
Id: warehouse.Id,
Name: warehouse.Name,
State: warehouse.State,
}, nil
}
}

return cfgpickers.GetDefaultWarehouse(ctx, w)
}
2 changes: 1 addition & 1 deletion experimental/aitools/lib/providers/clitools/discover.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import (
// Discover provides workspace context and workflow guidance.
// Returns L1 (flow) always + L2 (target) for detected target types + L3 (skills) listing.
func Discover(ctx context.Context, workingDirectory string) (string, error) {
warehouse, err := middlewares.GetWarehouseEndpoint(ctx)
warehouse, err := middlewares.GetWarehouseEndpoint(ctx, false)
if err != nil {
log.Debugf(ctx, "Failed to get default warehouse (non-fatal): %v", err)
warehouse = nil
Expand Down