diff --git a/experimental/aitools/cmd/discover_schema.go b/experimental/aitools/cmd/discover_schema.go index e611764830..9d30e367c0 100644 --- a/experimental/aitools/cmd/discover_schema.go +++ b/experimental/aitools/cmd/discover_schema.go @@ -50,7 +50,7 @@ For each table, returns: sess.Set(middlewares.DatabricksClientKey, w) ctx = session.WithSession(ctx, sess) - warehouseID, err := middlewares.GetWarehouseID(ctx) + warehouseID, err := middlewares.GetWarehouseID(ctx, true) if err != nil { return err } diff --git a/experimental/aitools/cmd/get_default_warehouse.go b/experimental/aitools/cmd/get_default_warehouse.go index 010539d5de..7c8da37b98 100644 --- a/experimental/aitools/cmd/get_default_warehouse.go +++ b/experimental/aitools/cmd/get_default_warehouse.go @@ -43,7 +43,7 @@ Returns warehouse ID of the default warehouse. Use --output json to get the full sess.Set(middlewares.DatabricksClientKey, w) ctx = session.WithSession(ctx, sess) - warehouse, err := middlewares.GetWarehouseEndpoint(ctx) + warehouse, err := middlewares.GetWarehouseEndpoint(ctx, false) if err != nil { return err } diff --git a/experimental/aitools/cmd/query.go b/experimental/aitools/cmd/query.go index d1075dfd9f..23fc3c5f2d 100644 --- a/experimental/aitools/cmd/query.go +++ b/experimental/aitools/cmd/query.go @@ -41,7 +41,7 @@ Output includes the query results as JSON and row count.`, sess.Set(middlewares.DatabricksClientKey, w) ctx = session.WithSession(ctx, sess) - warehouseID, err := middlewares.GetWarehouseID(ctx) + warehouseID, err := middlewares.GetWarehouseID(ctx, true) if err != nil { return err } diff --git a/experimental/aitools/lib/middlewares/warehouse.go b/experimental/aitools/lib/middlewares/warehouse.go index 78a62d793a..5ee3cd1387 100644 --- a/experimental/aitools/lib/middlewares/warehouse.go +++ b/experimental/aitools/lib/middlewares/warehouse.go @@ -8,6 +8,7 @@ import ( "github.com/databricks/cli/experimental/aitools/lib/session" "github.com/databricks/cli/libs/databrickscfg/cfgpickers" "github.com/databricks/cli/libs/env" + "github.com/databricks/databricks-sdk-go" "github.com/databricks/databricks-sdk-go/service/sql" ) @@ -39,7 +40,9 @@ func loadWarehouseInBackground(ctx context.Context) { sess.Set("warehouse_endpoint", warehouse) } -func GetWarehouseEndpoint(ctx context.Context) (*sql.EndpointInfo, error) { +// GetWarehouseEndpoint returns the resolved warehouse endpoint. +// If autoStart is true and the warehouse is stopped, it will be started automatically. +func GetWarehouseEndpoint(ctx context.Context, autoStart bool) (*sql.EndpointInfo, error) { sess, err := session.GetSession(ctx) if err != nil { return nil, err @@ -68,23 +71,62 @@ func GetWarehouseEndpoint(ctx context.Context) (*sql.EndpointInfo, error) { sess.Set("warehouse_endpoint", warehouse) } - return warehouse.(*sql.EndpointInfo), nil + endpoint := warehouse.(*sql.EndpointInfo) + + if autoStart && (endpoint.State == sql.StateStopped || endpoint.State == sql.StateStopping) { + endpoint, err = startWarehouse(ctx, endpoint.Id) + if err != nil { + return nil, err + } + sess.Set("warehouse_endpoint", endpoint) + } + + return endpoint, nil } -func GetWarehouseID(ctx context.Context) (string, error) { - warehouse, err := GetWarehouseEndpoint(ctx) +// GetWarehouseID returns the resolved warehouse ID. +// If autoStart is true and the warehouse is stopped, it will be started automatically. +func GetWarehouseID(ctx context.Context, autoStart bool) (string, error) { + warehouse, err := GetWarehouseEndpoint(ctx, autoStart) if err != nil { return "", err } return warehouse.Id, nil } +func startWarehouse(ctx context.Context, id string) (*sql.EndpointInfo, error) { + w, err := GetDatabricksClient(ctx) + if err != nil { + return nil, fmt.Errorf("get databricks client: %w", err) + } + wait, err := w.Warehouses.Start(ctx, sql.StartRequest{Id: id}) + if err != nil { + return nil, fmt.Errorf("start warehouse %s: %w", id, err) + } + resp, err := wait.Get() + if err != nil { + return nil, fmt.Errorf("wait for warehouse %s to start: %w", id, err) + } + return &sql.EndpointInfo{ + Id: resp.Id, + Name: resp.Name, + State: resp.State, + }, nil +} + func getDefaultWarehouse(ctx context.Context) (*sql.EndpointInfo, error) { w, err := GetDatabricksClient(ctx) if err != nil { return nil, fmt.Errorf("get databricks client: %w", err) } + return resolveWarehouse(ctx, w) +} +// resolveWarehouse selects a warehouse using the following priority: +// 1. DATABRICKS_WAREHOUSE_ID env var +// 2. User's default warehouse override (CUSTOM type only) +// 3. Server-side default / first usable warehouse by state +func resolveWarehouse(ctx context.Context, w *databricks.WorkspaceClient) (*sql.EndpointInfo, error) { // first resolve DATABRICKS_WAREHOUSE_ID env variable warehouseID := env.Get(ctx, "DATABRICKS_WAREHOUSE_ID") if warehouseID != "" { @@ -101,5 +143,23 @@ func getDefaultWarehouse(ctx context.Context) (*sql.EndpointInfo, error) { }, nil } + // Check user's default warehouse override (set via the SQL UI or CLI). + // Only CUSTOM overrides are used; LAST_SELECTED requires UI state we don't have. + override, err := w.Warehouses.GetDefaultWarehouseOverride(ctx, sql.GetDefaultWarehouseOverrideRequest{ + Name: "default-warehouse-overrides/me", + }) + if err == nil && override.Type == sql.DefaultWarehouseOverrideTypeCustom && override.WarehouseId != "" { + warehouse, err := w.Warehouses.Get(ctx, sql.GetWarehouseRequest{ + Id: override.WarehouseId, + }) + if err == nil && warehouse.State != sql.StateDeleted && warehouse.State != sql.StateDeleting { + return &sql.EndpointInfo{ + Id: warehouse.Id, + Name: warehouse.Name, + State: warehouse.State, + }, nil + } + } + return cfgpickers.GetDefaultWarehouse(ctx, w) } diff --git a/experimental/aitools/lib/providers/clitools/discover.go b/experimental/aitools/lib/providers/clitools/discover.go index 7b67b2a6fa..5c69ba16a7 100644 --- a/experimental/aitools/lib/providers/clitools/discover.go +++ b/experimental/aitools/lib/providers/clitools/discover.go @@ -16,7 +16,7 @@ import ( // Discover provides workspace context and workflow guidance. // Returns L1 (flow) always + L2 (target) for detected target types + L3 (skills) listing. func Discover(ctx context.Context, workingDirectory string) (string, error) { - warehouse, err := middlewares.GetWarehouseEndpoint(ctx) + warehouse, err := middlewares.GetWarehouseEndpoint(ctx, false) if err != nil { log.Debugf(ctx, "Failed to get default warehouse (non-fatal): %v", err) warehouse = nil