-
-
Notifications
You must be signed in to change notification settings - Fork 1k
Add CancelAndDrainContextWatcherHandler #2534
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,131 @@ | ||||||||
| package pgconn | ||||||||
|
|
||||||||
| import ( | ||||||||
| "context" | ||||||||
| "time" | ||||||||
|
|
||||||||
| "github.com/jackc/pgx/v5/pgconn/ctxwatch" | ||||||||
| "github.com/jackc/pgx/v5/pgproto3" | ||||||||
| ) | ||||||||
|
|
||||||||
| // CancelAndDrainContextWatcherHandler handles cancelled contexts by sending a cancel request to the server and then | ||||||||
| // draining any pending SQLSTATE 57014 (query_canceled) with a single ";" round-trip. Unlike [CancelRequestContextWatcherHandler], | ||||||||
| // no fixed sleep is used; the drain is deterministic. | ||||||||
| type CancelAndDrainContextWatcherHandler struct { | ||||||||
| Conn *PgConn | ||||||||
|
|
||||||||
| // DeadlineDelay is the network deadline set on the connection when the context | ||||||||
| // is cancelled, used as a fallback to unblock any blocked read. Defaults to 1s. | ||||||||
| DeadlineDelay time.Duration | ||||||||
|
|
||||||||
| // DrainTimeout is the maximum time to spend draining a cancelled query's | ||||||||
| // in-flight results via SELECT 1 polling. Defaults to 5s. | ||||||||
| DrainTimeout time.Duration | ||||||||
|
|
||||||||
| cancelFinishedChan chan struct{} | ||||||||
| stopFn context.CancelFunc | ||||||||
| } | ||||||||
|
|
||||||||
| var _ ctxwatch.Handler = (*CancelAndDrainContextWatcherHandler)(nil) | ||||||||
|
|
||||||||
| func (h *CancelAndDrainContextWatcherHandler) deadlineDelay() time.Duration { | ||||||||
| if h.DeadlineDelay == 0 { | ||||||||
| return time.Second | ||||||||
| } | ||||||||
| return h.DeadlineDelay | ||||||||
| } | ||||||||
|
|
||||||||
| func (h *CancelAndDrainContextWatcherHandler) drainTimeout() time.Duration { | ||||||||
| if h.DrainTimeout == 0 { | ||||||||
| return 5 * time.Second | ||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is how I would do it if I owned this library, but I'm matching the existing style - there are no duration constants anywhere else. |
||||||||
| } | ||||||||
| return h.DrainTimeout | ||||||||
| } | ||||||||
|
|
||||||||
| // HandleCancel is called when the context is cancelled. It sets a net.Conn deadline | ||||||||
| // as a fallback and sends a PostgreSQL cancel request in a goroutine. | ||||||||
| func (h *CancelAndDrainContextWatcherHandler) HandleCancel(_ context.Context) { | ||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why is
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the ctx passed here is already cancelled, this is the same pattern used by the existing handler: Lines 2885 to 2887 in a5680bc
|
||||||||
| h.cancelFinishedChan = make(chan struct{}) | ||||||||
| cancelCtx, stop := context.WithCancel(context.Background()) | ||||||||
| h.stopFn = stop | ||||||||
|
|
||||||||
| deadline := time.Now().Add(h.deadlineDelay()) | ||||||||
| h.Conn.conn.SetDeadline(deadline) | ||||||||
|
|
||||||||
| doneCh := h.cancelFinishedChan | ||||||||
| go func() { | ||||||||
| defer close(doneCh) | ||||||||
| reqCtx, cancel := context.WithDeadline(cancelCtx, deadline) | ||||||||
| defer cancel() | ||||||||
| h.Conn.CancelRequest(reqCtx) | ||||||||
| }() | ||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can this be replaced with a
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this pattern matches the pre-existing handler, there's only 1 goroutine here, and no meaningful error to return |
||||||||
| } | ||||||||
|
|
||||||||
| // HandleUnwatchAfterCancel is called after the cancelled query returns. It stops the cancel goroutine (if still | ||||||||
| // running), clears the net.Conn deadline, and drains any in-flight cancel by polling SELECT 1. | ||||||||
| func (h *CancelAndDrainContextWatcherHandler) HandleUnwatchAfterCancel() { | ||||||||
| if h.stopFn != nil { | ||||||||
| h.stopFn() | ||||||||
| } | ||||||||
| if h.cancelFinishedChan != nil { | ||||||||
| <-h.cancelFinishedChan | ||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This can block indefinitely
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. defer close(doneCh) // doneCh here is h.cancelFinishedChan
reqCtx, cancel := context.WithDeadline(cancelCtx, deadline) // cancelCtx is cancelled by stopFnso if stopFn gets called, this channel is closed, and if that doesn't happen for some reason, it's still closed by the deadline |
||||||||
| } | ||||||||
| h.Conn.conn.SetDeadline(time.Time{}) | ||||||||
| h.cancelFinishedChan = nil | ||||||||
| h.stopFn = nil | ||||||||
|
|
||||||||
| if !h.Conn.IsClosed() { | ||||||||
| ctx, cancel := context.WithTimeout(context.Background(), h.drainTimeout()) | ||||||||
| defer cancel() | ||||||||
| h.Conn.execInternalForDrain(ctx) | ||||||||
| } | ||||||||
| } | ||||||||
|
|
||||||||
| // queryCanceledSQLStateCode is SQLSTATE 57014 (query_canceled). | ||||||||
| const queryCanceledSQLStateCode = "57014" | ||||||||
|
|
||||||||
| // execInternalForDrain sends a single ";" and reads until ReadyForQuery, absorbing any | ||||||||
| // SQLSTATE 57014 (query_canceled). One round-trip is sufficient: PostgreSQL sets | ||||||||
| // QueryCancelPending at most once per cancel signal, so at most one 57014 can arrive. | ||||||||
| // On any failure the connection is asyncClosed. | ||||||||
| // | ||||||||
| // Called while the connection is still logically "busy" from pgconn's perspective | ||||||||
| // (lock is held and contextWatcher.Unwatch has been called) but idle from the | ||||||||
| // PostgreSQL server's perspective (ReadyForQuery was just received). This means | ||||||||
| // it bypasses the normal lock/unlock and contextWatcher.Watch paths. | ||||||||
| // | ||||||||
| // The deadline from ctx is applied directly to the net.Conn. | ||||||||
| func (pgConn *PgConn) execInternalForDrain(ctx context.Context) { | ||||||||
| if deadline, ok := ctx.Deadline(); ok { | ||||||||
| pgConn.conn.SetDeadline(deadline) | ||||||||
| defer pgConn.conn.SetDeadline(time.Time{}) | ||||||||
| } | ||||||||
|
|
||||||||
| pgConn.frontend.Send(&pgproto3.Query{String: ";"}) | ||||||||
| if err := pgConn.frontend.Flush(); err != nil { | ||||||||
| pgConn.asyncClose() | ||||||||
| return | ||||||||
| } | ||||||||
|
|
||||||||
| for { | ||||||||
| msg, err := pgConn.receiveMessage() | ||||||||
| if err != nil { | ||||||||
| pgConn.asyncClose() | ||||||||
| return | ||||||||
| } | ||||||||
|
|
||||||||
| switch msg := msg.(type) { | ||||||||
| case *pgproto3.ReadyForQuery: | ||||||||
| return | ||||||||
| case *pgproto3.ErrorResponse: | ||||||||
| pgErr := ErrorResponseToPgError(msg) | ||||||||
| if pgErr.Code != queryCanceledSQLStateCode { | ||||||||
| pgConn.asyncClose() | ||||||||
| return | ||||||||
| } | ||||||||
| // 57014 absorbed — continue reading until ReadyForQuery | ||||||||
| case *pgproto3.EmptyQueryResponse: | ||||||||
| // Expected response for ";". | ||||||||
| } | ||||||||
| } | ||||||||
| } | ||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,261 @@ | ||
| package pgconn_test | ||
|
|
||
| import ( | ||
| "context" | ||
| "fmt" | ||
| "io" | ||
| "os" | ||
| "testing" | ||
| "time" | ||
|
|
||
| "github.com/jackc/pgx/v5/pgconn" | ||
| "github.com/jackc/pgx/v5/pgconn/ctxwatch" | ||
| "github.com/stretchr/testify/assert" | ||
| "github.com/stretchr/testify/require" | ||
| ) | ||
|
|
||
| func buildCancelAndDrainConfig(t *testing.T) *pgconn.Config { | ||
| t.Helper() | ||
| config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) | ||
| require.NoError(t, err) | ||
| config.BuildContextWatcherHandler = func(conn *pgconn.PgConn) ctxwatch.Handler { | ||
| return &pgconn.CancelAndDrainContextWatcherHandler{Conn: conn} | ||
| } | ||
| config.ConnectTimeout = 5 * time.Second | ||
| return config | ||
| } | ||
|
|
||
| func TestCancelAndDrainContextWatcherHandler(t *testing.T) { | ||
| t.Parallel() | ||
|
|
||
| t.Run("connection reused after cancel", func(t *testing.T) { | ||
| t.Parallel() | ||
|
|
||
| pgConn, err := pgconn.ConnectConfig(context.Background(), buildCancelAndDrainConfig(t)) | ||
| require.NoError(t, err) | ||
| defer closeConn(t, pgConn) | ||
|
|
||
| ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) | ||
| defer cancel() | ||
|
|
||
| _, err = pgConn.Exec(ctx, "select pg_sleep(10)").ReadAll() | ||
| require.Error(t, err) | ||
| require.False(t, pgConn.IsClosed(), "connection should not be closed after cancel with drain handler") | ||
|
|
||
| ensureConnValid(t, pgConn) | ||
| }) | ||
|
|
||
| t.Run("no stale cancel bleed", func(t *testing.T) { | ||
| t.Parallel() | ||
|
|
||
| pgConn, err := pgconn.ConnectConfig(context.Background(), buildCancelAndDrainConfig(t)) | ||
| require.NoError(t, err) | ||
| defer closeConn(t, pgConn) | ||
|
|
||
| for i := range 50 { | ||
| func() { | ||
| ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond) | ||
| defer cancel() | ||
| pgConn.Exec(ctx, "select pg_sleep(0.020)").ReadAll() | ||
| }() | ||
|
|
||
| if pgConn.IsClosed() { | ||
| var err error | ||
| pgConn, err = pgconn.ConnectConfig(context.Background(), buildCancelAndDrainConfig(t)) | ||
| require.NoError(t, err, "iteration %d: failed to reconnect after closed connection", i) | ||
| } | ||
|
|
||
| ensureConnValid(t, pgConn) | ||
| } | ||
| }) | ||
|
|
||
| t.Run("stress", func(t *testing.T) { | ||
| t.Parallel() | ||
|
|
||
| for i := range 10 { | ||
| t.Run(fmt.Sprintf("goroutine_%d", i), func(t *testing.T) { | ||
| t.Parallel() | ||
|
|
||
| pgConn, err := pgconn.ConnectConfig(context.Background(), buildCancelAndDrainConfig(t)) | ||
| require.NoError(t, err) | ||
| defer closeConn(t, pgConn) | ||
|
|
||
| for j := range 20 { | ||
| func() { | ||
| ctx, cancel := context.WithTimeout(context.Background(), 4*time.Millisecond) | ||
| defer cancel() | ||
| pgConn.Exec(ctx, "select pg_sleep(0.010)").ReadAll() | ||
| }() | ||
|
|
||
| if pgConn.IsClosed() { | ||
| var err error | ||
| pgConn, err = pgconn.ConnectConfig(context.Background(), buildCancelAndDrainConfig(t)) | ||
| require.NoError(t, err, "goroutine %d iteration %d: failed to reconnect", i, j) | ||
| } | ||
|
|
||
| ensureConnValid(t, pgConn) | ||
| } | ||
| }) | ||
| } | ||
| }) | ||
|
|
||
| t.Run("ExecParams", func(t *testing.T) { | ||
| t.Parallel() | ||
|
|
||
| pgConn, err := pgconn.ConnectConfig(context.Background(), buildCancelAndDrainConfig(t)) | ||
| require.NoError(t, err) | ||
| defer closeConn(t, pgConn) | ||
|
|
||
| ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) | ||
| defer cancel() | ||
|
|
||
| rr := pgConn.ExecParams(ctx, "select pg_sleep(10)", nil, nil, nil, nil) | ||
| rr.Read() | ||
| _, err = rr.Close() | ||
| assert.Error(t, err) | ||
|
|
||
| if !pgConn.IsClosed() { | ||
| ensureConnValid(t, pgConn) | ||
| } | ||
| }) | ||
|
|
||
| t.Run("CopyTo", func(t *testing.T) { | ||
| t.Parallel() | ||
|
|
||
| pgConn, err := pgconn.ConnectConfig(context.Background(), buildCancelAndDrainConfig(t)) | ||
| require.NoError(t, err) | ||
| defer closeConn(t, pgConn) | ||
|
|
||
| ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) | ||
| defer cancel() | ||
|
|
||
| _, err = pgConn.CopyTo(ctx, io.Discard, "COPY (SELECT pg_sleep(10)) TO STDOUT") | ||
| assert.Error(t, err) | ||
|
|
||
| if !pgConn.IsClosed() { | ||
| ensureConnValid(t, pgConn) | ||
| } | ||
| }) | ||
|
|
||
| t.Run("CopyFrom", func(t *testing.T) { | ||
| t.Parallel() | ||
|
|
||
| pgConn, err := pgconn.ConnectConfig(context.Background(), buildCancelAndDrainConfig(t)) | ||
| require.NoError(t, err) | ||
| defer closeConn(t, pgConn) | ||
|
|
||
| _, err = pgConn.Exec(context.Background(), "CREATE TEMP TABLE drain_test_copyfrom (id int)").ReadAll() | ||
| require.NoError(t, err) | ||
|
|
||
| ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) | ||
| defer cancel() | ||
|
|
||
| pr, pw := io.Pipe() | ||
| defer pr.Close() | ||
| defer pw.Close() | ||
|
|
||
| _, err = pgConn.CopyFrom(ctx, pr, "COPY drain_test_copyfrom FROM STDIN") | ||
| assert.Error(t, err) | ||
|
|
||
| if !pgConn.IsClosed() { | ||
| ensureConnValid(t, pgConn) | ||
| } | ||
| }) | ||
|
|
||
| t.Run("Pipeline", func(t *testing.T) { | ||
| t.Parallel() | ||
|
|
||
| pgConn, err := pgconn.ConnectConfig(context.Background(), buildCancelAndDrainConfig(t)) | ||
| require.NoError(t, err) | ||
| defer closeConn(t, pgConn) | ||
|
|
||
| ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) | ||
| defer cancel() | ||
|
|
||
| pipeline := pgConn.StartPipeline(ctx) | ||
|
|
||
| pipeline.SendQueryParams("select pg_sleep(10)", nil, nil, nil, nil) | ||
| err = pipeline.Sync() | ||
| require.NoError(t, err) | ||
|
|
||
| pipeline.Close() | ||
|
|
||
| require.False(t, pgConn.IsClosed(), "connection should not be closed after cancelled pipeline with drain handler") | ||
| ensureConnValid(t, pgConn) | ||
| }) | ||
|
|
||
| t.Run("Prepare", func(t *testing.T) { | ||
| t.Parallel() | ||
|
|
||
| pgConn, err := pgconn.ConnectConfig(context.Background(), buildCancelAndDrainConfig(t)) | ||
| require.NoError(t, err) | ||
| defer closeConn(t, pgConn) | ||
|
|
||
| for i := range 20 { | ||
| func() { | ||
| ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond) | ||
| defer cancel() | ||
| pgConn.Prepare(ctx, "", "select pg_sleep(0.010)", nil) | ||
| }() | ||
|
|
||
| if pgConn.IsClosed() { | ||
| var err error | ||
| pgConn, err = pgconn.ConnectConfig(context.Background(), buildCancelAndDrainConfig(t)) | ||
| require.NoError(t, err, "iteration %d: failed to reconnect after closed connection", i) | ||
| } | ||
|
|
||
| ensureConnValid(t, pgConn) | ||
| } | ||
| }) | ||
|
|
||
| t.Run("Deallocate", func(t *testing.T) { | ||
| t.Parallel() | ||
|
|
||
| pgConn, err := pgconn.ConnectConfig(context.Background(), buildCancelAndDrainConfig(t)) | ||
| require.NoError(t, err) | ||
| defer closeConn(t, pgConn) | ||
|
|
||
| for i := range 20 { | ||
| _, err := pgConn.Prepare(context.Background(), "drain_dealloc_test", "select 1", nil) | ||
| require.NoError(t, err, "iteration %d: prepare failed", i) | ||
|
|
||
| func() { | ||
| ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond) | ||
| defer cancel() | ||
| pgConn.Deallocate(ctx, "drain_dealloc_test") | ||
| }() | ||
|
|
||
| if pgConn.IsClosed() { | ||
| var err error | ||
| pgConn, err = pgconn.ConnectConfig(context.Background(), buildCancelAndDrainConfig(t)) | ||
| require.NoError(t, err, "iteration %d: failed to reconnect after closed connection", i) | ||
| } | ||
|
|
||
| ensureConnValid(t, pgConn) | ||
| } | ||
| }) | ||
|
|
||
| t.Run("WaitForNotification", func(t *testing.T) { | ||
| t.Parallel() | ||
|
|
||
| pgConn, err := pgconn.ConnectConfig(context.Background(), buildCancelAndDrainConfig(t)) | ||
| require.NoError(t, err) | ||
| defer closeConn(t, pgConn) | ||
|
|
||
| if pgConn.ParameterStatus("crdb_version") != "" { | ||
| t.Skip("Server does not support LISTEN / NOTIFY (https://github.com/cockroachdb/cockroach/issues/41522)") | ||
| } | ||
|
|
||
| _, err = pgConn.Exec(context.Background(), "LISTEN drain_test_channel").ReadAll() | ||
| require.NoError(t, err) | ||
|
|
||
| ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) | ||
| defer cancel() | ||
|
|
||
| err = pgConn.WaitForNotification(ctx) | ||
| require.Error(t, err) | ||
|
|
||
| require.False(t, pgConn.IsClosed(), "connection should not be closed after cancelled WaitForNotification with drain handler") | ||
| ensureConnValid(t, pgConn) | ||
| }) | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not just use the
context.Contextcreated inHandleCancel? (e.g.shutdownCtx)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you're suggesting something like this:
this would work fine but is more allocation and diverges from the pattern in the other handler...I don't personally see much reason to prefer it, but I'm not against making that change.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar to what I said below, this should all be handled in pgxpool or puddle (somehow), and not done at the PgConn layer.