diff --git a/packages/cubejs-api-gateway/src/sql-server.ts b/packages/cubejs-api-gateway/src/sql-server.ts index 81d1d4aff9e9a..a43b62b2131f6 100644 --- a/packages/cubejs-api-gateway/src/sql-server.ts +++ b/packages/cubejs-api-gateway/src/sql-server.ts @@ -7,6 +7,7 @@ import { Request as NativeRequest, LoadRequestMeta, } from '@cubejs-backend/native'; +import type { ShutdownMode } from '@cubejs-backend/native'; import { displayCLIWarning, getEnv } from '@cubejs-backend/shared'; import * as crypto from 'crypto'; @@ -347,7 +348,7 @@ export class SQLServer { // @todo Implement } - public async shutdown(): Promise { - await shutdownInterface(this.sqlInterfaceInstance!); + public async shutdown(mode: ShutdownMode): Promise { + await shutdownInterface(this.sqlInterfaceInstance!, mode); } } diff --git a/packages/cubejs-backend-native/js/index.ts b/packages/cubejs-backend-native/js/index.ts index ea8d42e30527a..19169921e7661 100644 --- a/packages/cubejs-backend-native/js/index.ts +++ b/packages/cubejs-backend-native/js/index.ts @@ -328,10 +328,12 @@ export const registerInterface = async (options: SQLInterfaceOptions): Promise => { +export type ShutdownMode = 'fast' | 'semifast' | 'smart'; + +export const shutdownInterface = async (instance: SqlInterfaceInstance, shutdownMode: ShutdownMode): Promise => { const native = loadNative(); - await native.shutdownInterface(instance); + await native.shutdownInterface(instance, shutdownMode); }; export const execSql = async (instance: SqlInterfaceInstance, sqlQuery: string, stream: any, securityContext?: any): Promise => { diff --git a/packages/cubejs-backend-native/src/node_export.rs b/packages/cubejs-backend-native/src/node_export.rs index 6bb56abd18af5..981753dfa54e1 100644 --- a/packages/cubejs-backend-native/src/node_export.rs +++ b/packages/cubejs-backend-native/src/node_export.rs @@ -1,4 +1,5 @@ use cubesql::compile::{convert_sql_to_cube_query, get_df_batches}; +use cubesql::config::processing_loop::ShutdownMode; use cubesql::config::ConfigObj; use cubesql::sql::{DatabaseProtocol, SessionManager}; use cubesql::transport::TransportService; @@ -123,6 +124,17 @@ fn register_interface(mut cx: FunctionContext) -> JsResult fn shutdown_interface(mut cx: FunctionContext) -> JsResult { let interface = cx.argument::>(0)?; + let js_shutdown_mode = cx.argument::(1)?; + let shutdown_mode = match js_shutdown_mode.value(&mut cx).as_str() { + "fast" => ShutdownMode::Fast, + "semifast" => ShutdownMode::SemiFast, + "smart" => ShutdownMode::Smart, + _ => { + return cx.throw_range_error::<&str, Handle>( + "ShutdownMode param must be 'fast', 'semifast', or 'smart'", + ); + } + }; let (deferred, promise) = cx.promise(); let channel = cx.channel(); @@ -131,7 +143,7 @@ fn shutdown_interface(mut cx: FunctionContext) -> JsResult { let runtime = tokio_runtime_node(&mut cx)?; runtime.spawn(async move { - match services.stop_processing_loops().await { + match services.stop_processing_loops(shutdown_mode).await { Ok(_) => { if let Err(err) = services.await_processing_loops().await { log::error!("Error during awaiting on shutdown: {}", err) diff --git a/packages/cubejs-backend-native/test/server.js b/packages/cubejs-backend-native/test/server.js index 10188d6d955fa..83ca37a70d389 100644 --- a/packages/cubejs-backend-native/test/server.js +++ b/packages/cubejs-backend-native/test/server.js @@ -138,7 +138,7 @@ const meta_fixture = require('./meta'); console.log('SIGINT signal'); try { - await native.shutdownInterface(server); + await native.shutdownInterface(server, 'fast'); } catch (e) { console.log(e); } finally { diff --git a/packages/cubejs-backend-native/test/sql.test.ts b/packages/cubejs-backend-native/test/sql.test.ts index 184d51209e6f7..8de9a16dbe5c5 100644 --- a/packages/cubejs-backend-native/test/sql.test.ts +++ b/packages/cubejs-backend-native/test/sql.test.ts @@ -296,7 +296,7 @@ describe('SQLInterface', () => { await connection.end(); } finally { - await native.shutdownInterface(instance); + await native.shutdownInterface(instance, 'fast'); } }); @@ -346,7 +346,7 @@ describe('SQLInterface', () => { expect(rows).toBe(100000); - await native.shutdownInterface(instance); + await native.shutdownInterface(instance, 'fast'); } else { expect(process.env.CUBESQL_STREAM_MODE).toBeFalsy(); } diff --git a/packages/cubejs-server/src/server.ts b/packages/cubejs-server/src/server.ts index eb03b65753d81..6759b526eac00 100644 --- a/packages/cubejs-server/src/server.ts +++ b/packages/cubejs-server/src/server.ts @@ -231,7 +231,7 @@ export class CubejsServer { if (this.sqlServer) { locks.push( - this.sqlServer.shutdown() + this.sqlServer.shutdown(graceful && (signal === 'SIGTERM') ? 'semifast' : 'fast') ); } diff --git a/packages/cubejs-server/src/server/container.ts b/packages/cubejs-server/src/server/container.ts index c0b0a03001d82..cd96ef5e999a0 100644 --- a/packages/cubejs-server/src/server/container.ts +++ b/packages/cubejs-server/src/server/container.ts @@ -443,6 +443,7 @@ export class ServerContainer { process.exit(1); } } else { + console.log(`Recevied ${signal} signal, terminating with process exit`); process.exit(0); } }); diff --git a/packages/cubejs-testing/package.json b/packages/cubejs-testing/package.json index 0306f0ea2ac11..91c208dda73d2 100644 --- a/packages/cubejs-testing/package.json +++ b/packages/cubejs-testing/package.json @@ -58,6 +58,8 @@ "smoke:crate:snapshot": "jest --verbose --updateSnapshot -i dist/test/smoke-crate.test.js", "smoke:firebolt": "jest --verbose -i dist/test/smoke-firebolt.test.js", "smoke:firebolt:snapshot": "jest --updateSnapshot --verbose -i dist/test/smoke-firebolt.test.js", + "smoke:graceful-shutdown": "jest --verbose -i dist/test/smoke-graceful-shutdown.test.js", + "smoke:graceful-shutdown:snapshot": "jest --updateSnapshot --verbose -i dist/test/smoke-graceful-shutdown.test.js", "smoke:lambda": "jest --verbose -i dist/test/smoke-lambda.test.js", "smoke:lambda:snapshot": "jest --updateSnapshot --verbose -i dist/test/smoke-lambda.test.js", "smoke:materialize": "jest --verbose -i dist/test/smoke-materialize.test.js", diff --git a/packages/cubejs-testing/src/birdbox.ts b/packages/cubejs-testing/src/birdbox.ts index 7ec1c4bd51df8..a2941b4d589f7 100644 --- a/packages/cubejs-testing/src/birdbox.ts +++ b/packages/cubejs-testing/src/birdbox.ts @@ -222,11 +222,17 @@ function prepareTestData(type: DriverType, schemas?: Schemas) { } } +// Some logic to kill Cube in stop is more precise if we know killCube is only used to send signals +// that get the process terminated. +type KillCubeSignal = 'SIGINT' | 'SIGTERM'; + /** * Birdbox object interface. */ export interface BirdBox { stop: () => Promise; + killCube: (signal: KillCubeSignal) => void; + onCubeExit: () => Promise; stdout: internal.Readable | null; configuration: { playgroundUrl: string; @@ -246,11 +252,21 @@ export async function startBirdBoxFromContainer( if (process.env.TEST_CUBE_HOST) { const host = process.env.TEST_CUBE_HOST || 'localhost'; const port = process.env.TEST_CUBE_PORT || '8888'; + const pid = process.env.TEST_CUBE_PID ? Number(process.env.TEST_CUBE_PID) : null; return { stop: async () => { process.stdout.write('[Birdbox] Closed\n'); }, + killCube: (signal: KillCubeSignal) => { + if (pid !== null) { + process.kill(pid, signal); + } else { + process.stdout.write(`[Birdbox] Cannot kill Cube instance running in TEST_CUBE_HOST mode without TEST_CUBE_PID defined\n`); + throw new Error('Attempted to use killCube while running with TEST_CUBE_HOST'); + } + }, + onCubeExit: (): Promise => Promise.reject(new Error('onCubeExit not implemented')), // TODO: Implement stdout: null, configuration: { playgroundUrl: `http://${host}:${port}`, @@ -409,6 +425,15 @@ export async function startBirdBoxFromContainer( process.stdout.write('[Birdbox] Closed\n'); } }, + killCube: (signal: KillCubeSignal) => { + process.stdout.write(`[Birdbox] killCube (with signal ${signal}) not implemented for containers\n`); + throw new Error('killCube not implemented for containers'); + }, + onCubeExit: (): Promise => { + const _ = 0; + return Promise.reject(new Error('onCubeExit not implemented for containers')); + // TODO: Implement. + }, configuration: { playgroundUrl: `http://${host}:${playgroundPort}`, apiUrl: `http://${host}:${port}/cubejs-api/v1`, @@ -559,6 +584,11 @@ export async function startBirdBoxFromCli( ...options.env, }; + let exitResolve: (code: number | null) => void; + const exitPromise = new Promise((res, _rej) => { + exitResolve = res; + }); + try { cli = spawn( options.useCubejsServerBinary @@ -589,6 +619,10 @@ export async function startBirdBoxFromCli( process.stdout.write(msg); }); } + cli.on('exit', (code, signal) => { + process.stdout.write(`[Birdbox] Child process '${cli.pid}' exited with 'exit' event code ${code}, signal ${signal}\n`); + exitResolve(code); + }); await pausePromise(10 * 1000); } catch (e) { process.stdout.write(`Error spawning cube: ${e}\n`); @@ -596,6 +630,7 @@ export async function startBirdBoxFromCli( db.stop(); } + let sentKillSignal = false; return { // @ts-expect-error stdout: cli.stdout, @@ -611,12 +646,30 @@ export async function startBirdBoxFromCli( process.stdout.write('[Birdbox] Done with DB\n'); } if (cli.pid) { - process.kill(-cli.pid, 'SIGINT'); + process.stdout.write(`[Birdbox] Killing process group '${cli.pid}'\n`); + // Here, normally, we kill the process group by passing -cli.pid (a negative value), but + // with killCube we just kill the main process, and then can't kill any process group -- + // maybe that test has poor cleanup actions. + try { + process.kill(-cli.pid, 'SIGINT'); + } catch (error) { + if (!sentKillSignal) { + throw error; + } + } } if (options.log === Log.PIPE) { process.stdout.write('[Birdbox] Closed\n'); } }, + killCube: (signal: KillCubeSignal) => { + process.stdout.write(`[Birdbox] Killing Cube (pid = '${cli.pid}') with signal ${signal}\n`); + if (cli.pid) { + process.kill(cli.pid, signal); + sentKillSignal = true; + } + }, + onCubeExit: (): Promise => exitPromise, configuration: { playgroundUrl: 'http://127.0.0.1:4000', apiUrl: 'http://127.0.0.1:4000/cubejs-api/v1', diff --git a/packages/cubejs-testing/test/__snapshots__/smoke-graceful-shutdown.test.ts.snap b/packages/cubejs-testing/test/__snapshots__/smoke-graceful-shutdown.test.ts.snap new file mode 100644 index 0000000000000..09c5f5e3ebd0f --- /dev/null +++ b/packages/cubejs-testing/test/__snapshots__/smoke-graceful-shutdown.test.ts.snap @@ -0,0 +1,52 @@ +// Jest Snapshot v1, https://goo.gl/fbAQLP + +exports[`graceful shutdown PgClient Graceful Shutdown Finishing Transaction: sql_orders 1`] = ` +Array [ + Object { + "cn": "2", + "status": "processed", + }, + Object { + "cn": "2", + "status": "new", + }, + Object { + "cn": "1", + "status": "shipped", + }, +] +`; + +exports[`graceful shutdown PgClient Graceful Shutdown SIGINT: sql_orders 1`] = ` +Array [ + Object { + "cn": "2", + "status": "processed", + }, + Object { + "cn": "2", + "status": "new", + }, + Object { + "cn": "1", + "status": "shipped", + }, +] +`; + +exports[`graceful shutdown PgClient Graceful Shutdown SIGTERM: sql_orders 1`] = ` +Array [ + Object { + "cn": "2", + "status": "processed", + }, + Object { + "cn": "2", + "status": "new", + }, + Object { + "cn": "1", + "status": "shipped", + }, +] +`; diff --git a/packages/cubejs-testing/test/smoke-graceful-shutdown.test.ts b/packages/cubejs-testing/test/smoke-graceful-shutdown.test.ts new file mode 100644 index 0000000000000..efe06cfe1ff44 --- /dev/null +++ b/packages/cubejs-testing/test/smoke-graceful-shutdown.test.ts @@ -0,0 +1,303 @@ +import { StartedTestContainer } from 'testcontainers'; +// eslint-disable-next-line import/no-extraneous-dependencies +import { afterAll, beforeAll, expect, jest } from '@jest/globals'; +import { Client as PgClient } from 'pg'; +import { PostgresDBRunner } from '@cubejs-backend/testing-shared'; +import { getBirdbox } from '../src'; +import { + DEFAULT_CONFIG, + JEST_AFTER_ALL_DEFAULT_TIMEOUT, + JEST_BEFORE_ALL_DEFAULT_TIMEOUT, +} from './smoke-tests'; + +describe('graceful shutdown', () => { + jest.setTimeout(60 * 5 * 1000); + let db: StartedTestContainer; + + // For when graceful shutdown is not supposed to timeout, vs. for when it is supposed + // to timeout. + const longGracefulTimeoutSecs = 30; + const shortGracefulTimeoutSecs = 1; + + const pgPort = 5656; // Make random? (Value and comment taken from smoke-cubesql.) + let connectionId = 0; + + // Since we use 'error' and 'end' events for some tests, it is necessary or wise to let the event + // loop spin around once before asserting. + const yieldImmediate = () => new Promise(setImmediate); + + function unconnectedPostgresClient(user: string, password: string) { + connectionId++; + const currentConnId = connectionId; + + console.debug(`[pg] new connection ${currentConnId}`); + + const conn = new PgClient({ + database: 'db', + port: pgPort, + host: 'localhost', + user, + password, + ssl: false, + }); + conn.on('end', () => { + console.debug(`[pg] end ${currentConnId}`); + }); + + return conn; + } + + const makeBirdbox = (gracefulTimeoutSecs: number) => getBirdbox( + 'postgres', + { + ...DEFAULT_CONFIG, + // + CUBESQL_LOG_LEVEL: 'trace', + // + CUBEJS_DB_TYPE: 'postgres', + CUBEJS_DB_HOST: db.getHost(), + CUBEJS_DB_PORT: `${db.getMappedPort(5432)}`, + CUBEJS_DB_NAME: 'test', + CUBEJS_DB_USER: 'test', + CUBEJS_DB_PASS: 'test', + // + CUBEJS_PG_SQL_PORT: `${pgPort}`, + CUBESQL_SQL_PUSH_DOWN: 'true', + CUBESQL_STREAM_MODE: 'true', + + CUBEJS_GRACEFUL_SHUTDOWN: gracefulTimeoutSecs.toString(), + }, + { + schemaDir: 'smoke/schema', + cubejsConfig: 'smoke/cube.js', + }, + ); + + beforeAll(async () => { + db = await PostgresDBRunner.startContainer({}); + }, JEST_BEFORE_ALL_DEFAULT_TIMEOUT); + + afterAll(async () => { + await db.stop(); + }, JEST_AFTER_ALL_DEFAULT_TIMEOUT); + + const clientless = async (signal: 'SIGTERM' | 'SIGINT') => { + const birdbox = await makeBirdbox(longGracefulTimeoutSecs); + try { + birdbox.killCube(signal); + const code = await birdbox.onCubeExit(); + expect(code).toEqual(0); + } finally { + await birdbox.stop(); + } + }; + + test('Clientless Graceful Shutdown SIGTERM', async () => { + await clientless('SIGTERM'); + }); + + test('Clientless Graceful Shutdown SIGINT', async () => { + await clientless('SIGINT'); + }); + + const betweenQueries = async (signal: 'SIGTERM' | 'SIGINT') => { + const birdbox = await makeBirdbox(longGracefulTimeoutSecs); + try { + const connection: PgClient = unconnectedPostgresClient('admin', 'admin_password'); + + let endResolve: () => void; + const endPromise = new Promise((res, _rej) => { + endResolve = res; + }); + await connection.connect(); + + connection.on('end', () => { endResolve(); }); + let logTerminationErrors = true; + let shutdownErrors = 0; + connection.on('error', (e: Error) => { + const err = e as any; + if (err.severity === 'FATAL' && err.code === '57P01') { + shutdownErrors += 1; + } else if (logTerminationErrors && err.message !== 'Connection terminated unexpectedly') { + console.log(err); + } + }); + try { + const res = await connection.query( + 'SELECT COUNT(*) as cn, "status" FROM Orders GROUP BY 2 ORDER BY cn DESC' + ); + expect(res.rows).toMatchSnapshot('sql_orders'); + + logTerminationErrors = false; + birdbox.killCube(signal); + const code = await birdbox.onCubeExit(); + expect(code).toEqual(0); + } finally { + // Normally the connection ends by server shutdown, and this .end() call returns + // a Promise which never gets fulfilled. + const _ = connection.end(); + await endPromise; + } + expect(shutdownErrors).toEqual(1); + } finally { + await birdbox.stop(); + } + }; + + test('PgClient Graceful Shutdown SIGTERM', async () => { + await betweenQueries('SIGTERM'); + }); + + test('PgClient Graceful Shutdown SIGINT', async () => { + await betweenQueries('SIGINT'); + }); + + const midTransaction = async (signal: 'SIGTERM' | 'SIGINT') => { + const birdbox = await makeBirdbox(signal === 'SIGTERM' ? shortGracefulTimeoutSecs : longGracefulTimeoutSecs); + try { + const connection: PgClient = unconnectedPostgresClient('admin', 'admin_password'); + + let endResolve: () => void; + const endPromise = new Promise((res, _rej) => { + endResolve = res; + }); + await connection.connect(); + + let connectionEnded = false; + connection.on('end', () => { + connectionEnded = true; + endResolve(); + }); + let logTerminationErrors = true; + let shutdownErrors = 0; + let expectedShutdownErrors: number; + connection.on('error', (e: Error) => { + const err = e as any; + if (err.severity === 'FATAL' && err.code === '57P01') { + shutdownErrors += 1; + } else if (logTerminationErrors && err.message !== 'Connection terminated unexpectedly') { + console.log(err); + } + }); + try { + const res = await connection.query( + 'BEGIN' + ); + expect(res.command).toEqual('BEGIN'); + + // Sanity check: our SQL api client connection is still open. (I mean, we haven't even + // killed Cube.) + await yieldImmediate(); + expect(connectionEnded).toBe(false); + + logTerminationErrors = false; + birdbox.killCube(signal); + const code = await birdbox.onCubeExit(); + + /* This test may be overspecifying -- we have no requirement that the exit code be non-zero + if graceful shutdown times out. But for testing purposes, it does provide a handy way to + determine which mechanism caused the server to shut down. */ + if (signal === 'SIGTERM') { + expectedShutdownErrors = 0; + expect(code).not.toEqual(0); + } else { + expectedShutdownErrors = 1; + expect(code).toEqual(0); + } + } finally { + // Normally the connection ends by server shutdown, and this .end() call returns + // a Promise which never gets fulfilled. So we sign up for and wait for the event. + const _ = connection.end(); + await endPromise; + } + + await yieldImmediate(); + expect(shutdownErrors).toEqual(expectedShutdownErrors); + } finally { + await birdbox.stop(); + } + }; + + test('PgClient Graceful Shutdown Mid-Transaction SIGTERM', async () => { + await midTransaction('SIGTERM'); + }); + + test('PgClient Graceful Shutdown Mid-Transaction SIGINT', async () => { + await midTransaction('SIGINT'); + }); + + const waitForTransaction = async (signal: 'SIGTERM') => { + const birdbox = await makeBirdbox(longGracefulTimeoutSecs); + try { + const connection: PgClient = unconnectedPostgresClient('admin', 'admin_password'); + + let endResolve: () => void; + const endPromise = new Promise((res, _rej) => { + endResolve = res; + }); + await connection.connect(); + + let connectionEnded = false; + connection.on('end', () => { + connectionEnded = true; + endResolve(); + }); + let logTerminationErrors = true; + let shutdownErrors = 0; + connection.on('error', (e: Error) => { + const err = e as any; + if (err.severity === 'FATAL' && err.code === '57P01') { + shutdownErrors += 1; + } else if (logTerminationErrors && err.message !== 'Connection terminated unexpectedly') { + console.log(err); + } + }); + try { + // 1. Begin a transaction + const beginRes = await connection.query( + 'BEGIN' + ); + expect(beginRes.command).toEqual('BEGIN'); + + // 2. Kill Cube with SIGTERM. + birdbox.killCube(signal); + + // 3. Run a query (because why not?). + const selectRes = await connection.query( + 'SELECT COUNT(*) as cn, "status" FROM Orders GROUP BY 2 ORDER BY cn DESC' + ); + expect(selectRes.rows).toMatchSnapshot('sql_orders'); + + // Our SQL api client connection is still open. + await yieldImmediate(); + expect(connectionEnded).toBe(false); + + logTerminationErrors = false; + + // 4. Commit the transaction (or rollback). + const commitRes = await connection.query( + 'COMMIT' + ); + expect(commitRes.command).toEqual('COMMIT'); + + // 5. Now wait for the Cube exit result. + const code = await birdbox.onCubeExit(); + expect(code).toEqual(0); + } finally { + // Normally the connection ends by server shutdown, and this .end() call returns + // a Promise which never gets fulfilled. So we sign up for and wait for the event. + const _ = connection.end(); + await endPromise; + } + + await yieldImmediate(); + expect(shutdownErrors).toEqual(1); + } finally { + await birdbox.stop(); + } + }; + + test('PgClient Graceful Shutdown Finishing Transaction', async () => { + await waitForTransaction('SIGTERM'); + }); +}); diff --git a/rust/cubesql/cubesql/src/bin/cubesqld.rs b/rust/cubesql/cubesql/src/bin/cubesqld.rs index 385dad0cd49ed..e3d33a2cd8536 100644 --- a/rust/cubesql/cubesql/src/bin/cubesqld.rs +++ b/rust/cubesql/cubesql/src/bin/cubesqld.rs @@ -1,5 +1,5 @@ use cubesql::{ - config::{Config, CubeServices}, + config::{processing_loop::ShutdownMode, Config, CubeServices}, telemetry::{LocalReporter, ReportingLogger}, }; @@ -58,7 +58,7 @@ async fn stop_on_ctrl_c(s: &Arc) { counter += 1; if counter == 1 { log::info!("Received Ctrl+C, shutting down."); - s.stop_processing_loops().await.ok(); + s.stop_processing_loops(ShutdownMode::Fast).await.ok(); } else if counter == 3 { log::info!("Received Ctrl+C 3 times, exiting immediately."); std::process::exit(130); // 130 is the default exit code when killed by a signal. diff --git a/rust/cubesql/cubesql/src/config/mod.rs b/rust/cubesql/cubesql/src/config/mod.rs index 7ff9cec17885c..7e13be76a52f4 100644 --- a/rust/cubesql/cubesql/src/config/mod.rs +++ b/rust/cubesql/cubesql/src/config/mod.rs @@ -4,7 +4,7 @@ pub mod processing_loop; use crate::{ config::{ injection::{DIService, Injector}, - processing_loop::ProcessingLoop, + processing_loop::{ProcessingLoop, ShutdownMode}, }, sql::{PostgresServer, ServerManager, SessionManager, SqlAuthDefaultImpl, SqlAuthService}, transport::{HttpTransport, TransportService}, @@ -60,12 +60,15 @@ impl CubeServices { Ok(futures) } - pub async fn stop_processing_loops(&self) -> Result<(), CubeError> { + pub async fn stop_processing_loops( + &self, + shutdown_mode: ShutdownMode, + ) -> Result<(), CubeError> { if self.injector.has_service_typed::().await { self.injector .get_service_typed::() .await - .stop_processing() + .stop_processing(shutdown_mode) .await?; } diff --git a/rust/cubesql/cubesql/src/config/processing_loop.rs b/rust/cubesql/cubesql/src/config/processing_loop.rs index c8987623ec5f2..7a99d623e822a 100644 --- a/rust/cubesql/cubesql/src/config/processing_loop.rs +++ b/rust/cubesql/cubesql/src/config/processing_loop.rs @@ -1,9 +1,21 @@ use crate::CubeError; use async_trait::async_trait; +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy)] +pub enum ShutdownMode { + // Note that these values are ordered from least-urgent to most-urgent. + + // Postgres "Smart" mode leaves connections up until the client terminates them. + Smart, + // Shuts down connections when they have no pending operations. + SemiFast, + // Sends fatal error messages to clients and shuts down as soon as it can. Same as Postgres "Fast" mode. + Fast, +} + #[async_trait] pub trait ProcessingLoop: Send + Sync { async fn processing_loop(&self) -> Result<(), CubeError>; - async fn stop_processing(&self) -> Result<(), CubeError>; + async fn stop_processing(&self, mode: ShutdownMode) -> Result<(), CubeError>; } diff --git a/rust/cubesql/cubesql/src/error.rs b/rust/cubesql/cubesql/src/error.rs index 7f90a7fd9e784..64b0c6c9d5db1 100644 --- a/rust/cubesql/cubesql/src/error.rs +++ b/rust/cubesql/cubesql/src/error.rs @@ -219,8 +219,15 @@ impl From> for CubeError { } } -impl From> for CubeError { - fn from(v: tokio::sync::watch::error::SendError) -> Self { +impl + From>> + for CubeError +{ + fn from( + v: tokio::sync::watch::error::SendError< + Option, + >, + ) -> Self { CubeError::internal(v.to_string()) } } diff --git a/rust/cubesql/cubesql/src/sql/postgres/service.rs b/rust/cubesql/cubesql/src/sql/postgres/service.rs index a0b2489d08a0b..f523c28b30731 100644 --- a/rust/cubesql/cubesql/src/sql/postgres/service.rs +++ b/rust/cubesql/cubesql/src/sql/postgres/service.rs @@ -8,7 +8,7 @@ use tokio::{ use tokio_util::sync::CancellationToken; use crate::{ - config::processing_loop::ProcessingLoop, + config::processing_loop::{ProcessingLoop, ShutdownMode}, sql::{session::DatabaseProtocol, SessionManager}, telemetry::{ContextLogger, SessionLogger}, CubeError, @@ -19,8 +19,8 @@ use super::shim::AsyncPostgresShim; pub struct PostgresServer { // options address: String, - close_socket_rx: RwLock>, - close_socket_tx: watch::Sender, + close_socket_rx: RwLock>>, + close_socket_tx: watch::Sender>, // reference session_manager: Arc, } @@ -34,19 +34,40 @@ impl ProcessingLoop for PostgresServer { println!("🔗 Cube SQL (pg) is listening on {}", self.address); - let shim_cancellation_token = CancellationToken::new(); + let fast_shutdown_interruptor = CancellationToken::new(); + let semifast_shutdown_interruptor = CancellationToken::new(); let mut joinset = tokio::task::JoinSet::new(); + let mut active_shutdown_mode: Option = None; loop { let mut stop_receiver = self.close_socket_rx.write().await; let (socket, _) = tokio::select! { - res = stop_receiver.changed() => { - if res.is_err() || *stop_receiver.borrow() { - trace!("[pg] Stopping processing_loop via channel"); - - shim_cancellation_token.cancel(); - break; + _ = stop_receiver.changed() => { + let mode = *stop_receiver.borrow(); + if mode > active_shutdown_mode { + active_shutdown_mode = mode; + match active_shutdown_mode { + Some(ShutdownMode::Fast) => { + trace!("[pg] Stopping processing_loop via channel, fast mode"); + + fast_shutdown_interruptor.cancel(); + break; + } + Some(ShutdownMode::SemiFast) => { + trace!("[pg] Stopping processing_loop via channel, semifast mode"); + + semifast_shutdown_interruptor.cancel(); + break; + } + Some(ShutdownMode::Smart) => { + trace!("[pg] Stopping processing_loop via interruptor, smart mode"); + break; + } + None => { + unreachable!("mode compared greater than something; it can't be None"); + } + } } else { continue; } @@ -87,10 +108,12 @@ impl ProcessingLoop for PostgresServer { let connection_id = session.state.connection_id; let session_manager = self.session_manager.clone(); - let connection_interruptor = shim_cancellation_token.clone(); + let fast_shutdown_interruptor = fast_shutdown_interruptor.clone(); + let semifast_shutdown_interruptor = semifast_shutdown_interruptor.clone(); let join_handle: tokio::task::JoinHandle<()> = tokio::spawn(async move { let handler = AsyncPostgresShim::run_on( - connection_interruptor, + fast_shutdown_interruptor, + semifast_shutdown_interruptor, socket, session.clone(), logger.clone(), @@ -121,24 +144,64 @@ impl ProcessingLoop for PostgresServer { }); } + // Close the listening socket (so we _visibly_ stop accepting incoming connections) before + // we wait for the outstanding connection tasks finish. + std::mem::drop(listener); + // Now that we've had the stop signal, wait for outstanding connection tasks to finish // cleanly. - while let Some(_) = joinset.join_next().await { - // We do nothing here, same as the join_next() handler in the loop. + + loop { + let mut stop_receiver = self.close_socket_rx.write().await; + tokio::select! { + _ = stop_receiver.changed() => { + let mode = *stop_receiver.borrow(); + if mode > active_shutdown_mode { + active_shutdown_mode = mode; + match active_shutdown_mode { + Some(ShutdownMode::Fast) => { + trace!("[pg] Stopping processing_loop via channel: upgrading to fast mode"); + + fast_shutdown_interruptor.cancel(); + } + Some(ShutdownMode::SemiFast) => { + trace!("[pg] Stopping processing_loop via channel: upgrading to semifast mode"); + + semifast_shutdown_interruptor.cancel(); + } + _ => { + // Because of comparisons made, the smallest and 2nd smallest + // Option values are impossible. + unreachable!("impossible mode value, where mode={:?}", active_shutdown_mode); + } + } + } else { + continue; + } + } + res = joinset.join_next() => { + if let None = res { + break; + } else { + // We do nothing here, same as the other join_next() cleanup in the prior loop. + continue; + } + } + } } Ok(()) } - async fn stop_processing(&self) -> Result<(), CubeError> { - self.close_socket_tx.send(true)?; + async fn stop_processing(&self, mode: ShutdownMode) -> Result<(), CubeError> { + self.close_socket_tx.send(Some(mode))?; Ok(()) } } impl PostgresServer { pub fn new(address: String, session_manager: Arc) -> Arc { - let (close_socket_tx, close_socket_rx) = watch::channel(false); + let (close_socket_tx, close_socket_rx) = watch::channel(None::); Arc::new(Self { address, session_manager, diff --git a/rust/cubesql/cubesql/src/sql/postgres/shim.rs b/rust/cubesql/cubesql/src/sql/postgres/shim.rs index aeb507090450d..35d6636ae70d4 100644 --- a/rust/cubesql/cubesql/src/sql/postgres/shim.rs +++ b/rust/cubesql/cubesql/src/sql/postgres/shim.rs @@ -17,7 +17,7 @@ use crate::{ session::DatabaseProtocol, statement::{PostgresStatementParamsFinder, StatementPlaceholderReplacer}, types::CommandCompletion, - AuthContextRef, Session, StatusFlags, + AuthContextRef, Session, SessionState, StatusFlags, }, telemetry::ContextLogger, transport::SpanId, @@ -39,6 +39,7 @@ pub struct AsyncPostgresShim { socket: TcpStream, // If empty, this means socket is on a message boundary. partial_write_buf: bytes::BytesMut, + semifast_shutdown_interruptor: CancellationToken, // Extended query cursors: HashMap, portals: HashMap, @@ -226,13 +227,28 @@ impl From for ConnectionError { } impl AsyncPostgresShim { + async fn flush_and_write_admin_shutdown_fatal_message( + shim: &mut AsyncPostgresShim, + ) -> Result<(), ConnectionError> { + // We flush the partially written buf and add the fatal message -- it's another place's + // responsibility to impose a timeout and abort us. + shim.socket + .write_all_buf(&mut shim.partial_write_buf) + .await?; + shim.partial_write_buf = bytes::BytesMut::new(); + shim.write_admin_shutdown_fatal_message().await?; + return Ok(()); + } + pub async fn run_on( - shutdown_interruptor: CancellationToken, + fast_shutdown_interruptor: CancellationToken, + semifast_shutdown_interruptor: CancellationToken, socket: TcpStream, session: Arc, logger: Arc, ) -> Result<(), ConnectionError> { let mut shim = Self { + semifast_shutdown_interruptor, socket, partial_write_buf: bytes::BytesMut::new(), cursors: HashMap::new(), @@ -242,12 +258,8 @@ impl AsyncPostgresShim { }; let run_result = tokio::select! { - _ = shutdown_interruptor.cancelled() => { - // We flush the partially written buf and add the fatal message -- it's another - // place's responsibility to impose a timeout and abort us. - shim.socket.write_all_buf(&mut shim.partial_write_buf).await?; - shim.partial_write_buf = bytes::BytesMut::new(); - shim.write_admin_shutdown_fatal_message().await?; + _ = fast_shutdown_interruptor.cancelled() => { + Self::flush_and_write_admin_shutdown_fatal_message(&mut shim).await?; shim.socket.shutdown().await?; return Ok(()); } @@ -282,6 +294,16 @@ impl AsyncPostgresShim { } } + fn session_state_is_semifast_shutdownable(session_state: &SessionState) -> bool { + return !session_state.is_in_transaction() && !session_state.has_current_query(); + } + + fn is_semifast_shutdownable(&self) -> bool { + return self.cursors.is_empty() + && self.portals.is_empty() + && Self::session_state_is_semifast_shutdownable(&*self.session.state); + } + fn admin_shutdown_error() -> ConnectionError { ConnectionError::Protocol( ProtocolError::ErrorResponse { @@ -320,10 +342,21 @@ impl AsyncPostgresShim { // then reads and discards messages until a Sync is reached, then issues ReadyForQuery and returns to normal message processing. let mut tracked_error: Option = None; + // Clone here to avoid conflicting borrows of self in the tokio::select!. + let semifast_shutdown_interruptor = self.semifast_shutdown_interruptor.clone(); + loop { let mut doing_extended_query_message = false; + let semifast_shutdownable = self.is_semifast_shutdownable(); + + let message: protocol::FrontendMessage = tokio::select! { + true = async { semifast_shutdownable && { semifast_shutdown_interruptor.cancelled().await; true } } => { + return Self::flush_and_write_admin_shutdown_fatal_message(self).await; + } + message_result = buffer::read_message(&mut self.socket) => message_result? + }; - let result = match buffer::read_message(&mut self.socket).await? { + let result = match message { protocol::FrontendMessage::Query(body) => { let span_id = Self::new_span_id(body.query.clone()); let mut qtrace = Qtrace::new(&body.query); diff --git a/rust/cubesql/cubesql/src/sql/session.rs b/rust/cubesql/cubesql/src/sql/session.rs index 5b76643dcb903..c0fadc1c8930e 100644 --- a/rust/cubesql/cubesql/src/sql/session.rs +++ b/rust/cubesql/cubesql/src/sql/session.rs @@ -199,6 +199,18 @@ impl SessionState { } } + pub fn has_current_query(&self) -> bool { + let guard = self + .query + .read() + .expect("failed to unlock query for has_current_query"); + + match &*guard { + QueryState::None => false, + QueryState::Active { .. } => true, + } + } + pub fn end_query(&self) { let mut guard = self .query