From 79fc08074e05ebeb45212ffc8ad1a1332a87b35e Mon Sep 17 00:00:00 2001 From: simon Date: Mon, 27 Apr 2026 15:04:52 +0200 Subject: [PATCH 01/14] aitools: extract pollStatement helper and pin OnWaitTimeout Refactor `executeAndPoll` in `experimental/aitools/cmd/query.go` to extract a pure `pollStatement(ctx, api, resp)` helper. The helper polls until the statement reaches a terminal state and returns the response without any signal handling, spinner, or server-side cancellation; those concerns stay in `executeAndPoll` where they belong. Also pin `OnWaitTimeout: CONTINUE` explicitly on the `ExecuteStatement` call. The SDK default happens to be CONTINUE today, but relying on it is a hidden coupling: a server-side default flip would silently break the poll loop by killing the statement before our first GET. Behavior is unchanged for the existing `query` command. Follow-up PRs (parallel batch queries, statement lifecycle command tree) will reuse the helper. Co-authored-by: Isaac --- experimental/aitools/cmd/query.go | 63 ++++++++++----- experimental/aitools/cmd/query_test.go | 105 ++++++++++++++++++++++++- 2 files changed, 146 insertions(+), 22 deletions(-) diff --git a/experimental/aitools/cmd/query.go b/experimental/aitools/cmd/query.go index 7b95fdd4e23..6c125bbcd6b 100644 --- a/experimental/aitools/cmd/query.go +++ b/experimental/aitools/cmd/query.go @@ -262,9 +262,10 @@ func resolveWarehouseID(ctx context.Context, w any, flagValue string) (string, e func executeAndPoll(ctx context.Context, api sql.StatementExecutionInterface, warehouseID, statement string) (*sql.StatementResponse, error) { // Submit asynchronously to get the statement ID immediately for cancellation. resp, err := api.ExecuteStatement(ctx, sql.ExecuteStatementRequest{ - WarehouseId: warehouseID, - Statement: statement, - WaitTimeout: "0s", + WarehouseId: warehouseID, + Statement: statement, + WaitTimeout: "0s", + OnWaitTimeout: sql.ExecuteStatementRequestOnWaitTimeoutContinue, }) if err != nil { return nil, fmt.Errorf("execute statement: %w", err) @@ -272,11 +273,6 @@ func executeAndPoll(ctx context.Context, api sql.StatementExecutionInterface, wa statementID := resp.StatementId - // Check if it completed immediately. - if isTerminalState(resp.Status) { - return resp, checkFailedState(resp.Status) - } - // Set up Ctrl+C: signal cancels the poll context, cleanup is unified below. pollCtx, pollCancel := context.WithCancel(ctx) defer pollCancel() @@ -327,34 +323,59 @@ func executeAndPoll(ctx context.Context, api sql.StatementExecutionInterface, wa } }() + pollResp, err := pollStatement(pollCtx, api, resp) + if err != nil { + if pollCtx.Err() != nil { + cancelStatement() + cmdio.LogString(ctx, "Query cancelled.") + return nil, root.ErrAlreadyPrinted + } + return nil, err + } + + sp.Close() + if err := checkFailedState(pollResp.Status); err != nil { + return nil, err + } + return pollResp, nil +} + +// pollStatement polls until the statement reaches a terminal state. +// +// On context cancellation it returns the context error WITHOUT cancelling the +// server-side statement. Callers that want server-side cancellation should +// invoke CancelExecution explicitly. +// +// If the input response is already in a terminal state, it is returned without +// further polling. +func pollStatement(ctx context.Context, api sql.StatementExecutionInterface, resp *sql.StatementResponse) (*sql.StatementResponse, error) { + if isTerminalState(resp.Status) { + return resp, nil + } + + statementID := resp.StatementId + start := time.Now() + // Poll with additive backoff: 1s, 2s, 3s, 4s, 5s (capped). interval := pollIntervalInitial for { select { - case <-pollCtx.Done(): - cancelStatement() - cmdio.LogString(ctx, "Query cancelled.") - return nil, root.ErrAlreadyPrinted + case <-ctx.Done(): + return nil, ctx.Err() case <-time.After(interval): } log.Debugf(ctx, "Polling statement %s: %s elapsed", statementID, time.Since(start).Truncate(time.Second)) - pollResp, err := api.GetStatementByStatementId(pollCtx, statementID) + pollResp, err := api.GetStatementByStatementId(ctx, statementID) if err != nil { - if pollCtx.Err() != nil { - cancelStatement() - cmdio.LogString(ctx, "Query cancelled.") - return nil, root.ErrAlreadyPrinted + if ctx.Err() != nil { + return nil, ctx.Err() } return nil, fmt.Errorf("poll statement status: %w", err) } if isTerminalState(pollResp.Status) { - sp.Close() - if err := checkFailedState(pollResp.Status); err != nil { - return nil, err - } return &sql.StatementResponse{ StatementId: pollResp.StatementId, Status: pollResp.Status, diff --git a/experimental/aitools/cmd/query_test.go b/experimental/aitools/cmd/query_test.go index aa33921c83b..4bc06c1d63b 100644 --- a/experimental/aitools/cmd/query_test.go +++ b/experimental/aitools/cmd/query_test.go @@ -2,6 +2,7 @@ package aitools import ( "context" + "errors" "os" "path/filepath" "strings" @@ -48,7 +49,9 @@ func TestExecuteAndPollImmediateSuccess(t *testing.T) { mockAPI := mocksql.NewMockStatementExecutionInterface(t) mockAPI.EXPECT().ExecuteStatement(mock.Anything, mock.MatchedBy(func(req sql.ExecuteStatementRequest) bool { - return req.WarehouseId == "wh-123" && req.Statement == "SELECT 1" && req.WaitTimeout == "0s" + return req.WarehouseId == "wh-123" && req.Statement == "SELECT 1" && + req.WaitTimeout == "0s" && + req.OnWaitTimeout == sql.ExecuteStatementRequestOnWaitTimeoutContinue })).Return(&sql.StatementResponse{ StatementId: "stmt-1", Status: &sql.StatementStatus{State: sql.StatementStateSucceeded}, @@ -154,6 +157,106 @@ func TestExecuteAndPollCancelledContextCallsCancelExecution(t *testing.T) { require.ErrorIs(t, err, root.ErrAlreadyPrinted) } +func TestPollStatementImmediateTerminal(t *testing.T) { + ctx := cmdio.MockDiscard(t.Context()) + mockAPI := mocksql.NewMockStatementExecutionInterface(t) + + resp := &sql.StatementResponse{ + StatementId: "stmt-1", + Status: &sql.StatementStatus{State: sql.StatementStateSucceeded}, + Manifest: &sql.ResultManifest{Schema: &sql.ResultSchema{Columns: []sql.ColumnInfo{{Name: "1"}}}}, + Result: &sql.ResultData{DataArray: [][]string{{"1"}}}, + } + + pollResp, err := pollStatement(ctx, mockAPI, resp) + require.NoError(t, err) + assert.Equal(t, sql.StatementStateSucceeded, pollResp.Status.State) + assert.Equal(t, "stmt-1", pollResp.StatementId) +} + +func TestPollStatementTerminalFailureNotErrored(t *testing.T) { + // pollStatement returns the response without erroring on failed terminal + // states; callers (e.g. executeAndPoll) decide what to do via checkFailedState. + ctx := cmdio.MockDiscard(t.Context()) + mockAPI := mocksql.NewMockStatementExecutionInterface(t) + + resp := &sql.StatementResponse{ + StatementId: "stmt-1", + Status: &sql.StatementStatus{ + State: sql.StatementStateFailed, + Error: &sql.ServiceError{ErrorCode: "ERR", Message: "boom"}, + }, + } + + pollResp, err := pollStatement(ctx, mockAPI, resp) + require.NoError(t, err) + assert.Equal(t, sql.StatementStateFailed, pollResp.Status.State) +} + +func TestPollStatementEventualSuccess(t *testing.T) { + ctx := cmdio.MockDiscard(t.Context()) + mockAPI := mocksql.NewMockStatementExecutionInterface(t) + + initial := &sql.StatementResponse{ + StatementId: "stmt-1", + Status: &sql.StatementStatus{State: sql.StatementStatePending}, + } + + mockAPI.EXPECT().GetStatementByStatementId(mock.Anything, "stmt-1").Return(&sql.StatementResponse{ + StatementId: "stmt-1", + Status: &sql.StatementStatus{State: sql.StatementStateRunning}, + }, nil).Once() + + mockAPI.EXPECT().GetStatementByStatementId(mock.Anything, "stmt-1").Return(&sql.StatementResponse{ + StatementId: "stmt-1", + Status: &sql.StatementStatus{State: sql.StatementStateSucceeded}, + Result: &sql.ResultData{DataArray: [][]string{{"42"}}}, + }, nil).Once() + + pollResp, err := pollStatement(ctx, mockAPI, initial) + require.NoError(t, err) + assert.Equal(t, sql.StatementStateSucceeded, pollResp.Status.State) + assert.Equal(t, [][]string{{"42"}}, pollResp.Result.DataArray) +} + +func TestPollStatementContextCancellationDoesNotCancelServerSide(t *testing.T) { + // The mock asserts (via t.Cleanup) that no unexpected calls are made. + // Specifically, pollStatement must NOT call CancelExecution on context + // cancellation; that is the caller's responsibility. + ctx, cancel := context.WithCancel(cmdio.MockDiscard(t.Context())) + mockAPI := mocksql.NewMockStatementExecutionInterface(t) + + initial := &sql.StatementResponse{ + StatementId: "stmt-1", + Status: &sql.StatementStatus{State: sql.StatementStatePending}, + } + + cancel() + + pollResp, err := pollStatement(ctx, mockAPI, initial) + require.ErrorIs(t, err, context.Canceled) + assert.Nil(t, pollResp) +} + +func TestPollStatementGetErrorPropagated(t *testing.T) { + ctx := cmdio.MockDiscard(t.Context()) + mockAPI := mocksql.NewMockStatementExecutionInterface(t) + + initial := &sql.StatementResponse{ + StatementId: "stmt-1", + Status: &sql.StatementStatus{State: sql.StatementStatePending}, + } + + mockAPI.EXPECT().GetStatementByStatementId(mock.Anything, "stmt-1"). + Return(nil, errors.New("network unreachable")).Once() + + pollResp, err := pollStatement(ctx, mockAPI, initial) + require.Error(t, err) + assert.Contains(t, err.Error(), "poll statement status") + assert.Contains(t, err.Error(), "network unreachable") + assert.Nil(t, pollResp) +} + func TestResolveWarehouseIDWithFlag(t *testing.T) { ctx := t.Context() id, err := resolveWarehouseID(ctx, nil, "explicit-id") From 6b6128a7c10f802e5c67838ad1215ef6ec8fb985 Mon Sep 17 00:00:00 2001 From: simon Date: Mon, 27 Apr 2026 15:37:40 +0200 Subject: [PATCH 02/14] aitools: run multiple SQL queries in parallel from one query invocation Allow `databricks experimental aitools tools query` to accept several SQLs in a single invocation and run them in parallel against the warehouse. Pass multiple positional arguments and/or repeat `--file` to fan out: databricks experimental aitools tools query \ --warehouse --output json \ "SELECT count(*) FROM t" \ "SELECT min(ts), max(ts) FROM t" \ "SELECT col, count(*) FROM t GROUP BY 1" Multi-query output is always a JSON array of one object per input, preserving input order. The shape is `{sql, statement_id, state, elapsed_ms, columns, rows, error}`. Individual statement failures don't abort siblings; each is encoded in the per-result `error` field, and the exit code is non-zero when any statement failed. A new `--concurrency` flag (default 8) caps in-flight statements. On Ctrl+C the still-running statements are cancelled server-side via CancelExecution before exit. Single-query behavior is unchanged. The previous restriction that forbade mixing `--file` and a positional SQL is lifted, since both now contribute to the batch. Co-authored-by: Isaac --- experimental/aitools/README.md | 12 ++ experimental/aitools/cmd/batch.go | 206 +++++++++++++++++++++ experimental/aitools/cmd/batch_test.go | 237 +++++++++++++++++++++++++ experimental/aitools/cmd/query.go | 132 +++++++++----- experimental/aitools/cmd/query_test.go | 113 ++++++++---- experimental/aitools/cmd/render.go | 11 ++ 6 files changed, 637 insertions(+), 74 deletions(-) create mode 100644 experimental/aitools/cmd/batch.go create mode 100644 experimental/aitools/cmd/batch_test.go diff --git a/experimental/aitools/README.md b/experimental/aitools/README.md index 571136538c9..f645e4de51d 100644 --- a/experimental/aitools/README.md +++ b/experimental/aitools/README.md @@ -16,6 +16,18 @@ Current behavior: - `skills install` installs Databricks skills for detected coding agents. - `install` is a compatibility alias for `skills install`. - `tools` exposes a small set of AI-oriented workspace helpers. +- `tools query` accepts a single SQL or multiple SQLs in one invocation. Pass + several positional arguments and/or repeat `--file` to run them in parallel + against the warehouse. Multi-query output is always JSON; control parallelism + with `--concurrency` (default 8). + + ```bash + databricks experimental aitools tools query \ + --warehouse --output json \ + "SELECT count(*) FROM samples.nyctaxi.trips" \ + "SELECT min(tpep_pickup_datetime), max(tpep_pickup_datetime) FROM samples.nyctaxi.trips" \ + "SELECT vendor_id, count(*) FROM samples.nyctaxi.trips GROUP BY 1" + ``` Removed behavior: diff --git a/experimental/aitools/cmd/batch.go b/experimental/aitools/cmd/batch.go new file mode 100644 index 00000000000..8965923c17c --- /dev/null +++ b/experimental/aitools/cmd/batch.go @@ -0,0 +1,206 @@ +package aitools + +import ( + "context" + "fmt" + "os" + "os/signal" + "sync/atomic" + "syscall" + "time" + + "github.com/databricks/cli/libs/cmdio" + "github.com/databricks/cli/libs/log" + "github.com/databricks/databricks-sdk-go/service/sql" + "golang.org/x/sync/errgroup" +) + +// defaultBatchConcurrency caps in-flight statements when --concurrency is unset. +// Matches the default used by cmd/fs/cp.go for similar fan-out work. +const defaultBatchConcurrency = 8 + +// batchResult is the per-statement payload emitted in batch mode JSON output. +// State is the server-reported terminal state. Error is set whenever the +// statement did not produce usable rows, regardless of state, so consumers +// can branch on `error == null` alone. +type batchResult struct { + SQL string `json:"sql"` + StatementID string `json:"statement_id,omitempty"` + State sql.StatementState `json:"state,omitempty"` + ElapsedMs int64 `json:"elapsed_ms"` + Columns []string `json:"columns,omitempty"` + Rows [][]string `json:"rows,omitempty"` + Error *batchResultError `json:"error,omitempty"` +} + +// batchResultError captures user-visible error info for a failed statement. +type batchResultError struct { + Message string `json:"message"` + ErrorCode string `json:"error_code,omitempty"` +} + +// executeBatch submits sqls against the warehouse in parallel, polls each to +// completion, and returns one batchResult per input in input order. +// +// Individual statement failures do not abort siblings; failures are encoded in +// the per-result Error field so callers can render partial results. +// +// On context cancellation (Ctrl+C or parent context), still-running statements +// are cancelled server-side via CancelExecution. Statements that finished +// before cancellation are left as-is. +func executeBatch(ctx context.Context, api sql.StatementExecutionInterface, warehouseID string, sqls []string, concurrency int) []batchResult { + pollCtx, pollCancel := context.WithCancel(ctx) + defer pollCancel() + + sigCh := make(chan os.Signal, 1) + signal.Notify(sigCh, os.Interrupt, syscall.SIGTERM) + defer signal.Stop(sigCh) + + go func() { + select { + case <-sigCh: + log.Infof(ctx, "Received interrupt, cancelling %d in-flight queries", len(sqls)) + pollCancel() + case <-pollCtx.Done(): + } + }() + + sp := cmdio.NewSpinner(pollCtx) + defer sp.Close() + sp.Update(fmt.Sprintf("Executing %d queries...", len(sqls))) + + var completed atomic.Int64 + ticker := time.NewTicker(500 * time.Millisecond) + defer ticker.Stop() + go func() { + for { + select { + case <-pollCtx.Done(): + return + case <-ticker.C: + sp.Update(fmt.Sprintf("Executing %d queries... (%d/%d done)", len(sqls), completed.Load(), len(sqls))) + } + } + }() + + results := make([]batchResult, len(sqls)) + // Each goroutine writes to a distinct slot, safe without a mutex. + // We read after g.Wait(), establishing happens-before for all writes. + statementIDs := make([]string, len(sqls)) + + g := new(errgroup.Group) + g.SetLimit(concurrency) + for i, sqlStr := range sqls { + g.Go(func() error { + results[i] = runOneBatchQuery(pollCtx, api, warehouseID, sqlStr, statementIDs, i) + completed.Add(1) + return nil + }) + } + _ = g.Wait() + + // pollStatement is a pure helper that returns ctx.Err() on cancellation + // without touching the server. Sweep any not-yet-terminal statements here. + if pollCtx.Err() != nil { + cancelInFlight(ctx, api, statementIDs, results) + } + + return results +} + +// runOneBatchQuery submits one SQL, polls to completion, and returns its +// batchResult. All errors are encoded into the result; never returns an error. +func runOneBatchQuery(ctx context.Context, api sql.StatementExecutionInterface, warehouseID, sqlStr string, statementIDs []string, idx int) batchResult { + start := time.Now() + result := batchResult{SQL: sqlStr} + + resp, err := api.ExecuteStatement(ctx, sql.ExecuteStatementRequest{ + WarehouseId: warehouseID, + Statement: sqlStr, + WaitTimeout: "0s", + OnWaitTimeout: sql.ExecuteStatementRequestOnWaitTimeoutContinue, + }) + if err != nil { + if ctx.Err() != nil { + result.State = sql.StatementStateCanceled + result.Error = &batchResultError{Message: "submission cancelled"} + } else { + result.State = sql.StatementStateFailed + result.Error = &batchResultError{Message: fmt.Sprintf("execute statement: %v", err)} + } + result.ElapsedMs = time.Since(start).Milliseconds() + return result + } + + statementIDs[idx] = resp.StatementId + result.StatementID = resp.StatementId + + pollResp, err := pollStatement(ctx, api, resp) + if err != nil { + if ctx.Err() != nil { + result.State = sql.StatementStateCanceled + result.Error = &batchResultError{Message: "cancelled"} + } else { + result.State = sql.StatementStateFailed + result.Error = &batchResultError{Message: err.Error()} + } + result.ElapsedMs = time.Since(start).Milliseconds() + return result + } + + if pollResp.Status != nil { + result.State = pollResp.Status.State + } + + if result.State != sql.StatementStateSucceeded { + result.Error = &batchResultError{} + if pollResp.Status != nil && pollResp.Status.Error != nil { + result.Error.Message = pollResp.Status.Error.Message + result.Error.ErrorCode = string(pollResp.Status.Error.ErrorCode) + } else { + result.Error.Message = fmt.Sprintf("query reached terminal state %s", result.State) + } + result.ElapsedMs = time.Since(start).Milliseconds() + return result + } + + result.Columns = extractColumns(pollResp.Manifest) + rows, err := fetchAllRows(ctx, api, pollResp) + if err != nil { + result.Error = &batchResultError{Message: fmt.Sprintf("fetch rows: %v", err)} + result.ElapsedMs = time.Since(start).Milliseconds() + return result + } + result.Rows = rows + result.ElapsedMs = time.Since(start).Milliseconds() + return result +} + +// cancelInFlight sends CancelExecution for every statement that didn't reach +// a terminal state server-side before context cancellation. Best effort: errors +// are logged at warn but don't fail the batch. +func cancelInFlight(ctx context.Context, api sql.StatementExecutionInterface, statementIDs []string, results []batchResult) { + var cancelled int + for i, sid := range statementIDs { + if sid == "" { + continue + } + switch results[i].State { + case sql.StatementStateSucceeded, sql.StatementStateFailed, sql.StatementStateClosed: + continue + case sql.StatementStateCanceled, sql.StatementStatePending, sql.StatementStateRunning: + // Either still running server-side, or our internal "canceled" + // marker meaning the goroutine bailed without telling the server. + // Either way, send CancelExecution. + } + cancelCtx, cancel := context.WithTimeout(ctx, cancelTimeout) + if err := api.CancelExecution(cancelCtx, sql.CancelExecutionRequest{StatementId: sid}); err != nil { + log.Warnf(ctx, "Failed to cancel statement %s: %v", sid, err) + } + cancel() + cancelled++ + } + if cancelled > 0 { + cmdio.LogString(ctx, fmt.Sprintf("Cancelled %d in-flight queries.", cancelled)) + } +} diff --git a/experimental/aitools/cmd/batch_test.go b/experimental/aitools/cmd/batch_test.go new file mode 100644 index 00000000000..96235530f4d --- /dev/null +++ b/experimental/aitools/cmd/batch_test.go @@ -0,0 +1,237 @@ +package aitools + +import ( + "context" + "errors" + "fmt" + "strconv" + "strings" + "testing" + + "github.com/databricks/cli/libs/cmdio" + mocksql "github.com/databricks/databricks-sdk-go/experimental/mocks/service/sql" + "github.com/databricks/databricks-sdk-go/service/sql" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + +func TestRenderBatchJSON(t *testing.T) { + results := []batchResult{ + { + SQL: "SELECT 1", + StatementID: "stmt-1", + State: sql.StatementStateSucceeded, + ElapsedMs: 42, + Columns: []string{"n"}, + Rows: [][]string{{"1"}}, + }, + { + SQL: "SELECT bad_syntax", + StatementID: "stmt-2", + State: sql.StatementStateFailed, + ElapsedMs: 12, + Error: &batchResultError{ + Message: "near 'bad_syntax': syntax error", + ErrorCode: "SYNTAX_ERROR", + }, + }, + } + + var buf strings.Builder + err := renderBatchJSON(&buf, results) + require.NoError(t, err) + + output := buf.String() + assert.Contains(t, output, `"sql": "SELECT 1"`) + assert.Contains(t, output, `"statement_id": "stmt-1"`) + assert.Contains(t, output, `"state": "SUCCEEDED"`) + assert.Contains(t, output, `"elapsed_ms": 42`) + assert.Contains(t, output, `"columns": [`) + assert.Contains(t, output, `"rows": [`) + assert.Contains(t, output, `"sql": "SELECT bad_syntax"`) + assert.Contains(t, output, `"error": {`) + assert.Contains(t, output, `"error_code": "SYNTAX_ERROR"`) + // Trailing newline. + assert.True(t, strings.HasSuffix(output, "\n")) +} + +func TestExecuteBatchAllSucceed(t *testing.T) { + ctx := cmdio.MockDiscard(t.Context()) + mockAPI := mocksql.NewMockStatementExecutionInterface(t) + + sqls := []string{"SELECT 1", "SELECT 2", "SELECT 3"} + for i, sqlStr := range sqls { + sid := fmt.Sprintf("stmt-%d", i+1) + mockAPI.EXPECT().ExecuteStatement(mock.Anything, mock.MatchedBy(func(req sql.ExecuteStatementRequest) bool { + return req.Statement == sqlStr + })).Return(&sql.StatementResponse{ + StatementId: sid, + Status: &sql.StatementStatus{State: sql.StatementStateSucceeded}, + Manifest: &sql.ResultManifest{Schema: &sql.ResultSchema{Columns: []sql.ColumnInfo{{Name: "n"}}}}, + Result: &sql.ResultData{DataArray: [][]string{{strconv.Itoa(i + 1)}}}, + }, nil).Once() + } + + results := executeBatch(ctx, mockAPI, "wh-123", sqls, 8) + + require.Len(t, results, 3) + for i, r := range results { + assert.Equal(t, sqls[i], r.SQL, "result %d sql", i) + assert.Equal(t, sql.StatementStateSucceeded, r.State, "result %d state", i) + assert.Nil(t, r.Error, "result %d error", i) + assert.Equal(t, []string{"n"}, r.Columns, "result %d columns", i) + assert.Equal(t, [][]string{{strconv.Itoa(i + 1)}}, r.Rows, "result %d rows", i) + assert.NotEmpty(t, r.StatementID, "result %d statement_id", i) + } +} + +func TestExecuteBatchPartialFailure(t *testing.T) { + ctx := cmdio.MockDiscard(t.Context()) + mockAPI := mocksql.NewMockStatementExecutionInterface(t) + + mockAPI.EXPECT().ExecuteStatement(mock.Anything, mock.MatchedBy(func(req sql.ExecuteStatementRequest) bool { + return req.Statement == "SELECT 1" + })).Return(&sql.StatementResponse{ + StatementId: "stmt-good", + Status: &sql.StatementStatus{State: sql.StatementStateSucceeded}, + Manifest: &sql.ResultManifest{Schema: &sql.ResultSchema{Columns: []sql.ColumnInfo{{Name: "n"}}}}, + Result: &sql.ResultData{DataArray: [][]string{{"1"}}}, + }, nil).Once() + + mockAPI.EXPECT().ExecuteStatement(mock.Anything, mock.MatchedBy(func(req sql.ExecuteStatementRequest) bool { + return req.Statement == "SELECT bad" + })).Return(&sql.StatementResponse{ + StatementId: "stmt-bad", + Status: &sql.StatementStatus{ + State: sql.StatementStateFailed, + Error: &sql.ServiceError{ + ErrorCode: "SYNTAX_ERROR", + Message: "near 'bad': syntax error", + }, + }, + }, nil).Once() + + results := executeBatch(ctx, mockAPI, "wh-123", []string{"SELECT 1", "SELECT bad"}, 8) + + require.Len(t, results, 2) + assert.Nil(t, results[0].Error) + assert.Equal(t, sql.StatementStateSucceeded, results[0].State) + + require.NotNil(t, results[1].Error) + assert.Equal(t, sql.StatementStateFailed, results[1].State) + assert.Equal(t, "SYNTAX_ERROR", results[1].Error.ErrorCode) + assert.Contains(t, results[1].Error.Message, "syntax error") +} + +func TestExecuteBatchSubmissionFailure(t *testing.T) { + // ExecuteStatement transport error is encoded into the per-result error, + // not propagated up to abort siblings. + ctx := cmdio.MockDiscard(t.Context()) + mockAPI := mocksql.NewMockStatementExecutionInterface(t) + + mockAPI.EXPECT().ExecuteStatement(mock.Anything, mock.MatchedBy(func(req sql.ExecuteStatementRequest) bool { + return req.Statement == "SELECT good" + })).Return(&sql.StatementResponse{ + StatementId: "stmt-good", + Status: &sql.StatementStatus{State: sql.StatementStateSucceeded}, + }, nil).Once() + + mockAPI.EXPECT().ExecuteStatement(mock.Anything, mock.MatchedBy(func(req sql.ExecuteStatementRequest) bool { + return req.Statement == "SELECT broken" + })).Return(nil, errors.New("network unreachable")).Once() + + results := executeBatch(ctx, mockAPI, "wh-123", []string{"SELECT good", "SELECT broken"}, 8) + + require.Len(t, results, 2) + assert.Nil(t, results[0].Error) + require.NotNil(t, results[1].Error) + assert.Contains(t, results[1].Error.Message, "execute statement") + assert.Contains(t, results[1].Error.Message, "network unreachable") + assert.Empty(t, results[1].StatementID) +} + +func TestExecuteBatchSetsOnWaitTimeoutContinue(t *testing.T) { + // Guards against a silent SDK default flip from CONTINUE to CANCEL. + ctx := cmdio.MockDiscard(t.Context()) + mockAPI := mocksql.NewMockStatementExecutionInterface(t) + + mockAPI.EXPECT().ExecuteStatement(mock.Anything, mock.MatchedBy(func(req sql.ExecuteStatementRequest) bool { + return req.WaitTimeout == "0s" && req.OnWaitTimeout == sql.ExecuteStatementRequestOnWaitTimeoutContinue + })).Return(&sql.StatementResponse{ + StatementId: "stmt-x", + Status: &sql.StatementStatus{State: sql.StatementStateSucceeded}, + }, nil).Times(2) + + results := executeBatch(ctx, mockAPI, "wh-123", []string{"q1", "q2"}, 8) + require.Len(t, results, 2) +} + +func TestExecuteBatchPreservesInputOrder(t *testing.T) { + // Index 0 is slow (PENDING then SUCCEEDED on first poll); 1 and 2 are + // immediate. Despite the staggered completion, results stay in input order. + ctx := cmdio.MockDiscard(t.Context()) + mockAPI := mocksql.NewMockStatementExecutionInterface(t) + + mockAPI.EXPECT().ExecuteStatement(mock.Anything, mock.MatchedBy(func(req sql.ExecuteStatementRequest) bool { + return req.Statement == "SELECT 'slow'" + })).Return(&sql.StatementResponse{ + StatementId: "stmt-slow", + Status: &sql.StatementStatus{State: sql.StatementStatePending}, + }, nil).Once() + + mockAPI.EXPECT().GetStatementByStatementId(mock.Anything, "stmt-slow").Return(&sql.StatementResponse{ + StatementId: "stmt-slow", + Status: &sql.StatementStatus{State: sql.StatementStateSucceeded}, + }, nil).Once() + + for i, sqlStr := range []string{"SELECT 'fast1'", "SELECT 'fast2'"} { + sid := fmt.Sprintf("stmt-fast-%d", i+1) + mockAPI.EXPECT().ExecuteStatement(mock.Anything, mock.MatchedBy(func(req sql.ExecuteStatementRequest) bool { + return req.Statement == sqlStr + })).Return(&sql.StatementResponse{ + StatementId: sid, + Status: &sql.StatementStatus{State: sql.StatementStateSucceeded}, + }, nil).Once() + } + + sqls := []string{"SELECT 'slow'", "SELECT 'fast1'", "SELECT 'fast2'"} + results := executeBatch(ctx, mockAPI, "wh-1", sqls, 8) + + require.Len(t, results, 3) + for i, r := range results { + assert.Equal(t, sqls[i], r.SQL, "result %d", i) + assert.Equal(t, sql.StatementStateSucceeded, r.State, "result %d", i) + } +} + +func TestExecuteBatchContextCancellationCancelsInFlight(t *testing.T) { + // All statements are PENDING when the context is cancelled. cancelInFlight + // sweeps the in-flight set with CancelExecution. + ctx, cancel := context.WithCancel(cmdio.MockDiscard(t.Context())) + mockAPI := mocksql.NewMockStatementExecutionInterface(t) + + for i, sqlStr := range []string{"q1", "q2", "q3"} { + sid := fmt.Sprintf("stmt-%d", i+1) + mockAPI.EXPECT().ExecuteStatement(mock.Anything, mock.MatchedBy(func(req sql.ExecuteStatementRequest) bool { + return req.Statement == sqlStr + })).Return(&sql.StatementResponse{ + StatementId: sid, + Status: &sql.StatementStatus{State: sql.StatementStatePending}, + }, nil).Once() + + mockAPI.EXPECT().CancelExecution(mock.Anything, sql.CancelExecutionRequest{ + StatementId: sid, + }).Return(nil).Once() + } + + cancel() + + results := executeBatch(ctx, mockAPI, "wh", []string{"q1", "q2", "q3"}, 8) + + require.Len(t, results, 3) + for i, r := range results { + assert.Equal(t, sql.StatementStateCanceled, r.State, "result %d state", i) + require.NotNil(t, r.Error, "result %d error", i) + } +} diff --git a/experimental/aitools/cmd/query.go b/experimental/aitools/cmd/query.go index 6c125bbcd6b..b7e4d5ede34 100644 --- a/experimental/aitools/cmd/query.go +++ b/experimental/aitools/cmd/query.go @@ -75,31 +75,40 @@ func selectQueryOutputMode(outputType flags.Output, stdoutInteractive, promptSup func newQueryCmd() *cobra.Command { var warehouseID string - var filePath string + var filePaths []string var outputFormat string + var concurrency int cmd := &cobra.Command{ - Use: "query [SQL | file.sql]", + Use: "query [SQL | file.sql]...", Short: "Execute SQL against a Databricks warehouse", - Long: `Execute a SQL statement against a Databricks SQL warehouse and return results. + Long: `Execute one or more SQL statements against a Databricks SQL warehouse +and return results. -SQL can be provided as a positional argument, read from a file with --file, -or piped via stdin. If the positional argument ends in .sql and the file -exists, it is read as a SQL file automatically. +A single SQL can be provided as a positional argument, read from a file with +--file, or piped via stdin. If a positional argument ends in .sql and the +file exists, it is read as a SQL file automatically. + +Pass multiple positional arguments and/or repeat --file to run several +queries in parallel against the warehouse. Multi-query output is always +JSON: an array of {sql, statement_id, state, elapsed_ms, columns, rows, +error} objects in input order. The exit code is non-zero if any query +failed. The command auto-detects an available warehouse unless --warehouse is set or the DATABRICKS_WAREHOUSE_ID environment variable is configured. -Output is JSON in non-interactive contexts. In interactive terminals it renders -tables, and large results open an interactive table browser. Use --output csv -to export results as CSV.`, +For a single query, output is JSON in non-interactive contexts. In +interactive terminals it renders tables, and large results open an +interactive table browser. Use --output csv to export results as CSV.`, Example: ` databricks experimental aitools tools query "SELECT * FROM samples.nyctaxi.trips LIMIT 5" databricks experimental aitools tools query --warehouse abc123 "SELECT 1" databricks experimental aitools tools query --file report.sql databricks experimental aitools tools query report.sql databricks experimental aitools tools query --output csv "SELECT * FROM samples.nyctaxi.trips LIMIT 5" + databricks experimental aitools tools query --output json "SELECT 1" "SELECT 2" "SELECT 3" echo "SELECT 1" | databricks experimental aitools tools query`, - Args: cobra.MaximumNArgs(1), + Args: cobra.ArbitraryArgs, PreRunE: root.MustWorkspaceClient, RunE: func(cmd *cobra.Command, args []string) error { ctx := cmd.Context() @@ -124,19 +133,29 @@ to export results as CSV.`, return fmt.Errorf("unsupported output format %q, accepted values: text, json, csv", outputFormat) } - w := cmdctx.WorkspaceClient(ctx) - - sqlStatement, err := resolveSQL(ctx, cmd, args, filePath) + sqls, err := resolveSQLs(ctx, cmd, args, filePaths) if err != nil { return err } + // Reject incompatible flag combinations before any API call so the + // user sees the real error instead of an auth/warehouse failure. + if len(sqls) > 1 && flags.Output(outputFormat) != flags.OutputJSON { + return fmt.Errorf("multiple queries require --output json (got %q); pass --output json to receive a JSON array of per-statement results", outputFormat) + } + + w := cmdctx.WorkspaceClient(ctx) + wID, err := resolveWarehouseID(ctx, w, warehouseID) if err != nil { return err } - resp, err := executeAndPoll(ctx, w.StatementExecution, wID, sqlStatement) + if len(sqls) > 1 { + return runBatch(ctx, cmd, w.StatementExecution, wID, sqls, concurrency) + } + + resp, err := executeAndPoll(ctx, w.StatementExecution, wID, sqls[0]) if err != nil { return err } @@ -177,7 +196,8 @@ to export results as CSV.`, } cmd.Flags().StringVarP(&warehouseID, "warehouse", "w", "", "SQL warehouse ID to use for execution") - cmd.Flags().StringVarP(&filePath, "file", "f", "", "Path to a SQL file to execute") + cmd.Flags().StringSliceVarP(&filePaths, "file", "f", nil, "Path to a SQL file to execute (repeatable; pair with positional SQLs to run a batch)") + cmd.Flags().IntVar(&concurrency, "concurrency", defaultBatchConcurrency, "Maximum in-flight statements when running a batch of queries") // Local --output flag shadows the root command's persistent --output flag, // adding csv support for this command only. cmd.Flags().StringVarP(&outputFormat, "output", "o", string(flags.OutputText), "Output format: text, json, or csv") @@ -188,59 +208,85 @@ to export results as CSV.`, return cmd } -// resolveSQL determines the SQL statement to execute from the available input sources. -// Priority: --file flag > positional arg > stdin. -func resolveSQL(ctx context.Context, cmd *cobra.Command, args []string, filePath string) (string, error) { - var raw string +// resolveSQLs collects SQL statements from --file paths, positional args, and +// stdin. The returned slice preserves source order: --file paths first (in flag +// order), then positional args (in arg order), then stdin (only if no other +// source produced anything). Each SQL is run through cleanSQL. +func resolveSQLs(ctx context.Context, cmd *cobra.Command, args, filePaths []string) ([]string, error) { + var raws []string - switch { - case filePath != "": - if len(args) > 0 { - return "", errors.New("cannot use both --file and a positional SQL argument") - } - data, err := os.ReadFile(filePath) + for _, path := range filePaths { + data, err := os.ReadFile(path) if err != nil { - return "", fmt.Errorf("read SQL file: %w", err) + return nil, fmt.Errorf("read SQL file %s: %w", path, err) } - raw = string(data) + raws = append(raws, string(data)) + } - case len(args) > 0: + for _, arg := range args { // If the argument looks like a .sql file, try to read it. // Only fall through to literal SQL if the file doesn't exist. // Surface other errors (permission denied, etc.) directly. - if strings.HasSuffix(args[0], sqlFileExtension) { - data, err := os.ReadFile(args[0]) + if strings.HasSuffix(arg, sqlFileExtension) { + data, err := os.ReadFile(arg) if err != nil && !errors.Is(err, os.ErrNotExist) { - return "", fmt.Errorf("read SQL file: %w", err) + return nil, fmt.Errorf("read SQL file: %w", err) } if err == nil { - raw = string(data) - break + raws = append(raws, string(data)) + continue } } - raw = args[0] + raws = append(raws, arg) + } - default: - // No args: try reading from stdin if it's piped. + if len(raws) == 0 { + // No --file and no positional args: try reading from stdin if it's piped. // If stdin was overridden (e.g. cmd.SetIn in tests), always read from it. // Otherwise, only read if stdin is not a TTY (i.e. piped input). in := cmd.InOrStdin() _, isOsFile := in.(*os.File) if isOsFile && cmdio.IsPromptSupported(ctx) { - return "", errors.New("no SQL provided; pass a SQL string, use --file, or pipe via stdin") + return nil, errors.New("no SQL provided; pass a SQL string, use --file, or pipe via stdin") } data, err := io.ReadAll(in) if err != nil { - return "", fmt.Errorf("read stdin: %w", err) + return nil, fmt.Errorf("read stdin: %w", err) + } + raws = append(raws, string(data)) + } + + cleaned := make([]string, 0, len(raws)) + for i, raw := range raws { + c := cleanSQL(raw) + if c == "" { + if len(raws) == 1 { + return nil, errors.New("SQL statement is empty after removing comments and blank lines") + } + return nil, fmt.Errorf("SQL statement #%d is empty after removing comments and blank lines", i+1) } - raw = string(data) + cleaned = append(cleaned, c) } + return cleaned, nil +} - result := cleanSQL(raw) - if result == "" { - return "", errors.New("SQL statement is empty after removing comments and blank lines") +// runBatch executes multiple SQL statements in parallel and renders the result +// as a JSON array. Returns root.ErrAlreadyPrinted (so the exit code is non-zero +// without an extra error message) when any statement failed; the failure detail +// is already encoded in the printed JSON. The caller is responsible for +// rejecting incompatible output formats before invoking this. +func runBatch(ctx context.Context, cmd *cobra.Command, api sql.StatementExecutionInterface, warehouseID string, sqls []string, concurrency int) error { + results := executeBatch(ctx, api, warehouseID, sqls, concurrency) + if err := renderBatchJSON(cmd.OutOrStdout(), results); err != nil { + return err } - return result, nil + + for _, r := range results { + if r.Error != nil { + return root.ErrAlreadyPrinted + } + } + return nil } // resolveWarehouseID returns the warehouse ID to use for query execution. diff --git a/experimental/aitools/cmd/query_test.go b/experimental/aitools/cmd/query_test.go index 4bc06c1d63b..a5d079acf8a 100644 --- a/experimental/aitools/cmd/query_test.go +++ b/experimental/aitools/cmd/query_test.go @@ -433,69 +433,95 @@ func TestPollingConstants(t *testing.T) { assert.Equal(t, 10*time.Second, cancelTimeout) } -// newTestCmd creates a minimal cobra.Command for testing resolveSQL. +// newTestCmd creates a minimal cobra.Command for testing resolveSQLs. func newTestCmd() *cobra.Command { return &cobra.Command{Use: "test"} } -func TestResolveSQLFromFileFlag(t *testing.T) { +func TestResolveSQLsFromFileFlag(t *testing.T) { dir := t.TempDir() path := filepath.Join(dir, "query.sql") err := os.WriteFile(path, []byte("SELECT 1"), 0o644) require.NoError(t, err) cmd := newTestCmd() - result, err := resolveSQL(cmdio.MockDiscard(t.Context()), cmd, nil, path) + result, err := resolveSQLs(cmdio.MockDiscard(t.Context()), cmd, nil, []string{path}) require.NoError(t, err) - assert.Equal(t, "SELECT 1", result) + assert.Equal(t, []string{"SELECT 1"}, result) } -func TestResolveSQLFromFileFlagWithComments(t *testing.T) { +func TestResolveSQLsFromFileFlagWithComments(t *testing.T) { dir := t.TempDir() path := filepath.Join(dir, "query.sql") err := os.WriteFile(path, []byte("-- header comment\nSELECT 1\n-- trailing"), 0o644) require.NoError(t, err) cmd := newTestCmd() - result, err := resolveSQL(cmdio.MockDiscard(t.Context()), cmd, nil, path) + result, err := resolveSQLs(cmdio.MockDiscard(t.Context()), cmd, nil, []string{path}) require.NoError(t, err) - assert.Equal(t, "SELECT 1", result) + assert.Equal(t, []string{"SELECT 1"}, result) } -func TestResolveSQLFileFlagConflictsWithArg(t *testing.T) { +func TestResolveSQLsMixedFileAndPositional(t *testing.T) { + // --file paths are emitted before positional args, in flag order. + dir := t.TempDir() + path := filepath.Join(dir, "from-file.sql") + err := os.WriteFile(path, []byte("SELECT 'from file'"), 0o644) + require.NoError(t, err) + cmd := newTestCmd() - _, err := resolveSQL(cmdio.MockDiscard(t.Context()), cmd, []string{"SELECT 1"}, "/some/file.sql") - require.Error(t, err) - assert.Contains(t, err.Error(), "cannot use both --file and a positional SQL argument") + result, err := resolveSQLs(cmdio.MockDiscard(t.Context()), cmd, []string{"SELECT 'from arg'"}, []string{path}) + require.NoError(t, err) + assert.Equal(t, []string{"SELECT 'from file'", "SELECT 'from arg'"}, result) +} + +func TestResolveSQLsMultiplePositional(t *testing.T) { + cmd := newTestCmd() + result, err := resolveSQLs(cmdio.MockDiscard(t.Context()), cmd, []string{"SELECT 1", "SELECT 2", "SELECT 3"}, nil) + require.NoError(t, err) + assert.Equal(t, []string{"SELECT 1", "SELECT 2", "SELECT 3"}, result) +} + +func TestResolveSQLsMultipleFiles(t *testing.T) { + dir := t.TempDir() + pathA := filepath.Join(dir, "a.sql") + pathB := filepath.Join(dir, "b.sql") + require.NoError(t, os.WriteFile(pathA, []byte("SELECT 'a'"), 0o644)) + require.NoError(t, os.WriteFile(pathB, []byte("SELECT 'b'"), 0o644)) + + cmd := newTestCmd() + result, err := resolveSQLs(cmdio.MockDiscard(t.Context()), cmd, nil, []string{pathA, pathB}) + require.NoError(t, err) + assert.Equal(t, []string{"SELECT 'a'", "SELECT 'b'"}, result) } -func TestResolveSQLFromPositionalArg(t *testing.T) { +func TestResolveSQLsFromPositionalArg(t *testing.T) { cmd := newTestCmd() - result, err := resolveSQL(cmdio.MockDiscard(t.Context()), cmd, []string{"SELECT 42"}, "") + result, err := resolveSQLs(cmdio.MockDiscard(t.Context()), cmd, []string{"SELECT 42"}, nil) require.NoError(t, err) - assert.Equal(t, "SELECT 42", result) + assert.Equal(t, []string{"SELECT 42"}, result) } -func TestResolveSQLAutoDetectsSQLFile(t *testing.T) { +func TestResolveSQLsAutoDetectsSQLFile(t *testing.T) { dir := t.TempDir() path := filepath.Join(dir, "report.sql") err := os.WriteFile(path, []byte("SELECT * FROM sales"), 0o644) require.NoError(t, err) cmd := newTestCmd() - result, err := resolveSQL(cmdio.MockDiscard(t.Context()), cmd, []string{path}, "") + result, err := resolveSQLs(cmdio.MockDiscard(t.Context()), cmd, []string{path}, nil) require.NoError(t, err) - assert.Equal(t, "SELECT * FROM sales", result) + assert.Equal(t, []string{"SELECT * FROM sales"}, result) } -func TestResolveSQLNonexistentSQLFileTreatedAsString(t *testing.T) { +func TestResolveSQLsNonexistentSQLFileTreatedAsString(t *testing.T) { cmd := newTestCmd() - result, err := resolveSQL(cmdio.MockDiscard(t.Context()), cmd, []string{"nonexistent.sql"}, "") + result, err := resolveSQLs(cmdio.MockDiscard(t.Context()), cmd, []string{"nonexistent.sql"}, nil) require.NoError(t, err) - assert.Equal(t, "nonexistent.sql", result) + assert.Equal(t, []string{"nonexistent.sql"}, result) } -func TestResolveSQLUnreadableSQLFileReturnsError(t *testing.T) { +func TestResolveSQLsUnreadableSQLFileReturnsError(t *testing.T) { dir := t.TempDir() path := filepath.Join(dir, "locked.sql") err := os.WriteFile(path, []byte("SELECT 1"), 0o644) @@ -507,47 +533,54 @@ func TestResolveSQLUnreadableSQLFileReturnsError(t *testing.T) { t.Cleanup(func() { _ = os.Chmod(path, 0o644) }) cmd := newTestCmd() - _, err = resolveSQL(cmdio.MockDiscard(t.Context()), cmd, []string{path}, "") + _, err = resolveSQLs(cmdio.MockDiscard(t.Context()), cmd, []string{path}, nil) require.Error(t, err) assert.Contains(t, err.Error(), "read SQL file") } -func TestResolveSQLFromStdin(t *testing.T) { +func TestResolveSQLsFromStdin(t *testing.T) { cmd := newTestCmd() cmd.SetIn(strings.NewReader("SELECT 1 FROM stdin_test")) - result, err := resolveSQL(cmdio.MockDiscard(t.Context()), cmd, nil, "") + result, err := resolveSQLs(cmdio.MockDiscard(t.Context()), cmd, nil, nil) require.NoError(t, err) - assert.Equal(t, "SELECT 1 FROM stdin_test", result) + assert.Equal(t, []string{"SELECT 1 FROM stdin_test"}, result) } -func TestResolveSQLEmptyFileReturnsError(t *testing.T) { +func TestResolveSQLsEmptyFileReturnsError(t *testing.T) { dir := t.TempDir() path := filepath.Join(dir, "empty.sql") err := os.WriteFile(path, []byte(""), 0o644) require.NoError(t, err) cmd := newTestCmd() - _, err = resolveSQL(cmdio.MockDiscard(t.Context()), cmd, nil, path) + _, err = resolveSQLs(cmdio.MockDiscard(t.Context()), cmd, nil, []string{path}) require.Error(t, err) assert.Contains(t, err.Error(), "empty") } -func TestResolveSQLCommentsOnlyFileReturnsError(t *testing.T) { +func TestResolveSQLsCommentsOnlyFileReturnsError(t *testing.T) { dir := t.TempDir() path := filepath.Join(dir, "comments.sql") err := os.WriteFile(path, []byte("-- just a comment\n-- another"), 0o644) require.NoError(t, err) cmd := newTestCmd() - _, err = resolveSQL(cmdio.MockDiscard(t.Context()), cmd, nil, path) + _, err = resolveSQLs(cmdio.MockDiscard(t.Context()), cmd, nil, []string{path}) require.Error(t, err) assert.Contains(t, err.Error(), "empty") } -func TestResolveSQLMissingFileReturnsError(t *testing.T) { +func TestResolveSQLsBatchEmptyAtIndexReturnsIndexedError(t *testing.T) { cmd := newTestCmd() - _, err := resolveSQL(cmdio.MockDiscard(t.Context()), cmd, nil, "/nonexistent/path/query.sql") + _, err := resolveSQLs(cmdio.MockDiscard(t.Context()), cmd, []string{"SELECT 1", "-- comment only", "SELECT 3"}, nil) + require.Error(t, err) + assert.Contains(t, err.Error(), "SQL statement #2 is empty") +} + +func TestResolveSQLsMissingFileReturnsError(t *testing.T) { + cmd := newTestCmd() + _, err := resolveSQLs(cmdio.MockDiscard(t.Context()), cmd, nil, []string{"/nonexistent/path/query.sql"}) require.Error(t, err) assert.Contains(t, err.Error(), "read SQL file") } @@ -561,6 +594,24 @@ func TestQueryCommandUnsupportedOutputReturnsError(t *testing.T) { assert.Contains(t, err.Error(), "unsupported output format") } +func TestQueryCommandBatchTextOutputRejected(t *testing.T) { + cmd := newQueryCmd() + cmd.PreRunE = nil + cmd.SetArgs([]string{"--output", "text", "SELECT 1", "SELECT 2"}) + err := cmd.Execute() + require.Error(t, err) + assert.Contains(t, err.Error(), "multiple queries require --output json") +} + +func TestQueryCommandBatchCsvOutputRejected(t *testing.T) { + cmd := newQueryCmd() + cmd.PreRunE = nil + cmd.SetArgs([]string{"--output", "csv", "SELECT 1", "SELECT 2"}) + err := cmd.Execute() + require.Error(t, err) + assert.Contains(t, err.Error(), "multiple queries require --output json") +} + func TestQueryCommandOutputFlagIsCaseInsensitive(t *testing.T) { cmd := newQueryCmd() cmd.PreRunE = nil diff --git a/experimental/aitools/cmd/render.go b/experimental/aitools/cmd/render.go index 7727c37106c..d0b62926c20 100644 --- a/experimental/aitools/cmd/render.go +++ b/experimental/aitools/cmd/render.go @@ -29,6 +29,17 @@ func extractColumns(manifest *sql.ResultManifest) []string { return columns } +// renderBatchJSON writes batch results as a JSON array. The array preserves +// input order and includes one object per submitted statement. +func renderBatchJSON(w io.Writer, results []batchResult) error { + output, err := json.MarshalIndent(results, "", " ") + if err != nil { + return fmt.Errorf("marshal batch results: %w", err) + } + fmt.Fprintf(w, "%s\n", output) + return nil +} + // renderJSON writes query results as a parseable JSON array to stdout. // Row count is written to stderr so stdout remains valid JSON for piping. func renderJSON(w io.Writer, columns []string, rows [][]string) error { From bc06013af476d113a5eb92ecfe77702cb9dd0e3f Mon Sep 17 00:00:00 2001 From: simon Date: Mon, 27 Apr 2026 15:49:30 +0200 Subject: [PATCH 03/14] aitools: validate --concurrency and document batch result order Address two findings from a cursor PR review: 1. --concurrency was passed straight into errgroup.SetLimit. A value of 0 deadlocks (errgroup refuses to add goroutines), and a negative value silently removes the cap. Add a PreRunE check that rejects anything <= 0 with errInvalidBatchConcurrency, matching the shape used by cmd/fs/cp.go for the same flag. 2. The Long help previously said multi-query results come back "in input order", which was ambiguous when --file and positional SQLs are mixed. The actual behavior (already covered by TestResolveSQLsMixedFileAndPositional) is: --file inputs first in flag order, then positional SQLs in arg order. Tighten the help text to state that contract precisely. Adds two unit tests that verify --concurrency 0 and -1 are rejected before any API call. Co-authored-by: Isaac --- experimental/aitools/cmd/batch.go | 5 +++++ experimental/aitools/cmd/query.go | 14 ++++++++++---- experimental/aitools/cmd/query_test.go | 16 ++++++++++++++++ 3 files changed, 31 insertions(+), 4 deletions(-) diff --git a/experimental/aitools/cmd/batch.go b/experimental/aitools/cmd/batch.go index 8965923c17c..3f8fc3015bb 100644 --- a/experimental/aitools/cmd/batch.go +++ b/experimental/aitools/cmd/batch.go @@ -2,6 +2,7 @@ package aitools import ( "context" + "errors" "fmt" "os" "os/signal" @@ -19,6 +20,10 @@ import ( // Matches the default used by cmd/fs/cp.go for similar fan-out work. const defaultBatchConcurrency = 8 +// errInvalidBatchConcurrency is returned when --concurrency is set to a value +// that errgroup.SetLimit can't honor (0 deadlocks, negative removes the cap). +var errInvalidBatchConcurrency = errors.New("--concurrency must be at least 1") + // batchResult is the per-statement payload emitted in batch mode JSON output. // State is the server-reported terminal state. Error is set whenever the // statement did not produce usable rows, regardless of state, so consumers diff --git a/experimental/aitools/cmd/query.go b/experimental/aitools/cmd/query.go index b7e4d5ede34..afe544c0e26 100644 --- a/experimental/aitools/cmd/query.go +++ b/experimental/aitools/cmd/query.go @@ -92,8 +92,9 @@ file exists, it is read as a SQL file automatically. Pass multiple positional arguments and/or repeat --file to run several queries in parallel against the warehouse. Multi-query output is always JSON: an array of {sql, statement_id, state, elapsed_ms, columns, rows, -error} objects in input order. The exit code is non-zero if any query -failed. +error} objects. Result order is: --file inputs first (in flag order), +then positional SQLs (in arg order). The exit code is non-zero if any +query failed. The command auto-detects an available warehouse unless --warehouse is set or the DATABRICKS_WAREHOUSE_ID environment variable is configured. @@ -108,8 +109,13 @@ interactive table browser. Use --output csv to export results as CSV.`, databricks experimental aitools tools query --output csv "SELECT * FROM samples.nyctaxi.trips LIMIT 5" databricks experimental aitools tools query --output json "SELECT 1" "SELECT 2" "SELECT 3" echo "SELECT 1" | databricks experimental aitools tools query`, - Args: cobra.ArbitraryArgs, - PreRunE: root.MustWorkspaceClient, + Args: cobra.ArbitraryArgs, + PreRunE: func(cmd *cobra.Command, args []string) error { + if concurrency <= 0 { + return errInvalidBatchConcurrency + } + return root.MustWorkspaceClient(cmd, args) + }, RunE: func(cmd *cobra.Command, args []string) error { ctx := cmd.Context() diff --git a/experimental/aitools/cmd/query_test.go b/experimental/aitools/cmd/query_test.go index a5d079acf8a..abd6ffe8341 100644 --- a/experimental/aitools/cmd/query_test.go +++ b/experimental/aitools/cmd/query_test.go @@ -612,6 +612,22 @@ func TestQueryCommandBatchCsvOutputRejected(t *testing.T) { assert.Contains(t, err.Error(), "multiple queries require --output json") } +func TestQueryCommandConcurrencyZeroRejected(t *testing.T) { + // errgroup.SetLimit(0) deadlocks; we reject it in PreRunE. + cmd := newQueryCmd() + cmd.SetArgs([]string{"--concurrency", "0", "--output", "json", "SELECT 1", "SELECT 2"}) + err := cmd.Execute() + require.ErrorIs(t, err, errInvalidBatchConcurrency) +} + +func TestQueryCommandConcurrencyNegativeRejected(t *testing.T) { + // Negative removes the cap entirely in errgroup, which surprises users. + cmd := newQueryCmd() + cmd.SetArgs([]string{"--concurrency", "-1", "--output", "json", "SELECT 1", "SELECT 2"}) + err := cmd.Execute() + require.ErrorIs(t, err, errInvalidBatchConcurrency) +} + func TestQueryCommandOutputFlagIsCaseInsensitive(t *testing.T) { cmd := newQueryCmd() cmd.PreRunE = nil From 4d15c81687a0b7971bbf00624d228546f2c646dd Mon Sep 17 00:00:00 2001 From: simon Date: Mon, 27 Apr 2026 21:36:58 +0200 Subject: [PATCH 04/14] aitools: fold redundant cobra-level rejection tests into table-driven cases MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Two pairs of cobra-level tests were each testing one rejection code path with two flag values. Fold them into table-driven subtests so the shared assertion lives in one place: - TestQueryCommandBatchTextOutputRejected + ...CsvOutputRejected → TestQueryCommandBatchOutputRejection (text, csv subtests) - TestQueryCommandConcurrencyZeroRejected + ...NegativeRejected → TestQueryCommandConcurrencyRejection (0, -1 subtests) Same coverage, half the test functions. Co-authored-by: Isaac --- experimental/aitools/cmd/query_test.go | 54 ++++++++++++-------------- 1 file changed, 24 insertions(+), 30 deletions(-) diff --git a/experimental/aitools/cmd/query_test.go b/experimental/aitools/cmd/query_test.go index abd6ffe8341..e6bf9362fff 100644 --- a/experimental/aitools/cmd/query_test.go +++ b/experimental/aitools/cmd/query_test.go @@ -594,38 +594,32 @@ func TestQueryCommandUnsupportedOutputReturnsError(t *testing.T) { assert.Contains(t, err.Error(), "unsupported output format") } -func TestQueryCommandBatchTextOutputRejected(t *testing.T) { - cmd := newQueryCmd() - cmd.PreRunE = nil - cmd.SetArgs([]string{"--output", "text", "SELECT 1", "SELECT 2"}) - err := cmd.Execute() - require.Error(t, err) - assert.Contains(t, err.Error(), "multiple queries require --output json") -} - -func TestQueryCommandBatchCsvOutputRejected(t *testing.T) { - cmd := newQueryCmd() - cmd.PreRunE = nil - cmd.SetArgs([]string{"--output", "csv", "SELECT 1", "SELECT 2"}) - err := cmd.Execute() - require.Error(t, err) - assert.Contains(t, err.Error(), "multiple queries require --output json") -} - -func TestQueryCommandConcurrencyZeroRejected(t *testing.T) { - // errgroup.SetLimit(0) deadlocks; we reject it in PreRunE. - cmd := newQueryCmd() - cmd.SetArgs([]string{"--concurrency", "0", "--output", "json", "SELECT 1", "SELECT 2"}) - err := cmd.Execute() - require.ErrorIs(t, err, errInvalidBatchConcurrency) +func TestQueryCommandBatchOutputRejection(t *testing.T) { + // Multi-query mode is JSON-only. text and csv are rejected with an + // actionable error before any API call. + for _, format := range []string{"text", "csv"} { + t.Run(format, func(t *testing.T) { + cmd := newQueryCmd() + cmd.PreRunE = nil + cmd.SetArgs([]string{"--output", format, "SELECT 1", "SELECT 2"}) + err := cmd.Execute() + require.Error(t, err) + assert.Contains(t, err.Error(), "multiple queries require --output json") + }) + } } -func TestQueryCommandConcurrencyNegativeRejected(t *testing.T) { - // Negative removes the cap entirely in errgroup, which surprises users. - cmd := newQueryCmd() - cmd.SetArgs([]string{"--concurrency", "-1", "--output", "json", "SELECT 1", "SELECT 2"}) - err := cmd.Execute() - require.ErrorIs(t, err, errInvalidBatchConcurrency) +func TestQueryCommandConcurrencyRejection(t *testing.T) { + // errgroup.SetLimit(0) deadlocks; negative removes the cap entirely. + // Both surprise users, so PreRunE rejects anything <= 0. + for _, value := range []string{"0", "-1"} { + t.Run(value, func(t *testing.T) { + cmd := newQueryCmd() + cmd.SetArgs([]string{"--concurrency", value, "--output", "json", "SELECT 1", "SELECT 2"}) + err := cmd.Execute() + require.ErrorIs(t, err, errInvalidBatchConcurrency) + }) + } } func TestQueryCommandOutputFlagIsCaseInsensitive(t *testing.T) { From a1c5ca637443c2a3f1f2153d3c0b5785a3849eb8 Mon Sep 17 00:00:00 2001 From: simon Date: Tue, 28 Apr 2026 10:13:02 +0200 Subject: [PATCH 05/14] aitools: detach cancel-RPC ctx from cancelled parent Address Arseni's P2 finding on the batch PR. cancelInFlight (batch.go) and cancelStatement (query.go) used to derive the cancel-RPC ctx via context.WithTimeout(ctx, cancelTimeout). On the actual hot path (Ctrl+C or parent ctx cancelled), the inbound ctx is already cancelled by the time we reach the cancel sweep. The SDK then short-circuits on ctx.Err() and the cancel RPC never reaches the warehouse, leaving in-flight statements running server-side. Wrap with context.WithoutCancel(ctx) (Go 1.21+) so the timeout context keeps the caller's values but drops the cancellation signal. The cancel RPC now actually fires. Also tighten the existing tests: - TestExecuteBatchContextCancellationCancelsInFlight - TestExecuteAndPollCancelledContextCallsCancelExecution Both previously matched mock.Anything for the ctx argument, so they passed regardless of whether the bug was present. They now use mock.MatchedBy(c.Err() == nil) to assert the cancel-RPC ctx is alive. This is a regression guard; reverting the production fix makes the tests fail with "unexpected call" because the matcher no longer matches. Co-authored-by: Isaac --- experimental/aitools/cmd/batch.go | 6 +++++- experimental/aitools/cmd/batch_test.go | 10 ++++++++-- experimental/aitools/cmd/query.go | 7 +++++-- experimental/aitools/cmd/query_test.go | 8 ++++++-- 4 files changed, 24 insertions(+), 7 deletions(-) diff --git a/experimental/aitools/cmd/batch.go b/experimental/aitools/cmd/batch.go index 3f8fc3015bb..38ecea531e6 100644 --- a/experimental/aitools/cmd/batch.go +++ b/experimental/aitools/cmd/batch.go @@ -198,7 +198,11 @@ func cancelInFlight(ctx context.Context, api sql.StatementExecutionInterface, st // marker meaning the goroutine bailed without telling the server. // Either way, send CancelExecution. } - cancelCtx, cancel := context.WithTimeout(ctx, cancelTimeout) + // Detach from the inbound ctx (which is typically already cancelled by + // the time we reach this sweep): WithoutCancel keeps the caller's + // values but drops the cancellation signal so the cancel RPC actually + // reaches the warehouse instead of short-circuiting on ctx.Err(). + cancelCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), cancelTimeout) if err := api.CancelExecution(cancelCtx, sql.CancelExecutionRequest{StatementId: sid}); err != nil { log.Warnf(ctx, "Failed to cancel statement %s: %v", sid, err) } diff --git a/experimental/aitools/cmd/batch_test.go b/experimental/aitools/cmd/batch_test.go index 96235530f4d..f6f468768f9 100644 --- a/experimental/aitools/cmd/batch_test.go +++ b/experimental/aitools/cmd/batch_test.go @@ -207,10 +207,16 @@ func TestExecuteBatchPreservesInputOrder(t *testing.T) { func TestExecuteBatchContextCancellationCancelsInFlight(t *testing.T) { // All statements are PENDING when the context is cancelled. cancelInFlight - // sweeps the in-flight set with CancelExecution. + // sweeps the in-flight set with CancelExecution. Each cancel RPC must + // carry a NON-cancelled context, otherwise the SDK short-circuits on + // ctx.Err() and never reaches the warehouse. ctx, cancel := context.WithCancel(cmdio.MockDiscard(t.Context())) mockAPI := mocksql.NewMockStatementExecutionInterface(t) + aliveCtx := mock.MatchedBy(func(c context.Context) bool { + return c.Err() == nil + }) + for i, sqlStr := range []string{"q1", "q2", "q3"} { sid := fmt.Sprintf("stmt-%d", i+1) mockAPI.EXPECT().ExecuteStatement(mock.Anything, mock.MatchedBy(func(req sql.ExecuteStatementRequest) bool { @@ -220,7 +226,7 @@ func TestExecuteBatchContextCancellationCancelsInFlight(t *testing.T) { Status: &sql.StatementStatus{State: sql.StatementStatePending}, }, nil).Once() - mockAPI.EXPECT().CancelExecution(mock.Anything, sql.CancelExecutionRequest{ + mockAPI.EXPECT().CancelExecution(aliveCtx, sql.CancelExecutionRequest{ StatementId: sid, }).Return(nil).Once() } diff --git a/experimental/aitools/cmd/query.go b/experimental/aitools/cmd/query.go index afe544c0e26..7e9ae1d030d 100644 --- a/experimental/aitools/cmd/query.go +++ b/experimental/aitools/cmd/query.go @@ -345,8 +345,11 @@ func executeAndPoll(ctx context.Context, api sql.StatementExecutionInterface, wa // cancelStatement performs best-effort server-side cancellation. // Called on any poll exit due to context cancellation (signal or parent). cancelStatement := func() { - // Use the parent context (ctx), not the cancelled pollCtx. - cancelCtx, cancel := context.WithTimeout(ctx, cancelTimeout) + // Detach from any cancellation on the inbound ctx (the caller might + // have cancelled the parent before invoking this path): WithoutCancel + // preserves values but drops cancellation so the cancel RPC actually + // reaches the warehouse. + cancelCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), cancelTimeout) defer cancel() if err := api.CancelExecution(cancelCtx, sql.CancelExecutionRequest{ StatementId: statementID, diff --git a/experimental/aitools/cmd/query_test.go b/experimental/aitools/cmd/query_test.go index e6bf9362fff..59de11d578a 100644 --- a/experimental/aitools/cmd/query_test.go +++ b/experimental/aitools/cmd/query_test.go @@ -146,8 +146,12 @@ func TestExecuteAndPollCancelledContextCallsCancelExecution(t *testing.T) { Status: &sql.StatementStatus{State: sql.StatementStatePending}, }, nil) - // CancelExecution must be called when context is cancelled (not just on signal). - mockAPI.EXPECT().CancelExecution(mock.Anything, sql.CancelExecutionRequest{ + // CancelExecution must be called when context is cancelled (not just on + // signal). Assert the RPC's own ctx is NOT cancelled, otherwise the SDK + // would short-circuit on ctx.Err() and never reach the warehouse. + mockAPI.EXPECT().CancelExecution(mock.MatchedBy(func(c context.Context) bool { + return c.Err() == nil + }), sql.CancelExecutionRequest{ StatementId: "stmt-1", }).Return(nil).Once() From 200116de13155574447f64c4819d846aec412088 Mon Sep 17 00:00:00 2001 From: simon Date: Mon, 27 Apr 2026 16:01:12 +0200 Subject: [PATCH 06/14] aitools: add 'tools statement' lifecycle commands Adds a low-level command tree for asynchronous SQL statement management, complementing the synchronous 'tools query': databricks experimental aitools tools statement submit "SELECT ..." databricks experimental aitools tools statement get databricks experimental aitools tools statement status databricks experimental aitools tools statement cancel submit fires an ExecuteStatement with WaitTimeout=0s and OnWaitTimeout=CONTINUE, returning the statement_id immediately. get polls (via pollStatement from #5092) until terminal and emits rows on success or an error object on failure. status performs a single GET without polling. cancel sends CancelExecution. All four subcommands emit a uniform JSON shape {statement_id, state, warehouse_id, columns, rows, error} with omitempty so the payload only includes fields that subcommand has. Important UX nuance: 'statement get' Ctrl+C stops polling but does NOT cancel the server-side statement. Users that want server-side termination call 'statement cancel' explicitly. (This differs from 'tools query', which cancels server-side on Ctrl+C because the user invoked the synchronous path.) The pollStatement helper from #5092 is already designed to propagate ctx errors without touching the server, so 'get' inherits this behavior for free. Co-authored-by: Isaac --- experimental/aitools/README.md | 17 ++ experimental/aitools/cmd/statement.go | 56 ++++ experimental/aitools/cmd/statement_cancel.go | 53 ++++ experimental/aitools/cmd/statement_get.go | 92 +++++++ experimental/aitools/cmd/statement_status.go | 57 ++++ experimental/aitools/cmd/statement_submit.go | 86 ++++++ experimental/aitools/cmd/statement_test.go | 260 +++++++++++++++++++ experimental/aitools/cmd/tools.go | 1 + 8 files changed, 622 insertions(+) create mode 100644 experimental/aitools/cmd/statement.go create mode 100644 experimental/aitools/cmd/statement_cancel.go create mode 100644 experimental/aitools/cmd/statement_get.go create mode 100644 experimental/aitools/cmd/statement_status.go create mode 100644 experimental/aitools/cmd/statement_submit.go create mode 100644 experimental/aitools/cmd/statement_test.go diff --git a/experimental/aitools/README.md b/experimental/aitools/README.md index f645e4de51d..ec12ed10f7c 100644 --- a/experimental/aitools/README.md +++ b/experimental/aitools/README.md @@ -10,6 +10,10 @@ Current commands: - `databricks experimental aitools tools query` - `databricks experimental aitools tools discover-schema` - `databricks experimental aitools tools get-default-warehouse` +- `databricks experimental aitools tools statement submit` +- `databricks experimental aitools tools statement get` +- `databricks experimental aitools tools statement status` +- `databricks experimental aitools tools statement cancel` Current behavior: @@ -29,6 +33,19 @@ Current behavior: "SELECT vendor_id, count(*) FROM samples.nyctaxi.trips GROUP BY 1" ``` +- `tools statement` is a low-level lifecycle for asynchronous statements. + `submit` returns a `statement_id` immediately, `get` polls until terminal + and emits rows, `status` peeks without blocking, and `cancel` requests + termination. Ctrl+C on `get` stops polling but does NOT cancel the + server-side statement; use `cancel` for that. + + ```bash + SID=$(databricks experimental aitools tools statement submit \ + --warehouse "SELECT pg_sleep(5)" | jq -r '.statement_id') + databricks experimental aitools tools statement status "$SID" + databricks experimental aitools tools statement get "$SID" + ``` + Removed behavior: - there is no MCP server under `experimental aitools` diff --git a/experimental/aitools/cmd/statement.go b/experimental/aitools/cmd/statement.go new file mode 100644 index 00000000000..3c2ba679e93 --- /dev/null +++ b/experimental/aitools/cmd/statement.go @@ -0,0 +1,56 @@ +package aitools + +import ( + "encoding/json" + "fmt" + "io" + + "github.com/databricks/databricks-sdk-go/service/sql" + "github.com/spf13/cobra" +) + +// statementInfo is the JSON shape emitted by every `tools statement` +// subcommand. Fields are populated as the subcommand has them. omitempty keeps +// the output tight: `submit` doesn't emit columns/rows, `cancel` doesn't emit a +// warehouse_id, etc. +type statementInfo struct { + StatementID string `json:"statement_id"` + State sql.StatementState `json:"state,omitempty"` + WarehouseID string `json:"warehouse_id,omitempty"` + Columns []string `json:"columns,omitempty"` + Rows [][]string `json:"rows,omitempty"` + Error *batchResultError `json:"error,omitempty"` +} + +func renderStatementInfo(w io.Writer, info statementInfo) error { + data, err := json.MarshalIndent(info, "", " ") + if err != nil { + return fmt.Errorf("marshal statement info: %w", err) + } + fmt.Fprintf(w, "%s\n", data) + return nil +} + +func newStatementCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "statement", + Short: "Manage SQL statement lifecycle (submit, get, status, cancel)", + Long: `Low-level command tree for asynchronous SQL execution. + +Use 'submit' to fire a statement and get its statement_id back, then +'get' to block on results, 'status' to peek without blocking, and +'cancel' to terminate. For "I want results now," use 'tools query' +instead. + +All subcommands emit a JSON object with the statement_id and state. +'get' adds columns and rows on success; any subcommand may emit an +error object when the server reports a non-success terminal state.`, + } + + cmd.AddCommand(newStatementSubmitCmd()) + cmd.AddCommand(newStatementGetCmd()) + cmd.AddCommand(newStatementStatusCmd()) + cmd.AddCommand(newStatementCancelCmd()) + + return cmd +} diff --git a/experimental/aitools/cmd/statement_cancel.go b/experimental/aitools/cmd/statement_cancel.go new file mode 100644 index 00000000000..1774b7abe6a --- /dev/null +++ b/experimental/aitools/cmd/statement_cancel.go @@ -0,0 +1,53 @@ +package aitools + +import ( + "context" + "fmt" + + "github.com/databricks/cli/cmd/root" + "github.com/databricks/cli/libs/cmdctx" + "github.com/databricks/databricks-sdk-go/service/sql" + "github.com/spf13/cobra" +) + +func newStatementCancelCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "cancel STATEMENT_ID", + Short: "Request cancellation of a running statement", + Long: `Send a cancellation request for the given statement_id. The Statements +API returns no body on cancel; this command optimistically reports +state=CANCELED on success. Use 'statement status' afterwards to confirm +the server-side state if you need certainty.`, + Example: ` databricks experimental aitools tools statement cancel 01ef...`, + Args: cobra.ExactArgs(1), + PreRunE: root.MustWorkspaceClient, + RunE: func(cmd *cobra.Command, args []string) error { + ctx := cmd.Context() + w := cmdctx.WorkspaceClient(ctx) + statementID := args[0] + + info, err := cancelStatementExecution(ctx, w.StatementExecution, statementID) + if err != nil { + return err + } + return renderStatementInfo(cmd.OutOrStdout(), info) + }, + } + + return cmd +} + +// cancelStatementExecution issues CancelExecution and reports state=CANCELED on success. +// CancelExecution returns no body; the actual server-side state is verified +// asynchronously. Use 'statement status' to confirm if certainty is required. +func cancelStatementExecution(ctx context.Context, api sql.StatementExecutionInterface, statementID string) (statementInfo, error) { + if err := api.CancelExecution(ctx, sql.CancelExecutionRequest{ + StatementId: statementID, + }); err != nil { + return statementInfo{}, fmt.Errorf("cancel statement: %w", err) + } + return statementInfo{ + StatementID: statementID, + State: sql.StatementStateCanceled, + }, nil +} diff --git a/experimental/aitools/cmd/statement_get.go b/experimental/aitools/cmd/statement_get.go new file mode 100644 index 00000000000..61fe8d4a91a --- /dev/null +++ b/experimental/aitools/cmd/statement_get.go @@ -0,0 +1,92 @@ +package aitools + +import ( + "context" + "fmt" + + "github.com/databricks/cli/cmd/root" + "github.com/databricks/cli/libs/cmdctx" + "github.com/databricks/databricks-sdk-go/service/sql" + "github.com/spf13/cobra" +) + +func newStatementGetCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "get STATEMENT_ID", + Short: "Block until a previously submitted statement is terminal and emit its result", + Long: `Poll a statement_id until it reaches a terminal state, then emit +columns and rows on success or an error object on failure. + +Ctrl+C stops polling but does NOT cancel the server-side statement. +Use 'statement cancel ' to terminate explicitly. (This differs from +'tools query', which cancels server-side on Ctrl+C because the user +invoked the synchronous path.)`, + Example: ` databricks experimental aitools tools statement get 01ef...`, + Args: cobra.ExactArgs(1), + PreRunE: root.MustWorkspaceClient, + RunE: func(cmd *cobra.Command, args []string) error { + ctx := cmd.Context() + w := cmdctx.WorkspaceClient(ctx) + statementID := args[0] + + info, err := getStatementResult(ctx, w.StatementExecution, statementID) + if err != nil { + return err + } + + if err := renderStatementInfo(cmd.OutOrStdout(), info); err != nil { + return err + } + + // Non-zero exit when the statement reached a non-success terminal + // state. The error info is already in the JSON output. + if info.State != sql.StatementStateSucceeded { + return root.ErrAlreadyPrinted + } + return nil + }, + } + + return cmd +} + +// getStatementResult polls a statement until terminal, then assembles a +// statementInfo with rows on success or an error object on failure. +// +// Context cancellation propagates from pollStatement WITHOUT cancelling the +// server-side statement (intentional: 'get' is a poll-only operation; use +// 'cancel' to terminate explicitly). +func getStatementResult(ctx context.Context, api sql.StatementExecutionInterface, statementID string) (statementInfo, error) { + // Fetch the current state first so pollStatement can short-circuit if + // the statement is already terminal. + resp, err := api.GetStatementByStatementId(ctx, statementID) + if err != nil { + return statementInfo{}, fmt.Errorf("get statement: %w", err) + } + + pollResp, err := pollStatement(ctx, api, resp) + if err != nil { + return statementInfo{}, err + } + + info := statementInfo{StatementID: pollResp.StatementId} + if pollResp.Status != nil { + info.State = pollResp.Status.State + if pollResp.Status.Error != nil { + info.Error = &batchResultError{ + Message: pollResp.Status.Error.Message, + ErrorCode: string(pollResp.Status.Error.ErrorCode), + } + } + } + + if info.State == sql.StatementStateSucceeded { + info.Columns = extractColumns(pollResp.Manifest) + rows, err := fetchAllRows(ctx, api, pollResp) + if err != nil { + return info, err + } + info.Rows = rows + } + return info, nil +} diff --git a/experimental/aitools/cmd/statement_status.go b/experimental/aitools/cmd/statement_status.go new file mode 100644 index 00000000000..8475255a941 --- /dev/null +++ b/experimental/aitools/cmd/statement_status.go @@ -0,0 +1,57 @@ +package aitools + +import ( + "context" + "fmt" + + "github.com/databricks/cli/cmd/root" + "github.com/databricks/cli/libs/cmdctx" + "github.com/databricks/databricks-sdk-go/service/sql" + "github.com/spf13/cobra" +) + +func newStatementStatusCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "status STATEMENT_ID", + Short: "Return the current state of a statement without polling", + Long: `Single GET against the Statements API. Use this to peek at progress +without blocking. For a blocking poll-until-terminal call, use +'statement get'.`, + Example: ` databricks experimental aitools tools statement status 01ef...`, + Args: cobra.ExactArgs(1), + PreRunE: root.MustWorkspaceClient, + RunE: func(cmd *cobra.Command, args []string) error { + ctx := cmd.Context() + w := cmdctx.WorkspaceClient(ctx) + statementID := args[0] + + info, err := getStatementStatus(ctx, w.StatementExecution, statementID) + if err != nil { + return err + } + return renderStatementInfo(cmd.OutOrStdout(), info) + }, + } + + return cmd +} + +// getStatementStatus performs a single GET against the Statements API, no polling. +func getStatementStatus(ctx context.Context, api sql.StatementExecutionInterface, statementID string) (statementInfo, error) { + resp, err := api.GetStatementByStatementId(ctx, statementID) + if err != nil { + return statementInfo{}, fmt.Errorf("get statement: %w", err) + } + + info := statementInfo{StatementID: resp.StatementId} + if resp.Status != nil { + info.State = resp.Status.State + if resp.Status.Error != nil { + info.Error = &batchResultError{ + Message: resp.Status.Error.Message, + ErrorCode: string(resp.Status.Error.ErrorCode), + } + } + } + return info, nil +} diff --git a/experimental/aitools/cmd/statement_submit.go b/experimental/aitools/cmd/statement_submit.go new file mode 100644 index 00000000000..56767d029de --- /dev/null +++ b/experimental/aitools/cmd/statement_submit.go @@ -0,0 +1,86 @@ +package aitools + +import ( + "context" + "errors" + "fmt" + + "github.com/databricks/cli/cmd/root" + "github.com/databricks/cli/libs/cmdctx" + "github.com/databricks/databricks-sdk-go/service/sql" + "github.com/spf13/cobra" +) + +func newStatementSubmitCmd() *cobra.Command { + var warehouseID string + var filePath string + + cmd := &cobra.Command{ + Use: "submit [SQL | file.sql]", + Short: "Submit a SQL statement asynchronously and return its statement_id", + Long: `Submit a SQL statement to a Databricks SQL warehouse and return its +statement_id immediately, without waiting for results. + +The statement keeps running server-side. Harvest results with +'statement get ', inspect with 'statement status ', or stop +with 'statement cancel '.`, + Example: ` databricks experimental aitools tools statement submit "SELECT pg_sleep(60)" --warehouse + databricks experimental aitools tools statement submit --file query.sql`, + Args: cobra.MaximumNArgs(1), + PreRunE: root.MustWorkspaceClient, + RunE: func(cmd *cobra.Command, args []string) error { + ctx := cmd.Context() + + var fps []string + if filePath != "" { + fps = []string{filePath} + } + sqls, err := resolveSQLs(ctx, cmd, args, fps) + if err != nil { + return err + } + if len(sqls) != 1 { + return errors.New("submit accepts exactly one SQL statement; pass multiple to 'query' for batch") + } + + w := cmdctx.WorkspaceClient(ctx) + wID, err := resolveWarehouseID(ctx, w, warehouseID) + if err != nil { + return err + } + + info, err := submitStatement(ctx, w.StatementExecution, sqls[0], wID) + if err != nil { + return err + } + return renderStatementInfo(cmd.OutOrStdout(), info) + }, + } + + cmd.Flags().StringVarP(&warehouseID, "warehouse", "w", "", "SQL warehouse ID to use for execution") + cmd.Flags().StringVarP(&filePath, "file", "f", "", "Path to a SQL file to execute") + + return cmd +} + +// submitStatement issues an asynchronous ExecuteStatement and returns the handle. +func submitStatement(ctx context.Context, api sql.StatementExecutionInterface, statement, warehouseID string) (statementInfo, error) { + resp, err := api.ExecuteStatement(ctx, sql.ExecuteStatementRequest{ + WarehouseId: warehouseID, + Statement: statement, + WaitTimeout: "0s", + OnWaitTimeout: sql.ExecuteStatementRequestOnWaitTimeoutContinue, + }) + if err != nil { + return statementInfo{}, fmt.Errorf("execute statement: %w", err) + } + + info := statementInfo{ + StatementID: resp.StatementId, + WarehouseID: warehouseID, + } + if resp.Status != nil { + info.State = resp.Status.State + } + return info, nil +} diff --git a/experimental/aitools/cmd/statement_test.go b/experimental/aitools/cmd/statement_test.go new file mode 100644 index 00000000000..419fa6e8249 --- /dev/null +++ b/experimental/aitools/cmd/statement_test.go @@ -0,0 +1,260 @@ +package aitools + +import ( + "context" + "errors" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/databricks/cli/libs/cmdio" + mocksql "github.com/databricks/databricks-sdk-go/experimental/mocks/service/sql" + "github.com/databricks/databricks-sdk-go/service/sql" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + +func TestSubmitStatementReturnsHandle(t *testing.T) { + ctx := cmdio.MockDiscard(t.Context()) + mockAPI := mocksql.NewMockStatementExecutionInterface(t) + + mockAPI.EXPECT().ExecuteStatement(mock.Anything, mock.MatchedBy(func(req sql.ExecuteStatementRequest) bool { + return req.WarehouseId == "wh-1" && req.Statement == "SELECT 1" && + req.WaitTimeout == "0s" && + req.OnWaitTimeout == sql.ExecuteStatementRequestOnWaitTimeoutContinue + })).Return(&sql.StatementResponse{ + StatementId: "stmt-1", + Status: &sql.StatementStatus{State: sql.StatementStatePending}, + }, nil).Once() + + info, err := submitStatement(ctx, mockAPI, "SELECT 1", "wh-1") + require.NoError(t, err) + assert.Equal(t, "stmt-1", info.StatementID) + assert.Equal(t, sql.StatementStatePending, info.State) + assert.Equal(t, "wh-1", info.WarehouseID) +} + +func TestSubmitStatementWrapsTransportError(t *testing.T) { + ctx := cmdio.MockDiscard(t.Context()) + mockAPI := mocksql.NewMockStatementExecutionInterface(t) + + mockAPI.EXPECT().ExecuteStatement(mock.Anything, mock.Anything). + Return(nil, errors.New("network unreachable")).Once() + + _, err := submitStatement(ctx, mockAPI, "SELECT 1", "wh-1") + require.Error(t, err) + assert.Contains(t, err.Error(), "execute statement") + assert.Contains(t, err.Error(), "network unreachable") +} + +func TestGetStatementResultPolls(t *testing.T) { + ctx := cmdio.MockDiscard(t.Context()) + mockAPI := mocksql.NewMockStatementExecutionInterface(t) + + mockAPI.EXPECT().GetStatementByStatementId(mock.Anything, "stmt-1").Return(&sql.StatementResponse{ + StatementId: "stmt-1", + Status: &sql.StatementStatus{State: sql.StatementStateRunning}, + }, nil).Once() + + mockAPI.EXPECT().GetStatementByStatementId(mock.Anything, "stmt-1").Return(&sql.StatementResponse{ + StatementId: "stmt-1", + Status: &sql.StatementStatus{State: sql.StatementStateSucceeded}, + Manifest: &sql.ResultManifest{Schema: &sql.ResultSchema{Columns: []sql.ColumnInfo{{Name: "n"}}}, TotalChunkCount: 1}, + Result: &sql.ResultData{DataArray: [][]string{{"42"}}}, + }, nil).Once() + + info, err := getStatementResult(ctx, mockAPI, "stmt-1") + require.NoError(t, err) + assert.Equal(t, "stmt-1", info.StatementID) + assert.Equal(t, sql.StatementStateSucceeded, info.State) + assert.Equal(t, []string{"n"}, info.Columns) + assert.Equal(t, [][]string{{"42"}}, info.Rows) + assert.Nil(t, info.Error) +} + +func TestGetStatementResultFailedStateReportsError(t *testing.T) { + ctx := cmdio.MockDiscard(t.Context()) + mockAPI := mocksql.NewMockStatementExecutionInterface(t) + + mockAPI.EXPECT().GetStatementByStatementId(mock.Anything, "stmt-1").Return(&sql.StatementResponse{ + StatementId: "stmt-1", + Status: &sql.StatementStatus{ + State: sql.StatementStateFailed, + Error: &sql.ServiceError{ + ErrorCode: "SYNTAX_ERROR", + Message: "near 'bad': syntax error", + }, + }, + }, nil).Once() + + info, err := getStatementResult(ctx, mockAPI, "stmt-1") + require.NoError(t, err) + assert.Equal(t, sql.StatementStateFailed, info.State) + assert.Nil(t, info.Rows) + require.NotNil(t, info.Error) + assert.Equal(t, "SYNTAX_ERROR", info.Error.ErrorCode) + assert.Contains(t, info.Error.Message, "syntax error") +} + +func TestGetStatementResultDoesNotCancelServerSideOnContextCancel(t *testing.T) { + // 'statement get' is a poll-only operation: ctx cancellation must NOT + // trigger CancelExecution. The mock asserts (via t.Cleanup) that no + // unexpected calls happen. + ctx, cancel := context.WithCancel(cmdio.MockDiscard(t.Context())) + mockAPI := mocksql.NewMockStatementExecutionInterface(t) + + mockAPI.EXPECT().GetStatementByStatementId(mock.Anything, "stmt-1").Return(&sql.StatementResponse{ + StatementId: "stmt-1", + Status: &sql.StatementStatus{State: sql.StatementStatePending}, + }, nil).Once() + + cancel() + + _, err := getStatementResult(ctx, mockAPI, "stmt-1") + require.ErrorIs(t, err, context.Canceled) +} + +func TestGetStatementStatusSinglePoll(t *testing.T) { + ctx := cmdio.MockDiscard(t.Context()) + mockAPI := mocksql.NewMockStatementExecutionInterface(t) + + mockAPI.EXPECT().GetStatementByStatementId(mock.Anything, "stmt-1").Return(&sql.StatementResponse{ + StatementId: "stmt-1", + Status: &sql.StatementStatus{State: sql.StatementStateRunning}, + }, nil).Once() + + info, err := getStatementStatus(ctx, mockAPI, "stmt-1") + require.NoError(t, err) + assert.Equal(t, "stmt-1", info.StatementID) + assert.Equal(t, sql.StatementStateRunning, info.State) + assert.Nil(t, info.Error) +} + +func TestGetStatementStatusReportsError(t *testing.T) { + ctx := cmdio.MockDiscard(t.Context()) + mockAPI := mocksql.NewMockStatementExecutionInterface(t) + + mockAPI.EXPECT().GetStatementByStatementId(mock.Anything, "stmt-1").Return(&sql.StatementResponse{ + StatementId: "stmt-1", + Status: &sql.StatementStatus{ + State: sql.StatementStateFailed, + Error: &sql.ServiceError{ + ErrorCode: "TIMEOUT", + Message: "warehouse timed out", + }, + }, + }, nil).Once() + + info, err := getStatementStatus(ctx, mockAPI, "stmt-1") + require.NoError(t, err) + assert.Equal(t, sql.StatementStateFailed, info.State) + require.NotNil(t, info.Error) + assert.Equal(t, "TIMEOUT", info.Error.ErrorCode) +} + +func TestCancelStatementExecutionCallsAPI(t *testing.T) { + ctx := cmdio.MockDiscard(t.Context()) + mockAPI := mocksql.NewMockStatementExecutionInterface(t) + + mockAPI.EXPECT().CancelExecution(mock.Anything, sql.CancelExecutionRequest{ + StatementId: "stmt-1", + }).Return(nil).Once() + + info, err := cancelStatementExecution(ctx, mockAPI, "stmt-1") + require.NoError(t, err) + assert.Equal(t, "stmt-1", info.StatementID) + assert.Equal(t, sql.StatementStateCanceled, info.State) +} + +func TestCancelStatementExecutionWrapsAPIError(t *testing.T) { + ctx := cmdio.MockDiscard(t.Context()) + mockAPI := mocksql.NewMockStatementExecutionInterface(t) + + mockAPI.EXPECT().CancelExecution(mock.Anything, mock.Anything). + Return(errors.New("not found")).Once() + + _, err := cancelStatementExecution(ctx, mockAPI, "stmt-1") + require.Error(t, err) + assert.Contains(t, err.Error(), "cancel statement") + assert.Contains(t, err.Error(), "not found") +} + +func TestRenderStatementInfo(t *testing.T) { + info := statementInfo{ + StatementID: "stmt-1", + State: sql.StatementStateSucceeded, + WarehouseID: "wh-1", + Columns: []string{"n"}, + Rows: [][]string{{"42"}}, + } + + var buf strings.Builder + require.NoError(t, renderStatementInfo(&buf, info)) + + output := buf.String() + assert.Contains(t, output, `"statement_id": "stmt-1"`) + assert.Contains(t, output, `"state": "SUCCEEDED"`) + assert.Contains(t, output, `"warehouse_id": "wh-1"`) + assert.Contains(t, output, `"columns": [`) + assert.Contains(t, output, `"rows": [`) + assert.True(t, strings.HasSuffix(output, "\n")) +} + +func TestRenderStatementInfoOmitsEmptyFields(t *testing.T) { + // Cancel-style payload: only statement_id + state. + info := statementInfo{ + StatementID: "stmt-1", + State: sql.StatementStateCanceled, + } + + var buf strings.Builder + require.NoError(t, renderStatementInfo(&buf, info)) + + output := buf.String() + assert.Contains(t, output, `"statement_id": "stmt-1"`) + assert.Contains(t, output, `"state": "CANCELED"`) + assert.NotContains(t, output, `"warehouse_id"`) + assert.NotContains(t, output, `"columns"`) + assert.NotContains(t, output, `"rows"`) + assert.NotContains(t, output, `"error"`) +} + +func TestStatementSubmitRejectsMultipleSQLs(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "test.sql") + require.NoError(t, os.WriteFile(path, []byte("SELECT 1"), 0o644)) + + cmd := newStatementSubmitCmd() + cmd.PreRunE = nil + cmd.SetArgs([]string{"--file", path, "SELECT 2"}) + err := cmd.Execute() + require.Error(t, err) + assert.Contains(t, err.Error(), "exactly one") +} + +func TestStatementSubmitArgsBound(t *testing.T) { + // MaximumNArgs(1) means cobra rejects 2+ positionals at parse time. + cmd := newStatementSubmitCmd() + cmd.PreRunE = nil + cmd.SetArgs([]string{"SELECT 1", "SELECT 2"}) + err := cmd.Execute() + require.Error(t, err) +} + +func TestStatementGetRequiresStatementID(t *testing.T) { + cmd := newStatementGetCmd() + cmd.PreRunE = nil + cmd.SetArgs([]string{}) + err := cmd.Execute() + require.Error(t, err) +} + +func TestStatementCancelRequiresStatementID(t *testing.T) { + cmd := newStatementCancelCmd() + cmd.PreRunE = nil + cmd.SetArgs([]string{}) + err := cmd.Execute() + require.Error(t, err) +} diff --git a/experimental/aitools/cmd/tools.go b/experimental/aitools/cmd/tools.go index b5dd306d210..22781f987f6 100644 --- a/experimental/aitools/cmd/tools.go +++ b/experimental/aitools/cmd/tools.go @@ -15,6 +15,7 @@ func newToolsCmd() *cobra.Command { cmd.AddCommand(newQueryCmd()) cmd.AddCommand(newDiscoverSchemaCmd()) cmd.AddCommand(newGetDefaultWarehouseCmd()) + cmd.AddCommand(newStatementCmd()) return cmd } From 422a41ccbdce1b2702f748c0433aff269463738a Mon Sep 17 00:00:00 2001 From: simon Date: Mon, 27 Apr 2026 16:30:36 +0200 Subject: [PATCH 07/14] aitools: always populate error for non-success terminal states Address a cursor PR review finding: 'statement get' and 'statement status' previously only set info.Error when pollResp.Status.Error was non-nil. The Statements API can return a non-success terminal state (FAILED, CANCELED, CLOSED) with no Error payload, so the JSON contract "emits rows on success or an error object on failure" wasn't actually guaranteed. Skill consumers couldn't branch on `error == null` alone: they had to also inspect `state`. Especially bad for 'get', which exits non-zero on non-success terminal states without giving the caller structured failure detail. Add a shared helper, statementErrorFromStatus, that returns a batchResultError for any terminal non-success state, populated from the SDK's ServiceError when present and synthesizing "statement reached terminal state X" when the backend doesn't supply one. Mirrors the pattern already used by runOneBatchQuery in batch.go, so the contract is uniform across batch and single-statement paths. Both 'get' and 'status' now use the helper. PENDING and RUNNING still emit no error (legitimately mid-flight). New tests: - table-driven coverage of statementErrorFromStatus across nil, succeeded, running, pending, failed-with-error, failed-no-error, canceled-no-error, closed-no-error - getStatementResult with CLOSED state and no Error - getStatementResult with FAILED state and no Error - getStatementStatus with FAILED state and no Error - getStatementStatus with RUNNING state confirms no error is set Co-authored-by: Isaac --- experimental/aitools/cmd/statement.go | 21 +++ experimental/aitools/cmd/statement_get.go | 7 +- experimental/aitools/cmd/statement_status.go | 7 +- experimental/aitools/cmd/statement_test.go | 138 +++++++++++++++++++ 4 files changed, 161 insertions(+), 12 deletions(-) diff --git a/experimental/aitools/cmd/statement.go b/experimental/aitools/cmd/statement.go index 3c2ba679e93..e1c48a7ddbe 100644 --- a/experimental/aitools/cmd/statement.go +++ b/experimental/aitools/cmd/statement.go @@ -31,6 +31,27 @@ func renderStatementInfo(w io.Writer, info statementInfo) error { return nil } +// statementErrorFromStatus builds a batchResultError for any terminal non-success +// state (FAILED, CANCELED, CLOSED), populating it from the server's ServiceError +// when available and synthesizing a message when it isn't. Returns nil for +// SUCCEEDED, non-terminal states, and nil status. The synthesized fallback +// matters because the Statements API can hand back a non-success terminal state +// with `Error == nil`, and skill consumers should be able to branch on +// `error == null` alone instead of inspecting `state`. +func statementErrorFromStatus(status *sql.StatementStatus) *batchResultError { + if status == nil || !isTerminalState(status) || status.State == sql.StatementStateSucceeded { + return nil + } + out := &batchResultError{} + if status.Error != nil { + out.Message = status.Error.Message + out.ErrorCode = string(status.Error.ErrorCode) + } else { + out.Message = fmt.Sprintf("statement reached terminal state %s", status.State) + } + return out +} + func newStatementCmd() *cobra.Command { cmd := &cobra.Command{ Use: "statement", diff --git a/experimental/aitools/cmd/statement_get.go b/experimental/aitools/cmd/statement_get.go index 61fe8d4a91a..f01b6488fdb 100644 --- a/experimental/aitools/cmd/statement_get.go +++ b/experimental/aitools/cmd/statement_get.go @@ -72,13 +72,8 @@ func getStatementResult(ctx context.Context, api sql.StatementExecutionInterface info := statementInfo{StatementID: pollResp.StatementId} if pollResp.Status != nil { info.State = pollResp.Status.State - if pollResp.Status.Error != nil { - info.Error = &batchResultError{ - Message: pollResp.Status.Error.Message, - ErrorCode: string(pollResp.Status.Error.ErrorCode), - } - } } + info.Error = statementErrorFromStatus(pollResp.Status) if info.State == sql.StatementStateSucceeded { info.Columns = extractColumns(pollResp.Manifest) diff --git a/experimental/aitools/cmd/statement_status.go b/experimental/aitools/cmd/statement_status.go index 8475255a941..9981f49aa63 100644 --- a/experimental/aitools/cmd/statement_status.go +++ b/experimental/aitools/cmd/statement_status.go @@ -46,12 +46,7 @@ func getStatementStatus(ctx context.Context, api sql.StatementExecutionInterface info := statementInfo{StatementID: resp.StatementId} if resp.Status != nil { info.State = resp.Status.State - if resp.Status.Error != nil { - info.Error = &batchResultError{ - Message: resp.Status.Error.Message, - ErrorCode: string(resp.Status.Error.ErrorCode), - } - } } + info.Error = statementErrorFromStatus(resp.Status) return info, nil } diff --git a/experimental/aitools/cmd/statement_test.go b/experimental/aitools/cmd/statement_test.go index 419fa6e8249..95015c2260a 100644 --- a/experimental/aitools/cmd/statement_test.go +++ b/experimental/aitools/cmd/statement_test.go @@ -258,3 +258,141 @@ func TestStatementCancelRequiresStatementID(t *testing.T) { err := cmd.Execute() require.Error(t, err) } + +func TestStatementErrorFromStatus(t *testing.T) { + tests := []struct { + name string + status *sql.StatementStatus + wantNil bool + wantMsg string + wantCode string + }{ + { + name: "nil status", + status: nil, + wantNil: true, + }, + { + name: "succeeded never produces an error", + status: &sql.StatementStatus{State: sql.StatementStateSucceeded}, + wantNil: true, + }, + { + name: "running is not terminal", + status: &sql.StatementStatus{State: sql.StatementStateRunning}, + wantNil: true, + }, + { + name: "pending is not terminal", + status: &sql.StatementStatus{State: sql.StatementStatePending}, + wantNil: true, + }, + { + name: "failed with backend error preserves both fields", + status: &sql.StatementStatus{ + State: sql.StatementStateFailed, + Error: &sql.ServiceError{ErrorCode: "SYNTAX_ERROR", Message: "near 'bad'"}, + }, + wantMsg: "near 'bad'", + wantCode: "SYNTAX_ERROR", + }, + { + name: "failed without backend error synthesizes message", + status: &sql.StatementStatus{State: sql.StatementStateFailed}, + wantMsg: "statement reached terminal state FAILED", + }, + { + name: "canceled without backend error synthesizes message", + status: &sql.StatementStatus{State: sql.StatementStateCanceled}, + wantMsg: "statement reached terminal state CANCELED", + }, + { + name: "closed without backend error synthesizes message", + status: &sql.StatementStatus{State: sql.StatementStateClosed}, + wantMsg: "statement reached terminal state CLOSED", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := statementErrorFromStatus(tc.status) + if tc.wantNil { + assert.Nil(t, got) + return + } + require.NotNil(t, got) + assert.Equal(t, tc.wantMsg, got.Message) + assert.Equal(t, tc.wantCode, got.ErrorCode) + }) + } +} + +func TestGetStatementResultClosedTerminalSynthesizesError(t *testing.T) { + // Statement reached CLOSED with no Error payload from the server. The shared + // statementInfo contract guarantees a non-nil Error for any non-success + // terminal state so consumers can branch on `error == null` alone. + ctx := cmdio.MockDiscard(t.Context()) + mockAPI := mocksql.NewMockStatementExecutionInterface(t) + + mockAPI.EXPECT().GetStatementByStatementId(mock.Anything, "stmt-1").Return(&sql.StatementResponse{ + StatementId: "stmt-1", + Status: &sql.StatementStatus{State: sql.StatementStateClosed}, + }, nil).Once() + + info, err := getStatementResult(ctx, mockAPI, "stmt-1") + require.NoError(t, err) + assert.Equal(t, sql.StatementStateClosed, info.State) + require.NotNil(t, info.Error) + assert.Contains(t, info.Error.Message, "CLOSED") + assert.Empty(t, info.Error.ErrorCode) + assert.Nil(t, info.Rows) +} + +func TestGetStatementResultFailedWithoutBackendErrorSynthesizesError(t *testing.T) { + ctx := cmdio.MockDiscard(t.Context()) + mockAPI := mocksql.NewMockStatementExecutionInterface(t) + + mockAPI.EXPECT().GetStatementByStatementId(mock.Anything, "stmt-1").Return(&sql.StatementResponse{ + StatementId: "stmt-1", + Status: &sql.StatementStatus{State: sql.StatementStateFailed}, + }, nil).Once() + + info, err := getStatementResult(ctx, mockAPI, "stmt-1") + require.NoError(t, err) + assert.Equal(t, sql.StatementStateFailed, info.State) + require.NotNil(t, info.Error) + assert.Contains(t, info.Error.Message, "FAILED") +} + +func TestGetStatementStatusFailedWithoutBackendErrorSynthesizesError(t *testing.T) { + ctx := cmdio.MockDiscard(t.Context()) + mockAPI := mocksql.NewMockStatementExecutionInterface(t) + + mockAPI.EXPECT().GetStatementByStatementId(mock.Anything, "stmt-1").Return(&sql.StatementResponse{ + StatementId: "stmt-1", + Status: &sql.StatementStatus{State: sql.StatementStateFailed}, + }, nil).Once() + + info, err := getStatementStatus(ctx, mockAPI, "stmt-1") + require.NoError(t, err) + assert.Equal(t, sql.StatementStateFailed, info.State) + require.NotNil(t, info.Error) + assert.Contains(t, info.Error.Message, "FAILED") +} + +func TestGetStatementStatusRunningHasNoError(t *testing.T) { + // PENDING/RUNNING legitimately have no error; the contract only requires + // error population for terminal non-success states. + ctx := cmdio.MockDiscard(t.Context()) + mockAPI := mocksql.NewMockStatementExecutionInterface(t) + + mockAPI.EXPECT().GetStatementByStatementId(mock.Anything, "stmt-1").Return(&sql.StatementResponse{ + StatementId: "stmt-1", + Status: &sql.StatementStatus{State: sql.StatementStateRunning}, + }, nil).Once() + + info, err := getStatementStatus(ctx, mockAPI, "stmt-1") + require.NoError(t, err) + assert.Equal(t, sql.StatementStateRunning, info.State) + assert.Nil(t, info.Error) +} From f5e586d87219fb8c77d65234dcf8be80a48e918e Mon Sep 17 00:00:00 2001 From: simon Date: Mon, 27 Apr 2026 21:39:32 +0200 Subject: [PATCH 08/14] aitools: drop redundant statement-lifecycle tests; fold render shape tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Self-review pass on the test suite found ~8 functions worth trimming without losing coverage: Drop (cobra built-ins, not our contract): - TestStatementSubmitArgsBound: tests cobra's MaximumNArgs(1) - TestStatementGetRequiresStatementID: tests cobra's ExactArgs(1) - TestStatementCancelRequiresStatementID: tests cobra's ExactArgs(1) Drop (already covered by TestStatementErrorFromStatus, the table-driven helper test added with the cursor-fix commit): - TestGetStatementResultClosedTerminalSynthesizesError - TestGetStatementResultFailedWithoutBackendErrorSynthesizesError - TestGetStatementStatusFailedWithoutBackendErrorSynthesizesError - TestGetStatementStatusRunningHasNoError Fold: - TestRenderStatementInfo + TestRenderStatementInfoOmitsEmptyFields → one table-driven TestRenderStatementInfo with the full and minimal cases as subtests. Kept the validation we actually wrote (TestStatementSubmitRejectsMultipleSQLs) and the wiring tests that pin distinct contracts (TestGetStatementResultPolls, TestGetStatementResultFailedStateReportsError, TestGetStatementResultDoesNotCancelServerSideOnContextCancel, TestGetStatementStatusSinglePoll, TestGetStatementStatusReportsError, the cancel pair, and submit pair). Co-authored-by: Isaac --- experimental/aitools/cmd/statement_test.go | 180 ++++++--------------- 1 file changed, 51 insertions(+), 129 deletions(-) diff --git a/experimental/aitools/cmd/statement_test.go b/experimental/aitools/cmd/statement_test.go index 95015c2260a..69768d2a159 100644 --- a/experimental/aitools/cmd/statement_test.go +++ b/experimental/aitools/cmd/statement_test.go @@ -182,46 +182,63 @@ func TestCancelStatementExecutionWrapsAPIError(t *testing.T) { } func TestRenderStatementInfo(t *testing.T) { - info := statementInfo{ - StatementID: "stmt-1", - State: sql.StatementStateSucceeded, - WarehouseID: "wh-1", - Columns: []string{"n"}, - Rows: [][]string{{"42"}}, + tests := []struct { + name string + info statementInfo + mustHave []string + mustNotHave []string + }{ + { + name: "full payload renders every populated field", + info: statementInfo{ + StatementID: "stmt-1", + State: sql.StatementStateSucceeded, + WarehouseID: "wh-1", + Columns: []string{"n"}, + Rows: [][]string{{"42"}}, + }, + mustHave: []string{ + `"statement_id": "stmt-1"`, + `"state": "SUCCEEDED"`, + `"warehouse_id": "wh-1"`, + `"columns": [`, + `"rows": [`, + }, + }, + { + name: "cancel-style payload omits unset fields", + info: statementInfo{ + StatementID: "stmt-1", + State: sql.StatementStateCanceled, + }, + mustHave: []string{ + `"statement_id": "stmt-1"`, + `"state": "CANCELED"`, + }, + mustNotHave: []string{`"warehouse_id"`, `"columns"`, `"rows"`, `"error"`}, + }, } - var buf strings.Builder - require.NoError(t, renderStatementInfo(&buf, info)) - - output := buf.String() - assert.Contains(t, output, `"statement_id": "stmt-1"`) - assert.Contains(t, output, `"state": "SUCCEEDED"`) - assert.Contains(t, output, `"warehouse_id": "wh-1"`) - assert.Contains(t, output, `"columns": [`) - assert.Contains(t, output, `"rows": [`) - assert.True(t, strings.HasSuffix(output, "\n")) -} - -func TestRenderStatementInfoOmitsEmptyFields(t *testing.T) { - // Cancel-style payload: only statement_id + state. - info := statementInfo{ - StatementID: "stmt-1", - State: sql.StatementStateCanceled, + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + var buf strings.Builder + require.NoError(t, renderStatementInfo(&buf, tc.info)) + out := buf.String() + for _, want := range tc.mustHave { + assert.Contains(t, out, want) + } + for _, missing := range tc.mustNotHave { + assert.NotContains(t, out, missing) + } + assert.True(t, strings.HasSuffix(out, "\n")) + }) } - - var buf strings.Builder - require.NoError(t, renderStatementInfo(&buf, info)) - - output := buf.String() - assert.Contains(t, output, `"statement_id": "stmt-1"`) - assert.Contains(t, output, `"state": "CANCELED"`) - assert.NotContains(t, output, `"warehouse_id"`) - assert.NotContains(t, output, `"columns"`) - assert.NotContains(t, output, `"rows"`) - assert.NotContains(t, output, `"error"`) } func TestStatementSubmitRejectsMultipleSQLs(t *testing.T) { + // The "exactly one SQL" check is something we wrote, so it earns a test. + // Cobra's own MaximumNArgs / ExactArgs enforcement is its own contract + // and is not asserted here. dir := t.TempDir() path := filepath.Join(dir, "test.sql") require.NoError(t, os.WriteFile(path, []byte("SELECT 1"), 0o644)) @@ -234,31 +251,6 @@ func TestStatementSubmitRejectsMultipleSQLs(t *testing.T) { assert.Contains(t, err.Error(), "exactly one") } -func TestStatementSubmitArgsBound(t *testing.T) { - // MaximumNArgs(1) means cobra rejects 2+ positionals at parse time. - cmd := newStatementSubmitCmd() - cmd.PreRunE = nil - cmd.SetArgs([]string{"SELECT 1", "SELECT 2"}) - err := cmd.Execute() - require.Error(t, err) -} - -func TestStatementGetRequiresStatementID(t *testing.T) { - cmd := newStatementGetCmd() - cmd.PreRunE = nil - cmd.SetArgs([]string{}) - err := cmd.Execute() - require.Error(t, err) -} - -func TestStatementCancelRequiresStatementID(t *testing.T) { - cmd := newStatementCancelCmd() - cmd.PreRunE = nil - cmd.SetArgs([]string{}) - err := cmd.Execute() - require.Error(t, err) -} - func TestStatementErrorFromStatus(t *testing.T) { tests := []struct { name string @@ -326,73 +318,3 @@ func TestStatementErrorFromStatus(t *testing.T) { }) } } - -func TestGetStatementResultClosedTerminalSynthesizesError(t *testing.T) { - // Statement reached CLOSED with no Error payload from the server. The shared - // statementInfo contract guarantees a non-nil Error for any non-success - // terminal state so consumers can branch on `error == null` alone. - ctx := cmdio.MockDiscard(t.Context()) - mockAPI := mocksql.NewMockStatementExecutionInterface(t) - - mockAPI.EXPECT().GetStatementByStatementId(mock.Anything, "stmt-1").Return(&sql.StatementResponse{ - StatementId: "stmt-1", - Status: &sql.StatementStatus{State: sql.StatementStateClosed}, - }, nil).Once() - - info, err := getStatementResult(ctx, mockAPI, "stmt-1") - require.NoError(t, err) - assert.Equal(t, sql.StatementStateClosed, info.State) - require.NotNil(t, info.Error) - assert.Contains(t, info.Error.Message, "CLOSED") - assert.Empty(t, info.Error.ErrorCode) - assert.Nil(t, info.Rows) -} - -func TestGetStatementResultFailedWithoutBackendErrorSynthesizesError(t *testing.T) { - ctx := cmdio.MockDiscard(t.Context()) - mockAPI := mocksql.NewMockStatementExecutionInterface(t) - - mockAPI.EXPECT().GetStatementByStatementId(mock.Anything, "stmt-1").Return(&sql.StatementResponse{ - StatementId: "stmt-1", - Status: &sql.StatementStatus{State: sql.StatementStateFailed}, - }, nil).Once() - - info, err := getStatementResult(ctx, mockAPI, "stmt-1") - require.NoError(t, err) - assert.Equal(t, sql.StatementStateFailed, info.State) - require.NotNil(t, info.Error) - assert.Contains(t, info.Error.Message, "FAILED") -} - -func TestGetStatementStatusFailedWithoutBackendErrorSynthesizesError(t *testing.T) { - ctx := cmdio.MockDiscard(t.Context()) - mockAPI := mocksql.NewMockStatementExecutionInterface(t) - - mockAPI.EXPECT().GetStatementByStatementId(mock.Anything, "stmt-1").Return(&sql.StatementResponse{ - StatementId: "stmt-1", - Status: &sql.StatementStatus{State: sql.StatementStateFailed}, - }, nil).Once() - - info, err := getStatementStatus(ctx, mockAPI, "stmt-1") - require.NoError(t, err) - assert.Equal(t, sql.StatementStateFailed, info.State) - require.NotNil(t, info.Error) - assert.Contains(t, info.Error.Message, "FAILED") -} - -func TestGetStatementStatusRunningHasNoError(t *testing.T) { - // PENDING/RUNNING legitimately have no error; the contract only requires - // error population for terminal non-success states. - ctx := cmdio.MockDiscard(t.Context()) - mockAPI := mocksql.NewMockStatementExecutionInterface(t) - - mockAPI.EXPECT().GetStatementByStatementId(mock.Anything, "stmt-1").Return(&sql.StatementResponse{ - StatementId: "stmt-1", - Status: &sql.StatementStatus{State: sql.StatementStateRunning}, - }, nil).Once() - - info, err := getStatementStatus(ctx, mockAPI, "stmt-1") - require.NoError(t, err) - assert.Equal(t, sql.StatementStateRunning, info.State) - assert.Nil(t, info.Error) -} From 9b52b65bd178770add1e8f54a6dfcc1a38a35d71 Mon Sep 17 00:00:00 2001 From: simon Date: Tue, 28 Apr 2026 10:17:56 +0200 Subject: [PATCH 09/14] aitools: render partial JSON on chunk-fetch failure; reject submit input before auth Address two findings from Arseni's review. P2 (statement_get.go): getStatementResult used to return (info, err) when fetchAllRows failed after a SUCCEEDED state. RunE then discarded the populated info and surfaced only the raw Go error, so the user got an unstructured "fetch result chunk N: ..." string with no statement_id and no machine-readable error field. That contradicts the contract in the failed-terminal path two cases above, which renders JSON and returns root.ErrAlreadyPrinted. Now: on chunk-fetch failure, populate info.Error with the chunk-fetch message and return (info, nil). RunE renders the partial info as JSON and signals exit-non-zero based on info.Error != nil. The caller still gets statement_id and columns; the error field carries the failure detail. New test TestGetStatementResultChunkFetchFailureRendersPartialInfo locks this in. P3 (statement_submit.go): The PR description claims submit validates input before accessing WorkspaceClient. The code didn't actually deliver that: PreRunE was root.MustWorkspaceClient (auth/profile setup), then RunE did the resolveSQLs / "exactly one" checks. So a malformed invocation hit auth errors before ever surfacing the input error. Move resolveSQLs and the length check into a custom PreRunE that runs before root.MustWorkspaceClient, mirroring the pattern in query.go:113-118. The result is stashed in a closure variable (sqlStatement) for RunE to consume. Existing test TestStatementSubmitRejectsMultipleSQLs is renamed to ...BeforeWorkspaceClient and no longer needs to stub out PreRunE: the new ordering means a bad invocation gets the validation error without ever attempting workspace-client setup. Co-authored-by: Isaac --- experimental/aitools/cmd/statement_get.go | 15 +++++-- experimental/aitools/cmd/statement_submit.go | 16 ++++++-- experimental/aitools/cmd/statement_test.go | 42 +++++++++++++++++--- 3 files changed, 61 insertions(+), 12 deletions(-) diff --git a/experimental/aitools/cmd/statement_get.go b/experimental/aitools/cmd/statement_get.go index f01b6488fdb..617b5c274dd 100644 --- a/experimental/aitools/cmd/statement_get.go +++ b/experimental/aitools/cmd/statement_get.go @@ -39,8 +39,9 @@ invoked the synchronous path.)`, } // Non-zero exit when the statement reached a non-success terminal - // state. The error info is already in the JSON output. - if info.State != sql.StatementStateSucceeded { + // state OR a chunk-fetch failure prevented assembling the rows. + // In both cases the failure detail is already in the JSON output. + if info.State != sql.StatementStateSucceeded || info.Error != nil { return root.ErrAlreadyPrinted } return nil @@ -79,7 +80,15 @@ func getStatementResult(ctx context.Context, api sql.StatementExecutionInterface info.Columns = extractColumns(pollResp.Manifest) rows, err := fetchAllRows(ctx, api, pollResp) if err != nil { - return info, err + // The query succeeded server-side but a later chunk fetch failed + // (network blip, throttling, transient 5xx). Surface this as a + // structured error on the same statementInfo so the caller still + // gets a parseable JSON response with the statement_id; RunE then + // signals exit-non-zero based on info.Error. + info.Error = &batchResultError{ + Message: fmt.Sprintf("fetch result rows: %v", err), + } + return info, nil } info.Rows = rows } diff --git a/experimental/aitools/cmd/statement_submit.go b/experimental/aitools/cmd/statement_submit.go index 56767d029de..ac8bf424e5f 100644 --- a/experimental/aitools/cmd/statement_submit.go +++ b/experimental/aitools/cmd/statement_submit.go @@ -14,6 +14,10 @@ import ( func newStatementSubmitCmd() *cobra.Command { var warehouseID string var filePath string + // resolved by PreRunE so input validation runs before any auth/profile + // work and the documented "validates input before WorkspaceClient" claim + // in the PR description is actually true. + var sqlStatement string cmd := &cobra.Command{ Use: "submit [SQL | file.sql]", @@ -26,9 +30,8 @@ The statement keeps running server-side. Harvest results with with 'statement cancel '.`, Example: ` databricks experimental aitools tools statement submit "SELECT pg_sleep(60)" --warehouse databricks experimental aitools tools statement submit --file query.sql`, - Args: cobra.MaximumNArgs(1), - PreRunE: root.MustWorkspaceClient, - RunE: func(cmd *cobra.Command, args []string) error { + Args: cobra.MaximumNArgs(1), + PreRunE: func(cmd *cobra.Command, args []string) error { ctx := cmd.Context() var fps []string @@ -42,14 +45,19 @@ with 'statement cancel '.`, if len(sqls) != 1 { return errors.New("submit accepts exactly one SQL statement; pass multiple to 'query' for batch") } + sqlStatement = sqls[0] + return root.MustWorkspaceClient(cmd, args) + }, + RunE: func(cmd *cobra.Command, args []string) error { + ctx := cmd.Context() w := cmdctx.WorkspaceClient(ctx) wID, err := resolveWarehouseID(ctx, w, warehouseID) if err != nil { return err } - info, err := submitStatement(ctx, w.StatementExecution, sqls[0], wID) + info, err := submitStatement(ctx, w.StatementExecution, sqlStatement, wID) if err != nil { return err } diff --git a/experimental/aitools/cmd/statement_test.go b/experimental/aitools/cmd/statement_test.go index 69768d2a159..9c2264daf2c 100644 --- a/experimental/aitools/cmd/statement_test.go +++ b/experimental/aitools/cmd/statement_test.go @@ -116,6 +116,37 @@ func TestGetStatementResultDoesNotCancelServerSideOnContextCancel(t *testing.T) require.ErrorIs(t, err, context.Canceled) } +func TestGetStatementResultChunkFetchFailureRendersPartialInfo(t *testing.T) { + // SUCCEEDED state but a later chunk fetch fails (network blip, throttle, + // 5xx). getStatementResult should surface this as a structured error on + // the same statementInfo so the caller still gets parseable JSON with the + // statement_id, instead of returning a raw Go error that RunE would + // discard along with the populated info. + ctx := cmdio.MockDiscard(t.Context()) + mockAPI := mocksql.NewMockStatementExecutionInterface(t) + + mockAPI.EXPECT().GetStatementByStatementId(mock.Anything, "stmt-1").Return(&sql.StatementResponse{ + StatementId: "stmt-1", + Status: &sql.StatementStatus{State: sql.StatementStateSucceeded}, + Manifest: &sql.ResultManifest{ + Schema: &sql.ResultSchema{Columns: []sql.ColumnInfo{{Name: "n"}}}, + TotalChunkCount: 2, + }, + Result: &sql.ResultData{DataArray: [][]string{{"1"}}}, + }, nil).Once() + + mockAPI.EXPECT().GetStatementResultChunkNByStatementIdAndChunkIndex(mock.Anything, "stmt-1", 1). + Return(nil, errors.New("network blip")).Once() + + info, err := getStatementResult(ctx, mockAPI, "stmt-1") + require.NoError(t, err) + assert.Equal(t, sql.StatementStateSucceeded, info.State) + assert.Equal(t, []string{"n"}, info.Columns, "columns from the initial response are still surfaced") + require.NotNil(t, info.Error) + assert.Contains(t, info.Error.Message, "fetch result rows") + assert.Contains(t, info.Error.Message, "network blip") +} + func TestGetStatementStatusSinglePoll(t *testing.T) { ctx := cmdio.MockDiscard(t.Context()) mockAPI := mocksql.NewMockStatementExecutionInterface(t) @@ -235,16 +266,17 @@ func TestRenderStatementInfo(t *testing.T) { } } -func TestStatementSubmitRejectsMultipleSQLs(t *testing.T) { - // The "exactly one SQL" check is something we wrote, so it earns a test. - // Cobra's own MaximumNArgs / ExactArgs enforcement is its own contract - // and is not asserted here. +func TestStatementSubmitRejectsMultipleSQLsBeforeWorkspaceClient(t *testing.T) { + // The "exactly one SQL" check runs in PreRunE BEFORE MustWorkspaceClient, + // so a malformed invocation is rejected without any auth/profile work. + // The test relies on this ordering: it does not stub out PreRunE, so if + // validation moved back after MustWorkspaceClient the test would panic + // on a missing workspace client instead of returning the validation error. dir := t.TempDir() path := filepath.Join(dir, "test.sql") require.NoError(t, os.WriteFile(path, []byte("SELECT 1"), 0o644)) cmd := newStatementSubmitCmd() - cmd.PreRunE = nil cmd.SetArgs([]string{"--file", path, "SELECT 2"}) err := cmd.Execute() require.Error(t, err) From 80e5b46d109f88c2c7c26270d62c2b6a4a849736 Mon Sep 17 00:00:00 2001 From: simon Date: Mon, 27 Apr 2026 18:07:23 +0200 Subject: [PATCH 10/14] aitools: parallelize discover-schema across tables and probes discover-schema previously walked tables sequentially and ran each table's three probes (DESCRIBE, sample SELECT, null counts) one after the other. For ai-dev-kit's data-exploration phase that meant warehouse-bound work was idle most of the time. Same root cause as the multi-query exploration latency that PR 2 fixed; same fix. Two layers of parallelism: 1. Tables fan out via errgroup with --concurrency (default 8). A failure on one table never aborts the others; it gets rendered inline as "Error discovering ...". 2. Within a table, DESCRIBE still runs first because the column list feeds the null-counts query. After DESCRIBE returns, the sample SELECT and null-counts probes run concurrently. The output text is assembled once both finish, preserving the existing column order (COLUMNS, SAMPLE DATA, NULL COUNTS). Switch executeSQL from the SDK's ExecuteAndWait helper to ExecuteStatement + pollStatement (the helper extracted in #5092). This brings discover-schema in line with query.go and statement.go: explicit OnWaitTimeout=CONTINUE on every call, and any future polling-helper improvement (e.g. signal handling) lands here for free. Failed states now flow through checkFailedState, which yields more specific error messages (e.g. "query failed: SYNTAX_ERROR ...") than the previous hand-rolled branch. The user-visible "SAMPLE DATA: Error - %v" / "NULL COUNTS: Error - %v" wrapping is unchanged. Add --concurrency validation matching the cmd/fs/cp.go and experimental/aitools/cmd/query.go pattern: PreRunE rejects values <= 0 with errInvalidBatchConcurrency. Tests added in discover_schema_test.go: - quoteTableName (table-driven across valid identifiers, missing parts, injection attempts, empty parts, leading-digit identifiers) - parseDescribeResult skipping metadata rows - executeSQL pins OnWaitTimeout=CONTINUE - executeSQL propagates server-reported FAILED state - executeSQL wraps transport errors - discoverTable: sample and null-count probes run concurrently after DESCRIBE (atomic peak-counter assertion) - discoverTable: a sample failure does not abort null counts - --concurrency 0 and -1 rejected at PreRunE time - invalid table name (not CATALOG.SCHEMA.TABLE) rejected at RunE validation before any API call Co-authored-by: Isaac --- experimental/aitools/cmd/discover_schema.go | 119 ++++++--- .../aitools/cmd/discover_schema_test.go | 236 ++++++++++++++++++ 2 files changed, 314 insertions(+), 41 deletions(-) create mode 100644 experimental/aitools/cmd/discover_schema_test.go diff --git a/experimental/aitools/cmd/discover_schema.go b/experimental/aitools/cmd/discover_schema.go index fad77cd4d17..46d8012c432 100644 --- a/experimental/aitools/cmd/discover_schema.go +++ b/experimental/aitools/cmd/discover_schema.go @@ -15,11 +15,14 @@ import ( "github.com/databricks/databricks-sdk-go" dbsql "github.com/databricks/databricks-sdk-go/service/sql" "github.com/spf13/cobra" + "golang.org/x/sync/errgroup" ) var sqlIdentifierRe = regexp.MustCompile(`^[A-Za-z_][A-Za-z0-9_]*$`) func newDiscoverSchemaCmd() *cobra.Command { + var concurrency int + cmd := &cobra.Command{ Use: "discover-schema TABLE...", Short: "Discover schema for one or more tables", @@ -31,14 +34,22 @@ For each table, returns: - Column names and types - Sample data (5 rows) - Null counts per column -- Total row count`, +- Total row count + +Multiple tables are discovered in parallel against the warehouse, capped +by --concurrency (default 8). Within a single table, the sample-data and +null-counts probes also run in parallel after the column list is known.`, Example: ` databricks experimental aitools tools discover-schema samples.nyctaxi.trips databricks experimental aitools tools discover-schema catalog.schema.table1 catalog.schema.table2`, - Args: cobra.MinimumNArgs(1), - PreRunE: root.MustWorkspaceClient, + Args: cobra.MinimumNArgs(1), + PreRunE: func(cmd *cobra.Command, args []string) error { + if concurrency <= 0 { + return errInvalidBatchConcurrency + } + return root.MustWorkspaceClient(cmd, args) + }, RunE: func(cmd *cobra.Command, args []string) error { ctx := cmd.Context() - w := cmdctx.WorkspaceClient(ctx) // validate table names: each part must be a safe SQL identifier for _, table := range args { @@ -47,6 +58,8 @@ For each table, returns: } } + w := cmdctx.WorkspaceClient(ctx) + // set up session with client for middleware compatibility sess := session.NewSession() sess.Set(middlewares.DatabricksClientKey, w) @@ -57,14 +70,22 @@ For each table, returns: return err } - var results []string - for _, table := range args { - result, err := discoverTable(ctx, w, warehouseID, table) - if err != nil { - result = fmt.Sprintf("Error discovering %s: %v", table, err) - } - results = append(results, result) + results := make([]string, len(args)) + g := new(errgroup.Group) + g.SetLimit(concurrency) + for i, table := range args { + g.Go(func() error { + result, err := discoverTable(ctx, w, warehouseID, table) + if err != nil { + results[i] = fmt.Sprintf("Error discovering %s: %v", table, err) + } else { + results[i] = result + } + // A failure on one table shouldn't abort the others. + return nil + }) } + _ = g.Wait() // format output with dividers for multiple tables var output string @@ -90,12 +111,12 @@ For each table, returns: }, } + cmd.Flags().IntVar(&concurrency, "concurrency", defaultBatchConcurrency, "Maximum in-flight SQL statements when discovering multiple tables") + return cmd } func discoverTable(ctx context.Context, w *databricks.WorkspaceClient, warehouseID, table string) (string, error) { - var sb strings.Builder - quoted, err := quoteTableName(table) if err != nil { return "", err @@ -113,32 +134,47 @@ func discoverTable(ctx context.Context, w *databricks.WorkspaceClient, warehouse return "", errors.New("no columns found") } + // 2 + 3. Sample data and null counts run in parallel; both depend only on + // the column list (already known) and not on each other. + sampleSQL := fmt.Sprintf("SELECT * FROM %s LIMIT 5", quoted) + + nullCountExprs := make([]string, len(columns)) + for i, col := range columns { + nullCountExprs[i] = fmt.Sprintf("SUM(CASE WHEN `%s` IS NULL THEN 1 ELSE 0 END) AS `%s_nulls`", col, col) + } + nullSQL := fmt.Sprintf("SELECT COUNT(*) AS total_rows, %s FROM %s", + strings.Join(nullCountExprs, ", "), quoted) + + var sampleResp, nullResp *dbsql.StatementResponse + var sampleErr, nullErr error + + g := new(errgroup.Group) + g.Go(func() error { + sampleResp, sampleErr = executeSQL(ctx, w, warehouseID, sampleSQL) + return nil + }) + g.Go(func() error { + nullResp, nullErr = executeSQL(ctx, w, warehouseID, nullSQL) + return nil + }) + _ = g.Wait() + + // Assemble the output in the established order: columns, sample, null counts. + var sb strings.Builder sb.WriteString("COLUMNS:\n") for i, col := range columns { fmt.Fprintf(&sb, " %s: %s\n", col, types[i]) } - // 2. sample data (5 rows) - sampleSQL := fmt.Sprintf("SELECT * FROM %s LIMIT 5", quoted) - sampleResp, err := executeSQL(ctx, w, warehouseID, sampleSQL) - if err != nil { - fmt.Fprintf(&sb, "\nSAMPLE DATA: Error - %v\n", err) + if sampleErr != nil { + fmt.Fprintf(&sb, "\nSAMPLE DATA: Error - %v\n", sampleErr) } else { sb.WriteString("\nSAMPLE DATA:\n") sb.WriteString(formatTableData(sampleResp)) } - // 3. null counts per column - nullCountExprs := make([]string, len(columns)) - for i, col := range columns { - nullCountExprs[i] = fmt.Sprintf("SUM(CASE WHEN `%s` IS NULL THEN 1 ELSE 0 END) AS `%s_nulls`", col, col) - } - nullSQL := fmt.Sprintf("SELECT COUNT(*) AS total_rows, %s FROM %s", - strings.Join(nullCountExprs, ", "), quoted) - - nullResp, err := executeSQL(ctx, w, warehouseID, nullSQL) - if err != nil { - fmt.Fprintf(&sb, "\nNULL COUNTS: Error - %v\n", err) + if nullErr != nil { + fmt.Fprintf(&sb, "\nNULL COUNTS: Error - %v\n", nullErr) } else { sb.WriteString("\nNULL COUNTS:\n") sb.WriteString(formatNullCounts(nullResp, columns)) @@ -148,24 +184,25 @@ func discoverTable(ctx context.Context, w *databricks.WorkspaceClient, warehouse } func executeSQL(ctx context.Context, w *databricks.WorkspaceClient, warehouseID, statement string) (*dbsql.StatementResponse, error) { - resp, err := w.StatementExecution.ExecuteAndWait(ctx, dbsql.ExecuteStatementRequest{ - WarehouseId: warehouseID, - Statement: statement, - WaitTimeout: "50s", + resp, err := w.StatementExecution.ExecuteStatement(ctx, dbsql.ExecuteStatementRequest{ + WarehouseId: warehouseID, + Statement: statement, + WaitTimeout: "0s", + OnWaitTimeout: dbsql.ExecuteStatementRequestOnWaitTimeoutContinue, }) if err != nil { - return nil, err + return nil, fmt.Errorf("execute statement: %w", err) } - if resp.Status != nil && resp.Status.State == dbsql.StatementStateFailed { - errMsg := "query failed" - if resp.Status.Error != nil { - errMsg = resp.Status.Error.Message - } - return nil, errors.New(errMsg) + pollResp, err := pollStatement(ctx, w.StatementExecution, resp) + if err != nil { + return nil, err } - return resp, nil + if err := checkFailedState(pollResp.Status); err != nil { + return nil, err + } + return pollResp, nil } func parseDescribeResult(resp *dbsql.StatementResponse) (columns, types []string) { diff --git a/experimental/aitools/cmd/discover_schema_test.go b/experimental/aitools/cmd/discover_schema_test.go new file mode 100644 index 00000000000..4a86982ed45 --- /dev/null +++ b/experimental/aitools/cmd/discover_schema_test.go @@ -0,0 +1,236 @@ +package aitools + +import ( + "context" + "errors" + "strings" + "sync/atomic" + "testing" + "time" + + "github.com/databricks/cli/libs/cmdio" + "github.com/databricks/databricks-sdk-go" + mocksql "github.com/databricks/databricks-sdk-go/experimental/mocks/service/sql" + dbsql "github.com/databricks/databricks-sdk-go/service/sql" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + +func TestQuoteTableName(t *testing.T) { + tests := []struct { + name string + in string + want string + wantErr string + }{ + {"valid", "main.public.orders", "`main`.`public`.`orders`", ""}, + {"underscores ok", "_a.b_c.d_e", "`_a`.`b_c`.`d_e`", ""}, + {"missing parts", "public.orders", "", "expected CATALOG.SCHEMA.TABLE"}, + {"too many parts", "a.b.c.d", "", "expected CATALOG.SCHEMA.TABLE"}, + {"injection in catalog", "a;DROP--.b.c", "", "invalid SQL identifier"}, + {"backtick in name", "a.b.c`d", "", "invalid SQL identifier"}, + {"empty part", "a..c", "", "invalid SQL identifier"}, + {"starts with digit", "1main.public.orders", "", "invalid SQL identifier"}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got, err := quoteTableName(tc.in) + if tc.wantErr != "" { + require.Error(t, err) + assert.Contains(t, err.Error(), tc.wantErr) + return + } + require.NoError(t, err) + assert.Equal(t, tc.want, got) + }) + } +} + +func TestParseDescribeResultSkipsMetadataRows(t *testing.T) { + resp := &dbsql.StatementResponse{ + Result: &dbsql.ResultData{DataArray: [][]string{ + {"id", "BIGINT", ""}, + {"name", "STRING", ""}, + {"# Partition Information", "", ""}, // metadata divider, skip + {"region", "STRING", ""}, + {"", "STRING", ""}, // empty col name, skip + }}, + } + + cols, types := parseDescribeResult(resp) + assert.Equal(t, []string{"id", "name", "region"}, cols) + assert.Equal(t, []string{"BIGINT", "STRING", "STRING"}, types) +} + +func TestExecuteSQLUsesPollStatementAndPinsOnWaitTimeout(t *testing.T) { + ctx := cmdio.MockDiscard(t.Context()) + mockAPI := mocksql.NewMockStatementExecutionInterface(t) + + mockAPI.EXPECT().ExecuteStatement(mock.Anything, mock.MatchedBy(func(req dbsql.ExecuteStatementRequest) bool { + return req.Statement == "SELECT 1" && + req.WaitTimeout == "0s" && + req.OnWaitTimeout == dbsql.ExecuteStatementRequestOnWaitTimeoutContinue + })).Return(&dbsql.StatementResponse{ + StatementId: "stmt-1", + Status: &dbsql.StatementStatus{State: dbsql.StatementStateSucceeded}, + Result: &dbsql.ResultData{DataArray: [][]string{{"1"}}}, + }, nil).Once() + + w := &databricks.WorkspaceClient{StatementExecution: mockAPI} + resp, err := executeSQL(ctx, w, "wh-1", "SELECT 1") + require.NoError(t, err) + assert.Equal(t, "stmt-1", resp.StatementId) +} + +func TestExecuteSQLPropagatesFailedState(t *testing.T) { + ctx := cmdio.MockDiscard(t.Context()) + mockAPI := mocksql.NewMockStatementExecutionInterface(t) + + mockAPI.EXPECT().ExecuteStatement(mock.Anything, mock.Anything).Return(&dbsql.StatementResponse{ + StatementId: "stmt-1", + Status: &dbsql.StatementStatus{ + State: dbsql.StatementStateFailed, + Error: &dbsql.ServiceError{ErrorCode: "SYNTAX_ERROR", Message: "near 'oops'"}, + }, + }, nil).Once() + + w := &databricks.WorkspaceClient{StatementExecution: mockAPI} + _, err := executeSQL(ctx, w, "wh-1", "SELECT oops") + require.Error(t, err) + assert.Contains(t, err.Error(), "SYNTAX_ERROR") +} + +func TestExecuteSQLWrapsTransportError(t *testing.T) { + ctx := cmdio.MockDiscard(t.Context()) + mockAPI := mocksql.NewMockStatementExecutionInterface(t) + + mockAPI.EXPECT().ExecuteStatement(mock.Anything, mock.Anything). + Return(nil, errors.New("network unreachable")).Once() + + w := &databricks.WorkspaceClient{StatementExecution: mockAPI} + _, err := executeSQL(ctx, w, "wh-1", "SELECT 1") + require.Error(t, err) + assert.Contains(t, err.Error(), "execute statement") + assert.Contains(t, err.Error(), "network unreachable") +} + +func TestDiscoverTableRunsSampleAndNullsInParallel(t *testing.T) { + // After DESCRIBE returns, sample SELECT and null counts must run in + // parallel, not back-to-back. Each mocked probe blocks briefly so an + // atomic counter can observe peak in-flight calls. + ctx := cmdio.MockDiscard(t.Context()) + mockAPI := mocksql.NewMockStatementExecutionInterface(t) + + var inFlight, peak atomic.Int32 + + mockAPI.EXPECT().ExecuteStatement(mock.Anything, mock.MatchedBy(func(req dbsql.ExecuteStatementRequest) bool { + return strings.HasPrefix(req.Statement, "DESCRIBE TABLE") + })).Return(&dbsql.StatementResponse{ + StatementId: "stmt-desc", + Status: &dbsql.StatementStatus{State: dbsql.StatementStateSucceeded}, + Result: &dbsql.ResultData{DataArray: [][]string{ + {"id", "BIGINT", ""}, + {"name", "STRING", ""}, + }}, + }, nil).Once() + + probe := func(ctx context.Context, req dbsql.ExecuteStatementRequest) (*dbsql.StatementResponse, error) { + n := inFlight.Add(1) + for { + cur := peak.Load() + if n <= cur || peak.CompareAndSwap(cur, n) { + break + } + } + time.Sleep(50 * time.Millisecond) + inFlight.Add(-1) + return &dbsql.StatementResponse{ + StatementId: "stmt-probe", + Status: &dbsql.StatementStatus{State: dbsql.StatementStateSucceeded}, + Manifest: &dbsql.ResultManifest{Schema: &dbsql.ResultSchema{Columns: []dbsql.ColumnInfo{{Name: "x"}}}}, + Result: &dbsql.ResultData{DataArray: [][]string{{"0"}}}, + }, nil + } + + mockAPI.EXPECT().ExecuteStatement(mock.Anything, mock.MatchedBy(func(req dbsql.ExecuteStatementRequest) bool { + return strings.HasPrefix(req.Statement, "SELECT *") + })).RunAndReturn(probe).Once() + + mockAPI.EXPECT().ExecuteStatement(mock.Anything, mock.MatchedBy(func(req dbsql.ExecuteStatementRequest) bool { + return strings.Contains(req.Statement, "SUM(CASE WHEN") + })).RunAndReturn(probe).Once() + + w := &databricks.WorkspaceClient{StatementExecution: mockAPI} + out, err := discoverTable(ctx, w, "wh-1", "main.public.orders") + require.NoError(t, err) + + assert.GreaterOrEqual(t, peak.Load(), int32(2), "sample and null-count probes should run concurrently") + assert.Contains(t, out, "COLUMNS:") + assert.Contains(t, out, "SAMPLE DATA:") + assert.Contains(t, out, "NULL COUNTS:") +} + +func TestDiscoverTableSampleErrorDoesNotAbortNullCounts(t *testing.T) { + ctx := cmdio.MockDiscard(t.Context()) + mockAPI := mocksql.NewMockStatementExecutionInterface(t) + + mockAPI.EXPECT().ExecuteStatement(mock.Anything, mock.MatchedBy(func(req dbsql.ExecuteStatementRequest) bool { + return strings.HasPrefix(req.Statement, "DESCRIBE TABLE") + })).Return(&dbsql.StatementResponse{ + StatementId: "stmt-desc", + Status: &dbsql.StatementStatus{State: dbsql.StatementStateSucceeded}, + Result: &dbsql.ResultData{DataArray: [][]string{{"id", "BIGINT", ""}}}, + }, nil).Once() + + mockAPI.EXPECT().ExecuteStatement(mock.Anything, mock.MatchedBy(func(req dbsql.ExecuteStatementRequest) bool { + return strings.HasPrefix(req.Statement, "SELECT *") + })).Return(&dbsql.StatementResponse{ + StatementId: "stmt-sample", + Status: &dbsql.StatementStatus{ + State: dbsql.StatementStateFailed, + Error: &dbsql.ServiceError{ErrorCode: "PERM", Message: "permission denied"}, + }, + }, nil).Once() + + mockAPI.EXPECT().ExecuteStatement(mock.Anything, mock.MatchedBy(func(req dbsql.ExecuteStatementRequest) bool { + return strings.Contains(req.Statement, "SUM(CASE WHEN") + })).Return(&dbsql.StatementResponse{ + StatementId: "stmt-null", + Status: &dbsql.StatementStatus{State: dbsql.StatementStateSucceeded}, + Manifest: &dbsql.ResultManifest{Schema: &dbsql.ResultSchema{Columns: []dbsql.ColumnInfo{{Name: "total_rows"}, {Name: "id_nulls"}}}}, + Result: &dbsql.ResultData{DataArray: [][]string{{"100", "0"}}}, + }, nil).Once() + + w := &databricks.WorkspaceClient{StatementExecution: mockAPI} + out, err := discoverTable(ctx, w, "wh-1", "main.public.orders") + require.NoError(t, err) + assert.Contains(t, out, "SAMPLE DATA: Error - ") + assert.Contains(t, out, "permission denied") + assert.Contains(t, out, "NULL COUNTS:") + assert.Contains(t, out, "total_rows: 100") +} + +func TestDiscoverSchemaConcurrencyZeroRejected(t *testing.T) { + cmd := newDiscoverSchemaCmd() + cmd.SetArgs([]string{"--concurrency", "0", "main.public.orders"}) + err := cmd.Execute() + require.ErrorIs(t, err, errInvalidBatchConcurrency) +} + +func TestDiscoverSchemaConcurrencyNegativeRejected(t *testing.T) { + cmd := newDiscoverSchemaCmd() + cmd.SetArgs([]string{"--concurrency", "-1", "main.public.orders"}) + err := cmd.Execute() + require.ErrorIs(t, err, errInvalidBatchConcurrency) +} + +func TestDiscoverSchemaInvalidTableNameRejected(t *testing.T) { + cmd := newDiscoverSchemaCmd() + cmd.PreRunE = nil // skip workspace client requirement + cmd.SetArgs([]string{"not-three-parts"}) + err := cmd.Execute() + require.Error(t, err) + assert.Contains(t, err.Error(), "expected CATALOG.SCHEMA.TABLE") +} From ce3cdea47bc22f915980275853ef6a7126afcdcb Mon Sep 17 00:00:00 2001 From: simon Date: Mon, 27 Apr 2026 18:38:10 +0200 Subject: [PATCH 11/14] aitools: cap discover-schema statement concurrency globally and cancel on Ctrl+C Address two findings from a cursor PR review. 1. --concurrency previously capped table-level fan-out via errgroup.SetLimit, but each table issued up to two probes after DESCRIBE, so peak warehouse load was 2*concurrency rather than the advertised "max in-flight statements." A user setting --concurrency 1 to stay under a warehouse cap still saw two statements concurrently. Replace the table-level limit with a shared sqlGate (chan struct{} of capacity N + statement_id tracking) that wraps every executeSQL call. Now --concurrency really means "max statements in flight at any moment, across all tables and probes." Update the help text to match. 2. After switching from ExecuteAndWait to ExecuteStatement + pollStatement (PR4 first commit), Ctrl+C left up to 2*concurrency statements running server-side because nothing called CancelExecution. Add the same cancellation discipline used in batch.go: signal handler cancels a derived pollCtx, gate records each statement_id post-submission, and on cancellation we sweep the recorded IDs via CancelExecution before returning root.ErrAlreadyPrinted. Also addressed: - Move table-name validation into PreRunE so a malformed identifier is rejected before MustWorkspaceClient runs (real CLI-lifecycle improvement, not just a test trick). - Replace the timing-based parallelism test with a deterministic barrier (atomic counter + sync.OnceFunc + channel close): both probes must arrive before either is allowed to leave; if they ran sequentially the first probe times out and surfaces an error. Tests reorganized: - sqlGate.run: pins OnWaitTimeout, propagates FAILED, wraps transport errors, records ids, respects cancelled context - cancelDiscoverInFlight: per-id calls, empty list is a no-op - discoverTable: deterministic concurrent-probes assertion; per-probe failure does not abort siblings - cobra-level: invalid table name and injection attempts rejected before any workspace client setup Co-authored-by: Isaac --- experimental/aitools/cmd/discover_schema.go | 168 +++++++++++++----- .../aitools/cmd/discover_schema_test.go | 121 ++++++++++--- 2 files changed, 219 insertions(+), 70 deletions(-) diff --git a/experimental/aitools/cmd/discover_schema.go b/experimental/aitools/cmd/discover_schema.go index 46d8012c432..0811bb6cbf2 100644 --- a/experimental/aitools/cmd/discover_schema.go +++ b/experimental/aitools/cmd/discover_schema.go @@ -4,14 +4,20 @@ import ( "context" "errors" "fmt" + "os" + "os/signal" "regexp" + "slices" "strings" + "sync" + "syscall" "github.com/databricks/cli/cmd/root" "github.com/databricks/cli/experimental/aitools/lib/middlewares" "github.com/databricks/cli/experimental/aitools/lib/session" "github.com/databricks/cli/libs/cmdctx" "github.com/databricks/cli/libs/cmdio" + "github.com/databricks/cli/libs/log" "github.com/databricks/databricks-sdk-go" dbsql "github.com/databricks/databricks-sdk-go/service/sql" "github.com/spf13/cobra" @@ -20,6 +26,64 @@ import ( var sqlIdentifierRe = regexp.MustCompile(`^[A-Za-z_][A-Za-z0-9_]*$`) +// sqlGate caps in-flight SQL statements globally and records each statement_id +// so a Ctrl+C sweep can cancel anything still running server-side. The gate's +// concurrency limit applies across all probes (DESCRIBE, sample SELECT, null +// counts) and across all tables, so --concurrency really means "max statements +// in flight," not "max tables in flight." +type sqlGate struct { + sem chan struct{} + mu sync.Mutex + ids []string +} + +func newSQLGate(limit int) *sqlGate { + return &sqlGate{sem: make(chan struct{}, limit)} +} + +// run executes a SQL statement asynchronously, polls until terminal, and +// records the statement_id so it can be cancelled if the parent context is +// cancelled. Acquires a slot from the gate before submitting and releases it +// when polling completes (or the caller's context is cancelled). +func (g *sqlGate) run(ctx context.Context, w *databricks.WorkspaceClient, warehouseID, statement string) (*dbsql.StatementResponse, error) { + select { + case g.sem <- struct{}{}: + defer func() { <-g.sem }() + case <-ctx.Done(): + return nil, ctx.Err() + } + + resp, err := w.StatementExecution.ExecuteStatement(ctx, dbsql.ExecuteStatementRequest{ + WarehouseId: warehouseID, + Statement: statement, + WaitTimeout: "0s", + OnWaitTimeout: dbsql.ExecuteStatementRequestOnWaitTimeoutContinue, + }) + if err != nil { + return nil, fmt.Errorf("execute statement: %w", err) + } + + g.mu.Lock() + g.ids = append(g.ids, resp.StatementId) + g.mu.Unlock() + + pollResp, err := pollStatement(ctx, w.StatementExecution, resp) + if err != nil { + return nil, err + } + if err := checkFailedState(pollResp.Status); err != nil { + return nil, err + } + return pollResp, nil +} + +// trackedIDs returns a snapshot of statement_ids submitted through this gate. +func (g *sqlGate) trackedIDs() []string { + g.mu.Lock() + defer g.mu.Unlock() + return slices.Clone(g.ids) +} + func newDiscoverSchemaCmd() *cobra.Command { var concurrency int @@ -36,9 +100,13 @@ For each table, returns: - Null counts per column - Total row count -Multiple tables are discovered in parallel against the warehouse, capped -by --concurrency (default 8). Within a single table, the sample-data and -null-counts probes also run in parallel after the column list is known.`, +Tables and probes (DESCRIBE, sample SELECT, null counts) all share a +single warehouse-statement budget. --concurrency (default 8) caps the +total number of statements in flight at any moment, regardless of how +many tables you pass in. + +On Ctrl+C, in-flight statements are cancelled server-side via +CancelExecution before the command exits.`, Example: ` databricks experimental aitools tools discover-schema samples.nyctaxi.trips databricks experimental aitools tools discover-schema catalog.schema.table1 catalog.schema.table2`, Args: cobra.MinimumNArgs(1), @@ -46,18 +114,16 @@ null-counts probes also run in parallel after the column list is known.`, if concurrency <= 0 { return errInvalidBatchConcurrency } - return root.MustWorkspaceClient(cmd, args) - }, - RunE: func(cmd *cobra.Command, args []string) error { - ctx := cmd.Context() - - // validate table names: each part must be a safe SQL identifier + // Reject malformed identifiers before any auth/profile work. for _, table := range args { if _, err := quoteTableName(table); err != nil { return err } } - + return root.MustWorkspaceClient(cmd, args) + }, + RunE: func(cmd *cobra.Command, args []string) error { + ctx := cmd.Context() w := cmdctx.WorkspaceClient(ctx) // set up session with client for middleware compatibility @@ -70,12 +136,29 @@ null-counts probes also run in parallel after the column list is known.`, return err } + pollCtx, pollCancel := context.WithCancel(ctx) + defer pollCancel() + + sigCh := make(chan os.Signal, 1) + signal.Notify(sigCh, os.Interrupt, syscall.SIGTERM) + defer signal.Stop(sigCh) + + go func() { + select { + case <-sigCh: + log.Infof(ctx, "Received interrupt, cancelling in-flight discover-schema statements") + pollCancel() + case <-pollCtx.Done(): + } + }() + + gate := newSQLGate(concurrency) + results := make([]string, len(args)) g := new(errgroup.Group) - g.SetLimit(concurrency) for i, table := range args { g.Go(func() error { - result, err := discoverTable(ctx, w, warehouseID, table) + result, err := discoverTable(pollCtx, gate, w, warehouseID, table) if err != nil { results[i] = fmt.Sprintf("Error discovering %s: %v", table, err) } else { @@ -87,6 +170,11 @@ null-counts probes also run in parallel after the column list is known.`, } _ = g.Wait() + if pollCtx.Err() != nil { + cancelDiscoverInFlight(ctx, w.StatementExecution, gate.trackedIDs()) + return root.ErrAlreadyPrinted + } + // format output with dividers for multiple tables var output string if len(results) == 1 { @@ -111,20 +199,39 @@ null-counts probes also run in parallel after the column list is known.`, }, } - cmd.Flags().IntVar(&concurrency, "concurrency", defaultBatchConcurrency, "Maximum in-flight SQL statements when discovering multiple tables") + cmd.Flags().IntVar(&concurrency, "concurrency", defaultBatchConcurrency, "Maximum SQL statements in flight at once across all tables and probes") return cmd } -func discoverTable(ctx context.Context, w *databricks.WorkspaceClient, warehouseID, table string) (string, error) { +// cancelDiscoverInFlight sends CancelExecution for every recorded statement_id. +// Best effort: errors are logged but don't fail the user-visible exit. +// Statements that already finished server-side return an error which we just +// swallow at warn level; the alternative (per-statement state tracking) isn't +// worth the bookkeeping here. +func cancelDiscoverInFlight(ctx context.Context, api dbsql.StatementExecutionInterface, ids []string) { + if len(ids) == 0 { + cmdio.LogString(ctx, "discover-schema cancelled.") + return + } + for _, id := range ids { + cancelCtx, cancel := context.WithTimeout(ctx, cancelTimeout) + if err := api.CancelExecution(cancelCtx, dbsql.CancelExecutionRequest{StatementId: id}); err != nil { + log.Warnf(ctx, "Failed to cancel statement %s: %v", id, err) + } + cancel() + } + cmdio.LogString(ctx, fmt.Sprintf("discover-schema cancelled; sent CancelExecution for %d statement(s).", len(ids))) +} + +func discoverTable(ctx context.Context, gate *sqlGate, w *databricks.WorkspaceClient, warehouseID, table string) (string, error) { quoted, err := quoteTableName(table) if err != nil { return "", err } // 1. describe table - get columns and types - describeSQL := "DESCRIBE TABLE " + quoted - descResp, err := executeSQL(ctx, w, warehouseID, describeSQL) + descResp, err := gate.run(ctx, w, warehouseID, "DESCRIBE TABLE "+quoted) if err != nil { return "", fmt.Errorf("describe table: %w", err) } @@ -135,7 +242,8 @@ func discoverTable(ctx context.Context, w *databricks.WorkspaceClient, warehouse } // 2 + 3. Sample data and null counts run in parallel; both depend only on - // the column list (already known) and not on each other. + // the column list (already known) and not on each other. The gate (not + // errgroup) is what actually limits warehouse concurrency. sampleSQL := fmt.Sprintf("SELECT * FROM %s LIMIT 5", quoted) nullCountExprs := make([]string, len(columns)) @@ -150,11 +258,11 @@ func discoverTable(ctx context.Context, w *databricks.WorkspaceClient, warehouse g := new(errgroup.Group) g.Go(func() error { - sampleResp, sampleErr = executeSQL(ctx, w, warehouseID, sampleSQL) + sampleResp, sampleErr = gate.run(ctx, w, warehouseID, sampleSQL) return nil }) g.Go(func() error { - nullResp, nullErr = executeSQL(ctx, w, warehouseID, nullSQL) + nullResp, nullErr = gate.run(ctx, w, warehouseID, nullSQL) return nil }) _ = g.Wait() @@ -183,28 +291,6 @@ func discoverTable(ctx context.Context, w *databricks.WorkspaceClient, warehouse return sb.String(), nil } -func executeSQL(ctx context.Context, w *databricks.WorkspaceClient, warehouseID, statement string) (*dbsql.StatementResponse, error) { - resp, err := w.StatementExecution.ExecuteStatement(ctx, dbsql.ExecuteStatementRequest{ - WarehouseId: warehouseID, - Statement: statement, - WaitTimeout: "0s", - OnWaitTimeout: dbsql.ExecuteStatementRequestOnWaitTimeoutContinue, - }) - if err != nil { - return nil, fmt.Errorf("execute statement: %w", err) - } - - pollResp, err := pollStatement(ctx, w.StatementExecution, resp) - if err != nil { - return nil, err - } - - if err := checkFailedState(pollResp.Status); err != nil { - return nil, err - } - return pollResp, nil -} - func parseDescribeResult(resp *dbsql.StatementResponse) (columns, types []string) { if resp.Result == nil || resp.Result.DataArray == nil { return nil, nil diff --git a/experimental/aitools/cmd/discover_schema_test.go b/experimental/aitools/cmd/discover_schema_test.go index 4a86982ed45..a94eccbd938 100644 --- a/experimental/aitools/cmd/discover_schema_test.go +++ b/experimental/aitools/cmd/discover_schema_test.go @@ -4,6 +4,7 @@ import ( "context" "errors" "strings" + "sync" "sync/atomic" "testing" "time" @@ -53,9 +54,9 @@ func TestParseDescribeResultSkipsMetadataRows(t *testing.T) { Result: &dbsql.ResultData{DataArray: [][]string{ {"id", "BIGINT", ""}, {"name", "STRING", ""}, - {"# Partition Information", "", ""}, // metadata divider, skip + {"# Partition Information", "", ""}, {"region", "STRING", ""}, - {"", "STRING", ""}, // empty col name, skip + {"", "STRING", ""}, }}, } @@ -64,7 +65,7 @@ func TestParseDescribeResultSkipsMetadataRows(t *testing.T) { assert.Equal(t, []string{"BIGINT", "STRING", "STRING"}, types) } -func TestExecuteSQLUsesPollStatementAndPinsOnWaitTimeout(t *testing.T) { +func TestSQLGateRunPinsOnWaitTimeoutAndRecordsID(t *testing.T) { ctx := cmdio.MockDiscard(t.Context()) mockAPI := mocksql.NewMockStatementExecutionInterface(t) @@ -79,12 +80,15 @@ func TestExecuteSQLUsesPollStatementAndPinsOnWaitTimeout(t *testing.T) { }, nil).Once() w := &databricks.WorkspaceClient{StatementExecution: mockAPI} - resp, err := executeSQL(ctx, w, "wh-1", "SELECT 1") + gate := newSQLGate(2) + + resp, err := gate.run(ctx, w, "wh-1", "SELECT 1") require.NoError(t, err) assert.Equal(t, "stmt-1", resp.StatementId) + assert.Equal(t, []string{"stmt-1"}, gate.trackedIDs()) } -func TestExecuteSQLPropagatesFailedState(t *testing.T) { +func TestSQLGateRunPropagatesFailedState(t *testing.T) { ctx := cmdio.MockDiscard(t.Context()) mockAPI := mocksql.NewMockStatementExecutionInterface(t) @@ -97,12 +101,16 @@ func TestExecuteSQLPropagatesFailedState(t *testing.T) { }, nil).Once() w := &databricks.WorkspaceClient{StatementExecution: mockAPI} - _, err := executeSQL(ctx, w, "wh-1", "SELECT oops") + gate := newSQLGate(2) + + _, err := gate.run(ctx, w, "wh-1", "SELECT oops") require.Error(t, err) assert.Contains(t, err.Error(), "SYNTAX_ERROR") + // Even on failure, the id is recorded so a cancellation sweep can clean up. + assert.Equal(t, []string{"stmt-1"}, gate.trackedIDs()) } -func TestExecuteSQLWrapsTransportError(t *testing.T) { +func TestSQLGateRunWrapsTransportError(t *testing.T) { ctx := cmdio.MockDiscard(t.Context()) mockAPI := mocksql.NewMockStatementExecutionInterface(t) @@ -110,20 +118,36 @@ func TestExecuteSQLWrapsTransportError(t *testing.T) { Return(nil, errors.New("network unreachable")).Once() w := &databricks.WorkspaceClient{StatementExecution: mockAPI} - _, err := executeSQL(ctx, w, "wh-1", "SELECT 1") + gate := newSQLGate(2) + + _, err := gate.run(ctx, w, "wh-1", "SELECT 1") require.Error(t, err) assert.Contains(t, err.Error(), "execute statement") assert.Contains(t, err.Error(), "network unreachable") + assert.Empty(t, gate.trackedIDs(), "no id should be recorded when ExecuteStatement fails") } -func TestDiscoverTableRunsSampleAndNullsInParallel(t *testing.T) { - // After DESCRIBE returns, sample SELECT and null counts must run in - // parallel, not back-to-back. Each mocked probe blocks briefly so an - // atomic counter can observe peak in-flight calls. - ctx := cmdio.MockDiscard(t.Context()) +func TestSQLGateRunRespectsCancelledContext(t *testing.T) { + // With ctx already cancelled, gate.run must not call any API method: + // it bails at the semaphore-acquire select. + ctx, cancel := context.WithCancel(cmdio.MockDiscard(t.Context())) + cancel() + mockAPI := mocksql.NewMockStatementExecutionInterface(t) + w := &databricks.WorkspaceClient{StatementExecution: mockAPI} + gate := newSQLGate(2) - var inFlight, peak atomic.Int32 + _, err := gate.run(ctx, w, "wh-1", "SELECT 1") + require.ErrorIs(t, err, context.Canceled) +} + +func TestDiscoverTableRunsSampleAndNullsConcurrently(t *testing.T) { + // Deterministic barrier: both probes must enter before either is allowed + // to leave. If gate.run/discoverTable serialized them, the first probe + // would time out and return an error, which would surface as + // "SAMPLE DATA: Error - " or "NULL COUNTS: Error - " in the output. + ctx := cmdio.MockDiscard(t.Context()) + mockAPI := mocksql.NewMockStatementExecutionInterface(t) mockAPI.EXPECT().ExecuteStatement(mock.Anything, mock.MatchedBy(func(req dbsql.ExecuteStatementRequest) bool { return strings.HasPrefix(req.Statement, "DESCRIBE TABLE") @@ -136,18 +160,22 @@ func TestDiscoverTableRunsSampleAndNullsInParallel(t *testing.T) { }}, }, nil).Once() + const numProbes = 2 + var dispatched atomic.Int32 + release := make(chan struct{}) + closeRelease := sync.OnceFunc(func() { close(release) }) + probe := func(ctx context.Context, req dbsql.ExecuteStatementRequest) (*dbsql.StatementResponse, error) { - n := inFlight.Add(1) - for { - cur := peak.Load() - if n <= cur || peak.CompareAndSwap(cur, n) { - break - } + if dispatched.Add(1) == numProbes { + closeRelease() + } + select { + case <-release: + case <-time.After(2 * time.Second): + return nil, errors.New("probe timeout: not running concurrently") } - time.Sleep(50 * time.Millisecond) - inFlight.Add(-1) return &dbsql.StatementResponse{ - StatementId: "stmt-probe", + StatementId: "stmt-probe-" + req.Statement[:7], Status: &dbsql.StatementStatus{State: dbsql.StatementStateSucceeded}, Manifest: &dbsql.ResultManifest{Schema: &dbsql.ResultSchema{Columns: []dbsql.ColumnInfo{{Name: "x"}}}}, Result: &dbsql.ResultData{DataArray: [][]string{{"0"}}}, @@ -163,10 +191,12 @@ func TestDiscoverTableRunsSampleAndNullsInParallel(t *testing.T) { })).RunAndReturn(probe).Once() w := &databricks.WorkspaceClient{StatementExecution: mockAPI} - out, err := discoverTable(ctx, w, "wh-1", "main.public.orders") - require.NoError(t, err) + gate := newSQLGate(8) - assert.GreaterOrEqual(t, peak.Load(), int32(2), "sample and null-count probes should run concurrently") + out, err := discoverTable(ctx, gate, w, "wh-1", "main.public.orders") + require.NoError(t, err) + assert.Equal(t, int32(numProbes), dispatched.Load(), "both probes should have entered concurrently") + assert.NotContains(t, out, "Error - ", "no probe should have surfaced an error") assert.Contains(t, out, "COLUMNS:") assert.Contains(t, out, "SAMPLE DATA:") assert.Contains(t, out, "NULL COUNTS:") @@ -204,7 +234,9 @@ func TestDiscoverTableSampleErrorDoesNotAbortNullCounts(t *testing.T) { }, nil).Once() w := &databricks.WorkspaceClient{StatementExecution: mockAPI} - out, err := discoverTable(ctx, w, "wh-1", "main.public.orders") + gate := newSQLGate(8) + + out, err := discoverTable(ctx, gate, w, "wh-1", "main.public.orders") require.NoError(t, err) assert.Contains(t, out, "SAMPLE DATA: Error - ") assert.Contains(t, out, "permission denied") @@ -212,6 +244,28 @@ func TestDiscoverTableSampleErrorDoesNotAbortNullCounts(t *testing.T) { assert.Contains(t, out, "total_rows: 100") } +func TestCancelDiscoverInFlightCallsAPIPerID(t *testing.T) { + ctx := cmdio.MockDiscard(t.Context()) + mockAPI := mocksql.NewMockStatementExecutionInterface(t) + + for _, id := range []string{"stmt-a", "stmt-b", "stmt-c"} { + mockAPI.EXPECT().CancelExecution(mock.Anything, dbsql.CancelExecutionRequest{ + StatementId: id, + }).Return(nil).Once() + } + + cancelDiscoverInFlight(ctx, mockAPI, []string{"stmt-a", "stmt-b", "stmt-c"}) +} + +func TestCancelDiscoverInFlightHandlesEmptyList(t *testing.T) { + // Empty list = no API calls. Mock asserts (via t.Cleanup) that nothing + // unexpected happens. + ctx := cmdio.MockDiscard(t.Context()) + mockAPI := mocksql.NewMockStatementExecutionInterface(t) + + cancelDiscoverInFlight(ctx, mockAPI, nil) +} + func TestDiscoverSchemaConcurrencyZeroRejected(t *testing.T) { cmd := newDiscoverSchemaCmd() cmd.SetArgs([]string{"--concurrency", "0", "main.public.orders"}) @@ -226,11 +280,20 @@ func TestDiscoverSchemaConcurrencyNegativeRejected(t *testing.T) { require.ErrorIs(t, err, errInvalidBatchConcurrency) } -func TestDiscoverSchemaInvalidTableNameRejected(t *testing.T) { +func TestDiscoverSchemaInvalidTableNameRejectedBeforeWorkspaceClient(t *testing.T) { + // PreRunE rejects malformed identifiers before MustWorkspaceClient runs, + // so the test passes without any workspace mocking. cmd := newDiscoverSchemaCmd() - cmd.PreRunE = nil // skip workspace client requirement cmd.SetArgs([]string{"not-three-parts"}) err := cmd.Execute() require.Error(t, err) assert.Contains(t, err.Error(), "expected CATALOG.SCHEMA.TABLE") } + +func TestDiscoverSchemaInjectionAttemptRejected(t *testing.T) { + cmd := newDiscoverSchemaCmd() + cmd.SetArgs([]string{"a;DROP--.b.c"}) + err := cmd.Execute() + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid SQL identifier") +} From 19350a437d2ad666aa62362c14a3df629b539336 Mon Sep 17 00:00:00 2001 From: simon Date: Mon, 27 Apr 2026 21:40:39 +0200 Subject: [PATCH 12/14] aitools: drop redundant discover-schema tests; fold concurrency rejection MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Self-review pass on the test suite found ~3 functions worth trimming: Drop: - TestDiscoverSchemaInjectionAttemptRejected: TestQuoteTableName already has an "injection in catalog" case; the cobra-level wiring is already tested by TestDiscoverSchemaInvalidTableNameRejectedBeforeWorkspaceClient with a different bad input. - TestCancelDiscoverInFlightHandlesEmptyList: just verifies "no API calls when list is empty"; the mock would fail loudly on any unexpected call, making this a tautology. Fold: - TestDiscoverSchemaConcurrencyZeroRejected + ...NegativeRejected → TestDiscoverSchemaConcurrencyRejection (0, -1 subtests). Co-authored-by: Isaac --- .../aitools/cmd/discover_schema_test.go | 38 +++++-------------- 1 file changed, 9 insertions(+), 29 deletions(-) diff --git a/experimental/aitools/cmd/discover_schema_test.go b/experimental/aitools/cmd/discover_schema_test.go index a94eccbd938..6301eaa0c05 100644 --- a/experimental/aitools/cmd/discover_schema_test.go +++ b/experimental/aitools/cmd/discover_schema_test.go @@ -257,27 +257,15 @@ func TestCancelDiscoverInFlightCallsAPIPerID(t *testing.T) { cancelDiscoverInFlight(ctx, mockAPI, []string{"stmt-a", "stmt-b", "stmt-c"}) } -func TestCancelDiscoverInFlightHandlesEmptyList(t *testing.T) { - // Empty list = no API calls. Mock asserts (via t.Cleanup) that nothing - // unexpected happens. - ctx := cmdio.MockDiscard(t.Context()) - mockAPI := mocksql.NewMockStatementExecutionInterface(t) - - cancelDiscoverInFlight(ctx, mockAPI, nil) -} - -func TestDiscoverSchemaConcurrencyZeroRejected(t *testing.T) { - cmd := newDiscoverSchemaCmd() - cmd.SetArgs([]string{"--concurrency", "0", "main.public.orders"}) - err := cmd.Execute() - require.ErrorIs(t, err, errInvalidBatchConcurrency) -} - -func TestDiscoverSchemaConcurrencyNegativeRejected(t *testing.T) { - cmd := newDiscoverSchemaCmd() - cmd.SetArgs([]string{"--concurrency", "-1", "main.public.orders"}) - err := cmd.Execute() - require.ErrorIs(t, err, errInvalidBatchConcurrency) +func TestDiscoverSchemaConcurrencyRejection(t *testing.T) { + for _, value := range []string{"0", "-1"} { + t.Run(value, func(t *testing.T) { + cmd := newDiscoverSchemaCmd() + cmd.SetArgs([]string{"--concurrency", value, "main.public.orders"}) + err := cmd.Execute() + require.ErrorIs(t, err, errInvalidBatchConcurrency) + }) + } } func TestDiscoverSchemaInvalidTableNameRejectedBeforeWorkspaceClient(t *testing.T) { @@ -289,11 +277,3 @@ func TestDiscoverSchemaInvalidTableNameRejectedBeforeWorkspaceClient(t *testing. require.Error(t, err) assert.Contains(t, err.Error(), "expected CATALOG.SCHEMA.TABLE") } - -func TestDiscoverSchemaInjectionAttemptRejected(t *testing.T) { - cmd := newDiscoverSchemaCmd() - cmd.SetArgs([]string{"a;DROP--.b.c"}) - err := cmd.Execute() - require.Error(t, err) - assert.Contains(t, err.Error(), "invalid SQL identifier") -} From d6de46db38e54762736f8e019e3bf97d626a6fa8 Mon Sep 17 00:00:00 2001 From: simon Date: Mon, 27 Apr 2026 22:21:06 +0200 Subject: [PATCH 13/14] aitools: fix sqlGate.run race when context is already cancelled sqlGate.run used to enter a select with two ready cases when the caller passed an already-cancelled context: the gate had free slots, so `g.sem <- struct{}{}` was ready, and `<-ctx.Done()` was also ready. Go picks pseudo-randomly between simultaneously-ready cases, so on roughly half of those calls we proceeded to submit a statement under a cancelled context. Added an early `ctx.Err()` check before the select. The flaky test TestSQLGateRunRespectsCancelledContext is deterministic now (verified with -count=20). Surfaced by rebasing PR 4 on top of the trimmed PR 3, which changed test execution conditions enough to flip the coin. Co-authored-by: Isaac --- experimental/aitools/cmd/discover_schema.go | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/experimental/aitools/cmd/discover_schema.go b/experimental/aitools/cmd/discover_schema.go index 0811bb6cbf2..b3704495e06 100644 --- a/experimental/aitools/cmd/discover_schema.go +++ b/experimental/aitools/cmd/discover_schema.go @@ -46,6 +46,13 @@ func newSQLGate(limit int) *sqlGate { // cancelled. Acquires a slot from the gate before submitting and releases it // when polling completes (or the caller's context is cancelled). func (g *sqlGate) run(ctx context.Context, w *databricks.WorkspaceClient, warehouseID, statement string) (*dbsql.StatementResponse, error) { + // If the caller cancelled before we even tried, don't enter the select: + // when the gate has free slots both cases are ready and Go picks one + // pseudo-randomly. Without this early-out we'd occasionally submit a + // statement under a cancelled context. + if err := ctx.Err(); err != nil { + return nil, err + } select { case g.sem <- struct{}{}: defer func() { <-g.sem }() From 82d2052487d9e29ec1d546251478aef67ca81805 Mon Sep 17 00:00:00 2001 From: simon Date: Tue, 28 Apr 2026 10:21:20 +0200 Subject: [PATCH 14/14] aitools: escape backticks in DESCRIBE-derived column names Address Arseni's P3 finding on the discover-schema PR. parseDescribeResult returned column names verbatim and discoverTable interpolated them into the null-counts SQL inside backtick-quoted identifier positions, e.g. SUM(CASE WHEN `` IS NULL THEN 1 ELSE 0 END) AS `_nulls` Databricks/Delta DDL allows column names containing backticks via doubled-backtick escaping (`weird``col`). Without escaping in the SQL we generate, an embedded backtick in the column name terminates the quoted identifier mid-string and produces a PARSE_SYNTAX_ERROR. Sample-data uses SELECT * so it succeeds, and the user sees only a confusing "NULL COUNTS: Error - ..." line that's easy to misattribute to the warehouse. Escape via strings.ReplaceAll(col, "`", "``") in both the identifier and the alias positions before interpolation. New test TestDiscoverTableEscapesBackticksInColumnNames pins the doubled form in both spots and asserts the no-error code path. Co-authored-by: Isaac --- experimental/aitools/cmd/discover_schema.go | 9 +++- .../aitools/cmd/discover_schema_test.go | 47 +++++++++++++++++++ 2 files changed, 55 insertions(+), 1 deletion(-) diff --git a/experimental/aitools/cmd/discover_schema.go b/experimental/aitools/cmd/discover_schema.go index b3704495e06..091222368d9 100644 --- a/experimental/aitools/cmd/discover_schema.go +++ b/experimental/aitools/cmd/discover_schema.go @@ -255,7 +255,14 @@ func discoverTable(ctx context.Context, gate *sqlGate, w *databricks.WorkspaceCl nullCountExprs := make([]string, len(columns)) for i, col := range columns { - nullCountExprs[i] = fmt.Sprintf("SUM(CASE WHEN `%s` IS NULL THEN 1 ELSE 0 END) AS `%s_nulls`", col, col) + // Backticks inside an identifier are escaped by doubling them in + // Databricks/Delta SQL (`` ` `` → `` `` ``). Without this, a column + // name containing a backtick would terminate the quoted identifier + // mid-string and produce a PARSE_SYNTAX_ERROR. Sample-data uses + // SELECT * so the failure shows up only as a confusing + // "NULL COUNTS: Error - ..." line in the user-facing output. + escaped := strings.ReplaceAll(col, "`", "``") + nullCountExprs[i] = fmt.Sprintf("SUM(CASE WHEN `%s` IS NULL THEN 1 ELSE 0 END) AS `%s_nulls`", escaped, escaped) } nullSQL := fmt.Sprintf("SELECT COUNT(*) AS total_rows, %s FROM %s", strings.Join(nullCountExprs, ", "), quoted) diff --git a/experimental/aitools/cmd/discover_schema_test.go b/experimental/aitools/cmd/discover_schema_test.go index 6301eaa0c05..fe6d86e799f 100644 --- a/experimental/aitools/cmd/discover_schema_test.go +++ b/experimental/aitools/cmd/discover_schema_test.go @@ -244,6 +244,53 @@ func TestDiscoverTableSampleErrorDoesNotAbortNullCounts(t *testing.T) { assert.Contains(t, out, "total_rows: 100") } +func TestDiscoverTableEscapesBackticksInColumnNames(t *testing.T) { + // Databricks/Delta DDL allows backticks in column names via doubled- + // backtick escaping (e.g. CREATE TABLE t (`weird``col` STRING)). Without + // escaping in the null-counts SQL the embedded backtick would terminate + // the quoted identifier mid-string and produce a PARSE_SYNTAX_ERROR. + ctx := cmdio.MockDiscard(t.Context()) + mockAPI := mocksql.NewMockStatementExecutionInterface(t) + + mockAPI.EXPECT().ExecuteStatement(mock.Anything, mock.MatchedBy(func(req dbsql.ExecuteStatementRequest) bool { + return strings.HasPrefix(req.Statement, "DESCRIBE TABLE") + })).Return(&dbsql.StatementResponse{ + StatementId: "stmt-desc", + Status: &dbsql.StatementStatus{State: dbsql.StatementStateSucceeded}, + Result: &dbsql.ResultData{DataArray: [][]string{ + {"weird`col", "STRING", ""}, + }}, + }, nil).Once() + + mockAPI.EXPECT().ExecuteStatement(mock.Anything, mock.MatchedBy(func(req dbsql.ExecuteStatementRequest) bool { + return strings.HasPrefix(req.Statement, "SELECT *") + })).Return(&dbsql.StatementResponse{ + StatementId: "stmt-sample", + Status: &dbsql.StatementStatus{State: dbsql.StatementStateSucceeded}, + }, nil).Once() + + // Null-counts SQL must escape the embedded backtick. Both the identifier + // and the alias positions must use the doubled form. + mockAPI.EXPECT().ExecuteStatement(mock.Anything, mock.MatchedBy(func(req dbsql.ExecuteStatementRequest) bool { + return strings.Contains(req.Statement, "`weird``col`") && + strings.Contains(req.Statement, "`weird``col_nulls`") && + !strings.Contains(req.Statement, "`weird`col`") + })).Return(&dbsql.StatementResponse{ + StatementId: "stmt-null", + Status: &dbsql.StatementStatus{State: dbsql.StatementStateSucceeded}, + Manifest: &dbsql.ResultManifest{Schema: &dbsql.ResultSchema{Columns: []dbsql.ColumnInfo{{Name: "total_rows"}, {Name: "weird`col_nulls"}}}}, + Result: &dbsql.ResultData{DataArray: [][]string{{"5", "0"}}}, + }, nil).Once() + + w := &databricks.WorkspaceClient{StatementExecution: mockAPI} + gate := newSQLGate(8) + + out, err := discoverTable(ctx, gate, w, "wh-1", "main.public.orders") + require.NoError(t, err) + assert.Contains(t, out, "weird`col") + assert.NotContains(t, out, "Error - ") +} + func TestCancelDiscoverInFlightCallsAPIPerID(t *testing.T) { ctx := cmdio.MockDiscard(t.Context()) mockAPI := mocksql.NewMockStatementExecutionInterface(t)