diff --git a/sqlx-core/src/sqlite/connection/describe.rs b/sqlx-core/src/sqlite/connection/describe.rs index 8bc9f9ceed..cb86e7e024 100644 --- a/sqlx-core/src/sqlite/connection/describe.rs +++ b/sqlx-core/src/sqlite/connection/describe.rs @@ -64,7 +64,7 @@ pub(super) fn describe<'c: 'e, 'q: 'e, 'e>( // fallback to [column_decltype] if !stepped && stmt.read_only() { stepped = true; - let _ = conn.worker.step(*stmt).await; + let _ = conn.worker.step(stmt).await; } let mut ty = stmt.column_type_info(col); diff --git a/sqlx-core/src/sqlite/connection/establish.rs b/sqlx-core/src/sqlite/connection/establish.rs index 20206a4388..77f10db369 100644 --- a/sqlx-core/src/sqlite/connection/establish.rs +++ b/sqlx-core/src/sqlite/connection/establish.rs @@ -87,7 +87,7 @@ pub(crate) async fn establish(options: &SqliteConnectOptions) -> Result Result Result( +async fn prepare<'a>( + worker: &mut StatementWorker, statements: &'a mut StatementCache, statement: &'a mut Option, query: &str, @@ -39,7 +40,7 @@ fn prepare<'a>( if exists { // as this statement has been executed before, we reset before continuing // this also causes any rows that are from the statement to be inflated - statement.reset(); + statement.reset(worker).await?; } Ok(statement) @@ -61,19 +62,25 @@ fn bind( /// A structure holding sqlite statement handle and resetting the /// statement when it is dropped. -struct StatementResetter { - handle: StatementHandle, +struct StatementResetter<'a> { + handle: Arc, + worker: &'a mut StatementWorker, } -impl StatementResetter { - fn new(handle: StatementHandle) -> Self { - Self { handle } +impl<'a> StatementResetter<'a> { + fn new(worker: &'a mut StatementWorker, handle: &Arc) -> Self { + Self { + worker, + handle: Arc::clone(handle), + } } } -impl Drop for StatementResetter { +impl Drop for StatementResetter<'_> { fn drop(&mut self) { - self.handle.reset(); + // this method is designed to eagerly send the reset command + // so we don't need to await or spawn it + let _ = self.worker.reset(&self.handle); } } @@ -103,7 +110,7 @@ impl<'c> Executor<'c> for &'c mut SqliteConnection { } = self; // prepare statement object (or checkout from cache) - let stmt = prepare(statements, statement, sql, persistent)?; + let stmt = prepare(worker, statements, statement, sql, persistent).await?; // keep track of how many arguments we have bound let mut num_arguments = 0; @@ -113,7 +120,7 @@ impl<'c> Executor<'c> for &'c mut SqliteConnection { // is dropped. `StatementResetter` will reliably reset the // statement even if the stream returned from `fetch_many` // is dropped early. - let _resetter = StatementResetter::new(*stmt); + let resetter = StatementResetter::new(worker, stmt); // bind values to the statement num_arguments += bind(stmt, &arguments, num_arguments)?; @@ -125,7 +132,7 @@ impl<'c> Executor<'c> for &'c mut SqliteConnection { // invoke [sqlite3_step] on the dedicated worker thread // this will move us forward one row or finish the statement - let s = worker.step(*stmt).await?; + let s = resetter.worker.step(stmt).await?; match s { Either::Left(changes) => { @@ -145,7 +152,7 @@ impl<'c> Executor<'c> for &'c mut SqliteConnection { Either::Right(()) => { let (row, weak_values_ref) = SqliteRow::current( - *stmt, + &stmt, columns, column_names ); @@ -188,7 +195,7 @@ impl<'c> Executor<'c> for &'c mut SqliteConnection { } = self; // prepare statement object (or checkout from cache) - let virtual_stmt = prepare(statements, statement, sql, persistent)?; + let virtual_stmt = prepare(worker, statements, statement, sql, persistent).await?; // keep track of how many arguments we have bound let mut num_arguments = 0; @@ -205,18 +212,18 @@ impl<'c> Executor<'c> for &'c mut SqliteConnection { // invoke [sqlite3_step] on the dedicated worker thread // this will move us forward one row or finish the statement - match worker.step(*stmt).await? { + match worker.step(stmt).await? { Either::Left(_) => (), Either::Right(()) => { let (row, weak_values_ref) = - SqliteRow::current(*stmt, columns, column_names); + SqliteRow::current(stmt, columns, column_names); *last_row_values = Some(weak_values_ref); logger.increment_rows(); - virtual_stmt.reset(); + virtual_stmt.reset(worker).await?; return Ok(Some(row)); } } @@ -238,11 +245,12 @@ impl<'c> Executor<'c> for &'c mut SqliteConnection { handle: ref mut conn, ref mut statements, ref mut statement, + ref mut worker, .. } = self; // prepare statement object (or checkout from cache) - let statement = prepare(statements, statement, sql, true)?; + let statement = prepare(worker, statements, statement, sql, true).await?; let mut parameters = 0; let mut columns = None; diff --git a/sqlx-core/src/sqlite/connection/handle.rs b/sqlx-core/src/sqlite/connection/handle.rs index 6aa8f37667..bf59d41f5b 100644 --- a/sqlx-core/src/sqlite/connection/handle.rs +++ b/sqlx-core/src/sqlite/connection/handle.rs @@ -3,11 +3,23 @@ use std::ptr::NonNull; use libsqlite3_sys::{sqlite3, sqlite3_close, SQLITE_OK}; use crate::sqlite::SqliteError; +use std::sync::Arc; /// Managed handle to the raw SQLite3 database handle. -/// The database handle will be closed when this is dropped. +/// The database handle will be closed when this is dropped and no `ConnectionHandleRef`s exist. #[derive(Debug)] -pub(crate) struct ConnectionHandle(pub(super) NonNull); +pub(crate) struct ConnectionHandle(Arc); + +/// A wrapper around `ConnectionHandle` which only exists for a `StatementWorker` to own +/// which prevents the `sqlite3` handle from being finalized while it is running `sqlite3_step()` +/// or `sqlite3_reset()`. +/// +/// Note that this does *not* actually give access to the database handle! +pub(crate) struct ConnectionHandleRef(Arc); + +// Wrapper for `*mut sqlite3` which finalizes the handle on-drop. +#[derive(Debug)] +struct HandleInner(NonNull); // A SQLite3 handle is safe to send between threads, provided not more than // one is accessing it at the same time. This is upheld as long as [SQLITE_CONFIG_MULTITHREAD] is @@ -20,19 +32,32 @@ pub(crate) struct ConnectionHandle(pub(super) NonNull); unsafe impl Send for ConnectionHandle {} +// SAFETY: `Arc` normally only implements `Send` where `T: Sync` because it allows +// concurrent access. +// +// However, in this case we're only using `Arc` to prevent the database handle from being +// finalized while the worker still holds a statement handle; `ConnectionHandleRef` thus +// should *not* actually provide access to the database handle. +unsafe impl Send for ConnectionHandleRef {} + impl ConnectionHandle { #[inline] pub(super) unsafe fn new(ptr: *mut sqlite3) -> Self { - Self(NonNull::new_unchecked(ptr)) + Self(Arc::new(HandleInner(NonNull::new_unchecked(ptr)))) } #[inline] pub(crate) fn as_ptr(&self) -> *mut sqlite3 { - self.0.as_ptr() + self.0 .0.as_ptr() + } + + #[inline] + pub(crate) fn to_ref(&self) -> ConnectionHandleRef { + ConnectionHandleRef(Arc::clone(&self.0)) } } -impl Drop for ConnectionHandle { +impl Drop for HandleInner { fn drop(&mut self) { unsafe { // https://sqlite.org/c3ref/close.html diff --git a/sqlx-core/src/sqlite/connection/mod.rs b/sqlx-core/src/sqlite/connection/mod.rs index 92926beef4..c18add85b2 100644 --- a/sqlx-core/src/sqlite/connection/mod.rs +++ b/sqlx-core/src/sqlite/connection/mod.rs @@ -17,7 +17,7 @@ mod executor; mod explain; mod handle; -pub(crate) use handle::ConnectionHandle; +pub(crate) use handle::{ConnectionHandle, ConnectionHandleRef}; /// A connection to a [Sqlite] database. pub struct SqliteConnection { diff --git a/sqlx-core/src/sqlite/row.rs b/sqlx-core/src/sqlite/row.rs index 9f14ca58f0..f1e15f599d 100644 --- a/sqlx-core/src/sqlite/row.rs +++ b/sqlx-core/src/sqlite/row.rs @@ -23,7 +23,7 @@ pub struct SqliteRow { // IF the user drops the Row before iterating the stream (so // nearly all of our internal stream iterators), the executor moves on; otherwise, // it actually inflates this row with a list of owned sqlite3 values. - pub(crate) statement: StatementHandle, + pub(crate) statement: Arc, pub(crate) values: Arc>, pub(crate) num_values: usize, @@ -48,7 +48,7 @@ impl SqliteRow { // returns a weak reference to an atomic list where the executor should inflate if its going // to increment the statement with [step] pub(crate) fn current( - statement: StatementHandle, + statement: &Arc, columns: &Arc>, column_names: &Arc>, ) -> (Self, Weak>) { @@ -57,7 +57,7 @@ impl SqliteRow { let size = statement.column_count(); let row = Self { - statement, + statement: Arc::clone(statement), values, num_values: size, columns: Arc::clone(columns), diff --git a/sqlx-core/src/sqlite/statement/handle.rs b/sqlx-core/src/sqlite/statement/handle.rs index d1af117a7d..e796d48c5b 100644 --- a/sqlx-core/src/sqlite/statement/handle.rs +++ b/sqlx-core/src/sqlite/statement/handle.rs @@ -1,5 +1,6 @@ use std::ffi::c_void; use std::ffi::CStr; + use std::os::raw::{c_char, c_int}; use std::ptr; use std::ptr::NonNull; @@ -9,21 +10,22 @@ use std::str::{from_utf8, from_utf8_unchecked}; use libsqlite3_sys::{ sqlite3, sqlite3_bind_blob64, sqlite3_bind_double, sqlite3_bind_int, sqlite3_bind_int64, sqlite3_bind_null, sqlite3_bind_parameter_count, sqlite3_bind_parameter_name, - sqlite3_bind_text64, sqlite3_changes, sqlite3_column_blob, sqlite3_column_bytes, - sqlite3_column_count, sqlite3_column_database_name, sqlite3_column_decltype, - sqlite3_column_double, sqlite3_column_int, sqlite3_column_int64, sqlite3_column_name, - sqlite3_column_origin_name, sqlite3_column_table_name, sqlite3_column_type, - sqlite3_column_value, sqlite3_db_handle, sqlite3_reset, sqlite3_sql, sqlite3_stmt, - sqlite3_stmt_readonly, sqlite3_table_column_metadata, sqlite3_value, SQLITE_OK, - SQLITE_TRANSIENT, SQLITE_UTF8, + sqlite3_bind_text64, sqlite3_changes, sqlite3_clear_bindings, sqlite3_column_blob, + sqlite3_column_bytes, sqlite3_column_count, sqlite3_column_database_name, + sqlite3_column_decltype, sqlite3_column_double, sqlite3_column_int, sqlite3_column_int64, + sqlite3_column_name, sqlite3_column_origin_name, sqlite3_column_table_name, + sqlite3_column_type, sqlite3_column_value, sqlite3_db_handle, sqlite3_finalize, sqlite3_reset, + sqlite3_sql, sqlite3_step, sqlite3_stmt, sqlite3_stmt_readonly, sqlite3_table_column_metadata, + sqlite3_value, SQLITE_DONE, SQLITE_MISUSE, SQLITE_OK, SQLITE_ROW, SQLITE_TRANSIENT, + SQLITE_UTF8, }; use crate::error::{BoxDynError, Error}; use crate::sqlite::type_info::DataType; use crate::sqlite::{SqliteError, SqliteTypeInfo}; -#[derive(Debug, Copy, Clone)] -pub(crate) struct StatementHandle(pub(super) NonNull); +#[derive(Debug)] +pub(crate) struct StatementHandle(NonNull); // access to SQLite3 statement handles are safe to send and share between threads // as long as the `sqlite3_step` call is serialized. @@ -32,6 +34,14 @@ unsafe impl Send for StatementHandle {} unsafe impl Sync for StatementHandle {} impl StatementHandle { + pub(super) fn new(ptr: NonNull) -> Self { + Self(ptr) + } + + pub(crate) fn as_ptr(&self) -> *mut sqlite3_stmt { + self.0.as_ptr() + } + #[inline] pub(super) unsafe fn db_handle(&self) -> *mut sqlite3 { // O(c) access to the connection handle for this statement handle @@ -280,7 +290,25 @@ impl StatementHandle { Ok(from_utf8(self.column_blob(index))?) } - pub(crate) fn reset(&self) { - unsafe { sqlite3_reset(self.0.as_ptr()) }; + pub(crate) fn clear_bindings(&self) { + unsafe { sqlite3_clear_bindings(self.0.as_ptr()) }; + } +} +impl Drop for StatementHandle { + fn drop(&mut self) { + // SAFETY: we have exclusive access to the `StatementHandle` here + unsafe { + // https://sqlite.org/c3ref/finalize.html + let status = sqlite3_finalize(self.0.as_ptr()); + if status == SQLITE_MISUSE { + // Panic in case of detected misuse of SQLite API. + // + // sqlite3_finalize returns it at least in the + // case of detected double free, i.e. calling + // sqlite3_finalize on already finalized + // statement. + panic!("Detected sqlite3_finalize misuse."); + } + } } } diff --git a/sqlx-core/src/sqlite/statement/virtual.rs b/sqlx-core/src/sqlite/statement/virtual.rs index 0063e06508..3da6d33d64 100644 --- a/sqlx-core/src/sqlite/statement/virtual.rs +++ b/sqlx-core/src/sqlite/statement/virtual.rs @@ -3,13 +3,12 @@ use crate::error::Error; use crate::ext::ustr::UStr; use crate::sqlite::connection::ConnectionHandle; -use crate::sqlite::statement::StatementHandle; +use crate::sqlite::statement::{StatementHandle, StatementWorker}; use crate::sqlite::{SqliteColumn, SqliteError, SqliteRow, SqliteValue}; use crate::HashMap; use bytes::{Buf, Bytes}; use libsqlite3_sys::{ - sqlite3, sqlite3_clear_bindings, sqlite3_finalize, sqlite3_prepare_v3, sqlite3_reset, - sqlite3_stmt, SQLITE_MISUSE, SQLITE_OK, SQLITE_PREPARE_PERSISTENT, + sqlite3, sqlite3_prepare_v3, sqlite3_stmt, SQLITE_OK, SQLITE_PREPARE_PERSISTENT, }; use smallvec::SmallVec; use std::i32; @@ -31,7 +30,7 @@ pub(crate) struct VirtualStatement { // underlying sqlite handles for each inner statement // a SQL query string in SQLite is broken up into N statements // we use a [`SmallVec`] to optimize for the most likely case of a single statement - pub(crate) handles: SmallVec<[StatementHandle; 1]>, + pub(crate) handles: SmallVec<[Arc; 1]>, // each set of columns pub(crate) columns: SmallVec<[Arc>; 1]>, @@ -92,7 +91,7 @@ fn prepare( query.advance(n); if let Some(handle) = NonNull::new(statement_handle) { - return Ok(Some(StatementHandle(handle))); + return Ok(Some(StatementHandle::new(handle))); } } @@ -126,7 +125,7 @@ impl VirtualStatement { conn: &mut ConnectionHandle, ) -> Result< Option<( - &StatementHandle, + &Arc, &mut Arc>, &Arc>, &mut Option>>, @@ -159,7 +158,7 @@ impl VirtualStatement { column_names.insert(name, i); } - self.handles.push(statement); + self.handles.push(Arc::new(statement)); self.columns.push(Arc::new(columns)); self.column_names.push(Arc::new(column_names)); self.last_row_values.push(None); @@ -177,20 +176,20 @@ impl VirtualStatement { ))) } - pub(crate) fn reset(&mut self) { + pub(crate) async fn reset(&mut self, worker: &mut StatementWorker) -> Result<(), Error> { self.index = 0; for (i, handle) in self.handles.iter().enumerate() { SqliteRow::inflate_if_needed(&handle, &self.columns[i], self.last_row_values[i].take()); - unsafe { - // Reset A Prepared Statement Object - // https://www.sqlite.org/c3ref/reset.html - // https://www.sqlite.org/c3ref/clear_bindings.html - sqlite3_reset(handle.0.as_ptr()); - sqlite3_clear_bindings(handle.0.as_ptr()); - } + // Reset A Prepared Statement Object + // https://www.sqlite.org/c3ref/reset.html + // https://www.sqlite.org/c3ref/clear_bindings.html + worker.reset(handle).await?; + handle.clear_bindings(); } + + Ok(()) } } @@ -198,20 +197,6 @@ impl Drop for VirtualStatement { fn drop(&mut self) { for (i, handle) in self.handles.drain(..).enumerate() { SqliteRow::inflate_if_needed(&handle, &self.columns[i], self.last_row_values[i].take()); - - unsafe { - // https://sqlite.org/c3ref/finalize.html - let status = sqlite3_finalize(handle.0.as_ptr()); - if status == SQLITE_MISUSE { - // Panic in case of detected misuse of SQLite API. - // - // sqlite3_finalize returns it at least in the - // case of detected double free, i.e. calling - // sqlite3_finalize on already finalized - // statement. - panic!("Detected sqlite3_finalize misuse."); - } - } } } } diff --git a/sqlx-core/src/sqlite/statement/worker.rs b/sqlx-core/src/sqlite/statement/worker.rs index 8b1d229978..b5a7cf88ed 100644 --- a/sqlx-core/src/sqlite/statement/worker.rs +++ b/sqlx-core/src/sqlite/statement/worker.rs @@ -3,9 +3,14 @@ use crate::sqlite::statement::StatementHandle; use crossbeam_channel::{unbounded, Sender}; use either::Either; use futures_channel::oneshot; -use libsqlite3_sys::{sqlite3_step, SQLITE_DONE, SQLITE_ROW}; +use std::sync::{Arc, Weak}; use std::thread; +use crate::sqlite::connection::ConnectionHandleRef; + +use libsqlite3_sys::{sqlite3_reset, sqlite3_step, SQLITE_DONE, SQLITE_ROW}; +use std::future::Future; + // Each SQLite connection has a dedicated thread. // TODO: Tweak this so that we can use a thread pool per pool of SQLite3 connections to reduce @@ -18,31 +23,60 @@ pub(crate) struct StatementWorker { enum StatementWorkerCommand { Step { - statement: StatementHandle, + statement: Weak, tx: oneshot::Sender, Error>>, }, + Reset { + statement: Weak, + tx: oneshot::Sender<()>, + }, } impl StatementWorker { - pub(crate) fn new() -> Self { + pub(crate) fn new(conn: ConnectionHandleRef) -> Self { let (tx, rx) = unbounded(); thread::spawn(move || { for cmd in rx { match cmd { StatementWorkerCommand::Step { statement, tx } => { - let status = unsafe { sqlite3_step(statement.0.as_ptr()) }; + let statement = if let Some(statement) = statement.upgrade() { + statement + } else { + // statement is already finalized, the sender shouldn't be expecting a response + continue; + }; - let resp = match status { + // SAFETY: only the `StatementWorker` calls this function + let status = unsafe { sqlite3_step(statement.as_ptr()) }; + let result = match status { SQLITE_ROW => Ok(Either::Right(())), SQLITE_DONE => Ok(Either::Left(statement.changes())), _ => Err(statement.last_error().into()), }; - let _ = tx.send(resp); + let _ = tx.send(result); + } + StatementWorkerCommand::Reset { statement, tx } => { + if let Some(statement) = statement.upgrade() { + // SAFETY: this must be the only place we call `sqlite3_reset` + unsafe { sqlite3_reset(statement.as_ptr()) }; + + // `sqlite3_reset()` always returns either `SQLITE_OK` + // or the last error code for the statement, + // which should have already been handled; + // so it's assumed the return value is safe to ignore. + // + // https://www.sqlite.org/c3ref/reset.html + + let _ = tx.send(()); + } } } } + + // SAFETY: we need to make sure a strong ref to `conn` always outlives anything in `rx` + drop(conn); }); Self { tx } @@ -50,14 +84,47 @@ impl StatementWorker { pub(crate) async fn step( &mut self, - statement: StatementHandle, + statement: &Arc, ) -> Result, Error> { let (tx, rx) = oneshot::channel(); self.tx - .send(StatementWorkerCommand::Step { statement, tx }) + .send(StatementWorkerCommand::Step { + statement: Arc::downgrade(statement), + tx, + }) .map_err(|_| Error::WorkerCrashed)?; rx.await.map_err(|_| Error::WorkerCrashed)? } + + /// Send a command to the worker to execute `sqlite3_reset()` next. + /// + /// This method is written to execute the sending of the command eagerly so + /// you do not need to await the returned future unless you want to. + /// + /// The only error is `WorkerCrashed` as `sqlite3_reset()` returns the last error + /// in the statement execution which should have already been handled from `step()`. + pub(crate) fn reset( + &mut self, + statement: &Arc, + ) -> impl Future> { + // execute the sending eagerly so we don't need to spawn the future + let (tx, rx) = oneshot::channel(); + + let send_res = self + .tx + .send(StatementWorkerCommand::Reset { + statement: Arc::downgrade(statement), + tx, + }) + .map_err(|_| Error::WorkerCrashed); + + async move { + send_res?; + + // wait for the response + rx.await.map_err(|_| Error::WorkerCrashed) + } + } } diff --git a/tests/sqlite/sqlite.db b/tests/sqlite/sqlite.db index d693511079..a3d8d5cc29 100644 Binary files a/tests/sqlite/sqlite.db and b/tests/sqlite/sqlite.db differ diff --git a/tests/sqlite/sqlite.rs b/tests/sqlite/sqlite.rs index 1334e493c1..e794f49a74 100644 --- a/tests/sqlite/sqlite.rs +++ b/tests/sqlite/sqlite.rs @@ -536,3 +536,34 @@ async fn it_resets_prepared_statement_after_fetch_many() -> anyhow::Result<()> { Ok(()) } + +// https://github.com/launchbadge/sqlx/issues/1300 +#[sqlx_macros::test] +async fn concurrent_resets_dont_segfault() { + use sqlx::{sqlite::SqliteConnectOptions, ConnectOptions}; + use std::{str::FromStr, time::Duration}; + + let mut conn = SqliteConnectOptions::from_str(":memory:") + .unwrap() + .connect() + .await + .unwrap(); + + sqlx::query("CREATE TABLE stuff (name INTEGER, value INTEGER)") + .execute(&mut conn) + .await + .unwrap(); + + sqlx_rt::spawn(async move { + for i in 0..1000 { + sqlx::query("INSERT INTO stuff (name, value) VALUES (?, ?)") + .bind(i) + .bind(0) + .execute(&mut conn) + .await + .unwrap(); + } + }); + + sqlx_rt::sleep(Duration::from_millis(1)).await; +}