Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 48 additions & 34 deletions experimental/aitools/cmd/discover_schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"slices"
"strings"
"sync"
"sync/atomic"
"syscall"

"github.com/databricks/cli/cmd/root"
Expand Down Expand Up @@ -161,47 +162,20 @@ CancelExecution before the command exits.`,

gate := newSQLGate(concurrency)

results := make([]string, len(args))
g := new(errgroup.Group)
for i, table := range args {
g.Go(func() error {
result, err := discoverTable(pollCtx, gate, w, warehouseID, table)
if err != nil {
results[i] = fmt.Sprintf("Error discovering %s: %v", table, err)
} else {
results[i] = result
}
// A failure on one table shouldn't abort the others.
return nil
})
}
_ = g.Wait()
output, anyFailed := runDiscoverSchemas(pollCtx, gate, w, warehouseID, args)

if pollCtx.Err() != nil {
cancelDiscoverInFlight(ctx, w.StatementExecution, gate.trackedIDs())
return root.ErrAlreadyPrinted
}

// format output with dividers for multiple tables
var output string
if len(results) == 1 {
output = results[0]
} else {
divider := strings.Repeat("-", 70)
var sb strings.Builder
for i, result := range results {
if i > 0 {
sb.WriteByte('\n')
sb.WriteString(divider)
sb.WriteByte('\n')
}
fmt.Fprintf(&sb, "TABLE: %s\n%s\n", args[i], divider)
sb.WriteString(result)
}
output = sb.String()
}

cmdio.LogString(ctx, output)
if anyFailed {
// Per-table errors are already in `output`; ErrAlreadyPrinted
// gives a non-zero exit without re-printing them so scripts
// and CI can detect failure via the exit code.
return root.ErrAlreadyPrinted
}
return nil
},
}
Expand All @@ -211,6 +185,46 @@ CancelExecution before the command exits.`,
return cmd
}

// runDiscoverSchemas discovers schemas for tables concurrently and returns the
// rendered output. The bool is true if any table failed; per-table errors are
// inlined into the output so one bad table doesn't abort the others.
func runDiscoverSchemas(ctx context.Context, gate *sqlGate, w *databricks.WorkspaceClient, warehouseID string, tables []string) (string, bool) {
results := make([]string, len(tables))
var anyFailed atomic.Bool
g := new(errgroup.Group)
for i, table := range tables {
g.Go(func() error {
result, err := discoverTable(ctx, gate, w, warehouseID, table)
if err != nil {
results[i] = fmt.Sprintf("Error discovering %s: %v", table, err)
anyFailed.Store(true)
} else {
results[i] = result
}
// A failure on one table shouldn't abort the others.
return nil
})
}
_ = g.Wait()

if len(tables) == 1 {
return results[0], anyFailed.Load()
}

divider := strings.Repeat("-", 70)
var sb strings.Builder
for i, result := range results {
if i > 0 {
sb.WriteByte('\n')
sb.WriteString(divider)
sb.WriteByte('\n')
}
fmt.Fprintf(&sb, "TABLE: %s\n%s\n", tables[i], divider)
sb.WriteString(result)
}
return sb.String(), anyFailed.Load()
}

// cancelDiscoverInFlight sends CancelExecution for every recorded statement_id.
// Best effort: errors are logged but don't fail the user-visible exit.
// Statements that already finished server-side return an error which we just
Expand Down
59 changes: 59 additions & 0 deletions experimental/aitools/cmd/discover_schema_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -324,3 +324,62 @@ func TestDiscoverSchemaInvalidTableNameRejectedBeforeWorkspaceClient(t *testing.
require.Error(t, err)
assert.Contains(t, err.Error(), "expected CATALOG.SCHEMA.TABLE")
}

func TestRunDiscoverSchemasFlagsTableFailureForExitCode(t *testing.T) {
// runDiscoverSchemas must report any per-table failure via the bool
// return so the caller can produce a non-zero exit. Without this signal
// scripts and CI parse stdout to detect failure, which is brittle.
ctx := cmdio.MockDiscard(t.Context())
mockAPI := mocksql.NewMockStatementExecutionInterface(t)

mockAPI.EXPECT().ExecuteStatement(mock.Anything, mock.Anything).Return(&dbsql.StatementResponse{
StatementId: "stmt-bad",
Status: &dbsql.StatementStatus{
State: dbsql.StatementStateFailed,
Error: &dbsql.ServiceError{ErrorCode: "TABLE_OR_VIEW_NOT_FOUND", Message: "no such table"},
},
}, nil).Once()

w := &databricks.WorkspaceClient{StatementExecution: mockAPI}
output, anyFailed := runDiscoverSchemas(ctx, newSQLGate(8), w, "wh-1", []string{"main.public.missing"})

assert.True(t, anyFailed)
assert.Contains(t, output, "Error discovering main.public.missing")
assert.Contains(t, output, "TABLE_OR_VIEW_NOT_FOUND")
}

func TestRunDiscoverSchemasAllSucceedReturnsFalse(t *testing.T) {
ctx := cmdio.MockDiscard(t.Context())
mockAPI := mocksql.NewMockStatementExecutionInterface(t)

mockAPI.EXPECT().ExecuteStatement(mock.Anything, mock.MatchedBy(func(req dbsql.ExecuteStatementRequest) bool {
return strings.HasPrefix(req.Statement, "DESCRIBE TABLE")
})).Return(&dbsql.StatementResponse{
StatementId: "stmt-desc",
Status: &dbsql.StatementStatus{State: dbsql.StatementStateSucceeded},
Result: &dbsql.ResultData{DataArray: [][]string{{"id", "BIGINT", ""}}},
}, nil).Once()

mockAPI.EXPECT().ExecuteStatement(mock.Anything, mock.MatchedBy(func(req dbsql.ExecuteStatementRequest) bool {
return strings.HasPrefix(req.Statement, "SELECT *")
})).Return(&dbsql.StatementResponse{
StatementId: "stmt-sample",
Status: &dbsql.StatementStatus{State: dbsql.StatementStateSucceeded},
}, nil).Once()

mockAPI.EXPECT().ExecuteStatement(mock.Anything, mock.MatchedBy(func(req dbsql.ExecuteStatementRequest) bool {
return strings.Contains(req.Statement, "SUM(CASE WHEN")
})).Return(&dbsql.StatementResponse{
StatementId: "stmt-null",
Status: &dbsql.StatementStatus{State: dbsql.StatementStateSucceeded},
Manifest: &dbsql.ResultManifest{Schema: &dbsql.ResultSchema{Columns: []dbsql.ColumnInfo{{Name: "total_rows"}, {Name: "id_nulls"}}}},
Result: &dbsql.ResultData{DataArray: [][]string{{"7", "0"}}},
}, nil).Once()

w := &databricks.WorkspaceClient{StatementExecution: mockAPI}
output, anyFailed := runDiscoverSchemas(ctx, newSQLGate(8), w, "wh-1", []string{"main.public.orders"})

assert.False(t, anyFailed)
assert.Contains(t, output, "COLUMNS:")
assert.NotContains(t, output, "Error discovering")
}