diff --git a/experimental/aitools/README.md b/experimental/aitools/README.md index 571136538c..ec12ed10f7 100644 --- a/experimental/aitools/README.md +++ b/experimental/aitools/README.md @@ -10,12 +10,41 @@ Current commands: - `databricks experimental aitools tools query` - `databricks experimental aitools tools discover-schema` - `databricks experimental aitools tools get-default-warehouse` +- `databricks experimental aitools tools statement submit` +- `databricks experimental aitools tools statement get` +- `databricks experimental aitools tools statement status` +- `databricks experimental aitools tools statement cancel` Current behavior: - `skills install` installs Databricks skills for detected coding agents. - `install` is a compatibility alias for `skills install`. - `tools` exposes a small set of AI-oriented workspace helpers. +- `tools query` accepts a single SQL or multiple SQLs in one invocation. Pass + several positional arguments and/or repeat `--file` to run them in parallel + against the warehouse. Multi-query output is always JSON; control parallelism + with `--concurrency` (default 8). + + ```bash + databricks experimental aitools tools query \ + --warehouse --output json \ + "SELECT count(*) FROM samples.nyctaxi.trips" \ + "SELECT min(tpep_pickup_datetime), max(tpep_pickup_datetime) FROM samples.nyctaxi.trips" \ + "SELECT vendor_id, count(*) FROM samples.nyctaxi.trips GROUP BY 1" + ``` + +- `tools statement` is a low-level lifecycle for asynchronous statements. + `submit` returns a `statement_id` immediately, `get` polls until terminal + and emits rows, `status` peeks without blocking, and `cancel` requests + termination. Ctrl+C on `get` stops polling but does NOT cancel the + server-side statement; use `cancel` for that. + + ```bash + SID=$(databricks experimental aitools tools statement submit \ + --warehouse "SELECT pg_sleep(5)" | jq -r '.statement_id') + databricks experimental aitools tools statement status "$SID" + databricks experimental aitools tools statement get "$SID" + ``` Removed behavior: diff --git a/experimental/aitools/cmd/batch.go b/experimental/aitools/cmd/batch.go new file mode 100644 index 0000000000..38ecea531e --- /dev/null +++ b/experimental/aitools/cmd/batch.go @@ -0,0 +1,215 @@ +package aitools + +import ( + "context" + "errors" + "fmt" + "os" + "os/signal" + "sync/atomic" + "syscall" + "time" + + "github.com/databricks/cli/libs/cmdio" + "github.com/databricks/cli/libs/log" + "github.com/databricks/databricks-sdk-go/service/sql" + "golang.org/x/sync/errgroup" +) + +// defaultBatchConcurrency caps in-flight statements when --concurrency is unset. +// Matches the default used by cmd/fs/cp.go for similar fan-out work. +const defaultBatchConcurrency = 8 + +// errInvalidBatchConcurrency is returned when --concurrency is set to a value +// that errgroup.SetLimit can't honor (0 deadlocks, negative removes the cap). +var errInvalidBatchConcurrency = errors.New("--concurrency must be at least 1") + +// batchResult is the per-statement payload emitted in batch mode JSON output. +// State is the server-reported terminal state. Error is set whenever the +// statement did not produce usable rows, regardless of state, so consumers +// can branch on `error == null` alone. +type batchResult struct { + SQL string `json:"sql"` + StatementID string `json:"statement_id,omitempty"` + State sql.StatementState `json:"state,omitempty"` + ElapsedMs int64 `json:"elapsed_ms"` + Columns []string `json:"columns,omitempty"` + Rows [][]string `json:"rows,omitempty"` + Error *batchResultError `json:"error,omitempty"` +} + +// batchResultError captures user-visible error info for a failed statement. +type batchResultError struct { + Message string `json:"message"` + ErrorCode string `json:"error_code,omitempty"` +} + +// executeBatch submits sqls against the warehouse in parallel, polls each to +// completion, and returns one batchResult per input in input order. +// +// Individual statement failures do not abort siblings; failures are encoded in +// the per-result Error field so callers can render partial results. +// +// On context cancellation (Ctrl+C or parent context), still-running statements +// are cancelled server-side via CancelExecution. Statements that finished +// before cancellation are left as-is. +func executeBatch(ctx context.Context, api sql.StatementExecutionInterface, warehouseID string, sqls []string, concurrency int) []batchResult { + pollCtx, pollCancel := context.WithCancel(ctx) + defer pollCancel() + + sigCh := make(chan os.Signal, 1) + signal.Notify(sigCh, os.Interrupt, syscall.SIGTERM) + defer signal.Stop(sigCh) + + go func() { + select { + case <-sigCh: + log.Infof(ctx, "Received interrupt, cancelling %d in-flight queries", len(sqls)) + pollCancel() + case <-pollCtx.Done(): + } + }() + + sp := cmdio.NewSpinner(pollCtx) + defer sp.Close() + sp.Update(fmt.Sprintf("Executing %d queries...", len(sqls))) + + var completed atomic.Int64 + ticker := time.NewTicker(500 * time.Millisecond) + defer ticker.Stop() + go func() { + for { + select { + case <-pollCtx.Done(): + return + case <-ticker.C: + sp.Update(fmt.Sprintf("Executing %d queries... (%d/%d done)", len(sqls), completed.Load(), len(sqls))) + } + } + }() + + results := make([]batchResult, len(sqls)) + // Each goroutine writes to a distinct slot, safe without a mutex. + // We read after g.Wait(), establishing happens-before for all writes. + statementIDs := make([]string, len(sqls)) + + g := new(errgroup.Group) + g.SetLimit(concurrency) + for i, sqlStr := range sqls { + g.Go(func() error { + results[i] = runOneBatchQuery(pollCtx, api, warehouseID, sqlStr, statementIDs, i) + completed.Add(1) + return nil + }) + } + _ = g.Wait() + + // pollStatement is a pure helper that returns ctx.Err() on cancellation + // without touching the server. Sweep any not-yet-terminal statements here. + if pollCtx.Err() != nil { + cancelInFlight(ctx, api, statementIDs, results) + } + + return results +} + +// runOneBatchQuery submits one SQL, polls to completion, and returns its +// batchResult. All errors are encoded into the result; never returns an error. +func runOneBatchQuery(ctx context.Context, api sql.StatementExecutionInterface, warehouseID, sqlStr string, statementIDs []string, idx int) batchResult { + start := time.Now() + result := batchResult{SQL: sqlStr} + + resp, err := api.ExecuteStatement(ctx, sql.ExecuteStatementRequest{ + WarehouseId: warehouseID, + Statement: sqlStr, + WaitTimeout: "0s", + OnWaitTimeout: sql.ExecuteStatementRequestOnWaitTimeoutContinue, + }) + if err != nil { + if ctx.Err() != nil { + result.State = sql.StatementStateCanceled + result.Error = &batchResultError{Message: "submission cancelled"} + } else { + result.State = sql.StatementStateFailed + result.Error = &batchResultError{Message: fmt.Sprintf("execute statement: %v", err)} + } + result.ElapsedMs = time.Since(start).Milliseconds() + return result + } + + statementIDs[idx] = resp.StatementId + result.StatementID = resp.StatementId + + pollResp, err := pollStatement(ctx, api, resp) + if err != nil { + if ctx.Err() != nil { + result.State = sql.StatementStateCanceled + result.Error = &batchResultError{Message: "cancelled"} + } else { + result.State = sql.StatementStateFailed + result.Error = &batchResultError{Message: err.Error()} + } + result.ElapsedMs = time.Since(start).Milliseconds() + return result + } + + if pollResp.Status != nil { + result.State = pollResp.Status.State + } + + if result.State != sql.StatementStateSucceeded { + result.Error = &batchResultError{} + if pollResp.Status != nil && pollResp.Status.Error != nil { + result.Error.Message = pollResp.Status.Error.Message + result.Error.ErrorCode = string(pollResp.Status.Error.ErrorCode) + } else { + result.Error.Message = fmt.Sprintf("query reached terminal state %s", result.State) + } + result.ElapsedMs = time.Since(start).Milliseconds() + return result + } + + result.Columns = extractColumns(pollResp.Manifest) + rows, err := fetchAllRows(ctx, api, pollResp) + if err != nil { + result.Error = &batchResultError{Message: fmt.Sprintf("fetch rows: %v", err)} + result.ElapsedMs = time.Since(start).Milliseconds() + return result + } + result.Rows = rows + result.ElapsedMs = time.Since(start).Milliseconds() + return result +} + +// cancelInFlight sends CancelExecution for every statement that didn't reach +// a terminal state server-side before context cancellation. Best effort: errors +// are logged at warn but don't fail the batch. +func cancelInFlight(ctx context.Context, api sql.StatementExecutionInterface, statementIDs []string, results []batchResult) { + var cancelled int + for i, sid := range statementIDs { + if sid == "" { + continue + } + switch results[i].State { + case sql.StatementStateSucceeded, sql.StatementStateFailed, sql.StatementStateClosed: + continue + case sql.StatementStateCanceled, sql.StatementStatePending, sql.StatementStateRunning: + // Either still running server-side, or our internal "canceled" + // marker meaning the goroutine bailed without telling the server. + // Either way, send CancelExecution. + } + // Detach from the inbound ctx (which is typically already cancelled by + // the time we reach this sweep): WithoutCancel keeps the caller's + // values but drops the cancellation signal so the cancel RPC actually + // reaches the warehouse instead of short-circuiting on ctx.Err(). + cancelCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), cancelTimeout) + if err := api.CancelExecution(cancelCtx, sql.CancelExecutionRequest{StatementId: sid}); err != nil { + log.Warnf(ctx, "Failed to cancel statement %s: %v", sid, err) + } + cancel() + cancelled++ + } + if cancelled > 0 { + cmdio.LogString(ctx, fmt.Sprintf("Cancelled %d in-flight queries.", cancelled)) + } +} diff --git a/experimental/aitools/cmd/batch_test.go b/experimental/aitools/cmd/batch_test.go new file mode 100644 index 0000000000..f6f468768f --- /dev/null +++ b/experimental/aitools/cmd/batch_test.go @@ -0,0 +1,243 @@ +package aitools + +import ( + "context" + "errors" + "fmt" + "strconv" + "strings" + "testing" + + "github.com/databricks/cli/libs/cmdio" + mocksql "github.com/databricks/databricks-sdk-go/experimental/mocks/service/sql" + "github.com/databricks/databricks-sdk-go/service/sql" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + +func TestRenderBatchJSON(t *testing.T) { + results := []batchResult{ + { + SQL: "SELECT 1", + StatementID: "stmt-1", + State: sql.StatementStateSucceeded, + ElapsedMs: 42, + Columns: []string{"n"}, + Rows: [][]string{{"1"}}, + }, + { + SQL: "SELECT bad_syntax", + StatementID: "stmt-2", + State: sql.StatementStateFailed, + ElapsedMs: 12, + Error: &batchResultError{ + Message: "near 'bad_syntax': syntax error", + ErrorCode: "SYNTAX_ERROR", + }, + }, + } + + var buf strings.Builder + err := renderBatchJSON(&buf, results) + require.NoError(t, err) + + output := buf.String() + assert.Contains(t, output, `"sql": "SELECT 1"`) + assert.Contains(t, output, `"statement_id": "stmt-1"`) + assert.Contains(t, output, `"state": "SUCCEEDED"`) + assert.Contains(t, output, `"elapsed_ms": 42`) + assert.Contains(t, output, `"columns": [`) + assert.Contains(t, output, `"rows": [`) + assert.Contains(t, output, `"sql": "SELECT bad_syntax"`) + assert.Contains(t, output, `"error": {`) + assert.Contains(t, output, `"error_code": "SYNTAX_ERROR"`) + // Trailing newline. + assert.True(t, strings.HasSuffix(output, "\n")) +} + +func TestExecuteBatchAllSucceed(t *testing.T) { + ctx := cmdio.MockDiscard(t.Context()) + mockAPI := mocksql.NewMockStatementExecutionInterface(t) + + sqls := []string{"SELECT 1", "SELECT 2", "SELECT 3"} + for i, sqlStr := range sqls { + sid := fmt.Sprintf("stmt-%d", i+1) + mockAPI.EXPECT().ExecuteStatement(mock.Anything, mock.MatchedBy(func(req sql.ExecuteStatementRequest) bool { + return req.Statement == sqlStr + })).Return(&sql.StatementResponse{ + StatementId: sid, + Status: &sql.StatementStatus{State: sql.StatementStateSucceeded}, + Manifest: &sql.ResultManifest{Schema: &sql.ResultSchema{Columns: []sql.ColumnInfo{{Name: "n"}}}}, + Result: &sql.ResultData{DataArray: [][]string{{strconv.Itoa(i + 1)}}}, + }, nil).Once() + } + + results := executeBatch(ctx, mockAPI, "wh-123", sqls, 8) + + require.Len(t, results, 3) + for i, r := range results { + assert.Equal(t, sqls[i], r.SQL, "result %d sql", i) + assert.Equal(t, sql.StatementStateSucceeded, r.State, "result %d state", i) + assert.Nil(t, r.Error, "result %d error", i) + assert.Equal(t, []string{"n"}, r.Columns, "result %d columns", i) + assert.Equal(t, [][]string{{strconv.Itoa(i + 1)}}, r.Rows, "result %d rows", i) + assert.NotEmpty(t, r.StatementID, "result %d statement_id", i) + } +} + +func TestExecuteBatchPartialFailure(t *testing.T) { + ctx := cmdio.MockDiscard(t.Context()) + mockAPI := mocksql.NewMockStatementExecutionInterface(t) + + mockAPI.EXPECT().ExecuteStatement(mock.Anything, mock.MatchedBy(func(req sql.ExecuteStatementRequest) bool { + return req.Statement == "SELECT 1" + })).Return(&sql.StatementResponse{ + StatementId: "stmt-good", + Status: &sql.StatementStatus{State: sql.StatementStateSucceeded}, + Manifest: &sql.ResultManifest{Schema: &sql.ResultSchema{Columns: []sql.ColumnInfo{{Name: "n"}}}}, + Result: &sql.ResultData{DataArray: [][]string{{"1"}}}, + }, nil).Once() + + mockAPI.EXPECT().ExecuteStatement(mock.Anything, mock.MatchedBy(func(req sql.ExecuteStatementRequest) bool { + return req.Statement == "SELECT bad" + })).Return(&sql.StatementResponse{ + StatementId: "stmt-bad", + Status: &sql.StatementStatus{ + State: sql.StatementStateFailed, + Error: &sql.ServiceError{ + ErrorCode: "SYNTAX_ERROR", + Message: "near 'bad': syntax error", + }, + }, + }, nil).Once() + + results := executeBatch(ctx, mockAPI, "wh-123", []string{"SELECT 1", "SELECT bad"}, 8) + + require.Len(t, results, 2) + assert.Nil(t, results[0].Error) + assert.Equal(t, sql.StatementStateSucceeded, results[0].State) + + require.NotNil(t, results[1].Error) + assert.Equal(t, sql.StatementStateFailed, results[1].State) + assert.Equal(t, "SYNTAX_ERROR", results[1].Error.ErrorCode) + assert.Contains(t, results[1].Error.Message, "syntax error") +} + +func TestExecuteBatchSubmissionFailure(t *testing.T) { + // ExecuteStatement transport error is encoded into the per-result error, + // not propagated up to abort siblings. + ctx := cmdio.MockDiscard(t.Context()) + mockAPI := mocksql.NewMockStatementExecutionInterface(t) + + mockAPI.EXPECT().ExecuteStatement(mock.Anything, mock.MatchedBy(func(req sql.ExecuteStatementRequest) bool { + return req.Statement == "SELECT good" + })).Return(&sql.StatementResponse{ + StatementId: "stmt-good", + Status: &sql.StatementStatus{State: sql.StatementStateSucceeded}, + }, nil).Once() + + mockAPI.EXPECT().ExecuteStatement(mock.Anything, mock.MatchedBy(func(req sql.ExecuteStatementRequest) bool { + return req.Statement == "SELECT broken" + })).Return(nil, errors.New("network unreachable")).Once() + + results := executeBatch(ctx, mockAPI, "wh-123", []string{"SELECT good", "SELECT broken"}, 8) + + require.Len(t, results, 2) + assert.Nil(t, results[0].Error) + require.NotNil(t, results[1].Error) + assert.Contains(t, results[1].Error.Message, "execute statement") + assert.Contains(t, results[1].Error.Message, "network unreachable") + assert.Empty(t, results[1].StatementID) +} + +func TestExecuteBatchSetsOnWaitTimeoutContinue(t *testing.T) { + // Guards against a silent SDK default flip from CONTINUE to CANCEL. + ctx := cmdio.MockDiscard(t.Context()) + mockAPI := mocksql.NewMockStatementExecutionInterface(t) + + mockAPI.EXPECT().ExecuteStatement(mock.Anything, mock.MatchedBy(func(req sql.ExecuteStatementRequest) bool { + return req.WaitTimeout == "0s" && req.OnWaitTimeout == sql.ExecuteStatementRequestOnWaitTimeoutContinue + })).Return(&sql.StatementResponse{ + StatementId: "stmt-x", + Status: &sql.StatementStatus{State: sql.StatementStateSucceeded}, + }, nil).Times(2) + + results := executeBatch(ctx, mockAPI, "wh-123", []string{"q1", "q2"}, 8) + require.Len(t, results, 2) +} + +func TestExecuteBatchPreservesInputOrder(t *testing.T) { + // Index 0 is slow (PENDING then SUCCEEDED on first poll); 1 and 2 are + // immediate. Despite the staggered completion, results stay in input order. + ctx := cmdio.MockDiscard(t.Context()) + mockAPI := mocksql.NewMockStatementExecutionInterface(t) + + mockAPI.EXPECT().ExecuteStatement(mock.Anything, mock.MatchedBy(func(req sql.ExecuteStatementRequest) bool { + return req.Statement == "SELECT 'slow'" + })).Return(&sql.StatementResponse{ + StatementId: "stmt-slow", + Status: &sql.StatementStatus{State: sql.StatementStatePending}, + }, nil).Once() + + mockAPI.EXPECT().GetStatementByStatementId(mock.Anything, "stmt-slow").Return(&sql.StatementResponse{ + StatementId: "stmt-slow", + Status: &sql.StatementStatus{State: sql.StatementStateSucceeded}, + }, nil).Once() + + for i, sqlStr := range []string{"SELECT 'fast1'", "SELECT 'fast2'"} { + sid := fmt.Sprintf("stmt-fast-%d", i+1) + mockAPI.EXPECT().ExecuteStatement(mock.Anything, mock.MatchedBy(func(req sql.ExecuteStatementRequest) bool { + return req.Statement == sqlStr + })).Return(&sql.StatementResponse{ + StatementId: sid, + Status: &sql.StatementStatus{State: sql.StatementStateSucceeded}, + }, nil).Once() + } + + sqls := []string{"SELECT 'slow'", "SELECT 'fast1'", "SELECT 'fast2'"} + results := executeBatch(ctx, mockAPI, "wh-1", sqls, 8) + + require.Len(t, results, 3) + for i, r := range results { + assert.Equal(t, sqls[i], r.SQL, "result %d", i) + assert.Equal(t, sql.StatementStateSucceeded, r.State, "result %d", i) + } +} + +func TestExecuteBatchContextCancellationCancelsInFlight(t *testing.T) { + // All statements are PENDING when the context is cancelled. cancelInFlight + // sweeps the in-flight set with CancelExecution. Each cancel RPC must + // carry a NON-cancelled context, otherwise the SDK short-circuits on + // ctx.Err() and never reaches the warehouse. + ctx, cancel := context.WithCancel(cmdio.MockDiscard(t.Context())) + mockAPI := mocksql.NewMockStatementExecutionInterface(t) + + aliveCtx := mock.MatchedBy(func(c context.Context) bool { + return c.Err() == nil + }) + + for i, sqlStr := range []string{"q1", "q2", "q3"} { + sid := fmt.Sprintf("stmt-%d", i+1) + mockAPI.EXPECT().ExecuteStatement(mock.Anything, mock.MatchedBy(func(req sql.ExecuteStatementRequest) bool { + return req.Statement == sqlStr + })).Return(&sql.StatementResponse{ + StatementId: sid, + Status: &sql.StatementStatus{State: sql.StatementStatePending}, + }, nil).Once() + + mockAPI.EXPECT().CancelExecution(aliveCtx, sql.CancelExecutionRequest{ + StatementId: sid, + }).Return(nil).Once() + } + + cancel() + + results := executeBatch(ctx, mockAPI, "wh", []string{"q1", "q2", "q3"}, 8) + + require.Len(t, results, 3) + for i, r := range results { + assert.Equal(t, sql.StatementStateCanceled, r.State, "result %d state", i) + require.NotNil(t, r.Error, "result %d error", i) + } +} diff --git a/experimental/aitools/cmd/discover_schema.go b/experimental/aitools/cmd/discover_schema.go index fad77cd4d1..091222368d 100644 --- a/experimental/aitools/cmd/discover_schema.go +++ b/experimental/aitools/cmd/discover_schema.go @@ -4,22 +4,96 @@ import ( "context" "errors" "fmt" + "os" + "os/signal" "regexp" + "slices" "strings" + "sync" + "syscall" "github.com/databricks/cli/cmd/root" "github.com/databricks/cli/experimental/aitools/lib/middlewares" "github.com/databricks/cli/experimental/aitools/lib/session" "github.com/databricks/cli/libs/cmdctx" "github.com/databricks/cli/libs/cmdio" + "github.com/databricks/cli/libs/log" "github.com/databricks/databricks-sdk-go" dbsql "github.com/databricks/databricks-sdk-go/service/sql" "github.com/spf13/cobra" + "golang.org/x/sync/errgroup" ) var sqlIdentifierRe = regexp.MustCompile(`^[A-Za-z_][A-Za-z0-9_]*$`) +// sqlGate caps in-flight SQL statements globally and records each statement_id +// so a Ctrl+C sweep can cancel anything still running server-side. The gate's +// concurrency limit applies across all probes (DESCRIBE, sample SELECT, null +// counts) and across all tables, so --concurrency really means "max statements +// in flight," not "max tables in flight." +type sqlGate struct { + sem chan struct{} + mu sync.Mutex + ids []string +} + +func newSQLGate(limit int) *sqlGate { + return &sqlGate{sem: make(chan struct{}, limit)} +} + +// run executes a SQL statement asynchronously, polls until terminal, and +// records the statement_id so it can be cancelled if the parent context is +// cancelled. Acquires a slot from the gate before submitting and releases it +// when polling completes (or the caller's context is cancelled). +func (g *sqlGate) run(ctx context.Context, w *databricks.WorkspaceClient, warehouseID, statement string) (*dbsql.StatementResponse, error) { + // If the caller cancelled before we even tried, don't enter the select: + // when the gate has free slots both cases are ready and Go picks one + // pseudo-randomly. Without this early-out we'd occasionally submit a + // statement under a cancelled context. + if err := ctx.Err(); err != nil { + return nil, err + } + select { + case g.sem <- struct{}{}: + defer func() { <-g.sem }() + case <-ctx.Done(): + return nil, ctx.Err() + } + + resp, err := w.StatementExecution.ExecuteStatement(ctx, dbsql.ExecuteStatementRequest{ + WarehouseId: warehouseID, + Statement: statement, + WaitTimeout: "0s", + OnWaitTimeout: dbsql.ExecuteStatementRequestOnWaitTimeoutContinue, + }) + if err != nil { + return nil, fmt.Errorf("execute statement: %w", err) + } + + g.mu.Lock() + g.ids = append(g.ids, resp.StatementId) + g.mu.Unlock() + + pollResp, err := pollStatement(ctx, w.StatementExecution, resp) + if err != nil { + return nil, err + } + if err := checkFailedState(pollResp.Status); err != nil { + return nil, err + } + return pollResp, nil +} + +// trackedIDs returns a snapshot of statement_ids submitted through this gate. +func (g *sqlGate) trackedIDs() []string { + g.mu.Lock() + defer g.mu.Unlock() + return slices.Clone(g.ids) +} + func newDiscoverSchemaCmd() *cobra.Command { + var concurrency int + cmd := &cobra.Command{ Use: "discover-schema TABLE...", Short: "Discover schema for one or more tables", @@ -31,21 +105,33 @@ For each table, returns: - Column names and types - Sample data (5 rows) - Null counts per column -- Total row count`, +- Total row count + +Tables and probes (DESCRIBE, sample SELECT, null counts) all share a +single warehouse-statement budget. --concurrency (default 8) caps the +total number of statements in flight at any moment, regardless of how +many tables you pass in. + +On Ctrl+C, in-flight statements are cancelled server-side via +CancelExecution before the command exits.`, Example: ` databricks experimental aitools tools discover-schema samples.nyctaxi.trips databricks experimental aitools tools discover-schema catalog.schema.table1 catalog.schema.table2`, - Args: cobra.MinimumNArgs(1), - PreRunE: root.MustWorkspaceClient, - RunE: func(cmd *cobra.Command, args []string) error { - ctx := cmd.Context() - w := cmdctx.WorkspaceClient(ctx) - - // validate table names: each part must be a safe SQL identifier + Args: cobra.MinimumNArgs(1), + PreRunE: func(cmd *cobra.Command, args []string) error { + if concurrency <= 0 { + return errInvalidBatchConcurrency + } + // Reject malformed identifiers before any auth/profile work. for _, table := range args { if _, err := quoteTableName(table); err != nil { return err } } + return root.MustWorkspaceClient(cmd, args) + }, + RunE: func(cmd *cobra.Command, args []string) error { + ctx := cmd.Context() + w := cmdctx.WorkspaceClient(ctx) // set up session with client for middleware compatibility sess := session.NewSession() @@ -57,13 +143,43 @@ For each table, returns: return err } - var results []string - for _, table := range args { - result, err := discoverTable(ctx, w, warehouseID, table) - if err != nil { - result = fmt.Sprintf("Error discovering %s: %v", table, err) + pollCtx, pollCancel := context.WithCancel(ctx) + defer pollCancel() + + sigCh := make(chan os.Signal, 1) + signal.Notify(sigCh, os.Interrupt, syscall.SIGTERM) + defer signal.Stop(sigCh) + + go func() { + select { + case <-sigCh: + log.Infof(ctx, "Received interrupt, cancelling in-flight discover-schema statements") + pollCancel() + case <-pollCtx.Done(): } - results = append(results, result) + }() + + gate := newSQLGate(concurrency) + + results := make([]string, len(args)) + g := new(errgroup.Group) + for i, table := range args { + g.Go(func() error { + result, err := discoverTable(pollCtx, gate, w, warehouseID, table) + if err != nil { + results[i] = fmt.Sprintf("Error discovering %s: %v", table, err) + } else { + results[i] = result + } + // A failure on one table shouldn't abort the others. + return nil + }) + } + _ = g.Wait() + + if pollCtx.Err() != nil { + cancelDiscoverInFlight(ctx, w.StatementExecution, gate.trackedIDs()) + return root.ErrAlreadyPrinted } // format output with dividers for multiple tables @@ -90,20 +206,39 @@ For each table, returns: }, } + cmd.Flags().IntVar(&concurrency, "concurrency", defaultBatchConcurrency, "Maximum SQL statements in flight at once across all tables and probes") + return cmd } -func discoverTable(ctx context.Context, w *databricks.WorkspaceClient, warehouseID, table string) (string, error) { - var sb strings.Builder +// cancelDiscoverInFlight sends CancelExecution for every recorded statement_id. +// Best effort: errors are logged but don't fail the user-visible exit. +// Statements that already finished server-side return an error which we just +// swallow at warn level; the alternative (per-statement state tracking) isn't +// worth the bookkeeping here. +func cancelDiscoverInFlight(ctx context.Context, api dbsql.StatementExecutionInterface, ids []string) { + if len(ids) == 0 { + cmdio.LogString(ctx, "discover-schema cancelled.") + return + } + for _, id := range ids { + cancelCtx, cancel := context.WithTimeout(ctx, cancelTimeout) + if err := api.CancelExecution(cancelCtx, dbsql.CancelExecutionRequest{StatementId: id}); err != nil { + log.Warnf(ctx, "Failed to cancel statement %s: %v", id, err) + } + cancel() + } + cmdio.LogString(ctx, fmt.Sprintf("discover-schema cancelled; sent CancelExecution for %d statement(s).", len(ids))) +} +func discoverTable(ctx context.Context, gate *sqlGate, w *databricks.WorkspaceClient, warehouseID, table string) (string, error) { quoted, err := quoteTableName(table) if err != nil { return "", err } // 1. describe table - get columns and types - describeSQL := "DESCRIBE TABLE " + quoted - descResp, err := executeSQL(ctx, w, warehouseID, describeSQL) + descResp, err := gate.run(ctx, w, warehouseID, "DESCRIBE TABLE "+quoted) if err != nil { return "", fmt.Errorf("describe table: %w", err) } @@ -113,32 +248,55 @@ func discoverTable(ctx context.Context, w *databricks.WorkspaceClient, warehouse return "", errors.New("no columns found") } + // 2 + 3. Sample data and null counts run in parallel; both depend only on + // the column list (already known) and not on each other. The gate (not + // errgroup) is what actually limits warehouse concurrency. + sampleSQL := fmt.Sprintf("SELECT * FROM %s LIMIT 5", quoted) + + nullCountExprs := make([]string, len(columns)) + for i, col := range columns { + // Backticks inside an identifier are escaped by doubling them in + // Databricks/Delta SQL (`` ` `` → `` `` ``). Without this, a column + // name containing a backtick would terminate the quoted identifier + // mid-string and produce a PARSE_SYNTAX_ERROR. Sample-data uses + // SELECT * so the failure shows up only as a confusing + // "NULL COUNTS: Error - ..." line in the user-facing output. + escaped := strings.ReplaceAll(col, "`", "``") + nullCountExprs[i] = fmt.Sprintf("SUM(CASE WHEN `%s` IS NULL THEN 1 ELSE 0 END) AS `%s_nulls`", escaped, escaped) + } + nullSQL := fmt.Sprintf("SELECT COUNT(*) AS total_rows, %s FROM %s", + strings.Join(nullCountExprs, ", "), quoted) + + var sampleResp, nullResp *dbsql.StatementResponse + var sampleErr, nullErr error + + g := new(errgroup.Group) + g.Go(func() error { + sampleResp, sampleErr = gate.run(ctx, w, warehouseID, sampleSQL) + return nil + }) + g.Go(func() error { + nullResp, nullErr = gate.run(ctx, w, warehouseID, nullSQL) + return nil + }) + _ = g.Wait() + + // Assemble the output in the established order: columns, sample, null counts. + var sb strings.Builder sb.WriteString("COLUMNS:\n") for i, col := range columns { fmt.Fprintf(&sb, " %s: %s\n", col, types[i]) } - // 2. sample data (5 rows) - sampleSQL := fmt.Sprintf("SELECT * FROM %s LIMIT 5", quoted) - sampleResp, err := executeSQL(ctx, w, warehouseID, sampleSQL) - if err != nil { - fmt.Fprintf(&sb, "\nSAMPLE DATA: Error - %v\n", err) + if sampleErr != nil { + fmt.Fprintf(&sb, "\nSAMPLE DATA: Error - %v\n", sampleErr) } else { sb.WriteString("\nSAMPLE DATA:\n") sb.WriteString(formatTableData(sampleResp)) } - // 3. null counts per column - nullCountExprs := make([]string, len(columns)) - for i, col := range columns { - nullCountExprs[i] = fmt.Sprintf("SUM(CASE WHEN `%s` IS NULL THEN 1 ELSE 0 END) AS `%s_nulls`", col, col) - } - nullSQL := fmt.Sprintf("SELECT COUNT(*) AS total_rows, %s FROM %s", - strings.Join(nullCountExprs, ", "), quoted) - - nullResp, err := executeSQL(ctx, w, warehouseID, nullSQL) - if err != nil { - fmt.Fprintf(&sb, "\nNULL COUNTS: Error - %v\n", err) + if nullErr != nil { + fmt.Fprintf(&sb, "\nNULL COUNTS: Error - %v\n", nullErr) } else { sb.WriteString("\nNULL COUNTS:\n") sb.WriteString(formatNullCounts(nullResp, columns)) @@ -147,27 +305,6 @@ func discoverTable(ctx context.Context, w *databricks.WorkspaceClient, warehouse return sb.String(), nil } -func executeSQL(ctx context.Context, w *databricks.WorkspaceClient, warehouseID, statement string) (*dbsql.StatementResponse, error) { - resp, err := w.StatementExecution.ExecuteAndWait(ctx, dbsql.ExecuteStatementRequest{ - WarehouseId: warehouseID, - Statement: statement, - WaitTimeout: "50s", - }) - if err != nil { - return nil, err - } - - if resp.Status != nil && resp.Status.State == dbsql.StatementStateFailed { - errMsg := "query failed" - if resp.Status.Error != nil { - errMsg = resp.Status.Error.Message - } - return nil, errors.New(errMsg) - } - - return resp, nil -} - func parseDescribeResult(resp *dbsql.StatementResponse) (columns, types []string) { if resp.Result == nil || resp.Result.DataArray == nil { return nil, nil diff --git a/experimental/aitools/cmd/discover_schema_test.go b/experimental/aitools/cmd/discover_schema_test.go new file mode 100644 index 0000000000..fe6d86e799 --- /dev/null +++ b/experimental/aitools/cmd/discover_schema_test.go @@ -0,0 +1,326 @@ +package aitools + +import ( + "context" + "errors" + "strings" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/databricks/cli/libs/cmdio" + "github.com/databricks/databricks-sdk-go" + mocksql "github.com/databricks/databricks-sdk-go/experimental/mocks/service/sql" + dbsql "github.com/databricks/databricks-sdk-go/service/sql" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + +func TestQuoteTableName(t *testing.T) { + tests := []struct { + name string + in string + want string + wantErr string + }{ + {"valid", "main.public.orders", "`main`.`public`.`orders`", ""}, + {"underscores ok", "_a.b_c.d_e", "`_a`.`b_c`.`d_e`", ""}, + {"missing parts", "public.orders", "", "expected CATALOG.SCHEMA.TABLE"}, + {"too many parts", "a.b.c.d", "", "expected CATALOG.SCHEMA.TABLE"}, + {"injection in catalog", "a;DROP--.b.c", "", "invalid SQL identifier"}, + {"backtick in name", "a.b.c`d", "", "invalid SQL identifier"}, + {"empty part", "a..c", "", "invalid SQL identifier"}, + {"starts with digit", "1main.public.orders", "", "invalid SQL identifier"}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got, err := quoteTableName(tc.in) + if tc.wantErr != "" { + require.Error(t, err) + assert.Contains(t, err.Error(), tc.wantErr) + return + } + require.NoError(t, err) + assert.Equal(t, tc.want, got) + }) + } +} + +func TestParseDescribeResultSkipsMetadataRows(t *testing.T) { + resp := &dbsql.StatementResponse{ + Result: &dbsql.ResultData{DataArray: [][]string{ + {"id", "BIGINT", ""}, + {"name", "STRING", ""}, + {"# Partition Information", "", ""}, + {"region", "STRING", ""}, + {"", "STRING", ""}, + }}, + } + + cols, types := parseDescribeResult(resp) + assert.Equal(t, []string{"id", "name", "region"}, cols) + assert.Equal(t, []string{"BIGINT", "STRING", "STRING"}, types) +} + +func TestSQLGateRunPinsOnWaitTimeoutAndRecordsID(t *testing.T) { + ctx := cmdio.MockDiscard(t.Context()) + mockAPI := mocksql.NewMockStatementExecutionInterface(t) + + mockAPI.EXPECT().ExecuteStatement(mock.Anything, mock.MatchedBy(func(req dbsql.ExecuteStatementRequest) bool { + return req.Statement == "SELECT 1" && + req.WaitTimeout == "0s" && + req.OnWaitTimeout == dbsql.ExecuteStatementRequestOnWaitTimeoutContinue + })).Return(&dbsql.StatementResponse{ + StatementId: "stmt-1", + Status: &dbsql.StatementStatus{State: dbsql.StatementStateSucceeded}, + Result: &dbsql.ResultData{DataArray: [][]string{{"1"}}}, + }, nil).Once() + + w := &databricks.WorkspaceClient{StatementExecution: mockAPI} + gate := newSQLGate(2) + + resp, err := gate.run(ctx, w, "wh-1", "SELECT 1") + require.NoError(t, err) + assert.Equal(t, "stmt-1", resp.StatementId) + assert.Equal(t, []string{"stmt-1"}, gate.trackedIDs()) +} + +func TestSQLGateRunPropagatesFailedState(t *testing.T) { + ctx := cmdio.MockDiscard(t.Context()) + mockAPI := mocksql.NewMockStatementExecutionInterface(t) + + mockAPI.EXPECT().ExecuteStatement(mock.Anything, mock.Anything).Return(&dbsql.StatementResponse{ + StatementId: "stmt-1", + Status: &dbsql.StatementStatus{ + State: dbsql.StatementStateFailed, + Error: &dbsql.ServiceError{ErrorCode: "SYNTAX_ERROR", Message: "near 'oops'"}, + }, + }, nil).Once() + + w := &databricks.WorkspaceClient{StatementExecution: mockAPI} + gate := newSQLGate(2) + + _, err := gate.run(ctx, w, "wh-1", "SELECT oops") + require.Error(t, err) + assert.Contains(t, err.Error(), "SYNTAX_ERROR") + // Even on failure, the id is recorded so a cancellation sweep can clean up. + assert.Equal(t, []string{"stmt-1"}, gate.trackedIDs()) +} + +func TestSQLGateRunWrapsTransportError(t *testing.T) { + ctx := cmdio.MockDiscard(t.Context()) + mockAPI := mocksql.NewMockStatementExecutionInterface(t) + + mockAPI.EXPECT().ExecuteStatement(mock.Anything, mock.Anything). + Return(nil, errors.New("network unreachable")).Once() + + w := &databricks.WorkspaceClient{StatementExecution: mockAPI} + gate := newSQLGate(2) + + _, err := gate.run(ctx, w, "wh-1", "SELECT 1") + require.Error(t, err) + assert.Contains(t, err.Error(), "execute statement") + assert.Contains(t, err.Error(), "network unreachable") + assert.Empty(t, gate.trackedIDs(), "no id should be recorded when ExecuteStatement fails") +} + +func TestSQLGateRunRespectsCancelledContext(t *testing.T) { + // With ctx already cancelled, gate.run must not call any API method: + // it bails at the semaphore-acquire select. + ctx, cancel := context.WithCancel(cmdio.MockDiscard(t.Context())) + cancel() + + mockAPI := mocksql.NewMockStatementExecutionInterface(t) + w := &databricks.WorkspaceClient{StatementExecution: mockAPI} + gate := newSQLGate(2) + + _, err := gate.run(ctx, w, "wh-1", "SELECT 1") + require.ErrorIs(t, err, context.Canceled) +} + +func TestDiscoverTableRunsSampleAndNullsConcurrently(t *testing.T) { + // Deterministic barrier: both probes must enter before either is allowed + // to leave. If gate.run/discoverTable serialized them, the first probe + // would time out and return an error, which would surface as + // "SAMPLE DATA: Error - " or "NULL COUNTS: Error - " in the output. + ctx := cmdio.MockDiscard(t.Context()) + mockAPI := mocksql.NewMockStatementExecutionInterface(t) + + mockAPI.EXPECT().ExecuteStatement(mock.Anything, mock.MatchedBy(func(req dbsql.ExecuteStatementRequest) bool { + return strings.HasPrefix(req.Statement, "DESCRIBE TABLE") + })).Return(&dbsql.StatementResponse{ + StatementId: "stmt-desc", + Status: &dbsql.StatementStatus{State: dbsql.StatementStateSucceeded}, + Result: &dbsql.ResultData{DataArray: [][]string{ + {"id", "BIGINT", ""}, + {"name", "STRING", ""}, + }}, + }, nil).Once() + + const numProbes = 2 + var dispatched atomic.Int32 + release := make(chan struct{}) + closeRelease := sync.OnceFunc(func() { close(release) }) + + probe := func(ctx context.Context, req dbsql.ExecuteStatementRequest) (*dbsql.StatementResponse, error) { + if dispatched.Add(1) == numProbes { + closeRelease() + } + select { + case <-release: + case <-time.After(2 * time.Second): + return nil, errors.New("probe timeout: not running concurrently") + } + return &dbsql.StatementResponse{ + StatementId: "stmt-probe-" + req.Statement[:7], + Status: &dbsql.StatementStatus{State: dbsql.StatementStateSucceeded}, + Manifest: &dbsql.ResultManifest{Schema: &dbsql.ResultSchema{Columns: []dbsql.ColumnInfo{{Name: "x"}}}}, + Result: &dbsql.ResultData{DataArray: [][]string{{"0"}}}, + }, nil + } + + mockAPI.EXPECT().ExecuteStatement(mock.Anything, mock.MatchedBy(func(req dbsql.ExecuteStatementRequest) bool { + return strings.HasPrefix(req.Statement, "SELECT *") + })).RunAndReturn(probe).Once() + + mockAPI.EXPECT().ExecuteStatement(mock.Anything, mock.MatchedBy(func(req dbsql.ExecuteStatementRequest) bool { + return strings.Contains(req.Statement, "SUM(CASE WHEN") + })).RunAndReturn(probe).Once() + + w := &databricks.WorkspaceClient{StatementExecution: mockAPI} + gate := newSQLGate(8) + + out, err := discoverTable(ctx, gate, w, "wh-1", "main.public.orders") + require.NoError(t, err) + assert.Equal(t, int32(numProbes), dispatched.Load(), "both probes should have entered concurrently") + assert.NotContains(t, out, "Error - ", "no probe should have surfaced an error") + assert.Contains(t, out, "COLUMNS:") + assert.Contains(t, out, "SAMPLE DATA:") + assert.Contains(t, out, "NULL COUNTS:") +} + +func TestDiscoverTableSampleErrorDoesNotAbortNullCounts(t *testing.T) { + ctx := cmdio.MockDiscard(t.Context()) + mockAPI := mocksql.NewMockStatementExecutionInterface(t) + + mockAPI.EXPECT().ExecuteStatement(mock.Anything, mock.MatchedBy(func(req dbsql.ExecuteStatementRequest) bool { + return strings.HasPrefix(req.Statement, "DESCRIBE TABLE") + })).Return(&dbsql.StatementResponse{ + StatementId: "stmt-desc", + Status: &dbsql.StatementStatus{State: dbsql.StatementStateSucceeded}, + Result: &dbsql.ResultData{DataArray: [][]string{{"id", "BIGINT", ""}}}, + }, nil).Once() + + mockAPI.EXPECT().ExecuteStatement(mock.Anything, mock.MatchedBy(func(req dbsql.ExecuteStatementRequest) bool { + return strings.HasPrefix(req.Statement, "SELECT *") + })).Return(&dbsql.StatementResponse{ + StatementId: "stmt-sample", + Status: &dbsql.StatementStatus{ + State: dbsql.StatementStateFailed, + Error: &dbsql.ServiceError{ErrorCode: "PERM", Message: "permission denied"}, + }, + }, nil).Once() + + mockAPI.EXPECT().ExecuteStatement(mock.Anything, mock.MatchedBy(func(req dbsql.ExecuteStatementRequest) bool { + return strings.Contains(req.Statement, "SUM(CASE WHEN") + })).Return(&dbsql.StatementResponse{ + StatementId: "stmt-null", + Status: &dbsql.StatementStatus{State: dbsql.StatementStateSucceeded}, + Manifest: &dbsql.ResultManifest{Schema: &dbsql.ResultSchema{Columns: []dbsql.ColumnInfo{{Name: "total_rows"}, {Name: "id_nulls"}}}}, + Result: &dbsql.ResultData{DataArray: [][]string{{"100", "0"}}}, + }, nil).Once() + + w := &databricks.WorkspaceClient{StatementExecution: mockAPI} + gate := newSQLGate(8) + + out, err := discoverTable(ctx, gate, w, "wh-1", "main.public.orders") + require.NoError(t, err) + assert.Contains(t, out, "SAMPLE DATA: Error - ") + assert.Contains(t, out, "permission denied") + assert.Contains(t, out, "NULL COUNTS:") + assert.Contains(t, out, "total_rows: 100") +} + +func TestDiscoverTableEscapesBackticksInColumnNames(t *testing.T) { + // Databricks/Delta DDL allows backticks in column names via doubled- + // backtick escaping (e.g. CREATE TABLE t (`weird``col` STRING)). Without + // escaping in the null-counts SQL the embedded backtick would terminate + // the quoted identifier mid-string and produce a PARSE_SYNTAX_ERROR. + ctx := cmdio.MockDiscard(t.Context()) + mockAPI := mocksql.NewMockStatementExecutionInterface(t) + + mockAPI.EXPECT().ExecuteStatement(mock.Anything, mock.MatchedBy(func(req dbsql.ExecuteStatementRequest) bool { + return strings.HasPrefix(req.Statement, "DESCRIBE TABLE") + })).Return(&dbsql.StatementResponse{ + StatementId: "stmt-desc", + Status: &dbsql.StatementStatus{State: dbsql.StatementStateSucceeded}, + Result: &dbsql.ResultData{DataArray: [][]string{ + {"weird`col", "STRING", ""}, + }}, + }, nil).Once() + + mockAPI.EXPECT().ExecuteStatement(mock.Anything, mock.MatchedBy(func(req dbsql.ExecuteStatementRequest) bool { + return strings.HasPrefix(req.Statement, "SELECT *") + })).Return(&dbsql.StatementResponse{ + StatementId: "stmt-sample", + Status: &dbsql.StatementStatus{State: dbsql.StatementStateSucceeded}, + }, nil).Once() + + // Null-counts SQL must escape the embedded backtick. Both the identifier + // and the alias positions must use the doubled form. + mockAPI.EXPECT().ExecuteStatement(mock.Anything, mock.MatchedBy(func(req dbsql.ExecuteStatementRequest) bool { + return strings.Contains(req.Statement, "`weird``col`") && + strings.Contains(req.Statement, "`weird``col_nulls`") && + !strings.Contains(req.Statement, "`weird`col`") + })).Return(&dbsql.StatementResponse{ + StatementId: "stmt-null", + Status: &dbsql.StatementStatus{State: dbsql.StatementStateSucceeded}, + Manifest: &dbsql.ResultManifest{Schema: &dbsql.ResultSchema{Columns: []dbsql.ColumnInfo{{Name: "total_rows"}, {Name: "weird`col_nulls"}}}}, + Result: &dbsql.ResultData{DataArray: [][]string{{"5", "0"}}}, + }, nil).Once() + + w := &databricks.WorkspaceClient{StatementExecution: mockAPI} + gate := newSQLGate(8) + + out, err := discoverTable(ctx, gate, w, "wh-1", "main.public.orders") + require.NoError(t, err) + assert.Contains(t, out, "weird`col") + assert.NotContains(t, out, "Error - ") +} + +func TestCancelDiscoverInFlightCallsAPIPerID(t *testing.T) { + ctx := cmdio.MockDiscard(t.Context()) + mockAPI := mocksql.NewMockStatementExecutionInterface(t) + + for _, id := range []string{"stmt-a", "stmt-b", "stmt-c"} { + mockAPI.EXPECT().CancelExecution(mock.Anything, dbsql.CancelExecutionRequest{ + StatementId: id, + }).Return(nil).Once() + } + + cancelDiscoverInFlight(ctx, mockAPI, []string{"stmt-a", "stmt-b", "stmt-c"}) +} + +func TestDiscoverSchemaConcurrencyRejection(t *testing.T) { + for _, value := range []string{"0", "-1"} { + t.Run(value, func(t *testing.T) { + cmd := newDiscoverSchemaCmd() + cmd.SetArgs([]string{"--concurrency", value, "main.public.orders"}) + err := cmd.Execute() + require.ErrorIs(t, err, errInvalidBatchConcurrency) + }) + } +} + +func TestDiscoverSchemaInvalidTableNameRejectedBeforeWorkspaceClient(t *testing.T) { + // PreRunE rejects malformed identifiers before MustWorkspaceClient runs, + // so the test passes without any workspace mocking. + cmd := newDiscoverSchemaCmd() + cmd.SetArgs([]string{"not-three-parts"}) + err := cmd.Execute() + require.Error(t, err) + assert.Contains(t, err.Error(), "expected CATALOG.SCHEMA.TABLE") +} diff --git a/experimental/aitools/cmd/query.go b/experimental/aitools/cmd/query.go index 7b95fdd4e2..7e9ae1d030 100644 --- a/experimental/aitools/cmd/query.go +++ b/experimental/aitools/cmd/query.go @@ -75,32 +75,47 @@ func selectQueryOutputMode(outputType flags.Output, stdoutInteractive, promptSup func newQueryCmd() *cobra.Command { var warehouseID string - var filePath string + var filePaths []string var outputFormat string + var concurrency int cmd := &cobra.Command{ - Use: "query [SQL | file.sql]", + Use: "query [SQL | file.sql]...", Short: "Execute SQL against a Databricks warehouse", - Long: `Execute a SQL statement against a Databricks SQL warehouse and return results. + Long: `Execute one or more SQL statements against a Databricks SQL warehouse +and return results. -SQL can be provided as a positional argument, read from a file with --file, -or piped via stdin. If the positional argument ends in .sql and the file -exists, it is read as a SQL file automatically. +A single SQL can be provided as a positional argument, read from a file with +--file, or piped via stdin. If a positional argument ends in .sql and the +file exists, it is read as a SQL file automatically. + +Pass multiple positional arguments and/or repeat --file to run several +queries in parallel against the warehouse. Multi-query output is always +JSON: an array of {sql, statement_id, state, elapsed_ms, columns, rows, +error} objects. Result order is: --file inputs first (in flag order), +then positional SQLs (in arg order). The exit code is non-zero if any +query failed. The command auto-detects an available warehouse unless --warehouse is set or the DATABRICKS_WAREHOUSE_ID environment variable is configured. -Output is JSON in non-interactive contexts. In interactive terminals it renders -tables, and large results open an interactive table browser. Use --output csv -to export results as CSV.`, +For a single query, output is JSON in non-interactive contexts. In +interactive terminals it renders tables, and large results open an +interactive table browser. Use --output csv to export results as CSV.`, Example: ` databricks experimental aitools tools query "SELECT * FROM samples.nyctaxi.trips LIMIT 5" databricks experimental aitools tools query --warehouse abc123 "SELECT 1" databricks experimental aitools tools query --file report.sql databricks experimental aitools tools query report.sql databricks experimental aitools tools query --output csv "SELECT * FROM samples.nyctaxi.trips LIMIT 5" + databricks experimental aitools tools query --output json "SELECT 1" "SELECT 2" "SELECT 3" echo "SELECT 1" | databricks experimental aitools tools query`, - Args: cobra.MaximumNArgs(1), - PreRunE: root.MustWorkspaceClient, + Args: cobra.ArbitraryArgs, + PreRunE: func(cmd *cobra.Command, args []string) error { + if concurrency <= 0 { + return errInvalidBatchConcurrency + } + return root.MustWorkspaceClient(cmd, args) + }, RunE: func(cmd *cobra.Command, args []string) error { ctx := cmd.Context() @@ -124,19 +139,29 @@ to export results as CSV.`, return fmt.Errorf("unsupported output format %q, accepted values: text, json, csv", outputFormat) } - w := cmdctx.WorkspaceClient(ctx) - - sqlStatement, err := resolveSQL(ctx, cmd, args, filePath) + sqls, err := resolveSQLs(ctx, cmd, args, filePaths) if err != nil { return err } + // Reject incompatible flag combinations before any API call so the + // user sees the real error instead of an auth/warehouse failure. + if len(sqls) > 1 && flags.Output(outputFormat) != flags.OutputJSON { + return fmt.Errorf("multiple queries require --output json (got %q); pass --output json to receive a JSON array of per-statement results", outputFormat) + } + + w := cmdctx.WorkspaceClient(ctx) + wID, err := resolveWarehouseID(ctx, w, warehouseID) if err != nil { return err } - resp, err := executeAndPoll(ctx, w.StatementExecution, wID, sqlStatement) + if len(sqls) > 1 { + return runBatch(ctx, cmd, w.StatementExecution, wID, sqls, concurrency) + } + + resp, err := executeAndPoll(ctx, w.StatementExecution, wID, sqls[0]) if err != nil { return err } @@ -177,7 +202,8 @@ to export results as CSV.`, } cmd.Flags().StringVarP(&warehouseID, "warehouse", "w", "", "SQL warehouse ID to use for execution") - cmd.Flags().StringVarP(&filePath, "file", "f", "", "Path to a SQL file to execute") + cmd.Flags().StringSliceVarP(&filePaths, "file", "f", nil, "Path to a SQL file to execute (repeatable; pair with positional SQLs to run a batch)") + cmd.Flags().IntVar(&concurrency, "concurrency", defaultBatchConcurrency, "Maximum in-flight statements when running a batch of queries") // Local --output flag shadows the root command's persistent --output flag, // adding csv support for this command only. cmd.Flags().StringVarP(&outputFormat, "output", "o", string(flags.OutputText), "Output format: text, json, or csv") @@ -188,59 +214,85 @@ to export results as CSV.`, return cmd } -// resolveSQL determines the SQL statement to execute from the available input sources. -// Priority: --file flag > positional arg > stdin. -func resolveSQL(ctx context.Context, cmd *cobra.Command, args []string, filePath string) (string, error) { - var raw string +// resolveSQLs collects SQL statements from --file paths, positional args, and +// stdin. The returned slice preserves source order: --file paths first (in flag +// order), then positional args (in arg order), then stdin (only if no other +// source produced anything). Each SQL is run through cleanSQL. +func resolveSQLs(ctx context.Context, cmd *cobra.Command, args, filePaths []string) ([]string, error) { + var raws []string - switch { - case filePath != "": - if len(args) > 0 { - return "", errors.New("cannot use both --file and a positional SQL argument") - } - data, err := os.ReadFile(filePath) + for _, path := range filePaths { + data, err := os.ReadFile(path) if err != nil { - return "", fmt.Errorf("read SQL file: %w", err) + return nil, fmt.Errorf("read SQL file %s: %w", path, err) } - raw = string(data) + raws = append(raws, string(data)) + } - case len(args) > 0: + for _, arg := range args { // If the argument looks like a .sql file, try to read it. // Only fall through to literal SQL if the file doesn't exist. // Surface other errors (permission denied, etc.) directly. - if strings.HasSuffix(args[0], sqlFileExtension) { - data, err := os.ReadFile(args[0]) + if strings.HasSuffix(arg, sqlFileExtension) { + data, err := os.ReadFile(arg) if err != nil && !errors.Is(err, os.ErrNotExist) { - return "", fmt.Errorf("read SQL file: %w", err) + return nil, fmt.Errorf("read SQL file: %w", err) } if err == nil { - raw = string(data) - break + raws = append(raws, string(data)) + continue } } - raw = args[0] + raws = append(raws, arg) + } - default: - // No args: try reading from stdin if it's piped. + if len(raws) == 0 { + // No --file and no positional args: try reading from stdin if it's piped. // If stdin was overridden (e.g. cmd.SetIn in tests), always read from it. // Otherwise, only read if stdin is not a TTY (i.e. piped input). in := cmd.InOrStdin() _, isOsFile := in.(*os.File) if isOsFile && cmdio.IsPromptSupported(ctx) { - return "", errors.New("no SQL provided; pass a SQL string, use --file, or pipe via stdin") + return nil, errors.New("no SQL provided; pass a SQL string, use --file, or pipe via stdin") } data, err := io.ReadAll(in) if err != nil { - return "", fmt.Errorf("read stdin: %w", err) + return nil, fmt.Errorf("read stdin: %w", err) + } + raws = append(raws, string(data)) + } + + cleaned := make([]string, 0, len(raws)) + for i, raw := range raws { + c := cleanSQL(raw) + if c == "" { + if len(raws) == 1 { + return nil, errors.New("SQL statement is empty after removing comments and blank lines") + } + return nil, fmt.Errorf("SQL statement #%d is empty after removing comments and blank lines", i+1) } - raw = string(data) + cleaned = append(cleaned, c) + } + return cleaned, nil +} + +// runBatch executes multiple SQL statements in parallel and renders the result +// as a JSON array. Returns root.ErrAlreadyPrinted (so the exit code is non-zero +// without an extra error message) when any statement failed; the failure detail +// is already encoded in the printed JSON. The caller is responsible for +// rejecting incompatible output formats before invoking this. +func runBatch(ctx context.Context, cmd *cobra.Command, api sql.StatementExecutionInterface, warehouseID string, sqls []string, concurrency int) error { + results := executeBatch(ctx, api, warehouseID, sqls, concurrency) + if err := renderBatchJSON(cmd.OutOrStdout(), results); err != nil { + return err } - result := cleanSQL(raw) - if result == "" { - return "", errors.New("SQL statement is empty after removing comments and blank lines") + for _, r := range results { + if r.Error != nil { + return root.ErrAlreadyPrinted + } } - return result, nil + return nil } // resolveWarehouseID returns the warehouse ID to use for query execution. @@ -262,9 +314,10 @@ func resolveWarehouseID(ctx context.Context, w any, flagValue string) (string, e func executeAndPoll(ctx context.Context, api sql.StatementExecutionInterface, warehouseID, statement string) (*sql.StatementResponse, error) { // Submit asynchronously to get the statement ID immediately for cancellation. resp, err := api.ExecuteStatement(ctx, sql.ExecuteStatementRequest{ - WarehouseId: warehouseID, - Statement: statement, - WaitTimeout: "0s", + WarehouseId: warehouseID, + Statement: statement, + WaitTimeout: "0s", + OnWaitTimeout: sql.ExecuteStatementRequestOnWaitTimeoutContinue, }) if err != nil { return nil, fmt.Errorf("execute statement: %w", err) @@ -272,11 +325,6 @@ func executeAndPoll(ctx context.Context, api sql.StatementExecutionInterface, wa statementID := resp.StatementId - // Check if it completed immediately. - if isTerminalState(resp.Status) { - return resp, checkFailedState(resp.Status) - } - // Set up Ctrl+C: signal cancels the poll context, cleanup is unified below. pollCtx, pollCancel := context.WithCancel(ctx) defer pollCancel() @@ -297,8 +345,11 @@ func executeAndPoll(ctx context.Context, api sql.StatementExecutionInterface, wa // cancelStatement performs best-effort server-side cancellation. // Called on any poll exit due to context cancellation (signal or parent). cancelStatement := func() { - // Use the parent context (ctx), not the cancelled pollCtx. - cancelCtx, cancel := context.WithTimeout(ctx, cancelTimeout) + // Detach from any cancellation on the inbound ctx (the caller might + // have cancelled the parent before invoking this path): WithoutCancel + // preserves values but drops cancellation so the cancel RPC actually + // reaches the warehouse. + cancelCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), cancelTimeout) defer cancel() if err := api.CancelExecution(cancelCtx, sql.CancelExecutionRequest{ StatementId: statementID, @@ -327,34 +378,59 @@ func executeAndPoll(ctx context.Context, api sql.StatementExecutionInterface, wa } }() + pollResp, err := pollStatement(pollCtx, api, resp) + if err != nil { + if pollCtx.Err() != nil { + cancelStatement() + cmdio.LogString(ctx, "Query cancelled.") + return nil, root.ErrAlreadyPrinted + } + return nil, err + } + + sp.Close() + if err := checkFailedState(pollResp.Status); err != nil { + return nil, err + } + return pollResp, nil +} + +// pollStatement polls until the statement reaches a terminal state. +// +// On context cancellation it returns the context error WITHOUT cancelling the +// server-side statement. Callers that want server-side cancellation should +// invoke CancelExecution explicitly. +// +// If the input response is already in a terminal state, it is returned without +// further polling. +func pollStatement(ctx context.Context, api sql.StatementExecutionInterface, resp *sql.StatementResponse) (*sql.StatementResponse, error) { + if isTerminalState(resp.Status) { + return resp, nil + } + + statementID := resp.StatementId + start := time.Now() + // Poll with additive backoff: 1s, 2s, 3s, 4s, 5s (capped). interval := pollIntervalInitial for { select { - case <-pollCtx.Done(): - cancelStatement() - cmdio.LogString(ctx, "Query cancelled.") - return nil, root.ErrAlreadyPrinted + case <-ctx.Done(): + return nil, ctx.Err() case <-time.After(interval): } log.Debugf(ctx, "Polling statement %s: %s elapsed", statementID, time.Since(start).Truncate(time.Second)) - pollResp, err := api.GetStatementByStatementId(pollCtx, statementID) + pollResp, err := api.GetStatementByStatementId(ctx, statementID) if err != nil { - if pollCtx.Err() != nil { - cancelStatement() - cmdio.LogString(ctx, "Query cancelled.") - return nil, root.ErrAlreadyPrinted + if ctx.Err() != nil { + return nil, ctx.Err() } return nil, fmt.Errorf("poll statement status: %w", err) } if isTerminalState(pollResp.Status) { - sp.Close() - if err := checkFailedState(pollResp.Status); err != nil { - return nil, err - } return &sql.StatementResponse{ StatementId: pollResp.StatementId, Status: pollResp.Status, diff --git a/experimental/aitools/cmd/query_test.go b/experimental/aitools/cmd/query_test.go index aa33921c83..59de11d578 100644 --- a/experimental/aitools/cmd/query_test.go +++ b/experimental/aitools/cmd/query_test.go @@ -2,6 +2,7 @@ package aitools import ( "context" + "errors" "os" "path/filepath" "strings" @@ -48,7 +49,9 @@ func TestExecuteAndPollImmediateSuccess(t *testing.T) { mockAPI := mocksql.NewMockStatementExecutionInterface(t) mockAPI.EXPECT().ExecuteStatement(mock.Anything, mock.MatchedBy(func(req sql.ExecuteStatementRequest) bool { - return req.WarehouseId == "wh-123" && req.Statement == "SELECT 1" && req.WaitTimeout == "0s" + return req.WarehouseId == "wh-123" && req.Statement == "SELECT 1" && + req.WaitTimeout == "0s" && + req.OnWaitTimeout == sql.ExecuteStatementRequestOnWaitTimeoutContinue })).Return(&sql.StatementResponse{ StatementId: "stmt-1", Status: &sql.StatementStatus{State: sql.StatementStateSucceeded}, @@ -143,8 +146,12 @@ func TestExecuteAndPollCancelledContextCallsCancelExecution(t *testing.T) { Status: &sql.StatementStatus{State: sql.StatementStatePending}, }, nil) - // CancelExecution must be called when context is cancelled (not just on signal). - mockAPI.EXPECT().CancelExecution(mock.Anything, sql.CancelExecutionRequest{ + // CancelExecution must be called when context is cancelled (not just on + // signal). Assert the RPC's own ctx is NOT cancelled, otherwise the SDK + // would short-circuit on ctx.Err() and never reach the warehouse. + mockAPI.EXPECT().CancelExecution(mock.MatchedBy(func(c context.Context) bool { + return c.Err() == nil + }), sql.CancelExecutionRequest{ StatementId: "stmt-1", }).Return(nil).Once() @@ -154,6 +161,106 @@ func TestExecuteAndPollCancelledContextCallsCancelExecution(t *testing.T) { require.ErrorIs(t, err, root.ErrAlreadyPrinted) } +func TestPollStatementImmediateTerminal(t *testing.T) { + ctx := cmdio.MockDiscard(t.Context()) + mockAPI := mocksql.NewMockStatementExecutionInterface(t) + + resp := &sql.StatementResponse{ + StatementId: "stmt-1", + Status: &sql.StatementStatus{State: sql.StatementStateSucceeded}, + Manifest: &sql.ResultManifest{Schema: &sql.ResultSchema{Columns: []sql.ColumnInfo{{Name: "1"}}}}, + Result: &sql.ResultData{DataArray: [][]string{{"1"}}}, + } + + pollResp, err := pollStatement(ctx, mockAPI, resp) + require.NoError(t, err) + assert.Equal(t, sql.StatementStateSucceeded, pollResp.Status.State) + assert.Equal(t, "stmt-1", pollResp.StatementId) +} + +func TestPollStatementTerminalFailureNotErrored(t *testing.T) { + // pollStatement returns the response without erroring on failed terminal + // states; callers (e.g. executeAndPoll) decide what to do via checkFailedState. + ctx := cmdio.MockDiscard(t.Context()) + mockAPI := mocksql.NewMockStatementExecutionInterface(t) + + resp := &sql.StatementResponse{ + StatementId: "stmt-1", + Status: &sql.StatementStatus{ + State: sql.StatementStateFailed, + Error: &sql.ServiceError{ErrorCode: "ERR", Message: "boom"}, + }, + } + + pollResp, err := pollStatement(ctx, mockAPI, resp) + require.NoError(t, err) + assert.Equal(t, sql.StatementStateFailed, pollResp.Status.State) +} + +func TestPollStatementEventualSuccess(t *testing.T) { + ctx := cmdio.MockDiscard(t.Context()) + mockAPI := mocksql.NewMockStatementExecutionInterface(t) + + initial := &sql.StatementResponse{ + StatementId: "stmt-1", + Status: &sql.StatementStatus{State: sql.StatementStatePending}, + } + + mockAPI.EXPECT().GetStatementByStatementId(mock.Anything, "stmt-1").Return(&sql.StatementResponse{ + StatementId: "stmt-1", + Status: &sql.StatementStatus{State: sql.StatementStateRunning}, + }, nil).Once() + + mockAPI.EXPECT().GetStatementByStatementId(mock.Anything, "stmt-1").Return(&sql.StatementResponse{ + StatementId: "stmt-1", + Status: &sql.StatementStatus{State: sql.StatementStateSucceeded}, + Result: &sql.ResultData{DataArray: [][]string{{"42"}}}, + }, nil).Once() + + pollResp, err := pollStatement(ctx, mockAPI, initial) + require.NoError(t, err) + assert.Equal(t, sql.StatementStateSucceeded, pollResp.Status.State) + assert.Equal(t, [][]string{{"42"}}, pollResp.Result.DataArray) +} + +func TestPollStatementContextCancellationDoesNotCancelServerSide(t *testing.T) { + // The mock asserts (via t.Cleanup) that no unexpected calls are made. + // Specifically, pollStatement must NOT call CancelExecution on context + // cancellation; that is the caller's responsibility. + ctx, cancel := context.WithCancel(cmdio.MockDiscard(t.Context())) + mockAPI := mocksql.NewMockStatementExecutionInterface(t) + + initial := &sql.StatementResponse{ + StatementId: "stmt-1", + Status: &sql.StatementStatus{State: sql.StatementStatePending}, + } + + cancel() + + pollResp, err := pollStatement(ctx, mockAPI, initial) + require.ErrorIs(t, err, context.Canceled) + assert.Nil(t, pollResp) +} + +func TestPollStatementGetErrorPropagated(t *testing.T) { + ctx := cmdio.MockDiscard(t.Context()) + mockAPI := mocksql.NewMockStatementExecutionInterface(t) + + initial := &sql.StatementResponse{ + StatementId: "stmt-1", + Status: &sql.StatementStatus{State: sql.StatementStatePending}, + } + + mockAPI.EXPECT().GetStatementByStatementId(mock.Anything, "stmt-1"). + Return(nil, errors.New("network unreachable")).Once() + + pollResp, err := pollStatement(ctx, mockAPI, initial) + require.Error(t, err) + assert.Contains(t, err.Error(), "poll statement status") + assert.Contains(t, err.Error(), "network unreachable") + assert.Nil(t, pollResp) +} + func TestResolveWarehouseIDWithFlag(t *testing.T) { ctx := t.Context() id, err := resolveWarehouseID(ctx, nil, "explicit-id") @@ -330,69 +437,95 @@ func TestPollingConstants(t *testing.T) { assert.Equal(t, 10*time.Second, cancelTimeout) } -// newTestCmd creates a minimal cobra.Command for testing resolveSQL. +// newTestCmd creates a minimal cobra.Command for testing resolveSQLs. func newTestCmd() *cobra.Command { return &cobra.Command{Use: "test"} } -func TestResolveSQLFromFileFlag(t *testing.T) { +func TestResolveSQLsFromFileFlag(t *testing.T) { dir := t.TempDir() path := filepath.Join(dir, "query.sql") err := os.WriteFile(path, []byte("SELECT 1"), 0o644) require.NoError(t, err) cmd := newTestCmd() - result, err := resolveSQL(cmdio.MockDiscard(t.Context()), cmd, nil, path) + result, err := resolveSQLs(cmdio.MockDiscard(t.Context()), cmd, nil, []string{path}) require.NoError(t, err) - assert.Equal(t, "SELECT 1", result) + assert.Equal(t, []string{"SELECT 1"}, result) } -func TestResolveSQLFromFileFlagWithComments(t *testing.T) { +func TestResolveSQLsFromFileFlagWithComments(t *testing.T) { dir := t.TempDir() path := filepath.Join(dir, "query.sql") err := os.WriteFile(path, []byte("-- header comment\nSELECT 1\n-- trailing"), 0o644) require.NoError(t, err) cmd := newTestCmd() - result, err := resolveSQL(cmdio.MockDiscard(t.Context()), cmd, nil, path) + result, err := resolveSQLs(cmdio.MockDiscard(t.Context()), cmd, nil, []string{path}) + require.NoError(t, err) + assert.Equal(t, []string{"SELECT 1"}, result) +} + +func TestResolveSQLsMixedFileAndPositional(t *testing.T) { + // --file paths are emitted before positional args, in flag order. + dir := t.TempDir() + path := filepath.Join(dir, "from-file.sql") + err := os.WriteFile(path, []byte("SELECT 'from file'"), 0o644) + require.NoError(t, err) + + cmd := newTestCmd() + result, err := resolveSQLs(cmdio.MockDiscard(t.Context()), cmd, []string{"SELECT 'from arg'"}, []string{path}) require.NoError(t, err) - assert.Equal(t, "SELECT 1", result) + assert.Equal(t, []string{"SELECT 'from file'", "SELECT 'from arg'"}, result) } -func TestResolveSQLFileFlagConflictsWithArg(t *testing.T) { +func TestResolveSQLsMultiplePositional(t *testing.T) { cmd := newTestCmd() - _, err := resolveSQL(cmdio.MockDiscard(t.Context()), cmd, []string{"SELECT 1"}, "/some/file.sql") - require.Error(t, err) - assert.Contains(t, err.Error(), "cannot use both --file and a positional SQL argument") + result, err := resolveSQLs(cmdio.MockDiscard(t.Context()), cmd, []string{"SELECT 1", "SELECT 2", "SELECT 3"}, nil) + require.NoError(t, err) + assert.Equal(t, []string{"SELECT 1", "SELECT 2", "SELECT 3"}, result) +} + +func TestResolveSQLsMultipleFiles(t *testing.T) { + dir := t.TempDir() + pathA := filepath.Join(dir, "a.sql") + pathB := filepath.Join(dir, "b.sql") + require.NoError(t, os.WriteFile(pathA, []byte("SELECT 'a'"), 0o644)) + require.NoError(t, os.WriteFile(pathB, []byte("SELECT 'b'"), 0o644)) + + cmd := newTestCmd() + result, err := resolveSQLs(cmdio.MockDiscard(t.Context()), cmd, nil, []string{pathA, pathB}) + require.NoError(t, err) + assert.Equal(t, []string{"SELECT 'a'", "SELECT 'b'"}, result) } -func TestResolveSQLFromPositionalArg(t *testing.T) { +func TestResolveSQLsFromPositionalArg(t *testing.T) { cmd := newTestCmd() - result, err := resolveSQL(cmdio.MockDiscard(t.Context()), cmd, []string{"SELECT 42"}, "") + result, err := resolveSQLs(cmdio.MockDiscard(t.Context()), cmd, []string{"SELECT 42"}, nil) require.NoError(t, err) - assert.Equal(t, "SELECT 42", result) + assert.Equal(t, []string{"SELECT 42"}, result) } -func TestResolveSQLAutoDetectsSQLFile(t *testing.T) { +func TestResolveSQLsAutoDetectsSQLFile(t *testing.T) { dir := t.TempDir() path := filepath.Join(dir, "report.sql") err := os.WriteFile(path, []byte("SELECT * FROM sales"), 0o644) require.NoError(t, err) cmd := newTestCmd() - result, err := resolveSQL(cmdio.MockDiscard(t.Context()), cmd, []string{path}, "") + result, err := resolveSQLs(cmdio.MockDiscard(t.Context()), cmd, []string{path}, nil) require.NoError(t, err) - assert.Equal(t, "SELECT * FROM sales", result) + assert.Equal(t, []string{"SELECT * FROM sales"}, result) } -func TestResolveSQLNonexistentSQLFileTreatedAsString(t *testing.T) { +func TestResolveSQLsNonexistentSQLFileTreatedAsString(t *testing.T) { cmd := newTestCmd() - result, err := resolveSQL(cmdio.MockDiscard(t.Context()), cmd, []string{"nonexistent.sql"}, "") + result, err := resolveSQLs(cmdio.MockDiscard(t.Context()), cmd, []string{"nonexistent.sql"}, nil) require.NoError(t, err) - assert.Equal(t, "nonexistent.sql", result) + assert.Equal(t, []string{"nonexistent.sql"}, result) } -func TestResolveSQLUnreadableSQLFileReturnsError(t *testing.T) { +func TestResolveSQLsUnreadableSQLFileReturnsError(t *testing.T) { dir := t.TempDir() path := filepath.Join(dir, "locked.sql") err := os.WriteFile(path, []byte("SELECT 1"), 0o644) @@ -404,47 +537,54 @@ func TestResolveSQLUnreadableSQLFileReturnsError(t *testing.T) { t.Cleanup(func() { _ = os.Chmod(path, 0o644) }) cmd := newTestCmd() - _, err = resolveSQL(cmdio.MockDiscard(t.Context()), cmd, []string{path}, "") + _, err = resolveSQLs(cmdio.MockDiscard(t.Context()), cmd, []string{path}, nil) require.Error(t, err) assert.Contains(t, err.Error(), "read SQL file") } -func TestResolveSQLFromStdin(t *testing.T) { +func TestResolveSQLsFromStdin(t *testing.T) { cmd := newTestCmd() cmd.SetIn(strings.NewReader("SELECT 1 FROM stdin_test")) - result, err := resolveSQL(cmdio.MockDiscard(t.Context()), cmd, nil, "") + result, err := resolveSQLs(cmdio.MockDiscard(t.Context()), cmd, nil, nil) require.NoError(t, err) - assert.Equal(t, "SELECT 1 FROM stdin_test", result) + assert.Equal(t, []string{"SELECT 1 FROM stdin_test"}, result) } -func TestResolveSQLEmptyFileReturnsError(t *testing.T) { +func TestResolveSQLsEmptyFileReturnsError(t *testing.T) { dir := t.TempDir() path := filepath.Join(dir, "empty.sql") err := os.WriteFile(path, []byte(""), 0o644) require.NoError(t, err) cmd := newTestCmd() - _, err = resolveSQL(cmdio.MockDiscard(t.Context()), cmd, nil, path) + _, err = resolveSQLs(cmdio.MockDiscard(t.Context()), cmd, nil, []string{path}) require.Error(t, err) assert.Contains(t, err.Error(), "empty") } -func TestResolveSQLCommentsOnlyFileReturnsError(t *testing.T) { +func TestResolveSQLsCommentsOnlyFileReturnsError(t *testing.T) { dir := t.TempDir() path := filepath.Join(dir, "comments.sql") err := os.WriteFile(path, []byte("-- just a comment\n-- another"), 0o644) require.NoError(t, err) cmd := newTestCmd() - _, err = resolveSQL(cmdio.MockDiscard(t.Context()), cmd, nil, path) + _, err = resolveSQLs(cmdio.MockDiscard(t.Context()), cmd, nil, []string{path}) require.Error(t, err) assert.Contains(t, err.Error(), "empty") } -func TestResolveSQLMissingFileReturnsError(t *testing.T) { +func TestResolveSQLsBatchEmptyAtIndexReturnsIndexedError(t *testing.T) { cmd := newTestCmd() - _, err := resolveSQL(cmdio.MockDiscard(t.Context()), cmd, nil, "/nonexistent/path/query.sql") + _, err := resolveSQLs(cmdio.MockDiscard(t.Context()), cmd, []string{"SELECT 1", "-- comment only", "SELECT 3"}, nil) + require.Error(t, err) + assert.Contains(t, err.Error(), "SQL statement #2 is empty") +} + +func TestResolveSQLsMissingFileReturnsError(t *testing.T) { + cmd := newTestCmd() + _, err := resolveSQLs(cmdio.MockDiscard(t.Context()), cmd, nil, []string{"/nonexistent/path/query.sql"}) require.Error(t, err) assert.Contains(t, err.Error(), "read SQL file") } @@ -458,6 +598,34 @@ func TestQueryCommandUnsupportedOutputReturnsError(t *testing.T) { assert.Contains(t, err.Error(), "unsupported output format") } +func TestQueryCommandBatchOutputRejection(t *testing.T) { + // Multi-query mode is JSON-only. text and csv are rejected with an + // actionable error before any API call. + for _, format := range []string{"text", "csv"} { + t.Run(format, func(t *testing.T) { + cmd := newQueryCmd() + cmd.PreRunE = nil + cmd.SetArgs([]string{"--output", format, "SELECT 1", "SELECT 2"}) + err := cmd.Execute() + require.Error(t, err) + assert.Contains(t, err.Error(), "multiple queries require --output json") + }) + } +} + +func TestQueryCommandConcurrencyRejection(t *testing.T) { + // errgroup.SetLimit(0) deadlocks; negative removes the cap entirely. + // Both surprise users, so PreRunE rejects anything <= 0. + for _, value := range []string{"0", "-1"} { + t.Run(value, func(t *testing.T) { + cmd := newQueryCmd() + cmd.SetArgs([]string{"--concurrency", value, "--output", "json", "SELECT 1", "SELECT 2"}) + err := cmd.Execute() + require.ErrorIs(t, err, errInvalidBatchConcurrency) + }) + } +} + func TestQueryCommandOutputFlagIsCaseInsensitive(t *testing.T) { cmd := newQueryCmd() cmd.PreRunE = nil diff --git a/experimental/aitools/cmd/render.go b/experimental/aitools/cmd/render.go index 7727c37106..d0b62926c2 100644 --- a/experimental/aitools/cmd/render.go +++ b/experimental/aitools/cmd/render.go @@ -29,6 +29,17 @@ func extractColumns(manifest *sql.ResultManifest) []string { return columns } +// renderBatchJSON writes batch results as a JSON array. The array preserves +// input order and includes one object per submitted statement. +func renderBatchJSON(w io.Writer, results []batchResult) error { + output, err := json.MarshalIndent(results, "", " ") + if err != nil { + return fmt.Errorf("marshal batch results: %w", err) + } + fmt.Fprintf(w, "%s\n", output) + return nil +} + // renderJSON writes query results as a parseable JSON array to stdout. // Row count is written to stderr so stdout remains valid JSON for piping. func renderJSON(w io.Writer, columns []string, rows [][]string) error { diff --git a/experimental/aitools/cmd/statement.go b/experimental/aitools/cmd/statement.go new file mode 100644 index 0000000000..e1c48a7ddb --- /dev/null +++ b/experimental/aitools/cmd/statement.go @@ -0,0 +1,77 @@ +package aitools + +import ( + "encoding/json" + "fmt" + "io" + + "github.com/databricks/databricks-sdk-go/service/sql" + "github.com/spf13/cobra" +) + +// statementInfo is the JSON shape emitted by every `tools statement` +// subcommand. Fields are populated as the subcommand has them. omitempty keeps +// the output tight: `submit` doesn't emit columns/rows, `cancel` doesn't emit a +// warehouse_id, etc. +type statementInfo struct { + StatementID string `json:"statement_id"` + State sql.StatementState `json:"state,omitempty"` + WarehouseID string `json:"warehouse_id,omitempty"` + Columns []string `json:"columns,omitempty"` + Rows [][]string `json:"rows,omitempty"` + Error *batchResultError `json:"error,omitempty"` +} + +func renderStatementInfo(w io.Writer, info statementInfo) error { + data, err := json.MarshalIndent(info, "", " ") + if err != nil { + return fmt.Errorf("marshal statement info: %w", err) + } + fmt.Fprintf(w, "%s\n", data) + return nil +} + +// statementErrorFromStatus builds a batchResultError for any terminal non-success +// state (FAILED, CANCELED, CLOSED), populating it from the server's ServiceError +// when available and synthesizing a message when it isn't. Returns nil for +// SUCCEEDED, non-terminal states, and nil status. The synthesized fallback +// matters because the Statements API can hand back a non-success terminal state +// with `Error == nil`, and skill consumers should be able to branch on +// `error == null` alone instead of inspecting `state`. +func statementErrorFromStatus(status *sql.StatementStatus) *batchResultError { + if status == nil || !isTerminalState(status) || status.State == sql.StatementStateSucceeded { + return nil + } + out := &batchResultError{} + if status.Error != nil { + out.Message = status.Error.Message + out.ErrorCode = string(status.Error.ErrorCode) + } else { + out.Message = fmt.Sprintf("statement reached terminal state %s", status.State) + } + return out +} + +func newStatementCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "statement", + Short: "Manage SQL statement lifecycle (submit, get, status, cancel)", + Long: `Low-level command tree for asynchronous SQL execution. + +Use 'submit' to fire a statement and get its statement_id back, then +'get' to block on results, 'status' to peek without blocking, and +'cancel' to terminate. For "I want results now," use 'tools query' +instead. + +All subcommands emit a JSON object with the statement_id and state. +'get' adds columns and rows on success; any subcommand may emit an +error object when the server reports a non-success terminal state.`, + } + + cmd.AddCommand(newStatementSubmitCmd()) + cmd.AddCommand(newStatementGetCmd()) + cmd.AddCommand(newStatementStatusCmd()) + cmd.AddCommand(newStatementCancelCmd()) + + return cmd +} diff --git a/experimental/aitools/cmd/statement_cancel.go b/experimental/aitools/cmd/statement_cancel.go new file mode 100644 index 0000000000..1774b7abe6 --- /dev/null +++ b/experimental/aitools/cmd/statement_cancel.go @@ -0,0 +1,53 @@ +package aitools + +import ( + "context" + "fmt" + + "github.com/databricks/cli/cmd/root" + "github.com/databricks/cli/libs/cmdctx" + "github.com/databricks/databricks-sdk-go/service/sql" + "github.com/spf13/cobra" +) + +func newStatementCancelCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "cancel STATEMENT_ID", + Short: "Request cancellation of a running statement", + Long: `Send a cancellation request for the given statement_id. The Statements +API returns no body on cancel; this command optimistically reports +state=CANCELED on success. Use 'statement status' afterwards to confirm +the server-side state if you need certainty.`, + Example: ` databricks experimental aitools tools statement cancel 01ef...`, + Args: cobra.ExactArgs(1), + PreRunE: root.MustWorkspaceClient, + RunE: func(cmd *cobra.Command, args []string) error { + ctx := cmd.Context() + w := cmdctx.WorkspaceClient(ctx) + statementID := args[0] + + info, err := cancelStatementExecution(ctx, w.StatementExecution, statementID) + if err != nil { + return err + } + return renderStatementInfo(cmd.OutOrStdout(), info) + }, + } + + return cmd +} + +// cancelStatementExecution issues CancelExecution and reports state=CANCELED on success. +// CancelExecution returns no body; the actual server-side state is verified +// asynchronously. Use 'statement status' to confirm if certainty is required. +func cancelStatementExecution(ctx context.Context, api sql.StatementExecutionInterface, statementID string) (statementInfo, error) { + if err := api.CancelExecution(ctx, sql.CancelExecutionRequest{ + StatementId: statementID, + }); err != nil { + return statementInfo{}, fmt.Errorf("cancel statement: %w", err) + } + return statementInfo{ + StatementID: statementID, + State: sql.StatementStateCanceled, + }, nil +} diff --git a/experimental/aitools/cmd/statement_get.go b/experimental/aitools/cmd/statement_get.go new file mode 100644 index 0000000000..617b5c274d --- /dev/null +++ b/experimental/aitools/cmd/statement_get.go @@ -0,0 +1,96 @@ +package aitools + +import ( + "context" + "fmt" + + "github.com/databricks/cli/cmd/root" + "github.com/databricks/cli/libs/cmdctx" + "github.com/databricks/databricks-sdk-go/service/sql" + "github.com/spf13/cobra" +) + +func newStatementGetCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "get STATEMENT_ID", + Short: "Block until a previously submitted statement is terminal and emit its result", + Long: `Poll a statement_id until it reaches a terminal state, then emit +columns and rows on success or an error object on failure. + +Ctrl+C stops polling but does NOT cancel the server-side statement. +Use 'statement cancel ' to terminate explicitly. (This differs from +'tools query', which cancels server-side on Ctrl+C because the user +invoked the synchronous path.)`, + Example: ` databricks experimental aitools tools statement get 01ef...`, + Args: cobra.ExactArgs(1), + PreRunE: root.MustWorkspaceClient, + RunE: func(cmd *cobra.Command, args []string) error { + ctx := cmd.Context() + w := cmdctx.WorkspaceClient(ctx) + statementID := args[0] + + info, err := getStatementResult(ctx, w.StatementExecution, statementID) + if err != nil { + return err + } + + if err := renderStatementInfo(cmd.OutOrStdout(), info); err != nil { + return err + } + + // Non-zero exit when the statement reached a non-success terminal + // state OR a chunk-fetch failure prevented assembling the rows. + // In both cases the failure detail is already in the JSON output. + if info.State != sql.StatementStateSucceeded || info.Error != nil { + return root.ErrAlreadyPrinted + } + return nil + }, + } + + return cmd +} + +// getStatementResult polls a statement until terminal, then assembles a +// statementInfo with rows on success or an error object on failure. +// +// Context cancellation propagates from pollStatement WITHOUT cancelling the +// server-side statement (intentional: 'get' is a poll-only operation; use +// 'cancel' to terminate explicitly). +func getStatementResult(ctx context.Context, api sql.StatementExecutionInterface, statementID string) (statementInfo, error) { + // Fetch the current state first so pollStatement can short-circuit if + // the statement is already terminal. + resp, err := api.GetStatementByStatementId(ctx, statementID) + if err != nil { + return statementInfo{}, fmt.Errorf("get statement: %w", err) + } + + pollResp, err := pollStatement(ctx, api, resp) + if err != nil { + return statementInfo{}, err + } + + info := statementInfo{StatementID: pollResp.StatementId} + if pollResp.Status != nil { + info.State = pollResp.Status.State + } + info.Error = statementErrorFromStatus(pollResp.Status) + + if info.State == sql.StatementStateSucceeded { + info.Columns = extractColumns(pollResp.Manifest) + rows, err := fetchAllRows(ctx, api, pollResp) + if err != nil { + // The query succeeded server-side but a later chunk fetch failed + // (network blip, throttling, transient 5xx). Surface this as a + // structured error on the same statementInfo so the caller still + // gets a parseable JSON response with the statement_id; RunE then + // signals exit-non-zero based on info.Error. + info.Error = &batchResultError{ + Message: fmt.Sprintf("fetch result rows: %v", err), + } + return info, nil + } + info.Rows = rows + } + return info, nil +} diff --git a/experimental/aitools/cmd/statement_status.go b/experimental/aitools/cmd/statement_status.go new file mode 100644 index 0000000000..9981f49aa6 --- /dev/null +++ b/experimental/aitools/cmd/statement_status.go @@ -0,0 +1,52 @@ +package aitools + +import ( + "context" + "fmt" + + "github.com/databricks/cli/cmd/root" + "github.com/databricks/cli/libs/cmdctx" + "github.com/databricks/databricks-sdk-go/service/sql" + "github.com/spf13/cobra" +) + +func newStatementStatusCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "status STATEMENT_ID", + Short: "Return the current state of a statement without polling", + Long: `Single GET against the Statements API. Use this to peek at progress +without blocking. For a blocking poll-until-terminal call, use +'statement get'.`, + Example: ` databricks experimental aitools tools statement status 01ef...`, + Args: cobra.ExactArgs(1), + PreRunE: root.MustWorkspaceClient, + RunE: func(cmd *cobra.Command, args []string) error { + ctx := cmd.Context() + w := cmdctx.WorkspaceClient(ctx) + statementID := args[0] + + info, err := getStatementStatus(ctx, w.StatementExecution, statementID) + if err != nil { + return err + } + return renderStatementInfo(cmd.OutOrStdout(), info) + }, + } + + return cmd +} + +// getStatementStatus performs a single GET against the Statements API, no polling. +func getStatementStatus(ctx context.Context, api sql.StatementExecutionInterface, statementID string) (statementInfo, error) { + resp, err := api.GetStatementByStatementId(ctx, statementID) + if err != nil { + return statementInfo{}, fmt.Errorf("get statement: %w", err) + } + + info := statementInfo{StatementID: resp.StatementId} + if resp.Status != nil { + info.State = resp.Status.State + } + info.Error = statementErrorFromStatus(resp.Status) + return info, nil +} diff --git a/experimental/aitools/cmd/statement_submit.go b/experimental/aitools/cmd/statement_submit.go new file mode 100644 index 0000000000..ac8bf424e5 --- /dev/null +++ b/experimental/aitools/cmd/statement_submit.go @@ -0,0 +1,94 @@ +package aitools + +import ( + "context" + "errors" + "fmt" + + "github.com/databricks/cli/cmd/root" + "github.com/databricks/cli/libs/cmdctx" + "github.com/databricks/databricks-sdk-go/service/sql" + "github.com/spf13/cobra" +) + +func newStatementSubmitCmd() *cobra.Command { + var warehouseID string + var filePath string + // resolved by PreRunE so input validation runs before any auth/profile + // work and the documented "validates input before WorkspaceClient" claim + // in the PR description is actually true. + var sqlStatement string + + cmd := &cobra.Command{ + Use: "submit [SQL | file.sql]", + Short: "Submit a SQL statement asynchronously and return its statement_id", + Long: `Submit a SQL statement to a Databricks SQL warehouse and return its +statement_id immediately, without waiting for results. + +The statement keeps running server-side. Harvest results with +'statement get ', inspect with 'statement status ', or stop +with 'statement cancel '.`, + Example: ` databricks experimental aitools tools statement submit "SELECT pg_sleep(60)" --warehouse + databricks experimental aitools tools statement submit --file query.sql`, + Args: cobra.MaximumNArgs(1), + PreRunE: func(cmd *cobra.Command, args []string) error { + ctx := cmd.Context() + + var fps []string + if filePath != "" { + fps = []string{filePath} + } + sqls, err := resolveSQLs(ctx, cmd, args, fps) + if err != nil { + return err + } + if len(sqls) != 1 { + return errors.New("submit accepts exactly one SQL statement; pass multiple to 'query' for batch") + } + sqlStatement = sqls[0] + + return root.MustWorkspaceClient(cmd, args) + }, + RunE: func(cmd *cobra.Command, args []string) error { + ctx := cmd.Context() + w := cmdctx.WorkspaceClient(ctx) + wID, err := resolveWarehouseID(ctx, w, warehouseID) + if err != nil { + return err + } + + info, err := submitStatement(ctx, w.StatementExecution, sqlStatement, wID) + if err != nil { + return err + } + return renderStatementInfo(cmd.OutOrStdout(), info) + }, + } + + cmd.Flags().StringVarP(&warehouseID, "warehouse", "w", "", "SQL warehouse ID to use for execution") + cmd.Flags().StringVarP(&filePath, "file", "f", "", "Path to a SQL file to execute") + + return cmd +} + +// submitStatement issues an asynchronous ExecuteStatement and returns the handle. +func submitStatement(ctx context.Context, api sql.StatementExecutionInterface, statement, warehouseID string) (statementInfo, error) { + resp, err := api.ExecuteStatement(ctx, sql.ExecuteStatementRequest{ + WarehouseId: warehouseID, + Statement: statement, + WaitTimeout: "0s", + OnWaitTimeout: sql.ExecuteStatementRequestOnWaitTimeoutContinue, + }) + if err != nil { + return statementInfo{}, fmt.Errorf("execute statement: %w", err) + } + + info := statementInfo{ + StatementID: resp.StatementId, + WarehouseID: warehouseID, + } + if resp.Status != nil { + info.State = resp.Status.State + } + return info, nil +} diff --git a/experimental/aitools/cmd/statement_test.go b/experimental/aitools/cmd/statement_test.go new file mode 100644 index 0000000000..9c2264daf2 --- /dev/null +++ b/experimental/aitools/cmd/statement_test.go @@ -0,0 +1,352 @@ +package aitools + +import ( + "context" + "errors" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/databricks/cli/libs/cmdio" + mocksql "github.com/databricks/databricks-sdk-go/experimental/mocks/service/sql" + "github.com/databricks/databricks-sdk-go/service/sql" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + +func TestSubmitStatementReturnsHandle(t *testing.T) { + ctx := cmdio.MockDiscard(t.Context()) + mockAPI := mocksql.NewMockStatementExecutionInterface(t) + + mockAPI.EXPECT().ExecuteStatement(mock.Anything, mock.MatchedBy(func(req sql.ExecuteStatementRequest) bool { + return req.WarehouseId == "wh-1" && req.Statement == "SELECT 1" && + req.WaitTimeout == "0s" && + req.OnWaitTimeout == sql.ExecuteStatementRequestOnWaitTimeoutContinue + })).Return(&sql.StatementResponse{ + StatementId: "stmt-1", + Status: &sql.StatementStatus{State: sql.StatementStatePending}, + }, nil).Once() + + info, err := submitStatement(ctx, mockAPI, "SELECT 1", "wh-1") + require.NoError(t, err) + assert.Equal(t, "stmt-1", info.StatementID) + assert.Equal(t, sql.StatementStatePending, info.State) + assert.Equal(t, "wh-1", info.WarehouseID) +} + +func TestSubmitStatementWrapsTransportError(t *testing.T) { + ctx := cmdio.MockDiscard(t.Context()) + mockAPI := mocksql.NewMockStatementExecutionInterface(t) + + mockAPI.EXPECT().ExecuteStatement(mock.Anything, mock.Anything). + Return(nil, errors.New("network unreachable")).Once() + + _, err := submitStatement(ctx, mockAPI, "SELECT 1", "wh-1") + require.Error(t, err) + assert.Contains(t, err.Error(), "execute statement") + assert.Contains(t, err.Error(), "network unreachable") +} + +func TestGetStatementResultPolls(t *testing.T) { + ctx := cmdio.MockDiscard(t.Context()) + mockAPI := mocksql.NewMockStatementExecutionInterface(t) + + mockAPI.EXPECT().GetStatementByStatementId(mock.Anything, "stmt-1").Return(&sql.StatementResponse{ + StatementId: "stmt-1", + Status: &sql.StatementStatus{State: sql.StatementStateRunning}, + }, nil).Once() + + mockAPI.EXPECT().GetStatementByStatementId(mock.Anything, "stmt-1").Return(&sql.StatementResponse{ + StatementId: "stmt-1", + Status: &sql.StatementStatus{State: sql.StatementStateSucceeded}, + Manifest: &sql.ResultManifest{Schema: &sql.ResultSchema{Columns: []sql.ColumnInfo{{Name: "n"}}}, TotalChunkCount: 1}, + Result: &sql.ResultData{DataArray: [][]string{{"42"}}}, + }, nil).Once() + + info, err := getStatementResult(ctx, mockAPI, "stmt-1") + require.NoError(t, err) + assert.Equal(t, "stmt-1", info.StatementID) + assert.Equal(t, sql.StatementStateSucceeded, info.State) + assert.Equal(t, []string{"n"}, info.Columns) + assert.Equal(t, [][]string{{"42"}}, info.Rows) + assert.Nil(t, info.Error) +} + +func TestGetStatementResultFailedStateReportsError(t *testing.T) { + ctx := cmdio.MockDiscard(t.Context()) + mockAPI := mocksql.NewMockStatementExecutionInterface(t) + + mockAPI.EXPECT().GetStatementByStatementId(mock.Anything, "stmt-1").Return(&sql.StatementResponse{ + StatementId: "stmt-1", + Status: &sql.StatementStatus{ + State: sql.StatementStateFailed, + Error: &sql.ServiceError{ + ErrorCode: "SYNTAX_ERROR", + Message: "near 'bad': syntax error", + }, + }, + }, nil).Once() + + info, err := getStatementResult(ctx, mockAPI, "stmt-1") + require.NoError(t, err) + assert.Equal(t, sql.StatementStateFailed, info.State) + assert.Nil(t, info.Rows) + require.NotNil(t, info.Error) + assert.Equal(t, "SYNTAX_ERROR", info.Error.ErrorCode) + assert.Contains(t, info.Error.Message, "syntax error") +} + +func TestGetStatementResultDoesNotCancelServerSideOnContextCancel(t *testing.T) { + // 'statement get' is a poll-only operation: ctx cancellation must NOT + // trigger CancelExecution. The mock asserts (via t.Cleanup) that no + // unexpected calls happen. + ctx, cancel := context.WithCancel(cmdio.MockDiscard(t.Context())) + mockAPI := mocksql.NewMockStatementExecutionInterface(t) + + mockAPI.EXPECT().GetStatementByStatementId(mock.Anything, "stmt-1").Return(&sql.StatementResponse{ + StatementId: "stmt-1", + Status: &sql.StatementStatus{State: sql.StatementStatePending}, + }, nil).Once() + + cancel() + + _, err := getStatementResult(ctx, mockAPI, "stmt-1") + require.ErrorIs(t, err, context.Canceled) +} + +func TestGetStatementResultChunkFetchFailureRendersPartialInfo(t *testing.T) { + // SUCCEEDED state but a later chunk fetch fails (network blip, throttle, + // 5xx). getStatementResult should surface this as a structured error on + // the same statementInfo so the caller still gets parseable JSON with the + // statement_id, instead of returning a raw Go error that RunE would + // discard along with the populated info. + ctx := cmdio.MockDiscard(t.Context()) + mockAPI := mocksql.NewMockStatementExecutionInterface(t) + + mockAPI.EXPECT().GetStatementByStatementId(mock.Anything, "stmt-1").Return(&sql.StatementResponse{ + StatementId: "stmt-1", + Status: &sql.StatementStatus{State: sql.StatementStateSucceeded}, + Manifest: &sql.ResultManifest{ + Schema: &sql.ResultSchema{Columns: []sql.ColumnInfo{{Name: "n"}}}, + TotalChunkCount: 2, + }, + Result: &sql.ResultData{DataArray: [][]string{{"1"}}}, + }, nil).Once() + + mockAPI.EXPECT().GetStatementResultChunkNByStatementIdAndChunkIndex(mock.Anything, "stmt-1", 1). + Return(nil, errors.New("network blip")).Once() + + info, err := getStatementResult(ctx, mockAPI, "stmt-1") + require.NoError(t, err) + assert.Equal(t, sql.StatementStateSucceeded, info.State) + assert.Equal(t, []string{"n"}, info.Columns, "columns from the initial response are still surfaced") + require.NotNil(t, info.Error) + assert.Contains(t, info.Error.Message, "fetch result rows") + assert.Contains(t, info.Error.Message, "network blip") +} + +func TestGetStatementStatusSinglePoll(t *testing.T) { + ctx := cmdio.MockDiscard(t.Context()) + mockAPI := mocksql.NewMockStatementExecutionInterface(t) + + mockAPI.EXPECT().GetStatementByStatementId(mock.Anything, "stmt-1").Return(&sql.StatementResponse{ + StatementId: "stmt-1", + Status: &sql.StatementStatus{State: sql.StatementStateRunning}, + }, nil).Once() + + info, err := getStatementStatus(ctx, mockAPI, "stmt-1") + require.NoError(t, err) + assert.Equal(t, "stmt-1", info.StatementID) + assert.Equal(t, sql.StatementStateRunning, info.State) + assert.Nil(t, info.Error) +} + +func TestGetStatementStatusReportsError(t *testing.T) { + ctx := cmdio.MockDiscard(t.Context()) + mockAPI := mocksql.NewMockStatementExecutionInterface(t) + + mockAPI.EXPECT().GetStatementByStatementId(mock.Anything, "stmt-1").Return(&sql.StatementResponse{ + StatementId: "stmt-1", + Status: &sql.StatementStatus{ + State: sql.StatementStateFailed, + Error: &sql.ServiceError{ + ErrorCode: "TIMEOUT", + Message: "warehouse timed out", + }, + }, + }, nil).Once() + + info, err := getStatementStatus(ctx, mockAPI, "stmt-1") + require.NoError(t, err) + assert.Equal(t, sql.StatementStateFailed, info.State) + require.NotNil(t, info.Error) + assert.Equal(t, "TIMEOUT", info.Error.ErrorCode) +} + +func TestCancelStatementExecutionCallsAPI(t *testing.T) { + ctx := cmdio.MockDiscard(t.Context()) + mockAPI := mocksql.NewMockStatementExecutionInterface(t) + + mockAPI.EXPECT().CancelExecution(mock.Anything, sql.CancelExecutionRequest{ + StatementId: "stmt-1", + }).Return(nil).Once() + + info, err := cancelStatementExecution(ctx, mockAPI, "stmt-1") + require.NoError(t, err) + assert.Equal(t, "stmt-1", info.StatementID) + assert.Equal(t, sql.StatementStateCanceled, info.State) +} + +func TestCancelStatementExecutionWrapsAPIError(t *testing.T) { + ctx := cmdio.MockDiscard(t.Context()) + mockAPI := mocksql.NewMockStatementExecutionInterface(t) + + mockAPI.EXPECT().CancelExecution(mock.Anything, mock.Anything). + Return(errors.New("not found")).Once() + + _, err := cancelStatementExecution(ctx, mockAPI, "stmt-1") + require.Error(t, err) + assert.Contains(t, err.Error(), "cancel statement") + assert.Contains(t, err.Error(), "not found") +} + +func TestRenderStatementInfo(t *testing.T) { + tests := []struct { + name string + info statementInfo + mustHave []string + mustNotHave []string + }{ + { + name: "full payload renders every populated field", + info: statementInfo{ + StatementID: "stmt-1", + State: sql.StatementStateSucceeded, + WarehouseID: "wh-1", + Columns: []string{"n"}, + Rows: [][]string{{"42"}}, + }, + mustHave: []string{ + `"statement_id": "stmt-1"`, + `"state": "SUCCEEDED"`, + `"warehouse_id": "wh-1"`, + `"columns": [`, + `"rows": [`, + }, + }, + { + name: "cancel-style payload omits unset fields", + info: statementInfo{ + StatementID: "stmt-1", + State: sql.StatementStateCanceled, + }, + mustHave: []string{ + `"statement_id": "stmt-1"`, + `"state": "CANCELED"`, + }, + mustNotHave: []string{`"warehouse_id"`, `"columns"`, `"rows"`, `"error"`}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + var buf strings.Builder + require.NoError(t, renderStatementInfo(&buf, tc.info)) + out := buf.String() + for _, want := range tc.mustHave { + assert.Contains(t, out, want) + } + for _, missing := range tc.mustNotHave { + assert.NotContains(t, out, missing) + } + assert.True(t, strings.HasSuffix(out, "\n")) + }) + } +} + +func TestStatementSubmitRejectsMultipleSQLsBeforeWorkspaceClient(t *testing.T) { + // The "exactly one SQL" check runs in PreRunE BEFORE MustWorkspaceClient, + // so a malformed invocation is rejected without any auth/profile work. + // The test relies on this ordering: it does not stub out PreRunE, so if + // validation moved back after MustWorkspaceClient the test would panic + // on a missing workspace client instead of returning the validation error. + dir := t.TempDir() + path := filepath.Join(dir, "test.sql") + require.NoError(t, os.WriteFile(path, []byte("SELECT 1"), 0o644)) + + cmd := newStatementSubmitCmd() + cmd.SetArgs([]string{"--file", path, "SELECT 2"}) + err := cmd.Execute() + require.Error(t, err) + assert.Contains(t, err.Error(), "exactly one") +} + +func TestStatementErrorFromStatus(t *testing.T) { + tests := []struct { + name string + status *sql.StatementStatus + wantNil bool + wantMsg string + wantCode string + }{ + { + name: "nil status", + status: nil, + wantNil: true, + }, + { + name: "succeeded never produces an error", + status: &sql.StatementStatus{State: sql.StatementStateSucceeded}, + wantNil: true, + }, + { + name: "running is not terminal", + status: &sql.StatementStatus{State: sql.StatementStateRunning}, + wantNil: true, + }, + { + name: "pending is not terminal", + status: &sql.StatementStatus{State: sql.StatementStatePending}, + wantNil: true, + }, + { + name: "failed with backend error preserves both fields", + status: &sql.StatementStatus{ + State: sql.StatementStateFailed, + Error: &sql.ServiceError{ErrorCode: "SYNTAX_ERROR", Message: "near 'bad'"}, + }, + wantMsg: "near 'bad'", + wantCode: "SYNTAX_ERROR", + }, + { + name: "failed without backend error synthesizes message", + status: &sql.StatementStatus{State: sql.StatementStateFailed}, + wantMsg: "statement reached terminal state FAILED", + }, + { + name: "canceled without backend error synthesizes message", + status: &sql.StatementStatus{State: sql.StatementStateCanceled}, + wantMsg: "statement reached terminal state CANCELED", + }, + { + name: "closed without backend error synthesizes message", + status: &sql.StatementStatus{State: sql.StatementStateClosed}, + wantMsg: "statement reached terminal state CLOSED", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := statementErrorFromStatus(tc.status) + if tc.wantNil { + assert.Nil(t, got) + return + } + require.NotNil(t, got) + assert.Equal(t, tc.wantMsg, got.Message) + assert.Equal(t, tc.wantCode, got.ErrorCode) + }) + } +} diff --git a/experimental/aitools/cmd/tools.go b/experimental/aitools/cmd/tools.go index b5dd306d21..22781f987f 100644 --- a/experimental/aitools/cmd/tools.go +++ b/experimental/aitools/cmd/tools.go @@ -15,6 +15,7 @@ func newToolsCmd() *cobra.Command { cmd.AddCommand(newQueryCmd()) cmd.AddCommand(newDiscoverSchemaCmd()) cmd.AddCommand(newGetDefaultWarehouseCmd()) + cmd.AddCommand(newStatementCmd()) return cmd }