Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 42 additions & 21 deletions experimental/aitools/cmd/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -262,21 +262,17 @@ func resolveWarehouseID(ctx context.Context, w any, flagValue string) (string, e
func executeAndPoll(ctx context.Context, api sql.StatementExecutionInterface, warehouseID, statement string) (*sql.StatementResponse, error) {
// Submit asynchronously to get the statement ID immediately for cancellation.
resp, err := api.ExecuteStatement(ctx, sql.ExecuteStatementRequest{
WarehouseId: warehouseID,
Statement: statement,
WaitTimeout: "0s",
WarehouseId: warehouseID,
Statement: statement,
WaitTimeout: "0s",
OnWaitTimeout: sql.ExecuteStatementRequestOnWaitTimeoutContinue,
})
if err != nil {
return nil, fmt.Errorf("execute statement: %w", err)
}

statementID := resp.StatementId

// Check if it completed immediately.
if isTerminalState(resp.Status) {
return resp, checkFailedState(resp.Status)
}

// Set up Ctrl+C: signal cancels the poll context, cleanup is unified below.
pollCtx, pollCancel := context.WithCancel(ctx)
defer pollCancel()
Expand Down Expand Up @@ -327,34 +323,59 @@ func executeAndPoll(ctx context.Context, api sql.StatementExecutionInterface, wa
}
}()

pollResp, err := pollStatement(pollCtx, api, resp)
if err != nil {
if pollCtx.Err() != nil {
cancelStatement()
cmdio.LogString(ctx, "Query cancelled.")
return nil, root.ErrAlreadyPrinted
}
return nil, err
}

sp.Close()
if err := checkFailedState(pollResp.Status); err != nil {
return nil, err
}
return pollResp, nil
}

// pollStatement polls until the statement reaches a terminal state.
//
// On context cancellation it returns the context error WITHOUT cancelling the
// server-side statement. Callers that want server-side cancellation should
// invoke CancelExecution explicitly.
//
// If the input response is already in a terminal state, it is returned without
// further polling.
func pollStatement(ctx context.Context, api sql.StatementExecutionInterface, resp *sql.StatementResponse) (*sql.StatementResponse, error) {
if isTerminalState(resp.Status) {
return resp, nil
}

statementID := resp.StatementId
start := time.Now()

// Poll with additive backoff: 1s, 2s, 3s, 4s, 5s (capped).
interval := pollIntervalInitial
for {
select {
case <-pollCtx.Done():
cancelStatement()
cmdio.LogString(ctx, "Query cancelled.")
return nil, root.ErrAlreadyPrinted
case <-ctx.Done():
return nil, ctx.Err()
case <-time.After(interval):
}

log.Debugf(ctx, "Polling statement %s: %s elapsed", statementID, time.Since(start).Truncate(time.Second))

pollResp, err := api.GetStatementByStatementId(pollCtx, statementID)
pollResp, err := api.GetStatementByStatementId(ctx, statementID)
if err != nil {
if pollCtx.Err() != nil {
cancelStatement()
cmdio.LogString(ctx, "Query cancelled.")
return nil, root.ErrAlreadyPrinted
if ctx.Err() != nil {
return nil, ctx.Err()
}
return nil, fmt.Errorf("poll statement status: %w", err)
}

if isTerminalState(pollResp.Status) {
sp.Close()
if err := checkFailedState(pollResp.Status); err != nil {
return nil, err
}
return &sql.StatementResponse{
StatementId: pollResp.StatementId,
Status: pollResp.Status,
Expand Down
105 changes: 104 additions & 1 deletion experimental/aitools/cmd/query_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package aitools

import (
"context"
"errors"
"os"
"path/filepath"
"strings"
Expand Down Expand Up @@ -48,7 +49,9 @@ func TestExecuteAndPollImmediateSuccess(t *testing.T) {
mockAPI := mocksql.NewMockStatementExecutionInterface(t)

mockAPI.EXPECT().ExecuteStatement(mock.Anything, mock.MatchedBy(func(req sql.ExecuteStatementRequest) bool {
return req.WarehouseId == "wh-123" && req.Statement == "SELECT 1" && req.WaitTimeout == "0s"
return req.WarehouseId == "wh-123" && req.Statement == "SELECT 1" &&
req.WaitTimeout == "0s" &&
req.OnWaitTimeout == sql.ExecuteStatementRequestOnWaitTimeoutContinue
})).Return(&sql.StatementResponse{
StatementId: "stmt-1",
Status: &sql.StatementStatus{State: sql.StatementStateSucceeded},
Expand Down Expand Up @@ -154,6 +157,106 @@ func TestExecuteAndPollCancelledContextCallsCancelExecution(t *testing.T) {
require.ErrorIs(t, err, root.ErrAlreadyPrinted)
}

func TestPollStatementImmediateTerminal(t *testing.T) {
ctx := cmdio.MockDiscard(t.Context())
mockAPI := mocksql.NewMockStatementExecutionInterface(t)

resp := &sql.StatementResponse{
StatementId: "stmt-1",
Status: &sql.StatementStatus{State: sql.StatementStateSucceeded},
Manifest: &sql.ResultManifest{Schema: &sql.ResultSchema{Columns: []sql.ColumnInfo{{Name: "1"}}}},
Result: &sql.ResultData{DataArray: [][]string{{"1"}}},
}

pollResp, err := pollStatement(ctx, mockAPI, resp)
require.NoError(t, err)
assert.Equal(t, sql.StatementStateSucceeded, pollResp.Status.State)
assert.Equal(t, "stmt-1", pollResp.StatementId)
}

func TestPollStatementTerminalFailureNotErrored(t *testing.T) {
// pollStatement returns the response without erroring on failed terminal
// states; callers (e.g. executeAndPoll) decide what to do via checkFailedState.
ctx := cmdio.MockDiscard(t.Context())
mockAPI := mocksql.NewMockStatementExecutionInterface(t)

resp := &sql.StatementResponse{
StatementId: "stmt-1",
Status: &sql.StatementStatus{
State: sql.StatementStateFailed,
Error: &sql.ServiceError{ErrorCode: "ERR", Message: "boom"},
},
}

pollResp, err := pollStatement(ctx, mockAPI, resp)
require.NoError(t, err)
assert.Equal(t, sql.StatementStateFailed, pollResp.Status.State)
}

func TestPollStatementEventualSuccess(t *testing.T) {
ctx := cmdio.MockDiscard(t.Context())
mockAPI := mocksql.NewMockStatementExecutionInterface(t)

initial := &sql.StatementResponse{
StatementId: "stmt-1",
Status: &sql.StatementStatus{State: sql.StatementStatePending},
}

mockAPI.EXPECT().GetStatementByStatementId(mock.Anything, "stmt-1").Return(&sql.StatementResponse{
StatementId: "stmt-1",
Status: &sql.StatementStatus{State: sql.StatementStateRunning},
}, nil).Once()

mockAPI.EXPECT().GetStatementByStatementId(mock.Anything, "stmt-1").Return(&sql.StatementResponse{
StatementId: "stmt-1",
Status: &sql.StatementStatus{State: sql.StatementStateSucceeded},
Result: &sql.ResultData{DataArray: [][]string{{"42"}}},
}, nil).Once()

pollResp, err := pollStatement(ctx, mockAPI, initial)
require.NoError(t, err)
assert.Equal(t, sql.StatementStateSucceeded, pollResp.Status.State)
assert.Equal(t, [][]string{{"42"}}, pollResp.Result.DataArray)
}

func TestPollStatementContextCancellationDoesNotCancelServerSide(t *testing.T) {
// The mock asserts (via t.Cleanup) that no unexpected calls are made.
// Specifically, pollStatement must NOT call CancelExecution on context
// cancellation; that is the caller's responsibility.
ctx, cancel := context.WithCancel(cmdio.MockDiscard(t.Context()))
mockAPI := mocksql.NewMockStatementExecutionInterface(t)

initial := &sql.StatementResponse{
StatementId: "stmt-1",
Status: &sql.StatementStatus{State: sql.StatementStatePending},
}

cancel()

pollResp, err := pollStatement(ctx, mockAPI, initial)
require.ErrorIs(t, err, context.Canceled)
assert.Nil(t, pollResp)
}

func TestPollStatementGetErrorPropagated(t *testing.T) {
ctx := cmdio.MockDiscard(t.Context())
mockAPI := mocksql.NewMockStatementExecutionInterface(t)

initial := &sql.StatementResponse{
StatementId: "stmt-1",
Status: &sql.StatementStatus{State: sql.StatementStatePending},
}

mockAPI.EXPECT().GetStatementByStatementId(mock.Anything, "stmt-1").
Return(nil, errors.New("network unreachable")).Once()

pollResp, err := pollStatement(ctx, mockAPI, initial)
require.Error(t, err)
assert.Contains(t, err.Error(), "poll statement status")
assert.Contains(t, err.Error(), "network unreachable")
assert.Nil(t, pollResp)
}

func TestResolveWarehouseIDWithFlag(t *testing.T) {
ctx := t.Context()
id, err := resolveWarehouseID(ctx, nil, "explicit-id")
Expand Down