diff --git a/experimental/aitools/cmd/discover_schema.go b/experimental/aitools/cmd/discover_schema.go index 091222368d..418ab78e25 100644 --- a/experimental/aitools/cmd/discover_schema.go +++ b/experimental/aitools/cmd/discover_schema.go @@ -10,6 +10,7 @@ import ( "slices" "strings" "sync" + "sync/atomic" "syscall" "github.com/databricks/cli/cmd/root" @@ -161,47 +162,20 @@ CancelExecution before the command exits.`, gate := newSQLGate(concurrency) - results := make([]string, len(args)) - g := new(errgroup.Group) - for i, table := range args { - g.Go(func() error { - result, err := discoverTable(pollCtx, gate, w, warehouseID, table) - if err != nil { - results[i] = fmt.Sprintf("Error discovering %s: %v", table, err) - } else { - results[i] = result - } - // A failure on one table shouldn't abort the others. - return nil - }) - } - _ = g.Wait() + output, anyFailed := runDiscoverSchemas(pollCtx, gate, w, warehouseID, args) if pollCtx.Err() != nil { cancelDiscoverInFlight(ctx, w.StatementExecution, gate.trackedIDs()) return root.ErrAlreadyPrinted } - // format output with dividers for multiple tables - var output string - if len(results) == 1 { - output = results[0] - } else { - divider := strings.Repeat("-", 70) - var sb strings.Builder - for i, result := range results { - if i > 0 { - sb.WriteByte('\n') - sb.WriteString(divider) - sb.WriteByte('\n') - } - fmt.Fprintf(&sb, "TABLE: %s\n%s\n", args[i], divider) - sb.WriteString(result) - } - output = sb.String() - } - cmdio.LogString(ctx, output) + if anyFailed { + // Per-table errors are already in `output`; ErrAlreadyPrinted + // gives a non-zero exit without re-printing them so scripts + // and CI can detect failure via the exit code. + return root.ErrAlreadyPrinted + } return nil }, } @@ -211,6 +185,46 @@ CancelExecution before the command exits.`, return cmd } +// runDiscoverSchemas discovers schemas for tables concurrently and returns the +// rendered output. The bool is true if any table failed; per-table errors are +// inlined into the output so one bad table doesn't abort the others. +func runDiscoverSchemas(ctx context.Context, gate *sqlGate, w *databricks.WorkspaceClient, warehouseID string, tables []string) (string, bool) { + results := make([]string, len(tables)) + var anyFailed atomic.Bool + g := new(errgroup.Group) + for i, table := range tables { + g.Go(func() error { + result, err := discoverTable(ctx, gate, w, warehouseID, table) + if err != nil { + results[i] = fmt.Sprintf("Error discovering %s: %v", table, err) + anyFailed.Store(true) + } else { + results[i] = result + } + // A failure on one table shouldn't abort the others. + return nil + }) + } + _ = g.Wait() + + if len(tables) == 1 { + return results[0], anyFailed.Load() + } + + divider := strings.Repeat("-", 70) + var sb strings.Builder + for i, result := range results { + if i > 0 { + sb.WriteByte('\n') + sb.WriteString(divider) + sb.WriteByte('\n') + } + fmt.Fprintf(&sb, "TABLE: %s\n%s\n", tables[i], divider) + sb.WriteString(result) + } + return sb.String(), anyFailed.Load() +} + // cancelDiscoverInFlight sends CancelExecution for every recorded statement_id. // Best effort: errors are logged but don't fail the user-visible exit. // Statements that already finished server-side return an error which we just diff --git a/experimental/aitools/cmd/discover_schema_test.go b/experimental/aitools/cmd/discover_schema_test.go index fe6d86e799..b76004367c 100644 --- a/experimental/aitools/cmd/discover_schema_test.go +++ b/experimental/aitools/cmd/discover_schema_test.go @@ -324,3 +324,62 @@ func TestDiscoverSchemaInvalidTableNameRejectedBeforeWorkspaceClient(t *testing. require.Error(t, err) assert.Contains(t, err.Error(), "expected CATALOG.SCHEMA.TABLE") } + +func TestRunDiscoverSchemasFlagsTableFailureForExitCode(t *testing.T) { + // runDiscoverSchemas must report any per-table failure via the bool + // return so the caller can produce a non-zero exit. Without this signal + // scripts and CI parse stdout to detect failure, which is brittle. + ctx := cmdio.MockDiscard(t.Context()) + mockAPI := mocksql.NewMockStatementExecutionInterface(t) + + mockAPI.EXPECT().ExecuteStatement(mock.Anything, mock.Anything).Return(&dbsql.StatementResponse{ + StatementId: "stmt-bad", + Status: &dbsql.StatementStatus{ + State: dbsql.StatementStateFailed, + Error: &dbsql.ServiceError{ErrorCode: "TABLE_OR_VIEW_NOT_FOUND", Message: "no such table"}, + }, + }, nil).Once() + + w := &databricks.WorkspaceClient{StatementExecution: mockAPI} + output, anyFailed := runDiscoverSchemas(ctx, newSQLGate(8), w, "wh-1", []string{"main.public.missing"}) + + assert.True(t, anyFailed) + assert.Contains(t, output, "Error discovering main.public.missing") + assert.Contains(t, output, "TABLE_OR_VIEW_NOT_FOUND") +} + +func TestRunDiscoverSchemasAllSucceedReturnsFalse(t *testing.T) { + ctx := cmdio.MockDiscard(t.Context()) + mockAPI := mocksql.NewMockStatementExecutionInterface(t) + + mockAPI.EXPECT().ExecuteStatement(mock.Anything, mock.MatchedBy(func(req dbsql.ExecuteStatementRequest) bool { + return strings.HasPrefix(req.Statement, "DESCRIBE TABLE") + })).Return(&dbsql.StatementResponse{ + StatementId: "stmt-desc", + Status: &dbsql.StatementStatus{State: dbsql.StatementStateSucceeded}, + Result: &dbsql.ResultData{DataArray: [][]string{{"id", "BIGINT", ""}}}, + }, nil).Once() + + mockAPI.EXPECT().ExecuteStatement(mock.Anything, mock.MatchedBy(func(req dbsql.ExecuteStatementRequest) bool { + return strings.HasPrefix(req.Statement, "SELECT *") + })).Return(&dbsql.StatementResponse{ + StatementId: "stmt-sample", + Status: &dbsql.StatementStatus{State: dbsql.StatementStateSucceeded}, + }, nil).Once() + + mockAPI.EXPECT().ExecuteStatement(mock.Anything, mock.MatchedBy(func(req dbsql.ExecuteStatementRequest) bool { + return strings.Contains(req.Statement, "SUM(CASE WHEN") + })).Return(&dbsql.StatementResponse{ + StatementId: "stmt-null", + Status: &dbsql.StatementStatus{State: dbsql.StatementStateSucceeded}, + Manifest: &dbsql.ResultManifest{Schema: &dbsql.ResultSchema{Columns: []dbsql.ColumnInfo{{Name: "total_rows"}, {Name: "id_nulls"}}}}, + Result: &dbsql.ResultData{DataArray: [][]string{{"7", "0"}}}, + }, nil).Once() + + w := &databricks.WorkspaceClient{StatementExecution: mockAPI} + output, anyFailed := runDiscoverSchemas(ctx, newSQLGate(8), w, "wh-1", []string{"main.public.orders"}) + + assert.False(t, anyFailed) + assert.Contains(t, output, "COLUMNS:") + assert.NotContains(t, output, "Error discovering") +}