From c022bbabee433521dca8e183e3638abc28ad88f9 Mon Sep 17 00:00:00 2001 From: boromisp <2242020+boromisp@users.noreply.github.com> Date: Sun, 12 May 2024 20:36:43 +0200 Subject: [PATCH] Cancellable Query with AbortSignal option --- packages/pg/lib/client.js | 1 + packages/pg/lib/connection.js | 34 +++++- packages/pg/lib/query.js | 82 ++++++++++--- .../cancel-query-with-abort-signal-tests.js | 108 ++++++++++++++++++ 4 files changed, 209 insertions(+), 16 deletions(-) create mode 100644 packages/pg/test/integration/cancel/cancel-query-with-abort-signal-tests.js diff --git a/packages/pg/lib/client.js b/packages/pg/lib/client.js index 13071e88c..09aa8054b 100644 --- a/packages/pg/lib/client.js +++ b/packages/pg/lib/client.js @@ -51,6 +51,7 @@ class Client extends EventEmitter { keepAlive: c.keepAlive || false, keepAliveInitialDelayMillis: c.keepAliveInitialDelayMillis || 0, encoding: this.connectionParameters.client_encoding || 'utf8', + Promise: this._Promise, }) this.queryQueue = [] this.binary = c.binary || defaults.binary diff --git a/packages/pg/lib/connection.js b/packages/pg/lib/connection.js index de8e978a2..b07fc11e1 100644 --- a/packages/pg/lib/connection.js +++ b/packages/pg/lib/connection.js @@ -1,6 +1,5 @@ 'use strict' -var net = require('net') var EventEmitter = require('events').EventEmitter const { parse, serialize } = require('pg-protocol') @@ -43,11 +42,40 @@ class Connection extends EventEmitter { self._emitMessage = true } }) + + this._config = config + this._backendData = null + this._remote = null + } + + cancelWithClone() { + const config = this._config + const Promise = config.Promise || global.Promise + + return new Promise((resolve, reject) => { + const { processID, secretKey } = this._backendData + let { host, port, notIP } = this._remote + if (host && notIP && config.ssl && this.stream.remoteAddress) { + if (config.ssl === true) { + config.ssl = {} + } + config.ssl.servername = host + host = this.stream.remoteAddress + } + + const con = new Connection(config) + con + .on('connect', () => con.cancel(processID, secretKey)) + .on('error', reject) + .on('end', resolve) + .connect(port, host) + }) } connect(port, host) { var self = this + this._remote = { host, port } this._connecting = true this.stream.setNoDelay(true) this.stream.connect(port, host) @@ -108,6 +136,7 @@ class Connection extends EventEmitter { var net = require('net') if (net.isIP && net.isIP(host) === 0) { options.servername = host + self._remote.notIP = true } try { self.stream = getSecureStream(options) @@ -128,6 +157,9 @@ class Connection extends EventEmitter { this.emit('message', msg) } this.emit(eventName, msg) + if (msg.name === 'backendKeyData') { + this._backendData = msg + } }) } diff --git a/packages/pg/lib/query.js b/packages/pg/lib/query.js index fac4d86e3..3b23b0a00 100644 --- a/packages/pg/lib/query.js +++ b/packages/pg/lib/query.js @@ -5,6 +5,31 @@ const { EventEmitter } = require('events') const Result = require('./result') const utils = require('./utils') +function setupCancellation(cancelSignal, connection) { + let cancellation = null + + function cancelRequest() { + cancellation = connection.cancelWithClone().catch(() => { + // We could still have a cancel request in flight targeting this connection. + // Better safe than sorry? + connection.stream.destroy() + }) + } + + cancelSignal.addEventListener('abort', cancelRequest, { once: true }) + + return { + cleanup() { + if (cancellation) { + // Must wait out connection.cancelWithClone + return cancellation + } + cancelSignal.removeEventListener('abort', cancelRequest) + return Promise.resolve() + }, + } +} + class Query extends EventEmitter { constructor(config, values, callback) { super() @@ -29,6 +54,8 @@ class Query extends EventEmitter { // potential for multiple results this._results = this._result this._canceledDueToError = false + + this._cancelSignal = config.signal } requiresPreparation() { @@ -114,34 +141,53 @@ class Query extends EventEmitter { } } + _handleQueryComplete(fn) { + if (!this._cancellation) { + fn() + return + } + this._cancellation + .cleanup() + .then(fn) + .finally(() => { + this._cancellation = null + }) + } + handleError(err, connection) { // need to sync after error during a prepared statement if (this._canceledDueToError) { err = this._canceledDueToError this._canceledDueToError = false } - // if callback supplied do not emit error event as uncaught error - // events will bubble up to node process - if (this.callback) { - return this.callback(err) - } - this.emit('error', err) + + this._handleQueryComplete(() => { + // if callback supplied do not emit error event as uncaught error + // events will bubble up to node process + if (this.callback) { + return this.callback(err) + } + this.emit('error', err) + }) } handleReadyForQuery(con) { if (this._canceledDueToError) { return this.handleError(this._canceledDueToError, con) } - if (this.callback) { - try { - this.callback(null, this._results) - } catch (err) { - process.nextTick(() => { - throw err - }) + + this._handleQueryComplete(() => { + if (this.callback) { + try { + this.callback(null, this._results) + } catch (err) { + process.nextTick(() => { + throw err + }) + } } - } - this.emit('end', this._results) + this.emit('end', this._results) + }) } submit(connection) { @@ -155,6 +201,12 @@ class Query extends EventEmitter { if (this.values && !Array.isArray(this.values)) { return new Error('Query values must be an array') } + if (this._cancelSignal) { + if (this._cancelSignal.aborted) { + return this._cancelSignal.reason + } + this._cancellation = setupCancellation(this._cancelSignal, connection) + } if (this.requiresPreparation()) { this.prepare(connection) } else { diff --git a/packages/pg/test/integration/cancel/cancel-query-with-abort-signal-tests.js b/packages/pg/test/integration/cancel/cancel-query-with-abort-signal-tests.js new file mode 100644 index 000000000..bbd8e7a76 --- /dev/null +++ b/packages/pg/test/integration/cancel/cancel-query-with-abort-signal-tests.js @@ -0,0 +1,108 @@ +var helper = require('./../test-helper') + +var pg = helper.pg +const Client = pg.Client +const DatabaseError = pg.DatabaseError + +if (!global.AbortController) { + // skip these tests on node < 15 + return +} + +const suite = new helper.Suite('query cancellation with abort signal') + +suite.test('query with signal succeeds if not aborted', function (done) { + const client = new Client() + const { signal } = new AbortController() + + client.connect( + assert.success(() => { + client.query( + new pg.Query({ text: 'select pg_sleep(0.1)', signal }), + assert.success((result) => { + assert.equal(result.rows[0].pg_sleep, '') + client.end(done) + }) + ) + }) + ) +}) + +suite.test('query with signal is not submitted if the signal is already aborted', function (done) { + const client = new Client() + const signal = AbortSignal.abort() + + let counter = 0 + + client.query( + new pg.Query({ text: 'INVALID SQL...' }), + assert.calls((err) => { + assert(err instanceof DatabaseError) + counter++ + }) + ) + + client.query( + new pg.Query({ text: 'begin' }), + assert.success(() => { + counter++ + }) + ) + + client.query( + new pg.Query({ text: 'INVALID SQL...', signal }), + assert.calls((err) => { + assert.equal(err.name, 'AbortError') + counter++ + }) + ) + + client.query( + new pg.Query({ text: 'select 1' }), + assert.success(() => { + counter++ + assert.equal(counter, 4) + client.end(done) + }) + ) + + client.connect(assert.success(() => {})) +}) + +suite.test('query can be canceled with abort signal', function (done) { + const client = new Client() + const ac = new AbortController() + const { signal } = ac + + client.query( + new pg.Query({ text: 'SELECT pg_sleep(0.5)', signal }), + assert.calls((err) => { + assert(err instanceof DatabaseError) + assert(err.code === '57014') + client.end(done) + }) + ) + + client.connect( + assert.success(() => { + setTimeout(() => { + ac.abort() + }, 50) + }) + ) +}) + +suite.test('long abort signal timeout does not keep the query / connection going', function (done) { + const client = new Client() + const signal = AbortSignal.timeout(10_000) + + client.query( + new pg.Query({ text: 'SELECT pg_sleep(0.1)', signal }), + assert.success((result) => { + assert.equal(result.rows[0].pg_sleep, '') + client.end(done) + }) + ) + + client.connect(assert.success(() => {})) +})